Source code for wsgidav.mw.cors

# -*- coding: utf-8 -*-
# (c) 2009-2023 Martin Wendt and contributors; see WsgiDAV https://github.com/mar10/wsgidav
# Licensed under the MIT license:
# http://www.opensource.org/licenses/mit-license.php
"""
WSGI middleware used for CORS support (optional).

Respond to CORS preflight OPTIONS request and inject CORS headers.
"""
from wsgidav import util
from wsgidav.mw.base_mw import BaseMiddleware

__docformat__ = "reStructuredText"

_logger = util.get_module_logger(__name__)


[docs] class Cors(BaseMiddleware): def __init__(self, wsgidav_app, next_app, config): super().__init__(wsgidav_app, next_app, config) opts = config.get("cors", None) if opts is None: opts = {} allow_origins = opts.get("allow_origin") if type(allow_origins) is str: allow_origins = allow_origins.strip() if allow_origins != "*": allow_origins = [allow_origins] elif allow_origins: allow_origins = [ao.strip() for ao in allow_origins] allow_headers = ",".join(util.to_set(opts.get("allow_headers"))) allow_methods = ",".join(util.to_set(opts.get("allow_methods"))) expose_headers = ",".join(util.to_set(opts.get("expose_headers"))) allow_credentials = opts.get("allow_credentials", False) max_age = opts.get("max_age") always_headers = opts.get("add_always") add_always = [] if allow_credentials: add_always.append(("Access-Control-Allow-Credentials", "true")) if always_headers: if type(always_headers) is not dict: raise ValueError( f"cors.add_always must be a list a dict: {always_headers}" ) for n, v in always_headers.items(): add_always.append((n, v)) add_non_preflight = add_always[:] if expose_headers: add_always.append(("Access-Control-Expose-Headers", expose_headers)) add_preflight = add_always[:] if allow_headers: add_preflight.append(("Access-Control-Allow-Headers", allow_headers)) if allow_methods: add_preflight.append(("Access-Control-Allow-Methods", allow_methods)) if max_age: add_preflight.append(("Access-Control-Max-Age", str(int(max_age)))) self.non_preflight_headers = add_non_preflight self.preflight_headers = add_preflight #: Either '*' or al list of origins self.allow_origins = allow_origins def __repr__(self): allow_origin = self.get_config("cors.allow_origin", None) return f"{self.__module__}.{self.__class__.__name__}({allow_origin})"
[docs] def is_disabled(self): """Optionally return True to skip this module on startup.""" return not self.get_config("cors.allow_origin", False)
[docs] def __call__(self, environ, start_response): method = environ["REQUEST_METHOD"].upper() origin = environ.get("HTTP_ORIGIN") ac_req_meth = environ.get("HTTP_ACCESS_CONTROL_REQUEST_METHOD") ac_req_headers = environ.get("HTTP_ACCESS_CONTROL_REQUEST_HEADERS") acao_headers = None if self.allow_origins == "*": acao_headers = [("Access-Control-Allow-Origin", "*")] elif origin in self.allow_origins: acao_headers = [ ("Access-Control-Allow-Origin", origin), ("Vary", "Origin"), ] if acao_headers: _logger.debug( f"Granted CORS {method} {environ['PATH_INFO']!r} " f"{ac_req_meth!r}, headers: {ac_req_headers}, origin: {origin!r}" ) else: # Deny (still return 200 on preflight) _logger.warning( f"Denied CORS {method} {environ['PATH_INFO']!r} " f"{ac_req_meth!r}, headers: {ac_req_headers}, origin: {origin!r}" ) is_preflight = method == "OPTIONS" and ac_req_meth is not None # Handle preflight request if is_preflight: # Always return 2xx, but only add Access-Control-Allow-Origin etc. # if Origin is allowed resp_headers = [ ("Content-Length", "0"), ("Date", util.get_rfc1123_time()), ] if acao_headers: resp_headers += acao_headers + self.preflight_headers start_response("204 No Content", resp_headers) return [b""] # non_preflight CORS request def wrapped_start_response(status, headers, exc_info=None): if acao_headers: util.update_headers_in_place( headers, acao_headers + self.non_preflight_headers, ) start_response(status, headers, exc_info) return self.next_app(environ, wrapped_start_response)