From ac59b1c435e2d13351499a3241f34aefc408c532 Mon Sep 17 00:00:00 2001 From: Estelle Poulin Date: Thu, 26 Oct 2023 22:58:33 -0400 Subject: [PATCH] Seein' it I2I --- stabilityai/__init__.py | 4 +++- stabilityai/client.py | 19 ++++++++++++------- stabilityai/exceptions.py | 1 + stabilityai/utils.py | 38 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 54 insertions(+), 8 deletions(-) create mode 100644 stabilityai/utils.py diff --git a/stabilityai/__init__.py b/stabilityai/__init__.py index 976498a..77a855a 100644 --- a/stabilityai/__init__.py +++ b/stabilityai/__init__.py @@ -1 +1,3 @@ -__version__ = "1.0.3" +from pkg_resources import parse_version + +__version__ = parse_version("1.0.4") diff --git a/stabilityai/client.py b/stabilityai/client.py index f0f8cfe..2ec5de5 100644 --- a/stabilityai/client.py +++ b/stabilityai/client.py @@ -48,6 +48,7 @@ from stabilityai.models import ( UpscaleImageHeight, UpscaleImageWidth, ) +from stabilityai.utils import serialize_for_form class AsyncStabilityClient: @@ -150,7 +151,7 @@ class AsyncStabilityClient: ) ) - assert text_prompts is not None + text_prompts = text_prompts or [TextPrompt(text=text_prompt)] return text_prompts @validate_arguments @@ -227,9 +228,9 @@ class AsyncStabilityClient: @validate_arguments(config={"arbitrary_types_allowed": True}) async def image_to_image( self, - text_prompts: TextPrompts, - text_prompt: SingleTextPrompt, init_image: InitImage, + text_prompts: Optional[TextPrompts] = None, + text_prompt: Optional[SingleTextPrompt] = None, *, init_image_mode: Optional[InitImageMode] = None, image_strength: InitImageStrength, @@ -263,12 +264,16 @@ class AsyncStabilityClient: extras=extras, ) + form = aiohttp.FormData() + form_data = serialize_for_form(request_body.dict()) + + for k, v in form_data.items(): + form.add_field(k, v) + try: res = await self.session.post( - f"/v1/generation/{engine_id}/text-to-image", - data=request_body.json( - exclude_none=True, exclude_defaults=True, exclude_unset=True - ), + f"/v1/generation/{engine_id}/image-to-image", + data=form, ) except ClientResponseError as e: raise figure_out_exception(e) from e diff --git a/stabilityai/exceptions.py b/stabilityai/exceptions.py index 9c99bb3..1e0a3b2 100644 --- a/stabilityai/exceptions.py +++ b/stabilityai/exceptions.py @@ -1,4 +1,5 @@ import textwrap + from aiohttp.client_exceptions import ClientResponseError diff --git a/stabilityai/utils.py b/stabilityai/utils.py new file mode 100644 index 0000000..7d9a114 --- /dev/null +++ b/stabilityai/utils.py @@ -0,0 +1,38 @@ +from numbers import Number +from typing import Mapping, Sequence + + +def serialize_for_form(data, path="", res={}): + print(data, path, res) + if isinstance(data, Mapping): + for key, val in data.items(): + if val is not None: + xxx = f"{path}[{key}]" if path else key + serialize_for_form(val, xxx, res) + elif isinstance(data, Sequence) and not isinstance(data, str): + for i, elem in enumerate(data): + serialize_for_form(elem, f"{path}[{i}]", res) + elif isinstance(data, Number): + res[path] = str(data) + else: + res[path] = data + + return res + + +if __name__ == "__main__": + dat = { + "text_prompts": [{"text": "Birthday party at a gothic church", "weight": None}], + "init_image": None, + "init_image_mode": None, + "image_strength": 0.3, + "step_schedule_start": None, + "step_schedule_end": None, + "cfg_scale": None, + "clip_guidance_preset": None, + "sampler": None, + "sample_extras": None, + } + s = serialize_for_form(dat) + + print(s)