import asyncio import inspect import sys import types from typing import Any, Callable from quart.local import LocalStack, TaskLocal from ._synchronise import sync_with_context def _patch_asyncio() -> None: # This patches asyncio to add a sync_wait method to the event # loop. This method can then be called from within a task # including a synchronous function called from a task. Sadly it # requires the python Task and Future implementations, which # invokes some performance cost. asyncio.Task = asyncio.tasks._CTask = asyncio.tasks.Task = asyncio.tasks._PyTask # type: ignore asyncio.Future = asyncio.futures._CFuture = asyncio.futures.Future = asyncio.futures._PyFuture # type: ignore # noqa current_policy = asyncio.get_event_loop_policy() if hasattr(asyncio, 'unix_events'): target_policy = asyncio.unix_events._UnixDefaultEventLoopPolicy # type: ignore else: target_policy = object if not isinstance(current_policy, target_policy): raise RuntimeError("Flask Patching only works with the default event loop policy") _patch_loop() _patch_task() def _patch_loop() -> None: def _sync_wait(self, future): # type: ignore preserved_ready = list(self._ready) self._ready.clear() future = asyncio.tasks.ensure_future(future, loop=self) while not future.done() and not future.cancelled(): self._run_once() if self._stopping: break self._ready.extendleft(preserved_ready) return future.result() asyncio.BaseEventLoop.sync_wait = _sync_wait # type: ignore def _patch_task() -> None: # Patch the asyncio task to allow it to be re-entered. def enter_task(loop, task): # type: ignore asyncio.tasks._current_tasks[loop] = task asyncio.tasks._enter_task = enter_task # type: ignore def leave_task(loop, task): # type: ignore del asyncio.tasks._current_tasks[loop] asyncio.tasks._leave_task = leave_task # type: ignore def step(self, exception=None): # type: ignore current_task = asyncio.tasks._current_tasks.get(self._loop) try: self._Task__step_orig(exception) finally: if current_task is None: asyncio.tasks._current_tasks.pop(self._loop, None) else: asyncio.tasks._current_tasks[self._loop] = current_task asyncio.Task._Task__step_orig = asyncio.Task._Task__step # type: ignore asyncio.Task._Task__step = step # type: ignore def _context_decorator(func: Callable) -> Callable: def wrapper(*args: Any, **kwargs: Any) -> Any: return sync_with_context(func(*args, **kwargs)) return wrapper def _convert_module(new_name, module): # type: ignore new_module = types.ModuleType(new_name) for name, member in inspect.getmembers(module): if inspect.getmodule(member) == module and inspect.iscoroutinefunction(member): setattr(new_module, name, _context_decorator(member)) else: setattr(new_module, name, member) return new_module def _patch_modules() -> None: if 'flask' in sys.modules: raise ImportError('Cannot mock flask, already imported') # Create a set of Flask modules, prioritising those within the # flask_patch namespace over simple references to the Quart # versions. flask_modules = {} for name, module in list(sys.modules.items()): if name.startswith('quart.flask_patch._'): continue elif name.startswith('quart.flask_patch'): flask_modules[name.replace('quart.flask_patch', 'flask')] = module elif name.startswith('quart.') and not name.startswith('quart.serving'): flask_name = name.replace('quart.', 'flask.') if flask_name not in flask_modules: flask_modules[flask_name] = _convert_module(flask_name, module) sys.modules.update(flask_modules) def _patch_quart_local() -> None: LocalStack.__ident_func__ = lambda _: TaskLocal._task_identity() # type: ignore def patch_all() -> None: _patch_asyncio() _patch_modules() _patch_quart_local()