Source code for shapeflow.core.db

import abc
import queue
import os
import time
from contextlib import contextmanager
from typing import Optional, List, Type, Any, Tuple
import datetime

from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session
from sqlalchemy.orm.util import object_state  # type: ignore
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import Integer, String, DateTime
from sqlalchemy.exc import InvalidRequestError

from shapeflow import get_logger
from shapeflow.core import RootException, RootInstance, Lockable
from shapeflow.util import hash_file

log = get_logger(__name__)
Base = declarative_base()


[docs]class SessionWrapper(object): """Wrapper object for a ``SQLAlchemy`` session factory. """ _session_factory: scoped_session
[docs] def connect(self, session_wrapper: 'SessionWrapper'): """Share the session factory of another ``SessionWrapper`` instance """ self._session_factory = session_wrapper._session_factory
[docs] @contextmanager def session(self): """ ``SQLAlchemy`` session context manager. Opens a ``SQLAlchemy`` session and commits after the block is done. Changes are rolled back if an exception is raised. Usage:: with self.session() as s: # interact with the database here """ session = self._session_factory() try: yield session session.commit() except: session.rollback() finally: session.close()
[docs]class DbModel(Base, SessionWrapper, Lockable): """Abstract database model class. Subclasses should """ __abstract__ = True @property def _models(self) -> List['DbModel']: """Used in `DbModel.session()` to add nested `DbModel` instances """ return [attr for attr in self.__dict__.values() if isinstance(attr, DbModel)] + [self]
[docs] def get(self, attr: str) -> Any: """Get attribute value from database """ with self.session(): return getattr(self, attr)
[docs] @contextmanager def session(self, add: bool = True): """``SQLAlchemy`` session context manager. Opens a ``SQLAlchemy`` session and commits after the block is done. Changes are rolled back if an exception is raised. Usage:: with self.session() as s: # interact with the database here Calls ``DbModel._pre()`` before yielding the session and ``DbModel._post()`` after the block is completed. Parameters ---------- add: bool add model(s) after opening the session """ with self.lock(): log.vdebug(f'opening session') session = self._session_factory() try: if add: for model in self._models: self._retry_add(session, model) else: self._retry_add(session, self) self._pre() yield session self._post() log.vdebug('committing') session.commit() except Exception as e: log.error(f"error during session: {e.args}") log.error('rolling back') session.rollback() raise finally: log.vdebug(f'closing session') session.close()
@staticmethod def _retry_add(session, model: 'DbModel', retry: bool = False): try: session.add(model) except InvalidRequestError: if not retry: time.sleep(0.1) DbModel._retry_add(session, model, retry=True) else: raise def _pre(self): if hasattr(self, 'added') and self.added is None: self.added = datetime.datetime.now() if hasattr(self, 'modified'): self.modified = datetime.datetime.now() def _post(self): pass
[docs]class FileModel(DbModel): """Abstract database model for files. Files are hashed and resolved in order to keep a single entry per file. """ __abstract__ = True _hash_q: queue.Queue _resolved: bool _path: str id = Column(Integer, primary_key=True) hash = Column(String, unique=True) path = Column(String) used = Column(DateTime) def __init__(self, path: str): self._resolved = False if path is not None: self._queue_hash(path) @property def resolved(self) -> bool: """Whether the ``FileModel`` has been resolved""" if hasattr(self, '_resolved'): return self._resolved else: return False def _queue_hash(self, path: str) -> None: self._path = path if self._check_file(): self._hash_q = hash_file(self._path) else: raise ValueError if not self._hash_q.qsize(): log.debug(f"queueing hash for {path}") def _get_hash(self) -> str: try: return self._hash_q.get() except AttributeError: raise RootException(f"{self.__class__.__qualname__}: " f"get_hash() was called before queue_hash()") def _check_file(self): if self._path is not None: return os.path.isfile(self._path) else: return False def _join(self): if self.hash is None: if self._hash_q is not None: while not self._hash_q.qsize(): time.sleep(0.01)
[docs] def resolve(self) -> 'FileModel': """Resolve the file by its SHA1 hash ~ :func:`~shapeflow.util.hash_file`. If the computed hash is new, the file is committed to the database. Otherwise, the original entry is re-used. Returns ------- FileModel The current instance if the file is new, or a new ``FileModel`` instance representing the original database entry. """ if not self.resolved: self._join() hash = self._get_hash() with self.session(add=False) as s: match = s.query(self.__class__).filter_by(hash=hash).first() if match is None: s.add(self) self.hash = hash self.path = self._path file = self else: file = match file.connect(self) # If current path is different, update history if self._path != file.path: file.path = self._path if object_state(self).persistent: s.delete(self) file._resolved = True file.used = datetime.datetime.now() return file else: return self
[docs]class BaseAnalysisModel(DbModel): """AnalysisModel interface. """ __abstract__ = True
[docs] @abc.abstractmethod def get_name(self) -> str: """Get the name of the analysis"""
[docs] @abc.abstractmethod def get_config_json(self) -> Optional[str]: """Get the current configuration in JSON"""
[docs] @abc.abstractmethod def load_config(self, video_path: str, design_path: str = None, include: List[str] = None) -> Optional[dict]: """Load configuration from the database"""
[docs] @abc.abstractmethod def get_undo_config(self, context: str = None) -> Tuple[Optional[dict], Optional[int]]: """Undo configuration. If a ``context`` is supplied, ensure that the ``context`` field changes, but the other fields remain the same"""
[docs] @abc.abstractmethod def get_redo_config(self, context: str = None) -> Tuple[Optional[dict], Optional[int]]: """Redo configuration. If a ``context`` is supplied, ensure that the ``context`` field changes, but the other fields remain the same"""
[docs] @abc.abstractmethod def store(self) -> None: """Store analysis information from wrapped ``BaseVideoAnalyzer`` to the database"""
[docs] @abc.abstractmethod def export_result(self, run: int = None) -> None: """Export a result from the database"""
[docs] @abc.abstractmethod def get_runs(self) -> int: """Get th number of runs of this analysis"""
[docs] @abc.abstractmethod def get_id(self) -> int: """Get the id of this analysis"""