You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

168 lines
4.4 KiB
Python

import io
import os
import string
import textwrap
from typing import Optional
import aiohttp
from pydantic import validate_arguments
from rtex.constants import DEFAULT_API_HOST, FORMAT_MIME
from rtex.exceptions import YouNeedToUseAContextManager
from rtex.models import (
CreateLaTeXDocumentRequest,
CreateLaTeXDocumentResponse,
RenderDensity,
RenderFormat,
RenderQuality,
)
class AsyncRtexClient:
def __init__(
self,
api_host=os.environ.get("RTEX_API_HOST", DEFAULT_API_HOST),
):
self.api_host = api_host
self.latex_template = string.Template(
textwrap.dedent(
r"""
\documentclass{$docclass}
\begin{document}
$doc
\end{document}
"""
)
)
async def __aenter__(self):
self.session = await aiohttp.ClientSession(
base_url=self.api_host,
headers={
"Content-Type": "application/json",
},
).__aenter__()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return await self.session.__aexit__(exc_type, exc_val, exc_tb)
def _oops_no_session(self):
if not self.session:
raise YouNeedToUseAContextManager(
textwrap.dedent(
f"""\
{self.__class__.__name__} keeps a aiohttp.ClientSession under
the hood and needs to be closed when you're done with it. But
since there isn't an async version of __del__ we have to use
__aenter__/__aexit__ instead. Apologies.
Instead of
myclient = {self.__class__.__name__}()
myclient.text_to_image(...)
Do this
async with {self.__class__.__name__} as myclient:
myclient.text_to_image(...)
Note that it's `async with` and not `with`.
"""
)
)
@validate_arguments
async def create_render(
self,
code: str,
format: RenderFormat = "png",
documentclass: str = "minimal",
quality: Optional[RenderQuality] = None,
density: Optional[RenderDensity] = None,
):
self._oops_no_session()
final_doc = self.latex_template.substitute(
docclass=documentclass,
doc=code,
)
request_body = CreateLaTeXDocumentRequest(
code=final_doc,
format=format,
quality=quality,
density=density,
)
res = await self.session.post(
"/api/v2",
headers={"Accept": "application/json"},
data=request_body.json(
exclude_defaults=True,
exclude_none=True,
exclude_unset=True,
),
)
return CreateLaTeXDocumentResponse.parse_obj(await res.json()).__root__
@validate_arguments(config={"arbitrary_types_allowed": True})
async def save_render(
self,
filename: str,
output_fd: io.IOBase,
format: RenderFormat = "png",
):
self._oops_no_session()
res = await self.session.get(
f"/api/v2/{filename}",
headers={
"Accept": FORMAT_MIME[format],
},
)
async for chunk, _ in res.content.iter_chunks():
output_fd.write(chunk)
@validate_arguments
async def get_render(
self,
filename: str,
format: RenderFormat = "png",
):
buf = io.BytesIO()
await self.save_render(filename, buf, format)
buf.seek(0)
return buf
@validate_arguments
async def render_math(
self,
code: str,
format: RenderFormat = "png",
documentclass: str = "minimal",
quality: Optional[RenderQuality] = None,
density: Optional[RenderDensity] = None,
):
final_doc = rf"\({code}\)"
res = await self.create_render(
code=final_doc,
format=format,
documentclass=documentclass,
quality=quality,
density=density,
)
if res.status == "error":
raise RuntimeError("Failed to render code")
return await self.get_render(
filename=res.filename,
format=format,
)