예제 #1
0
    def test_schema_collection_remove(self):
        metadata = MetaData()

        t1 = Table('t1', metadata, Column('x', Integer), schema='foo')
        t2 = Table('t2', metadata, Column('x', Integer), schema='bar')
        t3 = Table('t3', metadata, Column('x', Integer), schema='bar')

        metadata.remove(t3)
        eq_(metadata._schemas, set(['foo', 'bar']))
        eq_(len(metadata.tables), 2)

        metadata.remove(t1)
        eq_(metadata._schemas, set(['bar']))
        eq_(len(metadata.tables), 1)
예제 #2
0
def run():
    db = create_engine(engine_name, echo=True)
    db = connect()
    metadata = MetaData(db)
    insp = reflection.Inspector.from_engine(db)
    tables = []
    for table_name in insp.get_table_names():
        table = Table(table_name, metadata, autoload=True, autoload_with=db)
        if not table_name.endswith('_aud'):
            tables.append(table)
        else:
            table.drop(db)
            metadata.remove(table)

    audit_tables = []
    for t in tables:
        audit_table = create_audit_table(t)
        audit_tables.append(audit_table)

    #create_sqlite_backup_db(audit_tables)
    create_triggers(db, tables)
    metadata.create_all()
예제 #3
0
def run():
    db = create_engine(engine_name, echo=True)
    db = connect()
    metadata = MetaData(db)
    insp = reflection.Inspector.from_engine(db)
    tables = []
    for table_name in insp.get_table_names():
        table = Table(table_name, metadata, autoload=True, autoload_with=db)
        if not table_name.endswith('_aud'):
            tables.append(table)
        else:
            table.drop(db)
            metadata.remove(table)        

    audit_tables = []
    for t in tables:
        audit_table = create_audit_table(t)
        audit_tables.append(audit_table)
        
    create_sqlite_backup_db(audit_tables)
    create_triggers(db, tables)
    metadata.create_all()
예제 #4
0
def entity_model(entity, entity_only=False, with_supporting_document=False):
    """
    Creates a mapped class and corresponding relationships from an entity
    object. Entities of 'EntitySupportingDocument' type are not supported
    since they are already mapped from their parent classes, a TypeError will
    be raised.
    :param entity: Entity
    :type entity: Entity
    :param entity_only: True to only reflect the table corresponding to the
    specified entity. Remote entities and corresponding relationships will
    not be reflected.
    :type entity_only: bool
    :return: An SQLAlchemy model reflected from the table in the database
    corresponding to the specified entity object.
    """
    if entity.TYPE_INFO == 'ENTITY_SUPPORTING_DOCUMENT':
        raise TypeError('<EntitySupportingDocument> type not supported. '
                        'Please use the parent entity.')

    rf_entities = [entity.name]

    if not entity_only:
        parents = [p.name for p in entity.parents()]
        children = [c.name for c in entity.children()]
        associations = [a.name for a in entity.associations()]

        rf_entities.extend(parents)
        rf_entities.extend(children)
        rf_entities.extend(associations)

    _bind_metadata(metadata)

    #We will use a different metadata object just for reflecting 'rf_entities'
    rf_metadata = MetaData(metadata.bind)
    rf_metadata.reflect(only=rf_entities)
    '''
    Remove supporting document tables if entity supports them. The supporting
    document models will be setup manually.
    '''
    ent_supporting_docs_table = None
    profile_supporting_docs_table = None

    if entity.supports_documents and not entity_only:
        ent_supporting_doc = entity.supporting_doc.name
        profile_supporting_doc = entity.profile.supporting_document.name

        ent_supporting_docs_table = rf_metadata.tables.get(
            ent_supporting_doc, None)
        profile_supporting_docs_table = rf_metadata.tables.get(
            profile_supporting_doc, None)

        #Remove the supporting doc tables from the metadata
        if not ent_supporting_docs_table is None:
            rf_metadata.remove(ent_supporting_docs_table)
        if not profile_supporting_docs_table is None:
            rf_metadata.remove(profile_supporting_docs_table)

    Base = automap_base(metadata=rf_metadata, cls=Model)
    '''
    Return the supporting document model that corresponds to the
    primary entity.
    '''
    supporting_doc_model = None

    #Setup supporting document models
    if entity.supports_documents and not entity_only:
        supporting_doc_model = configure_supporting_documents_inheritance(
            ent_supporting_docs_table, profile_supporting_docs_table, Base,
            entity.name)

    #Set up mapped classes and relationships
    Base.prepare(
        name_for_collection_relationship=_rename_supporting_doc_collection,
        generate_relationship=_gen_relationship)

    if with_supporting_document and not entity_only:
        return getattr(Base.classes, entity.name, None), supporting_doc_model

    return getattr(Base.classes, entity.name, None)
예제 #5
0
from sqlalchemy import create_engine,MetaData,inspect,\
 Table,Column,Integer,Text
from sqlalchemy.dialects.postgresql import JSON, JSONB

proc_num = cpu_count()
db = create_engine(connection_string, pool_size=proc_num)
engine = db.connect()
metadata = MetaData(engine)

table_name = '_'.join([ontology_name, 'term'])
print 'Table name', table_name
inspector = inspect(engine)
if table_name in inspector.get_table_names():
    table = Table(table_name, metadata, autoload=True)
    table.drop(checkfirst=True)
    metadata.remove(table)
table = Table(table_name, metadata,
              Column('id', Text, primary_key=True, index=True, unique=True),
              Column('terms', JSON), Column('depth', Integer),
              Column('relations', JSON))

table.create(checkfirst=True)
#metadata.create_all()

#term->ont->[{id:id,terms:[terms],depth:depth,relations:{type:[adjacency_list]}}]
posts = map(
    lambda x: {
        'id':
        x,
        'terms':
        new_id_to_terms[x],
예제 #6
0
def create_triggers(db, tables):


    db = create_engine(engine_name)
    db.echo = True
    db.connect()
    metadata = MetaData(db)


    insp = reflection.Inspector.from_engine(db)

    tables = []
    for table_name in insp.get_table_names():
        if not table_name.endswith('_aud'):
            table = Table(table_name, metadata, autoload=True, autoload_with=db)
            tables.append(table)
            #print("TABLE: %s"%table)
            #print table.__repr__
        else:
            table = Table(table_name, metadata, autoload=True, autoload_with=db)
            table.drop(db)
            metadata.remove(table)


    drop_trigger_text = """DROP TRIGGER IF EXISTS %(trigger_name)s;"""
    for table in tables:
        pk_cols = [c.name for c in table.primary_key]
        for pk_col in pk_cols:
            try:
                db.execute(drop_trigger_text % {
                    'trigger_name' : table.name + "_ins_trig",
                })
            except:
                pass

        for pk_col in pk_cols:
            try:
                db.execute(drop_trigger_text % {
                    'trigger_name' : table.name + "_upd_trig",
                })
            except:
                pass
    #metadata.create_all()

    trigger_text = """
                    CREATE TRIGGER
                        %(trigger_name)s
                    AFTER %(action)s ON
                        %(table_name)s
                    FOR EACH ROW
                        BEGIN
                            INSERT INTO %(table_name)s_aud
                            SELECT
                                d.*,
                                '%(action)s',
                                NULL,
                                NOW()
                            FROM
                                %(table_name)s
                                AS d
                            WHERE
                                %(pkd)s;
                        END
                        """

    for table in tables:


        pk_cols = [c.name for c in table.primary_key]
        pkd = []

        for pk_col in pk_cols:
            pkd.append("d.%s = NEW.%s"%(pk_col, pk_col))

        text_dict = {
            'action'       : 'INSERT',
            'trigger_name' : table.name + "_ins_trig",
            'table_name'   : table.name,
            'pkd'           : ' and '.join(pkd),
        }

        logging.info(trigger_text % text_dict)
        trig_ddl = DDL(trigger_text % text_dict)
        trig_ddl.execute_at('after-create', table.metadata)

        text_dict['action'] = 'UPDATE'
        text_dict['trigger_name'] = table.name + "_upd_trig"
        trig_ddl = DDL(trigger_text % text_dict)
        trig_ddl.execute_at('after-create', table.metadata)

    metadata.create_all()
예제 #7
0
class DBConnection(object):

    def __init__(self, connection_string='sqlite:///:memory:', echo=False):
        """Initialize a database connection."""
        self.engine = create_engine(connection_string, echo=echo)
        self.metadata = MetaData()
        self.metadata.bind = self.engine

    def get_scheme(self, rel_key):
        """Return the schema associated with a relation key."""

        table = self.metadata.tables[str(rel_key)]
        return Scheme((c.name, type_to_raco[type(c.type)])
                      for c in table.columns)

    def add_table(self, rel_key, schema, tuples=None):
        """Add a table to the database."""
        self.delete_table(rel_key, ignore_failure=True)
        assert str(rel_key) not in self.metadata.tables

        columns = [Column(n, raco_to_type[t](), nullable=False)
                   for n, t in schema.attributes]
        table = Table(str(rel_key), self.metadata, *columns)
        table.create(self.engine)
        if tuples:
            tuples = [{n: v for n, v in zip(schema.get_names(), tup)}
                      for tup in tuples]
            if tuples:
                self.engine.execute(table.insert(), tuples)

    def append_table(self, rel_key, tuples):
        """Append tuples to an existing relation."""
        scheme = self.get_scheme(rel_key)

        table = self.metadata.tables[str(rel_key)]
        tuples = [{n: v for n, v in zip(scheme.get_names(), tup)}
                  for tup in tuples]
        if tuples:
            self.engine.execute(table.insert(), tuples)

    def num_tuples(self, rel_key):
        """Return number of tuples of rel_key """
        table = self.metadata.tables[str(rel_key)]
        return self.engine.execute(table.count()).scalar()

    def get_table(self, rel_key):
        """Retrieve the contents of a table as a bag (Counter)."""
        table = self.metadata.tables[str(rel_key)]
        s = select([table])
        return collections.Counter(tuple(t) for t in self.engine.execute(s))

    def delete_table(self, rel_key, ignore_failure=False):
        """Delete a table from the database."""
        try:
            table = self.metadata.tables[str(rel_key)]
            table.drop(self.engine)
            self.metadata.remove(table)
        except:
            if not ignore_failure:
                raise

    def get_sql_output(self, sql):
        """Retrieve the result of a query as a bag (Counter)."""
        s = text(sql)
        return collections.Counter(tuple(t) for t in self.engine.execute(s))
예제 #8
0
class Database(object):

    TYPE_MAPPINGS = {
        types.CHAR: types.Unicode,
        types.VARCHAR: types.Unicode,
        types.Enum: types.Unicode,
        mssql.base.NTEXT: types.Unicode,
        mssql.base.NVARCHAR: types.Unicode,
        mssql.base.NCHAR: types.Unicode,
        mssql.base.VARCHAR: types.Unicode,
        mssql.base.BIT: types.Boolean,
        mssql.base.UNIQUEIDENTIFIER: types.Unicode,
        mssql.base.TIMESTAMP: types.Binary,
        mssql.base.XML: types.Unicode,
        mssql.base.BINARY: types.Binary,
        mssql.base.VARBINARY: types.LargeBinary,
        mssql.base.IMAGE: types.LargeBinary,
        mssql.base.SMALLMONEY: types.Numeric,
        mssql.base.SQL_VARIANT: types.LargeBinary,
        mysql.MEDIUMBLOB: types.LargeBinary,
        mysql.LONGBLOB: types.LargeBinary,
        mysql.MEDIUMINT: types.Integer,
        mysql.BIGINT: types.BigInteger,
        mysql.MEDIUMTEXT: types.Unicode,
        mysql.TINYTEXT: types.Unicode,
        mysql.LONGTEXT: types.Unicode,
        mysql.BLOB: types.LargeBinary,
        mysql.LONGBLOB: types.LargeBinary,
        types.BLOB: types.LargeBinary,
        types.VARBINARY: types.LargeBinary,
    }

    TYPE_BASES = (
        types.ARRAY,
        types.JSON,
        types.DateTime,
        types.BigInteger,
        types.Numeric,
        types.Float,
        types.Integer,
        types.Enum,
    )

    def __init__(self, uri):
        engine_kwargs = {'poolclass': NullPool}
        self.scheme = urlparse(uri).scheme.lower()
        # self.is_sqlite = 'sqlite' in self.scheme
        # self.is_postgres = 'postgres' in self.scheme
        # self.is_mysql = 'mysql' in self.scheme
        # self.is_mssql = 'mssql' in self.scheme

        self.uri = uri
        self.meta = MetaData()
        self.engine = create_engine(uri, **engine_kwargs)
        self.meta.bind = self.engine
        self.meta.reflect(resolve_fks=False)

    @property
    def tables(self):
        return self.meta.sorted_tables

    def count(self, table):
        return self.engine.execute(table.count()).scalar()

    def _translate_type(self, type_):
        type_ = self.TYPE_MAPPINGS.get(type_, type_)
        for base in self.TYPE_BASES:
            if issubclass(type_, base):
                type_ = base
                break
        return self.TYPE_MAPPINGS.get(type_, type_)

    def create(self, table, mapping, drop=False):
        columns = []
        for column in table.columns:
            cname = mapping.columns.get(column.name)
            ctype = self._translate_type(type(column.type))
            # not reading nullable from source:
            columns.append(Column(cname, ctype, nullable=True))
        if mapping.name in self.meta.tables:
            table = self.meta.tables[mapping.name]
            if not drop:
                return (table, True)
            log.warning("Drop existing table: %s", mapping.name)
            table.drop(self.engine)
            self.meta.remove(table)
        target_table = Table(mapping.name, self.meta, *columns)
        target_table.create(self.engine)
        return (target_table, False)

    def _convert_value(self, value, table, column):
        if isinstance(column.type, (types.DateTime, types.Date)):
            if value in ('0000-00-00 00:00:00', '0000-00-00'):
                value = None
        if isinstance(column.type, (types.String, types.Unicode)):
            if isinstance(value, str):
                value = remove_unsafe_chars(value)
        return value

    def copy(self,
             source_db,
             source_table,
             target_table,
             mapping,
             chunk_size=10000):
        conn = source_db.engine.connect()
        conn = conn.execution_options(stream_results=True)
        proxy = conn.execute(source_table.select())
        # log.info("Chunk size: %d", chunk_size)
        while True:
            rows = proxy.fetchmany(size=chunk_size)
            if not len(rows):
                break
            chunk = []
            for row in rows:
                item = {}
                for src_name, value in row.items():
                    target_name = mapping.columns.get(src_name)
                    column = target_table.columns[target_name]
                    value = self._convert_value(value, target_table, column)
                    item[target_name] = value
                chunk.append(item)
                yield item
            target_table.insert().execute(chunk)
예제 #9
0
class SqliteWrapper(object):
    def __init__(self,
                 db_path=DataBaseConfig.file_path,
                 clear_db=False,
                 exist_optimize=True,
                 action_limit=True):
        from sqlalchemy import create_engine, MetaData
        from sqlalchemy.orm import sessionmaker
        self.engine = create_engine('sqlite:///{host}'.format(host=db_path),
                                    echo=False)
        self.metadata = MetaData(bind=self.engine)
        self.__table_definition__()
        if clear_db is True:
            self.clear_db()
        self.metadata.create_all(bind=self.engine, checkfirst=True)
        self.__table_mapping__()
        session = sessionmaker(bind=self.engine)
        self.session = session()
        self.__book_dict__ = None
        self.__reader_dict__ = None
        self.__event_dict__ = None
        self.__action_range__ = None
        self.__optimize_check__ = exist_optimize
        self.__action_limit__ = action_limit
        self.optimize()
        # from sqlalchemy.orm import create_session
        # self.session = create_session(bind=self.engine)
        if self.__action_limit__ is True:
            self.__init_action_limit__()

    def __table_definition__(self):
        from structures.Book import define_book_table
        from structures.Event import define_event_table
        from structures.Reader import define_reader_table
        self.user_table = define_reader_table(self.metadata)
        self.book_table = define_book_table(self.metadata)
        self.events_table = define_event_table(self.metadata)

    def __table_mapping__(self):
        from sqlalchemy.orm import mapper
        mapper(Reader, self.user_table)
        mapper(Book, self.book_table)
        mapper(Event, self.events_table)

    def optimize(self):
        if self.__optimize_check__ is True:
            self.__optimize_book_check__()
            self.__optimize_reader_check__()
            self.__optimize_event_check__()

    def __optimize_book_check__(self):
        tmp = self.session.query(Book).all()  # TODO: 数据量比较大的时候需要只 select index
        self.__book_dict__ = {var.index: None for var in tmp}

    def __optimize_reader_check__(self):
        tmp = self.session.query(Reader).all()
        self.__reader_dict__ = {var.index: None for var in tmp}

    def __optimize_event_check__(self):
        tmp = self.session.query(Event).all()
        self.__event_dict__ = dict()
        for event in tmp:
            key = event.hashable_key
            if key not in self.__event_dict__:
                self.__event_dict__[key] = None

    def __init_action_limit__(self):
        tmp = [event.date for event in self.session.query(Event).all()]
        self.__action_range__ = (min(tmp), max(tmp))

    def add_all(self, obj):
        if len(obj) > 0:
            self.session.add_all(obj)
            self.session.commit()
            self.optimize()

    def add(self, obj):
        self.session.add(obj)
        self.session.commit()
        if self.__optimize_check__ is True:
            if isinstance(obj, Book):  # updating exsiting_optimation
                self.__book_dict__[obj.index] = None
            elif isinstance(obj, Reader):
                self.__reader_dict__[obj.index] = None
            elif isinstance(obj, Event):
                self.__event_dict__[obj.hashable_key] = None
            else:
                raise TypeError

    def drop_tables(self, table):
        from sqlalchemy import Table
        assert isinstance(table, Table), str(
            TypeError('table should be of type sqlalchemy.Table'))
        self.metadata.remove(table)

    def clear_db(self):
        self.metadata.drop_all()

    def get_all(self, obj_type: type, **filter_by):
        """get_all all {obj_type} filter by {kwargs} -> list"""
        return self.session.query(obj_type).filter_by(**filter_by).all()

    def get_one(self, obj_type: type, **filter_by):
        return self.session.query(obj_type).filter_by(**filter_by).one()

    def exists_book(self, value: Book):
        from sqlalchemy.orm.exc import NoResultFound
        try:
            return value.index in self.__book_dict__
        except TypeError:
            try:
                self.session.query(Book).filter_by(index=value.index).one()
                return True
            except NoResultFound:
                return False

    def exists_reader(self, value: Reader):
        from sqlalchemy.orm.exc import NoResultFound
        try:
            return value.index in self.__reader_dict__
        except TypeError:
            try:
                self.session.query(Reader).filter_by(index=value.index).one()
                return True
            except NoResultFound:
                return False

    def exists_event(self, event: Event):
        from sqlalchemy.orm.exc import NoResultFound
        try:
            return event.hashable_key in self.__event_dict__
        except TypeError:
            try:
                self.session.query(Event).filter_by(
                    book_id=event.book_id,
                    reader_id=event.reader_id,
                    event_date=event.event_date,
                    event_type=event.event_type).one()
                return True
            except NoResultFound:
                return False

    def exists(self, obj):
        """wrapper.exists(obj) -> bool -- check whether obj exits in database"""
        if isinstance(obj, (list, tuple, set)):
            check_list = list()
            for i in range(len(obj)):
                check_list.append(self.exists(obj[i]))
            return check_list
        elif isinstance(obj, Book):
            return self.exists_book(obj)
        elif isinstance(obj, Reader):
            return self.exists_reader(obj)
        elif isinstance(obj, Event):
            return self.exists_event(obj)
        else:
            raise TypeError

    def merge(self, inst, load=True):
        if isinstance(inst, (Book, Reader)):
            self.session.merge(inst, load=load)
        elif isinstance(inst, Event):
            if self.__action_limit__ is True:
                if self.__action_range__[
                        0] <= inst.date <= self.__action_range__[1]:
                    raise PermissionError(
                        'Event {} can be changed since created.'.format(
                            inst.__repr__()))
                else:
                    self.session.merge(inst, load=load)
            else:
                self.session.merge(inst, load=load)
        else:
            raise TypeError
예제 #10
0
파일: __init__.py 프로젝트: gltn/stdm
def entity_model(entity, entity_only=False, with_supporting_document=False):
    """
    Creates a mapped class and corresponding relationships from an entity
    object. Entities of 'EntitySupportingDocument' type are not supported
    since they are already mapped from their parent classes, a TypeError will
    be raised.
    :param entity: Entity
    :type entity: Entity
    :param entity_only: True to only reflect the table corresponding to the
    specified entity. Remote entities and corresponding relationships will
    not be reflected.
    :type entity_only: bool
    :return: An SQLAlchemy model reflected from the table in the database
    corresponding to the specified entity object.
    """
    if entity.TYPE_INFO == 'ENTITY_SUPPORTING_DOCUMENT':
        raise TypeError('<EntitySupportingDocument> type not supported. '
                        'Please use the parent entity.')

    rf_entities = [entity.name]

    if not entity_only:
        parents = [p.name for p in entity.parents()]
        children = [c.name for c in entity.children()]
        associations = [a.name for a in entity.associations()]

        rf_entities.extend(parents)
        rf_entities.extend(children)
        rf_entities.extend(associations)

    _bind_metadata(metadata)

    # We will use a different metadata object just for reflecting 'rf_entities'
    rf_metadata = MetaData(metadata.bind)
    rf_metadata.reflect(only=rf_entities)

    '''
    Remove supporting document tables if entity supports them. The supporting
    document models will be setup manually.
    '''
    ent_supporting_docs_table = None
    profile_supporting_docs_table = None

    if entity.supports_documents and not entity_only:
        ent_supporting_doc = entity.supporting_doc.name
        profile_supporting_doc = entity.profile.supporting_document.name

        ent_supporting_docs_table = rf_metadata.tables.get(ent_supporting_doc,
                                                           None
        )
        profile_supporting_docs_table = rf_metadata.tables.get(
            profile_supporting_doc, None
        )

        # Remove the supporting doc tables from the metadata
        if not ent_supporting_docs_table is None:
            rf_metadata.remove(ent_supporting_docs_table)
        if not profile_supporting_docs_table is None:
            rf_metadata.remove(profile_supporting_docs_table)

    Base = automap_base(metadata=rf_metadata, cls=Model)
    '''
    Return the supporting document model that corresponds to the
    primary entity.
    '''
    supporting_doc_model = None

    # Setup supporting document models
    if entity.supports_documents and not entity_only:
        supporting_doc_model = configure_supporting_documents_inheritance(
            ent_supporting_docs_table, profile_supporting_docs_table, Base,
            entity.name
        )

    # Set up mapped classes and relationships
    Base.prepare(
        name_for_collection_relationship=_rename_supporting_doc_collection,
        generate_relationship=_gen_relationship
    )

    if with_supporting_document and not entity_only:
        return getattr(Base.classes, entity.name, None), supporting_doc_model

    return getattr(Base.classes, entity.name, None)
예제 #11
0
def create_triggers(db, tables):


    db = create_engine(engine_name)
    db.echo = True
    db.connect()
    metadata = MetaData(db)


    insp = reflection.Inspector.from_engine(db)

    tables = []
    for table_name in insp.get_table_names():
        if not table_name.endswith('_aud'):
            table = Table(table_name, metadata, autoload=True, autoload_with=db)
            tables.append(table)
            #print("TABLE: %s"%table)
            #print table.__repr__
        else:
            table = Table(table_name, metadata, autoload=True, autoload_with=db)
            table.drop(db)
            metadata.remove(table)        


    drop_trigger_text = """DROP TRIGGER IF EXISTS %(trigger_name)s;"""
    for table in tables:
        pk_cols = [c.name for c in table.primary_key]
        for pk_col in pk_cols:
            try:
                db.execute(drop_trigger_text % {
                    'trigger_name' : table.name + "_ins_trig",
                })
            except:
                pass

        for pk_col in pk_cols:
            try:
                db.execute(drop_trigger_text % {
                    'trigger_name' : table.name + "_upd_trig",
                })
            except:
                pass
    #metadata.create_all()

    trigger_text = """
                    CREATE TRIGGER
                        %(trigger_name)s
                    AFTER %(action)s ON
                        %(table_name)s
                    FOR EACH ROW
                        BEGIN
                            INSERT INTO %(table_name)s_aud
                            SELECT
                                d.*,
                                '%(action)s',
                                NULL,
                                date('now')
                            FROM
                                %(table_name)s
                                AS d
                            WHERE
                                %(pkd)s;
                        END
                        """
    
    for table in tables:


        pk_cols = [c.name for c in table.primary_key]
        pkd = []
        
        for pk_col in pk_cols:
            pkd.append("d.%s = NEW.%s"%(pk_col, pk_col))

        text_dict = {
            'action'       : 'INSERT',
            'trigger_name' : table.name + "_ins_trig",
            'table_name'   : table.name,
            'pkd'           : ' and '.join(pkd),
        }

        logging.info(trigger_text % text_dict)
        trig_ddl = DDL(trigger_text % text_dict)
        trig_ddl.execute_at('after-create', table.metadata)  

        text_dict['action'] = 'UPDATE'
        text_dict['trigger_name'] = table.name + "_upd_trig"
        trig_ddl = DDL(trigger_text % text_dict)
        trig_ddl.execute_at('after-create', table.metadata)  

    metadata.create_all()
예제 #12
0
class DatabaseWrapper:
    def __init__(self, connection_name: str):
        self._engine = create_engine(connection_name, echo=False)
        self._connection = self._engine.connect()
        self._metadata = MetaData()
        self._db_name = connection_name[connection_name.rfind('/') + 1:]
        self._tables = {}

    def __delete__(self, instance):
        self._connection.close()
        del self._connection
        del self._engine
        del self._metadata

    def create_table(self, table_name, columns):
        """ Create tables for database
         :param table_name name of the table to create
         :param columns list of column objects
         :param drop Drop existing values
         """
        self._tables[table_name] = Table(table_name, self._metadata, *columns)
        self._metadata.drop_all(self._engine)
        self._metadata.create_all(self._engine)
        return table_name

    def resume_table(self, table_name, primarykey_id, columns):
        self._tables[table_name] = Table(table_name, self._metadata, *columns)
        self._metadata.create_all(self._engine)
        return self.fetch_row(table_name, self.row_count(table_name),
                              primarykey_id)

    def fetch_row(self, table_name: str, row_number: int, column: int):
        if not self.has_table(table_name):
            return
        result = self._connection.execute(concat("select * from ", table_name))
        rows = result.fetchall()
        if row_number is -1 or row_number >= len(rows):
            row_number = self.row_count(table_name) - 2
        else:
            row_number -= 1
        if column >= self.column_count(table_name):
            raise IndexError("Colum index %d out of range %d " %
                             (column, len(result.keys())))
        if column is -1:
            return rows[row_number][:]
        return rows[row_number][column]

    def fetch_last(self, table_name):
        if not self.has_table(table_name):
            return
        result = self._connection.execute(concat("select * from ", table_name))
        row = result.fetchone()
        return row

    def row_count(self, table_name) -> int:
        res = self._connection.execute(concat("select * from ", table_name))
        rows = res.fetchall()
        return len(rows)

    def column_count(self, table_name) -> int:
        result = self._connection.execute(concat("select * from ", table_name))
        return len(result.keys())

    def has_table(self, table_name) -> bool:
        return self._engine.has_table(table_name)

    def drop_table(self, table_name) -> None:
        self._metadata.remove(self._tables[table_name])

    def insert_values(self, table_name, values: dict):
        self._insert(table_name, values)

    def select_table(self, table_name) -> ResultProxy:
        return self._connection.execute(self._tables[table_name].select())

    def get_dbname(self):
        return self._db_name

    def _insert(self, table_name, values: dict):
        ins = self._tables[table_name].insert().values(values)
        self._connection.execute(ins)
        pass

    @staticmethod
    def generate_columns(column_infos: dict, primarykey_index):
        """ Generate database columns
        :param column_infos dictionary of column name as key and column type with additional size info
        :type column_infos dict
        :param primarykey_index index of column which is primary key
        :type primarykey_index int
         """
        columns = list()
        index = 0
        for (col_name, col_type) in column_infos.items():
            primary = index == primarykey_index
            columns.append(Column(col_name, col_type, primary_key=primary))
            index += 1
        return columns
예제 #13
0
def upgrade(migrate_engine):
    """Perform sysinv database upgrade migrations (release4).
    """

    meta = MetaData()
    meta.bind = migrate_engine
    migrate_engine.connect()

    # 046_drop_iport.py
    i_port = Table('i_port', meta, autoload=True)
    i_port.drop()

    # 047_install_state.py
    i_host = Table('i_host', meta, autoload=True)
    i_host.create_column(Column('install_state', String(255)))
    i_host.create_column(Column('install_state_info', String(255)))

    # 048 Replace services enum with string (include ceph, platform)
    service_parameter = Table('service_parameter',
                              meta,
                              Column('id', Integer,
                                     primary_key=True, nullable=False),
                              mysql_engine=ENGINE, mysql_charset=CHARSET,
                              autoload=True)
    service_parameter.drop()
    meta.remove(service_parameter)
    service_parameter = Table(
        'service_parameter',
        meta,
        Column('created_at', DateTime),
        Column('updated_at', DateTime),
        Column('deleted_at', DateTime),
        Column('id', Integer, primary_key=True, nullable=False),
        Column('uuid', String(36), unique=True),
        Column('service', String(16)),
        Column('section', String(255)),
        Column('name', String(255)),
        Column('value', String(255)),
        UniqueConstraint('service', 'section', 'name',
                         name='u_servicesectionname'),
        mysql_engine=ENGINE,
        mysql_charset=CHARSET,
    )
    service_parameter.create(migrate_engine, checkfirst=False)

    # 049_add_controllerfs_scratch.py
    controller_fs = Table('controller_fs', meta, autoload=True)
    controller_fs.create_column(Column('scratch_gib', Integer))
    # 052_add_controllerfs_state.py
    controller_fs.create_column(Column('state', String(255)))

    # 050_services.py
    services = Table(
        'services',
        meta,
        Column('created_at', DateTime),
        Column('updated_at', DateTime),
        Column('deleted_at', DateTime),

        Column('id', Integer, primary_key=True, ),

        Column('name', String(255), nullable=False),
        Column('enabled', Boolean, default=False),

        mysql_engine=ENGINE,
        mysql_charset=CHARSET,
    )
    services.create()
    iservicegroup = Table('i_servicegroup', meta, autoload=True)
    iservicegroup.drop()

    # 051_mtce.py Enhance the services enum to include platform;
    # String per 048

    # 053_add_virtual_interface.py
    Table('interfaces', meta, autoload=True)

    virtual_interfaces = Table(
        'virtual_interfaces',
        meta,
        Column('created_at', DateTime),
        Column('updated_at', DateTime),
        Column('deleted_at', DateTime),
        Column('id', Integer, ForeignKey('interfaces.id',
                                         ondelete="CASCADE"),
               primary_key=True, nullable=False),

        Column('imac', String(255)),
        Column('imtu', Integer),
        Column('providernetworks', String(255)),
        Column('providernetworksdict', Text),

        mysql_engine=ENGINE,
        mysql_charset=CHARSET,
    )
    virtual_interfaces.create()

    # 054_system_mode.py
    systems = Table('i_system', meta, autoload=True)
    systems.create_column(Column('system_mode', String(255)))
    _populate_system_mode(systems)

    # 055_tpmconfig.py Seed HTTPS disabled capability in i_system table
    # only one system entry should be populated
    sys = list(systems.select().where(
        systems.c.uuid is not None).execute())
    if len(sys) > 0:
        json_dict = json.loads(sys[0].capabilities)
        json_dict['https_enabled'] = 'n'
        systems.update().where(
            systems.c.uuid == sys[0].uuid).values(
            {'capabilities': json.dumps(json_dict)}).execute()

    # Add tpmconfig DB table
    tpmconfig = Table(
        'tpmconfig',
        meta,
        Column('created_at', DateTime),
        Column('updated_at', DateTime),
        Column('deleted_at', DateTime),

        Column('id', Integer, primary_key=True, nullable=False),
        Column('uuid', String(36), unique=True),

        Column('tpm_path', String(255)),

        mysql_engine=ENGINE,
        mysql_charset=CHARSET,
    )
    tpmconfig.create()

    # Add tpmdevice DB table
    tpmdevice = Table(
        'tpmdevice',
        meta,
        Column('created_at', DateTime),
        Column('updated_at', DateTime),
        Column('deleted_at', DateTime),

        Column('id', Integer, primary_key=True, nullable=False),
        Column('uuid', String(36), unique=True),

        Column('state', String(255)),
        Column('host_id', Integer,
               ForeignKey('i_host.id', ondelete='CASCADE')),

        mysql_engine=ENGINE,
        mysql_charset=CHARSET,
    )
    tpmdevice.create()

    # 056_ipv_add_failed_status.py
    # Enhance the pv_state enum to include 'failed'
    if migrate_engine.url.get_dialect() is postgresql.dialect:
        i_pv = Table('i_pv',
                     meta,
                     Column('id', Integer, primary_key=True, nullable=False),
                     mysql_engine=ENGINE, mysql_charset=CHARSET,
                     autoload=True)

        migrate_engine.execute('ALTER TABLE i_pv DROP CONSTRAINT "pvStateEnum"')
        # In 16.10, as DB changes by PATCH are not supported, we use 'reserve1' instead of
        # 'failed'. Therefore, even though upgrades with PVs in 'failed' state should not
        # be allowed, we still have to guard against them by converting 'reserve1' to
        # 'failed' everywhere.
        LOG.info("Migrate pv_state")
        migrate_engine.execute('UPDATE i_pv SET pv_state=\'failed\' WHERE pv_state=\'reserve1\'')

        pv_state_col = i_pv.c.pv_state
        pv_state_col.alter(Column('pv_state', String(32)))

    # 057_idisk_id_path_wwn.py
    i_idisk = Table('i_idisk', meta, autoload=True)

    # Add the columns for persistently identifying devices.
    i_idisk.create_column(Column('device_id', String(255)))
    i_idisk.create_column(Column('device_path', String(255)))
    i_idisk.create_column(Column('device_wwn', String(255)))

    # Remove the device_node unique constraint and add a unique constraint for
    # device_path.
    UniqueConstraint('device_node', 'forihostid', table=i_idisk,
                     name='u_devhost').drop()
    UniqueConstraint('device_path', 'forihostid', table=i_idisk,
                     name='u_devhost').create()

    # 058_system_timezone.py
    systems.create_column(Column('timezone', String(255)))
    _populate_system_timezone(systems)

    # 059 N/A

    # 060_disk_device_path.py
    i_pv = Table('i_pv', meta, autoload=True)
    ceph_mon = Table('ceph_mon', meta, autoload=True)
    journal_table = Table('journal', meta, autoload=True)
    storage_lvm = Table('storage_lvm', meta, autoload=True)
    # Update the i_pv table.
    i_pv.create_column(Column('idisk_device_path', String(255)))
    # Update the ceph_mon table.
    col_resource = getattr(ceph_mon.c, 'device_node')
    col_resource.alter(name='device_path')
    _update_ceph_mon_device_path(ceph_mon)
    # Update the journal table.
    col_resource = getattr(journal_table.c, 'device_node')
    col_resource.alter(name='device_path')
    # Update the storage_lvm table.
    _update_storage_lvm_device_path(storage_lvm)

    # 062_iscsi_initiator_name.py
    i_host = Table('i_host', meta, autoload=True)
    i_host.create_column(Column('iscsi_initiator_name', String(64)))
예제 #14
0
class DBConnection(object):

    def __init__(self, connection_string='sqlite:///:memory:', echo=False):
        """Initialize a database connection."""
        self.engine = create_engine(connection_string, echo=echo)
        self.metadata = MetaData()
        self.metadata.bind = self.engine
        self.__add_function_registry__()

    def __add_function_registry__(self):
        functions_schema = Scheme([("name", types.STRING_TYPE),
                                   ("description", types.STRING_TYPE),
                                   ("outputType", types.STRING_TYPE),
                                   ("lang", types.INT_TYPE),
                                   ("binary", types.BLOB_TYPE)])

        columns = [Column(n, raco_to_type[t](), nullable=False)
                   for n, t in functions_schema.attributes]
        table = Table("registered_functions", self.metadata, *columns)
        table.create(self.engine)

    def get_scheme(self, rel_key):
        """Return the schema associated with a relation key."""

        table = self.metadata.tables[str(rel_key)]
        return Scheme((c.name, type_to_raco[type(c.type)])
                      for c in table.columns)

    def add_table(self, rel_key, schema, tuples=None):
        """Add a table to the database."""
        self.delete_table(rel_key, ignore_failure=True)
        assert str(rel_key) not in self.metadata.tables

        columns = [Column(n, raco_to_type[t](), nullable=False)
                   for n, t in schema.attributes]
        table = Table(str(rel_key), self.metadata, *columns)
        table.create(self.engine)
        if tuples:
            tuples = [{n: v for n, v in zip(schema.get_names(), tup)}
                      for tup in tuples]
            if tuples:
                self.engine.execute(table.insert(), tuples)

    def append_table(self, rel_key, tuples):
        """Append tuples to an existing relation."""
        scheme = self.get_scheme(rel_key)

        table = self.metadata.tables[str(rel_key)]
        tuples = [{n: v for n, v in zip(scheme.get_names(), tup)}
                  for tup in tuples]
        if tuples:
            self.engine.execute(table.insert(), tuples)

    def num_tuples(self, rel_key):
        """Return number of tuples of rel_key """
        table = self.metadata.tables[str(rel_key)]
        return self.engine.execute(table.count()).scalar()

    def get_table(self, rel_key):
        """Retrieve the contents of a table as a bag (Counter)."""
        table = self.metadata.tables[str(rel_key)]
        s = select([table])
        return collections.Counter(tuple(t) for t in self.engine.execute(s))

    def delete_table(self, rel_key, ignore_failure=False):
        """Delete a table from the database."""
        try:
            table = self.metadata.tables[str(rel_key)]
            table.drop(self.engine)
            self.metadata.remove(table)
        except:
            if not ignore_failure:
                raise

    def get_sql_output(self, sql):
        """Retrieve the result of a query as a bag (Counter)."""
        s = text(sql)
        return collections.Counter(tuple(t) for t in self.engine.execute(s))

    def get_function(self, name):
        """Retrieve a function from catalog."""
        s = "select * from registered_functions where name=" + str(name)
        return dict(self.engine.execute(s).first())

    def register_function(self, tup):
        """Register a function in the catalog."""
        table = self.metadata.tables['registered_functions']
        scheme = self.get_scheme('registered_functions')
        func = [{n: v for n, v in zip(scheme.get_names(), tup)}]
        self.engine.execute(table.insert(), func)
예제 #15
0
    e = create_engine(db_conn_string)
    m = MetaData()
    if args.schema is not None:
        m.reflect(bind=e, schema=args.schema)
    else:
        m.reflect(bind=e)
    tables_in_db = m.tables.keys()
except Exception as e:
    print('Problems accessing database: %s' % e)
    sys.exit(1)

# remove unwanted tables
if args.exclude is not None:
    if args.exclude is list:
        for t in args.exclude:
            m.remove(m.tables[t])
    else:
        m.remove(m.tables[args.exclude])

# generate the schema graph
graph = create_schema_graph(
    tables=[m.tables[x] for x in list(m.tables.keys())],
    show_datatypes=False,
    show_indexes=False,
    rankdir='TB',
    concentrate=True,
)

# Write out graph to the corresponding file
if args.file is not None:
    # get file extension
예제 #16
0
class Sqlite(object):
    connection_dict = dict()

    def __init__(self, db_path: str = ':memory:'):
        from sqlalchemy import create_engine, MetaData

        # 建立连接
        self.engine = create_engine('sqlite:///{host}'.format(host=db_path),
                                    echo=False)
        self.metadata = MetaData(bind=self.engine)
        self.__session__ = None

        # 若目标表格不存在则创建
        # self.metadata.create_all(bind=self.engine, checkfirst=True)

    @classmethod
    def new(cls, db_path: str = ':memory:'):
        if db_path in cls.connection_dict:
            return cls.connection_dict[db_path]
        else:
            return cls(db_path=db_path)

    @property
    def session(self):
        from sqlalchemy.orm import Session
        if self.__session__ is None:
            from sqlalchemy.orm import sessionmaker
            # from sqlalchemy.orm import create_session
            # self.session = create_session(bind=self.engine)
            ses = sessionmaker(bind=self.engine)
            self.__session__ = ses()
        assert isinstance(self.__session__, Session)
        return self.__session__

    def execute(self, clause: str, params=None):
        self.session.execute(
            clause,
            params=params,
        )
        self.session.flush()

    @staticmethod
    def map(obj_type: type, table_def: Table, create_table: bool = True):
        from sqlalchemy.exc import ArgumentError
        try:
            mapper(obj_type, table_def)
            table_def.create(checkfirst=create_table)
        except ArgumentError:
            pass

    def add_all(self, obj):
        if len(obj) > 0:
            self.session.add_all(obj)
            self.session.commit()

    def add(self, obj):
        self.session.add(obj)
        self.session.commit()

    def delete_table(self, table: Table):
        self.metadata.remove(table)

    def clean(self):
        self.metadata.drop_all()

    def close(self):
        self.session.flush()
        self.session.close()