diff --git a/.github/workflows/build-and-publish.yml b/.github/workflows/build-and-publish.yml new file mode 100644 index 0000000..650fb79 --- /dev/null +++ b/.github/workflows/build-and-publish.yml @@ -0,0 +1,41 @@ +--- + +name: Publish Python distributions to PyPI and TestPyPI +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + validate: + name: Run static analysis on the code + runs-on: ubuntu-22.04 + steps: + - name: Lint with flake8 + run: | + pip install flake8 + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + + build-and-publish: + name: Build and publish Python distribution + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@main + - name: Initialize Python 3.11 + uses: actions/setup-python@v1 + with: + python-version: 3.11 + - name: Install dependencies + run: | + python -m pip install --upgrade pip build + - name: Build binary wheel and a source tarball + run: python -m build + - name: Publish distribution to PyPI + uses: pypa/gh-action-pypi-publish@v1.8.5 + with: + password: ${{ secrets.PIPY_PASSWORD }} + diff --git a/README.md b/README.md index aba5cbe..7eaf79c 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,71 @@ -# stabilityai -Client library for the stability.ai REST API +# Stability AI + +An **UNOFFICIAL** client library for the stability.ai REST API. + +## Motivation + +The official `stability-sdk` is a based on gRPC and also really hard to use. Like look at this, this +ignores setting up the SDK. + +```python +from stability_sdk import client +import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation + +answers = stability_api.generate( + prompt="a rocket-ship launching from rolling greens with blue daisies", + seed=892226758, + steps=30, + cfg_scale=8.0, + width=512, + height=512, + sampler=generation.SAMPLER_K_DPMPP_2M +) + +for resp in answers: + for artifact in resp.artifacts: + if artifact.finish_reason == generation.FILTER: + warnings.warn( + "Your request activated the API's safety filters and could not be processed." + "Please modify the prompt and try again.") + if artifact.type == generation.ARTIFACT_IMAGE: + global img + img = Image.open(io.BytesIO(artifact.binary)) + img.save(str(artifact.seed)+ ".png") +``` + +This for loop is *magic*. You must loop the results in exactly this way or the gRPC library won't +work. It's about an unpythonic as a library can get. + +## My Take + +```python +# Set the STABILITY_API_KEY environment variable. + +from stabilityai.client import AsyncStabilityClient +from stabilityai.models import Sampler + +async def example(): + async with AsyncStabilityClient() as stability: + results = await stability.text_to_image( + text_prompt="a rocket-ship launching from rolling greens with blue daisies", + # All these are optional and have sane defaults. + seed=892226758, + steps=30, + cfg_scale=8.0, + width=512, + height=512, + sampler=Sampler.K_DPMPP_2M, + ) + + artifact = results.artifacts[0] + + img = Image.open(artifact.file) + img.save(artifact.file.name) +``` + +## Additional Nicetieis + +* Instead of manually checking `FINISH_REASON` an appropriate exception will automatically be + raised. + +* Full mypy/pyright support for type checking and autocomplete. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..e5397a7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,50 @@ +# pyproject.toml +# https://packaging.python.org/en/latest/specifications/declaring-project-metadata +# https://pip.pypa.io/en/stable/reference/build-system/pyproject-toml + +[build-system] +requires = ["setuptools", "setuptools-scm", "build"] +build-backend = "setuptools.build_meta" + + +[project] +name = "stabilityai" +description = "*Unofficial* client for the Stability REST API" +authors = [ + {name = "Estelle Poulin", email = "dev@inspiredby.es"}, +] +readme = "README.md" +requires-python = ">=3.11" +keywords = ["stabilityai", "bot"] +license = {text = "GPLv3"} +classifiers = [ + "Programming Language :: Python :: 3", +] +dynamic = ["version", "dependencies"] + + +[project.urls] +homepage = "https://github.com/estheruary/stabilityai" +repository = "https://github.com/estheruary/stabilityai" +changelog = "https://github.com/estheruary/stabilityai/-/blob/main/CHANGELOG.md" + + +[tool.setuptools] +packages = ["stabilityai"] + + +[tool.setuptools.dynamic] +version = {attr = "stabilityai.__version__"} +dependencies = {file = ["requirements.txt"]} + + +[tool.black] +line-length = 100 + + +[tool.isort] +profile = "black" + + +[tool.vulture] +ignore_names = ["self", "cls"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..61926f6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +aiohttp==3.8.4 +pydantic [email]==1.10.2 diff --git a/stabilityai/__init__.py b/stabilityai/__init__.py new file mode 100644 index 0000000..5c4105c --- /dev/null +++ b/stabilityai/__init__.py @@ -0,0 +1 @@ +__version__ = "1.0.1" diff --git a/stabilityai/client.py b/stabilityai/client.py new file mode 100644 index 0000000..bec1e5a --- /dev/null +++ b/stabilityai/client.py @@ -0,0 +1,271 @@ +import json +import os +from typing import Optional + +import aiohttp +from pydantic import ( + validate_arguments, +) +from stabilityai.constants import ( + DEFAULT_GENERATION_ENGINE, + DEFAULT_STABILITY_API_HOST, + DEFAULT_STABILITY_CLIENT_ID, + DEFAULT_STABILITY_CLIENT_VERSION, + DEFAULT_UPSCALE_ENGINE, +) +from stabilityai.exceptions import YouNeedToUseAContextManager +from stabilityai.models import ( + AccountResponseBody, + BalanceResponseBody, + CfgScale, + ClipGuidancePreset, + DiffuseImageHeight, + DiffuseImageWidth, + Engine, + Extras, + ImageToImageRequestBody, + ImageToImageResponseBody, + ImageToImageUpscaleRequestBody, + ImageToImageUpscaleResponseBody, + InitImage, + InitImageMode, + InitImageStrength, + ListEnginesResponseBody, + Sampler, + Samples, + Seed, + SingleTextPrompt, + Steps, + StylePreset, + TextPrompt, + TextPrompts, + TextToImageRequestBody, + TextToImageResponseBody, + UpscaleImageHeight, + UpscaleImageWidth, +) +from stabilityai.utils import omit_none +import textwrap + + +class AsyncStabilityClient: + def __init__( + self, + api_host: str = os.environ.get("STABILITY_API_HOST", DEFAULT_STABILITY_API_HOST), + api_key: Optional[str] = os.environ.get("STABILITY_API_KEY"), + client_id: str = os.environ.get("STABILITY_CLIENT_ID", DEFAULT_STABILITY_CLIENT_ID), + client_version: str = os.environ.get( + "STABILITY_CLIENT_VERSION", DEFAULT_STABILITY_CLIENT_VERSION + ), + organization: Optional[str] = os.environ.get("STABILITY_CLIENT_ORGANIZATION"), + ) -> None: + self.api_host = api_host + self.api_key = api_key + self.client_id = client_id + self.client_version = client_version + self.organization = organization + + async def __aenter__(self): + self.session = await aiohttp.ClientSession( + base_url=self.api_host, + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": f"Bearer {self.api_key}", + "Stability-Client-ID": self.client_id, + "Stability-Client-Version": self.client_version, + **({"Organization": self.organization} if self.organization else {}), + }, + raise_for_status=True, + ).__aenter__() + + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return await self.session.__aexit__(exc_type, exc_val, exc_tb) + + async def get_engines(self) -> ListEnginesResponseBody: + res = await self.session.get("/v1/engines/list") + return await res.json() + + async def get_account(self) -> AccountResponseBody: + res = await self.session.get("/v1/user/account") + return await res.json() + + async def get_account_balance(self) -> BalanceResponseBody: + res = await self.session.get("/v1/user/balance") + return await res.json() + + def _oops_no_session(self): + if not self.session: + raise YouNeedToUseAContextManager( + textwrap.dedent( + f"""\ + {self.__class__.__name__} keeps a aiohttp.ClientSession under + the hood and needs to be closed when you're done with it. But + since there isn't an async version of __del__ we have to use + __aenter__/__aexit__ instead. Apologies. + + Instead of + + myclient = {self.__class__.__name__}() + myclient.text_to_image(...) + + Do this + + async with {self.__class__.__name__} as myclient: + myclient.text_to_image(...) + + Note that it's `async with` and not `with`. + """ + ) + ) + + def _normalize_text_prompts( + self, + text_prompts: Optional[TextPrompts], + text_prompt: Optional[SingleTextPrompt] + ): + if not bool(text_prompt) ^ bool(text_prompts): + raise RuntimeError( + textwrap.dedent( + f"""\ + You must provide one of text_prompt and text_prompts. + + Do this + + stability.text_to_image( + text_prompt="A beautiful sunrise" + ) + + Or this + + from stabilityai.models import TextPrompt + + stability.text_to_image( + text_prompts=[ + TextPrompt(text="A beautiful sunrise", weight=1.0) + ], + ) + """ + ) + ) + + if text_prompt: + text_prompts = [TextPrompt(text=text_prompt)] + + # After this moment text_prompts can't be None. + assert text_prompts is not None + return text_prompts + + @validate_arguments + async def text_to_image( + self, + text_prompts: Optional[TextPrompts] = None, + text_prompt: Optional[SingleTextPrompt] = None, + *, + engine: Optional[Engine] = None, + cfg_scale: Optional[CfgScale] = None, + clip_guidance_preset: Optional[ClipGuidancePreset] = None, + height: Optional[DiffuseImageHeight] = None, + sampler: Optional[Sampler] = None, + samples: Optional[Samples] = None, + seed: Optional[Seed] = None, + steps: Optional[Steps] = None, + style_preset: Optional[StylePreset] = None, + width: Optional[DiffuseImageWidth] = None, + extras: Optional[Extras] = None, + ): + self._oops_no_session() + + text_prompts = self._normalize_text_prompts(text_prompts, text_prompt) + engine_id = engine.id if engine else DEFAULT_GENERATION_ENGINE + + request_body = TextToImageRequestBody( + cfg_scale=cfg_scale, + clip_guidance_preset=clip_guidance_preset, + height=height, + sampler=sampler, + samples=samples, + seed=seed, + steps=steps, + style_preset=style_preset, + text_prompts=text_prompts, + width=width, + extras=extras, + ) + + res = await self.session.post( + f"/v1/generation/{engine_id}/text-to-image", + data=json.dumps(omit_none(json.loads(request_body.json()))), + ) + Sampler.K_DPMPP_2M + + return TextToImageResponseBody.parse_obj(await res.json()) + + @validate_arguments + async def image_to_image( + self, + text_prompts: TextPrompts, + text_prompt: SingleTextPrompt, + init_image: InitImage, + *, + init_image_mode: Optional[InitImageMode] = None, + image_strength: InitImageStrength, + engine: Optional[Engine] = None, + cfg_scale: Optional[CfgScale] = None, + clip_guidance_preset: Optional[ClipGuidancePreset] = None, + sampler: Optional[Sampler] = None, + samples: Optional[Samples] = None, + seed: Optional[Seed] = None, + steps: Optional[Steps] = None, + style_preset: Optional[StylePreset] = None, + extras: Optional[Extras] = None, + ): + self._oops_no_session() + + text_prompts = self._normalize_text_prompts(text_prompts, text_prompt) + engine_id = engine.id if engine else DEFAULT_GENERATION_ENGINE + + request_body = ImageToImageRequestBody( + cfg_scale=cfg_scale, + clip_guidance_preset=clip_guidance_preset, + init_image=init_image, + init_image_mode=init_image_mode, + image_strength=image_strength, + sampler=sampler, + samples=samples, + seed=seed, + steps=steps, + style_preset=style_preset, + text_prompts=text_prompts, + extras=extras, + ) + + res = await self.session.post( + f"/v1/generation/{engine_id}/text-to-image", + data=json.dumps(omit_none(json.loads(request_body.json()))), + ) + + return ImageToImageResponseBody.parse_obj(await res.json()) + + async def image_to_image_upscale( + self, + image: InitImage, + *, + engine: Optional[Engine] = None, + width: Optional[UpscaleImageWidth] = None, + height: Optional[UpscaleImageHeight] = None, + ): + self._oops_no_session() + + engine_id = engine.id if engine else DEFAULT_UPSCALE_ENGINE + + request_body = ImageToImageUpscaleRequestBody(image=image, width=width, height=height) + + res = await self.session.post( + f"/v1/generation/{engine_id}/image-to-image/upscale", + data=json.dumps(omit_none(json.loads(request_body.json()))), + ) + + return ImageToImageUpscaleResponseBody.parse_obj(await res.json()) diff --git a/stabilityai/constants.py b/stabilityai/constants.py new file mode 100644 index 0000000..84963dc --- /dev/null +++ b/stabilityai/constants.py @@ -0,0 +1,9 @@ +from typing import Final + + +DEFAULT_STABILITY_API_HOST: Final[str] = "https://api.stability.ai" +DEFAULT_STABILITY_CLIENT_ID: Final[str] = "Stability Python SDK" +DEFAULT_STABILITY_CLIENT_VERSION: Final[str] = "1.0.0" + +DEFAULT_GENERATION_ENGINE: Final[str] = "stable-diffusion-xl-beta-v2-2-2" +DEFAULT_UPSCALE_ENGINE: Final[str] = "esrgan-v1-x2plus" diff --git a/stabilityai/exceptions.py b/stabilityai/exceptions.py new file mode 100644 index 0000000..273a9c1 --- /dev/null +++ b/stabilityai/exceptions.py @@ -0,0 +1,2 @@ +class YouNeedToUseAContextManager(Exception): + pass diff --git a/stabilityai/models.py b/stabilityai/models.py new file mode 100644 index 0000000..13fa308 --- /dev/null +++ b/stabilityai/models.py @@ -0,0 +1,247 @@ +import base64 +import io +from enum import StrEnum +from typing import Annotated, List, Optional + +from pydantic import ( + AnyUrl, + BaseModel, + EmailStr, + confloat, + conint, + constr, + validator, +) + +class Type(StrEnum): + AUDIO = "AUDIO" + CLASSIFICATION = "CLASSIFICATION" + PICTURE = "PICTURE" + STORAGE = "STORAGE" + TEXT = "TEXT" + VIDEO = "VIDEO" + + +class Engine(BaseModel): + description: str + id: str + name: str + type: Type + + +class Error(BaseModel): + id: str + name: str + message: str + + +CfgScale = Annotated[float, confloat(ge=0.0, le=35.0)] + + +class ClipGuidancePreset(StrEnum): + FAST_BLUE = "FAST_BLUE" + FAST_GREEN = "FAST_GREEN" + NONE = "NONE" + SIMPLE = "SIMPLE" + SLOW = "SLOW" + SLOWER = "SLOWER" + SLOWEST = "SLOWEST" + + +UpscaleImageHeight = Annotated[int, conint(ge=512)] + +UpscaleImageWidth = Annotated[int, conint(ge=512)] + +DiffuseImageHeight = Annotated[int, conint(ge=128, multiple_of=64)] + +DiffuseImageWidth = Annotated[int, conint(ge=128, multiple_of=64)] + + +class Sampler(StrEnum): + DDIM = "DDIM" + DDPM = "DDPM" + K_DPMPP_2M = "K_DPMPP_2M" + K_DPMPP_2S_ANCESTRAL = "K_DPMPP_2S_ANCESTRAL" + K_DPM_2 = "K_DPM_2" + K_DPM_2_ANCESTRAL = "K_DPM_2_ANCESTRAL" + K_EULER = "K_EULER" + K_EULER_ANCESTRAL = "K_EULER_ANCESTRAL" + K_HEUN = "K_HEUN" + K_LMS = "K_LMS" + + +Samples = Annotated[int, conint(ge=1, le=10)] + +Seed = Annotated[int, conint(ge=0, le=4294967295)] + +Steps = Annotated[int, conint(ge=10, le=150)] + +Extras = dict + + +class StylePreset(StrEnum): + enhance = "enhance" + anime = "anime" + photographic = "photographic" + digital_art = "digital-art" + comic_book = "comic-book" + fantasy_art = "fantasy-art" + line_art = "line-art" + analog_film = "analog-film" + neon_punk = "neon-punk" + isometric = "isometric" + low_poly = "low-poly" + origami = "origami" + modeling_compound = "modeling-compound" + cinematic = "cinematic" + field_3d_model = "3d-model" + pixel_art = "pixel-art" + tile_texture = "tile-texture" + + +class TextPrompt(BaseModel): + text: Annotated[str, constr(max_length=2000)] + weight: Optional[float] = None + +SingleTextPrompt = str + +TextPrompts = List[TextPrompt] + +InitImage = bytes + +InitImageStrength = Annotated[float, confloat(ge=0.0, le=1.0)] + + +class InitImageMode(StrEnum): + IMAGE_STRENGTH = "IMAGE_STRENGTH" + STEP_SCHEDULE = "STEP_SCHEDULE" + + +StepScheduleStart = Annotated[float, confloat(ge=0.0, le=1.0)] + +StepScheduleEnd = Annotated[float, confloat(ge=0.0, le=1.0)] + +MaskImage = bytes + +MaskSource = str + + +class GenerationRequestOptionalParams(BaseModel): + cfg_scale: Optional[CfgScale] = None + clip_guidance_preset: Optional[ClipGuidancePreset] = None + sampler: Optional[Sampler] = None + samples: Optional[Samples] = None + seed: Optional[Seed] = None + steps: Optional[Steps] = None + style_preset: Optional[StylePreset] = None + extras: Optional[Extras] = None + + +class LatentUpscalerUpscaleRequestBody(BaseModel): + image: InitImage + width: Optional[UpscaleImageWidth] = None + height: Optional[UpscaleImageHeight] = None + text_prompts: Optional[TextPrompts] = None + seed: Optional[Seed] = None + steps: Optional[Steps] = None + cfg_scale: Optional[CfgScale] = None + + +class ImageToImageRequestBody(BaseModel): + text_prompts: TextPrompts + init_image: InitImage + init_image_mode: Optional[InitImageMode] = InitImageMode("IMAGE_STRENGTH") + image_strength: Optional[InitImageStrength] = None + step_schedule_start: Optional[StepScheduleStart] = None + step_schedule_end: Optional[StepScheduleEnd] = None + cfg_scale: Optional[CfgScale] = None + clip_guidance_preset: Optional[ClipGuidancePreset] = None + sampler: Optional[Sampler] = None + samples: Optional[Samples] = None + seed: Optional[Seed] = None + steps: Optional[Steps] = None + style_preset: Optional[StylePreset] = None + extras: Optional[Extras] = None + + +class MaskingRequestBody(BaseModel): + init_image: InitImage + mask_source: MaskSource + mask_image: Optional[MaskImage] = None + text_prompts: TextPrompts + cfg_scale: Optional[CfgScale] = None + clip_guidance_preset: Optional[ClipGuidancePreset] = None + sampler: Optional[Sampler] = None + samples: Optional[Samples] = None + seed: Optional[Seed] = None + steps: Optional[Steps] = None + style_preset: Optional[StylePreset] = None + extras: Optional[Extras] = None + + +class TextToImageRequestBody(GenerationRequestOptionalParams): + height: Optional[DiffuseImageHeight] = None + width: Optional[DiffuseImageWidth] = None + text_prompts: TextPrompts + + +class ImageToImageUpscaleRequestBody(BaseModel): + image: InitImage + width: Optional[UpscaleImageWidth] + height: Optional[UpscaleImageHeight] + + @validator("width", always=True) + def mutually_exclusive(cls, v, values): + if values["height"] is not None and v: + raise ValueError("You can only specify one of width and height.") + return v + + +class BalanceResponseBody(BaseModel): + credits: float + + +ListEnginesResponseBody = List[Engine] + + +class FinishReason(StrEnum): + SUCCESS = "SUCCESS" + ERROR = "ERROR" + CONTENT_FILTERED = "CONTENT_FILTERED" + + +class Image(BaseModel): + base64: Optional[str] = None + finishReason: Optional[FinishReason] = None + seed: Optional[float] = None + + @property + def file(self): + assert self.base64 is not None + return io.BytesIO(base64.b64decode(self.base64)) + + +class OrganizationMembership(BaseModel): + id: str + is_default: bool + name: str + role: str + + +class AccountResponseBody(BaseModel): + email: EmailStr + id: str + organizations: List[OrganizationMembership] + profile_picture: Optional[AnyUrl] = None + + +class TextToImageResponseBody(BaseModel): + artifacts: List[Image] + + +class ImageToImageResponseBody(BaseModel): + artifacts: List[Image] + + +class ImageToImageUpscaleResponseBody(BaseModel): + artifacts: List[Image] diff --git a/stabilityai/utils.py b/stabilityai/utils.py new file mode 100644 index 0000000..ae49f21 --- /dev/null +++ b/stabilityai/utils.py @@ -0,0 +1,6 @@ +from typing import Dict, Mapping, TypeVar + +T, U = TypeVar("T"), TypeVar("U") + +def omit_none(m: Mapping[T, U]) -> Dict[T, U]: + return {k: v for k, v in m.items() if v is not None}