Source code for django_sorcery.management.alembic

"""Alembic Django command things."""
import os
from collections import OrderedDict, namedtuple

import alembic
import alembic.config
from django.core.management.base import BaseCommand, CommandError
from django.utils.functional import cached_property
from sqlalchemy.orm import configure_mappers

from ..db import alembic as sorcery_alembic, databases, meta, signals
from ..db.alembic import include_object, process_revision_directives


SORCERY_ALEMBIC_CONFIG_FOLDER = os.path.abspath(os.path.dirname(sorcery_alembic.__file__))


AlembicAppConfig = namedtuple("AlembicAppConfig", ["name", "config", "script", "db", "app", "version_path", "tables"])


[docs]class AlembicCommand(BaseCommand): """Base alembic django command.""" @cached_property def sorcery_apps(self): """All sorcery apps and their alembic configs.""" configs = OrderedDict() for db in databases.values(): table_class_map = {model.__table__: model for model in db.models_registry if hasattr(model, "__table__")} for table in db.metadata.sorted_tables: model = table_class_map.get(table) if model: info = meta.model_info(model) app = info.app_config if app: path = self.get_app_version_path(app) if os.path.exists(path): config = self.get_app_config(app, db) appconfig = AlembicAppConfig( name=app.label, config=config, db=db, script=self.get_config_script(config), version_path=path, app=app, tables=[], ) configs.setdefault(app.label, appconfig).tables.append(table) for app in configs.values(): signals.alembic_app_created.send(app.app) signals.alembic_config_created.send(app.config) return configs
[docs] def get_app_config(self, app, db): """Return alembic config for an app.""" # TODO: read these from django db settings version_table = ( getattr(app, "version_table", None) or db.kwargs.get("version_table") or "alembic_version_%s" % app.label.lower().replace(".", "_") ) max_length = db.engine.dialect.max_identifier_length if max_length and len(version_table) >= max_length: raise CommandError( "'{name}' is {length} characters long which is an invalid identifier " "in {dialect!r} as its max idenfier length is {max_length}".format( name=version_table, dialect=db.engine.dialect.name, length=len(version_table), max_length=max_length ) ) version_table_schema = getattr(app, "version_table_schema", None) or db.kwargs.get("version_table_schema") config = alembic.config.Config(output_buffer=self.stdout, stdout=self.stdout) config.set_main_option("script_location", SORCERY_ALEMBIC_CONFIG_FOLDER) config.set_main_option("version_locations", self.get_app_version_path(app)) config.set_main_option("version_table", version_table) if version_table_schema and db.engine.dialect.name != "sqlite": config.set_main_option("version_table_schema", version_table_schema) return config
[docs] def get_config_script(self, config): """Returns the alembic script directory for the config.""" return alembic.script.ScriptDirectory.from_config(config)
[docs] def lookup_app(self, app_label): """Looks up an app's alembic config.""" if app_label not in self.sorcery_apps: raise CommandError("App '%s' could not be found. Is it in INSTALLED_APPS?" % app_label) return self.sorcery_apps[app_label]
[docs] def get_app_version_path(self, app): """Returns the default migration directory location of al app.""" return os.path.join(app.path, "migrations")
[docs] def get_common_config(self, context): """Common alembic configuration.""" config = context.config return { "include_object": include_object, "process_revision_directives": process_revision_directives, "version_table": config.get_main_option("version_table"), "version_table_schema": config.get_main_option("version_table_schema"), }
[docs] def run_env(self, context, appconfig): """Executes an alembic context, just like the env.py file of alembic.""" configure_mappers() try: if context.is_offline_mode(): self.run_migrations_offline(context, appconfig) else: self.run_migrations_online(context, appconfig) except alembic.util.exc.CommandError as e: raise CommandError(str(e))
[docs] def run_migrations_online(self, context, appconfig): """Executes an online alembic context.""" with appconfig.db.engine.connect() as connection: context.configure( connection=connection, target_metadata=appconfig.db.metadata, **self.get_common_config(context) ) with context.begin_transaction(): context.run_migrations()
[docs] def run_migrations_offline(self, context, appconfig): """Executes an offline alembic context.""" context.configure( url=appconfig.db.url, literal_binds=True, target_metadata=appconfig.db.metadata, **self.get_common_config(context), ) with context.begin_transaction(): context.run_migrations()