Exemple #1
0
    def register(cls):
        # register the entity
        if issubclass(cls, EntityMixin):
            entity_type_ = entity_type
            if not entity_type:
                entity_type_ = cls.__name__.lower()

            if entity_type_ not in zvt_context.entity_types:
                zvt_context.entity_types.append(entity_type_)
            zvt_context.entity_schema_map[entity_type_] = cls

            add_to_map_list(the_map=zvt_context.entity_map_schemas, key=entity_type, value=cls)
        return cls
Exemple #2
0
def register_schema(regions: List[Region],
                    providers: Dict[(Region, List[Provider])],
                    db_name: str,
                    schema_base: DeclarativeMeta,
                    entity_type: EntityType = EntityType.Stock):
    """
    function for register schema,please declare them before register

    :param providers: the supported providers for the schema
    :type providers:
    :param db_name: database name for the schema
    :type db_name:
    :param schema_base:
    :type schema_base:
    :param entity_type: the schema related entity_type
    :type entity_type:
    :return:
    :rtype:
    """
    schemas = []
    for region in regions:
        for item in schema_base._decl_class_registry.items():
            cls = item[1]
            if type(cls) == DeclarativeMeta:
                # register provider to the schema
                for provider in providers[region]:
                    if issubclass(cls, Mixin):
                        cls.register_provider(region, provider)

                if zvt_context.dbname_map_schemas.get(db_name):
                    schemas = zvt_context.dbname_map_schemas[db_name]
                zvt_context.schemas.append(cls)
                add_to_map_list(the_map=zvt_context.entity_map_schemas,
                                key=entity_type,
                                value=cls)
                schemas.append(cls)

        for provider in providers[region]:
            # track in in  _providers
            if region in zvt_context.providers.keys():
                if provider not in zvt_context.providers[region]:
                    zvt_context.providers[region].append(provider)
            else:
                zvt_context.providers.update({region: [provider]})

            if not zvt_context.provider_map_dbnames.get(provider):
                zvt_context.provider_map_dbnames[provider] = []
            zvt_context.provider_map_dbnames[provider].append(db_name)
            zvt_context.dbname_map_base[db_name] = schema_base

            # create the db & table
            engine = get_db_engine(region, provider, db_name=db_name)
            if engine is None: continue
            schema_base.metadata.create_all(engine)

            session_fac = get_db_session_factory(region,
                                                 provider,
                                                 db_name=db_name)
            session_fac.configure(bind=engine)

        for provider in providers[region]:
            engine = get_db_engine(region, provider, db_name=db_name)
            if engine is None: continue
            inspector = Inspector.from_engine(engine)

            # create index for 'id','timestamp','entity_id','code','report_period','updated_timestamp
            for table_name, table in iter(schema_base.metadata.tables.items()):
                index_column_names = [
                    index['name']
                    for index in inspector.get_indexes(table_name)
                ]

                logger.debug('engine:{},table:{},index:{}'.format(
                    engine, table_name, index_column_names))

                for col in [
                        'id', 'timestamp', 'entity_id', 'code',
                        'report_period', 'created_timestamp',
                        'updated_timestamp'
                ]:
                    if col in table.c:
                        index_name = '{}_{}_index'.format(table_name, col)
                        if index_name not in index_column_names:
                            column = eval('table.c.{}'.format(col))
                            # if col == 'timestamp': column = '-' + column
                            # index = sqlalchemy.schema.Index(index_name, column, unique=(col=='id'))
                            index = sqlalchemy.schema.Index(index_name, column)
                            index.create(engine)
                for cols in [('timestamp', 'entity_id'),
                             ('timestamp', 'code')]:
                    if (cols[0] in table.c) and (col[1] in table.c):
                        index_name = '{}_{}_{}_index'.format(
                            table_name, col[0], col[1])
                        if index_name not in index_column_names:
                            column0 = eval('table.c.{}'.format(col[0]))
                            column1 = eval('table.c.{}'.format(col[1]))
                            index = sqlalchemy.schema.Index(
                                index_name, column0, column1)
                            index.create(engine)

    zvt_context.dbname_map_schemas[db_name] = schemas
Exemple #3
0
def register_schema(providers: List[str],
                    db_name: str,
                    schema_base: DeclarativeMeta,
                    entity_type: str = None):
    """
    function for register schema,please declare them before register

    :param providers: the supported providers for the schema
    :type providers:
    :param db_name: database name for the schema
    :type db_name:
    :param schema_base:
    :type schema_base:
    :param entity_type: the schema related entity_type
    :type entity_type:
    :return:
    :rtype:
    """
    schemas = []
    for item in schema_base.registry.mappers:
        cls = item.class_
        if type(cls) == DeclarativeMeta:
            # register provider to the schema
            for provider in providers:
                if issubclass(cls, Mixin):
                    cls.register_provider(provider)

            if zvt_context.dbname_map_schemas.get(db_name):
                schemas = zvt_context.dbname_map_schemas[db_name]
            zvt_context.schemas.append(cls)
            if entity_type:
                add_to_map_list(the_map=zvt_context.entity_map_schemas,
                                key=entity_type,
                                value=cls)
            schemas.append(cls)

    zvt_context.dbname_map_schemas[db_name] = schemas

    for provider in providers:
        if provider not in zvt_context.providers:
            zvt_context.providers.append(provider)

        if not zvt_context.provider_map_dbnames.get(provider):
            zvt_context.provider_map_dbnames[provider] = []
        zvt_context.provider_map_dbnames[provider].append(db_name)
        zvt_context.dbname_map_base[db_name] = schema_base

        # create the db & table
        engine = get_db_engine(provider, db_name=db_name)
        schema_base.metadata.create_all(engine)

        session_fac = get_db_session_factory(provider, db_name=db_name)
        session_fac.configure(bind=engine)

    for provider in providers:
        engine = get_db_engine(provider, db_name=db_name)

        # create index for 'timestamp','entity_id','code','report_period','updated_timestamp
        for table_name, table in iter(schema_base.metadata.tables.items()):
            index_list = []
            with engine.connect() as con:
                rs = con.execute("PRAGMA INDEX_LIST('{}')".format(table_name))
                for row in rs:
                    index_list.append(row[1])

            logger.debug('engine:{},table:{},index:{}'.format(
                engine, table_name, index_list))

            for col in [
                    'timestamp', 'entity_id', 'code', 'report_period',
                    'created_timestamp', 'updated_timestamp'
            ]:
                if col in table.c:
                    column = eval('table.c.{}'.format(col))
                    index_name = '{}_{}_index'.format(table_name, col)
                    if index_name not in index_list:
                        index = sqlalchemy.schema.Index(index_name, column)
                        index.create(engine)
            for cols in [('timestamp', 'entity_id'), ('timestamp', 'code')]:
                if (cols[0] in table.c) and (col[1] in table.c):
                    column0 = eval('table.c.{}'.format(col[0]))
                    column1 = eval('table.c.{}'.format(col[1]))
                    index_name = '{}_{}_{}_index'.format(
                        table_name, col[0], col[1])
                    if index_name not in index_list:
                        index = sqlalchemy.schema.Index(
                            index_name, column0, column1)
                        index.create(engine)
Exemple #4
0
def register_schema(regions: List[Region],
                    providers: Dict[(Region, List[Provider])],
                    db_name: str,
                    schema_base: DeclarativeMeta,
                    entity_type: EntityType = None):
    """
    function for register schema,please declare them before register

    :param providers: the supported providers for the schema
    :type providers:
    :param db_name: database name for the schema
    :type db_name:
    :param schema_base:
    :type schema_base:
    :param entity_type: the schema related entity_type
    :type entity_type:
    :return:
    :rtype:
    """
    schemas = []
    for region in regions:
        # for item in schema_base._decl_class_registry.items():
        for item in schema_base.registry.mappers:
            cls = item.class_
            if type(item.class_) == DeclarativeMeta:
                # register provider to the schema
                [
                    cls.register_provider(region, provider)
                    for provider in providers[region]
                    if issubclass(cls, Mixin)
                ]

                if dbname_map_schemas.get(db_name):
                    schemas = dbname_map_schemas[db_name]
                zvt_context.schemas.append(cls)

                if entity_type:
                    add_to_map_list(the_map=zvt_context.entity_map_schemas,
                                    key=entity_type,
                                    value=cls)
                schemas.append(cls)

        # create the db & table
        engine = get_db_engine(region, schema_base, db_name=db_name)
        if engine is None:
            continue

        for provider in providers[region]:
            # track in in  _providers
            if region in zvt_context.providers.keys():
                if provider not in zvt_context.providers[region]:
                    zvt_context.providers[region].append(provider)
            else:
                zvt_context.providers.update({region: [provider]})

            if not provider_map_dbnames.get(provider):
                provider_map_dbnames[provider] = []
            provider_map_dbnames[provider].append(db_name)

            session_fac = get_db_session_factory(region,
                                                 provider,
                                                 db_name=db_name)
            session_fac.configure(bind=engine)

        set_db_name(db_name, schema_base)
        inspector = Inspector.from_engine(engine)

        if not dbname_map_index.get(region):
            dbname_map_index[region] = []

        # create index for 'id', 'timestamp', 'entity_id', 'code', 'report_period', 'updated_timestamp
        for table_name, table in iter(schema_base.metadata.tables.items()):
            if table_name in dbname_map_index[region]:
                continue

            dbname_map_index[region].append(table_name)

            index_column_names = [
                index['name'] for index in inspector.get_indexes(table_name)
            ]

            # try:
            #     index_column_names = [index['name'] for index in inspector.get_indexes(table_name)]
            # except Exception as e:
            #     logger.error(f'get table error: {e}')
            #     schema_base.metadata.create_all(engine)
            #     index_column_names = [index['name'] for index in inspector.get_indexes(table_name)]

            if zvt_config['debug'] == 2:
                logger.debug(
                    f'create index -> engine: {engine}, table: {table_name}, index: {index_column_names}'
                )

            for col in [
                    'timestamp', 'entity_id', 'code', 'report_period',
                    'created_timestamp', 'updated_timestamp'
            ]:
                if col in table.c:
                    index_name = '{}_{}_index'.format(table_name, col)
                    if index_name not in index_column_names:
                        column = eval('table.c.{}'.format(col))
                        if col == 'timestamp':
                            column = eval('table.c.{}.desc()'.format(col))
                        else:
                            column = eval('table.c.{}'.format(col))
                        # index = sqlalchemy.schema.Index(index_name, column, unique=(col=='id'))
                        index = sqlalchemy.schema.Index(index_name, column)
                        index.create(engine)

            for cols in [('timestamp', 'entity_id'), ('timestamp', 'code')]:
                if (cols[0] in table.c) and (col[1] in table.c):
                    index_name = f'{table_name}_{col[0]}_{col[1]}_index'
                    if index_name not in index_column_names:
                        column0 = eval('table.c.{}'.format(col[0]))
                        column1 = eval('table.c.{}'.format(col[1]))
                        index = sqlalchemy.schema.Index(
                            index_name, column0, column1)
                        index.create(engine)

    dbname_map_schemas[db_name] = schemas