diff --git a/stabilityai/__init__.py b/stabilityai/__init__.py index 976498a..77a855a 100644 --- a/stabilityai/__init__.py +++ b/stabilityai/__init__.py @@ -1 +1,3 @@ -__version__ = "1.0.3" +from pkg_resources import parse_version + +__version__ = parse_version("1.0.4") diff --git a/stabilityai/client.py b/stabilityai/client.py index 97fe164..2ec5de5 100644 --- a/stabilityai/client.py +++ b/stabilityai/client.py @@ -1,11 +1,11 @@ -import json import os +import textwrap from typing import Optional import aiohttp -from pydantic import ( - validate_arguments, -) +from aiohttp.client_exceptions import ClientResponseError +from pydantic import validate_arguments + from stabilityai.constants import ( DEFAULT_GENERATION_ENGINE, DEFAULT_STABILITY_API_HOST, @@ -13,7 +13,11 @@ from stabilityai.constants import ( DEFAULT_STABILITY_CLIENT_VERSION, DEFAULT_UPSCALE_ENGINE, ) -from stabilityai.exceptions import ThisFunctionRequiresAPrompt, YouNeedToUseAContextManager +from stabilityai.exceptions import ( + ThisFunctionRequiresAPrompt, + YouNeedToUseAContextManager, + figure_out_exception, +) from stabilityai.models import ( AccountResponseBody, BalanceResponseBody, @@ -44,8 +48,7 @@ from stabilityai.models import ( UpscaleImageHeight, UpscaleImageWidth, ) -from stabilityai.utils import omit_none -import textwrap +from stabilityai.utils import serialize_for_form class AsyncStabilityClient: @@ -69,7 +72,6 @@ class AsyncStabilityClient: self.session = await aiohttp.ClientSession( base_url=self.api_host, headers={ - "Content-Type": "application/json", "Accept": "application/json", "Authorization": f"Bearer {self.api_key}", "Stability-Client-ID": self.client_id, @@ -122,9 +124,7 @@ class AsyncStabilityClient: ) def _normalize_text_prompts( - self, - text_prompts: Optional[TextPrompts], - text_prompt: Optional[SingleTextPrompt] + self, text_prompts: Optional[TextPrompts], text_prompt: Optional[SingleTextPrompt] ): if not bool(text_prompt) ^ bool(text_prompts): raise ThisFunctionRequiresAPrompt( @@ -147,15 +147,11 @@ class AsyncStabilityClient: TextPrompt(text="A beautiful sunrise", weight=1.0) ], ) - """ + """ ) ) - if text_prompt: - text_prompts = [TextPrompt(text=text_prompt)] - - # After this moment text_prompts can't be None. - assert text_prompts is not None + text_prompts = text_prompts or [TextPrompt(text=text_prompt)] return text_prompts @validate_arguments @@ -195,11 +191,16 @@ class AsyncStabilityClient: extras=extras, ) - res = await self.session.post( - f"/v1/generation/{engine_id}/text-to-image", - data=json.dumps(omit_none(json.loads(request_body.json()))), - ) - Sampler.K_DPMPP_2M + try: + res = await self.session.post( + f"/v1/generation/{engine_id}/text-to-image", + headers={"Content-Type": "application/json"}, + data=request_body.json( + exclude_defaults=True, exclude_none=True, exclude_unset=True + ), + ) + except ClientResponseError as e: + raise figure_out_exception(e) from e return TextToImageResponseBody.parse_obj(await res.json()) @@ -224,12 +225,12 @@ class AsyncStabilityClient: height=dim, ) - @validate_arguments + @validate_arguments(config={"arbitrary_types_allowed": True}) async def image_to_image( self, - text_prompts: TextPrompts, - text_prompt: SingleTextPrompt, init_image: InitImage, + text_prompts: Optional[TextPrompts] = None, + text_prompt: Optional[SingleTextPrompt] = None, *, init_image_mode: Optional[InitImageMode] = None, image_strength: InitImageStrength, @@ -263,13 +264,23 @@ class AsyncStabilityClient: extras=extras, ) - res = await self.session.post( - f"/v1/generation/{engine_id}/text-to-image", - data=json.dumps(omit_none(json.loads(request_body.json()))), - ) + form = aiohttp.FormData() + form_data = serialize_for_form(request_body.dict()) + + for k, v in form_data.items(): + form.add_field(k, v) + + try: + res = await self.session.post( + f"/v1/generation/{engine_id}/image-to-image", + data=form, + ) + except ClientResponseError as e: + raise figure_out_exception(e) from e return ImageToImageResponseBody.parse_obj(await res.json()) + @validate_arguments(config={"arbitrary_types_allowed": True}) async def image_to_image_upscale( self, image: InitImage, @@ -284,9 +295,16 @@ class AsyncStabilityClient: request_body = ImageToImageUpscaleRequestBody(image=image, width=width, height=height) - res = await self.session.post( - f"/v1/generation/{engine_id}/image-to-image/upscale", - data=json.dumps(omit_none(json.loads(request_body.json()))), - ) + form = aiohttp.FormData() + form.add_field("width", "1024") + form.add_field("image", request_body.image) + + try: + res = await self.session.post( + f"/v1/generation/{engine_id}/image-to-image/upscale", + data=form, + ) + except ClientResponseError as e: + raise figure_out_exception(e) from e return ImageToImageUpscaleResponseBody.parse_obj(await res.json()) diff --git a/stabilityai/constants.py b/stabilityai/constants.py index 84963dc..8083cff 100644 --- a/stabilityai/constants.py +++ b/stabilityai/constants.py @@ -1,6 +1,5 @@ from typing import Final - DEFAULT_STABILITY_API_HOST: Final[str] = "https://api.stability.ai" DEFAULT_STABILITY_CLIENT_ID: Final[str] = "Stability Python SDK" DEFAULT_STABILITY_CLIENT_VERSION: Final[str] = "1.0.0" diff --git a/stabilityai/exceptions.py b/stabilityai/exceptions.py index 8a2458e..1e0a3b2 100644 --- a/stabilityai/exceptions.py +++ b/stabilityai/exceptions.py @@ -1,6 +1,125 @@ +import textwrap + +from aiohttp.client_exceptions import ClientResponseError + + class YouNeedToUseAContextManager(ValueError): pass class ThisFunctionRequiresAPrompt(ValueError): pass + + +class StabilityAiError(Exception): + retryable: bool = False + + +class YouAreHoldingItWrong(StabilityAiError): + pass + + +class ApiKeyIsMissingOrInvalid(StabilityAiError): + pass + + +class ImAfraidICantLetYouDoThat(StabilityAiError): + pass + + +class YouBrokeTheirServeres(StabilityAiError): + retryable: bool = True + + +class ResourceNotFound(StabilityAiError): + pass + + +class RateLimitOrServerError(StabilityAiError): + retryable: bool = True + + +def figure_out_exception(e: ClientResponseError) -> StabilityAiError: + if e.status == 400: + return YouAreHoldingItWrong( + textwrap.dedent( + """\ + Either you specificed some combination of options that their + server doesn't like or it's a bug in this library. If you + think it's the latter open up a bug report on + + https://github.com/estheruary/stabilityai + """ + ) + ) + elif e.status == 401: + return ApiKeyIsMissingOrInvalid( + textwrap.dedent( + """\ + This could be one of a few problems. If you don't have an API + key you should generate one at + + https://beta.dreamstudio.ai/account + + If you have an API key there are two ways to provide it to this + library. The first and most common is via the environment. You + to set STABILITY_API_KEY. + + The second method is to provide it to the client constructor + directly. + + async with AsyncStabilityClient(api_key="your-key") as stability: + ... + """ + ) + ) + elif e.status == 403: + return ImAfraidICantLetYouDoThat( + textwrap.dedent( + """\ + At the time of writing (Jun 2023) Stability doesn't have any + permissions on their API so if you get this it's likely + a bug in this library or a problem on their servers. If you + think it's the former report it at + + https://github.com/estheruary/stabilityai + """ + ) + ) + elif e.status == 429: + return RateLimitOrServerError( + textwrap.dedent( + """\ + Stability's rate limit is 150 req / 10 sec. If you get this + error on *every* request then it's a bug on their servers. + You need to generate a new API key and delete the old one. + """ + ) + ) + elif e.status == 404: + return ResourceNotFound( + textwrap.dedent( + """\ + This usually happens when you specify models that don't exist + *or* a combination of model/sampler that doesn't work together. + """ + ) + ) + elif e.status == 500: + return YouBrokeTheirServeres( + textwrap.dedent( + """\ + Impressive. This is not an error you can do anything about + except retry. Stability might be having an outage or a + problem on their servers. + """ + ) + ) + else: + return StabilityAiError( + textwrap.dedent( + """\ + Sorry, I don't know what's up with this error. + """ + ) + ) diff --git a/stabilityai/models.py b/stabilityai/models.py index a3c9ab4..5e90e3a 100644 --- a/stabilityai/models.py +++ b/stabilityai/models.py @@ -3,15 +3,8 @@ import io from enum import StrEnum from typing import Annotated, List, Optional -from pydantic import ( - AnyUrl, - BaseModel, - EmailStr, - confloat, - conint, - constr, - validator, -) +from pydantic import AnyUrl, BaseModel, EmailStr, confloat, conint, constr, validator + class Type(StrEnum): AUDIO = "AUDIO" @@ -103,11 +96,12 @@ class TextPrompt(BaseModel): text: Annotated[str, constr(max_length=2000)] weight: Optional[float] = None + SingleTextPrompt = str TextPrompts = List[TextPrompt] -InitImage = bytes +InitImage = io.IOBase InitImageStrength = Annotated[float, confloat(ge=0.0, le=1.0)] @@ -146,6 +140,9 @@ class LatentUpscalerUpscaleRequestBody(BaseModel): steps: Optional[Steps] = None cfg_scale: Optional[CfgScale] = None + class Config: + arbitrary_types_allowed = True + class ImageToImageRequestBody(BaseModel): text_prompts: TextPrompts @@ -163,6 +160,9 @@ class ImageToImageRequestBody(BaseModel): style_preset: Optional[StylePreset] = None extras: Optional[Extras] = None + class Config: + arbitrary_types_allowed = True + class MaskingRequestBody(BaseModel): init_image: InitImage @@ -178,6 +178,9 @@ class MaskingRequestBody(BaseModel): style_preset: Optional[StylePreset] = None extras: Optional[Extras] = None + class Config: + arbitrary_types_allowed = True + class TextToImageRequestBody(GenerationRequestOptionalParams): height: Optional[DiffuseImageHeight] = None @@ -192,10 +195,13 @@ class ImageToImageUpscaleRequestBody(BaseModel): @validator("width", always=True) def mutually_exclusive(cls, v, values): - if values["height"] is not None and v: + if values.get("height") is not None and v: raise ValueError("You can only specify one of width and height.") return v + class Config: + arbitrary_types_allowed = True + class BalanceResponseBody(BaseModel): credits: float diff --git a/stabilityai/utils.py b/stabilityai/utils.py index ae49f21..7d9a114 100644 --- a/stabilityai/utils.py +++ b/stabilityai/utils.py @@ -1,6 +1,38 @@ -from typing import Dict, Mapping, TypeVar +from numbers import Number +from typing import Mapping, Sequence -T, U = TypeVar("T"), TypeVar("U") -def omit_none(m: Mapping[T, U]) -> Dict[T, U]: - return {k: v for k, v in m.items() if v is not None} +def serialize_for_form(data, path="", res={}): + print(data, path, res) + if isinstance(data, Mapping): + for key, val in data.items(): + if val is not None: + xxx = f"{path}[{key}]" if path else key + serialize_for_form(val, xxx, res) + elif isinstance(data, Sequence) and not isinstance(data, str): + for i, elem in enumerate(data): + serialize_for_form(elem, f"{path}[{i}]", res) + elif isinstance(data, Number): + res[path] = str(data) + else: + res[path] = data + + return res + + +if __name__ == "__main__": + dat = { + "text_prompts": [{"text": "Birthday party at a gothic church", "weight": None}], + "init_image": None, + "init_image_mode": None, + "image_strength": 0.3, + "step_schedule_start": None, + "step_schedule_end": None, + "cfg_scale": None, + "clip_guidance_preset": None, + "sampler": None, + "sample_extras": None, + } + s = serialize_for_form(dat) + + print(s)