You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
162 lines
4.2 KiB
Python
162 lines
4.2 KiB
Python
1 year ago
|
import io
|
||
|
import os
|
||
|
import string
|
||
|
import textwrap
|
||
|
from typing import Optional
|
||
|
|
||
|
import aiohttp
|
||
|
from pydantic import validate_arguments
|
||
|
|
||
|
from rtex.constants import DEFAULT_API_HOST, FORMAT_MIME
|
||
|
from rtex.exceptions import YouNeedToUseAContextManager
|
||
|
from rtex.models import (
|
||
|
CreateLaTeXDocumentRequest,
|
||
|
CreateLaTeXDocumentResponse,
|
||
|
RenderDensity,
|
||
|
RenderFormat,
|
||
|
RenderQuality,
|
||
|
)
|
||
|
|
||
|
|
||
|
class AsyncRtexClient:
|
||
|
def __init__(
|
||
|
self,
|
||
|
api_host=os.environ.get("RTEX_API_HOST", DEFAULT_API_HOST),
|
||
|
):
|
||
|
self.api_host = api_host
|
||
|
self.latex_template = string.Template(
|
||
|
textwrap.dedent(
|
||
|
r"""
|
||
|
\documentclass{$docclass}
|
||
|
\begin{document}
|
||
|
$doc
|
||
|
\end{document}
|
||
|
"""
|
||
|
)
|
||
|
)
|
||
|
|
||
|
async def __aenter__(self):
|
||
|
self.session = await aiohttp.ClientSession(
|
||
|
base_url=self.api_host,
|
||
|
headers={
|
||
|
"Content-Type": "application/json",
|
||
|
},
|
||
|
).__aenter__()
|
||
|
|
||
|
return self
|
||
|
|
||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
|
return await self.session.__aexit__(exc_type, exc_val, exc_tb)
|
||
|
|
||
|
def _oops_no_session(self):
|
||
|
if not self.session:
|
||
|
raise YouNeedToUseAContextManager(
|
||
|
textwrap.dedent(
|
||
|
f"""\
|
||
|
{self.__class__.__name__} keeps a aiohttp.ClientSession under
|
||
|
the hood and needs to be closed when you're done with it. But
|
||
|
since there isn't an async version of __del__ we have to use
|
||
|
__aenter__/__aexit__ instead. Apologies.
|
||
|
|
||
|
Instead of
|
||
|
|
||
|
myclient = {self.__class__.__name__}()
|
||
|
myclient.text_to_image(...)
|
||
|
|
||
|
Do this
|
||
|
|
||
|
async with {self.__class__.__name__} as myclient:
|
||
|
myclient.text_to_image(...)
|
||
|
|
||
|
Note that it's `async with` and not `with`.
|
||
|
"""
|
||
|
)
|
||
|
)
|
||
|
|
||
|
@validate_arguments
|
||
|
async def create_render(
|
||
|
self,
|
||
|
code: str,
|
||
|
format: RenderFormat = "png",
|
||
|
documentclass: str = "minimal",
|
||
|
quality: Optional[RenderQuality] = None,
|
||
|
density: Optional[RenderDensity] = None,
|
||
|
):
|
||
|
self._oops_no_session()
|
||
|
|
||
|
final_doc = self.latex_template.substitute(
|
||
|
docclass=documentclass,
|
||
|
doc=code,
|
||
|
)
|
||
|
|
||
|
request_body = CreateLaTeXDocumentRequest(
|
||
|
code=final_doc,
|
||
|
format=format,
|
||
|
quality=quality,
|
||
|
density=density,
|
||
|
)
|
||
|
|
||
|
res = await self.session.post(
|
||
|
"/api/v2",
|
||
|
headers={"Accept": "application/json"},
|
||
|
data=request_body.json(
|
||
|
exclude_defaults=True,
|
||
|
exclude_none=True,
|
||
|
exclude_unset=True,
|
||
|
),
|
||
|
)
|
||
|
|
||
|
return CreateLaTeXDocumentResponse.parse_obj(await res.json()).__root__
|
||
|
|
||
|
@validate_arguments(config={"arbitrary_types_allowed": True})
|
||
|
async def save_render(
|
||
|
self,
|
||
|
filename: str,
|
||
|
output_fd: io.IOBase,
|
||
|
format: RenderFormat = "png",
|
||
|
):
|
||
|
self._oops_no_session()
|
||
|
|
||
|
res = await self.session.get(
|
||
|
f"/api/v2/{filename}",
|
||
|
headers={
|
||
|
"Accept": FORMAT_MIME[format],
|
||
|
},
|
||
|
)
|
||
|
|
||
|
async for chunk, _ in res.content.iter_chunks():
|
||
|
output_fd.write(chunk)
|
||
|
|
||
|
@validate_arguments
|
||
|
async def get_render(
|
||
|
self,
|
||
|
filename: str,
|
||
|
format: RenderFormat = "png",
|
||
|
):
|
||
|
buf = io.BytesIO()
|
||
|
await self.save_render(filename, buf, format)
|
||
|
buf.seek(0)
|
||
|
|
||
|
return buf
|
||
|
|
||
|
@validate_arguments
|
||
|
async def render_math(
|
||
|
self,
|
||
|
code: str,
|
||
|
format: RenderFormat = "png",
|
||
|
):
|
||
|
final_doc = rf"\({code}\)"
|
||
|
|
||
|
res = await self.create_render(
|
||
|
code=final_doc,
|
||
|
format=format,
|
||
|
)
|
||
|
|
||
|
if res.status == "error":
|
||
|
raise RuntimeError("Failed to render code")
|
||
|
|
||
|
return await self.get_render(
|
||
|
filename=res.filename,
|
||
|
format=format,
|
||
|
)
|