diff --git a/stabilityai/client.py b/stabilityai/client.py index 97fe164..8fd1282 100644 --- a/stabilityai/client.py +++ b/stabilityai/client.py @@ -1,11 +1,10 @@ -import json import os +import textwrap from typing import Optional import aiohttp -from pydantic import ( - validate_arguments, -) +from pydantic import validate_arguments + from stabilityai.constants import ( DEFAULT_GENERATION_ENGINE, DEFAULT_STABILITY_API_HOST, @@ -13,7 +12,10 @@ from stabilityai.constants import ( DEFAULT_STABILITY_CLIENT_VERSION, DEFAULT_UPSCALE_ENGINE, ) -from stabilityai.exceptions import ThisFunctionRequiresAPrompt, YouNeedToUseAContextManager +from stabilityai.exceptions import ( + ThisFunctionRequiresAPrompt, + YouNeedToUseAContextManager, +) from stabilityai.models import ( AccountResponseBody, BalanceResponseBody, @@ -44,8 +46,6 @@ from stabilityai.models import ( UpscaleImageHeight, UpscaleImageWidth, ) -from stabilityai.utils import omit_none -import textwrap class AsyncStabilityClient: @@ -69,7 +69,6 @@ class AsyncStabilityClient: self.session = await aiohttp.ClientSession( base_url=self.api_host, headers={ - "Content-Type": "application/json", "Accept": "application/json", "Authorization": f"Bearer {self.api_key}", "Stability-Client-ID": self.client_id, @@ -122,9 +121,7 @@ class AsyncStabilityClient: ) def _normalize_text_prompts( - self, - text_prompts: Optional[TextPrompts], - text_prompt: Optional[SingleTextPrompt] + self, text_prompts: Optional[TextPrompts], text_prompt: Optional[SingleTextPrompt] ): if not bool(text_prompt) ^ bool(text_prompts): raise ThisFunctionRequiresAPrompt( @@ -147,14 +144,10 @@ class AsyncStabilityClient: TextPrompt(text="A beautiful sunrise", weight=1.0) ], ) - """ + """ ) ) - if text_prompt: - text_prompts = [TextPrompt(text=text_prompt)] - - # After this moment text_prompts can't be None. assert text_prompts is not None return text_prompts @@ -197,7 +190,8 @@ class AsyncStabilityClient: res = await self.session.post( f"/v1/generation/{engine_id}/text-to-image", - data=json.dumps(omit_none(json.loads(request_body.json()))), + headers={"Content-Type": "application/json"}, + data=request_body.json(exclude_defaults=True, exclude_none=True, exclude_unset=True), ) Sampler.K_DPMPP_2M @@ -224,7 +218,7 @@ class AsyncStabilityClient: height=dim, ) - @validate_arguments + @validate_arguments(config={"arbitrary_types_allowed": True}) async def image_to_image( self, text_prompts: TextPrompts, @@ -265,11 +259,12 @@ class AsyncStabilityClient: res = await self.session.post( f"/v1/generation/{engine_id}/text-to-image", - data=json.dumps(omit_none(json.loads(request_body.json()))), + data=request_body.json(exclude_none=True, exclude_defaults=True, exclude_unset=True), ) return ImageToImageResponseBody.parse_obj(await res.json()) + @validate_arguments(config={"arbitrary_types_allowed": True}) async def image_to_image_upscale( self, image: InitImage, @@ -284,9 +279,13 @@ class AsyncStabilityClient: request_body = ImageToImageUpscaleRequestBody(image=image, width=width, height=height) + form = aiohttp.FormData() + form.add_field("width", "1024") + form.add_field("image", request_body.image) + res = await self.session.post( f"/v1/generation/{engine_id}/image-to-image/upscale", - data=json.dumps(omit_none(json.loads(request_body.json()))), + data=form, ) return ImageToImageUpscaleResponseBody.parse_obj(await res.json()) diff --git a/stabilityai/constants.py b/stabilityai/constants.py index 84963dc..8083cff 100644 --- a/stabilityai/constants.py +++ b/stabilityai/constants.py @@ -1,6 +1,5 @@ from typing import Final - DEFAULT_STABILITY_API_HOST: Final[str] = "https://api.stability.ai" DEFAULT_STABILITY_CLIENT_ID: Final[str] = "Stability Python SDK" DEFAULT_STABILITY_CLIENT_VERSION: Final[str] = "1.0.0" diff --git a/stabilityai/models.py b/stabilityai/models.py index 695a06e..5e90e3a 100644 --- a/stabilityai/models.py +++ b/stabilityai/models.py @@ -3,15 +3,8 @@ import io from enum import StrEnum from typing import Annotated, List, Optional -from pydantic import ( - AnyUrl, - BaseModel, - EmailStr, - confloat, - conint, - constr, - validator, -) +from pydantic import AnyUrl, BaseModel, EmailStr, confloat, conint, constr, validator + class Type(StrEnum): AUDIO = "AUDIO" @@ -103,11 +96,12 @@ class TextPrompt(BaseModel): text: Annotated[str, constr(max_length=2000)] weight: Optional[float] = None + SingleTextPrompt = str TextPrompts = List[TextPrompt] -InitImage = bytes +InitImage = io.IOBase InitImageStrength = Annotated[float, confloat(ge=0.0, le=1.0)] @@ -146,6 +140,9 @@ class LatentUpscalerUpscaleRequestBody(BaseModel): steps: Optional[Steps] = None cfg_scale: Optional[CfgScale] = None + class Config: + arbitrary_types_allowed = True + class ImageToImageRequestBody(BaseModel): text_prompts: TextPrompts @@ -163,6 +160,9 @@ class ImageToImageRequestBody(BaseModel): style_preset: Optional[StylePreset] = None extras: Optional[Extras] = None + class Config: + arbitrary_types_allowed = True + class MaskingRequestBody(BaseModel): init_image: InitImage @@ -178,6 +178,9 @@ class MaskingRequestBody(BaseModel): style_preset: Optional[StylePreset] = None extras: Optional[Extras] = None + class Config: + arbitrary_types_allowed = True + class TextToImageRequestBody(GenerationRequestOptionalParams): height: Optional[DiffuseImageHeight] = None @@ -196,6 +199,9 @@ class ImageToImageUpscaleRequestBody(BaseModel): raise ValueError("You can only specify one of width and height.") return v + class Config: + arbitrary_types_allowed = True + class BalanceResponseBody(BaseModel): credits: float diff --git a/stabilityai/utils.py b/stabilityai/utils.py deleted file mode 100644 index ae49f21..0000000 --- a/stabilityai/utils.py +++ /dev/null @@ -1,6 +0,0 @@ -from typing import Dict, Mapping, TypeVar - -T, U = TypeVar("T"), TypeVar("U") - -def omit_none(m: Mapping[T, U]) -> Dict[T, U]: - return {k: v for k, v in m.items() if v is not None}