def upgrade():
    # Create DraftThread table
    op.create_table(
        "draftthread",
        sa.Column("id", sa.Integer(), nullable=False),
        sa.Column("master_public_id", mysql.BINARY(16), nullable=False),
        sa.Column("thread_id", sa.Integer()),
        sa.ForeignKeyConstraint(["thread_id"], ["thread.id"],
                                ondelete="CASCADE"),
        sa.Column("message_id", sa.Integer()),
        sa.ForeignKeyConstraint(["message_id"], ["message.id"],
                                ondelete="CASCADE"),
        sa.PrimaryKeyConstraint("id"),
        sa.Column("created_at", sa.DateTime(), nullable=False),
        sa.Column("updated_at", sa.DateTime(), nullable=False),
        sa.Column("deleted_at", sa.DateTime(), nullable=True),
        sa.Column("public_id", mysql.BINARY(16), nullable=False, index=True),
    )

    # Add columns to SpoolMessage table
    op.add_column("spoolmessage",
                  sa.Column("parent_draft_id", sa.Integer(), nullable=True))
    op.create_foreign_key(
        "spoolmessage_ibfk_3",
        "spoolmessage",
        "spoolmessage",
        ["parent_draft_id"],
        ["id"],
    )

    op.add_column("spoolmessage",
                  sa.Column("draft_copied_from", sa.Integer(), nullable=True))
    op.create_foreign_key(
        "spoolmessage_ibfk_4",
        "spoolmessage",
        "spoolmessage",
        ["draft_copied_from"],
        ["id"],
    )

    op.add_column("spoolmessage",
                  sa.Column("replyto_thread_id", sa.Integer(), nullable=True))
    op.create_foreign_key(
        "spoolmessage_ibfk_5",
        "spoolmessage",
        "draftthread",
        ["replyto_thread_id"],
        ["id"],
    )

    op.add_column(
        "spoolmessage",
        sa.Column(
            "state",
            sa.Enum("draft", "sending", "sending failed", "sent"),
            server_default="draft",
            nullable=False,
        ),
    )
def upgrade():
    op.add_column("transaction",
                  sa.Column("public_id", mysql.BINARY(16), nullable=True))
    op.add_column(
        "transaction",
        sa.Column("object_public_id", sa.String(length=191), nullable=True),
    )
    op.create_index("ix_transaction_public_id",
                    "transaction", ["public_id"],
                    unique=False)

    # TODO(emfree) reflect
    from inbox.ignition import main_engine
    from inbox.models.session import session_scope
    from inbox.sqlalchemy_ext.util import b36_to_bin, generate_public_id

    engine = main_engine(pool_size=1, max_overflow=0)
    Base = declarative_base()
    Base.metadata.reflect(engine)

    class Transaction(Base):
        __table__ = Base.metadata.tables["transaction"]

    with session_scope(versioned=False) as db_session:
        count = 0
        (num_transactions, ) = db_session.query(sa.func.max(
            Transaction.id)).one()
        print("Adding public ids to {} transactions".format(num_transactions))
        for pointer in range(0, num_transactions + 1, 500):
            for entry in db_session.query(Transaction).filter(
                    Transaction.id >= pointer, Transaction.id < pointer + 500):
                entry.public_id = b36_to_bin(generate_public_id())
                count += 1
                if not count % 500:
                    sys.stdout.write(".")
                    sys.stdout.flush()
                    db_session.commit()
                    garbage_collect()

    op.alter_column("transaction",
                    "public_id",
                    existing_type=mysql.BINARY(16),
                    nullable=False)

    op.add_column(
        "transaction",
        sa.Column("public_snapshot", sa.Text(length=4194304), nullable=True),
    )
    op.add_column(
        "transaction",
        sa.Column("private_snapshot", sa.Text(length=4194304), nullable=True),
    )
    op.drop_column("transaction", u"additional_data")
Beispiel #3
0
def upgrade():
    op.create_table(
        "webhookparameters",
        sa.Column("id", sa.Integer(), nullable=False),
        sa.Column("public_id", mysql.BINARY(16), nullable=False),
        sa.Column("namespace_id", sa.Integer(), nullable=False),
        sa.Column("callback_url", sa.Text(), nullable=False),
        sa.Column("failure_notify_url", sa.Text(), nullable=True),
        sa.Column("to_addr", sa.String(length=255), nullable=True),
        sa.Column("from_addr", sa.String(length=255), nullable=True),
        sa.Column("cc_addr", sa.String(length=255), nullable=True),
        sa.Column("bcc_addr", sa.String(length=255), nullable=True),
        sa.Column("email", sa.String(length=255), nullable=True),
        sa.Column("subject", sa.String(length=255), nullable=True),
        sa.Column("thread", mysql.BINARY(16), nullable=True),
        sa.Column("filename", sa.String(length=255), nullable=True),
        sa.Column("started_before", sa.DateTime(), nullable=True),
        sa.Column("started_after", sa.DateTime(), nullable=True),
        sa.Column("last_message_before", sa.DateTime(), nullable=True),
        sa.Column("last_message_after", sa.DateTime(), nullable=True),
        sa.Column("include_body", sa.Boolean(), nullable=False),
        sa.Column("max_retries",
                  sa.Integer(),
                  server_default="3",
                  nullable=False),
        sa.Column("retry_interval",
                  sa.Integer(),
                  server_default="60",
                  nullable=False),
        sa.Column(
            "active",
            sa.Boolean(),
            server_default=sa.sql.expression.true(),
            nullable=False,
        ),
        sa.Column("min_processed_id",
                  sa.Integer(),
                  server_default="0",
                  nullable=False),
        sa.ForeignKeyConstraint(["namespace_id"], ["namespace.id"],
                                ondelete="CASCADE"),
        sa.PrimaryKeyConstraint("id"),
    )
    op.create_index(
        "ix_webhookparameters_public_id",
        "webhookparameters",
        ["public_id"],
        unique=False,
    )
def upgrade():
    op.add_column('transaction',
                  sa.Column('public_id', mysql.BINARY(16), nullable=True))
    op.add_column(
        'transaction',
        sa.Column('object_public_id', sa.String(length=191), nullable=True))
    op.create_index('ix_transaction_public_id',
                    'transaction', ['public_id'],
                    unique=False)

    from inbox.sqlalchemy_ext.util import generate_public_id, b36_to_bin
    # TODO(emfree) reflect
    from inbox.models.session import session_scope
    from inbox.ignition import main_engine
    engine = main_engine(pool_size=1, max_overflow=0)
    Base = declarative_base()
    Base.metadata.reflect(engine)

    class Transaction(Base):
        __table__ = Base.metadata.tables['transaction']

    with session_scope(versioned=False,
                       ignore_soft_deletes=False) as db_session:
        count = 0
        num_transactions, = db_session.query(sa.func.max(Transaction.id)).one()
        print 'Adding public ids to {} transactions'.format(num_transactions)
        for pointer in range(0, num_transactions + 1, 500):
            for entry in db_session.query(Transaction).filter(
                    Transaction.id >= pointer, Transaction.id < pointer + 500):
                entry.public_id = b36_to_bin(generate_public_id())
                count += 1
                if not count % 500:
                    sys.stdout.write('.')
                    sys.stdout.flush()
                    db_session.commit()
                    garbage_collect()

    op.alter_column('transaction',
                    'public_id',
                    existing_type=mysql.BINARY(16),
                    nullable=False)

    op.add_column(
        'transaction',
        sa.Column('public_snapshot', sa.Text(length=4194304), nullable=True))
    op.add_column(
        'transaction',
        sa.Column('private_snapshot', sa.Text(length=4194304), nullable=True))
    op.drop_column('transaction', u'additional_data')
Beispiel #5
0
class Block(Base):
    __tablename__ = "blocks"
    id = Column(mysql.BIGINT, primary_key=True)
    hsh = Column(mysql.BINARY(32))
    time = Column(mysql.BIGINT)
    received_time = Column(mysql.BIGINT)
    relayed_by = Column(mysql.VARBINARY(16))
Beispiel #6
0
def upgrade():

    from inbox.server.config import load_config
    load_config()
    from inbox.sqlalchemy.util import generate_public_id
    from inbox.server.models import session_scope

    # These all inherit HasPublicID
    from inbox.server.models.tables.base import (
        Account, Block, Contact, Message, Namespace,
        SharedFolder, Thread, User, UserSession, HasPublicID)

    classes = [
        Account, Block, Contact, Message, Namespace,
        SharedFolder, Thread, User, UserSession]

    for c in classes:
        assert issubclass(c, HasPublicID)
        print '[{0}] adding public_id column... '.format(c.__tablename__),
        sys.stdout.flush()
        op.add_column(c.__tablename__, sa.Column(
            'public_id', mysql.BINARY(16), nullable=False))

        print 'adding index... ',
        op.create_index(
            'ix_{0}_public_id'.format(c.__tablename__),
            c.__tablename__,
            ['public_id'],
            unique=False)

        print 'Done!'
        sys.stdout.flush()

    print 'Finished adding columns. \nNow generating public_ids'

    with session_scope() as db_session:
        count = 0
        for c in classes:
            garbage_collect()
            print '[{0}] Loading rows. '.format(c.__name__),
            sys.stdout.flush()
            print 'Generating public_ids',
            sys.stdout.flush()
            for r in db_session.query(c).yield_per(chunk_size):
                count += 1
                r.public_id = generate_public_id()
                if not count % chunk_size:
                    sys.stdout.write('.')
                    sys.stdout.flush()
                    db_session.commit()
                    garbage_collect()
            sys.stdout.write(' Saving. '.format(c.__name__)),
            # sys.stdout.flush()
            sys.stdout.flush()
            db_session.commit()
            sys.stdout.write('Done!\n')
            sys.stdout.flush()
        print '\nUpdgraded OK!\n'
Beispiel #7
0
def upgrade():
    op.drop_table('blocks')
    op.create_table(
        'blocks', sa.Column('id', mysql.BIGINT, nullable=False),
        sa.Column('hash', mysql.BINARY(32), nullable=True),
        sa.Column('time', mysql.BIGINT, nullable=True),
        sa.Column('received_time', mysql.BIGINT, nullable=True),
        sa.Column('relayed_by', mysql.VARBINARY(16), nullable=True),
        sa.PrimaryKeyConstraint('id'))
Beispiel #8
0
class TestcaseInputValue(db.Table):
  id_t = sql.Integer
  __tablename__ = "testcase_input_values"

  # Columns.
  id: int = sql.Column(id_t, primary_key=True)
  date_added: datetime.datetime = sql.Column(
    sql.DateTime().with_variant(mysql.DATETIME(fsp=3), "mysql"),
    nullable=False,
    default=labdate.GetUtcMillisecondsNow,
  )
  md5: bytes = sql.Column(
    sql.Binary(16).with_variant(mysql.BINARY(16), "mysql"),
    nullable=False,
    index=True,
    unique=True,
  )
  charcount = sql.Column(sql.Integer, nullable=False)
  linecount = sql.Column(sql.Integer, nullable=False)
  string: str = sql.Column(
    sql.UnicodeText().with_variant(sql.UnicodeText(2 ** 31), "mysql"),
    nullable=False,
  )

  # Relationships.
  inputs: typing.List[TestcaseInput] = orm.relationship(
    TestcaseInput, back_populates="value"
  )

  @classmethod
  def GetOrAdd(cls, session: db.session_t, string: str) -> "TestcaseInputValue":
    """Instantiate a TestcaseInputValue entry from a string.

    Args:
      session: A database session.
      string: The string.

    Returns:
      A TestcaseInputValue instance.
    """
    md5 = hashlib.md5()
    md5.update(string.encode("utf-8"))

    return labm8.py.sqlutil.GetOrAdd(
      session,
      cls,
      md5=md5.digest(),
      charcount=len(string),
      linecount=string.count("\n"),
      string=string,
    )

  def __repr__(self):
    return self.string[:50] or ""
Beispiel #9
0
class Transaction(Base):
    __tablename__ = "transactions"
    id = Column(mysql.BIGINT, primary_key=True)
    hsh = Column(mysql.BINARY(32))
    block_id = Column(mysql.BIGINT)
    received_time = Column(mysql.BIGINT)
    fee = Column(mysql.BIGINT)
    total_out_value = Column(mysql.BIGINT)
    num_inputs = Column(mysql.INTEGER)
    num_outputs = Column(mysql.INTEGER)
    coinbase = Column(mysql.BIT)
    lock_time = Column(mysql.BIGINT)
    relayed_by = Column(mysql.VARBINARY(16))
Beispiel #10
0
def upgrade():
    op.create_table(
        'webhookparameters',
        sa.Column('id', sa.Integer(), nullable=False),
        sa.Column('public_id', mysql.BINARY(16), nullable=False),
        sa.Column('namespace_id', sa.Integer(), nullable=False),
        sa.Column('callback_url', sa.Text(), nullable=False),
        sa.Column('failure_notify_url', sa.Text(), nullable=True),
        sa.Column('to_addr', sa.String(length=255), nullable=True),
        sa.Column('from_addr', sa.String(length=255), nullable=True),
        sa.Column('cc_addr', sa.String(length=255), nullable=True),
        sa.Column('bcc_addr', sa.String(length=255), nullable=True),
        sa.Column('email', sa.String(length=255), nullable=True),
        sa.Column('subject', sa.String(length=255), nullable=True),
        sa.Column('thread', mysql.BINARY(16), nullable=True),
        sa.Column('filename', sa.String(length=255), nullable=True),
        sa.Column('started_before', sa.DateTime(), nullable=True),
        sa.Column('started_after', sa.DateTime(), nullable=True),
        sa.Column('last_message_before', sa.DateTime(), nullable=True),
        sa.Column('last_message_after', sa.DateTime(), nullable=True),
        sa.Column('include_body', sa.Boolean(), nullable=False),
        sa.Column('max_retries', sa.Integer(), server_default='3',
                  nullable=False),
        sa.Column('retry_interval', sa.Integer(), server_default='60',
                  nullable=False),
        sa.Column('active', sa.Boolean(),
                  server_default=sa.sql.expression.true(),
                  nullable=False),
        sa.Column('min_processed_id', sa.Integer(), server_default='0',
                  nullable=False),
        sa.ForeignKeyConstraint(['namespace_id'], ['namespace.id'],
                                ondelete='CASCADE'),
        sa.PrimaryKeyConstraint('id')
    )
    op.create_index('ix_webhookparameters_public_id', 'webhookparameters',
                    ['public_id'], unique=False)
 def load_dialect_impl(self, dialect):
     if dialect.name == 'postgresql':
         # Use a BYTEA type for postgresql.
         impl = postgresql.BYTEA(60)
     elif dialect.name == 'oracle':
         # Use a RAW type for oracle.
         impl = oracle.RAW(60)
     elif dialect.name == 'sqlite':
         # Use a BLOB type for sqlite
         impl = sqlite.BLOB(60)
     elif dialect.name == 'mysql':
         # Use a BINARY type for mysql.
         impl = mysql.BINARY(60)
     else:
         impl = types.VARBINARY(60)
     return dialect.type_descriptor(impl)
Beispiel #12
0
class Directory(Base):
    """Directory cache entry."""
    __tablename__ = "directories"

    relpath_md5: str = Column(Binary(16), primary_key=True)
    checksum: bytes = Column(sql.Binary(16).with_variant(
        mysql.BINARY(16), 'mysql'),
                             nullable=False)
    date_added: datetime.datetime = Column(DateTime,
                                           nullable=False,
                                           default=datetime.datetime.utcnow)

    def __repr__(self):
        return (f'{self.relpath}:  '
                f'{shell.ShellEscapeCodes.YELLOW}{self.message}'
                f'{shell.ShellEscapeCodes.END}  [{self.category}]')
Beispiel #13
0
def upgrade():
    from inbox.sqlalchemy_ext.util import JSON

    op.add_column('actionlog',
                  sa.Column('extra_args', JSON(), nullable=True))

    op.add_column('message',
                  sa.Column('version', mysql.BINARY(16), nullable=True))

    from inbox.ignition import main_engine
    from inbox.models.session import session_scope

    engine = main_engine(pool_size=1, max_overflow=0)
    Base = sa.ext.declarative.declarative_base()
    Base.metadata.reflect(engine)

    class Message(Base):
        __table__ = Base.metadata.tables['message']

    # Delete old draft versions, set message.version=public_id on the latest
    # one.
    with session_scope(ignore_soft_deletes=False, versioned=False) as \
            db_session:
        q = db_session.query(Message).filter(
            Message.is_created == True,
            Message.is_draft == True)

        for d in page_query(q):
            if d.child_draft is not None:
                db_session.delete(d)
            else:
                d.version = d.public_id
                db_session.add(d)

        db_session.commit()

    op.drop_constraint('message_ibfk_3', 'message', type_='foreignkey')
    op.drop_column('message', 'parent_draft_id')
Beispiel #14
0
import sqlalchemy as sql
from sqlalchemy import orm
from sqlalchemy.dialects import mysql

import deeplearning.deepsmith.generator
import deeplearning.deepsmith.harness
import deeplearning.deepsmith.profiling_event
import deeplearning.deepsmith.toolchain
import phd.lib.labm8.sqlutil
from deeplearning.deepsmith import db
from deeplearning.deepsmith.proto import deepsmith_pb2
from phd.lib.labm8 import labdate, pbutil

# The index types for tables defined in this file.
_TestcaseId = sql.Integer
_TestcaseInputSetId = sql.Binary(16).with_variant(mysql.BINARY(16), 'mysql')
_TestcaseInputId = sql.Integer
_TestcaseInputNameId = db.StringTable.id_t
_TestcaseInputValueId = sql.Integer
_TestcaseInvariantOptSetId = sql.Binary(16).with_variant(
    mysql.BINARY(16), 'mysql')
_TestcaseInvariantOptId = sql.Integer
_TestcaseInvariantOptNameId = db.StringTable.id_t
_TestcaseInvariantOptValueId = db.StringTable.id_t


class Testcase(db.Table):
    """A testcase is a set of parameters for a runnable test.

  It is a tuple of <toolchain,generator,harness,inputs,invariant_opts>.
  """
Beispiel #15
0
import datetime
import hashlib
import typing

import sqlalchemy as sql
from sqlalchemy import orm
from sqlalchemy.dialects import mysql

import labm8.sqlutil
from deeplearning.deepsmith import db
from deeplearning.deepsmith.proto import deepsmith_pb2
from labm8 import labdate

# The index types for tables defined in this file.
_GeneratorId = sql.Integer
_GeneratorOptSetId = sql.Binary(16).with_variant(mysql.BINARY(16), 'mysql')
_GeneratorOptId = sql.Integer
_GeneratorOptNameId = db.StringTable.id_t
_GeneratorOptValueId = db.StringTable.id_t


class Generator(db.Table):
    id_t = _GeneratorId
    __tablename__ = 'generators'

    # Columns.
    id: int = sql.Column(id_t, primary_key=True)
    date_added: datetime.datetime = sql.Column(
        sql.DateTime().with_variant(mysql.DATETIME(fsp=3), 'mysql'),
        nullable=False,
        default=labdate.GetUtcMillisecondsNow)
def upgrade():
    # These all inherit HasPublicID
    from inbox.models import (
        Account,
        Block,
        Contact,
        HasPublicID,
        Message,
        Namespace,
        SharedFolder,
        Thread,
        User,
        UserSession,
    )
    from inbox.models.session import session_scope
    from inbox.sqlalchemy_ext.util import generate_public_id

    classes = [
        Account,
        Block,
        Contact,
        Message,
        Namespace,
        SharedFolder,
        Thread,
        User,
        UserSession,
    ]

    for c in classes:
        assert issubclass(c, HasPublicID)
        print "[{0}] adding public_id column... ".format(c.__tablename__),
        sys.stdout.flush()
        op.add_column(c.__tablename__,
                      sa.Column("public_id", mysql.BINARY(16), nullable=False))

        print "adding index... ",
        op.create_index(
            "ix_{0}_public_id".format(c.__tablename__),
            c.__tablename__,
            ["public_id"],
            unique=False,
        )

        print "Done!"
        sys.stdout.flush()

    print "Finished adding columns. \nNow generating public_ids"

    with session_scope() as db_session:
        count = 0
        for c in classes:
            garbage_collect()
            print "[{0}] Loading rows. ".format(c.__name__),
            sys.stdout.flush()
            print "Generating public_ids",
            sys.stdout.flush()
            for r in db_session.query(c).yield_per(chunk_size):
                count += 1
                r.public_id = generate_public_id()
                if not count % chunk_size:
                    sys.stdout.write(".")
                    sys.stdout.flush()
                    db_session.commit()
                    garbage_collect()
            sys.stdout.write(" Saving. ".format(c.__name__)),
            # sys.stdout.flush()
            sys.stdout.flush()
            db_session.commit()
            sys.stdout.write("Done!\n")
            sys.stdout.flush()
        print "\nUpdgraded OK!\n"
Beispiel #17
0
import typing

import sqlalchemy as sql
from sqlalchemy import orm
from sqlalchemy.dialects import mysql

import deeplearning.deepsmith.toolchain
import phd.lib.labm8.sqlutil
from deeplearning.deepsmith import db
from deeplearning.deepsmith.proto import deepsmith_pb2
from phd.lib.labm8 import labdate


# The index types for tables defined in this file.
_TestbedId = sql.Integer
_TestbedOptSetId = sql.Binary(16).with_variant(mysql.BINARY(16), 'mysql')
_TestbedOptId = sql.Integer
_TestbedOptNameId = db.StringTable.id_t
_TestbedOptValueId = db.StringTable.id_t


class Testbed(db.Table):
  """A Testbed is a system on which testcases may be run.

  Each testbed is a <toolchain,name,opts> tuple.
  """
  id_t = _TestbedId
  __tablename__ = 'testbeds'

  # Columns.
  id: int = sql.Column(id_t, primary_key=True)
def upgrade():
    easupdate = False

    print "Creating new tables and columns..."
    op.create_table(
        "folder",
        sa.Column("id", sa.Integer(), nullable=False),
        sa.Column("account_id", sa.Integer(), nullable=False),
        sa.Column("name",
                  sa.String(length=191, collation="utf8mb4_general_ci"),
                  nullable=True),
        sa.ForeignKeyConstraint(["account_id"], ["account.id"],
                                ondelete="CASCADE"),
        sa.PrimaryKeyConstraint("id"),
        sa.UniqueConstraint("account_id", "name"),
    )
    op.create_table(
        "internaltag",
        sa.Column("id", sa.Integer(), nullable=False),
        sa.Column("public_id", mysql.BINARY(16), nullable=False),
        sa.Column("namespace_id", sa.Integer(), nullable=False),
        sa.Column("name", sa.String(length=191), nullable=False),
        sa.Column("thread_id", sa.Integer(), nullable=False),
        sa.ForeignKeyConstraint(["namespace_id"], ["namespace.id"],
                                ondelete="CASCADE"),
        sa.ForeignKeyConstraint(["thread_id"], ["thread.id"],
                                ondelete="CASCADE"),
        sa.PrimaryKeyConstraint("id"),
        sa.UniqueConstraint("namespace_id", "name"),
    )
    op.add_column("folderitem",
                  sa.Column("folder_id", sa.Integer(), nullable=True))
    op.create_foreign_key(
        "fk_folder_id",
        "folderitem",
        "folder",
        ["folder_id"],
        ["id"],
        ondelete="CASCADE",
    )

    op.add_column("account",
                  sa.Column("inbox_folder_id", sa.Integer, nullable=True))
    op.add_column("account",
                  sa.Column("sent_folder_id", sa.Integer, nullable=True))
    op.add_column("account",
                  sa.Column("drafts_folder_id", sa.Integer, nullable=True))
    op.add_column("account",
                  sa.Column("spam_folder_id", sa.Integer, nullable=True))
    op.add_column("account",
                  sa.Column("trash_folder_id", sa.Integer, nullable=True))
    op.add_column("account",
                  sa.Column("archive_folder_id", sa.Integer, nullable=True))
    op.add_column("account",
                  sa.Column("all_folder_id", sa.Integer, nullable=True))
    op.add_column("account",
                  sa.Column("starred_folder_id", sa.Integer, nullable=True))
    op.create_foreign_key("account_ibfk_2", "account", "folder",
                          ["inbox_folder_id"], ["id"])
    op.create_foreign_key("account_ibfk_3", "account", "folder",
                          ["sent_folder_id"], ["id"])
    op.create_foreign_key("account_ibfk_4", "account", "folder",
                          ["drafts_folder_id"], ["id"])
    op.create_foreign_key("account_ibfk_5", "account", "folder",
                          ["spam_folder_id"], ["id"])
    op.create_foreign_key("account_ibfk_6", "account", "folder",
                          ["trash_folder_id"], ["id"])
    op.create_foreign_key("account_ibfk_7", "account", "folder",
                          ["archive_folder_id"], ["id"])
    op.create_foreign_key("account_ibfk_8", "account", "folder",
                          ["all_folder_id"], ["id"])
    op.create_foreign_key("account_ibfk_9", "account", "folder",
                          ["starred_folder_id"], ["id"])

    op.add_column("imapuid", sa.Column("folder_id", sa.Integer, nullable=True))
    op.create_foreign_key("imapuid_ibfk_3", "imapuid", "folder", ["folder_id"],
                          ["id"])

    from inbox.ignition import main_engine
    from inbox.models.session import session_scope

    engine = main_engine(pool_size=1, max_overflow=0)

    Base = declarative_base()
    Base.metadata.reflect(engine)

    if "easuid" in Base.metadata.tables:
        easupdate = True
        print "Adding new EASUid columns..."

        op.add_column("easuid",
                      sa.Column("fld_uid", sa.Integer(), nullable=True))

        op.add_column("easuid",
                      sa.Column("folder_id", sa.Integer(), nullable=True))

        op.create_foreign_key("easuid_ibfk_3", "easuid", "folder",
                              ["folder_id"], ["id"])

        op.create_unique_constraint(
            "uq_easuid_folder_id_msg_uid_easaccount_id",
            "easuid",
            ["folder_id", "msg_uid", "easaccount_id"],
        )

        op.create_index("easuid_easaccount_id_folder_id", "easuid",
                        ["easaccount_id", "folder_id"])

    # Include our changes to the EASUid table:
    Base = declarative_base()
    Base.metadata.reflect(engine)

    class Folder(Base):
        __table__ = Base.metadata.tables["folder"]
        account = relationship("Account",
                               foreign_keys="Folder.account_id",
                               backref="folders")

    class FolderItem(Base):
        __table__ = Base.metadata.tables["folderitem"]
        folder = relationship("Folder", backref="threads", lazy="joined")

    class Thread(Base):
        __table__ = Base.metadata.tables["thread"]
        folderitems = relationship(
            "FolderItem",
            backref="thread",
            single_parent=True,
            cascade="all, delete, delete-orphan",
        )
        namespace = relationship("Namespace", backref="threads")

    class Namespace(Base):
        __table__ = Base.metadata.tables["namespace"]
        account = relationship("Account",
                               backref=backref("namespace", uselist=False))

    class Account(Base):
        __table__ = Base.metadata.tables["account"]
        inbox_folder = relationship("Folder",
                                    foreign_keys="Account.inbox_folder_id")
        sent_folder = relationship("Folder",
                                   foreign_keys="Account.sent_folder_id")
        drafts_folder = relationship("Folder",
                                     foreign_keys="Account.drafts_folder_id")
        spam_folder = relationship("Folder",
                                   foreign_keys="Account.spam_folder_id")
        trash_folder = relationship("Folder",
                                    foreign_keys="Account.trash_folder_id")
        starred_folder = relationship("Folder",
                                      foreign_keys="Account.starred_folder_id")
        archive_folder = relationship("Folder",
                                      foreign_keys="Account.archive_folder_id")
        all_folder = relationship("Folder",
                                  foreign_keys="Account.all_folder_id")

    class ImapUid(Base):
        __table__ = Base.metadata.tables["imapuid"]
        folder = relationship("Folder", backref="imapuids", lazy="joined")

    if easupdate:

        class EASUid(Base):
            __table__ = Base.metadata.tables["easuid"]
            folder = relationship(
                "Folder",
                foreign_keys="EASUid.folder_id",
                backref="easuids",
                lazy="joined",
            )

    print "Creating Folder rows and migrating FolderItems..."
    # not many folders per account, so shouldn't grow that big
    with session_scope(versioned=False) as db_session:
        folders = dict([((i.account_id, i.name), i)
                        for i in db_session.query(Folder).all()])
        count = 0
        for folderitem in (db_session.query(FolderItem).join(Thread).join(
                Namespace).yield_per(CHUNK_SIZE)):
            account_id = folderitem.thread.namespace.account_id
            if folderitem.thread.namespace.account.provider == "gmail":
                if folderitem.folder_name in folder_name_subst_map:
                    new_folder_name = folder_name_subst_map[
                        folderitem.folder_name]
                else:
                    new_folder_name = folderitem.folder_name
            elif folderitem.thread.namespace.account.provider == "eas":
                new_folder_name = folderitem.folder_name.title()

            if (account_id, new_folder_name) in folders:
                f = folders[(account_id, new_folder_name)]
            else:
                f = Folder(account_id=account_id, name=new_folder_name)
                folders[(account_id, new_folder_name)] = f
            folderitem.folder = f
            count += 1
            if count > CHUNK_SIZE:
                db_session.commit()
                count = 0
        db_session.commit()

        print "Migrating ImapUids to reference Folder rows..."
        for imapuid in db_session.query(ImapUid).yield_per(CHUNK_SIZE):
            account_id = imapuid.imapaccount_id
            if imapuid.folder_name in folder_name_subst_map:
                new_folder_name = folder_name_subst_map[imapuid.folder_name]
            else:
                new_folder_name = imapuid.folder_name
            if (account_id, new_folder_name) in folders:
                f = folders[(account_id, new_folder_name)]
            else:
                f = Folder(account_id=account_id, name=new_folder_name)
                folders[(account_id, new_folder_name)] = f
            imapuid.folder = f
            count += 1
            if count > CHUNK_SIZE:
                db_session.commit()
                count = 0
        db_session.commit()

        if easupdate:
            print "Migrating EASUids to reference Folder rows..."

            for easuid in db_session.query(EASUid).yield_per(CHUNK_SIZE):
                account_id = easuid.easaccount_id
                new_folder_name = easuid.folder_name

                if (account_id, new_folder_name) in folders:
                    f = folders[(account_id, new_folder_name)]
                else:
                    f = Folder(account_id=account_id, name=new_folder_name)
                    folders[(account_id, new_folder_name)] = f
                easuid.folder = f
                count += 1
                if count > CHUNK_SIZE:
                    db_session.commit()
                    count = 0
            db_session.commit()

        print "Migrating *_folder_name fields to reference Folder rows..."
        for account in db_session.query(Account).filter_by(provider="gmail"):
            if account.inbox_folder_name:
                # hard replace INBOX with canonicalized caps
                k = (account.id, "Inbox")
                if k in folders:
                    account.inbox_folder = folders[k]
                else:
                    account.inbox_folder = Folder(
                        account_id=account.id,
                        name=folder_name_subst_map[account.inbox_folder_name],
                    )
            if account.sent_folder_name:
                k = (account.id, account.sent_folder_name)
                if k in folders:
                    account.sent_folder = folders[k]
                else:
                    account.sent_folder = Folder(account_id=account.id,
                                                 name=account.sent_folder_name)
            if account.drafts_folder_name:
                k = (account.id, account.drafts_folder_name)
                if k in folders:
                    account.drafts_folder = folders[k]
                else:
                    account.drafts_folder = Folder(
                        account_id=account.id, name=account.drafts_folder_name)
            # all/archive mismatch is intentional; semantics have changed
            if account.archive_folder_name:
                k = (account.id, account.archive_folder_name)
                if k in folders:
                    account.all_folder = folders[k]
                else:
                    account.all_folder = Folder(
                        account_id=account.id,
                        name=account.archive_folder_name)
        db_session.commit()

        if easupdate:
            print "Migrating EAS accounts' *_folder_name fields to reference " "Folder rows..."

            for account in db_session.query(Account).filter_by(provider="eas"):
                if account.inbox_folder_name:
                    k = (account.id, account.inbox_folder_name)
                    if k in folders:
                        account.inbox_folder = folders[k]
                    else:
                        account.inbox_folder = Folder(
                            account_id=account.id,
                            name=account.inbox_folder_name)
                if account.sent_folder_name:
                    k = (account.id, account.sent_folder_name)
                    if k in folders:
                        account.sent_folder = folders[k]
                    else:
                        account.sent_folder = Folder(
                            account_id=account.id,
                            name=account.sent_folder_name)
                if account.drafts_folder_name:
                    k = (account.id, account.drafts_folder_name)
                    if k in folders:
                        account.drafts_folder = folders[k]
                    else:
                        account.drafts_folder = Folder(
                            account_id=account.id,
                            name=account.drafts_folder_name)
                if account.archive_folder_name:
                    k = (account.id, account.archive_folder_name)
                    if k in folders:
                        account.archive_folder = folders[k]
                    else:
                        account.archive_folder = Folder(
                            account_id=account.id,
                            name=account.archive_folder_name)
            db_session.commit()

    print "Final schema tweaks and new constraint enforcement"
    op.alter_column("folderitem",
                    "folder_id",
                    existing_type=sa.Integer(),
                    nullable=False)
    op.drop_constraint("folder_name", "folderitem", type_="unique")
    op.drop_constraint("folder_name", "imapuid", type_="unique")
    op.create_unique_constraint(
        "uq_imapuid_folder_id_msg_uid_imapaccount_id",
        "imapuid",
        ["folder_id", "msg_uid", "imapaccount_id"],
    )
    op.drop_column("folderitem", "folder_name")
    op.drop_column("imapuid", "folder_name")
    op.drop_column("account", "inbox_folder_name")
    op.drop_column("account", "drafts_folder_name")
    op.drop_column("account", "sent_folder_name")
    op.drop_column("account", "archive_folder_name")

    if easupdate:
        print "Dropping old EASUid columns..."

        op.drop_constraint("folder_name", "easuid", type_="unique")
        op.drop_index("easuid_easaccount_id_folder_name", "easuid")
        op.drop_column("easuid", "folder_name")
class TestEndToEnd(object):
    timeout_seconds = 60

    @pytest.fixture
    def table_name(self, replhandler):
        return '{0}_biz'.format(replhandler)

    @pytest.fixture
    def avro_schema(self, table_name):
        return {
            u'fields': [{
                u'type': u'int',
                u'name': u'id',
                u'pkey': 1
            }, {
                u'default': None,
                u'maxlen': 64,
                u'type': [u'null', u'string'],
                u'name': u'name'
            }],
            u'namespace':
            u'',
            u'name':
            table_name,
            u'type':
            u'record',
            u'pkey': [u'id']
        }

    @pytest.fixture(params=[{
        'table_name':
        'test_complex_table',
        'test_schema': [
            # test_bit
            # ColumnInfo('BIT(8)', mysql.BIT, 3),

            # test_tinyint
            ColumnInfo('TINYINT', mysql.TINYINT(), 127),
            ColumnInfo('TINYINT(3) SIGNED',
                       mysql.TINYINT(display_width=3, unsigned=False), -128),
            ColumnInfo('TINYINT(3) UNSIGNED',
                       mysql.TINYINT(display_width=3, unsigned=True), 255),
            ColumnInfo(
                'TINYINT(3) UNSIGNED ZEROFILL',
                mysql.TINYINT(display_width=3, unsigned=True, zerofill=True),
                5),
            ColumnInfo('BOOL', mysql.BOOLEAN(), 1),
            ColumnInfo('BOOLEAN', mysql.BOOLEAN(), 1),

            # test_smallint
            ColumnInfo('SMALLINT', mysql.SMALLINT(), 32767),
            ColumnInfo('SMALLINT(5) SIGNED',
                       mysql.SMALLINT(display_width=5, unsigned=False),
                       -32768),
            ColumnInfo('SMALLINT(5) UNSIGNED',
                       mysql.SMALLINT(display_width=5, unsigned=True), 65535),
            ColumnInfo(
                'SMALLINT(3) UNSIGNED ZEROFILL',
                mysql.SMALLINT(display_width=3, unsigned=True, zerofill=True),
                5),

            # test_mediumint
            ColumnInfo('MEDIUMINT', mysql.MEDIUMINT(), 8388607),
            ColumnInfo('MEDIUMINT(7) SIGNED',
                       mysql.MEDIUMINT(display_width=7, unsigned=False),
                       -8388608),
            ColumnInfo('MEDIUMINT(8) UNSIGNED',
                       mysql.MEDIUMINT(display_width=8, unsigned=True),
                       16777215),
            ColumnInfo(
                'MEDIUMINT(3) UNSIGNED ZEROFILL',
                mysql.MEDIUMINT(display_width=3, unsigned=True, zerofill=True),
                5),

            # test_int
            ColumnInfo('INT', mysql.INTEGER(), 2147483647),
            ColumnInfo('INT(10) SIGNED',
                       mysql.INTEGER(display_width=10, unsigned=False),
                       -2147483648),
            ColumnInfo('INT(11) UNSIGNED',
                       mysql.INTEGER(display_width=11, unsigned=True),
                       4294967295),
            ColumnInfo(
                'INT(3) UNSIGNED ZEROFILL',
                mysql.INTEGER(display_width=3, unsigned=True, zerofill=True),
                5),
            ColumnInfo('INTEGER(3)', mysql.INTEGER(display_width=3), 3),

            # test_bigint
            ColumnInfo('BIGINT(19)', mysql.BIGINT(display_width=19),
                       23372854775807),
            ColumnInfo('BIGINT(19) SIGNED',
                       mysql.BIGINT(display_width=19, unsigned=False),
                       -9223372036854775808),
            # ColumnInfo('BIGINT(20) UNSIGNED', mysql.INTEGER(display_width=20, unsigned=True), 18446744073709551615),
            ColumnInfo(
                'BIGINT(3) UNSIGNED ZEROFILL',
                mysql.BIGINT(display_width=3, unsigned=True, zerofill=True),
                5),

            # test_decimal
            ColumnInfo('DECIMAL(9, 2)', mysql.DECIMAL(precision=9, scale=2),
                       101.41),
            ColumnInfo('DECIMAL(12, 11) SIGNED',
                       mysql.DECIMAL(precision=12, scale=11, unsigned=False),
                       -3.14159265359),
            ColumnInfo('DECIMAL(2, 1) UNSIGNED',
                       mysql.DECIMAL(precision=2, scale=1, unsigned=True),
                       0.0),
            ColumnInfo(
                'DECIMAL(9, 2) UNSIGNED ZEROFILL',
                mysql.DECIMAL(precision=9,
                              scale=2,
                              unsigned=True,
                              zerofill=True), 5.22),
            ColumnInfo('DEC(9, 3)', mysql.DECIMAL(precision=9, scale=3),
                       5.432),
            ColumnInfo('FIXED(9, 3)', mysql.DECIMAL(precision=9, scale=3),
                       45.432),

            # test_float
            ColumnInfo('FLOAT', mysql.FLOAT(), 3.14),
            ColumnInfo('FLOAT(5, 3) SIGNED',
                       mysql.FLOAT(precision=5, scale=3, unsigned=False),
                       -2.14),
            ColumnInfo('FLOAT(5, 3) UNSIGNED',
                       mysql.FLOAT(precision=5, scale=3, unsigned=True), 2.14),
            ColumnInfo(
                'FLOAT(5, 3) UNSIGNED ZEROFILL',
                mysql.FLOAT(precision=5, scale=3, unsigned=True,
                            zerofill=True), 24.00),
            ColumnInfo('FLOAT(5)', mysql.FLOAT(5), 24.01),
            ColumnInfo('FLOAT(30)', mysql.FLOAT(30), 24.01),

            # test_double
            ColumnInfo('DOUBLE', mysql.DOUBLE(), 3.14),
            ColumnInfo('DOUBLE(5, 3) SIGNED',
                       mysql.DOUBLE(precision=5, scale=3, unsigned=False),
                       -3.14),
            ColumnInfo('DOUBLE(5, 3) UNSIGNED',
                       mysql.DOUBLE(precision=5, scale=3, unsigned=True),
                       2.14),
            ColumnInfo(
                'DOUBLE(5, 3) UNSIGNED ZEROFILL',
                mysql.DOUBLE(precision=5,
                             scale=3,
                             unsigned=True,
                             zerofill=True), 24.00),
            ColumnInfo('DOUBLE PRECISION', mysql.DOUBLE(), 3.14),
            ColumnInfo('REAL', mysql.DOUBLE(), 3.14),

            # test_date_time
            ColumnInfo('DATE', mysql.DATE(), datetime.date(1901, 1, 1)),
            ColumnInfo('DATE', mysql.DATE(), datetime.date(2050, 12, 31)),
            ColumnInfo('DATETIME', mysql.DATETIME(),
                       datetime.datetime(1970, 1, 1, 0, 0, 1, 0)),
            ColumnInfo('DATETIME', mysql.DATETIME(),
                       datetime.datetime(2038, 1, 19, 3, 14, 7, 0)),
            ColumnInfo('DATETIME(6)', mysql.DATETIME(fsp=6),
                       datetime.datetime(1970, 1, 1, 0, 0, 1, 111111)),
            ColumnInfo('DATETIME(6)', mysql.DATETIME(fsp=6),
                       datetime.datetime(2038, 1, 19, 3, 14, 7, 999999)),
            ColumnInfo('TIMESTAMP', mysql.TIMESTAMP(),
                       datetime.datetime(1970, 1, 1, 0, 0, 1, 0)),
            ColumnInfo('TIMESTAMP', mysql.TIMESTAMP(),
                       datetime.datetime(2038, 1, 19, 3, 14, 7, 0)),
            ColumnInfo('TIMESTAMP(6)', mysql.TIMESTAMP(fsp=6),
                       datetime.datetime(1970, 1, 1, 0, 0, 1, 111111)),
            ColumnInfo('TIMESTAMP(6)', mysql.TIMESTAMP(fsp=6),
                       datetime.datetime(2038, 1, 19, 3, 14, 7, 999999)),
            ColumnInfo('TIME', mysql.TIME(), datetime.timedelta(0, 0, 0)),
            ColumnInfo('TIME', mysql.TIME(),
                       datetime.timedelta(0, 23 * 3600 + 59 * 60 + 59, 0)),
            ColumnInfo('TIME(6)', mysql.TIME(fsp=6),
                       datetime.timedelta(0, 0, 111111)),
            ColumnInfo('TIME(6)', mysql.TIME(fsp=6),
                       datetime.timedelta(0, 23 * 3600 + 59 * 60 + 59,
                                          999999)),
            ColumnInfo('YEAR', mysql.YEAR(), 2000),
            ColumnInfo('YEAR(4)', mysql.YEAR(display_width=4), 2000),

            # test_char
            ColumnInfo('CHAR', mysql.CHAR(), 'a'),
            ColumnInfo('CHARACTER', mysql.CHAR(), 'a'),
            ColumnInfo('NATIONAL CHAR', mysql.CHAR(), 'a'),
            ColumnInfo('NCHAR', mysql.CHAR(), 'a'),
            ColumnInfo('CHAR(0)', mysql.CHAR(length=0), ''),
            ColumnInfo('CHAR(10)', mysql.CHAR(length=10), '1234567890'),
            ColumnInfo('VARCHAR(1000)', mysql.VARCHAR(length=1000), 'asdasdd'),
            ColumnInfo('CHARACTER VARYING(1000)', mysql.VARCHAR(length=1000),
                       'test dsafnskdf j'),
            ColumnInfo('NATIONAL VARCHAR(1000)', mysql.VARCHAR(length=1000),
                       'asdkjasd'),
            ColumnInfo('NVARCHAR(1000)', mysql.VARCHAR(length=1000),
                       'asdkjasd'),
            ColumnInfo('VARCHAR(10000)', mysql.VARCHAR(length=10000),
                       '1234567890'),

            # test_binary
            ColumnInfo('BINARY(5)', mysql.BINARY(length=5), 'hello'),
            ColumnInfo('VARBINARY(100)', mysql.VARBINARY(length=100), 'hello'),
            ColumnInfo('TINYBLOB', mysql.TINYBLOB(), 'hello'),
            ColumnInfo('TINYTEXT', mysql.TINYTEXT(), 'hello'),
            ColumnInfo('BLOB', mysql.BLOB(), 'hello'),
            ColumnInfo('BLOB(100)', mysql.BLOB(length=100), 'hello'),
            ColumnInfo('TEXT', mysql.TEXT(), 'hello'),
            ColumnInfo('TEXT(100)', mysql.TEXT(length=100), 'hello'),
            ColumnInfo('MEDIUMBLOB', mysql.MEDIUMBLOB(), 'hello'),
            ColumnInfo('MEDIUMTEXT', mysql.MEDIUMTEXT(), 'hello'),
            ColumnInfo('LONGBLOB', mysql.LONGBLOB(), 'hello'),
            ColumnInfo('LONGTEXT', mysql.LONGTEXT(), 'hello'),

            # test_enum
            ColumnInfo("ENUM('ONE', 'TWO')", mysql.ENUM(['ONE', 'TWO']),
                       'ONE'),

            # test_set
            ColumnInfo("SET('ONE', 'TWO')", mysql.SET(['ONE', 'TWO']),
                       set(['ONE', 'TWO']))
        ]
    }])
    def complex_table(self, request):
        return request.param

    @pytest.fixture
    def complex_table_name(self, replhandler, complex_table):
        return "{0}_{1}".format(replhandler, complex_table['table_name'])

    @pytest.fixture
    def complex_table_schema(self, complex_table):
        return complex_table['test_schema']

    def _build_sql_column_name(self, complex_column_name):
        return 'test_{}'.format(complex_column_name)

    def _build_complex_column_create_query(self, complex_column_name,
                                           complex_column_schema):
        return '`{0}` {1}'.format(complex_column_name, complex_column_schema)

    @pytest.fixture
    def complex_table_create_query(self, complex_table_schema):
        return ", ".join([
            self._build_complex_column_create_query(
                self._build_sql_column_name(indx), complex_column_schema.type)
            for indx, complex_column_schema in enumerate(complex_table_schema)
        ])

    @pytest.fixture
    def sqla_objs(self, complex_table_schema):
        return [
            complex_column_schema.sqla_obj
            for complex_column_schema in complex_table_schema
        ]

    @pytest.fixture
    def create_complex_table(self, containers, rbrsource, complex_table_name,
                             complex_table_create_query):
        if complex_table_create_query.strip():
            complex_table_create_query = ", {}".format(
                complex_table_create_query)
        query = """CREATE TABLE {complex_table_name}
        (
            `id` int(11) NOT NULL PRIMARY KEY
            {complex_table_create_query}
        ) ENGINE=InnoDB DEFAULT CHARSET=utf8
        """.format(complex_table_name=complex_table_name,
                   complex_table_create_query=complex_table_create_query)

        execute_query_get_one_row(containers, rbrsource, query)

    @pytest.fixture
    def ComplexModel(self, complex_table_name, create_complex_table,
                     complex_table_schema):
        class Model(Base):
            __tablename__ = complex_table_name
            id = Column('id', Integer, primary_key=True)

        for indx, complex_column_schema in enumerate(complex_table_schema):
            col_name = self._build_sql_column_name(indx)
            setattr(Model, col_name,
                    Column(col_name, complex_column_schema.sqla_obj))
        return Model

    @pytest.fixture
    def actual_complex_data(self, complex_table_schema):
        res = {'id': 1}
        for indx, complex_column_schema in enumerate(complex_table_schema):
            if isinstance(complex_column_schema.sqla_obj, mysql.DATE):
                data = complex_column_schema.data.strftime('%Y-%m-%d')
            elif isinstance(complex_column_schema.sqla_obj, mysql.DATETIME):
                data = complex_column_schema.data.strftime(
                    '%Y-%m-%d %H:%M:%S.%f')
            elif isinstance(complex_column_schema.sqla_obj, mysql.TIMESTAMP):
                data = complex_column_schema.data.strftime(
                    '%Y-%m-%d %H:%M:%S.%f')
            elif isinstance(complex_column_schema.sqla_obj, mysql.TIME):
                time = datetime.time(
                    complex_column_schema.data.seconds / 3600,
                    (complex_column_schema.data.seconds / 60) % 60,
                    complex_column_schema.data.seconds % 60,
                    complex_column_schema.data.microseconds)
                data = time.strftime('%H:%M:%S.%f')
            else:
                data = complex_column_schema.data
            res.update({self._build_sql_column_name(indx): data})
        return res

    @pytest.fixture
    def expected_complex_data(self, actual_complex_data, complex_table_schema):
        expected_complex_data_dict = {'id': 1}
        for indx, complex_column_schema in enumerate(complex_table_schema):
            column_name = self._build_sql_column_name(indx)
            if isinstance(complex_column_schema.sqla_obj, mysql.SET):
                expected_complex_data_dict[column_name] = \
                    sorted(actual_complex_data[column_name])
            elif isinstance(complex_column_schema.sqla_obj, mysql.DATETIME):
                date_time_obj = \
                    complex_column_schema.data.isoformat()
                expected_complex_data_dict[column_name] = date_time_obj
            elif isinstance(complex_column_schema.sqla_obj, mysql.TIMESTAMP):
                date_time_obj = \
                    complex_column_schema.data.replace(tzinfo=pytz.utc)
                expected_complex_data_dict[column_name] = date_time_obj
            elif isinstance(complex_column_schema.sqla_obj, mysql.TIME):
                number_of_micros = transform_timedelta_to_number_of_microseconds(
                    complex_column_schema.data)
                expected_complex_data_dict[column_name] = number_of_micros
            else:
                expected_complex_data_dict[column_name] = \
                    complex_column_schema.data
        return expected_complex_data_dict

    def test_complex_table(self, containers, rbrsource, complex_table_name,
                           ComplexModel, actual_complex_data,
                           expected_complex_data, schematizer, namespace,
                           rbr_source_session, gtid_enabled):
        if not gtid_enabled:
            increment_heartbeat(containers, rbrsource)

        complex_instance = ComplexModel(**actual_complex_data)
        rbr_source_session.add(complex_instance)
        rbr_source_session.commit()
        messages = _fetch_messages(containers, schematizer, namespace,
                                   complex_table_name, 1)
        expected_messages = [
            {
                'message_type': MessageType.create,
                'payload_data': expected_complex_data
            },
        ]

        _verify_messages(messages, expected_messages)

    def test_create_table(self, containers, rbrsource, schematracker,
                          create_table_query, avro_schema, table_name,
                          namespace, schematizer, rbr_source_session,
                          gtid_enabled):
        if not gtid_enabled:
            increment_heartbeat(containers, rbrsource)
        execute_query_get_one_row(
            containers, rbrsource,
            create_table_query.format(table_name=table_name))

        # Need to poll for the creation of the table
        _wait_for_table(containers, schematracker, table_name)

        # Check the schematracker db also has the table.
        verify_create_table_query = "SHOW CREATE TABLE {table_name}".format(
            table_name=table_name)
        verify_create_table_result = execute_query_get_one_row(
            containers, schematracker, verify_create_table_query)
        expected_create_table_result = execute_query_get_one_row(
            containers, rbrsource, verify_create_table_query)
        self.assert_expected_result(verify_create_table_result,
                                    expected_create_table_result)

        # It's necessary to insert data for the topic to actually be created.
        Biz = _generate_basic_model(table_name)
        rbr_source_session.add(Biz(id=1, name='insert'))
        rbr_source_session.commit()

        _wait_for_schematizer_topic(schematizer, namespace, table_name)

        # Check schematizer.
        self.check_schematizer_has_correct_source_info(table_name=table_name,
                                                       avro_schema=avro_schema,
                                                       namespace=namespace,
                                                       schematizer=schematizer)

    def test_create_table_with_row_format(self, containers, rbrsource,
                                          schematracker, replhandler,
                                          gtid_enabled):
        table_name = '{0}_row_format_tester'.format(replhandler)
        create_table_stmt = """
        CREATE TABLE {name}
        ( id int(11) primary key) ROW_FORMAT=COMPRESSED ENGINE=InnoDB
        """.format(name=table_name)
        if not gtid_enabled:
            increment_heartbeat(containers, rbrsource)
        execute_query_get_one_row(containers, rbrsource, create_table_stmt)

        _wait_for_table(containers, schematracker, table_name)
        # Check the schematracker db also has the table.
        verify_create_table_query = "SHOW CREATE TABLE {table_name}".format(
            table_name=table_name)
        verify_create_table_result = execute_query_get_one_row(
            containers, schematracker, verify_create_table_query)
        expected_create_table_result = execute_query_get_one_row(
            containers, rbrsource, verify_create_table_query)
        self.assert_expected_result(verify_create_table_result,
                                    expected_create_table_result)

    def test_alter_table(self, containers, rbrsource, schematracker,
                         alter_table_query, table_name, gtid_enabled):
        if not gtid_enabled:
            increment_heartbeat(containers, rbrsource)
        execute_query_get_one_row(
            containers, rbrsource,
            alter_table_query.format(table_name=table_name))
        execute_query_get_one_row(
            containers, rbrsource,
            "ALTER TABLE {name} ROW_FORMAT=COMPRESSED".format(name=table_name))

        time.sleep(2)

        # Check the schematracker db also has the table.
        verify_describe_table_query = "DESCRIBE {table_name}".format(
            table_name=table_name)
        verify_alter_table_result = execute_query_get_all_rows(
            containers, schematracker, verify_describe_table_query)
        expected_alter_table_result = execute_query_get_all_rows(
            containers, rbrsource, verify_describe_table_query)

        if 'address' in verify_alter_table_result[0].values():
            actual_result = verify_alter_table_result[0]
        elif 'address' in verify_alter_table_result[1].values():
            actual_result = verify_alter_table_result[1]
        else:
            raise AssertionError('The alter table query did not succeed')

        if 'address' in expected_alter_table_result[0].values():
            expected_result = expected_alter_table_result[0]
        else:
            expected_result = expected_alter_table_result[1]

        self.assert_expected_result(actual_result, expected_result)

    def test_basic_table(self, containers, replhandler, rbrsource,
                         create_table_query, namespace, schematizer,
                         rbr_source_session, gtid_enabled):
        if not gtid_enabled:
            increment_heartbeat(containers, rbrsource)

        source = "{0}_basic_table".format(replhandler)
        execute_query_get_one_row(containers, rbrsource,
                                  create_table_query.format(table_name=source))

        BasicModel = _generate_basic_model(source)
        model_1 = BasicModel(id=1, name='insert')
        model_2 = BasicModel(id=2, name='insert')
        rbr_source_session.add(model_1)
        rbr_source_session.add(model_2)
        rbr_source_session.commit()
        model_1.name = 'update'
        rbr_source_session.delete(model_2)
        rbr_source_session.commit()

        messages = _fetch_messages(containers, schematizer, namespace, source,
                                   4)
        expected_messages = [
            {
                'message_type': MessageType.create,
                'payload_data': {
                    'id': 1,
                    'name': 'insert'
                }
            },
            {
                'message_type': MessageType.create,
                'payload_data': {
                    'id': 2,
                    'name': 'insert'
                }
            },
            {
                'message_type': MessageType.update,
                'payload_data': {
                    'id': 1,
                    'name': 'update'
                },
                'previous_payload_data': {
                    'id': 1,
                    'name': 'insert'
                }
            },
            {
                'message_type': MessageType.delete,
                'payload_data': {
                    'id': 2,
                    'name': 'insert'
                }
            },
        ]
        _verify_messages(messages, expected_messages)

    def test_table_with_contains_pii(self, containers, replhandler, rbrsource,
                                     create_table_query, namespace,
                                     schematizer, rbr_source_session,
                                     gtid_enabled):
        with reconfigure(encryption_type='AES_MODE_CBC-1',
                         key_location='acceptance/configs/data_pipeline/'):
            if not gtid_enabled:
                increment_heartbeat(containers, rbrsource)

            source = "{}_secret_table".format(replhandler)
            execute_query_get_one_row(
                containers, rbrsource,
                create_table_query.format(table_name=source))

            BasicModel = _generate_basic_model(source)
            model_1 = BasicModel(id=1, name='insert')
            model_2 = BasicModel(id=2, name='insert')
            rbr_source_session.add(model_1)
            rbr_source_session.add(model_2)
            rbr_source_session.commit()

            messages = _fetch_messages(containers, schematizer, namespace,
                                       source, 2)
            expected_messages = [{
                'message_type': MessageType.create,
                'payload_data': {
                    'id': 1,
                    'name': 'insert'
                }
            }, {
                'message_type': MessageType.create,
                'payload_data': {
                    'id': 2,
                    'name': 'insert'
                }
            }]
            _verify_messages(messages, expected_messages)

    def check_schematizer_has_correct_source_info(self, table_name,
                                                  avro_schema, namespace,
                                                  schematizer):
        sources = schematizer.get_sources_by_namespace(namespace)
        source = next(src for src in reversed(sources)
                      if src.name == table_name)
        topic = schematizer.get_topics_by_source_id(source.source_id)[-1]
        schema = schematizer.get_latest_schema_by_topic_name(topic.name)
        assert schema.topic.source.name == table_name
        assert schema.topic.source.namespace.name == namespace
        assert schema.schema_json == avro_schema

    def assert_expected_result(self, result, expected):
        for key, value in expected.iteritems():
            assert result[key] == value