ranchimallflo-api/py3.7/lib/python3.7/site-packages/quart/flask_patch/_patch.py

120 lines
4.1 KiB
Python

import asyncio
import inspect
import sys
import types
from typing import Any, Callable
from quart.local import LocalStack, TaskLocal
from ._synchronise import sync_with_context
def _patch_asyncio() -> None:
# This patches asyncio to add a sync_wait method to the event
# loop. This method can then be called from within a task
# including a synchronous function called from a task. Sadly it
# requires the python Task and Future implementations, which
# invokes some performance cost.
asyncio.Task = asyncio.tasks._CTask = asyncio.tasks.Task = asyncio.tasks._PyTask # type: ignore
asyncio.Future = asyncio.futures._CFuture = asyncio.futures.Future = asyncio.futures._PyFuture # type: ignore # noqa
current_policy = asyncio.get_event_loop_policy()
if hasattr(asyncio, 'unix_events'):
target_policy = asyncio.unix_events._UnixDefaultEventLoopPolicy # type: ignore
else:
target_policy = object
if not isinstance(current_policy, target_policy):
raise RuntimeError("Flask Patching only works with the default event loop policy")
_patch_loop()
_patch_task()
def _patch_loop() -> None:
def _sync_wait(self, future): # type: ignore
preserved_ready = list(self._ready)
self._ready.clear()
future = asyncio.tasks.ensure_future(future, loop=self)
while not future.done() and not future.cancelled():
self._run_once()
if self._stopping:
break
self._ready.extendleft(preserved_ready)
return future.result()
asyncio.BaseEventLoop.sync_wait = _sync_wait # type: ignore
def _patch_task() -> None:
# Patch the asyncio task to allow it to be re-entered.
def enter_task(loop, task): # type: ignore
asyncio.tasks._current_tasks[loop] = task
asyncio.tasks._enter_task = enter_task # type: ignore
def leave_task(loop, task): # type: ignore
del asyncio.tasks._current_tasks[loop]
asyncio.tasks._leave_task = leave_task # type: ignore
def step(self, exception=None): # type: ignore
current_task = asyncio.tasks._current_tasks.get(self._loop)
try:
self._Task__step_orig(exception)
finally:
if current_task is None:
asyncio.tasks._current_tasks.pop(self._loop, None)
else:
asyncio.tasks._current_tasks[self._loop] = current_task
asyncio.Task._Task__step_orig = asyncio.Task._Task__step # type: ignore
asyncio.Task._Task__step = step # type: ignore
def _context_decorator(func: Callable) -> Callable:
def wrapper(*args: Any, **kwargs: Any) -> Any:
return sync_with_context(func(*args, **kwargs))
return wrapper
def _convert_module(new_name, module): # type: ignore
new_module = types.ModuleType(new_name)
for name, member in inspect.getmembers(module):
if inspect.getmodule(member) == module and inspect.iscoroutinefunction(member):
setattr(new_module, name, _context_decorator(member))
else:
setattr(new_module, name, member)
return new_module
def _patch_modules() -> None:
if 'flask' in sys.modules:
raise ImportError('Cannot mock flask, already imported')
# Create a set of Flask modules, prioritising those within the
# flask_patch namespace over simple references to the Quart
# versions.
flask_modules = {}
for name, module in list(sys.modules.items()):
if name.startswith('quart.flask_patch._'):
continue
elif name.startswith('quart.flask_patch'):
flask_modules[name.replace('quart.flask_patch', 'flask')] = module
elif name.startswith('quart.') and not name.startswith('quart.serving'):
flask_name = name.replace('quart.', 'flask.')
if flask_name not in flask_modules:
flask_modules[flask_name] = _convert_module(flask_name, module)
sys.modules.update(flask_modules)
def _patch_quart_local() -> None:
LocalStack.__ident_func__ = lambda _: TaskLocal._task_identity() # type: ignore
def patch_all() -> None:
_patch_asyncio()
_patch_modules()
_patch_quart_local()