wip
							parent
							
								
									f9859d1767
								
							
						
					
					
						commit
						ab492417d6
					
				@ -0,0 +1,41 @@
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
name: Publish Python distributions to PyPI and TestPyPI
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    branches: [main]
 | 
			
		||||
  pull_request:
 | 
			
		||||
    branches: [main]
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  validate:
 | 
			
		||||
    name: Run static analysis on the code
 | 
			
		||||
    runs-on: ubuntu-22.04
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Lint with flake8
 | 
			
		||||
        run: |
 | 
			
		||||
          pip install flake8
 | 
			
		||||
          # stop the build if there are Python syntax errors or undefined names
 | 
			
		||||
          flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
 | 
			
		||||
          # exit-zero treats all errors as warnings.
 | 
			
		||||
          flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
 | 
			
		||||
 | 
			
		||||
  build-and-publish:
 | 
			
		||||
    name: Build and publish Python distribution
 | 
			
		||||
    runs-on: ubuntu-22.04
 | 
			
		||||
    steps:
 | 
			
		||||
      - uses: actions/checkout@main
 | 
			
		||||
      - name: Initialize Python 3.11
 | 
			
		||||
        uses: actions/setup-python@v1
 | 
			
		||||
        with:
 | 
			
		||||
          python-version: 3.11
 | 
			
		||||
      - name: Install dependencies
 | 
			
		||||
        run: |
 | 
			
		||||
          python -m pip install --upgrade pip build
 | 
			
		||||
      - name: Build binary wheel and a source tarball
 | 
			
		||||
        run: python -m build 
 | 
			
		||||
      - name: Publish distribution to PyPI
 | 
			
		||||
        uses: pypa/gh-action-pypi-publish@v1.8.5 
 | 
			
		||||
        with:
 | 
			
		||||
          password: ${{ secrets.PIPY_PASSWORD }}
 | 
			
		||||
 | 
			
		||||
@ -1,2 +1,71 @@
 | 
			
		||||
# stabilityai
 | 
			
		||||
Client library for the stability.ai REST API
 | 
			
		||||
# Stability AI
 | 
			
		||||
 | 
			
		||||
An **UNOFFICIAL** client library for the stability.ai REST API.
 | 
			
		||||
 | 
			
		||||
## Motivation
 | 
			
		||||
 | 
			
		||||
The official `stability-sdk` is a based on gRPC and also really hard to use. Like look at this, this
 | 
			
		||||
ignores setting up the SDK.
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from stability_sdk import client
 | 
			
		||||
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
 | 
			
		||||
 | 
			
		||||
answers = stability_api.generate(
 | 
			
		||||
    prompt="a rocket-ship launching from rolling greens with blue daisies",
 | 
			
		||||
    seed=892226758,
 | 
			
		||||
    steps=30,
 | 
			
		||||
    cfg_scale=8.0,
 | 
			
		||||
    width=512,
 | 
			
		||||
    height=512,
 | 
			
		||||
    sampler=generation.SAMPLER_K_DPMPP_2M
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
for resp in answers:
 | 
			
		||||
    for artifact in resp.artifacts:
 | 
			
		||||
        if artifact.finish_reason == generation.FILTER:
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                "Your request activated the API's safety filters and could not be processed."
 | 
			
		||||
                "Please modify the prompt and try again.")
 | 
			
		||||
        if artifact.type == generation.ARTIFACT_IMAGE:
 | 
			
		||||
            global img
 | 
			
		||||
            img = Image.open(io.BytesIO(artifact.binary))
 | 
			
		||||
            img.save(str(artifact.seed)+ ".png")
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
This for loop is *magic*. You must loop the results in exactly this way or the gRPC library won't
 | 
			
		||||
work. It's about an unpythonic as a library can get.
 | 
			
		||||
 | 
			
		||||
## My Take
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
# Set the STABILITY_API_KEY environment variable.
 | 
			
		||||
 | 
			
		||||
from stabilityai.client import AsyncStabilityClient
 | 
			
		||||
from stabilityai.models import Sampler
 | 
			
		||||
 | 
			
		||||
async def example():
 | 
			
		||||
  async with AsyncStabilityClient() as stability:
 | 
			
		||||
    results = await stability.text_to_image(
 | 
			
		||||
        text_prompt="a rocket-ship launching from rolling greens with blue daisies",
 | 
			
		||||
        # All these are optional and have sane defaults.
 | 
			
		||||
        seed=892226758,
 | 
			
		||||
        steps=30,
 | 
			
		||||
        cfg_scale=8.0,
 | 
			
		||||
        width=512,
 | 
			
		||||
        height=512,
 | 
			
		||||
        sampler=Sampler.K_DPMPP_2M,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    artifact = results.artifacts[0]
 | 
			
		||||
 | 
			
		||||
    img = Image.open(artifact.file)
 | 
			
		||||
    img.save(artifact.file.name)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Additional Nicetieis
 | 
			
		||||
 | 
			
		||||
* Instead of manually checking `FINISH_REASON` an appropriate exception will automatically be
 | 
			
		||||
    raised.
 | 
			
		||||
 | 
			
		||||
* Full mypy/pyright support for type checking and autocomplete.
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,50 @@
 | 
			
		||||
# pyproject.toml
 | 
			
		||||
# https://packaging.python.org/en/latest/specifications/declaring-project-metadata
 | 
			
		||||
# https://pip.pypa.io/en/stable/reference/build-system/pyproject-toml
 | 
			
		||||
 | 
			
		||||
[build-system]
 | 
			
		||||
requires = ["setuptools", "setuptools-scm", "build"]
 | 
			
		||||
build-backend = "setuptools.build_meta"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
[project]
 | 
			
		||||
name = "stabilityai"
 | 
			
		||||
description = "*Unofficial* client for the Stability REST API"
 | 
			
		||||
authors = [
 | 
			
		||||
	{name = "Estelle Poulin", email = "dev@inspiredby.es"},
 | 
			
		||||
]
 | 
			
		||||
readme = "README.md"
 | 
			
		||||
requires-python = ">=3.11"
 | 
			
		||||
keywords = ["stabilityai", "bot"]
 | 
			
		||||
license = {text = "GPLv3"}
 | 
			
		||||
classifiers = [
 | 
			
		||||
    "Programming Language :: Python :: 3",
 | 
			
		||||
]
 | 
			
		||||
dynamic = ["version", "dependencies"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
[project.urls]
 | 
			
		||||
homepage = "https://github.com/estheruary/stabilityai"
 | 
			
		||||
repository = "https://github.com/estheruary/stabilityai"
 | 
			
		||||
changelog = "https://github.com/estheruary/stabilityai/-/blob/main/CHANGELOG.md"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
[tool.setuptools]
 | 
			
		||||
packages = ["stabilityai"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
[tool.setuptools.dynamic]
 | 
			
		||||
version = {attr = "stabilityai.__version__"}
 | 
			
		||||
dependencies = {file = ["requirements.txt"]}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
[tool.black]
 | 
			
		||||
line-length = 100
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
[tool.isort]
 | 
			
		||||
profile = "black"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
[tool.vulture]
 | 
			
		||||
ignore_names = ["self", "cls"]
 | 
			
		||||
@ -0,0 +1,2 @@
 | 
			
		||||
aiohttp==3.8.4
 | 
			
		||||
pydantic [email]==1.10.2
 | 
			
		||||
@ -0,0 +1 @@
 | 
			
		||||
__version__ = "1.0.1"
 | 
			
		||||
@ -0,0 +1,271 @@
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import aiohttp
 | 
			
		||||
from pydantic import (
 | 
			
		||||
    validate_arguments,
 | 
			
		||||
)
 | 
			
		||||
from stabilityai.constants import (
 | 
			
		||||
    DEFAULT_GENERATION_ENGINE,
 | 
			
		||||
    DEFAULT_STABILITY_API_HOST,
 | 
			
		||||
    DEFAULT_STABILITY_CLIENT_ID,
 | 
			
		||||
    DEFAULT_STABILITY_CLIENT_VERSION,
 | 
			
		||||
    DEFAULT_UPSCALE_ENGINE,
 | 
			
		||||
)
 | 
			
		||||
from stabilityai.exceptions import YouNeedToUseAContextManager
 | 
			
		||||
from stabilityai.models import (
 | 
			
		||||
    AccountResponseBody,
 | 
			
		||||
    BalanceResponseBody,
 | 
			
		||||
    CfgScale,
 | 
			
		||||
    ClipGuidancePreset,
 | 
			
		||||
    DiffuseImageHeight,
 | 
			
		||||
    DiffuseImageWidth,
 | 
			
		||||
    Engine,
 | 
			
		||||
    Extras,
 | 
			
		||||
    ImageToImageRequestBody,
 | 
			
		||||
    ImageToImageResponseBody,
 | 
			
		||||
    ImageToImageUpscaleRequestBody,
 | 
			
		||||
    ImageToImageUpscaleResponseBody,
 | 
			
		||||
    InitImage,
 | 
			
		||||
    InitImageMode,
 | 
			
		||||
    InitImageStrength,
 | 
			
		||||
    ListEnginesResponseBody,
 | 
			
		||||
    Sampler,
 | 
			
		||||
    Samples,
 | 
			
		||||
    Seed,
 | 
			
		||||
    SingleTextPrompt,
 | 
			
		||||
    Steps,
 | 
			
		||||
    StylePreset,
 | 
			
		||||
    TextPrompt,
 | 
			
		||||
    TextPrompts,
 | 
			
		||||
    TextToImageRequestBody,
 | 
			
		||||
    TextToImageResponseBody,
 | 
			
		||||
    UpscaleImageHeight,
 | 
			
		||||
    UpscaleImageWidth,
 | 
			
		||||
)
 | 
			
		||||
from stabilityai.utils import omit_none
 | 
			
		||||
import textwrap
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AsyncStabilityClient:
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        api_host: str = os.environ.get("STABILITY_API_HOST", DEFAULT_STABILITY_API_HOST),
 | 
			
		||||
        api_key: Optional[str] = os.environ.get("STABILITY_API_KEY"),
 | 
			
		||||
        client_id: str = os.environ.get("STABILITY_CLIENT_ID", DEFAULT_STABILITY_CLIENT_ID),
 | 
			
		||||
        client_version: str = os.environ.get(
 | 
			
		||||
            "STABILITY_CLIENT_VERSION", DEFAULT_STABILITY_CLIENT_VERSION
 | 
			
		||||
        ),
 | 
			
		||||
        organization: Optional[str] = os.environ.get("STABILITY_CLIENT_ORGANIZATION"),
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self.api_host = api_host
 | 
			
		||||
        self.api_key = api_key
 | 
			
		||||
        self.client_id = client_id
 | 
			
		||||
        self.client_version = client_version
 | 
			
		||||
        self.organization = organization
 | 
			
		||||
 | 
			
		||||
    async def __aenter__(self):
 | 
			
		||||
        self.session = await aiohttp.ClientSession(
 | 
			
		||||
            base_url=self.api_host,
 | 
			
		||||
            headers={
 | 
			
		||||
                "Content-Type": "application/json",
 | 
			
		||||
                "Accept": "application/json",
 | 
			
		||||
                "Authorization": f"Bearer {self.api_key}",
 | 
			
		||||
                "Stability-Client-ID": self.client_id,
 | 
			
		||||
                "Stability-Client-Version": self.client_version,
 | 
			
		||||
                **({"Organization": self.organization} if self.organization else {}),
 | 
			
		||||
            },
 | 
			
		||||
            raise_for_status=True,
 | 
			
		||||
        ).__aenter__()
 | 
			
		||||
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    async def __aexit__(self, exc_type, exc_val, exc_tb):
 | 
			
		||||
        return await self.session.__aexit__(exc_type, exc_val, exc_tb)
 | 
			
		||||
 | 
			
		||||
    async def get_engines(self) -> ListEnginesResponseBody:
 | 
			
		||||
        res = await self.session.get("/v1/engines/list")
 | 
			
		||||
        return await res.json()
 | 
			
		||||
 | 
			
		||||
    async def get_account(self) -> AccountResponseBody:
 | 
			
		||||
        res = await self.session.get("/v1/user/account")
 | 
			
		||||
        return await res.json()
 | 
			
		||||
 | 
			
		||||
    async def get_account_balance(self) -> BalanceResponseBody:
 | 
			
		||||
        res = await self.session.get("/v1/user/balance")
 | 
			
		||||
        return await res.json()
 | 
			
		||||
 | 
			
		||||
    def _oops_no_session(self):
 | 
			
		||||
        if not self.session:
 | 
			
		||||
            raise YouNeedToUseAContextManager(
 | 
			
		||||
                textwrap.dedent(
 | 
			
		||||
                    f"""\
 | 
			
		||||
                {self.__class__.__name__} keeps a aiohttp.ClientSession under
 | 
			
		||||
                the hood and needs to be closed when you're done with it. But
 | 
			
		||||
                since there isn't an async version of __del__ we have to use
 | 
			
		||||
                __aenter__/__aexit__ instead. Apologies.
 | 
			
		||||
 | 
			
		||||
                Instead of
 | 
			
		||||
 | 
			
		||||
                    myclient = {self.__class__.__name__}()
 | 
			
		||||
                    myclient.text_to_image(...)
 | 
			
		||||
                                
 | 
			
		||||
                Do this
 | 
			
		||||
 | 
			
		||||
                    async with {self.__class__.__name__} as myclient:
 | 
			
		||||
                        myclient.text_to_image(...)
 | 
			
		||||
 | 
			
		||||
                Note that it's `async with` and not `with`.
 | 
			
		||||
                """
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def _normalize_text_prompts(
 | 
			
		||||
        self,
 | 
			
		||||
        text_prompts: Optional[TextPrompts],
 | 
			
		||||
        text_prompt: Optional[SingleTextPrompt]
 | 
			
		||||
    ):
 | 
			
		||||
        if not bool(text_prompt) ^ bool(text_prompts):
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                textwrap.dedent(
 | 
			
		||||
                    f"""\
 | 
			
		||||
                    You must provide one of text_prompt and text_prompts.
 | 
			
		||||
 | 
			
		||||
                    Do this
 | 
			
		||||
 | 
			
		||||
                        stability.text_to_image(
 | 
			
		||||
                            text_prompt="A beautiful sunrise"
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                    Or this
 | 
			
		||||
 | 
			
		||||
                        from stabilityai.models import TextPrompt
 | 
			
		||||
 | 
			
		||||
                        stability.text_to_image(
 | 
			
		||||
                            text_prompts=[
 | 
			
		||||
                                TextPrompt(text="A beautiful sunrise", weight=1.0)
 | 
			
		||||
                            ],
 | 
			
		||||
                        )
 | 
			
		||||
                    """ 
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if text_prompt:
 | 
			
		||||
            text_prompts = [TextPrompt(text=text_prompt)]
 | 
			
		||||
 | 
			
		||||
        # After this moment text_prompts can't be None.
 | 
			
		||||
        assert text_prompts is not None
 | 
			
		||||
        return text_prompts
 | 
			
		||||
 | 
			
		||||
    @validate_arguments
 | 
			
		||||
    async def text_to_image(
 | 
			
		||||
        self,
 | 
			
		||||
        text_prompts: Optional[TextPrompts] = None,
 | 
			
		||||
        text_prompt: Optional[SingleTextPrompt] = None,
 | 
			
		||||
        *,
 | 
			
		||||
        engine: Optional[Engine] = None,
 | 
			
		||||
        cfg_scale: Optional[CfgScale] = None,
 | 
			
		||||
        clip_guidance_preset: Optional[ClipGuidancePreset] = None,
 | 
			
		||||
        height: Optional[DiffuseImageHeight] = None,
 | 
			
		||||
        sampler: Optional[Sampler] = None,
 | 
			
		||||
        samples: Optional[Samples] = None,
 | 
			
		||||
        seed: Optional[Seed] = None,
 | 
			
		||||
        steps: Optional[Steps] = None,
 | 
			
		||||
        style_preset: Optional[StylePreset] = None,
 | 
			
		||||
        width: Optional[DiffuseImageWidth] = None,
 | 
			
		||||
        extras: Optional[Extras] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        self._oops_no_session()
 | 
			
		||||
 | 
			
		||||
        text_prompts = self._normalize_text_prompts(text_prompts, text_prompt)
 | 
			
		||||
        engine_id = engine.id if engine else DEFAULT_GENERATION_ENGINE
 | 
			
		||||
 | 
			
		||||
        request_body = TextToImageRequestBody(
 | 
			
		||||
            cfg_scale=cfg_scale,
 | 
			
		||||
            clip_guidance_preset=clip_guidance_preset,
 | 
			
		||||
            height=height,
 | 
			
		||||
            sampler=sampler,
 | 
			
		||||
            samples=samples,
 | 
			
		||||
            seed=seed,
 | 
			
		||||
            steps=steps,
 | 
			
		||||
            style_preset=style_preset,
 | 
			
		||||
            text_prompts=text_prompts,
 | 
			
		||||
            width=width,
 | 
			
		||||
            extras=extras,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        res = await self.session.post(
 | 
			
		||||
            f"/v1/generation/{engine_id}/text-to-image",
 | 
			
		||||
            data=json.dumps(omit_none(json.loads(request_body.json()))),
 | 
			
		||||
        )
 | 
			
		||||
        Sampler.K_DPMPP_2M
 | 
			
		||||
 | 
			
		||||
        return TextToImageResponseBody.parse_obj(await res.json())
 | 
			
		||||
 | 
			
		||||
    @validate_arguments
 | 
			
		||||
    async def image_to_image(
 | 
			
		||||
        self,
 | 
			
		||||
        text_prompts: TextPrompts,
 | 
			
		||||
        text_prompt: SingleTextPrompt,
 | 
			
		||||
        init_image: InitImage,
 | 
			
		||||
        *,
 | 
			
		||||
        init_image_mode: Optional[InitImageMode] = None,
 | 
			
		||||
        image_strength: InitImageStrength,
 | 
			
		||||
        engine: Optional[Engine] = None,
 | 
			
		||||
        cfg_scale: Optional[CfgScale] = None,
 | 
			
		||||
        clip_guidance_preset: Optional[ClipGuidancePreset] = None,
 | 
			
		||||
        sampler: Optional[Sampler] = None,
 | 
			
		||||
        samples: Optional[Samples] = None,
 | 
			
		||||
        seed: Optional[Seed] = None,
 | 
			
		||||
        steps: Optional[Steps] = None,
 | 
			
		||||
        style_preset: Optional[StylePreset] = None,
 | 
			
		||||
        extras: Optional[Extras] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        self._oops_no_session()
 | 
			
		||||
 | 
			
		||||
        text_prompts = self._normalize_text_prompts(text_prompts, text_prompt)
 | 
			
		||||
        engine_id = engine.id if engine else DEFAULT_GENERATION_ENGINE
 | 
			
		||||
 | 
			
		||||
        request_body = ImageToImageRequestBody(
 | 
			
		||||
            cfg_scale=cfg_scale,
 | 
			
		||||
            clip_guidance_preset=clip_guidance_preset,
 | 
			
		||||
            init_image=init_image,
 | 
			
		||||
            init_image_mode=init_image_mode,
 | 
			
		||||
            image_strength=image_strength,
 | 
			
		||||
            sampler=sampler,
 | 
			
		||||
            samples=samples,
 | 
			
		||||
            seed=seed,
 | 
			
		||||
            steps=steps,
 | 
			
		||||
            style_preset=style_preset,
 | 
			
		||||
            text_prompts=text_prompts,
 | 
			
		||||
            extras=extras,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        res = await self.session.post(
 | 
			
		||||
            f"/v1/generation/{engine_id}/text-to-image",
 | 
			
		||||
            data=json.dumps(omit_none(json.loads(request_body.json()))),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return ImageToImageResponseBody.parse_obj(await res.json())
 | 
			
		||||
 | 
			
		||||
    async def image_to_image_upscale(
 | 
			
		||||
        self,
 | 
			
		||||
        image: InitImage,
 | 
			
		||||
        *,
 | 
			
		||||
        engine: Optional[Engine] = None,
 | 
			
		||||
        width: Optional[UpscaleImageWidth] = None,
 | 
			
		||||
        height: Optional[UpscaleImageHeight] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        self._oops_no_session()
 | 
			
		||||
 | 
			
		||||
        engine_id = engine.id if engine else DEFAULT_UPSCALE_ENGINE
 | 
			
		||||
 | 
			
		||||
        request_body = ImageToImageUpscaleRequestBody(image=image, width=width, height=height)
 | 
			
		||||
 | 
			
		||||
        res = await self.session.post(
 | 
			
		||||
            f"/v1/generation/{engine_id}/image-to-image/upscale",
 | 
			
		||||
            data=json.dumps(omit_none(json.loads(request_body.json()))),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return ImageToImageUpscaleResponseBody.parse_obj(await res.json())
 | 
			
		||||
@ -0,0 +1,9 @@
 | 
			
		||||
from typing import Final
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
DEFAULT_STABILITY_API_HOST: Final[str] = "https://api.stability.ai"
 | 
			
		||||
DEFAULT_STABILITY_CLIENT_ID: Final[str] = "Stability Python SDK"
 | 
			
		||||
DEFAULT_STABILITY_CLIENT_VERSION: Final[str] = "1.0.0"
 | 
			
		||||
 | 
			
		||||
DEFAULT_GENERATION_ENGINE: Final[str] = "stable-diffusion-xl-beta-v2-2-2"
 | 
			
		||||
DEFAULT_UPSCALE_ENGINE: Final[str] = "esrgan-v1-x2plus"
 | 
			
		||||
@ -0,0 +1,2 @@
 | 
			
		||||
class YouNeedToUseAContextManager(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
@ -0,0 +1,247 @@
 | 
			
		||||
import base64
 | 
			
		||||
import io
 | 
			
		||||
from enum import StrEnum
 | 
			
		||||
from typing import Annotated, List, Optional
 | 
			
		||||
 | 
			
		||||
from pydantic import (
 | 
			
		||||
    AnyUrl,
 | 
			
		||||
    BaseModel,
 | 
			
		||||
    EmailStr,
 | 
			
		||||
    confloat,
 | 
			
		||||
    conint,
 | 
			
		||||
    constr,
 | 
			
		||||
    validator,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
class Type(StrEnum):
 | 
			
		||||
    AUDIO = "AUDIO"
 | 
			
		||||
    CLASSIFICATION = "CLASSIFICATION"
 | 
			
		||||
    PICTURE = "PICTURE"
 | 
			
		||||
    STORAGE = "STORAGE"
 | 
			
		||||
    TEXT = "TEXT"
 | 
			
		||||
    VIDEO = "VIDEO"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Engine(BaseModel):
 | 
			
		||||
    description: str
 | 
			
		||||
    id: str
 | 
			
		||||
    name: str
 | 
			
		||||
    type: Type
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Error(BaseModel):
 | 
			
		||||
    id: str
 | 
			
		||||
    name: str
 | 
			
		||||
    message: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
CfgScale = Annotated[float, confloat(ge=0.0, le=35.0)]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ClipGuidancePreset(StrEnum):
 | 
			
		||||
    FAST_BLUE = "FAST_BLUE"
 | 
			
		||||
    FAST_GREEN = "FAST_GREEN"
 | 
			
		||||
    NONE = "NONE"
 | 
			
		||||
    SIMPLE = "SIMPLE"
 | 
			
		||||
    SLOW = "SLOW"
 | 
			
		||||
    SLOWER = "SLOWER"
 | 
			
		||||
    SLOWEST = "SLOWEST"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
UpscaleImageHeight = Annotated[int, conint(ge=512)]
 | 
			
		||||
 | 
			
		||||
UpscaleImageWidth = Annotated[int, conint(ge=512)]
 | 
			
		||||
 | 
			
		||||
DiffuseImageHeight = Annotated[int, conint(ge=128, multiple_of=64)]
 | 
			
		||||
 | 
			
		||||
DiffuseImageWidth = Annotated[int, conint(ge=128, multiple_of=64)]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Sampler(StrEnum):
 | 
			
		||||
    DDIM = "DDIM"
 | 
			
		||||
    DDPM = "DDPM"
 | 
			
		||||
    K_DPMPP_2M = "K_DPMPP_2M"
 | 
			
		||||
    K_DPMPP_2S_ANCESTRAL = "K_DPMPP_2S_ANCESTRAL"
 | 
			
		||||
    K_DPM_2 = "K_DPM_2"
 | 
			
		||||
    K_DPM_2_ANCESTRAL = "K_DPM_2_ANCESTRAL"
 | 
			
		||||
    K_EULER = "K_EULER"
 | 
			
		||||
    K_EULER_ANCESTRAL = "K_EULER_ANCESTRAL"
 | 
			
		||||
    K_HEUN = "K_HEUN"
 | 
			
		||||
    K_LMS = "K_LMS"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Samples = Annotated[int, conint(ge=1, le=10)]
 | 
			
		||||
 | 
			
		||||
Seed = Annotated[int, conint(ge=0, le=4294967295)]
 | 
			
		||||
 | 
			
		||||
Steps = Annotated[int, conint(ge=10, le=150)]
 | 
			
		||||
 | 
			
		||||
Extras = dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StylePreset(StrEnum):
 | 
			
		||||
    enhance = "enhance"
 | 
			
		||||
    anime = "anime"
 | 
			
		||||
    photographic = "photographic"
 | 
			
		||||
    digital_art = "digital-art"
 | 
			
		||||
    comic_book = "comic-book"
 | 
			
		||||
    fantasy_art = "fantasy-art"
 | 
			
		||||
    line_art = "line-art"
 | 
			
		||||
    analog_film = "analog-film"
 | 
			
		||||
    neon_punk = "neon-punk"
 | 
			
		||||
    isometric = "isometric"
 | 
			
		||||
    low_poly = "low-poly"
 | 
			
		||||
    origami = "origami"
 | 
			
		||||
    modeling_compound = "modeling-compound"
 | 
			
		||||
    cinematic = "cinematic"
 | 
			
		||||
    field_3d_model = "3d-model"
 | 
			
		||||
    pixel_art = "pixel-art"
 | 
			
		||||
    tile_texture = "tile-texture"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TextPrompt(BaseModel):
 | 
			
		||||
    text: Annotated[str, constr(max_length=2000)]
 | 
			
		||||
    weight: Optional[float] = None
 | 
			
		||||
 | 
			
		||||
SingleTextPrompt = str
 | 
			
		||||
 | 
			
		||||
TextPrompts = List[TextPrompt]
 | 
			
		||||
 | 
			
		||||
InitImage = bytes
 | 
			
		||||
 | 
			
		||||
InitImageStrength = Annotated[float, confloat(ge=0.0, le=1.0)]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InitImageMode(StrEnum):
 | 
			
		||||
    IMAGE_STRENGTH = "IMAGE_STRENGTH"
 | 
			
		||||
    STEP_SCHEDULE = "STEP_SCHEDULE"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
StepScheduleStart = Annotated[float, confloat(ge=0.0, le=1.0)]
 | 
			
		||||
 | 
			
		||||
StepScheduleEnd = Annotated[float, confloat(ge=0.0, le=1.0)]
 | 
			
		||||
 | 
			
		||||
MaskImage = bytes
 | 
			
		||||
 | 
			
		||||
MaskSource = str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GenerationRequestOptionalParams(BaseModel):
 | 
			
		||||
    cfg_scale: Optional[CfgScale] = None
 | 
			
		||||
    clip_guidance_preset: Optional[ClipGuidancePreset] = None
 | 
			
		||||
    sampler: Optional[Sampler] = None
 | 
			
		||||
    samples: Optional[Samples] = None
 | 
			
		||||
    seed: Optional[Seed] = None
 | 
			
		||||
    steps: Optional[Steps] = None
 | 
			
		||||
    style_preset: Optional[StylePreset] = None
 | 
			
		||||
    extras: Optional[Extras] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LatentUpscalerUpscaleRequestBody(BaseModel):
 | 
			
		||||
    image: InitImage
 | 
			
		||||
    width: Optional[UpscaleImageWidth] = None
 | 
			
		||||
    height: Optional[UpscaleImageHeight] = None
 | 
			
		||||
    text_prompts: Optional[TextPrompts] = None
 | 
			
		||||
    seed: Optional[Seed] = None
 | 
			
		||||
    steps: Optional[Steps] = None
 | 
			
		||||
    cfg_scale: Optional[CfgScale] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ImageToImageRequestBody(BaseModel):
 | 
			
		||||
    text_prompts: TextPrompts
 | 
			
		||||
    init_image: InitImage
 | 
			
		||||
    init_image_mode: Optional[InitImageMode] = InitImageMode("IMAGE_STRENGTH")
 | 
			
		||||
    image_strength: Optional[InitImageStrength] = None
 | 
			
		||||
    step_schedule_start: Optional[StepScheduleStart] = None
 | 
			
		||||
    step_schedule_end: Optional[StepScheduleEnd] = None
 | 
			
		||||
    cfg_scale: Optional[CfgScale] = None
 | 
			
		||||
    clip_guidance_preset: Optional[ClipGuidancePreset] = None
 | 
			
		||||
    sampler: Optional[Sampler] = None
 | 
			
		||||
    samples: Optional[Samples] = None
 | 
			
		||||
    seed: Optional[Seed] = None
 | 
			
		||||
    steps: Optional[Steps] = None
 | 
			
		||||
    style_preset: Optional[StylePreset] = None
 | 
			
		||||
    extras: Optional[Extras] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MaskingRequestBody(BaseModel):
 | 
			
		||||
    init_image: InitImage
 | 
			
		||||
    mask_source: MaskSource
 | 
			
		||||
    mask_image: Optional[MaskImage] = None
 | 
			
		||||
    text_prompts: TextPrompts
 | 
			
		||||
    cfg_scale: Optional[CfgScale] = None
 | 
			
		||||
    clip_guidance_preset: Optional[ClipGuidancePreset] = None
 | 
			
		||||
    sampler: Optional[Sampler] = None
 | 
			
		||||
    samples: Optional[Samples] = None
 | 
			
		||||
    seed: Optional[Seed] = None
 | 
			
		||||
    steps: Optional[Steps] = None
 | 
			
		||||
    style_preset: Optional[StylePreset] = None
 | 
			
		||||
    extras: Optional[Extras] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TextToImageRequestBody(GenerationRequestOptionalParams):
 | 
			
		||||
    height: Optional[DiffuseImageHeight] = None
 | 
			
		||||
    width: Optional[DiffuseImageWidth] = None
 | 
			
		||||
    text_prompts: TextPrompts
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ImageToImageUpscaleRequestBody(BaseModel):
 | 
			
		||||
    image: InitImage
 | 
			
		||||
    width: Optional[UpscaleImageWidth]
 | 
			
		||||
    height: Optional[UpscaleImageHeight]
 | 
			
		||||
 | 
			
		||||
    @validator("width", always=True)
 | 
			
		||||
    def mutually_exclusive(cls, v, values):
 | 
			
		||||
        if values["height"] is not None and v:
 | 
			
		||||
            raise ValueError("You can only specify one of width and height.")
 | 
			
		||||
        return v
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BalanceResponseBody(BaseModel):
 | 
			
		||||
    credits: float
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ListEnginesResponseBody = List[Engine]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FinishReason(StrEnum):
 | 
			
		||||
    SUCCESS = "SUCCESS"
 | 
			
		||||
    ERROR = "ERROR"
 | 
			
		||||
    CONTENT_FILTERED = "CONTENT_FILTERED"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Image(BaseModel):
 | 
			
		||||
    base64: Optional[str] = None
 | 
			
		||||
    finishReason: Optional[FinishReason] = None
 | 
			
		||||
    seed: Optional[float] = None
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def file(self):
 | 
			
		||||
        assert self.base64 is not None
 | 
			
		||||
        return io.BytesIO(base64.b64decode(self.base64))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OrganizationMembership(BaseModel):
 | 
			
		||||
    id: str
 | 
			
		||||
    is_default: bool
 | 
			
		||||
    name: str
 | 
			
		||||
    role: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AccountResponseBody(BaseModel):
 | 
			
		||||
    email: EmailStr
 | 
			
		||||
    id: str
 | 
			
		||||
    organizations: List[OrganizationMembership]
 | 
			
		||||
    profile_picture: Optional[AnyUrl] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TextToImageResponseBody(BaseModel):
 | 
			
		||||
    artifacts: List[Image]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ImageToImageResponseBody(BaseModel):
 | 
			
		||||
    artifacts: List[Image]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ImageToImageUpscaleResponseBody(BaseModel):
 | 
			
		||||
    artifacts: List[Image]
 | 
			
		||||
@ -0,0 +1,6 @@
 | 
			
		||||
from typing import Dict, Mapping, TypeVar
 | 
			
		||||
 | 
			
		||||
T, U = TypeVar("T"), TypeVar("U")
 | 
			
		||||
 | 
			
		||||
def omit_none(m: Mapping[T, U]) -> Dict[T, U]:
 | 
			
		||||
    return {k: v for k, v in m.items() if v is not None}
 | 
			
		||||
					Loading…
					
					
				
		Reference in New Issue