| 
						
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -1,11 +1,10 @@
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import json
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import os
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import textwrap
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from typing import Optional
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import aiohttp
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from pydantic import (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    validate_arguments,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from pydantic import validate_arguments
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from stabilityai.constants import (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    DEFAULT_GENERATION_ENGINE,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    DEFAULT_STABILITY_API_HOST,
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -13,7 +12,10 @@ from stabilityai.constants import (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    DEFAULT_STABILITY_CLIENT_VERSION,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    DEFAULT_UPSCALE_ENGINE,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from stabilityai.exceptions import ThisFunctionRequiresAPrompt, YouNeedToUseAContextManager
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from stabilityai.exceptions import (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    ThisFunctionRequiresAPrompt,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    YouNeedToUseAContextManager,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from stabilityai.models import (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    AccountResponseBody,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    BalanceResponseBody,
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -44,8 +46,6 @@ from stabilityai.models import (
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    UpscaleImageHeight,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    UpscaleImageWidth,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				from stabilityai.utils import omit_none
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				import textwrap
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				class AsyncStabilityClient:
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -69,7 +69,6 @@ class AsyncStabilityClient:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self.session = await aiohttp.ClientSession(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            base_url=self.api_host,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            headers={
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                "Content-Type": "application/json",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                "Accept": "application/json",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                "Authorization": f"Bearer {self.api_key}",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                "Stability-Client-ID": self.client_id,
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -122,9 +121,7 @@ class AsyncStabilityClient:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    def _normalize_text_prompts(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        text_prompts: Optional[TextPrompts],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        text_prompt: Optional[SingleTextPrompt]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self, text_prompts: Optional[TextPrompts], text_prompt: Optional[SingleTextPrompt]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    ):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if not bool(text_prompt) ^ bool(text_prompts):
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            raise ThisFunctionRequiresAPrompt(
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -147,14 +144,10 @@ class AsyncStabilityClient:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                                TextPrompt(text="A beautiful sunrise", weight=1.0)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                            ],
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                        )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    """ 
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                    """
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				                )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        if text_prompt:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            text_prompts = [TextPrompt(text=text_prompt)]
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        # After this moment text_prompts can't be None.
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        assert text_prompts is not None
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return text_prompts
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -197,7 +190,8 @@ class AsyncStabilityClient:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        res = await self.session.post(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            f"/v1/generation/{engine_id}/text-to-image",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            data=json.dumps(omit_none(json.loads(request_body.json()))),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            headers={"Content-Type": "application/json"},
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            data=request_body.json(exclude_defaults=True, exclude_none=True, exclude_unset=True),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        Sampler.K_DPMPP_2M
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -224,7 +218,7 @@ class AsyncStabilityClient:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            height=dim,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    @validate_arguments
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    @validate_arguments(config={"arbitrary_types_allowed": True})
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    async def image_to_image(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        text_prompts: TextPrompts,
 | 
			
		
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
			
			 | 
			 | 
			
				@ -265,11 +259,12 @@ class AsyncStabilityClient:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        res = await self.session.post(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            f"/v1/generation/{engine_id}/text-to-image",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            data=json.dumps(omit_none(json.loads(request_body.json()))),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            data=request_body.json(exclude_none=True, exclude_defaults=True, exclude_unset=True),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return ImageToImageResponseBody.parse_obj(await res.json())
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    @validate_arguments(config={"arbitrary_types_allowed": True})
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				    async def image_to_image_upscale(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        self,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        image: InitImage,
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
			
			 | 
			 | 
			
				@ -284,9 +279,13 @@ class AsyncStabilityClient:
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        request_body = ImageToImageUpscaleRequestBody(image=image, width=width, height=height)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        form = aiohttp.FormData()
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        form.add_field("width", "1024")
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        form.add_field("image", request_body.image)
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        res = await self.session.post(
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            f"/v1/generation/{engine_id}/image-to-image/upscale",
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            data=json.dumps(omit_none(json.loads(request_body.json()))),
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				            data=form,
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        )
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				
 | 
			
		
		
	
		
			
				 | 
				 | 
			
			 | 
			 | 
			
				        return ImageToImageUpscaleResponseBody.parse_obj(await res.json())
 | 
			
		
		
	
	
		
			
				
					| 
						
						
						
					 | 
				
			
			 | 
			 | 
			
				
 
 |