import logging
from abc import (
ABC,
abstractmethod,
)
from contextlib import contextmanager
from typing import (
Any,
List,
Optional,
Sequence,
)
import sqlalchemy as sa
from alembic import (
context,
op,
)
from sqlalchemy.exc import OperationalError
log = logging.getLogger(__name__)
[docs]class DDLOperation(ABC):
"""Base class for all DDL operations."""
[docs] def run(self) -> Optional[Any]:
if not self._is_repair_mode():
return self.execute()
else:
if self.pre_execute_check():
return self.execute()
else:
self.log_check_not_passed()
return None
[docs] @abstractmethod
def execute(self) -> Optional[Any]: ...
[docs] @abstractmethod
def pre_execute_check(self) -> bool: ...
[docs] @abstractmethod
def log_check_not_passed(self) -> None: ...
def _is_repair_mode(self) -> bool:
"""`--repair` option has been passed to the command."""
return bool(context.config.get_main_option("repair"))
def _log_object_exists_message(self, object_name: str) -> None:
log.info(f"{object_name} already exists. Skipping revision.")
def _log_object_does_not_exist_message(self, object_name: str) -> None:
log.info(f"{object_name} does not exist. Skipping revision.")
[docs]class DDLAlterOperation(DDLOperation):
"""
Base class for DDL operations that implement special handling of ALTER statements.
Ref:
- https://alembic.sqlalchemy.org/en/latest/ops.html#alembic.operations.Operations.batch_alter_table
- https://alembic.sqlalchemy.org/en/latest/batch.html
"""
[docs] def __init__(self, table_name: str) -> None:
self.table_name = table_name
[docs] def run(self) -> Optional[Any]:
if context.is_offline_mode():
log.info("Generation of `alter` statements is disabled in offline mode.")
return None
return super().run()
[docs] def execute(self) -> Optional[Any]:
if _is_sqlite():
with legacy_alter_table(), op.batch_alter_table(self.table_name) as batch_op:
return self.batch_execute(batch_op)
else:
return self.non_batch_execute() # use regular op context for non-sqlite db
[docs] @abstractmethod
def batch_execute(self, batch_op) -> Optional[Any]: ...
[docs] @abstractmethod
def non_batch_execute(self) -> Optional[Any]: ...
[docs]class CreateTable(DDLOperation):
"""Wraps alembic's create_table directive."""
[docs] def __init__(self, table_name: str, *columns: sa.schema.SchemaItem) -> None:
self.table_name = table_name
self.columns = columns
[docs] def execute(self) -> Optional[sa.Table]:
return op.create_table(self.table_name, *self.columns)
[docs] def pre_execute_check(self) -> bool:
return not table_exists(self.table_name, False)
[docs] def log_check_not_passed(self) -> None:
self._log_object_exists_message(f"{self.table_name} table")
[docs]class DropTable(DDLOperation):
"""Wraps alembic's drop_table directive."""
[docs] def __init__(self, table_name: str) -> None:
self.table_name = table_name
[docs] def execute(self) -> None:
op.drop_table(self.table_name)
[docs] def pre_execute_check(self) -> bool:
return table_exists(self.table_name, False)
[docs] def log_check_not_passed(self) -> None:
self._log_object_does_not_exist_message(f"{self.table_name} table")
[docs]class CreateIndex(DDLOperation):
"""Wraps alembic's create_index directive."""
[docs] def __init__(self, index_name: str, table_name: str, columns: Sequence, **kw: Any) -> None:
self.index_name = index_name
self.table_name = table_name
self.columns = columns
self.kw = kw
[docs] def execute(self) -> None:
op.create_index(self.index_name, self.table_name, self.columns, **self.kw)
[docs] def pre_execute_check(self) -> bool:
return not index_exists(self.index_name, self.table_name, False)
[docs] def log_check_not_passed(self) -> None:
name = _table_object_description(self.index_name, self.table_name)
self._log_object_exists_message(name)
[docs]class DropIndex(DDLOperation):
"""Wraps alembic's drop_index directive."""
[docs] def __init__(self, index_name: str, table_name: str) -> None:
self.index_name = index_name
self.table_name = table_name
[docs] def execute(self) -> None:
op.drop_index(self.index_name, table_name=self.table_name)
[docs] def pre_execute_check(self) -> bool:
return index_exists(self.index_name, self.table_name, False)
[docs] def log_check_not_passed(self) -> None:
name = _table_object_description(self.index_name, self.table_name)
self._log_object_does_not_exist_message(name)
[docs]class AddColumn(DDLOperation):
"""Wraps alembic's add_column directive."""
[docs] def __init__(self, table_name: str, column: sa.Column) -> None:
self.table_name = table_name
self.column = column
[docs] def execute(self) -> None:
op.add_column(self.table_name, self.column)
[docs] def pre_execute_check(self) -> bool:
return not column_exists(self.table_name, self.column.name, False)
[docs] def log_check_not_passed(self) -> None:
name = _table_object_description(self.column.name, self.table_name)
self._log_object_exists_message(name)
[docs]class DropColumn(DDLAlterOperation):
"""Wraps alembic's drop_column directive."""
[docs] def __init__(self, table_name: str, column_name: str) -> None:
super().__init__(table_name)
self.column_name = column_name
[docs] def batch_execute(self, batch_op) -> None:
batch_op.drop_column(self.column_name)
[docs] def non_batch_execute(self) -> None:
op.drop_column(self.table_name, self.column_name)
[docs] def pre_execute_check(self) -> bool:
return column_exists(self.table_name, self.column_name, False)
[docs] def log_check_not_passed(self) -> None:
name = _table_object_description(self.column_name, self.table_name)
self._log_object_does_not_exist_message(name)
[docs]class AlterColumn(DDLAlterOperation):
"""Wraps alembic's alter_column directive."""
[docs] def __init__(self, table_name: str, column_name: str, **kw: Any) -> None:
self.table_name = table_name
self.column_name = column_name
self.kw = kw
[docs] def batch_execute(self, batch_op) -> None:
batch_op.alter_column(self.column_name, **self.kw)
[docs] def non_batch_execute(self) -> None:
op.alter_column(self.table_name, self.column_name, **self.kw)
[docs] def pre_execute_check(self) -> bool:
# Assume that if a column exists, it can be altered.
return column_exists(self.table_name, self.column_name, False)
[docs] def log_check_not_passed(self) -> None:
name = _table_object_description(self.column_name, self.table_name)
self._log_object_does_not_exist_message(name)
[docs]class CreateForeignKey(DDLAlterOperation):
"""Wraps alembic's create_foreign_key directive."""
[docs] def __init__(
self,
foreign_key_name: str,
table_name: str,
referent_table: str,
local_cols: List[str],
remote_cols: List[str],
**kw: Any,
) -> None:
super().__init__(table_name)
self.foreign_key_name = foreign_key_name
self.referent_table = referent_table
self.local_cols = local_cols
self.remote_cols = remote_cols
self.kw = kw
[docs] def batch_execute(self, batch_op) -> None:
batch_op.create_foreign_key(
self.foreign_key_name, self.referent_table, self.local_cols, self.remote_cols, **self.kw
)
[docs] def non_batch_execute(self) -> None:
op.create_foreign_key(
self.foreign_key_name, self.table_name, self.referent_table, self.local_cols, self.remote_cols, **self.kw
)
[docs] def pre_execute_check(self) -> bool:
return not foreign_key_exists(self.foreign_key_name, self.table_name, False)
[docs] def log_check_not_passed(self) -> None:
name = _table_object_description(self.foreign_key_name, self.table_name)
self._log_object_exists_message(name)
[docs]class CreateUniqueConstraint(DDLAlterOperation):
"""Wraps alembic's create_unique_constraint directive."""
[docs] def __init__(self, constraint_name: str, table_name: str, columns: List[str]) -> None:
super().__init__(table_name)
self.constraint_name = constraint_name
self.columns = columns
[docs] def batch_execute(self, batch_op) -> None:
batch_op.create_unique_constraint(self.constraint_name, self.columns)
[docs] def non_batch_execute(self) -> None:
op.create_unique_constraint(self.constraint_name, self.table_name, self.columns)
[docs] def pre_execute_check(self) -> bool:
return not unique_constraint_exists(self.constraint_name, self.table_name, False)
[docs] def log_check_not_passed(self) -> None:
name = _table_object_description(self.constraint_name, self.table_name)
self._log_object_exists_message(name)
[docs]class DropConstraint(DDLAlterOperation):
"""Wraps alembic's drop_constraint directive."""
[docs] def __init__(self, constraint_name: str, table_name: str) -> None:
super().__init__(table_name)
self.constraint_name = constraint_name
[docs] def batch_execute(self, batch_op) -> None:
batch_op.drop_constraint(self.constraint_name)
[docs] def non_batch_execute(self) -> None:
op.drop_constraint(self.constraint_name, self.table_name)
[docs] def pre_execute_check(self) -> bool:
return unique_constraint_exists(self.constraint_name, self.table_name, False)
[docs] def log_check_not_passed(self) -> None:
name = _table_object_description(self.constraint_name, self.table_name)
self._log_object_does_not_exist_message(name)
[docs]def create_table(table_name: str, *columns: sa.schema.SchemaItem) -> Optional[sa.Table]:
return CreateTable(table_name, *columns).run()
[docs]def drop_table(table_name: str) -> None:
DropTable(table_name).run()
[docs]def add_column(table_name: str, column: sa.Column) -> None:
AddColumn(table_name, column).run()
[docs]def drop_column(table_name, column_name) -> None:
DropColumn(table_name, column_name).run()
[docs]def alter_column(table_name: str, column_name: str, **kw) -> None:
AlterColumn(table_name, column_name, **kw).run()
[docs]def create_index(index_name, table_name, columns, **kw) -> None:
CreateIndex(index_name, table_name, columns, **kw).run()
[docs]def drop_index(index_name, table_name) -> None:
DropIndex(index_name, table_name).run()
[docs]def create_foreign_key(
foreign_key_name: str,
table_name: str,
referent_table: str,
local_cols: List[str],
remote_cols: List[str],
**kw: Any,
) -> None:
CreateForeignKey(foreign_key_name, table_name, referent_table, local_cols, remote_cols, **kw).run()
[docs]def create_unique_constraint(constraint_name: str, table_name: str, columns: List[str]) -> None:
CreateUniqueConstraint(constraint_name, table_name, columns).run()
[docs]def drop_constraint(constraint_name: str, table_name: str) -> None:
DropConstraint(constraint_name, table_name).run()
[docs]def table_exists(table_name: str, default: bool) -> bool:
"""Check if table exists. If running in offline mode, return default."""
if context.is_offline_mode():
_log_offline_mode_message(table_exists.__name__, default)
return default
return _inspector().has_table(table_name)
[docs]def column_exists(table_name: str, column_name: str, default: bool) -> bool:
"""Check if column exists. If running in offline mode, return default."""
if context.is_offline_mode():
_log_offline_mode_message(column_exists.__name__, default)
return default
columns = _inspector().get_columns(table_name)
return any(c["name"] == column_name for c in columns)
[docs]def index_exists(index_name: str, table_name: str, default: bool) -> bool:
"""Check if index exists. If running in offline mode, return default."""
if context.is_offline_mode():
_log_offline_mode_message(index_exists.__name__, default)
return default
indexes = _inspector().get_indexes(table_name)
return any(index["name"] == index_name for index in indexes)
[docs]def foreign_key_exists(constraint_name: str, table_name: str, default: bool) -> bool:
"""Check if unique constraint exists. If running in offline mode, return default."""
if context.is_offline_mode():
_log_offline_mode_message(foreign_key_exists.__name__, default)
return default
constraints = _inspector().get_foreign_keys(table_name)
return any(c["name"] == constraint_name for c in constraints)
[docs]def unique_constraint_exists(constraint_name: str, table_name: str, default: bool) -> bool:
"""Check if unique constraint exists. If running in offline mode, return default."""
if context.is_offline_mode():
_log_offline_mode_message(unique_constraint_exists.__name__, default)
return default
constraints = _inspector().get_unique_constraints(table_name)
return any(c["name"] == constraint_name for c in constraints)
def _table_object_description(object_name: str, table_name: str) -> str:
return f"{object_name} on {table_name} table"
def _log_offline_mode_message(function_name: str, return_value: Any) -> None:
log.info(
f"This script is being executed in offline mode, so it cannot connect to the database. "
f"Therefore, function `{function_name}` will return the value `{return_value}`, "
f"which is the expected value during normal operation."
)
def _inspector() -> Any:
bind = op.get_context().bind
return sa.inspect(bind)
def _is_sqlite() -> bool:
bind = op.get_context().bind
return bool(bind and bind.engine.name == "sqlite")
[docs]@contextmanager
def legacy_alter_table():
"""
Wrapper required for add/drop column statements.
Prevents error when column belongs to a table referenced in a view. Relevant to sqlite only.
Ref: https://github.com/sqlalchemy/alembic/issues/1207
Ref: https://sqlite.org/pragma.html#pragma_legacy_alter_table
"""
try:
op.execute("PRAGMA legacy_alter_table=1;")
yield
finally:
op.execute("PRAGMA legacy_alter_table=0;")
[docs]@contextmanager
def transaction():
"""
Wraps multiple statements in upgrade/downgrade revision script functions in
a database transaction, ensuring transactional control.
Used for SQLite only. Although SQLite supports transactional DDL, pysqlite does not.
Ref: https://bugs.python.org/issue10740
"""
if not _is_sqlite():
yield # For postgresql, alembic ensures transactional context.
else:
try:
op.execute("BEGIN")
yield
op.execute("END")
except OperationalError:
op.execute("ROLLBACK")
raise