Merge pull request #1 from estheruary/dev

Get Image2Image Working
main
Estelle Poulin 1 year ago committed by GitHub
commit a479e8c23c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1 +1,3 @@
__version__ = "1.0.3" from pkg_resources import parse_version
__version__ = parse_version("1.0.4")

@ -1,11 +1,11 @@
import json
import os import os
import textwrap
from typing import Optional from typing import Optional
import aiohttp import aiohttp
from pydantic import ( from aiohttp.client_exceptions import ClientResponseError
validate_arguments, from pydantic import validate_arguments
)
from stabilityai.constants import ( from stabilityai.constants import (
DEFAULT_GENERATION_ENGINE, DEFAULT_GENERATION_ENGINE,
DEFAULT_STABILITY_API_HOST, DEFAULT_STABILITY_API_HOST,
@ -13,7 +13,11 @@ from stabilityai.constants import (
DEFAULT_STABILITY_CLIENT_VERSION, DEFAULT_STABILITY_CLIENT_VERSION,
DEFAULT_UPSCALE_ENGINE, DEFAULT_UPSCALE_ENGINE,
) )
from stabilityai.exceptions import ThisFunctionRequiresAPrompt, YouNeedToUseAContextManager from stabilityai.exceptions import (
ThisFunctionRequiresAPrompt,
YouNeedToUseAContextManager,
figure_out_exception,
)
from stabilityai.models import ( from stabilityai.models import (
AccountResponseBody, AccountResponseBody,
BalanceResponseBody, BalanceResponseBody,
@ -44,8 +48,7 @@ from stabilityai.models import (
UpscaleImageHeight, UpscaleImageHeight,
UpscaleImageWidth, UpscaleImageWidth,
) )
from stabilityai.utils import omit_none from stabilityai.utils import serialize_for_form
import textwrap
class AsyncStabilityClient: class AsyncStabilityClient:
@ -69,7 +72,6 @@ class AsyncStabilityClient:
self.session = await aiohttp.ClientSession( self.session = await aiohttp.ClientSession(
base_url=self.api_host, base_url=self.api_host,
headers={ headers={
"Content-Type": "application/json",
"Accept": "application/json", "Accept": "application/json",
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
"Stability-Client-ID": self.client_id, "Stability-Client-ID": self.client_id,
@ -122,9 +124,7 @@ class AsyncStabilityClient:
) )
def _normalize_text_prompts( def _normalize_text_prompts(
self, self, text_prompts: Optional[TextPrompts], text_prompt: Optional[SingleTextPrompt]
text_prompts: Optional[TextPrompts],
text_prompt: Optional[SingleTextPrompt]
): ):
if not bool(text_prompt) ^ bool(text_prompts): if not bool(text_prompt) ^ bool(text_prompts):
raise ThisFunctionRequiresAPrompt( raise ThisFunctionRequiresAPrompt(
@ -151,11 +151,7 @@ class AsyncStabilityClient:
) )
) )
if text_prompt: text_prompts = text_prompts or [TextPrompt(text=text_prompt)]
text_prompts = [TextPrompt(text=text_prompt)]
# After this moment text_prompts can't be None.
assert text_prompts is not None
return text_prompts return text_prompts
@validate_arguments @validate_arguments
@ -195,11 +191,16 @@ class AsyncStabilityClient:
extras=extras, extras=extras,
) )
try:
res = await self.session.post( res = await self.session.post(
f"/v1/generation/{engine_id}/text-to-image", f"/v1/generation/{engine_id}/text-to-image",
data=json.dumps(omit_none(json.loads(request_body.json()))), headers={"Content-Type": "application/json"},
data=request_body.json(
exclude_defaults=True, exclude_none=True, exclude_unset=True
),
) )
Sampler.K_DPMPP_2M except ClientResponseError as e:
raise figure_out_exception(e) from e
return TextToImageResponseBody.parse_obj(await res.json()) return TextToImageResponseBody.parse_obj(await res.json())
@ -224,12 +225,12 @@ class AsyncStabilityClient:
height=dim, height=dim,
) )
@validate_arguments @validate_arguments(config={"arbitrary_types_allowed": True})
async def image_to_image( async def image_to_image(
self, self,
text_prompts: TextPrompts,
text_prompt: SingleTextPrompt,
init_image: InitImage, init_image: InitImage,
text_prompts: Optional[TextPrompts] = None,
text_prompt: Optional[SingleTextPrompt] = None,
*, *,
init_image_mode: Optional[InitImageMode] = None, init_image_mode: Optional[InitImageMode] = None,
image_strength: InitImageStrength, image_strength: InitImageStrength,
@ -263,13 +264,23 @@ class AsyncStabilityClient:
extras=extras, extras=extras,
) )
form = aiohttp.FormData()
form_data = serialize_for_form(request_body.dict())
for k, v in form_data.items():
form.add_field(k, v)
try:
res = await self.session.post( res = await self.session.post(
f"/v1/generation/{engine_id}/text-to-image", f"/v1/generation/{engine_id}/image-to-image",
data=json.dumps(omit_none(json.loads(request_body.json()))), data=form,
) )
except ClientResponseError as e:
raise figure_out_exception(e) from e
return ImageToImageResponseBody.parse_obj(await res.json()) return ImageToImageResponseBody.parse_obj(await res.json())
@validate_arguments(config={"arbitrary_types_allowed": True})
async def image_to_image_upscale( async def image_to_image_upscale(
self, self,
image: InitImage, image: InitImage,
@ -284,9 +295,16 @@ class AsyncStabilityClient:
request_body = ImageToImageUpscaleRequestBody(image=image, width=width, height=height) request_body = ImageToImageUpscaleRequestBody(image=image, width=width, height=height)
form = aiohttp.FormData()
form.add_field("width", "1024")
form.add_field("image", request_body.image)
try:
res = await self.session.post( res = await self.session.post(
f"/v1/generation/{engine_id}/image-to-image/upscale", f"/v1/generation/{engine_id}/image-to-image/upscale",
data=json.dumps(omit_none(json.loads(request_body.json()))), data=form,
) )
except ClientResponseError as e:
raise figure_out_exception(e) from e
return ImageToImageUpscaleResponseBody.parse_obj(await res.json()) return ImageToImageUpscaleResponseBody.parse_obj(await res.json())

@ -1,6 +1,5 @@
from typing import Final from typing import Final
DEFAULT_STABILITY_API_HOST: Final[str] = "https://api.stability.ai" DEFAULT_STABILITY_API_HOST: Final[str] = "https://api.stability.ai"
DEFAULT_STABILITY_CLIENT_ID: Final[str] = "Stability Python SDK" DEFAULT_STABILITY_CLIENT_ID: Final[str] = "Stability Python SDK"
DEFAULT_STABILITY_CLIENT_VERSION: Final[str] = "1.0.0" DEFAULT_STABILITY_CLIENT_VERSION: Final[str] = "1.0.0"

@ -1,6 +1,125 @@
import textwrap
from aiohttp.client_exceptions import ClientResponseError
class YouNeedToUseAContextManager(ValueError): class YouNeedToUseAContextManager(ValueError):
pass pass
class ThisFunctionRequiresAPrompt(ValueError): class ThisFunctionRequiresAPrompt(ValueError):
pass pass
class StabilityAiError(Exception):
retryable: bool = False
class YouAreHoldingItWrong(StabilityAiError):
pass
class ApiKeyIsMissingOrInvalid(StabilityAiError):
pass
class ImAfraidICantLetYouDoThat(StabilityAiError):
pass
class YouBrokeTheirServeres(StabilityAiError):
retryable: bool = True
class ResourceNotFound(StabilityAiError):
pass
class RateLimitOrServerError(StabilityAiError):
retryable: bool = True
def figure_out_exception(e: ClientResponseError) -> StabilityAiError:
if e.status == 400:
return YouAreHoldingItWrong(
textwrap.dedent(
"""\
Either you specificed some combination of options that their
server doesn't like or it's a bug in this library. If you
think it's the latter open up a bug report on
https://github.com/estheruary/stabilityai
"""
)
)
elif e.status == 401:
return ApiKeyIsMissingOrInvalid(
textwrap.dedent(
"""\
This could be one of a few problems. If you don't have an API
key you should generate one at
https://beta.dreamstudio.ai/account
If you have an API key there are two ways to provide it to this
library. The first and most common is via the environment. You
to set STABILITY_API_KEY.
The second method is to provide it to the client constructor
directly.
async with AsyncStabilityClient(api_key="your-key") as stability:
...
"""
)
)
elif e.status == 403:
return ImAfraidICantLetYouDoThat(
textwrap.dedent(
"""\
At the time of writing (Jun 2023) Stability doesn't have any
permissions on their API so if you get this it's likely
a bug in this library or a problem on their servers. If you
think it's the former report it at
https://github.com/estheruary/stabilityai
"""
)
)
elif e.status == 429:
return RateLimitOrServerError(
textwrap.dedent(
"""\
Stability's rate limit is 150 req / 10 sec. If you get this
error on *every* request then it's a bug on their servers.
You need to generate a new API key and delete the old one.
"""
)
)
elif e.status == 404:
return ResourceNotFound(
textwrap.dedent(
"""\
This usually happens when you specify models that don't exist
*or* a combination of model/sampler that doesn't work together.
"""
)
)
elif e.status == 500:
return YouBrokeTheirServeres(
textwrap.dedent(
"""\
Impressive. This is not an error you can do anything about
except retry. Stability might be having an outage or a
problem on their servers.
"""
)
)
else:
return StabilityAiError(
textwrap.dedent(
"""\
Sorry, I don't know what's up with this error.
"""
)
)

@ -3,15 +3,8 @@ import io
from enum import StrEnum from enum import StrEnum
from typing import Annotated, List, Optional from typing import Annotated, List, Optional
from pydantic import ( from pydantic import AnyUrl, BaseModel, EmailStr, confloat, conint, constr, validator
AnyUrl,
BaseModel,
EmailStr,
confloat,
conint,
constr,
validator,
)
class Type(StrEnum): class Type(StrEnum):
AUDIO = "AUDIO" AUDIO = "AUDIO"
@ -103,11 +96,12 @@ class TextPrompt(BaseModel):
text: Annotated[str, constr(max_length=2000)] text: Annotated[str, constr(max_length=2000)]
weight: Optional[float] = None weight: Optional[float] = None
SingleTextPrompt = str SingleTextPrompt = str
TextPrompts = List[TextPrompt] TextPrompts = List[TextPrompt]
InitImage = bytes InitImage = io.IOBase
InitImageStrength = Annotated[float, confloat(ge=0.0, le=1.0)] InitImageStrength = Annotated[float, confloat(ge=0.0, le=1.0)]
@ -146,6 +140,9 @@ class LatentUpscalerUpscaleRequestBody(BaseModel):
steps: Optional[Steps] = None steps: Optional[Steps] = None
cfg_scale: Optional[CfgScale] = None cfg_scale: Optional[CfgScale] = None
class Config:
arbitrary_types_allowed = True
class ImageToImageRequestBody(BaseModel): class ImageToImageRequestBody(BaseModel):
text_prompts: TextPrompts text_prompts: TextPrompts
@ -163,6 +160,9 @@ class ImageToImageRequestBody(BaseModel):
style_preset: Optional[StylePreset] = None style_preset: Optional[StylePreset] = None
extras: Optional[Extras] = None extras: Optional[Extras] = None
class Config:
arbitrary_types_allowed = True
class MaskingRequestBody(BaseModel): class MaskingRequestBody(BaseModel):
init_image: InitImage init_image: InitImage
@ -178,6 +178,9 @@ class MaskingRequestBody(BaseModel):
style_preset: Optional[StylePreset] = None style_preset: Optional[StylePreset] = None
extras: Optional[Extras] = None extras: Optional[Extras] = None
class Config:
arbitrary_types_allowed = True
class TextToImageRequestBody(GenerationRequestOptionalParams): class TextToImageRequestBody(GenerationRequestOptionalParams):
height: Optional[DiffuseImageHeight] = None height: Optional[DiffuseImageHeight] = None
@ -192,10 +195,13 @@ class ImageToImageUpscaleRequestBody(BaseModel):
@validator("width", always=True) @validator("width", always=True)
def mutually_exclusive(cls, v, values): def mutually_exclusive(cls, v, values):
if values["height"] is not None and v: if values.get("height") is not None and v:
raise ValueError("You can only specify one of width and height.") raise ValueError("You can only specify one of width and height.")
return v return v
class Config:
arbitrary_types_allowed = True
class BalanceResponseBody(BaseModel): class BalanceResponseBody(BaseModel):
credits: float credits: float

@ -1,6 +1,38 @@
from typing import Dict, Mapping, TypeVar from numbers import Number
from typing import Mapping, Sequence
T, U = TypeVar("T"), TypeVar("U")
def omit_none(m: Mapping[T, U]) -> Dict[T, U]: def serialize_for_form(data, path="", res={}):
return {k: v for k, v in m.items() if v is not None} print(data, path, res)
if isinstance(data, Mapping):
for key, val in data.items():
if val is not None:
xxx = f"{path}[{key}]" if path else key
serialize_for_form(val, xxx, res)
elif isinstance(data, Sequence) and not isinstance(data, str):
for i, elem in enumerate(data):
serialize_for_form(elem, f"{path}[{i}]", res)
elif isinstance(data, Number):
res[path] = str(data)
else:
res[path] = data
return res
if __name__ == "__main__":
dat = {
"text_prompts": [{"text": "Birthday party at a gothic church", "weight": None}],
"init_image": None,
"init_image_mode": None,
"image_strength": 0.3,
"step_schedule_start": None,
"step_schedule_end": None,
"cfg_scale": None,
"clip_guidance_preset": None,
"sampler": None,
"sample_extras": None,
}
s = serialize_for_form(dat)
print(s)

Loading…
Cancel
Save