Source code for django_sorcery.db.sqlalchemy

"""SQLAlchemy goodies that provides a nice interface to using sqlalchemy with
django."""
import functools
import inspect

import sqlalchemy as sa
import sqlalchemy.orm  # noqa
from sqlalchemy.ext.declarative import declarative_base

from ..utils import make_args
from . import fields, signals
from .composites import BaseComposite, CompositeField
from .models import Base, BaseMeta
from .query import Query, QueryProperty
from .relations import RelationsMixin
from .session import SignallingSession
from .transaction import TransactionContext


[docs]def instrument(name): def do(self, *args, **kwargs): return getattr(self.registry(), name)(*args, **kwargs) return do
[docs]def makeprop(name): def get(self): return getattr(self.registry(), name) return property(get)
class _sqla_meta(type): def __new__(mcs, name, bases, attrs): typ = super().__new__(mcs, name, bases, attrs) # figure out all props to be proxied dummy = sa.orm.Session() props = {i for i in dir(dummy) if not i.startswith("__")} props.update( [ "__contains__", "__iter__", "add", "add_all", "autocommit", "autoflush", "begin", "begin_nested", "bind", "bulk_insert_mappings", "bulk_save_objects", "bulk_update_mappings", "close", "commit", "connection", "delete", "deleted", "dirty", "execute", "expire", "expire_all", "expunge", "expunge_all", "flush", "get_bind", "identity_map", "info", "is_active", "is_modified", "merge", "new", "no_autoflush", "query", "refresh", "rollback", "scalar", ] ) for i in props: if not hasattr(typ, name): if hasattr(dummy, i) and inspect.isroutine(getattr(dummy, i)): setattr(typ, i, instrument(i)) else: setattr(typ, i, makeprop(i)) return typ
[docs]class SQLAlchemy(RelationsMixin, metaclass=_sqla_meta): """This class itself is a scoped session and provides very thin and useful abstractions and conventions for using sqlalchemy with django.""" session_class = SignallingSession query_class = Query registry_class = sa.util.ThreadLocalRegistry metadata_class = sa.MetaData model_class = Base BaseComposite = BaseComposite CompositeField = CompositeField def __init__(self, url, **kwargs): self.url, self.kwargs = url, kwargs self.alias = kwargs.get("alias") self.session_class = self.kwargs.get("session_class", None) or self.session_class self.query_class = self.kwargs.get("query_class", None) or self.query_class self.registry_class = self.kwargs.get("registry_class", None) or self.registry_class self.metadata_class = self.kwargs.get("metadata_class", None) or self.metadata_class self.model_class = self.kwargs.get("model_class", None) or self.model_class self.engine_options = self.kwargs.get("engine_options", {}) self.session_options = self.kwargs.get("session_options", {}) self.session_options.setdefault("query_cls", self.query_class) self.session_options.setdefault("class_", self.session_class) self.middleware = self.make_middleware() self.models_registry = [] self.metadata = self.metadata_class(**self.kwargs.get("metadata_kwargs", {})) self.Model = self._make_declarative(self.model_class) for module, is_partial in [(sa, False), (sa.sql, False), (sa.orm, False), (fields, True)]: for key in module.__all__: if not hasattr(self, key) and not hasattr(type(self), key): value = getattr(module, key) setattr( self, key, functools.wraps(value)(functools.partial(value, db=self)) if is_partial else value ) self.collections = sa.orm.collections self.event = sa.event self.relationship = self._wrap(self.relationship) self.relation = self._wrap(self.relation) self.dynamic_loader = self._wrap(self.dynamic_loader) self._registry = None def __call__(self, **kwargs): return self.session(**kwargs) @property def registry(self): """Returns scoped registry instance.""" if not self._registry: engine = self._create_engine(self.url, **self.engine_options) self._registry = self.registry_class(sa.orm.sessionmaker(bind=engine, **self.session_options)) return self._registry @property def inspector(self): """Returns engine inspector. Useful for querying for db schema info. """ return sa.inspect(self.engine)
[docs] def session(self, **kwargs): """Return the current session, creating it if necessary using session_factory for the current scope Any kwargs provided will be passed on to the session_factory. If there's already a session in current scope, will raise InvalidRequestError """ if not kwargs: return self.registry() if self.registry.has(): raise sa.exc.InvalidRequestError("Scoped session is already present; " "no new arguments may be specified.") session = self.session_factory(**kwargs) self.registry.set(session) return session
[docs] def Table(self, name, *args, **kwargs): """Returns a sqlalchemy table that is automatically added to metadata.""" assert name is not None, "Table requires `name` argument" assert args, "Table at least takes one column argument" for arg in args: assert not isinstance(arg, sa.MetaData), "Passing a `metadata` is not allowed" return sa.Table(name, self.metadata, *args, **kwargs)
def _set_query_class(self, cls, kwargs): query_class = self.query_class try: query_class = cls.query_class except AttributeError: pass kwargs["query_class"] = query_class def _wrap(self, fn): """Wrapper that sets the 'query_class' argument from the model's query_class attribute.""" @functools.wraps(fn) def func(cls, *args, **kwargs): """Wraps a function setting default `query_class` argument.""" self._set_query_class(cls, kwargs) if "backref" in kwargs: backref = kwargs["backref"] backref_kwargs = {} if isinstance(backref, tuple): backref, backref_kwargs = backref self._set_query_class(backref, backref_kwargs) kwargs["backref"] = (backref, backref_kwargs) return fn(cls, *args, **kwargs) return func def _create_engine(self, url, **kwargs): engine = sa.create_engine(url, **kwargs) signals.engine_created.send(engine) return engine def _make_declarative(self, model): """ Creates the base class that the models should inherit from ---------------------------------------------------------- model: class The base class for the declarative_base to be inherited from """ base = declarative_base( cls=model, metadata=self.metadata, metaclass=type("BaseMeta", (BaseMeta,), {"db": self}) ) # allow to customize things in custom base model if not hasattr(base, "query_class"): base.query_class = self.query_class if not hasattr(base, "query"): base.query = self.queryproperty() if not hasattr(base, "objects"): base.objects = self.queryproperty() return base @property def engine(self): """Current engine.""" return self.bind @property def session_factory(self): """Current session factory to create sessions.""" return self.registry.createfunc def __repr__(self): return "<{} engine={!r}>".format(self.__class__.__name__, self.url)
[docs] def queryproperty(self, *args, **kwargs): """Generate a query property for a model.""" model = None if args and isinstance(args[0], type) and issubclass(args[0], self.Model): model = args[0] args = args[1:] return QueryProperty(self, model, *args, **kwargs)
[docs] def atomic(self, savepoint=True): """Create a savepoint or transaction scope.""" return TransactionContext(self, savepoint=savepoint)
[docs] def make_middleware(self): """Creates a middleware to be used in a django application.""" from .middleware import SQLAlchemyDBMiddleware return type("SQLAlchemyMiddleware", (SQLAlchemyDBMiddleware,), {"db": self})
[docs] def args(self, *args, **kwargs): """Useful for setting table args and mapper args on models.""" return make_args(*args, **kwargs)
[docs] def remove(self): """Remove the current scoped session.""" if self.registry.has(): self.registry().close() self.registry.clear() for signal in signals.all_signals.scoped_signals: signal.cleanup()
[docs] def create_all(self): """Create the schema in db.""" result = signals.before_create_all.send(db=self, bind=self.engine) if all(i[1] in [True, None] for i in result): self.metadata.create_all(bind=self.engine) signals.after_create_all.send(self, db=self, bind=self.engine)
[docs] def drop_all(self): """Drop the schema in db.""" result = signals.before_drop_all.send(db=self, bind=self.engine) if all(i[1] in [True, None] for i in result): self.metadata.drop_all(bind=self.engine) signals.after_drop_all.send(self, db=self, bind=self.engine)