Estelle Poulin 1 year ago
parent f9859d1767
commit ab492417d6

@ -0,0 +1,41 @@
---
name: Publish Python distributions to PyPI and TestPyPI
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
validate:
name: Run static analysis on the code
runs-on: ubuntu-22.04
steps:
- name: Lint with flake8
run: |
pip install flake8
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings.
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
build-and-publish:
name: Build and publish Python distribution
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@main
- name: Initialize Python 3.11
uses: actions/setup-python@v1
with:
python-version: 3.11
- name: Install dependencies
run: |
python -m pip install --upgrade pip build
- name: Build binary wheel and a source tarball
run: python -m build
- name: Publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@v1.8.5
with:
password: ${{ secrets.PIPY_PASSWORD }}

@ -1,2 +1,71 @@
# stabilityai
Client library for the stability.ai REST API
# Stability AI
An **UNOFFICIAL** client library for the stability.ai REST API.
## Motivation
The official `stability-sdk` is a based on gRPC and also really hard to use. Like look at this, this
ignores setting up the SDK.
```python
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
answers = stability_api.generate(
prompt="a rocket-ship launching from rolling greens with blue daisies",
seed=892226758,
steps=30,
cfg_scale=8.0,
width=512,
height=512,
sampler=generation.SAMPLER_K_DPMPP_2M
)
for resp in answers:
for artifact in resp.artifacts:
if artifact.finish_reason == generation.FILTER:
warnings.warn(
"Your request activated the API's safety filters and could not be processed."
"Please modify the prompt and try again.")
if artifact.type == generation.ARTIFACT_IMAGE:
global img
img = Image.open(io.BytesIO(artifact.binary))
img.save(str(artifact.seed)+ ".png")
```
This for loop is *magic*. You must loop the results in exactly this way or the gRPC library won't
work. It's about an unpythonic as a library can get.
## My Take
```python
# Set the STABILITY_API_KEY environment variable.
from stabilityai.client import AsyncStabilityClient
from stabilityai.models import Sampler
async def example():
async with AsyncStabilityClient() as stability:
results = await stability.text_to_image(
text_prompt="a rocket-ship launching from rolling greens with blue daisies",
# All these are optional and have sane defaults.
seed=892226758,
steps=30,
cfg_scale=8.0,
width=512,
height=512,
sampler=Sampler.K_DPMPP_2M,
)
artifact = results.artifacts[0]
img = Image.open(artifact.file)
img.save(artifact.file.name)
```
## Additional Nicetieis
* Instead of manually checking `FINISH_REASON` an appropriate exception will automatically be
raised.
* Full mypy/pyright support for type checking and autocomplete.

@ -0,0 +1,50 @@
# pyproject.toml
# https://packaging.python.org/en/latest/specifications/declaring-project-metadata
# https://pip.pypa.io/en/stable/reference/build-system/pyproject-toml
[build-system]
requires = ["setuptools", "setuptools-scm", "build"]
build-backend = "setuptools.build_meta"
[project]
name = "stabilityai"
description = "*Unofficial* client for the Stability REST API"
authors = [
{name = "Estelle Poulin", email = "dev@inspiredby.es"},
]
readme = "README.md"
requires-python = ">=3.11"
keywords = ["stabilityai", "bot"]
license = {text = "GPLv3"}
classifiers = [
"Programming Language :: Python :: 3",
]
dynamic = ["version", "dependencies"]
[project.urls]
homepage = "https://github.com/estheruary/stabilityai"
repository = "https://github.com/estheruary/stabilityai"
changelog = "https://github.com/estheruary/stabilityai/-/blob/main/CHANGELOG.md"
[tool.setuptools]
packages = ["stabilityai"]
[tool.setuptools.dynamic]
version = {attr = "stabilityai.__version__"}
dependencies = {file = ["requirements.txt"]}
[tool.black]
line-length = 100
[tool.isort]
profile = "black"
[tool.vulture]
ignore_names = ["self", "cls"]

@ -0,0 +1,2 @@
aiohttp==3.8.4
pydantic [email]==1.10.2

@ -0,0 +1 @@
__version__ = "1.0.1"

@ -0,0 +1,271 @@
import json
import os
from typing import Optional
import aiohttp
from pydantic import (
validate_arguments,
)
from stabilityai.constants import (
DEFAULT_GENERATION_ENGINE,
DEFAULT_STABILITY_API_HOST,
DEFAULT_STABILITY_CLIENT_ID,
DEFAULT_STABILITY_CLIENT_VERSION,
DEFAULT_UPSCALE_ENGINE,
)
from stabilityai.exceptions import YouNeedToUseAContextManager
from stabilityai.models import (
AccountResponseBody,
BalanceResponseBody,
CfgScale,
ClipGuidancePreset,
DiffuseImageHeight,
DiffuseImageWidth,
Engine,
Extras,
ImageToImageRequestBody,
ImageToImageResponseBody,
ImageToImageUpscaleRequestBody,
ImageToImageUpscaleResponseBody,
InitImage,
InitImageMode,
InitImageStrength,
ListEnginesResponseBody,
Sampler,
Samples,
Seed,
SingleTextPrompt,
Steps,
StylePreset,
TextPrompt,
TextPrompts,
TextToImageRequestBody,
TextToImageResponseBody,
UpscaleImageHeight,
UpscaleImageWidth,
)
from stabilityai.utils import omit_none
import textwrap
class AsyncStabilityClient:
def __init__(
self,
api_host: str = os.environ.get("STABILITY_API_HOST", DEFAULT_STABILITY_API_HOST),
api_key: Optional[str] = os.environ.get("STABILITY_API_KEY"),
client_id: str = os.environ.get("STABILITY_CLIENT_ID", DEFAULT_STABILITY_CLIENT_ID),
client_version: str = os.environ.get(
"STABILITY_CLIENT_VERSION", DEFAULT_STABILITY_CLIENT_VERSION
),
organization: Optional[str] = os.environ.get("STABILITY_CLIENT_ORGANIZATION"),
) -> None:
self.api_host = api_host
self.api_key = api_key
self.client_id = client_id
self.client_version = client_version
self.organization = organization
async def __aenter__(self):
self.session = await aiohttp.ClientSession(
base_url=self.api_host,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {self.api_key}",
"Stability-Client-ID": self.client_id,
"Stability-Client-Version": self.client_version,
**({"Organization": self.organization} if self.organization else {}),
},
raise_for_status=True,
).__aenter__()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return await self.session.__aexit__(exc_type, exc_val, exc_tb)
async def get_engines(self) -> ListEnginesResponseBody:
res = await self.session.get("/v1/engines/list")
return await res.json()
async def get_account(self) -> AccountResponseBody:
res = await self.session.get("/v1/user/account")
return await res.json()
async def get_account_balance(self) -> BalanceResponseBody:
res = await self.session.get("/v1/user/balance")
return await res.json()
def _oops_no_session(self):
if not self.session:
raise YouNeedToUseAContextManager(
textwrap.dedent(
f"""\
{self.__class__.__name__} keeps a aiohttp.ClientSession under
the hood and needs to be closed when you're done with it. But
since there isn't an async version of __del__ we have to use
__aenter__/__aexit__ instead. Apologies.
Instead of
myclient = {self.__class__.__name__}()
myclient.text_to_image(...)
Do this
async with {self.__class__.__name__} as myclient:
myclient.text_to_image(...)
Note that it's `async with` and not `with`.
"""
)
)
def _normalize_text_prompts(
self,
text_prompts: Optional[TextPrompts],
text_prompt: Optional[SingleTextPrompt]
):
if not bool(text_prompt) ^ bool(text_prompts):
raise RuntimeError(
textwrap.dedent(
f"""\
You must provide one of text_prompt and text_prompts.
Do this
stability.text_to_image(
text_prompt="A beautiful sunrise"
)
Or this
from stabilityai.models import TextPrompt
stability.text_to_image(
text_prompts=[
TextPrompt(text="A beautiful sunrise", weight=1.0)
],
)
"""
)
)
if text_prompt:
text_prompts = [TextPrompt(text=text_prompt)]
# After this moment text_prompts can't be None.
assert text_prompts is not None
return text_prompts
@validate_arguments
async def text_to_image(
self,
text_prompts: Optional[TextPrompts] = None,
text_prompt: Optional[SingleTextPrompt] = None,
*,
engine: Optional[Engine] = None,
cfg_scale: Optional[CfgScale] = None,
clip_guidance_preset: Optional[ClipGuidancePreset] = None,
height: Optional[DiffuseImageHeight] = None,
sampler: Optional[Sampler] = None,
samples: Optional[Samples] = None,
seed: Optional[Seed] = None,
steps: Optional[Steps] = None,
style_preset: Optional[StylePreset] = None,
width: Optional[DiffuseImageWidth] = None,
extras: Optional[Extras] = None,
):
self._oops_no_session()
text_prompts = self._normalize_text_prompts(text_prompts, text_prompt)
engine_id = engine.id if engine else DEFAULT_GENERATION_ENGINE
request_body = TextToImageRequestBody(
cfg_scale=cfg_scale,
clip_guidance_preset=clip_guidance_preset,
height=height,
sampler=sampler,
samples=samples,
seed=seed,
steps=steps,
style_preset=style_preset,
text_prompts=text_prompts,
width=width,
extras=extras,
)
res = await self.session.post(
f"/v1/generation/{engine_id}/text-to-image",
data=json.dumps(omit_none(json.loads(request_body.json()))),
)
Sampler.K_DPMPP_2M
return TextToImageResponseBody.parse_obj(await res.json())
@validate_arguments
async def image_to_image(
self,
text_prompts: TextPrompts,
text_prompt: SingleTextPrompt,
init_image: InitImage,
*,
init_image_mode: Optional[InitImageMode] = None,
image_strength: InitImageStrength,
engine: Optional[Engine] = None,
cfg_scale: Optional[CfgScale] = None,
clip_guidance_preset: Optional[ClipGuidancePreset] = None,
sampler: Optional[Sampler] = None,
samples: Optional[Samples] = None,
seed: Optional[Seed] = None,
steps: Optional[Steps] = None,
style_preset: Optional[StylePreset] = None,
extras: Optional[Extras] = None,
):
self._oops_no_session()
text_prompts = self._normalize_text_prompts(text_prompts, text_prompt)
engine_id = engine.id if engine else DEFAULT_GENERATION_ENGINE
request_body = ImageToImageRequestBody(
cfg_scale=cfg_scale,
clip_guidance_preset=clip_guidance_preset,
init_image=init_image,
init_image_mode=init_image_mode,
image_strength=image_strength,
sampler=sampler,
samples=samples,
seed=seed,
steps=steps,
style_preset=style_preset,
text_prompts=text_prompts,
extras=extras,
)
res = await self.session.post(
f"/v1/generation/{engine_id}/text-to-image",
data=json.dumps(omit_none(json.loads(request_body.json()))),
)
return ImageToImageResponseBody.parse_obj(await res.json())
async def image_to_image_upscale(
self,
image: InitImage,
*,
engine: Optional[Engine] = None,
width: Optional[UpscaleImageWidth] = None,
height: Optional[UpscaleImageHeight] = None,
):
self._oops_no_session()
engine_id = engine.id if engine else DEFAULT_UPSCALE_ENGINE
request_body = ImageToImageUpscaleRequestBody(image=image, width=width, height=height)
res = await self.session.post(
f"/v1/generation/{engine_id}/image-to-image/upscale",
data=json.dumps(omit_none(json.loads(request_body.json()))),
)
return ImageToImageUpscaleResponseBody.parse_obj(await res.json())

@ -0,0 +1,9 @@
from typing import Final
DEFAULT_STABILITY_API_HOST: Final[str] = "https://api.stability.ai"
DEFAULT_STABILITY_CLIENT_ID: Final[str] = "Stability Python SDK"
DEFAULT_STABILITY_CLIENT_VERSION: Final[str] = "1.0.0"
DEFAULT_GENERATION_ENGINE: Final[str] = "stable-diffusion-xl-beta-v2-2-2"
DEFAULT_UPSCALE_ENGINE: Final[str] = "esrgan-v1-x2plus"

@ -0,0 +1,2 @@
class YouNeedToUseAContextManager(Exception):
pass

@ -0,0 +1,247 @@
import base64
import io
from enum import StrEnum
from typing import Annotated, List, Optional
from pydantic import (
AnyUrl,
BaseModel,
EmailStr,
confloat,
conint,
constr,
validator,
)
class Type(StrEnum):
AUDIO = "AUDIO"
CLASSIFICATION = "CLASSIFICATION"
PICTURE = "PICTURE"
STORAGE = "STORAGE"
TEXT = "TEXT"
VIDEO = "VIDEO"
class Engine(BaseModel):
description: str
id: str
name: str
type: Type
class Error(BaseModel):
id: str
name: str
message: str
CfgScale = Annotated[float, confloat(ge=0.0, le=35.0)]
class ClipGuidancePreset(StrEnum):
FAST_BLUE = "FAST_BLUE"
FAST_GREEN = "FAST_GREEN"
NONE = "NONE"
SIMPLE = "SIMPLE"
SLOW = "SLOW"
SLOWER = "SLOWER"
SLOWEST = "SLOWEST"
UpscaleImageHeight = Annotated[int, conint(ge=512)]
UpscaleImageWidth = Annotated[int, conint(ge=512)]
DiffuseImageHeight = Annotated[int, conint(ge=128, multiple_of=64)]
DiffuseImageWidth = Annotated[int, conint(ge=128, multiple_of=64)]
class Sampler(StrEnum):
DDIM = "DDIM"
DDPM = "DDPM"
K_DPMPP_2M = "K_DPMPP_2M"
K_DPMPP_2S_ANCESTRAL = "K_DPMPP_2S_ANCESTRAL"
K_DPM_2 = "K_DPM_2"
K_DPM_2_ANCESTRAL = "K_DPM_2_ANCESTRAL"
K_EULER = "K_EULER"
K_EULER_ANCESTRAL = "K_EULER_ANCESTRAL"
K_HEUN = "K_HEUN"
K_LMS = "K_LMS"
Samples = Annotated[int, conint(ge=1, le=10)]
Seed = Annotated[int, conint(ge=0, le=4294967295)]
Steps = Annotated[int, conint(ge=10, le=150)]
Extras = dict
class StylePreset(StrEnum):
enhance = "enhance"
anime = "anime"
photographic = "photographic"
digital_art = "digital-art"
comic_book = "comic-book"
fantasy_art = "fantasy-art"
line_art = "line-art"
analog_film = "analog-film"
neon_punk = "neon-punk"
isometric = "isometric"
low_poly = "low-poly"
origami = "origami"
modeling_compound = "modeling-compound"
cinematic = "cinematic"
field_3d_model = "3d-model"
pixel_art = "pixel-art"
tile_texture = "tile-texture"
class TextPrompt(BaseModel):
text: Annotated[str, constr(max_length=2000)]
weight: Optional[float] = None
SingleTextPrompt = str
TextPrompts = List[TextPrompt]
InitImage = bytes
InitImageStrength = Annotated[float, confloat(ge=0.0, le=1.0)]
class InitImageMode(StrEnum):
IMAGE_STRENGTH = "IMAGE_STRENGTH"
STEP_SCHEDULE = "STEP_SCHEDULE"
StepScheduleStart = Annotated[float, confloat(ge=0.0, le=1.0)]
StepScheduleEnd = Annotated[float, confloat(ge=0.0, le=1.0)]
MaskImage = bytes
MaskSource = str
class GenerationRequestOptionalParams(BaseModel):
cfg_scale: Optional[CfgScale] = None
clip_guidance_preset: Optional[ClipGuidancePreset] = None
sampler: Optional[Sampler] = None
samples: Optional[Samples] = None
seed: Optional[Seed] = None
steps: Optional[Steps] = None
style_preset: Optional[StylePreset] = None
extras: Optional[Extras] = None
class LatentUpscalerUpscaleRequestBody(BaseModel):
image: InitImage
width: Optional[UpscaleImageWidth] = None
height: Optional[UpscaleImageHeight] = None
text_prompts: Optional[TextPrompts] = None
seed: Optional[Seed] = None
steps: Optional[Steps] = None
cfg_scale: Optional[CfgScale] = None
class ImageToImageRequestBody(BaseModel):
text_prompts: TextPrompts
init_image: InitImage
init_image_mode: Optional[InitImageMode] = InitImageMode("IMAGE_STRENGTH")
image_strength: Optional[InitImageStrength] = None
step_schedule_start: Optional[StepScheduleStart] = None
step_schedule_end: Optional[StepScheduleEnd] = None
cfg_scale: Optional[CfgScale] = None
clip_guidance_preset: Optional[ClipGuidancePreset] = None
sampler: Optional[Sampler] = None
samples: Optional[Samples] = None
seed: Optional[Seed] = None
steps: Optional[Steps] = None
style_preset: Optional[StylePreset] = None
extras: Optional[Extras] = None
class MaskingRequestBody(BaseModel):
init_image: InitImage
mask_source: MaskSource
mask_image: Optional[MaskImage] = None
text_prompts: TextPrompts
cfg_scale: Optional[CfgScale] = None
clip_guidance_preset: Optional[ClipGuidancePreset] = None
sampler: Optional[Sampler] = None
samples: Optional[Samples] = None
seed: Optional[Seed] = None
steps: Optional[Steps] = None
style_preset: Optional[StylePreset] = None
extras: Optional[Extras] = None
class TextToImageRequestBody(GenerationRequestOptionalParams):
height: Optional[DiffuseImageHeight] = None
width: Optional[DiffuseImageWidth] = None
text_prompts: TextPrompts
class ImageToImageUpscaleRequestBody(BaseModel):
image: InitImage
width: Optional[UpscaleImageWidth]
height: Optional[UpscaleImageHeight]
@validator("width", always=True)
def mutually_exclusive(cls, v, values):
if values["height"] is not None and v:
raise ValueError("You can only specify one of width and height.")
return v
class BalanceResponseBody(BaseModel):
credits: float
ListEnginesResponseBody = List[Engine]
class FinishReason(StrEnum):
SUCCESS = "SUCCESS"
ERROR = "ERROR"
CONTENT_FILTERED = "CONTENT_FILTERED"
class Image(BaseModel):
base64: Optional[str] = None
finishReason: Optional[FinishReason] = None
seed: Optional[float] = None
@property
def file(self):
assert self.base64 is not None
return io.BytesIO(base64.b64decode(self.base64))
class OrganizationMembership(BaseModel):
id: str
is_default: bool
name: str
role: str
class AccountResponseBody(BaseModel):
email: EmailStr
id: str
organizations: List[OrganizationMembership]
profile_picture: Optional[AnyUrl] = None
class TextToImageResponseBody(BaseModel):
artifacts: List[Image]
class ImageToImageResponseBody(BaseModel):
artifacts: List[Image]
class ImageToImageUpscaleResponseBody(BaseModel):
artifacts: List[Image]

@ -0,0 +1,6 @@
from typing import Dict, Mapping, TypeVar
T, U = TypeVar("T"), TypeVar("U")
def omit_none(m: Mapping[T, U]) -> Dict[T, U]:
return {k: v for k, v in m.items() if v is not None}
Loading…
Cancel
Save