import abc
import copy
from typing import Optional, Union, Type, Dict, Mapping, List
import numpy as np
from pydantic import BaseModel
from shapeflow import get_logger, __version__
from shapeflow.core import EnforcedStr, Described
from shapeflow.util import ndarray2str
log = get_logger(__name__)
# Metadata tags
VERSION: str = 'config_version'
CLASS: str = 'config_class'
TAGS = (VERSION, CLASS)
# Extension
__meta_ext__ = '.meta'
# Excel sheet name
__meta_sheet__ = 'metadata'
# todo: move up to shapeflow.core
[docs]class Factory(EnforcedStr, metaclass=abc.ABCMeta): # todo: add a _class & issubclass check
"""An enforced string which maps its options to types.
Included types should be subclasses of :class:`~shapeflow.core.Described`
in order to generate descriptions for all options.
"""
_mapping: Mapping[str, Type[Described]] = {}
_default: Optional[str] = None
_type: Type[Described] = Described
[docs] def get(self) -> Type[Described]:
"""Get the type associated with the current string.
"""
if self._str in self._mapping:
return self._mapping[self._str]
else:
raise ValueError(f"Factory {self.__class__.__name__} doesn't map "
f"{self._str} to a class.")
[docs] @classmethod
def get_str(cls, mapped_value: Type[Described]):
"""Get the string for a specific type.
"""
for k,v in cls._mapping.items():
if mapped_value == v:
return k
@property
def options(self) -> List[str]:
"""The options for this factory.
"""
return list(self._mapping.keys())
@property
def descriptions(self) -> Dict[str, str]:
"""The descriptions for this factory.
"""
return { k:v._description() for k,v in self._mapping.items() }
@property
def default(self) -> Optional[str]:
"""The default for this factory.
"""
if self._default is not None:
return self._default
else:
if hasattr(self, '_mapping') and len(self._mapping):
return list(self._mapping.keys())[0]
else:
return None
[docs] @classmethod
def extend(cls, key: str, extension: Type[Described]):
"""Add a new type to this factory.
Used to dynamically add options e.g. for including plugins.
"""
if not hasattr(cls, '_mapping'):
cls._mapping = {}
assert isinstance(cls._mapping, dict) # to put MyPy at ease
if issubclass(extension, cls._type):
log.debug(f"Extending Factory '{cls.__name__}' "
f"with {{'{key}': {extension}}}")
cls._mapping.update({key: extension})
else:
raise TypeError(f"Attempting to extend Factory '{cls.__name__}' "
f"with incompatible class {extension.__name__}")
[docs] @abc.abstractmethod
def config_schema(self) -> dict:
"""The ``pydantic`` configuration schema for
the members of this factory
"""
[docs]class extend(object): # todo: can this be a function instead? look at the @dataclass decorator, something weird is going on there with * and /
"""Decorator to extend :class:`~shapeflow.core.config.Factory` classes.
Usage::
from shapeflow.core.config import extend
@extend(SomeFactory)
class SomeClass:
pass
"""
_factory: Type[Factory]
module_as_key: bool
def __init__(self, factory: Type[Factory], module_as_key: bool = False):
self._factory = factory
self.module_as_key = module_as_key
def __call__(self, cls):
key: str
if self.module_as_key:
key = cls.__module__.split('.')[-1]
else:
key = cls.__name__
self._factory.extend(key, cls)
return cls
[docs]def untag(d: dict) -> dict:
"""Remove the tags from a configuration ``dict``
Parameters
----------
d : dict
Any configuration dict
Returns
-------
dict
The original configuration ``dict`` without class and version info
"""
for tag in TAGS:
if tag in d:
d.pop(tag)
return d
[docs]class BaseConfig(BaseModel, Described):
"""Abstract configuration class.
All other configuration classes should derive from this one.
Usage, where ``SomeConfig`` is a subclass of ``BaseConfig``::
# instantiating
config = SomeConfig()
config = SomeConfig(field1=1.0, field2='text')
config = SomeConfig(**dict_with_fields_and_values)
# updating
config(field1=1.0, field2='text')
config(**dict_with_fields_and_values)
# saving
dict_with_fields_and_values = config.to_dict()
When writing ``BaseConfig`` subclasses, use the
:class:`~shapeflow.core.config.extend` decorator to make your
configuration class accessible through the
:class:`~shapeflow.core.config.ConfigType` factory. Configuration fields
are declared as ``pydantic.Field`` instances and must be type-annotated
for type resolution to work properly.
Example::
from pydantic import Field
from shapeflow.core.config import BaseConfig
@extend(ConfigType)
class SomeConfig(BaseConfig):
field1: int = Field(default=42)
field2: SomeNestedConfig = Field(default_factory=SomeOtherConfig)
"""
class Config:
"""``pydantic`` configuration class
"""
arbitrary_types_allowed = False
use_enum_value = True
validate_assignment = True
json_encoders = {
np.ndarray: list,
EnforcedStr: str,
}
[docs] @classmethod
def _resolve_enforcedstr(cls, value, field):
"""Resolve :class:`~shapeflow.core.EnforcedStr` objects
from regular ``str`` objects. To be used in ``pydantic`` validators.
"""
if isinstance(value, field.type_):
return value
elif isinstance(value, str):
return field.type_(value)
else:
raise NotImplementedError
[docs] @classmethod
def _odd_add(cls, value):
"""Make sure a value stays odd by incrementing even values.
To be used in ``pydantic`` validators.
"""
if value:
if not (value % 2):
return value + 1
else:
return value
else:
return 0
[docs] @classmethod
def _int_limits(cls, value, field):
"""Enforce ``pydantic`` field limits (`le`, `lt`, `ge`, `gt`)
for ``int`` fields. To be used in ``pydantic`` validators.
"""
if field.field_info.le is not None and not value <= field.field_info.le:
return field.field_info.le
elif field.field_info.lt is not None and not value < field.field_info.lt:
return field.field_info.lt - 1
elif field.field_info.ge is not None and not value >= field.field_info.ge:
return field.field_info.ge
elif field.field_info.gt is not None and not value > field.field_info.gt:
return field.field_info.gt + 1
else:
return value
[docs] @classmethod
def _float_limits(cls, value, field):
"""Enforce ``pydantic`` field limits (`le`, `lt`, `ge`, `gt`)
for ``float`` fields. To be used in ``pydantic`` validators.
"""
if field.field_info.le is not None and not value <= field.field_info.le:
return field.field_info.le
elif field.field_info.lt is not None and not value < field.field_info.lt:
log.warning(f"resolving float 'lt' as 'le' for field {field}")
return field.field_info.lt
elif field.field_info.ge is not None and not value >= field.field_info.ge:
return field.field_info.ge
elif field.field_info.gt is not None and not value > field.field_info.gt:
log.warning(f"resolving float 'gt' as 'ge' for field {field}")
return field.field_info.gt
else:
return value
def __call__(self, **kwargs) -> None:
# iterate over fields to maintain validation order
for field in self.__fields__.keys():
if field in kwargs: # todo: inefficient
if isinstance(getattr(self, field), BaseConfig) and isinstance(kwargs[field], dict):
# If field is a BaseConfig instance, resolve in place
getattr(self, field)(**kwargs[field])
else:
# Otherwise, let the validators handle it
setattr(self, field, kwargs[field])
[docs] def to_dict(self, do_tag: bool = False) -> dict: # todo: should be replaced by pydantic internals + serialization
"""Return the configuration as a serializable dict.
Parameters
----------
do_tag : bool
If `True`, add configuration class and version fields to the dict
Returns
-------
dict
A serializable representation of this configuration object.
"""
"""Return the configuration as a serializable dict.
:param do_tag: if `True`, add configuration class and version fields to the dict
:return: dict
"""
output: dict = {}
def _represent(obj) -> Union[dict, str]:
if isinstance(obj, BaseConfig):
# Recurse, but don't tag
return obj.to_dict(do_tag = False)
if isinstance(obj, EnforcedStr):
# Return str value
try:
return str(obj)
except TypeError:
return ''
if isinstance(obj, tuple):
# Convert to str to bypass YAML tuple representation
return str(obj)
if isinstance(obj, np.ndarray):
# Convert to str to bypass YAML list representation
return ndarray2str(obj)
else:
# Assume that `obj` is serializable
return obj
for attr, val in self.__dict__.items():
try:
if val is not None:
if any([
isinstance(val, list),
isinstance(val, tuple),
]):
output[_represent(attr)] = type(val)([*map(_represent, val)])
elif isinstance(val, dict):
output[_represent(attr)] = {_represent(k):_represent(v) for k,v in val.items()}
else:
output[_represent(attr)] = _represent(val)
except ValueError:
log.debug(f"Config.to_dict() - skipping '{attr}': {val}")
if do_tag:
# todo: should only tag at the top-level (lots of unnecessary info otherwise)
self.tag(output)
return output
[docs] def tag(self, d: dict) -> dict:
"""Tag a ``dict`` with this object's class and the library version.
This information is used to deserialize correctly later on.
"""
d[VERSION] = __version__
d[CLASS] = self.__class__.__name__
return d
[docs]class ConfigType(Factory):
"""Configuration type factory
"""
_type = BaseConfig
_mapping: Mapping[str, Type[Described]] = {}
[docs] def get(self) -> Type[BaseConfig]:
"""Return the configuration type.
"""
config = super().get()
if issubclass(config, BaseConfig):
return config
else:
raise TypeError(
f"'{self.__class__.__name__}' tried to return an unexpected type '{config}'. "
f"This is very weird and shouldn't happen, really."
)
[docs] def config_schema(self) -> dict:
"""Return the configuration schema.
"""
return self.get().schema()
[docs]class Configurable(Described):
"""A class with an associated configuration type.
"""
_config_class: Type[BaseConfig] = BaseConfig
"""The configuration class as a class attribute. When subclassing, set this
attribute to a specific :class:`~shapeflow.core.config.BaseConfig` type to
associate it with this class.
"""
[docs] @classmethod
def config_class(cls):
"""The configuration class.
"""
return cls._config_class
[docs] @classmethod
def config_schema(cls):
"""The configuration schema.
"""
return cls.config_class().schema()
[docs]class Instance(Configurable): # todo: why isn't this just in Configurable?
_config: BaseConfig
@property
def config(self) -> BaseConfig:
return self._config
def __init__(self, config: BaseConfig = None):
self._configure(config)
super(Instance, self).__init__()
log.debug(f'Initialized {self.__class__.__qualname__} with {self._config}')
def _configure(self, config: BaseConfig = None): # todo: adapt to dataclass implementation
_type = self._config_class
if config is not None:
if isinstance(config, _type):
# Each instance should have a *copy* of the config, not references to the actual values
self._config = copy.deepcopy(config)
elif isinstance(config, dict):
log.warning(f"Initializing '{self.__class__.__name__}' from a dict, "
f"please initialize from '{_type}' instead.")
self._config = _type(**untag(config))
else:
raise TypeError(f"Tried to initialize '{self.__class__.__name__}' with {type(config).__name__} '{config}'.")
else:
self._config = _type()