Source code for shapeflow.core

import abc
import threading
from typing import Callable, Dict, List, Tuple, Type, Optional, _GenericAlias, Any  # type: ignore
import collections
from contextlib import contextmanager

import uuid

from shapeflow import get_logger
from shapeflow.util.meta import bind


log = get_logger(__name__)


# todo: move up to shapeflow
[docs]class RootException(Exception): """All ``shapeflow`` exceptions should be subclasses of this one. Automatically logs the exception class and message at the ``ERROR`` level. """ msg = '' """The message to log """ def __init__(self, *args): # https://stackoverflow.com/questions/49224770/ # if no arguments are passed set the first positional argument # to be the default message. To do that, we have to replace the # 'args' tuple with another one, that will only contain the message. # (we cannot do an assignment since tuples are immutable) if not (args): args = (self.msg,) log.error(self.__class__.__name__ + ': ' + ' '.join(args)) super(Exception, self).__init__(*args)
[docs]class DispatchingError(RootException): """An error dispatching a method call or exposing an endpoint. """
[docs]class EnforcedStr(str): """A string that is enforced to be one of several options. Works like a dynamic ``Enum`` -- options can be added at runtime. """ _options: List[str] = [''] _descriptions: Dict[str, str] = {} _str: str _default: Optional[str] = None def __init__(self, string: str = None): super().__init__() if string is not None: if string not in self.options: if string: log.warning(f"Illegal {self.__class__.__name__} '{string}', " f"should be one of {self.options}. " f"Defaulting to '{self.default}'.") self._str = str(self.default) else: self._str = str(string) else: self._str = str(self.default) def __repr__(self): return f"<{self.__class__.__name__} '{self._str}'>" def __str__(self): return str(self._str) # Make SURE it's a string :( def __eq__(self, other): if hasattr(other, '_str'): return self._str == other._str elif isinstance(other, str): return self._str == other else: return False @property def options(self): """The accepted options """ return self._options @property def descriptions(self): """The descriptions of each option """ return self._descriptions @property def describe(self): """The description of the currently selected option """ return self.descriptions[self._str] @property def default(self): """The default option for this :class:`~shapeflow.core.EnforcedStr` """ if self._default is not None: return self._default else: return self._options[0]
[docs] @classmethod def set_default(cls, value: 'EnforcedStr') -> None: """Explicitly sets the default. Parameters ---------- value : EnforcedStr The default value to set """ if isinstance(value, cls) and value in cls().options: log.debug(f"setting default of '{cls.__name__}' to '{value}'") cls._default = value else: raise ValueError( f"cannot set default of '{cls.__name__}' to '{value}'" )
def __hash__(self): # todo: why? return hash(str(self))
[docs] @classmethod def __modify_schema__(cls, field_schema): """Modify ``pydantic`` schema to include default, descriptions and act as an ``Enum`` """ # pydantic temp = cls() field_schema.update( enum=temp.options, default=temp.default, descriptions=temp.descriptions )
class _Streaming(EnforcedStr): _options = ['off', 'image', 'json', 'plain'] stream_off = _Streaming('off') stream_image = _Streaming('image') stream_json = _Streaming('json') stream_plain = _Streaming('plain')
[docs]class Endpoint(object): """An endpoint for an internal method. """ _name: str _registered: bool _signature: Type[Callable] _method: Optional[Callable] _streaming: _Streaming _update: Optional[Callable[['Endpoint'], None]] def __init__(self, signature: _GenericAlias, streaming: _Streaming = stream_off): # todo: type Callable[] correctly try: assert signature.__origin__ == collections.abc.Callable assert hasattr(signature, '__args__') except Exception: raise TypeError('Invalid Endpoint signature') self._method = None self._update = None self._registered = False self._signature = signature self._streaming = streaming
[docs] def compatible(self, method: Callable) -> bool: """Checks whether a method is compatible with the endpoint's signature Parameters ---------- method : Callable Any method or function Returns ------- bool ``True`` if the method is compatible, ``False`` if it isn't. """ if hasattr(method, '__annotations__'): args: List = [] for arg in self.signature: if arg == type(None): arg = None args.append(arg) # Don't be too pedantic unannotated None-type return return tuple(method.__annotations__.values()) == tuple(args) else: return False
[docs] def expose(self): """ Expose a method at this endpoint. Used as a decorator:: @endpoint.expose() def some_method(): pass """ def wrapper(method): if self._method is not None: log.debug( # todo: add traceback f"Exposing '{method.__qualname__}' at endpoint '{self.name}' will override " f"previously exposed method '{self._method.__qualname__}'." ) # todo: keep in mind we're also marking the methods themselves if not self.compatible(method): raise DispatchingError( f"Cannot expose '{method.__qualname__}' at endpoint '{self.name}'. " f"Incompatible signature: {method.__annotations__} vs. {self.signature}" ) method._endpoint = self self._method = method if self._update is not None: self._update(self) return method return wrapper
@property def method(self) -> Optional[Callable]: """The method exposed at this endpoint. Can be ``None`` """ return self._method @property def signature(self) -> tuple: """The signature of this endpoint. """ return self._signature.__args__ # type: ignore @property def streaming(self) -> _Streaming: """What or whether this endpoint streams. """ return self._streaming @property def registered(self) -> bool: """Whether this endpoint is registered. """ return self._registered @property def name(self) -> str: """The name of this endpoint. Taken from its attribute name in the object where it is registered. """ try: return self._name except AttributeError: return ''
[docs] def register(self, name: str, callback: Callable[['Endpoint'], None]): """Register the endpoint in some other object. """ self._registered = True self._name = name self._update = callback
[docs]class Dispatcher(object): # todo: these should also register specific instances & handle dispatching? """Dispatches requests to :class:`shapeflow.core.Endpoint` objects. """ _endpoints: Tuple[Endpoint, ...] #type: ignore _dispatchers: Tuple['Dispatcher', ...] _name: str _parent: Optional['Dispatcher'] _address_space: Dict[str, Optional[Callable]] _update: Optional[Callable[['Dispatcher'], None]] _instance: Optional[object] def __init__(self, instance: object = None): self._update = None if instance is not None: self._set_instance(instance) else: self._address_space = {} self._endpoints = tuple() self._dispatchers = tuple() @property def name(self) -> str: """The name of this dispatcher. """ try: return self._name except AttributeError: return self.__class__.__name__ @property def dispatchers(self) -> Tuple['Dispatcher', ...]: """The dispatchers nested in this dispatcher. """ return self._dispatchers @property def endpoints(self) -> Tuple[Endpoint, ...]: """The endpoints contained in this dispatcher. """ return self._endpoints @property def address_space(self) -> Dict[str, Optional[Callable]]: """The address-method mapping of this dispatcher. """ return self._address_space def _set_instance(self, instance: object): self._instance = instance self._address_space = {} self._endpoints = tuple() self._dispatchers = tuple() for attr, val in self.__class__.__dict__.items(): if isinstance(val, Endpoint): # todo: also register dispatchers self._add_endpoint(attr, val) elif isinstance(val, Dispatcher): self._add_dispatcher(attr, val) def _register(self, name: str, callback: Callable[['Dispatcher'], None]): """Register this dispatcher within another dispatcher. """ self._update = callback self._name = name def _add_endpoint(self, name: str, endpoint: Endpoint): endpoint.register(name=name, callback=self._update_endpoint) if endpoint.method is not None and self._instance is not None: method = bind(self._instance, endpoint.method) else: method = endpoint.method self._address_space[name] = method self._endpoints = tuple(list(self._endpoints) + [endpoint]) setattr(self, name, endpoint) if self._update is not None: self._update(self) def _add_dispatcher(self, name: str, dispatcher: 'Dispatcher'): dispatcher._register(name=name, callback=self._update_dispatcher) self._address_space.update({ "/".join([name, address]): method for address, method in dispatcher.address_space.items() if method is not None and "__" not in address }) self._dispatchers = tuple(list(self._dispatchers) + [dispatcher]) setattr(self, name, dispatcher) if self._update is not None: self._update(self) def _update_endpoint(self, endpoint: Endpoint) -> None: self._address_space.update({ endpoint.name: endpoint.method }) if self._update is not None: self._update(self) def _update_dispatcher(self, dispatcher: 'Dispatcher') -> None: self._address_space.update({ # todo: this doesn't take into account deleted keys! "/".join([dispatcher.name, address]): method for address, method in dispatcher.address_space.items() if method is not None and "__" not in address }) if self._update is not None: self._update(self)
[docs] def dispatch(self, address: str, *args, **kwargs) -> Any: """Dispatch a request to a method. Parameters ---------- address : str The address to dispatch to args Any positional arguments to pass on to the method kwargs Any keyword arguments to pass on to the method Returns ------- Any Whatever the method returns. """ try: method = self.address_space[address] if method is not None: # todo: consider doing some type checking here, args/kwargs vs. method._endpoint.signature return method(*args, **kwargs) except KeyError: raise DispatchingError( f"'{self.name}' can't dispatch address '{address}'." )
def dispatch_async(self, address: str, *args, **kwargs) -> None: def _dispatch(): self.dispatch(address, *args, **kwargs) threading.Thread(target=_dispatch).start() def __getitem__(self, item): return getattr(self, item)
[docs]class Described(object): """A class with a description. This description is taken from the first line of the docstring if there is one or set to the name of the class if there isn't. """ @classmethod def _description(cls): if cls.__doc__ is not None: return cls.__doc__.split('\n')[0] else: return cls.__name__
[docs]class Lockable(object): """Wrapper around :class:`threading.Lock` & :class:`threading.Event` Defines a :class:`~shapeflow.core.Lockable.lock` context to handle locking and unlocking along with a ``_cancel`` and ``_error`` events to communicate with :class:`~shapeflow.core.Lockable` objects from other threads. Doesn't need to initialize; lock & events are created when they're needed. """ _lock: threading.Lock _cancel: threading.Event _error: threading.Event @property def _ensure_lock(self) -> threading.Lock: try: return self._lock except AttributeError: self._lock = threading.Lock() return self._lock @property def _ensure_cancel(self) -> threading.Event: try: return self._cancel except AttributeError: self._cancel = threading.Event() return self._cancel @property def _ensure_error(self) -> threading.Event: try: return self._error except AttributeError: self._error = threading.Event() return self._error
[docs] @contextmanager def lock(self): """Locking context. If ``_lock`` event doesn't exist yet it is instantiated first. Upon exiting the context, the :class:`threading.Lock` object is compared to the original to ensure that no shenanigans took place. """ log.vdebug(f"Acquiring lock {self}...") locked = self._ensure_lock.acquire() original_lock = self._lock log.vdebug(f"Acquired lock {self}") try: log.vdebug(f"Locking {self}") yield locked finally: log.vdebug(f"Unlocking {self}") # Make 'sure' nothing weird happened to self._lock assert self._lock == original_lock self._lock.release()
[docs] def cancel(self): """Sets the ``_cancel`` event. If ``_cancel`` event doesn't exist yet it is instantiated first. """ self._ensure_cancel.set()
[docs] def error(self): """Sets the ``_error`` event. If ``_error`` event doesn't exist yet it is instantiated first. """ self._ensure_error.set()
@property def canceled(self) -> bool: """Returns ``True`` if the ``_cancel`` event is set. If ``_cancel`` event doesn't exist yet it is instantiated first. """ return self._ensure_cancel.is_set() @property def errored(self) -> bool: """Returns ``True`` if the ``_error`` event is set. If ``_error`` event doesn't exist yet it is instantiated first. """ return self._ensure_error.is_set()
[docs] def clear_cancel(self): """Clears the ``_cancel`` event. If ``_cancel`` event doesn't exist yet it is instantiated first. """ return self._ensure_cancel.clear()
[docs] def clear_error(self): """Clears the ``_error`` event. If ``_error`` event doesn't exist yet it is instantiated first. """ return self._ensure_error.clear()
[docs]class RootInstance(Lockable): # todo: basically deprecated _id: str def _set_id(self, id: str): self._id = id @property def id(self): return self._id