Seein' it I2I

pull/1/head
Estelle Poulin 1 year ago
parent 57cbe0b85f
commit ac59b1c435

@ -1 +1,3 @@
__version__ = "1.0.3"
from pkg_resources import parse_version
__version__ = parse_version("1.0.4")

@ -48,6 +48,7 @@ from stabilityai.models import (
UpscaleImageHeight,
UpscaleImageWidth,
)
from stabilityai.utils import serialize_for_form
class AsyncStabilityClient:
@ -150,7 +151,7 @@ class AsyncStabilityClient:
)
)
assert text_prompts is not None
text_prompts = text_prompts or [TextPrompt(text=text_prompt)]
return text_prompts
@validate_arguments
@ -227,9 +228,9 @@ class AsyncStabilityClient:
@validate_arguments(config={"arbitrary_types_allowed": True})
async def image_to_image(
self,
text_prompts: TextPrompts,
text_prompt: SingleTextPrompt,
init_image: InitImage,
text_prompts: Optional[TextPrompts] = None,
text_prompt: Optional[SingleTextPrompt] = None,
*,
init_image_mode: Optional[InitImageMode] = None,
image_strength: InitImageStrength,
@ -263,12 +264,16 @@ class AsyncStabilityClient:
extras=extras,
)
form = aiohttp.FormData()
form_data = serialize_for_form(request_body.dict())
for k, v in form_data.items():
form.add_field(k, v)
try:
res = await self.session.post(
f"/v1/generation/{engine_id}/text-to-image",
data=request_body.json(
exclude_none=True, exclude_defaults=True, exclude_unset=True
),
f"/v1/generation/{engine_id}/image-to-image",
data=form,
)
except ClientResponseError as e:
raise figure_out_exception(e) from e

@ -1,4 +1,5 @@
import textwrap
from aiohttp.client_exceptions import ClientResponseError

@ -0,0 +1,38 @@
from numbers import Number
from typing import Mapping, Sequence
def serialize_for_form(data, path="", res={}):
print(data, path, res)
if isinstance(data, Mapping):
for key, val in data.items():
if val is not None:
xxx = f"{path}[{key}]" if path else key
serialize_for_form(val, xxx, res)
elif isinstance(data, Sequence) and not isinstance(data, str):
for i, elem in enumerate(data):
serialize_for_form(elem, f"{path}[{i}]", res)
elif isinstance(data, Number):
res[path] = str(data)
else:
res[path] = data
return res
if __name__ == "__main__":
dat = {
"text_prompts": [{"text": "Birthday party at a gothic church", "weight": None}],
"init_image": None,
"init_image_mode": None,
"image_strength": 0.3,
"step_schedule_start": None,
"step_schedule_end": None,
"cfg_scale": None,
"clip_guidance_preset": None,
"sampler": None,
"sample_extras": None,
}
s = serialize_for_form(dat)
print(s)
Loading…
Cancel
Save