Source code for django_sorcery.db.session

"""sqlalchemy session related things."""
from itertools import chain

from sqlalchemy import event, orm

from ..utils import setdefaultattr
from . import signals


[docs]def before_flush(session, flush_context, instances): signals.before_flush.send(session, flush_context=flush_context, instances=instances)
[docs]def after_flush(session, flush_context): signals.after_flush.send(session, flush_context=flush_context)
[docs]def before_commit(session): if session.transaction and (session.transaction._parent is None or not session.transaction.nested): signals.before_commit.send(session) signals.before_scoped_commit.send(session)
[docs]def after_commit(session): if session.transaction and (session.transaction._parent is None or not session.transaction.nested): signals.after_scoped_commit.send(session) signals.after_commit.send(session) setdefaultattr(session, "models_committed", set()).clear() setdefaultattr(session, "models_deleted", set()).clear()
[docs]def after_rollback(session): if session.transaction and (session.transaction._parent is None or session.transaction.nested): signals.after_scoped_rollback.send(session) signals.after_rollback.send(session) setdefaultattr(session, "models_committed", set()).clear() setdefaultattr(session, "models_deleted", set()).clear()
[docs]def record_models(session, flush_context=None, instances=None): setdefaultattr(session, "models_committed", set()) setdefaultattr(session, "models_deleted", set()) for instance in chain(session.new, session.dirty): session.models_committed.add(instance) for instance in session.deleted: session.models_deleted.add(instance) if instance in session.models_committed: session.models_committed.remove(instance)
[docs]class SignallingSession(orm.Session): """A custom sqlalchemy session implementation that provides signals.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) event.listen(self, "after_flush", record_models) event.listen(self, "before_flush", before_flush) event.listen(self, "after_flush", after_flush) event.listen(self, "before_commit", before_commit) event.listen(self, "after_commit", after_commit) event.listen(self, "after_rollback", after_rollback)
[docs] def query(self, *args, **kwargs): """Override to try to use the model.query_class.""" if len(args) == 1 and hasattr(args[0], "query_class") and args[0].query_class is not None: return args[0].query_class(*args, session=self, **kwargs) return super().query(*args, **kwargs)