flo-token-explorer-historic/py3/lib/python3.6/site-packages/alembic/testing/fixtures.py
2021-12-09 18:16:18 +00:00

278 lines
8.1 KiB
Python

# coding: utf-8
from contextlib import contextmanager
import io
import re
from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import inspect
from sqlalchemy import MetaData
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import text
import alembic
from . import config
from . import mock
from .assertions import _get_dialect
from .assertions import eq_
from .plugin.plugin_base import SkipTest
from .. import util
from ..environment import EnvironmentContext
from ..migration import MigrationContext
from ..operations import Operations
from ..util.compat import configparser
from ..util.compat import string_types
from ..util.compat import text_type
testing_config = configparser.ConfigParser()
testing_config.read(["test.cfg"])
if not util.sqla_094:
class TestBase(object):
# A sequence of database names to always run, regardless of the
# constraints below.
__whitelist__ = ()
# A sequence of requirement names matching testing.requires decorators
__requires__ = ()
# A sequence of dialect names to exclude from the test class.
__unsupported_on__ = ()
# If present, test class is only runnable for the *single* specified
# dialect. If you need multiple, use __unsupported_on__ and invert.
__only_on__ = None
# A sequence of no-arg callables. If any are True, the entire
# testcase is skipped.
__skip_if__ = None
def assert_(self, val, msg=None):
assert val, msg
# apparently a handful of tests are doing this....OK
def setup(self):
if hasattr(self, "setUp"):
self.setUp()
def teardown(self):
if hasattr(self, "tearDown"):
self.tearDown()
else:
from sqlalchemy.testing.fixtures import TestBase # noqa
def capture_db():
buf = []
def dump(sql, *multiparams, **params):
buf.append(str(sql.compile(dialect=engine.dialect)))
engine = create_engine("postgresql://", strategy="mock", executor=dump)
return engine, buf
_engs = {}
@contextmanager
def capture_context_buffer(**kw):
if kw.pop("bytes_io", False):
buf = io.BytesIO()
else:
buf = io.StringIO()
kw.update({"dialect_name": "sqlite", "output_buffer": buf})
conf = EnvironmentContext.configure
def configure(*arg, **opt):
opt.update(**kw)
return conf(*arg, **opt)
with mock.patch.object(EnvironmentContext, "configure", configure):
yield buf
def op_fixture(
dialect="default",
as_sql=False,
naming_convention=None,
literal_binds=False,
native_boolean=None,
):
opts = {}
if naming_convention:
if not util.sqla_092:
raise SkipTest(
"naming_convention feature requires " "sqla 0.9.2 or greater"
)
opts["target_metadata"] = MetaData(naming_convention=naming_convention)
class buffer_(object):
def __init__(self):
self.lines = []
def write(self, msg):
msg = msg.strip()
msg = re.sub(r"[\n\t]", "", msg)
if as_sql:
# the impl produces soft tabs,
# so search for blocks of 4 spaces
msg = re.sub(r" ", "", msg)
msg = re.sub(r"\;\n*$", "", msg)
self.lines.append(msg)
def flush(self):
pass
buf = buffer_()
class ctx(MigrationContext):
def clear_assertions(self):
buf.lines[:] = []
def assert_(self, *sql):
# TODO: make this more flexible about
# whitespace and such
eq_(buf.lines, list(sql))
def assert_contains(self, sql):
for stmt in buf.lines:
if sql in stmt:
return
else:
assert False, "Could not locate fragment %r in %r" % (
sql,
buf.lines,
)
if as_sql:
opts["as_sql"] = as_sql
if literal_binds:
opts["literal_binds"] = literal_binds
if dialect == "mariadb":
ctx_dialect = _get_dialect("mysql")
ctx_dialect.server_version_info = (10, 0, 0, "MariaDB")
else:
ctx_dialect = _get_dialect(dialect)
if native_boolean is not None:
ctx_dialect.supports_native_boolean = native_boolean
# this is new as of SQLAlchemy 1.2.7 and is used by SQL Server,
# which breaks assumptions in the alembic test suite
ctx_dialect.non_native_boolean_check_constraint = True
if not as_sql:
def execute(stmt, *multiparam, **param):
if isinstance(stmt, string_types):
stmt = text(stmt)
assert stmt.supports_execution
sql = text_type(stmt.compile(dialect=ctx_dialect))
buf.write(sql)
connection = mock.Mock(dialect=ctx_dialect, execute=execute)
else:
opts["output_buffer"] = buf
connection = None
context = ctx(ctx_dialect, connection, opts)
alembic.op._proxy = Operations(context)
return context
class AlterColRoundTripFixture(object):
# since these tests are about syntax, use more recent SQLAlchemy as some of
# the type / server default compare logic might not work on older
# SQLAlchemy versions as seems to be the case for SQLAlchemy 1.1 on Oracle
__requires__ = ("alter_column", "sqlalchemy_12")
def setUp(self):
self.conn = config.db.connect()
self.ctx = MigrationContext.configure(self.conn)
self.op = Operations(self.ctx)
self.metadata = MetaData()
def _compare_type(self, t1, t2):
c1 = Column("q", t1)
c2 = Column("q", t2)
assert not self.ctx.impl.compare_type(
c1, c2
), "Type objects %r and %r didn't compare as equivalent" % (t1, t2)
def _compare_server_default(self, t1, s1, t2, s2):
c1 = Column("q", t1, server_default=s1)
c2 = Column("q", t2, server_default=s2)
assert not self.ctx.impl.compare_server_default(
c1, c2, s2, s1
), "server defaults %r and %r didn't compare as equivalent" % (s1, s2)
def tearDown(self):
self.metadata.drop_all(self.conn)
self.conn.close()
def _run_alter_col(self, from_, to_, compare=None):
column = Column(
from_.get("name", "colname"),
from_.get("type", String(10)),
nullable=from_.get("nullable", True),
server_default=from_.get("server_default", None),
# comment=from_.get("comment", None)
)
t = Table("x", self.metadata, column)
t.create(self.conn)
insp = inspect(self.conn)
old_col = insp.get_columns("x")[0]
# TODO: conditional comment support
self.op.alter_column(
"x",
column.name,
existing_type=column.type,
existing_server_default=column.server_default
if column.server_default is not None
else False,
existing_nullable=True if column.nullable else False,
# existing_comment=column.comment,
nullable=to_.get("nullable", None),
# modify_comment=False,
server_default=to_.get("server_default", False),
new_column_name=to_.get("name", None),
type_=to_.get("type", None),
)
insp = inspect(self.conn)
new_col = insp.get_columns("x")[0]
if compare is None:
compare = to_
eq_(
new_col["name"],
compare["name"] if "name" in compare else column.name,
)
self._compare_type(
new_col["type"], compare.get("type", old_col["type"])
)
eq_(new_col["nullable"], compare.get("nullable", column.nullable))
self._compare_server_default(
new_col["type"],
new_col.get("default", None),
compare.get("type", old_col["type"]),
compare["server_default"].text
if "server_default" in compare
else column.server_default.arg.text
if column.server_default is not None
else None,
)