diff --git a/stabilityai/client.py b/stabilityai/client.py index 8fd1282..f0f8cfe 100644 --- a/stabilityai/client.py +++ b/stabilityai/client.py @@ -3,6 +3,7 @@ import textwrap from typing import Optional import aiohttp +from aiohttp.client_exceptions import ClientResponseError from pydantic import validate_arguments from stabilityai.constants import ( @@ -15,6 +16,7 @@ from stabilityai.constants import ( from stabilityai.exceptions import ( ThisFunctionRequiresAPrompt, YouNeedToUseAContextManager, + figure_out_exception, ) from stabilityai.models import ( AccountResponseBody, @@ -188,12 +190,16 @@ class AsyncStabilityClient: extras=extras, ) - res = await self.session.post( - f"/v1/generation/{engine_id}/text-to-image", - headers={"Content-Type": "application/json"}, - data=request_body.json(exclude_defaults=True, exclude_none=True, exclude_unset=True), - ) - Sampler.K_DPMPP_2M + try: + res = await self.session.post( + f"/v1/generation/{engine_id}/text-to-image", + headers={"Content-Type": "application/json"}, + data=request_body.json( + exclude_defaults=True, exclude_none=True, exclude_unset=True + ), + ) + except ClientResponseError as e: + raise figure_out_exception(e) from e return TextToImageResponseBody.parse_obj(await res.json()) @@ -257,10 +263,15 @@ class AsyncStabilityClient: extras=extras, ) - res = await self.session.post( - f"/v1/generation/{engine_id}/text-to-image", - data=request_body.json(exclude_none=True, exclude_defaults=True, exclude_unset=True), - ) + try: + res = await self.session.post( + f"/v1/generation/{engine_id}/text-to-image", + data=request_body.json( + exclude_none=True, exclude_defaults=True, exclude_unset=True + ), + ) + except ClientResponseError as e: + raise figure_out_exception(e) from e return ImageToImageResponseBody.parse_obj(await res.json()) @@ -283,9 +294,12 @@ class AsyncStabilityClient: form.add_field("width", "1024") form.add_field("image", request_body.image) - res = await self.session.post( - f"/v1/generation/{engine_id}/image-to-image/upscale", - data=form, - ) + try: + res = await self.session.post( + f"/v1/generation/{engine_id}/image-to-image/upscale", + data=form, + ) + except ClientResponseError as e: + raise figure_out_exception(e) from e return ImageToImageUpscaleResponseBody.parse_obj(await res.json()) diff --git a/stabilityai/exceptions.py b/stabilityai/exceptions.py index 8a2458e..9c99bb3 100644 --- a/stabilityai/exceptions.py +++ b/stabilityai/exceptions.py @@ -1,6 +1,124 @@ +import textwrap +from aiohttp.client_exceptions import ClientResponseError + + class YouNeedToUseAContextManager(ValueError): pass class ThisFunctionRequiresAPrompt(ValueError): pass + + +class StabilityAiError(Exception): + retryable: bool = False + + +class YouAreHoldingItWrong(StabilityAiError): + pass + + +class ApiKeyIsMissingOrInvalid(StabilityAiError): + pass + + +class ImAfraidICantLetYouDoThat(StabilityAiError): + pass + + +class YouBrokeTheirServeres(StabilityAiError): + retryable: bool = True + + +class ResourceNotFound(StabilityAiError): + pass + + +class RateLimitOrServerError(StabilityAiError): + retryable: bool = True + + +def figure_out_exception(e: ClientResponseError) -> StabilityAiError: + if e.status == 400: + return YouAreHoldingItWrong( + textwrap.dedent( + """\ + Either you specificed some combination of options that their + server doesn't like or it's a bug in this library. If you + think it's the latter open up a bug report on + + https://github.com/estheruary/stabilityai + """ + ) + ) + elif e.status == 401: + return ApiKeyIsMissingOrInvalid( + textwrap.dedent( + """\ + This could be one of a few problems. If you don't have an API + key you should generate one at + + https://beta.dreamstudio.ai/account + + If you have an API key there are two ways to provide it to this + library. The first and most common is via the environment. You + to set STABILITY_API_KEY. + + The second method is to provide it to the client constructor + directly. + + async with AsyncStabilityClient(api_key="your-key") as stability: + ... + """ + ) + ) + elif e.status == 403: + return ImAfraidICantLetYouDoThat( + textwrap.dedent( + """\ + At the time of writing (Jun 2023) Stability doesn't have any + permissions on their API so if you get this it's likely + a bug in this library or a problem on their servers. If you + think it's the former report it at + + https://github.com/estheruary/stabilityai + """ + ) + ) + elif e.status == 429: + return RateLimitOrServerError( + textwrap.dedent( + """\ + Stability's rate limit is 150 req / 10 sec. If you get this + error on *every* request then it's a bug on their servers. + You need to generate a new API key and delete the old one. + """ + ) + ) + elif e.status == 404: + return ResourceNotFound( + textwrap.dedent( + """\ + This usually happens when you specify models that don't exist + *or* a combination of model/sampler that doesn't work together. + """ + ) + ) + elif e.status == 500: + return YouBrokeTheirServeres( + textwrap.dedent( + """\ + Impressive. This is not an error you can do anything about + except retry. Stability might be having an outage or a + problem on their servers. + """ + ) + ) + else: + return StabilityAiError( + textwrap.dedent( + """\ + Sorry, I don't know what's up with this error. + """ + ) + )