|
|
@ -1,11 +1,10 @@
|
|
|
|
import json
|
|
|
|
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
import textwrap
|
|
|
|
from typing import Optional
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
import aiohttp
|
|
|
|
import aiohttp
|
|
|
|
from pydantic import (
|
|
|
|
from pydantic import validate_arguments
|
|
|
|
validate_arguments,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
from stabilityai.constants import (
|
|
|
|
from stabilityai.constants import (
|
|
|
|
DEFAULT_GENERATION_ENGINE,
|
|
|
|
DEFAULT_GENERATION_ENGINE,
|
|
|
|
DEFAULT_STABILITY_API_HOST,
|
|
|
|
DEFAULT_STABILITY_API_HOST,
|
|
|
@ -13,7 +12,10 @@ from stabilityai.constants import (
|
|
|
|
DEFAULT_STABILITY_CLIENT_VERSION,
|
|
|
|
DEFAULT_STABILITY_CLIENT_VERSION,
|
|
|
|
DEFAULT_UPSCALE_ENGINE,
|
|
|
|
DEFAULT_UPSCALE_ENGINE,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from stabilityai.exceptions import ThisFunctionRequiresAPrompt, YouNeedToUseAContextManager
|
|
|
|
from stabilityai.exceptions import (
|
|
|
|
|
|
|
|
ThisFunctionRequiresAPrompt,
|
|
|
|
|
|
|
|
YouNeedToUseAContextManager,
|
|
|
|
|
|
|
|
)
|
|
|
|
from stabilityai.models import (
|
|
|
|
from stabilityai.models import (
|
|
|
|
AccountResponseBody,
|
|
|
|
AccountResponseBody,
|
|
|
|
BalanceResponseBody,
|
|
|
|
BalanceResponseBody,
|
|
|
@ -44,8 +46,6 @@ from stabilityai.models import (
|
|
|
|
UpscaleImageHeight,
|
|
|
|
UpscaleImageHeight,
|
|
|
|
UpscaleImageWidth,
|
|
|
|
UpscaleImageWidth,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from stabilityai.utils import omit_none
|
|
|
|
|
|
|
|
import textwrap
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AsyncStabilityClient:
|
|
|
|
class AsyncStabilityClient:
|
|
|
@ -69,7 +69,6 @@ class AsyncStabilityClient:
|
|
|
|
self.session = await aiohttp.ClientSession(
|
|
|
|
self.session = await aiohttp.ClientSession(
|
|
|
|
base_url=self.api_host,
|
|
|
|
base_url=self.api_host,
|
|
|
|
headers={
|
|
|
|
headers={
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
|
|
|
|
"Accept": "application/json",
|
|
|
|
"Accept": "application/json",
|
|
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
|
|
"Stability-Client-ID": self.client_id,
|
|
|
|
"Stability-Client-ID": self.client_id,
|
|
|
@ -122,9 +121,7 @@ class AsyncStabilityClient:
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _normalize_text_prompts(
|
|
|
|
def _normalize_text_prompts(
|
|
|
|
self,
|
|
|
|
self, text_prompts: Optional[TextPrompts], text_prompt: Optional[SingleTextPrompt]
|
|
|
|
text_prompts: Optional[TextPrompts],
|
|
|
|
|
|
|
|
text_prompt: Optional[SingleTextPrompt]
|
|
|
|
|
|
|
|
):
|
|
|
|
):
|
|
|
|
if not bool(text_prompt) ^ bool(text_prompts):
|
|
|
|
if not bool(text_prompt) ^ bool(text_prompts):
|
|
|
|
raise ThisFunctionRequiresAPrompt(
|
|
|
|
raise ThisFunctionRequiresAPrompt(
|
|
|
@ -151,10 +148,6 @@ class AsyncStabilityClient:
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if text_prompt:
|
|
|
|
|
|
|
|
text_prompts = [TextPrompt(text=text_prompt)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# After this moment text_prompts can't be None.
|
|
|
|
|
|
|
|
assert text_prompts is not None
|
|
|
|
assert text_prompts is not None
|
|
|
|
return text_prompts
|
|
|
|
return text_prompts
|
|
|
|
|
|
|
|
|
|
|
@ -197,7 +190,8 @@ class AsyncStabilityClient:
|
|
|
|
|
|
|
|
|
|
|
|
res = await self.session.post(
|
|
|
|
res = await self.session.post(
|
|
|
|
f"/v1/generation/{engine_id}/text-to-image",
|
|
|
|
f"/v1/generation/{engine_id}/text-to-image",
|
|
|
|
data=json.dumps(omit_none(json.loads(request_body.json()))),
|
|
|
|
headers={"Content-Type": "application/json"},
|
|
|
|
|
|
|
|
data=request_body.json(exclude_defaults=True, exclude_none=True, exclude_unset=True),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
Sampler.K_DPMPP_2M
|
|
|
|
Sampler.K_DPMPP_2M
|
|
|
|
|
|
|
|
|
|
|
@ -224,7 +218,7 @@ class AsyncStabilityClient:
|
|
|
|
height=dim,
|
|
|
|
height=dim,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@validate_arguments
|
|
|
|
@validate_arguments(config={"arbitrary_types_allowed": True})
|
|
|
|
async def image_to_image(
|
|
|
|
async def image_to_image(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
text_prompts: TextPrompts,
|
|
|
|
text_prompts: TextPrompts,
|
|
|
@ -265,11 +259,12 @@ class AsyncStabilityClient:
|
|
|
|
|
|
|
|
|
|
|
|
res = await self.session.post(
|
|
|
|
res = await self.session.post(
|
|
|
|
f"/v1/generation/{engine_id}/text-to-image",
|
|
|
|
f"/v1/generation/{engine_id}/text-to-image",
|
|
|
|
data=json.dumps(omit_none(json.loads(request_body.json()))),
|
|
|
|
data=request_body.json(exclude_none=True, exclude_defaults=True, exclude_unset=True),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return ImageToImageResponseBody.parse_obj(await res.json())
|
|
|
|
return ImageToImageResponseBody.parse_obj(await res.json())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@validate_arguments(config={"arbitrary_types_allowed": True})
|
|
|
|
async def image_to_image_upscale(
|
|
|
|
async def image_to_image_upscale(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
image: InitImage,
|
|
|
|
image: InitImage,
|
|
|
@ -284,9 +279,13 @@ class AsyncStabilityClient:
|
|
|
|
|
|
|
|
|
|
|
|
request_body = ImageToImageUpscaleRequestBody(image=image, width=width, height=height)
|
|
|
|
request_body = ImageToImageUpscaleRequestBody(image=image, width=width, height=height)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
form = aiohttp.FormData()
|
|
|
|
|
|
|
|
form.add_field("width", "1024")
|
|
|
|
|
|
|
|
form.add_field("image", request_body.image)
|
|
|
|
|
|
|
|
|
|
|
|
res = await self.session.post(
|
|
|
|
res = await self.session.post(
|
|
|
|
f"/v1/generation/{engine_id}/image-to-image/upscale",
|
|
|
|
f"/v1/generation/{engine_id}/image-to-image/upscale",
|
|
|
|
data=json.dumps(omit_none(json.loads(request_body.json()))),
|
|
|
|
data=form,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return ImageToImageUpscaleResponseBody.parse_obj(await res.json())
|
|
|
|
return ImageToImageUpscaleResponseBody.parse_obj(await res.json())
|
|
|
|