# -*- coding: utf-8 -*-
# (c) 2009-2023 Martin Wendt and contributors; see WsgiDAV https://github.com/mar10/wsgidav
# Licensed under the MIT license: http://www.opensource.org/licenses/mit-license.php
"""
Tools that make it easier to implement custom WsgiDAV providers.
"""
import os
import stat
from wsgidav import util
from wsgidav.dav_provider import DAVCollection, DAVNonCollection
__docformat__ = "reStructuredText en"
_logger = util.get_module_logger(__name__)
# ============================================================================
# VirtualCollection
# ============================================================================
[docs]class VirtualCollection(DAVCollection):
"""Abstract base class for collections that contain a list of static members.
Member names are passed to the constructor.
get_member() is implemented by calling self.provider.get_resource_inst()
"""
def __init__(self, path, environ, display_info, member_name_list):
super().__init__(path, environ)
if util.is_basestring(display_info):
display_info = {"type": display_info}
assert type(display_info) is dict
assert type(member_name_list) is list
self.display_info = display_info
self.member_name_list = member_name_list
[docs] def get_display_info(self):
return self.display_info
[docs] def get_member_names(self):
return self.member_name_list
[docs] def prevent_locking(self):
"""Return True, since we don't want to lock virtual collections."""
return True
[docs] def get_member(self, name):
# raise NotImplementedError
return self.provider.get_resource_inst(
util.join_uri(self.path, name), self.environ
)
# ============================================================================
# _VirtualNonCollection classes
# ============================================================================
class _VirtualNonCollection(DAVNonCollection):
"""Abstract base class for all non-collection resources."""
def __init__(self, path, environ):
super().__init__(path, environ)
def get_content_length(self):
return None
def get_content_type(self):
return None
def get_creation_date(self):
return None
def get_display_name(self):
return self.name
def get_display_info(self):
raise NotImplementedError
def get_etag(self):
return None
def get_last_modified(self):
return None
def support_ranges(self):
return False
# def handle_delete(self):
# raise DAVError(HTTP_FORBIDDEN)
# def handle_move(self, destPath):
# raise DAVError(HTTP_FORBIDDEN)
# def handle_copy(self, destPath, depthInfinity):
# raise DAVError(HTTP_FORBIDDEN)
# ============================================================================
# VirtualTextResource
# ============================================================================
[docs]class VirtualTextResource(_VirtualNonCollection):
"""A virtual file, containing a string."""
def __init__(self, path, environ, content, display_name=None, display_type=None):
super().__init__(path, environ)
self.content = content
self.display_name = display_name
self.display_type = display_type
[docs] def get_content_length(self):
return len(self.get_content().read())
[docs] def get_content_type(self):
if self.name.endswith(".txt"):
return "text/plain"
return "text/html"
[docs] def get_display_name(self):
return self.display_name or self.name
[docs] def get_display_info(self):
return {"type": "Virtual info file"}
[docs] def prevent_locking(self):
return True
# def get_ref_url(self):
# refPath = "/by_key/%s/%s" % (self._data["key"], self.name)
# return quote(self.provider.share_path + refPath)
[docs] def get_content(self):
return util.StringIO(self.content)
# ============================================================================
# FileResource
# ============================================================================
[docs]class FileResource(_VirtualNonCollection):
"""Represents an existing file."""
BUFFER_SIZE = 8192
def __init__(self, path, environ, file_path):
if not os.path.exists(file_path):
_logger.error("FileResource({!r}) does not exist.".format(file_path))
super().__init__(path, environ)
self.file_path = file_path
[docs] def get_content_length(self):
statresults = os.stat(self.file_path)
return statresults[stat.ST_SIZE]
[docs] def get_content_type(self):
if not os.path.isfile(self.file_path):
return "text/html"
# (mimetype, _mimeencoding) = mimetypes.guess_type(self.file_path)
# if not mimetype:
# mimetype = "application/octet-stream"
# return mimetype
return util.guess_mime_type(self.file_path)
[docs] def get_creation_date(self):
statresults = os.stat(self.file_path)
return statresults[stat.ST_CTIME]
[docs] def get_display_info(self):
return {"type": "File"}
[docs] def get_last_modified(self):
statresults = os.stat(self.file_path)
return statresults[stat.ST_MTIME]
# def get_ref_url(self):
# refPath = "/by_key/%s/%s" % (self._data["key"], os.path.basename(self.file_path))
# return quote(self.provider.share_path + refPath)
[docs] def get_content(self):
# mime = self.get_content_type()
# GC issue 57: always store as binary
# if mime.startswith("text"):
# return open(self.file_path, "r", FileResource.BUFFER_SIZE)
return open(self.file_path, "rb", FileResource.BUFFER_SIZE)
# ============================================================================
# Resolvers
# ============================================================================
[docs]class DAVResolver:
"""Return a DAVResource object for a path (None, if not found)."""
def __init__(self, parent_resolver, name):
self.parent_resolver = parent_resolver
self.name = name
[docs] def resolve(self, script_name, path_info, environ):
raise NotImplementedError