Source code for galaxy.model.migrations.util

import logging
from abc import (
    ABC,
    abstractmethod,
)
from contextlib import contextmanager
from typing import (
    Any,
    List,
    Optional,
    Sequence,
)

import sqlalchemy as sa
from alembic import (
    context,
    op,
)
from sqlalchemy.exc import OperationalError

log = logging.getLogger(__name__)


[docs]class DDLOperation(ABC): """Base class for all DDL operations."""
[docs] def run(self) -> Optional[Any]: if not self._is_repair_mode(): return self.execute() else: if self.pre_execute_check(): return self.execute() else: self.log_check_not_passed() return None
[docs] @abstractmethod def execute(self) -> Optional[Any]: ...
[docs] @abstractmethod def pre_execute_check(self) -> bool: ...
[docs] @abstractmethod def log_check_not_passed(self) -> None: ...
def _is_repair_mode(self) -> bool: """`--repair` option has been passed to the command.""" return bool(context.config.get_main_option("repair")) def _log_object_exists_message(self, object_name: str) -> None: log.info(f"{object_name} already exists. Skipping revision.") def _log_object_does_not_exist_message(self, object_name: str) -> None: log.info(f"{object_name} does not exist. Skipping revision.")
[docs]class DDLAlterOperation(DDLOperation): """ Base class for DDL operations that implement special handling of ALTER statements. Ref: - https://alembic.sqlalchemy.org/en/latest/ops.html#alembic.operations.Operations.batch_alter_table - https://alembic.sqlalchemy.org/en/latest/batch.html """
[docs] def __init__(self, table_name: str) -> None: self.table_name = table_name
[docs] def run(self) -> Optional[Any]: if context.is_offline_mode(): log.info("Generation of `alter` statements is disabled in offline mode.") return None return super().run()
[docs] def execute(self) -> Optional[Any]: if _is_sqlite(): with legacy_alter_table(), op.batch_alter_table(self.table_name) as batch_op: return self.batch_execute(batch_op) else: return self.non_batch_execute() # use regular op context for non-sqlite db
[docs] @abstractmethod def batch_execute(self, batch_op) -> Optional[Any]: ...
[docs] @abstractmethod def non_batch_execute(self) -> Optional[Any]: ...
[docs]class CreateTable(DDLOperation): """Wraps alembic's create_table directive."""
[docs] def __init__(self, table_name: str, *columns: sa.schema.SchemaItem) -> None: self.table_name = table_name self.columns = columns
[docs] def execute(self) -> Optional[sa.Table]: return op.create_table(self.table_name, *self.columns)
[docs] def pre_execute_check(self) -> bool: return not table_exists(self.table_name, False)
[docs] def log_check_not_passed(self) -> None: self._log_object_exists_message(f"{self.table_name} table")
[docs]class DropTable(DDLOperation): """Wraps alembic's drop_table directive."""
[docs] def __init__(self, table_name: str) -> None: self.table_name = table_name
[docs] def execute(self) -> None: op.drop_table(self.table_name)
[docs] def pre_execute_check(self) -> bool: return table_exists(self.table_name, False)
[docs] def log_check_not_passed(self) -> None: self._log_object_does_not_exist_message(f"{self.table_name} table")
[docs]class CreateIndex(DDLOperation): """Wraps alembic's create_index directive."""
[docs] def __init__(self, index_name: str, table_name: str, columns: Sequence, **kw: Any) -> None: self.index_name = index_name self.table_name = table_name self.columns = columns self.kw = kw
[docs] def execute(self) -> None: op.create_index(self.index_name, self.table_name, self.columns, **self.kw)
[docs] def pre_execute_check(self) -> bool: return not index_exists(self.index_name, self.table_name, False)
[docs] def log_check_not_passed(self) -> None: name = _table_object_description(self.index_name, self.table_name) self._log_object_exists_message(name)
[docs]class DropIndex(DDLOperation): """Wraps alembic's drop_index directive."""
[docs] def __init__(self, index_name: str, table_name: str) -> None: self.index_name = index_name self.table_name = table_name
[docs] def execute(self) -> None: op.drop_index(self.index_name, table_name=self.table_name)
[docs] def pre_execute_check(self) -> bool: return index_exists(self.index_name, self.table_name, False)
[docs] def log_check_not_passed(self) -> None: name = _table_object_description(self.index_name, self.table_name) self._log_object_does_not_exist_message(name)
[docs]class AddColumn(DDLOperation): """Wraps alembic's add_column directive."""
[docs] def __init__(self, table_name: str, column: sa.Column) -> None: self.table_name = table_name self.column = column
[docs] def execute(self) -> None: op.add_column(self.table_name, self.column)
[docs] def pre_execute_check(self) -> bool: return not column_exists(self.table_name, self.column.name, False)
[docs] def log_check_not_passed(self) -> None: name = _table_object_description(self.column.name, self.table_name) self._log_object_exists_message(name)
[docs]class DropColumn(DDLAlterOperation): """Wraps alembic's drop_column directive."""
[docs] def __init__(self, table_name: str, column_name: str) -> None: super().__init__(table_name) self.column_name = column_name
[docs] def batch_execute(self, batch_op) -> None: batch_op.drop_column(self.column_name)
[docs] def non_batch_execute(self) -> None: op.drop_column(self.table_name, self.column_name)
[docs] def pre_execute_check(self) -> bool: return column_exists(self.table_name, self.column_name, False)
[docs] def log_check_not_passed(self) -> None: name = _table_object_description(self.column_name, self.table_name) self._log_object_does_not_exist_message(name)
[docs]class AlterColumn(DDLAlterOperation): """Wraps alembic's alter_column directive."""
[docs] def __init__(self, table_name: str, column_name: str, **kw: Any) -> None: self.table_name = table_name self.column_name = column_name self.kw = kw
[docs] def batch_execute(self, batch_op) -> None: batch_op.alter_column(self.column_name, **self.kw)
[docs] def non_batch_execute(self) -> None: op.alter_column(self.table_name, self.column_name, **self.kw)
[docs] def pre_execute_check(self) -> bool: # Assume that if a column exists, it can be altered. return column_exists(self.table_name, self.column_name, False)
[docs] def log_check_not_passed(self) -> None: name = _table_object_description(self.column_name, self.table_name) self._log_object_does_not_exist_message(name)
[docs]class CreateForeignKey(DDLAlterOperation): """Wraps alembic's create_foreign_key directive."""
[docs] def __init__( self, foreign_key_name: str, table_name: str, referent_table: str, local_cols: List[str], remote_cols: List[str], **kw: Any, ) -> None: super().__init__(table_name) self.foreign_key_name = foreign_key_name self.referent_table = referent_table self.local_cols = local_cols self.remote_cols = remote_cols self.kw = kw
[docs] def batch_execute(self, batch_op) -> None: batch_op.create_foreign_key( self.foreign_key_name, self.referent_table, self.local_cols, self.remote_cols, **self.kw )
[docs] def non_batch_execute(self) -> None: op.create_foreign_key( self.foreign_key_name, self.table_name, self.referent_table, self.local_cols, self.remote_cols, **self.kw )
[docs] def pre_execute_check(self) -> bool: return not foreign_key_exists(self.foreign_key_name, self.table_name, False)
[docs] def log_check_not_passed(self) -> None: name = _table_object_description(self.foreign_key_name, self.table_name) self._log_object_exists_message(name)
[docs]class CreateUniqueConstraint(DDLAlterOperation): """Wraps alembic's create_unique_constraint directive."""
[docs] def __init__(self, constraint_name: str, table_name: str, columns: List[str]) -> None: super().__init__(table_name) self.constraint_name = constraint_name self.columns = columns
[docs] def batch_execute(self, batch_op) -> None: batch_op.create_unique_constraint(self.constraint_name, self.columns)
[docs] def non_batch_execute(self) -> None: op.create_unique_constraint(self.constraint_name, self.table_name, self.columns)
[docs] def pre_execute_check(self) -> bool: return not unique_constraint_exists(self.constraint_name, self.table_name, False)
[docs] def log_check_not_passed(self) -> None: name = _table_object_description(self.constraint_name, self.table_name) self._log_object_exists_message(name)
[docs]class DropConstraint(DDLAlterOperation): """Wraps alembic's drop_constraint directive."""
[docs] def __init__(self, constraint_name: str, table_name: str) -> None: super().__init__(table_name) self.constraint_name = constraint_name
[docs] def batch_execute(self, batch_op) -> None: batch_op.drop_constraint(self.constraint_name)
[docs] def non_batch_execute(self) -> None: op.drop_constraint(self.constraint_name, self.table_name)
[docs] def pre_execute_check(self) -> bool: return unique_constraint_exists(self.constraint_name, self.table_name, False)
[docs] def log_check_not_passed(self) -> None: name = _table_object_description(self.constraint_name, self.table_name) self._log_object_does_not_exist_message(name)
[docs]def create_table(table_name: str, *columns: sa.schema.SchemaItem) -> Optional[sa.Table]: return CreateTable(table_name, *columns).run()
[docs]def drop_table(table_name: str) -> None: DropTable(table_name).run()
[docs]def add_column(table_name: str, column: sa.Column) -> None: AddColumn(table_name, column).run()
[docs]def drop_column(table_name, column_name) -> None: DropColumn(table_name, column_name).run()
[docs]def alter_column(table_name: str, column_name: str, **kw) -> None: AlterColumn(table_name, column_name, **kw).run()
[docs]def create_index(index_name, table_name, columns, **kw) -> None: CreateIndex(index_name, table_name, columns, **kw).run()
[docs]def drop_index(index_name, table_name) -> None: DropIndex(index_name, table_name).run()
[docs]def create_foreign_key( foreign_key_name: str, table_name: str, referent_table: str, local_cols: List[str], remote_cols: List[str], **kw: Any, ) -> None: CreateForeignKey(foreign_key_name, table_name, referent_table, local_cols, remote_cols, **kw).run()
[docs]def create_unique_constraint(constraint_name: str, table_name: str, columns: List[str]) -> None: CreateUniqueConstraint(constraint_name, table_name, columns).run()
[docs]def drop_constraint(constraint_name: str, table_name: str) -> None: DropConstraint(constraint_name, table_name).run()
[docs]def table_exists(table_name: str, default: bool) -> bool: """Check if table exists. If running in offline mode, return default.""" if context.is_offline_mode(): _log_offline_mode_message(table_exists.__name__, default) return default return _inspector().has_table(table_name)
[docs]def column_exists(table_name: str, column_name: str, default: bool) -> bool: """Check if column exists. If running in offline mode, return default.""" if context.is_offline_mode(): _log_offline_mode_message(column_exists.__name__, default) return default columns = _inspector().get_columns(table_name) return any(c["name"] == column_name for c in columns)
[docs]def index_exists(index_name: str, table_name: str, default: bool) -> bool: """Check if index exists. If running in offline mode, return default.""" if context.is_offline_mode(): _log_offline_mode_message(index_exists.__name__, default) return default indexes = _inspector().get_indexes(table_name) return any(index["name"] == index_name for index in indexes)
[docs]def foreign_key_exists(constraint_name: str, table_name: str, default: bool) -> bool: """Check if unique constraint exists. If running in offline mode, return default.""" if context.is_offline_mode(): _log_offline_mode_message(foreign_key_exists.__name__, default) return default constraints = _inspector().get_foreign_keys(table_name) return any(c["name"] == constraint_name for c in constraints)
[docs]def unique_constraint_exists(constraint_name: str, table_name: str, default: bool) -> bool: """Check if unique constraint exists. If running in offline mode, return default.""" if context.is_offline_mode(): _log_offline_mode_message(unique_constraint_exists.__name__, default) return default constraints = _inspector().get_unique_constraints(table_name) return any(c["name"] == constraint_name for c in constraints)
def _table_object_description(object_name: str, table_name: str) -> str: return f"{object_name} on {table_name} table" def _log_offline_mode_message(function_name: str, return_value: Any) -> None: log.info( f"This script is being executed in offline mode, so it cannot connect to the database. " f"Therefore, function `{function_name}` will return the value `{return_value}`, " f"which is the expected value during normal operation." ) def _inspector() -> Any: bind = op.get_context().bind return sa.inspect(bind) def _is_sqlite() -> bool: bind = op.get_context().bind return bool(bind and bind.engine.name == "sqlite")
[docs]@contextmanager def legacy_alter_table(): """ Wrapper required for add/drop column statements. Prevents error when column belongs to a table referenced in a view. Relevant to sqlite only. Ref: https://github.com/sqlalchemy/alembic/issues/1207 Ref: https://sqlite.org/pragma.html#pragma_legacy_alter_table """ try: op.execute("PRAGMA legacy_alter_table=1;") yield finally: op.execute("PRAGMA legacy_alter_table=0;")
[docs]@contextmanager def transaction(): """ Wraps multiple statements in upgrade/downgrade revision script functions in a database transaction, ensuring transactional control. Used for SQLite only. Although SQLite supports transactional DDL, pysqlite does not. Ref: https://bugs.python.org/issue10740 """ if not _is_sqlite(): yield # For postgresql, alembic ensures transactional context. else: try: op.execute("BEGIN") yield op.execute("END") except OperationalError: op.execute("ROLLBACK") raise