wip
parent
f9859d1767
commit
ab492417d6
@ -0,0 +1,41 @@
|
||||
---
|
||||
|
||||
name: Publish Python distributions to PyPI and TestPyPI
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
validate:
|
||||
name: Run static analysis on the code
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
pip install flake8
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# exit-zero treats all errors as warnings.
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
||||
|
||||
build-and-publish:
|
||||
name: Build and publish Python distribution
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
- name: Initialize Python 3.11
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.11
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip build
|
||||
- name: Build binary wheel and a source tarball
|
||||
run: python -m build
|
||||
- name: Publish distribution to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@v1.8.5
|
||||
with:
|
||||
password: ${{ secrets.PIPY_PASSWORD }}
|
||||
|
@ -1,2 +1,71 @@
|
||||
# stabilityai
|
||||
Client library for the stability.ai REST API
|
||||
# Stability AI
|
||||
|
||||
An **UNOFFICIAL** client library for the stability.ai REST API.
|
||||
|
||||
## Motivation
|
||||
|
||||
The official `stability-sdk` is a based on gRPC and also really hard to use. Like look at this, this
|
||||
ignores setting up the SDK.
|
||||
|
||||
```python
|
||||
from stability_sdk import client
|
||||
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
|
||||
|
||||
answers = stability_api.generate(
|
||||
prompt="a rocket-ship launching from rolling greens with blue daisies",
|
||||
seed=892226758,
|
||||
steps=30,
|
||||
cfg_scale=8.0,
|
||||
width=512,
|
||||
height=512,
|
||||
sampler=generation.SAMPLER_K_DPMPP_2M
|
||||
)
|
||||
|
||||
for resp in answers:
|
||||
for artifact in resp.artifacts:
|
||||
if artifact.finish_reason == generation.FILTER:
|
||||
warnings.warn(
|
||||
"Your request activated the API's safety filters and could not be processed."
|
||||
"Please modify the prompt and try again.")
|
||||
if artifact.type == generation.ARTIFACT_IMAGE:
|
||||
global img
|
||||
img = Image.open(io.BytesIO(artifact.binary))
|
||||
img.save(str(artifact.seed)+ ".png")
|
||||
```
|
||||
|
||||
This for loop is *magic*. You must loop the results in exactly this way or the gRPC library won't
|
||||
work. It's about an unpythonic as a library can get.
|
||||
|
||||
## My Take
|
||||
|
||||
```python
|
||||
# Set the STABILITY_API_KEY environment variable.
|
||||
|
||||
from stabilityai.client import AsyncStabilityClient
|
||||
from stabilityai.models import Sampler
|
||||
|
||||
async def example():
|
||||
async with AsyncStabilityClient() as stability:
|
||||
results = await stability.text_to_image(
|
||||
text_prompt="a rocket-ship launching from rolling greens with blue daisies",
|
||||
# All these are optional and have sane defaults.
|
||||
seed=892226758,
|
||||
steps=30,
|
||||
cfg_scale=8.0,
|
||||
width=512,
|
||||
height=512,
|
||||
sampler=Sampler.K_DPMPP_2M,
|
||||
)
|
||||
|
||||
artifact = results.artifacts[0]
|
||||
|
||||
img = Image.open(artifact.file)
|
||||
img.save(artifact.file.name)
|
||||
```
|
||||
|
||||
## Additional Nicetieis
|
||||
|
||||
* Instead of manually checking `FINISH_REASON` an appropriate exception will automatically be
|
||||
raised.
|
||||
|
||||
* Full mypy/pyright support for type checking and autocomplete.
|
||||
|
@ -0,0 +1,50 @@
|
||||
# pyproject.toml
|
||||
# https://packaging.python.org/en/latest/specifications/declaring-project-metadata
|
||||
# https://pip.pypa.io/en/stable/reference/build-system/pyproject-toml
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools", "setuptools-scm", "build"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
|
||||
[project]
|
||||
name = "stabilityai"
|
||||
description = "*Unofficial* client for the Stability REST API"
|
||||
authors = [
|
||||
{name = "Estelle Poulin", email = "dev@inspiredby.es"},
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
keywords = ["stabilityai", "bot"]
|
||||
license = {text = "GPLv3"}
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
]
|
||||
dynamic = ["version", "dependencies"]
|
||||
|
||||
|
||||
[project.urls]
|
||||
homepage = "https://github.com/estheruary/stabilityai"
|
||||
repository = "https://github.com/estheruary/stabilityai"
|
||||
changelog = "https://github.com/estheruary/stabilityai/-/blob/main/CHANGELOG.md"
|
||||
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["stabilityai"]
|
||||
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = {attr = "stabilityai.__version__"}
|
||||
dependencies = {file = ["requirements.txt"]}
|
||||
|
||||
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
||||
|
||||
[tool.vulture]
|
||||
ignore_names = ["self", "cls"]
|
@ -0,0 +1,2 @@
|
||||
aiohttp==3.8.4
|
||||
pydantic [email]==1.10.2
|
@ -0,0 +1 @@
|
||||
__version__ = "1.0.1"
|
@ -0,0 +1,271 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
from pydantic import (
|
||||
validate_arguments,
|
||||
)
|
||||
from stabilityai.constants import (
|
||||
DEFAULT_GENERATION_ENGINE,
|
||||
DEFAULT_STABILITY_API_HOST,
|
||||
DEFAULT_STABILITY_CLIENT_ID,
|
||||
DEFAULT_STABILITY_CLIENT_VERSION,
|
||||
DEFAULT_UPSCALE_ENGINE,
|
||||
)
|
||||
from stabilityai.exceptions import YouNeedToUseAContextManager
|
||||
from stabilityai.models import (
|
||||
AccountResponseBody,
|
||||
BalanceResponseBody,
|
||||
CfgScale,
|
||||
ClipGuidancePreset,
|
||||
DiffuseImageHeight,
|
||||
DiffuseImageWidth,
|
||||
Engine,
|
||||
Extras,
|
||||
ImageToImageRequestBody,
|
||||
ImageToImageResponseBody,
|
||||
ImageToImageUpscaleRequestBody,
|
||||
ImageToImageUpscaleResponseBody,
|
||||
InitImage,
|
||||
InitImageMode,
|
||||
InitImageStrength,
|
||||
ListEnginesResponseBody,
|
||||
Sampler,
|
||||
Samples,
|
||||
Seed,
|
||||
SingleTextPrompt,
|
||||
Steps,
|
||||
StylePreset,
|
||||
TextPrompt,
|
||||
TextPrompts,
|
||||
TextToImageRequestBody,
|
||||
TextToImageResponseBody,
|
||||
UpscaleImageHeight,
|
||||
UpscaleImageWidth,
|
||||
)
|
||||
from stabilityai.utils import omit_none
|
||||
import textwrap
|
||||
|
||||
|
||||
class AsyncStabilityClient:
|
||||
def __init__(
|
||||
self,
|
||||
api_host: str = os.environ.get("STABILITY_API_HOST", DEFAULT_STABILITY_API_HOST),
|
||||
api_key: Optional[str] = os.environ.get("STABILITY_API_KEY"),
|
||||
client_id: str = os.environ.get("STABILITY_CLIENT_ID", DEFAULT_STABILITY_CLIENT_ID),
|
||||
client_version: str = os.environ.get(
|
||||
"STABILITY_CLIENT_VERSION", DEFAULT_STABILITY_CLIENT_VERSION
|
||||
),
|
||||
organization: Optional[str] = os.environ.get("STABILITY_CLIENT_ORGANIZATION"),
|
||||
) -> None:
|
||||
self.api_host = api_host
|
||||
self.api_key = api_key
|
||||
self.client_id = client_id
|
||||
self.client_version = client_version
|
||||
self.organization = organization
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = await aiohttp.ClientSession(
|
||||
base_url=self.api_host,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Stability-Client-ID": self.client_id,
|
||||
"Stability-Client-Version": self.client_version,
|
||||
**({"Organization": self.organization} if self.organization else {}),
|
||||
},
|
||||
raise_for_status=True,
|
||||
).__aenter__()
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return await self.session.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def get_engines(self) -> ListEnginesResponseBody:
|
||||
res = await self.session.get("/v1/engines/list")
|
||||
return await res.json()
|
||||
|
||||
async def get_account(self) -> AccountResponseBody:
|
||||
res = await self.session.get("/v1/user/account")
|
||||
return await res.json()
|
||||
|
||||
async def get_account_balance(self) -> BalanceResponseBody:
|
||||
res = await self.session.get("/v1/user/balance")
|
||||
return await res.json()
|
||||
|
||||
def _oops_no_session(self):
|
||||
if not self.session:
|
||||
raise YouNeedToUseAContextManager(
|
||||
textwrap.dedent(
|
||||
f"""\
|
||||
{self.__class__.__name__} keeps a aiohttp.ClientSession under
|
||||
the hood and needs to be closed when you're done with it. But
|
||||
since there isn't an async version of __del__ we have to use
|
||||
__aenter__/__aexit__ instead. Apologies.
|
||||
|
||||
Instead of
|
||||
|
||||
myclient = {self.__class__.__name__}()
|
||||
myclient.text_to_image(...)
|
||||
|
||||
Do this
|
||||
|
||||
async with {self.__class__.__name__} as myclient:
|
||||
myclient.text_to_image(...)
|
||||
|
||||
Note that it's `async with` and not `with`.
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
def _normalize_text_prompts(
|
||||
self,
|
||||
text_prompts: Optional[TextPrompts],
|
||||
text_prompt: Optional[SingleTextPrompt]
|
||||
):
|
||||
if not bool(text_prompt) ^ bool(text_prompts):
|
||||
raise RuntimeError(
|
||||
textwrap.dedent(
|
||||
f"""\
|
||||
You must provide one of text_prompt and text_prompts.
|
||||
|
||||
Do this
|
||||
|
||||
stability.text_to_image(
|
||||
text_prompt="A beautiful sunrise"
|
||||
)
|
||||
|
||||
Or this
|
||||
|
||||
from stabilityai.models import TextPrompt
|
||||
|
||||
stability.text_to_image(
|
||||
text_prompts=[
|
||||
TextPrompt(text="A beautiful sunrise", weight=1.0)
|
||||
],
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
if text_prompt:
|
||||
text_prompts = [TextPrompt(text=text_prompt)]
|
||||
|
||||
# After this moment text_prompts can't be None.
|
||||
assert text_prompts is not None
|
||||
return text_prompts
|
||||
|
||||
@validate_arguments
|
||||
async def text_to_image(
|
||||
self,
|
||||
text_prompts: Optional[TextPrompts] = None,
|
||||
text_prompt: Optional[SingleTextPrompt] = None,
|
||||
*,
|
||||
engine: Optional[Engine] = None,
|
||||
cfg_scale: Optional[CfgScale] = None,
|
||||
clip_guidance_preset: Optional[ClipGuidancePreset] = None,
|
||||
height: Optional[DiffuseImageHeight] = None,
|
||||
sampler: Optional[Sampler] = None,
|
||||
samples: Optional[Samples] = None,
|
||||
seed: Optional[Seed] = None,
|
||||
steps: Optional[Steps] = None,
|
||||
style_preset: Optional[StylePreset] = None,
|
||||
width: Optional[DiffuseImageWidth] = None,
|
||||
extras: Optional[Extras] = None,
|
||||
):
|
||||
self._oops_no_session()
|
||||
|
||||
text_prompts = self._normalize_text_prompts(text_prompts, text_prompt)
|
||||
engine_id = engine.id if engine else DEFAULT_GENERATION_ENGINE
|
||||
|
||||
request_body = TextToImageRequestBody(
|
||||
cfg_scale=cfg_scale,
|
||||
clip_guidance_preset=clip_guidance_preset,
|
||||
height=height,
|
||||
sampler=sampler,
|
||||
samples=samples,
|
||||
seed=seed,
|
||||
steps=steps,
|
||||
style_preset=style_preset,
|
||||
text_prompts=text_prompts,
|
||||
width=width,
|
||||
extras=extras,
|
||||
)
|
||||
|
||||
res = await self.session.post(
|
||||
f"/v1/generation/{engine_id}/text-to-image",
|
||||
data=json.dumps(omit_none(json.loads(request_body.json()))),
|
||||
)
|
||||
Sampler.K_DPMPP_2M
|
||||
|
||||
return TextToImageResponseBody.parse_obj(await res.json())
|
||||
|
||||
@validate_arguments
|
||||
async def image_to_image(
|
||||
self,
|
||||
text_prompts: TextPrompts,
|
||||
text_prompt: SingleTextPrompt,
|
||||
init_image: InitImage,
|
||||
*,
|
||||
init_image_mode: Optional[InitImageMode] = None,
|
||||
image_strength: InitImageStrength,
|
||||
engine: Optional[Engine] = None,
|
||||
cfg_scale: Optional[CfgScale] = None,
|
||||
clip_guidance_preset: Optional[ClipGuidancePreset] = None,
|
||||
sampler: Optional[Sampler] = None,
|
||||
samples: Optional[Samples] = None,
|
||||
seed: Optional[Seed] = None,
|
||||
steps: Optional[Steps] = None,
|
||||
style_preset: Optional[StylePreset] = None,
|
||||
extras: Optional[Extras] = None,
|
||||
):
|
||||
self._oops_no_session()
|
||||
|
||||
text_prompts = self._normalize_text_prompts(text_prompts, text_prompt)
|
||||
engine_id = engine.id if engine else DEFAULT_GENERATION_ENGINE
|
||||
|
||||
request_body = ImageToImageRequestBody(
|
||||
cfg_scale=cfg_scale,
|
||||
clip_guidance_preset=clip_guidance_preset,
|
||||
init_image=init_image,
|
||||
init_image_mode=init_image_mode,
|
||||
image_strength=image_strength,
|
||||
sampler=sampler,
|
||||
samples=samples,
|
||||
seed=seed,
|
||||
steps=steps,
|
||||
style_preset=style_preset,
|
||||
text_prompts=text_prompts,
|
||||
extras=extras,
|
||||
)
|
||||
|
||||
res = await self.session.post(
|
||||
f"/v1/generation/{engine_id}/text-to-image",
|
||||
data=json.dumps(omit_none(json.loads(request_body.json()))),
|
||||
)
|
||||
|
||||
return ImageToImageResponseBody.parse_obj(await res.json())
|
||||
|
||||
async def image_to_image_upscale(
|
||||
self,
|
||||
image: InitImage,
|
||||
*,
|
||||
engine: Optional[Engine] = None,
|
||||
width: Optional[UpscaleImageWidth] = None,
|
||||
height: Optional[UpscaleImageHeight] = None,
|
||||
):
|
||||
self._oops_no_session()
|
||||
|
||||
engine_id = engine.id if engine else DEFAULT_UPSCALE_ENGINE
|
||||
|
||||
request_body = ImageToImageUpscaleRequestBody(image=image, width=width, height=height)
|
||||
|
||||
res = await self.session.post(
|
||||
f"/v1/generation/{engine_id}/image-to-image/upscale",
|
||||
data=json.dumps(omit_none(json.loads(request_body.json()))),
|
||||
)
|
||||
|
||||
return ImageToImageUpscaleResponseBody.parse_obj(await res.json())
|
@ -0,0 +1,9 @@
|
||||
from typing import Final
|
||||
|
||||
|
||||
DEFAULT_STABILITY_API_HOST: Final[str] = "https://api.stability.ai"
|
||||
DEFAULT_STABILITY_CLIENT_ID: Final[str] = "Stability Python SDK"
|
||||
DEFAULT_STABILITY_CLIENT_VERSION: Final[str] = "1.0.0"
|
||||
|
||||
DEFAULT_GENERATION_ENGINE: Final[str] = "stable-diffusion-xl-beta-v2-2-2"
|
||||
DEFAULT_UPSCALE_ENGINE: Final[str] = "esrgan-v1-x2plus"
|
@ -0,0 +1,2 @@
|
||||
class YouNeedToUseAContextManager(Exception):
|
||||
pass
|
@ -0,0 +1,247 @@
|
||||
import base64
|
||||
import io
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
from pydantic import (
|
||||
AnyUrl,
|
||||
BaseModel,
|
||||
EmailStr,
|
||||
confloat,
|
||||
conint,
|
||||
constr,
|
||||
validator,
|
||||
)
|
||||
|
||||
class Type(StrEnum):
|
||||
AUDIO = "AUDIO"
|
||||
CLASSIFICATION = "CLASSIFICATION"
|
||||
PICTURE = "PICTURE"
|
||||
STORAGE = "STORAGE"
|
||||
TEXT = "TEXT"
|
||||
VIDEO = "VIDEO"
|
||||
|
||||
|
||||
class Engine(BaseModel):
|
||||
description: str
|
||||
id: str
|
||||
name: str
|
||||
type: Type
|
||||
|
||||
|
||||
class Error(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
message: str
|
||||
|
||||
|
||||
CfgScale = Annotated[float, confloat(ge=0.0, le=35.0)]
|
||||
|
||||
|
||||
class ClipGuidancePreset(StrEnum):
|
||||
FAST_BLUE = "FAST_BLUE"
|
||||
FAST_GREEN = "FAST_GREEN"
|
||||
NONE = "NONE"
|
||||
SIMPLE = "SIMPLE"
|
||||
SLOW = "SLOW"
|
||||
SLOWER = "SLOWER"
|
||||
SLOWEST = "SLOWEST"
|
||||
|
||||
|
||||
UpscaleImageHeight = Annotated[int, conint(ge=512)]
|
||||
|
||||
UpscaleImageWidth = Annotated[int, conint(ge=512)]
|
||||
|
||||
DiffuseImageHeight = Annotated[int, conint(ge=128, multiple_of=64)]
|
||||
|
||||
DiffuseImageWidth = Annotated[int, conint(ge=128, multiple_of=64)]
|
||||
|
||||
|
||||
class Sampler(StrEnum):
|
||||
DDIM = "DDIM"
|
||||
DDPM = "DDPM"
|
||||
K_DPMPP_2M = "K_DPMPP_2M"
|
||||
K_DPMPP_2S_ANCESTRAL = "K_DPMPP_2S_ANCESTRAL"
|
||||
K_DPM_2 = "K_DPM_2"
|
||||
K_DPM_2_ANCESTRAL = "K_DPM_2_ANCESTRAL"
|
||||
K_EULER = "K_EULER"
|
||||
K_EULER_ANCESTRAL = "K_EULER_ANCESTRAL"
|
||||
K_HEUN = "K_HEUN"
|
||||
K_LMS = "K_LMS"
|
||||
|
||||
|
||||
Samples = Annotated[int, conint(ge=1, le=10)]
|
||||
|
||||
Seed = Annotated[int, conint(ge=0, le=4294967295)]
|
||||
|
||||
Steps = Annotated[int, conint(ge=10, le=150)]
|
||||
|
||||
Extras = dict
|
||||
|
||||
|
||||
class StylePreset(StrEnum):
|
||||
enhance = "enhance"
|
||||
anime = "anime"
|
||||
photographic = "photographic"
|
||||
digital_art = "digital-art"
|
||||
comic_book = "comic-book"
|
||||
fantasy_art = "fantasy-art"
|
||||
line_art = "line-art"
|
||||
analog_film = "analog-film"
|
||||
neon_punk = "neon-punk"
|
||||
isometric = "isometric"
|
||||
low_poly = "low-poly"
|
||||
origami = "origami"
|
||||
modeling_compound = "modeling-compound"
|
||||
cinematic = "cinematic"
|
||||
field_3d_model = "3d-model"
|
||||
pixel_art = "pixel-art"
|
||||
tile_texture = "tile-texture"
|
||||
|
||||
|
||||
class TextPrompt(BaseModel):
|
||||
text: Annotated[str, constr(max_length=2000)]
|
||||
weight: Optional[float] = None
|
||||
|
||||
SingleTextPrompt = str
|
||||
|
||||
TextPrompts = List[TextPrompt]
|
||||
|
||||
InitImage = bytes
|
||||
|
||||
InitImageStrength = Annotated[float, confloat(ge=0.0, le=1.0)]
|
||||
|
||||
|
||||
class InitImageMode(StrEnum):
|
||||
IMAGE_STRENGTH = "IMAGE_STRENGTH"
|
||||
STEP_SCHEDULE = "STEP_SCHEDULE"
|
||||
|
||||
|
||||
StepScheduleStart = Annotated[float, confloat(ge=0.0, le=1.0)]
|
||||
|
||||
StepScheduleEnd = Annotated[float, confloat(ge=0.0, le=1.0)]
|
||||
|
||||
MaskImage = bytes
|
||||
|
||||
MaskSource = str
|
||||
|
||||
|
||||
class GenerationRequestOptionalParams(BaseModel):
|
||||
cfg_scale: Optional[CfgScale] = None
|
||||
clip_guidance_preset: Optional[ClipGuidancePreset] = None
|
||||
sampler: Optional[Sampler] = None
|
||||
samples: Optional[Samples] = None
|
||||
seed: Optional[Seed] = None
|
||||
steps: Optional[Steps] = None
|
||||
style_preset: Optional[StylePreset] = None
|
||||
extras: Optional[Extras] = None
|
||||
|
||||
|
||||
class LatentUpscalerUpscaleRequestBody(BaseModel):
|
||||
image: InitImage
|
||||
width: Optional[UpscaleImageWidth] = None
|
||||
height: Optional[UpscaleImageHeight] = None
|
||||
text_prompts: Optional[TextPrompts] = None
|
||||
seed: Optional[Seed] = None
|
||||
steps: Optional[Steps] = None
|
||||
cfg_scale: Optional[CfgScale] = None
|
||||
|
||||
|
||||
class ImageToImageRequestBody(BaseModel):
|
||||
text_prompts: TextPrompts
|
||||
init_image: InitImage
|
||||
init_image_mode: Optional[InitImageMode] = InitImageMode("IMAGE_STRENGTH")
|
||||
image_strength: Optional[InitImageStrength] = None
|
||||
step_schedule_start: Optional[StepScheduleStart] = None
|
||||
step_schedule_end: Optional[StepScheduleEnd] = None
|
||||
cfg_scale: Optional[CfgScale] = None
|
||||
clip_guidance_preset: Optional[ClipGuidancePreset] = None
|
||||
sampler: Optional[Sampler] = None
|
||||
samples: Optional[Samples] = None
|
||||
seed: Optional[Seed] = None
|
||||
steps: Optional[Steps] = None
|
||||
style_preset: Optional[StylePreset] = None
|
||||
extras: Optional[Extras] = None
|
||||
|
||||
|
||||
class MaskingRequestBody(BaseModel):
|
||||
init_image: InitImage
|
||||
mask_source: MaskSource
|
||||
mask_image: Optional[MaskImage] = None
|
||||
text_prompts: TextPrompts
|
||||
cfg_scale: Optional[CfgScale] = None
|
||||
clip_guidance_preset: Optional[ClipGuidancePreset] = None
|
||||
sampler: Optional[Sampler] = None
|
||||
samples: Optional[Samples] = None
|
||||
seed: Optional[Seed] = None
|
||||
steps: Optional[Steps] = None
|
||||
style_preset: Optional[StylePreset] = None
|
||||
extras: Optional[Extras] = None
|
||||
|
||||
|
||||
class TextToImageRequestBody(GenerationRequestOptionalParams):
|
||||
height: Optional[DiffuseImageHeight] = None
|
||||
width: Optional[DiffuseImageWidth] = None
|
||||
text_prompts: TextPrompts
|
||||
|
||||
|
||||
class ImageToImageUpscaleRequestBody(BaseModel):
|
||||
image: InitImage
|
||||
width: Optional[UpscaleImageWidth]
|
||||
height: Optional[UpscaleImageHeight]
|
||||
|
||||
@validator("width", always=True)
|
||||
def mutually_exclusive(cls, v, values):
|
||||
if values["height"] is not None and v:
|
||||
raise ValueError("You can only specify one of width and height.")
|
||||
return v
|
||||
|
||||
|
||||
class BalanceResponseBody(BaseModel):
|
||||
credits: float
|
||||
|
||||
|
||||
ListEnginesResponseBody = List[Engine]
|
||||
|
||||
|
||||
class FinishReason(StrEnum):
|
||||
SUCCESS = "SUCCESS"
|
||||
ERROR = "ERROR"
|
||||
CONTENT_FILTERED = "CONTENT_FILTERED"
|
||||
|
||||
|
||||
class Image(BaseModel):
|
||||
base64: Optional[str] = None
|
||||
finishReason: Optional[FinishReason] = None
|
||||
seed: Optional[float] = None
|
||||
|
||||
@property
|
||||
def file(self):
|
||||
assert self.base64 is not None
|
||||
return io.BytesIO(base64.b64decode(self.base64))
|
||||
|
||||
|
||||
class OrganizationMembership(BaseModel):
|
||||
id: str
|
||||
is_default: bool
|
||||
name: str
|
||||
role: str
|
||||
|
||||
|
||||
class AccountResponseBody(BaseModel):
|
||||
email: EmailStr
|
||||
id: str
|
||||
organizations: List[OrganizationMembership]
|
||||
profile_picture: Optional[AnyUrl] = None
|
||||
|
||||
|
||||
class TextToImageResponseBody(BaseModel):
|
||||
artifacts: List[Image]
|
||||
|
||||
|
||||
class ImageToImageResponseBody(BaseModel):
|
||||
artifacts: List[Image]
|
||||
|
||||
|
||||
class ImageToImageUpscaleResponseBody(BaseModel):
|
||||
artifacts: List[Image]
|
@ -0,0 +1,6 @@
|
||||
from typing import Dict, Mapping, TypeVar
|
||||
|
||||
T, U = TypeVar("T"), TypeVar("U")
|
||||
|
||||
def omit_none(m: Mapping[T, U]) -> Dict[T, U]:
|
||||
return {k: v for k, v in m.items() if v is not None}
|
Loading…
Reference in New Issue