Upscaling works now

pull/1/head
Estelle Poulin 1 year ago
parent f284983368
commit ab7bdec755

@ -1,11 +1,10 @@
import json
import os import os
import textwrap
from typing import Optional from typing import Optional
import aiohttp import aiohttp
from pydantic import ( from pydantic import validate_arguments
validate_arguments,
)
from stabilityai.constants import ( from stabilityai.constants import (
DEFAULT_GENERATION_ENGINE, DEFAULT_GENERATION_ENGINE,
DEFAULT_STABILITY_API_HOST, DEFAULT_STABILITY_API_HOST,
@ -13,7 +12,10 @@ from stabilityai.constants import (
DEFAULT_STABILITY_CLIENT_VERSION, DEFAULT_STABILITY_CLIENT_VERSION,
DEFAULT_UPSCALE_ENGINE, DEFAULT_UPSCALE_ENGINE,
) )
from stabilityai.exceptions import ThisFunctionRequiresAPrompt, YouNeedToUseAContextManager from stabilityai.exceptions import (
ThisFunctionRequiresAPrompt,
YouNeedToUseAContextManager,
)
from stabilityai.models import ( from stabilityai.models import (
AccountResponseBody, AccountResponseBody,
BalanceResponseBody, BalanceResponseBody,
@ -44,8 +46,6 @@ from stabilityai.models import (
UpscaleImageHeight, UpscaleImageHeight,
UpscaleImageWidth, UpscaleImageWidth,
) )
from stabilityai.utils import omit_none
import textwrap
class AsyncStabilityClient: class AsyncStabilityClient:
@ -69,7 +69,6 @@ class AsyncStabilityClient:
self.session = await aiohttp.ClientSession( self.session = await aiohttp.ClientSession(
base_url=self.api_host, base_url=self.api_host,
headers={ headers={
"Content-Type": "application/json",
"Accept": "application/json", "Accept": "application/json",
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
"Stability-Client-ID": self.client_id, "Stability-Client-ID": self.client_id,
@ -122,9 +121,7 @@ class AsyncStabilityClient:
) )
def _normalize_text_prompts( def _normalize_text_prompts(
self, self, text_prompts: Optional[TextPrompts], text_prompt: Optional[SingleTextPrompt]
text_prompts: Optional[TextPrompts],
text_prompt: Optional[SingleTextPrompt]
): ):
if not bool(text_prompt) ^ bool(text_prompts): if not bool(text_prompt) ^ bool(text_prompts):
raise ThisFunctionRequiresAPrompt( raise ThisFunctionRequiresAPrompt(
@ -151,10 +148,6 @@ class AsyncStabilityClient:
) )
) )
if text_prompt:
text_prompts = [TextPrompt(text=text_prompt)]
# After this moment text_prompts can't be None.
assert text_prompts is not None assert text_prompts is not None
return text_prompts return text_prompts
@ -197,7 +190,8 @@ class AsyncStabilityClient:
res = await self.session.post( res = await self.session.post(
f"/v1/generation/{engine_id}/text-to-image", f"/v1/generation/{engine_id}/text-to-image",
data=json.dumps(omit_none(json.loads(request_body.json()))), headers={"Content-Type": "application/json"},
data=request_body.json(exclude_defaults=True, exclude_none=True, exclude_unset=True),
) )
Sampler.K_DPMPP_2M Sampler.K_DPMPP_2M
@ -224,7 +218,7 @@ class AsyncStabilityClient:
height=dim, height=dim,
) )
@validate_arguments @validate_arguments(config={"arbitrary_types_allowed": True})
async def image_to_image( async def image_to_image(
self, self,
text_prompts: TextPrompts, text_prompts: TextPrompts,
@ -265,11 +259,12 @@ class AsyncStabilityClient:
res = await self.session.post( res = await self.session.post(
f"/v1/generation/{engine_id}/text-to-image", f"/v1/generation/{engine_id}/text-to-image",
data=json.dumps(omit_none(json.loads(request_body.json()))), data=request_body.json(exclude_none=True, exclude_defaults=True, exclude_unset=True),
) )
return ImageToImageResponseBody.parse_obj(await res.json()) return ImageToImageResponseBody.parse_obj(await res.json())
@validate_arguments(config={"arbitrary_types_allowed": True})
async def image_to_image_upscale( async def image_to_image_upscale(
self, self,
image: InitImage, image: InitImage,
@ -284,9 +279,13 @@ class AsyncStabilityClient:
request_body = ImageToImageUpscaleRequestBody(image=image, width=width, height=height) request_body = ImageToImageUpscaleRequestBody(image=image, width=width, height=height)
form = aiohttp.FormData()
form.add_field("width", "1024")
form.add_field("image", request_body.image)
res = await self.session.post( res = await self.session.post(
f"/v1/generation/{engine_id}/image-to-image/upscale", f"/v1/generation/{engine_id}/image-to-image/upscale",
data=json.dumps(omit_none(json.loads(request_body.json()))), data=form,
) )
return ImageToImageUpscaleResponseBody.parse_obj(await res.json()) return ImageToImageUpscaleResponseBody.parse_obj(await res.json())

@ -1,6 +1,5 @@
from typing import Final from typing import Final
DEFAULT_STABILITY_API_HOST: Final[str] = "https://api.stability.ai" DEFAULT_STABILITY_API_HOST: Final[str] = "https://api.stability.ai"
DEFAULT_STABILITY_CLIENT_ID: Final[str] = "Stability Python SDK" DEFAULT_STABILITY_CLIENT_ID: Final[str] = "Stability Python SDK"
DEFAULT_STABILITY_CLIENT_VERSION: Final[str] = "1.0.0" DEFAULT_STABILITY_CLIENT_VERSION: Final[str] = "1.0.0"

@ -3,15 +3,8 @@ import io
from enum import StrEnum from enum import StrEnum
from typing import Annotated, List, Optional from typing import Annotated, List, Optional
from pydantic import ( from pydantic import AnyUrl, BaseModel, EmailStr, confloat, conint, constr, validator
AnyUrl,
BaseModel,
EmailStr,
confloat,
conint,
constr,
validator,
)
class Type(StrEnum): class Type(StrEnum):
AUDIO = "AUDIO" AUDIO = "AUDIO"
@ -103,11 +96,12 @@ class TextPrompt(BaseModel):
text: Annotated[str, constr(max_length=2000)] text: Annotated[str, constr(max_length=2000)]
weight: Optional[float] = None weight: Optional[float] = None
SingleTextPrompt = str SingleTextPrompt = str
TextPrompts = List[TextPrompt] TextPrompts = List[TextPrompt]
InitImage = bytes InitImage = io.IOBase
InitImageStrength = Annotated[float, confloat(ge=0.0, le=1.0)] InitImageStrength = Annotated[float, confloat(ge=0.0, le=1.0)]
@ -146,6 +140,9 @@ class LatentUpscalerUpscaleRequestBody(BaseModel):
steps: Optional[Steps] = None steps: Optional[Steps] = None
cfg_scale: Optional[CfgScale] = None cfg_scale: Optional[CfgScale] = None
class Config:
arbitrary_types_allowed = True
class ImageToImageRequestBody(BaseModel): class ImageToImageRequestBody(BaseModel):
text_prompts: TextPrompts text_prompts: TextPrompts
@ -163,6 +160,9 @@ class ImageToImageRequestBody(BaseModel):
style_preset: Optional[StylePreset] = None style_preset: Optional[StylePreset] = None
extras: Optional[Extras] = None extras: Optional[Extras] = None
class Config:
arbitrary_types_allowed = True
class MaskingRequestBody(BaseModel): class MaskingRequestBody(BaseModel):
init_image: InitImage init_image: InitImage
@ -178,6 +178,9 @@ class MaskingRequestBody(BaseModel):
style_preset: Optional[StylePreset] = None style_preset: Optional[StylePreset] = None
extras: Optional[Extras] = None extras: Optional[Extras] = None
class Config:
arbitrary_types_allowed = True
class TextToImageRequestBody(GenerationRequestOptionalParams): class TextToImageRequestBody(GenerationRequestOptionalParams):
height: Optional[DiffuseImageHeight] = None height: Optional[DiffuseImageHeight] = None
@ -196,6 +199,9 @@ class ImageToImageUpscaleRequestBody(BaseModel):
raise ValueError("You can only specify one of width and height.") raise ValueError("You can only specify one of width and height.")
return v return v
class Config:
arbitrary_types_allowed = True
class BalanceResponseBody(BaseModel): class BalanceResponseBody(BaseModel):
credits: float credits: float

@ -1,6 +0,0 @@
from typing import Dict, Mapping, TypeVar
T, U = TypeVar("T"), TypeVar("U")
def omit_none(m: Mapping[T, U]) -> Dict[T, U]:
return {k: v for k, v in m.items() if v is not None}
Loading…
Cancel
Save