wip
							parent
							
								
									f9859d1767
								
							
						
					
					
						commit
						ab492417d6
					
				@ -0,0 +1,41 @@
 | 
				
			|||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					name: Publish Python distributions to PyPI and TestPyPI
 | 
				
			||||||
 | 
					on:
 | 
				
			||||||
 | 
					  push:
 | 
				
			||||||
 | 
					    branches: [main]
 | 
				
			||||||
 | 
					  pull_request:
 | 
				
			||||||
 | 
					    branches: [main]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					jobs:
 | 
				
			||||||
 | 
					  validate:
 | 
				
			||||||
 | 
					    name: Run static analysis on the code
 | 
				
			||||||
 | 
					    runs-on: ubuntu-22.04
 | 
				
			||||||
 | 
					    steps:
 | 
				
			||||||
 | 
					      - name: Lint with flake8
 | 
				
			||||||
 | 
					        run: |
 | 
				
			||||||
 | 
					          pip install flake8
 | 
				
			||||||
 | 
					          # stop the build if there are Python syntax errors or undefined names
 | 
				
			||||||
 | 
					          flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
 | 
				
			||||||
 | 
					          # exit-zero treats all errors as warnings.
 | 
				
			||||||
 | 
					          flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  build-and-publish:
 | 
				
			||||||
 | 
					    name: Build and publish Python distribution
 | 
				
			||||||
 | 
					    runs-on: ubuntu-22.04
 | 
				
			||||||
 | 
					    steps:
 | 
				
			||||||
 | 
					      - uses: actions/checkout@main
 | 
				
			||||||
 | 
					      - name: Initialize Python 3.11
 | 
				
			||||||
 | 
					        uses: actions/setup-python@v1
 | 
				
			||||||
 | 
					        with:
 | 
				
			||||||
 | 
					          python-version: 3.11
 | 
				
			||||||
 | 
					      - name: Install dependencies
 | 
				
			||||||
 | 
					        run: |
 | 
				
			||||||
 | 
					          python -m pip install --upgrade pip build
 | 
				
			||||||
 | 
					      - name: Build binary wheel and a source tarball
 | 
				
			||||||
 | 
					        run: python -m build 
 | 
				
			||||||
 | 
					      - name: Publish distribution to PyPI
 | 
				
			||||||
 | 
					        uses: pypa/gh-action-pypi-publish@v1.8.5 
 | 
				
			||||||
 | 
					        with:
 | 
				
			||||||
 | 
					          password: ${{ secrets.PIPY_PASSWORD }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1,2 +1,71 @@
 | 
				
			|||||||
# stabilityai
 | 
					# Stability AI
 | 
				
			||||||
Client library for the stability.ai REST API
 | 
					
 | 
				
			||||||
 | 
					An **UNOFFICIAL** client library for the stability.ai REST API.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Motivation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The official `stability-sdk` is a based on gRPC and also really hard to use. Like look at this, this
 | 
				
			||||||
 | 
					ignores setting up the SDK.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					from stability_sdk import client
 | 
				
			||||||
 | 
					import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					answers = stability_api.generate(
 | 
				
			||||||
 | 
					    prompt="a rocket-ship launching from rolling greens with blue daisies",
 | 
				
			||||||
 | 
					    seed=892226758,
 | 
				
			||||||
 | 
					    steps=30,
 | 
				
			||||||
 | 
					    cfg_scale=8.0,
 | 
				
			||||||
 | 
					    width=512,
 | 
				
			||||||
 | 
					    height=512,
 | 
				
			||||||
 | 
					    sampler=generation.SAMPLER_K_DPMPP_2M
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					for resp in answers:
 | 
				
			||||||
 | 
					    for artifact in resp.artifacts:
 | 
				
			||||||
 | 
					        if artifact.finish_reason == generation.FILTER:
 | 
				
			||||||
 | 
					            warnings.warn(
 | 
				
			||||||
 | 
					                "Your request activated the API's safety filters and could not be processed."
 | 
				
			||||||
 | 
					                "Please modify the prompt and try again.")
 | 
				
			||||||
 | 
					        if artifact.type == generation.ARTIFACT_IMAGE:
 | 
				
			||||||
 | 
					            global img
 | 
				
			||||||
 | 
					            img = Image.open(io.BytesIO(artifact.binary))
 | 
				
			||||||
 | 
					            img.save(str(artifact.seed)+ ".png")
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This for loop is *magic*. You must loop the results in exactly this way or the gRPC library won't
 | 
				
			||||||
 | 
					work. It's about an unpythonic as a library can get.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## My Take
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					# Set the STABILITY_API_KEY environment variable.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from stabilityai.client import AsyncStabilityClient
 | 
				
			||||||
 | 
					from stabilityai.models import Sampler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def example():
 | 
				
			||||||
 | 
					  async with AsyncStabilityClient() as stability:
 | 
				
			||||||
 | 
					    results = await stability.text_to_image(
 | 
				
			||||||
 | 
					        text_prompt="a rocket-ship launching from rolling greens with blue daisies",
 | 
				
			||||||
 | 
					        # All these are optional and have sane defaults.
 | 
				
			||||||
 | 
					        seed=892226758,
 | 
				
			||||||
 | 
					        steps=30,
 | 
				
			||||||
 | 
					        cfg_scale=8.0,
 | 
				
			||||||
 | 
					        width=512,
 | 
				
			||||||
 | 
					        height=512,
 | 
				
			||||||
 | 
					        sampler=Sampler.K_DPMPP_2M,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    artifact = results.artifacts[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    img = Image.open(artifact.file)
 | 
				
			||||||
 | 
					    img.save(artifact.file.name)
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Additional Nicetieis
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					* Instead of manually checking `FINISH_REASON` an appropriate exception will automatically be
 | 
				
			||||||
 | 
					    raised.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					* Full mypy/pyright support for type checking and autocomplete.
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,50 @@
 | 
				
			|||||||
 | 
					# pyproject.toml
 | 
				
			||||||
 | 
					# https://packaging.python.org/en/latest/specifications/declaring-project-metadata
 | 
				
			||||||
 | 
					# https://pip.pypa.io/en/stable/reference/build-system/pyproject-toml
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[build-system]
 | 
				
			||||||
 | 
					requires = ["setuptools", "setuptools-scm", "build"]
 | 
				
			||||||
 | 
					build-backend = "setuptools.build_meta"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[project]
 | 
				
			||||||
 | 
					name = "stabilityai"
 | 
				
			||||||
 | 
					description = "*Unofficial* client for the Stability REST API"
 | 
				
			||||||
 | 
					authors = [
 | 
				
			||||||
 | 
						{name = "Estelle Poulin", email = "dev@inspiredby.es"},
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					readme = "README.md"
 | 
				
			||||||
 | 
					requires-python = ">=3.11"
 | 
				
			||||||
 | 
					keywords = ["stabilityai", "bot"]
 | 
				
			||||||
 | 
					license = {text = "GPLv3"}
 | 
				
			||||||
 | 
					classifiers = [
 | 
				
			||||||
 | 
					    "Programming Language :: Python :: 3",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					dynamic = ["version", "dependencies"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[project.urls]
 | 
				
			||||||
 | 
					homepage = "https://github.com/estheruary/stabilityai"
 | 
				
			||||||
 | 
					repository = "https://github.com/estheruary/stabilityai"
 | 
				
			||||||
 | 
					changelog = "https://github.com/estheruary/stabilityai/-/blob/main/CHANGELOG.md"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[tool.setuptools]
 | 
				
			||||||
 | 
					packages = ["stabilityai"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[tool.setuptools.dynamic]
 | 
				
			||||||
 | 
					version = {attr = "stabilityai.__version__"}
 | 
				
			||||||
 | 
					dependencies = {file = ["requirements.txt"]}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[tool.black]
 | 
				
			||||||
 | 
					line-length = 100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[tool.isort]
 | 
				
			||||||
 | 
					profile = "black"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[tool.vulture]
 | 
				
			||||||
 | 
					ignore_names = ["self", "cls"]
 | 
				
			||||||
@ -0,0 +1,2 @@
 | 
				
			|||||||
 | 
					aiohttp==3.8.4
 | 
				
			||||||
 | 
					pydantic [email]==1.10.2
 | 
				
			||||||
@ -0,0 +1 @@
 | 
				
			|||||||
 | 
					__version__ = "1.0.1"
 | 
				
			||||||
@ -0,0 +1,271 @@
 | 
				
			|||||||
 | 
					import json
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import aiohttp
 | 
				
			||||||
 | 
					from pydantic import (
 | 
				
			||||||
 | 
					    validate_arguments,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from stabilityai.constants import (
 | 
				
			||||||
 | 
					    DEFAULT_GENERATION_ENGINE,
 | 
				
			||||||
 | 
					    DEFAULT_STABILITY_API_HOST,
 | 
				
			||||||
 | 
					    DEFAULT_STABILITY_CLIENT_ID,
 | 
				
			||||||
 | 
					    DEFAULT_STABILITY_CLIENT_VERSION,
 | 
				
			||||||
 | 
					    DEFAULT_UPSCALE_ENGINE,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from stabilityai.exceptions import YouNeedToUseAContextManager
 | 
				
			||||||
 | 
					from stabilityai.models import (
 | 
				
			||||||
 | 
					    AccountResponseBody,
 | 
				
			||||||
 | 
					    BalanceResponseBody,
 | 
				
			||||||
 | 
					    CfgScale,
 | 
				
			||||||
 | 
					    ClipGuidancePreset,
 | 
				
			||||||
 | 
					    DiffuseImageHeight,
 | 
				
			||||||
 | 
					    DiffuseImageWidth,
 | 
				
			||||||
 | 
					    Engine,
 | 
				
			||||||
 | 
					    Extras,
 | 
				
			||||||
 | 
					    ImageToImageRequestBody,
 | 
				
			||||||
 | 
					    ImageToImageResponseBody,
 | 
				
			||||||
 | 
					    ImageToImageUpscaleRequestBody,
 | 
				
			||||||
 | 
					    ImageToImageUpscaleResponseBody,
 | 
				
			||||||
 | 
					    InitImage,
 | 
				
			||||||
 | 
					    InitImageMode,
 | 
				
			||||||
 | 
					    InitImageStrength,
 | 
				
			||||||
 | 
					    ListEnginesResponseBody,
 | 
				
			||||||
 | 
					    Sampler,
 | 
				
			||||||
 | 
					    Samples,
 | 
				
			||||||
 | 
					    Seed,
 | 
				
			||||||
 | 
					    SingleTextPrompt,
 | 
				
			||||||
 | 
					    Steps,
 | 
				
			||||||
 | 
					    StylePreset,
 | 
				
			||||||
 | 
					    TextPrompt,
 | 
				
			||||||
 | 
					    TextPrompts,
 | 
				
			||||||
 | 
					    TextToImageRequestBody,
 | 
				
			||||||
 | 
					    TextToImageResponseBody,
 | 
				
			||||||
 | 
					    UpscaleImageHeight,
 | 
				
			||||||
 | 
					    UpscaleImageWidth,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from stabilityai.utils import omit_none
 | 
				
			||||||
 | 
					import textwrap
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AsyncStabilityClient:
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        api_host: str = os.environ.get("STABILITY_API_HOST", DEFAULT_STABILITY_API_HOST),
 | 
				
			||||||
 | 
					        api_key: Optional[str] = os.environ.get("STABILITY_API_KEY"),
 | 
				
			||||||
 | 
					        client_id: str = os.environ.get("STABILITY_CLIENT_ID", DEFAULT_STABILITY_CLIENT_ID),
 | 
				
			||||||
 | 
					        client_version: str = os.environ.get(
 | 
				
			||||||
 | 
					            "STABILITY_CLIENT_VERSION", DEFAULT_STABILITY_CLIENT_VERSION
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        organization: Optional[str] = os.environ.get("STABILITY_CLIENT_ORGANIZATION"),
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        self.api_host = api_host
 | 
				
			||||||
 | 
					        self.api_key = api_key
 | 
				
			||||||
 | 
					        self.client_id = client_id
 | 
				
			||||||
 | 
					        self.client_version = client_version
 | 
				
			||||||
 | 
					        self.organization = organization
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def __aenter__(self):
 | 
				
			||||||
 | 
					        self.session = await aiohttp.ClientSession(
 | 
				
			||||||
 | 
					            base_url=self.api_host,
 | 
				
			||||||
 | 
					            headers={
 | 
				
			||||||
 | 
					                "Content-Type": "application/json",
 | 
				
			||||||
 | 
					                "Accept": "application/json",
 | 
				
			||||||
 | 
					                "Authorization": f"Bearer {self.api_key}",
 | 
				
			||||||
 | 
					                "Stability-Client-ID": self.client_id,
 | 
				
			||||||
 | 
					                "Stability-Client-Version": self.client_version,
 | 
				
			||||||
 | 
					                **({"Organization": self.organization} if self.organization else {}),
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            raise_for_status=True,
 | 
				
			||||||
 | 
					        ).__aenter__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def __aexit__(self, exc_type, exc_val, exc_tb):
 | 
				
			||||||
 | 
					        return await self.session.__aexit__(exc_type, exc_val, exc_tb)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def get_engines(self) -> ListEnginesResponseBody:
 | 
				
			||||||
 | 
					        res = await self.session.get("/v1/engines/list")
 | 
				
			||||||
 | 
					        return await res.json()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def get_account(self) -> AccountResponseBody:
 | 
				
			||||||
 | 
					        res = await self.session.get("/v1/user/account")
 | 
				
			||||||
 | 
					        return await res.json()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def get_account_balance(self) -> BalanceResponseBody:
 | 
				
			||||||
 | 
					        res = await self.session.get("/v1/user/balance")
 | 
				
			||||||
 | 
					        return await res.json()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _oops_no_session(self):
 | 
				
			||||||
 | 
					        if not self.session:
 | 
				
			||||||
 | 
					            raise YouNeedToUseAContextManager(
 | 
				
			||||||
 | 
					                textwrap.dedent(
 | 
				
			||||||
 | 
					                    f"""\
 | 
				
			||||||
 | 
					                {self.__class__.__name__} keeps a aiohttp.ClientSession under
 | 
				
			||||||
 | 
					                the hood and needs to be closed when you're done with it. But
 | 
				
			||||||
 | 
					                since there isn't an async version of __del__ we have to use
 | 
				
			||||||
 | 
					                __aenter__/__aexit__ instead. Apologies.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                Instead of
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    myclient = {self.__class__.__name__}()
 | 
				
			||||||
 | 
					                    myclient.text_to_image(...)
 | 
				
			||||||
 | 
					                                
 | 
				
			||||||
 | 
					                Do this
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    async with {self.__class__.__name__} as myclient:
 | 
				
			||||||
 | 
					                        myclient.text_to_image(...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                Note that it's `async with` and not `with`.
 | 
				
			||||||
 | 
					                """
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _normalize_text_prompts(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        text_prompts: Optional[TextPrompts],
 | 
				
			||||||
 | 
					        text_prompt: Optional[SingleTextPrompt]
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        if not bool(text_prompt) ^ bool(text_prompts):
 | 
				
			||||||
 | 
					            raise RuntimeError(
 | 
				
			||||||
 | 
					                textwrap.dedent(
 | 
				
			||||||
 | 
					                    f"""\
 | 
				
			||||||
 | 
					                    You must provide one of text_prompt and text_prompts.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    Do this
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        stability.text_to_image(
 | 
				
			||||||
 | 
					                            text_prompt="A beautiful sunrise"
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    Or this
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        from stabilityai.models import TextPrompt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        stability.text_to_image(
 | 
				
			||||||
 | 
					                            text_prompts=[
 | 
				
			||||||
 | 
					                                TextPrompt(text="A beautiful sunrise", weight=1.0)
 | 
				
			||||||
 | 
					                            ],
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                    """ 
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if text_prompt:
 | 
				
			||||||
 | 
					            text_prompts = [TextPrompt(text=text_prompt)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # After this moment text_prompts can't be None.
 | 
				
			||||||
 | 
					        assert text_prompts is not None
 | 
				
			||||||
 | 
					        return text_prompts
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @validate_arguments
 | 
				
			||||||
 | 
					    async def text_to_image(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        text_prompts: Optional[TextPrompts] = None,
 | 
				
			||||||
 | 
					        text_prompt: Optional[SingleTextPrompt] = None,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        engine: Optional[Engine] = None,
 | 
				
			||||||
 | 
					        cfg_scale: Optional[CfgScale] = None,
 | 
				
			||||||
 | 
					        clip_guidance_preset: Optional[ClipGuidancePreset] = None,
 | 
				
			||||||
 | 
					        height: Optional[DiffuseImageHeight] = None,
 | 
				
			||||||
 | 
					        sampler: Optional[Sampler] = None,
 | 
				
			||||||
 | 
					        samples: Optional[Samples] = None,
 | 
				
			||||||
 | 
					        seed: Optional[Seed] = None,
 | 
				
			||||||
 | 
					        steps: Optional[Steps] = None,
 | 
				
			||||||
 | 
					        style_preset: Optional[StylePreset] = None,
 | 
				
			||||||
 | 
					        width: Optional[DiffuseImageWidth] = None,
 | 
				
			||||||
 | 
					        extras: Optional[Extras] = None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        self._oops_no_session()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        text_prompts = self._normalize_text_prompts(text_prompts, text_prompt)
 | 
				
			||||||
 | 
					        engine_id = engine.id if engine else DEFAULT_GENERATION_ENGINE
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        request_body = TextToImageRequestBody(
 | 
				
			||||||
 | 
					            cfg_scale=cfg_scale,
 | 
				
			||||||
 | 
					            clip_guidance_preset=clip_guidance_preset,
 | 
				
			||||||
 | 
					            height=height,
 | 
				
			||||||
 | 
					            sampler=sampler,
 | 
				
			||||||
 | 
					            samples=samples,
 | 
				
			||||||
 | 
					            seed=seed,
 | 
				
			||||||
 | 
					            steps=steps,
 | 
				
			||||||
 | 
					            style_preset=style_preset,
 | 
				
			||||||
 | 
					            text_prompts=text_prompts,
 | 
				
			||||||
 | 
					            width=width,
 | 
				
			||||||
 | 
					            extras=extras,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        res = await self.session.post(
 | 
				
			||||||
 | 
					            f"/v1/generation/{engine_id}/text-to-image",
 | 
				
			||||||
 | 
					            data=json.dumps(omit_none(json.loads(request_body.json()))),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        Sampler.K_DPMPP_2M
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return TextToImageResponseBody.parse_obj(await res.json())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @validate_arguments
 | 
				
			||||||
 | 
					    async def image_to_image(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        text_prompts: TextPrompts,
 | 
				
			||||||
 | 
					        text_prompt: SingleTextPrompt,
 | 
				
			||||||
 | 
					        init_image: InitImage,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        init_image_mode: Optional[InitImageMode] = None,
 | 
				
			||||||
 | 
					        image_strength: InitImageStrength,
 | 
				
			||||||
 | 
					        engine: Optional[Engine] = None,
 | 
				
			||||||
 | 
					        cfg_scale: Optional[CfgScale] = None,
 | 
				
			||||||
 | 
					        clip_guidance_preset: Optional[ClipGuidancePreset] = None,
 | 
				
			||||||
 | 
					        sampler: Optional[Sampler] = None,
 | 
				
			||||||
 | 
					        samples: Optional[Samples] = None,
 | 
				
			||||||
 | 
					        seed: Optional[Seed] = None,
 | 
				
			||||||
 | 
					        steps: Optional[Steps] = None,
 | 
				
			||||||
 | 
					        style_preset: Optional[StylePreset] = None,
 | 
				
			||||||
 | 
					        extras: Optional[Extras] = None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        self._oops_no_session()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        text_prompts = self._normalize_text_prompts(text_prompts, text_prompt)
 | 
				
			||||||
 | 
					        engine_id = engine.id if engine else DEFAULT_GENERATION_ENGINE
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        request_body = ImageToImageRequestBody(
 | 
				
			||||||
 | 
					            cfg_scale=cfg_scale,
 | 
				
			||||||
 | 
					            clip_guidance_preset=clip_guidance_preset,
 | 
				
			||||||
 | 
					            init_image=init_image,
 | 
				
			||||||
 | 
					            init_image_mode=init_image_mode,
 | 
				
			||||||
 | 
					            image_strength=image_strength,
 | 
				
			||||||
 | 
					            sampler=sampler,
 | 
				
			||||||
 | 
					            samples=samples,
 | 
				
			||||||
 | 
					            seed=seed,
 | 
				
			||||||
 | 
					            steps=steps,
 | 
				
			||||||
 | 
					            style_preset=style_preset,
 | 
				
			||||||
 | 
					            text_prompts=text_prompts,
 | 
				
			||||||
 | 
					            extras=extras,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        res = await self.session.post(
 | 
				
			||||||
 | 
					            f"/v1/generation/{engine_id}/text-to-image",
 | 
				
			||||||
 | 
					            data=json.dumps(omit_none(json.loads(request_body.json()))),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return ImageToImageResponseBody.parse_obj(await res.json())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def image_to_image_upscale(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        image: InitImage,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        engine: Optional[Engine] = None,
 | 
				
			||||||
 | 
					        width: Optional[UpscaleImageWidth] = None,
 | 
				
			||||||
 | 
					        height: Optional[UpscaleImageHeight] = None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        self._oops_no_session()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        engine_id = engine.id if engine else DEFAULT_UPSCALE_ENGINE
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        request_body = ImageToImageUpscaleRequestBody(image=image, width=width, height=height)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        res = await self.session.post(
 | 
				
			||||||
 | 
					            f"/v1/generation/{engine_id}/image-to-image/upscale",
 | 
				
			||||||
 | 
					            data=json.dumps(omit_none(json.loads(request_body.json()))),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return ImageToImageUpscaleResponseBody.parse_obj(await res.json())
 | 
				
			||||||
@ -0,0 +1,9 @@
 | 
				
			|||||||
 | 
					from typing import Final
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DEFAULT_STABILITY_API_HOST: Final[str] = "https://api.stability.ai"
 | 
				
			||||||
 | 
					DEFAULT_STABILITY_CLIENT_ID: Final[str] = "Stability Python SDK"
 | 
				
			||||||
 | 
					DEFAULT_STABILITY_CLIENT_VERSION: Final[str] = "1.0.0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DEFAULT_GENERATION_ENGINE: Final[str] = "stable-diffusion-xl-beta-v2-2-2"
 | 
				
			||||||
 | 
					DEFAULT_UPSCALE_ENGINE: Final[str] = "esrgan-v1-x2plus"
 | 
				
			||||||
@ -0,0 +1,2 @@
 | 
				
			|||||||
 | 
					class YouNeedToUseAContextManager(Exception):
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
@ -0,0 +1,247 @@
 | 
				
			|||||||
 | 
					import base64
 | 
				
			||||||
 | 
					import io
 | 
				
			||||||
 | 
					from enum import StrEnum
 | 
				
			||||||
 | 
					from typing import Annotated, List, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from pydantic import (
 | 
				
			||||||
 | 
					    AnyUrl,
 | 
				
			||||||
 | 
					    BaseModel,
 | 
				
			||||||
 | 
					    EmailStr,
 | 
				
			||||||
 | 
					    confloat,
 | 
				
			||||||
 | 
					    conint,
 | 
				
			||||||
 | 
					    constr,
 | 
				
			||||||
 | 
					    validator,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Type(StrEnum):
 | 
				
			||||||
 | 
					    AUDIO = "AUDIO"
 | 
				
			||||||
 | 
					    CLASSIFICATION = "CLASSIFICATION"
 | 
				
			||||||
 | 
					    PICTURE = "PICTURE"
 | 
				
			||||||
 | 
					    STORAGE = "STORAGE"
 | 
				
			||||||
 | 
					    TEXT = "TEXT"
 | 
				
			||||||
 | 
					    VIDEO = "VIDEO"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Engine(BaseModel):
 | 
				
			||||||
 | 
					    description: str
 | 
				
			||||||
 | 
					    id: str
 | 
				
			||||||
 | 
					    name: str
 | 
				
			||||||
 | 
					    type: Type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Error(BaseModel):
 | 
				
			||||||
 | 
					    id: str
 | 
				
			||||||
 | 
					    name: str
 | 
				
			||||||
 | 
					    message: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					CfgScale = Annotated[float, confloat(ge=0.0, le=35.0)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ClipGuidancePreset(StrEnum):
 | 
				
			||||||
 | 
					    FAST_BLUE = "FAST_BLUE"
 | 
				
			||||||
 | 
					    FAST_GREEN = "FAST_GREEN"
 | 
				
			||||||
 | 
					    NONE = "NONE"
 | 
				
			||||||
 | 
					    SIMPLE = "SIMPLE"
 | 
				
			||||||
 | 
					    SLOW = "SLOW"
 | 
				
			||||||
 | 
					    SLOWER = "SLOWER"
 | 
				
			||||||
 | 
					    SLOWEST = "SLOWEST"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					UpscaleImageHeight = Annotated[int, conint(ge=512)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					UpscaleImageWidth = Annotated[int, conint(ge=512)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DiffuseImageHeight = Annotated[int, conint(ge=128, multiple_of=64)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DiffuseImageWidth = Annotated[int, conint(ge=128, multiple_of=64)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Sampler(StrEnum):
 | 
				
			||||||
 | 
					    DDIM = "DDIM"
 | 
				
			||||||
 | 
					    DDPM = "DDPM"
 | 
				
			||||||
 | 
					    K_DPMPP_2M = "K_DPMPP_2M"
 | 
				
			||||||
 | 
					    K_DPMPP_2S_ANCESTRAL = "K_DPMPP_2S_ANCESTRAL"
 | 
				
			||||||
 | 
					    K_DPM_2 = "K_DPM_2"
 | 
				
			||||||
 | 
					    K_DPM_2_ANCESTRAL = "K_DPM_2_ANCESTRAL"
 | 
				
			||||||
 | 
					    K_EULER = "K_EULER"
 | 
				
			||||||
 | 
					    K_EULER_ANCESTRAL = "K_EULER_ANCESTRAL"
 | 
				
			||||||
 | 
					    K_HEUN = "K_HEUN"
 | 
				
			||||||
 | 
					    K_LMS = "K_LMS"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Samples = Annotated[int, conint(ge=1, le=10)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Seed = Annotated[int, conint(ge=0, le=4294967295)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Steps = Annotated[int, conint(ge=10, le=150)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Extras = dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StylePreset(StrEnum):
 | 
				
			||||||
 | 
					    enhance = "enhance"
 | 
				
			||||||
 | 
					    anime = "anime"
 | 
				
			||||||
 | 
					    photographic = "photographic"
 | 
				
			||||||
 | 
					    digital_art = "digital-art"
 | 
				
			||||||
 | 
					    comic_book = "comic-book"
 | 
				
			||||||
 | 
					    fantasy_art = "fantasy-art"
 | 
				
			||||||
 | 
					    line_art = "line-art"
 | 
				
			||||||
 | 
					    analog_film = "analog-film"
 | 
				
			||||||
 | 
					    neon_punk = "neon-punk"
 | 
				
			||||||
 | 
					    isometric = "isometric"
 | 
				
			||||||
 | 
					    low_poly = "low-poly"
 | 
				
			||||||
 | 
					    origami = "origami"
 | 
				
			||||||
 | 
					    modeling_compound = "modeling-compound"
 | 
				
			||||||
 | 
					    cinematic = "cinematic"
 | 
				
			||||||
 | 
					    field_3d_model = "3d-model"
 | 
				
			||||||
 | 
					    pixel_art = "pixel-art"
 | 
				
			||||||
 | 
					    tile_texture = "tile-texture"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TextPrompt(BaseModel):
 | 
				
			||||||
 | 
					    text: Annotated[str, constr(max_length=2000)]
 | 
				
			||||||
 | 
					    weight: Optional[float] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					SingleTextPrompt = str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TextPrompts = List[TextPrompt]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					InitImage = bytes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					InitImageStrength = Annotated[float, confloat(ge=0.0, le=1.0)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class InitImageMode(StrEnum):
 | 
				
			||||||
 | 
					    IMAGE_STRENGTH = "IMAGE_STRENGTH"
 | 
				
			||||||
 | 
					    STEP_SCHEDULE = "STEP_SCHEDULE"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					StepScheduleStart = Annotated[float, confloat(ge=0.0, le=1.0)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					StepScheduleEnd = Annotated[float, confloat(ge=0.0, le=1.0)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MaskImage = bytes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MaskSource = str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GenerationRequestOptionalParams(BaseModel):
 | 
				
			||||||
 | 
					    cfg_scale: Optional[CfgScale] = None
 | 
				
			||||||
 | 
					    clip_guidance_preset: Optional[ClipGuidancePreset] = None
 | 
				
			||||||
 | 
					    sampler: Optional[Sampler] = None
 | 
				
			||||||
 | 
					    samples: Optional[Samples] = None
 | 
				
			||||||
 | 
					    seed: Optional[Seed] = None
 | 
				
			||||||
 | 
					    steps: Optional[Steps] = None
 | 
				
			||||||
 | 
					    style_preset: Optional[StylePreset] = None
 | 
				
			||||||
 | 
					    extras: Optional[Extras] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LatentUpscalerUpscaleRequestBody(BaseModel):
 | 
				
			||||||
 | 
					    image: InitImage
 | 
				
			||||||
 | 
					    width: Optional[UpscaleImageWidth] = None
 | 
				
			||||||
 | 
					    height: Optional[UpscaleImageHeight] = None
 | 
				
			||||||
 | 
					    text_prompts: Optional[TextPrompts] = None
 | 
				
			||||||
 | 
					    seed: Optional[Seed] = None
 | 
				
			||||||
 | 
					    steps: Optional[Steps] = None
 | 
				
			||||||
 | 
					    cfg_scale: Optional[CfgScale] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ImageToImageRequestBody(BaseModel):
 | 
				
			||||||
 | 
					    text_prompts: TextPrompts
 | 
				
			||||||
 | 
					    init_image: InitImage
 | 
				
			||||||
 | 
					    init_image_mode: Optional[InitImageMode] = InitImageMode("IMAGE_STRENGTH")
 | 
				
			||||||
 | 
					    image_strength: Optional[InitImageStrength] = None
 | 
				
			||||||
 | 
					    step_schedule_start: Optional[StepScheduleStart] = None
 | 
				
			||||||
 | 
					    step_schedule_end: Optional[StepScheduleEnd] = None
 | 
				
			||||||
 | 
					    cfg_scale: Optional[CfgScale] = None
 | 
				
			||||||
 | 
					    clip_guidance_preset: Optional[ClipGuidancePreset] = None
 | 
				
			||||||
 | 
					    sampler: Optional[Sampler] = None
 | 
				
			||||||
 | 
					    samples: Optional[Samples] = None
 | 
				
			||||||
 | 
					    seed: Optional[Seed] = None
 | 
				
			||||||
 | 
					    steps: Optional[Steps] = None
 | 
				
			||||||
 | 
					    style_preset: Optional[StylePreset] = None
 | 
				
			||||||
 | 
					    extras: Optional[Extras] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MaskingRequestBody(BaseModel):
 | 
				
			||||||
 | 
					    init_image: InitImage
 | 
				
			||||||
 | 
					    mask_source: MaskSource
 | 
				
			||||||
 | 
					    mask_image: Optional[MaskImage] = None
 | 
				
			||||||
 | 
					    text_prompts: TextPrompts
 | 
				
			||||||
 | 
					    cfg_scale: Optional[CfgScale] = None
 | 
				
			||||||
 | 
					    clip_guidance_preset: Optional[ClipGuidancePreset] = None
 | 
				
			||||||
 | 
					    sampler: Optional[Sampler] = None
 | 
				
			||||||
 | 
					    samples: Optional[Samples] = None
 | 
				
			||||||
 | 
					    seed: Optional[Seed] = None
 | 
				
			||||||
 | 
					    steps: Optional[Steps] = None
 | 
				
			||||||
 | 
					    style_preset: Optional[StylePreset] = None
 | 
				
			||||||
 | 
					    extras: Optional[Extras] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TextToImageRequestBody(GenerationRequestOptionalParams):
 | 
				
			||||||
 | 
					    height: Optional[DiffuseImageHeight] = None
 | 
				
			||||||
 | 
					    width: Optional[DiffuseImageWidth] = None
 | 
				
			||||||
 | 
					    text_prompts: TextPrompts
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ImageToImageUpscaleRequestBody(BaseModel):
 | 
				
			||||||
 | 
					    image: InitImage
 | 
				
			||||||
 | 
					    width: Optional[UpscaleImageWidth]
 | 
				
			||||||
 | 
					    height: Optional[UpscaleImageHeight]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @validator("width", always=True)
 | 
				
			||||||
 | 
					    def mutually_exclusive(cls, v, values):
 | 
				
			||||||
 | 
					        if values["height"] is not None and v:
 | 
				
			||||||
 | 
					            raise ValueError("You can only specify one of width and height.")
 | 
				
			||||||
 | 
					        return v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BalanceResponseBody(BaseModel):
 | 
				
			||||||
 | 
					    credits: float
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ListEnginesResponseBody = List[Engine]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FinishReason(StrEnum):
 | 
				
			||||||
 | 
					    SUCCESS = "SUCCESS"
 | 
				
			||||||
 | 
					    ERROR = "ERROR"
 | 
				
			||||||
 | 
					    CONTENT_FILTERED = "CONTENT_FILTERED"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Image(BaseModel):
 | 
				
			||||||
 | 
					    base64: Optional[str] = None
 | 
				
			||||||
 | 
					    finishReason: Optional[FinishReason] = None
 | 
				
			||||||
 | 
					    seed: Optional[float] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def file(self):
 | 
				
			||||||
 | 
					        assert self.base64 is not None
 | 
				
			||||||
 | 
					        return io.BytesIO(base64.b64decode(self.base64))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class OrganizationMembership(BaseModel):
 | 
				
			||||||
 | 
					    id: str
 | 
				
			||||||
 | 
					    is_default: bool
 | 
				
			||||||
 | 
					    name: str
 | 
				
			||||||
 | 
					    role: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AccountResponseBody(BaseModel):
 | 
				
			||||||
 | 
					    email: EmailStr
 | 
				
			||||||
 | 
					    id: str
 | 
				
			||||||
 | 
					    organizations: List[OrganizationMembership]
 | 
				
			||||||
 | 
					    profile_picture: Optional[AnyUrl] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TextToImageResponseBody(BaseModel):
 | 
				
			||||||
 | 
					    artifacts: List[Image]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ImageToImageResponseBody(BaseModel):
 | 
				
			||||||
 | 
					    artifacts: List[Image]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ImageToImageUpscaleResponseBody(BaseModel):
 | 
				
			||||||
 | 
					    artifacts: List[Image]
 | 
				
			||||||
@ -0,0 +1,6 @@
 | 
				
			|||||||
 | 
					from typing import Dict, Mapping, TypeVar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					T, U = TypeVar("T"), TypeVar("U")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def omit_none(m: Mapping[T, U]) -> Dict[T, U]:
 | 
				
			||||||
 | 
					    return {k: v for k, v in m.items() if v is not None}
 | 
				
			||||||
					Loading…
					
					
				
		Reference in New Issue