from abc import ABC, abstractmethod from datetime import datetime, timedelta from email.utils import parsedate_to_datetime from hashlib import md5 from inspect import isasyncgen from os import PathLike from types import TracebackType from typing import ( AnyStr, AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Optional, Tuple, TYPE_CHECKING, Union, ) from wsgiref.handlers import format_date_time from aiofiles import open as async_open from aiofiles.base import AiofilesContextManager from aiofiles.threadpool import AsyncFileIO from ._base import _BaseRequestResponse, JSONMixin from ..datastructures import ( CIMultiDict, ContentRange, Headers, HeaderSet, Range, ResponseAccessControl, ResponseCacheControl, ) from ..utils import create_cookie, file_path_to_path if TYPE_CHECKING: from ..routing import Rule # noqa class ResponseBody(ABC): """Base class wrapper for response body data. This ensures that the following is possible (as Quart assumes so when returning the body to the ASGI server async with wrapper as response: async for data in response: send(data) """ @abstractmethod async def __aenter__(self) -> AsyncIterable: pass @abstractmethod async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None: pass @abstractmethod async def convert_to_sequence(self) -> bytes: pass def _raise_if_invalid_range(begin: int, end: int, size: int) -> None: if begin >= end or abs(begin) > size or end > size: from ..exceptions import RequestRangeNotSatisfiable raise RequestRangeNotSatisfiable() class DataBody(ResponseBody): def __init__(self, data: bytes) -> None: self.data = data self.begin = 0 self.end = len(self.data) async def __aenter__(self) -> "DataBody": return self async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None: pass def __aiter__(self) -> AsyncIterator: async def _aiter() -> AsyncGenerator[bytes, None]: yield self.data[self.begin:self.end] return _aiter() async def convert_to_sequence(self) -> bytes: return self.data[self.begin:self.end] async def make_conditional( self, begin: int, end: Optional[int], max_partial_size: Optional[int]=None, ) -> int: self.begin = begin self.end = len(self.data) if end is None else end if max_partial_size is not None: self.end = min(self.begin + max_partial_size, self.end) _raise_if_invalid_range(self.begin, self.end, len(self.data)) return len(self.data) class IterableBody(ResponseBody): def __init__(self, iterable: Union[AsyncGenerator[bytes, None], Iterable]) -> None: self.iter: AsyncGenerator[bytes, None] if isasyncgen(iterable): self.iter = iterable # type: ignore else: async def _aiter() -> AsyncGenerator[bytes, None]: for data in iterable: # type: ignore yield data self.iter = _aiter() async def __aenter__(self) -> "IterableBody": return self async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None: pass def __aiter__(self) -> AsyncIterator: return self.iter async def convert_to_sequence(self) -> bytes: result = bytearray() async for data in self.iter: result.extend(data) return bytes(result) class FileBody(ResponseBody): """Provides an async file accessor with range setting. The :attr:`Response.response` attribute must be async-iterable and yield bytes, which this wrapper does for a file. In addition it allows a range to be set on the file, thereby supporting conditional requests. Attributes: buffer_size: Size in bytes to load per iteration. """ buffer_size = 8192 def __init__( self, file_path: Union[str, PathLike], *, buffer_size: Optional[int] = None, ) -> None: self.file_path = file_path_to_path(file_path) self.size = self.file_path.stat().st_size self.begin = 0 self.end = self.size if buffer_size is not None: self.buffer_size = buffer_size self.file: Optional[AsyncFileIO] = None self.file_manager: Optional[AiofilesContextManager] = None async def __aenter__(self) -> "FileBody": self.file_manager = async_open(self.file_path, mode="rb") self.file = await self.file_manager.__aenter__() await self.file.seek(self.begin) return self async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None: await self.file_manager.__aexit__(exc_type, exc_value, tb) def __aiter__(self) -> 'FileBody': return self async def __anext__(self) -> bytes: current = await self.file.tell() if current >= self.end: raise StopAsyncIteration() read_size = min(self.buffer_size, self.end - current) chunk = await self.file.read(read_size) if chunk: return chunk else: raise StopAsyncIteration() async def convert_to_sequence(self) -> bytes: result = bytearray() async with self as response: async for data in response: result.extend(data) return bytes(result) async def make_conditional( self, begin: int, end: Optional[int], max_partial_size: Optional[int]=None, ) -> int: self.begin = begin self.end = self.size if end is None else end if max_partial_size is not None: self.end = min(self.begin + max_partial_size, self.end) _raise_if_invalid_range(self.begin, self.end, self.size) return self.size class Response(_BaseRequestResponse, JSONMixin): """This class represents a response. It can be subclassed and the subclassed used in preference by replacing the :attr:`~quart.Quart.response_class` with your subclass. Attributes: automatically_set_content_length: If False the content length header must be provided. default_status: The status code to use if not provided. default_mimetype: The mimetype to use if not provided. implicit_sequence_conversion: Implicitly convert the response to a iterable in the get_data method, to allow multiple iterations. """ automatically_set_content_length = True default_status = 200 default_mimetype = 'text/html' data_body_class = DataBody file_body_class = FileBody implicit_sequence_conversion = True iterable_body_class = IterableBody def __init__( self, response: Union[ResponseBody, AnyStr, Iterable], status: Optional[int]=None, headers: Optional[Union[dict, CIMultiDict, Headers]]=None, mimetype: Optional[str]=None, content_type: Optional[str]=None, *, timeout: Optional[int]=None, ) -> None: """Create a response object. The response itself can be a chunk of data or a iterable/generator of data chunks. The Content-Type can either be specified as a mimetype or content_type header or omitted to use the :attr:`default_mimetype`. Arguments: response: The response data or iterable over the data. status_code: Status code of the response. headers: Headers to attach to the response. mimetype: Mimetype of the response. content_type: Content-Type header value. timeout: Optional argument to specify timeout when sending response data. Attributes: response: An iterable of the response bytes-data. """ super().__init__(headers) self.timeout = timeout if status is None: status = self.default_status try: self.status_code = int(status) except ValueError as error: raise ValueError('Quart does not support non-integer status values') from error if content_type is None: if mimetype is None and 'content-type' not in self.headers: mimetype = self.default_mimetype if mimetype is not None: self.mimetype = mimetype if content_type is not None: self.headers['Content-Type'] = content_type self.response: ResponseBody if isinstance(response, ResponseBody): self.response = response elif isinstance(response, (str, bytes)): self.set_data(response) # type: ignore else: self.response = self.iterable_body_class(response) async def get_data(self, raw: bool=True) -> AnyStr: """Return the body data.""" if self.implicit_sequence_conversion: self.response = self.data_body_class(await self.response.convert_to_sequence()) result = b'' if raw else '' async with self.response as body: # type: ignore async for data in body: if raw: result += data else: result += data.decode(self.charset) return result # type: ignore def set_data(self, data: AnyStr) -> None: """Set the response data. This will encode using the :attr:`charset`. """ if isinstance(data, str): bytes_data = data.encode(self.charset) else: bytes_data = data self.response = self.data_body_class(bytes_data) if self.automatically_set_content_length: self.content_length = len(bytes_data) async def make_conditional( self, request_range: Range, max_partial_size: Optional[int]=None, ) -> None: """Make the response conditional to the Arguments: request_range: The range as requested by the request. max_partial_size: The maximum length the server is willing to serve in a single response. Defaults to unlimited. """ self.accept_ranges = "bytes" # Advertise this ability if len(request_range.ranges) == 0: # Not a conditional request return if request_range.units != "bytes" or len(request_range.ranges) > 1: from ..exceptions import RequestRangeNotSatisfiable raise RequestRangeNotSatisfiable() begin, end = request_range.ranges[0] try: complete_length = await self.response.make_conditional( # type: ignore begin, end, max_partial_size, ) except AttributeError: self.response = self.data_body_class(await self.response.convert_to_sequence()) return await self.make_conditional(request_range, max_partial_size) else: self.content_length = self.response.end - self.response.begin # type: ignore if self.content_length != complete_length: self.content_range = ContentRange( request_range.units, self.response.begin, self.response.end - 1, # type: ignore complete_length, ) self.status_code = 206 async def freeze(self) -> None: """Freeze this object ready for pickling.""" self.set_data((await self.get_data())) def set_cookie( # type: ignore self, key: str, value: AnyStr='', max_age: Optional[Union[int, timedelta]]=None, expires: Optional[datetime]=None, path: str='/', domain: Optional[str]=None, secure: bool=False, httponly: bool=False, ) -> None: """Set a cookie in the response headers. The arguments are the standard cookie morsels and this is a wrapper around the stdlib SimpleCookie code. """ if isinstance(value, bytes): value = value.decode() # type: ignore cookie = create_cookie(key, value, max_age, expires, path, domain, secure, httponly) # type: ignore # noqa: E501 self.headers.add('Set-Cookie', cookie.output(header='')) def delete_cookie(self, key: str, path: str='/', domain: Optional[str]=None) -> None: """Delete a cookie (set to expire immediately).""" self.set_cookie(key, expires=datetime.utcnow(), max_age=0, path=path, domain=domain) async def add_etag(self, overwrite: bool=False, weak: bool=False) -> None: if overwrite or 'etag' not in self.headers: self.set_etag(md5((await self.get_data())).hexdigest(), weak) # type: ignore def get_etag(self) -> Tuple[Optional[str], Optional[bool]]: etag = self.headers.get('ETag') if etag is None: return None, None else: weak = False if etag.upper().startswith('W/'): etag = etag[2:] return etag.strip('"'), weak def set_etag(self, etag: str, weak: bool=False) -> None: if weak: self.headers['ETag'] = f"W/\"{etag}\"" else: self.headers['ETag'] = f"\"{etag}\"" @property def access_control(self) -> ResponseAccessControl: def on_update(value: ResponseAccessControl) -> None: self.access_control = value return ResponseAccessControl.from_headers( self.headers.get('Access-Control-Allow-Credentials', ''), self.headers.get('Access-Control-Allow-Headers', ''), self.headers.get('Access-Control-Allow-Methods', ''), self.headers.get('Access-Control-Allow-Origin', ''), self.headers.get('Access-Control-Expose-Headers', ''), self.headers.get('Access-Control-Max-Age', ''), on_update=on_update, ) @access_control.setter def access_control(self, value: ResponseAccessControl) -> None: max_age = value.max_age if max_age is None: self.headers.pop('Access-Control-Max-Age', None) else: self.headers['Access-Control-Max-Age'] = max_age if value.allow_credentials: self.headers['Access-Control-Allow-Credentials'] = 'true' else: self.headers.pop('Access-Control-Allow-Credentials', None) self._set_or_pop_header('Access-Control-Allow-Headers', value.allow_headers.to_header()) self._set_or_pop_header('Access-Control-Allow-Methods', value.allow_methods.to_header()) self._set_or_pop_header('Access-Control-Allow-Origin', value.allow_origin.to_header()) self._set_or_pop_header('Access-Control-Expose-Headers', value.expose_headers.to_header()) @property def accept_ranges(self) -> Optional[str]: return self.headers.get('Accept-Ranges') @accept_ranges.setter def accept_ranges(self, value: str) -> None: self.headers['Accept-Ranges'] = value @property def age(self) -> Optional[int]: try: value = self.headers.get('Age', '') except (TypeError, ValueError): return None return int(value) if value > 0 else None @age.setter def age(self, value: Union[int, timedelta]) -> None: if isinstance(value, timedelta): self.headers['Age'] = str(value.total_seconds()) else: self.headers['Age'] = str(value) @property def allow(self) -> HeaderSet: def on_update(header_set: HeaderSet) -> None: self.allow = header_set return HeaderSet.from_header(self.headers.get('Allow', ''), on_update=on_update) @allow.setter def allow(self, value: HeaderSet) -> None: self._set_or_pop_header('Allow', value.to_header()) @property def cache_control(self) -> ResponseCacheControl: def on_update(cache_control: ResponseCacheControl) -> None: self.cache_control = cache_control return ResponseCacheControl.from_header(self.headers.get('Cache-Control', ''), on_update) # type: ignore # noqa: E501 @cache_control.setter def cache_control(self, value: ResponseCacheControl) -> None: self._set_or_pop_header('Cache-Control', value.to_header()) @property def content_encoding(self) -> Optional[str]: return self.headers.get('Content-Encoding') @content_encoding.setter def content_encoding(self, value: str) -> None: self.headers['Content-Encoding'] = value @property def content_language(self) -> HeaderSet: def on_update(header_set: HeaderSet) -> None: self.content_language = header_set return HeaderSet.from_header(self.headers.get('Content-Language', ''), on_update=on_update) @content_language.setter def content_language(self, value: HeaderSet) -> None: self._set_or_pop_header('Content-Language', value.to_header()) @property def content_length(self) -> Optional[int]: try: return int(self.headers.get('Content-Length')) except (ValueError, TypeError): return None @content_length.setter def content_length(self, value: int) -> None: self.headers['Content-Length'] = str(value) @property def content_location(self) -> Optional[str]: return self.headers.get('Content-Location') @content_location.setter def content_location(self, value: str) -> None: self.headers['Content-Location'] = value @property def content_md5(self) -> Optional[str]: return self.headers.get('Content-MD5') @content_md5.setter def content_md5(self, value: str) -> None: self.headers['Content-MD5'] = value @property def content_range(self) -> ContentRange: def on_update(cache_range: ContentRange) -> None: self.content_range = cache_range return ContentRange.from_header(self.headers.get('Content-Range', ''), on_update) @content_range.setter def content_range(self, value: ContentRange) -> None: self._set_or_pop_header('Content-Range', value.to_header()) @property def content_type(self) -> Optional[str]: return self.headers.get('Content-Type') @content_type.setter def content_type(self, value: str) -> None: self.headers['Content-Type'] = value @property def date(self) -> Optional[datetime]: try: return parsedate_to_datetime(self.headers.get('Date', '')) except TypeError: # Not a date format return None @date.setter def date(self, value: datetime) -> None: self.headers['Date'] = format_date_time(value.timestamp()) @property def expires(self) -> Optional[datetime]: try: return parsedate_to_datetime(self.headers.get('Expires', '')) except TypeError: # Not a date format return None @expires.setter def expires(self, value: datetime) -> None: self.headers['Expires'] = format_date_time(value.timestamp()) @property def last_modified(self) -> Optional[datetime]: try: return parsedate_to_datetime(self.headers.get('Last-Modified', '')) except TypeError: # Not a date format return None @last_modified.setter def last_modified(self, value: datetime) -> None: self.headers['Last-Modified'] = format_date_time(value.timestamp()) @property def location(self) -> Optional[str]: return self.headers.get('Location') @location.setter def location(self, value: str) -> None: self.headers['Location'] = value @property def referrer(self) -> Optional[str]: return self.headers.get('Referer') @referrer.setter def referrer(self, value: str) -> None: self.headers['Referer'] = value @property def retry_after(self) -> Optional[datetime]: value = self.headers.get('Retry-After', '') if value.isdigit(): return datetime.utcnow() + timedelta(seconds=int(value)) else: try: return parsedate_to_datetime(value) except TypeError: return None @retry_after.setter def retry_after(self, value: Union[datetime, int]) -> None: if isinstance(value, datetime): self.headers['Retry-After'] = format_date_time(value.timestamp()) else: self.headers['Retry-After'] = str(value) @property def vary(self) -> HeaderSet: def on_update(header_set: HeaderSet) -> None: self.vary = header_set return HeaderSet.from_header(self.headers.get('Vary', ''), on_update=on_update) @vary.setter def vary(self, value: HeaderSet) -> None: self._set_or_pop_header('Vary', value.to_header()) async def _load_json_data(self) -> str: """Return the data after decoding.""" return await self.get_data(raw=False) def _set_or_pop_header(self, key: str, value: str) -> None: if value == '': self.headers.pop(key, None) else: self.headers[key] = value def __repr__(self) -> str: return f"{self.__class__.__name__}({self.status_code})"