async def run(self):
        # 抓取股票列表
        df_entity = self.bao_get_all_securities(
            to_bao_entity_type(EntityType.Stock))

        if pd_valid(df_entity):
            df_stock = self.to_entity(df_entity, entity_type=EntityType.Stock)

            # persist to Stock
            await df_to_db(region=self.region,
                           provider=self.provider,
                           data_schema=Stock,
                           db_session=get_db_session(self.region,
                                                     self.provider, Stock),
                           df=df_stock)

            # persist StockDetail too
            await df_to_db(region=self.region,
                           provider=self.provider,
                           data_schema=StockDetail,
                           db_session=get_db_session(self.region,
                                                     self.provider,
                                                     StockDetail),
                           df=df_stock)

            self.logger.info("persist stock list success")
    async def on_finish(self):
        last_year = str(now_pd_timestamp(self.region).year)
        codes = [item.code for item in self.entities]

        db_session = get_db_session(self.region, self.provider,
                                    DividendFinancing)

        need_filleds, column_names = DividendFinancing.query_data(
            region=self.region,
            provider=self.provider,
            db_session=db_session,
            codes=codes,
            end_timestamp=last_year,
            filters=[DividendFinancing.rights_raising_fund.is_(None)])

        if need_filleds:
            desc = self.data_schema.__name__ + ": update relevant table"

            db_session_1 = get_db_session(self.region, self.provider,
                                          self.data_schema)
            kafka_producer = connect_kafka_producer(findy_config['kafka'])

            for item in need_filleds:
                result, column_names = self.data_schema.query_data(
                    region=self.region,
                    provider=self.provider,
                    db_session=db_session_1,
                    entity_id=item.entity_id,
                    start_timestamp=item.timestamp,
                    end_timestamp=f"{item.timestamp.year}-12-31",
                    func=func.sum(self.data_schema.rights_raising_fund))

                if isinstance(result, (int, float)):
                    item.rights_raising_fund = result

                data = {
                    "task": 'rig',
                    "total": len(need_filleds),
                    "desc": desc,
                    "leave": True,
                    "update": 1
                }
                publish_message(kafka_producer, progress_topic,
                                bytes(progress_key, encoding='utf-8'),
                                bytes(json.dumps(data), encoding='utf-8'))

            try:
                db_session.commit()
            except Exception as e:
                self.logger.error(f'{self.__class__.__name__}, error: {e}')
                db_session.rollback()

        await super().on_finish()
Beispiel #3
0
    async def on_finish(self):
        desc = DividendFinancing.__name__ + ": update relevant table"
        db_session = get_db_session(self.region, self.provider, DividendFinancing)
        kafka_producer = connect_kafka_producer(findy_config['kafka'])

        for entity in self.entities:
            code_security = {}
            code_security[entity.code] = entity

            need_fill_items, column_names = DividendFinancing.query_data(
                region=self.region,
                provider=self.provider,
                db_session=db_session,
                codes=list(code_security.keys()),
                filters=[
                    DividendFinancing.ipo_raising_fund.is_(None),
                    DividendFinancing.ipo_issues != 0])

            if need_fill_items and len(need_fill_items) > 0:
                for need_fill_item in need_fill_items:
                    need_fill_item.ipo_raising_fund = code_security[entity.code].raising_fund

            data = {"task": 'div', "total": len(self.entities), "desc": desc, "leave": True, "update": 1}
            publish_message(kafka_producer, progress_topic, bytes(progress_key, encoding='utf-8'), bytes(json.dumps(data), encoding='utf-8'))

        try:
            db_session.commit()
        except Exception as e:
            self.logger.error(f'{self.__class__.__name__}, error: {e}')
            db_session.rollback()

        await super().on_finish()
Beispiel #4
0
    async def run(self):
        db_session = get_db_session(self.region, self.provider, self.data_schema)
        kafka_producer = connect_kafka_producer(findy_config['kafka'])

        if not hasattr(self, 'entities'):
            self.entities: List = None
            await self.init_entities(db_session)

        if self.entities and len(self.entities) > 0:
            http_session = get_async_http_session()
            throttler = asyncio.Semaphore(self.share_para[0])

            (taskid, desc) = self.share_para[1]
            data = {"task": taskid, "total": len(self.entities), "desc": desc, "leave": True, "update": 0}
            publish_message(kafka_producer, progress_topic, bytes(progress_key, encoding='utf-8'), bytes(json.dumps(data), encoding='utf-8'))

            # tasks = [asyncio.ensure_future(self.process_loop(entity, http_session, db_session, throttler)) for entity in self.entities]
            tasks = [self.process_loop(entity, http_session, db_session, kafka_producer, throttler) for entity in self.entities]

            # for result in asyncio.as_completed(tasks):
            #     await result

            [await _ for _ in asyncio.as_completed(tasks)]

            await self.on_finish()

            return await http_session.close()
Beispiel #5
0
    async def get_stocks(cls, region: Region, provider: Provider, timestamp, code=None, codes=None, ids=None):
        """
        the publishing policy of portfolio positions is different for different types,
        overwrite this function for get the holding stocks in specific date

        :param code: portfolio(etf/block/index...) code
        :param codes: portfolio(etf/block/index...) codes
        :param ids: portfolio(etf/block/index...) ids
        :param timestamp: the date of the holding stocks
        :param provider: the data provider
        :return:
        """
        from findy.database.plugins.register import get_schema_by_name
        from findy.database.context import get_db_session

        schema_str = f'{cls.__name__}Stock'
        portfolio_stock = get_schema_by_name(schema_str)
        db_session = get_db_session(region, provider, data_schema=portfolio_stock)
        data, column_names = portfolio_stock.query_data(
            region=region,
            provider=provider,
            db_session=db_session,
            code=code,
            codes=codes,
            timestamp=timestamp,
            ids=ids)

        if data and len(data) > 0:
            return pd.DataFrame([s.__dict__ for s in data], columns=column_names)
        else:
            return pd.DataFrame()
    async def run(self):
        http_session = get_sync_http_session()
        db_session = get_db_session(self.region, self.provider, self.data_schema)

        # 上证、中证
        await self.fetch_csi_index(http_session, db_session)

        # 深证
        await self.fetch_szse_index(http_session, db_session)
Beispiel #7
0
    def load_data(self):
        self.logger.info('load_data start')
        start_time = time.time()

        db_session = get_db_session(self.region, self.provider,
                                    self.data_schema)

        # params = dict(entity_ids=self.entity_ids, provider=self.provider,
        #               columns=self.columns, start_timestamp=self.start_timestamp,
        #               end_timestamp=self.end_timestamp, filters=self.filters,
        #               order=self.order, limit=self.limit, level=self.level,
        #               index=[self.category_field, self.time_field],
        #               time_field=self.time_field)
        # self.logger.info(f'query_data params:{params}')

        # 转换成标准entity_id
        if self.entity_schema and not self.entity_ids:
            entities, column_names = get_entities(
                region=self.region,
                provider=self.provider,
                db_session=db_session,
                entity_schema=self.entity_schema,
                exchanges=self.exchanges,
                codes=self.codes,
                columns=[self.entity_schema.entity_id])

            if len(entities) > 0:
                self.entity_ids = [entity.entity_id for entity in entities]
            else:
                return

        data, column_names = self.data_schema.query_data(
            region=self.region,
            provider=self.provider,
            db_session=db_session,
            entity_ids=self.entity_ids,
            columns=self.columns,
            start_timestamp=self.start_timestamp,
            end_timestamp=self.end_timestamp,
            filters=self.filters,
            order=self.order,
            limit=self.limit,
            level=self.level,
            index=[self.category_field, self.time_field],
            time_field=self.time_field)

        if data and not self.columns:
            self.data_df = pd.DataFrame([s.__dict__ for s in data],
                                        columns=column_names)
        else:
            self.data_df = pd.DataFrame(data, columns=column_names)

        cost_time = time.time() - start_time
        self.logger.info(f'load_data finished, cost_time:{cost_time}')

        for listener in self.data_listeners:
            listener.on_data_loaded(self.data_df)
Beispiel #8
0
def load_company_info(tickers=None):
    entity_schema = get_entity_schema_by_type(EntityType.StockDetail)
    db_session = get_db_session(Region.US, Provider.Yahoo, entity_schema)
    entities, column_names = get_entities(region=Region.US,
                                          provider=Provider.Yahoo,
                                          entity_schema=entity_schema,
                                          db_session=db_session,
                                          codes=tickers)

    df = pd.DataFrame([s.__dict__ for s in entities], columns=column_names)
    df.reset_index(drop=True, inplace=True)

    return df
Beispiel #9
0
    async def run(self):
        db_session = get_db_session(self.region, self.provider, self.data_schema)
        trade_days, column_names = StockTradeDay.query_data(
            region=self.region,
            provider=self.provider,
            db_session=db_session,
            order=StockTradeDay.timestamp.desc())

        if trade_days and len(trade_days) > 0:
            self.trade_day = [day.timestamp for day in trade_days]
        else:
            self.trade_day = []
            self.logger.warning("load trade days failed")

        return await super().run()
    async def persist_index(self, df) -> None:
        df['timestamp'] = df['timestamp'].apply(lambda x: to_pd_timestamp(x))
        df['list_date'] = df['list_date'].apply(lambda x: to_pd_timestamp(x))
        df['id'] = df['code'].apply(lambda code: f'index_cn_{code}')
        df['entity_id'] = df['id']
        df['exchange'] = 'cn'
        df['entity_type'] = EntityType.Index.value

        df = df.dropna(axis=0, how='any')
        df = df.drop_duplicates(subset='id', keep='last')

        db_session = get_db_session(self.region, self.provider, Index)
        await df_to_db(region=self.region,
                       provider=self.provider,
                       data_schema=Index,
                       db_session=db_session,
                       df=df)
    async def persist(self, df, db_session):
        start_point = time.time()
        # persist to Stock
        saved = await df_to_db(region=self.region,
                               provider=self.provider,
                               data_schema=self.data_schema,
                               db_session=db_session,
                               df=df)

        # persist StockDetail too
        await df_to_db(region=self.region,
                       provider=self.provider,
                       data_schema=StockDetail,
                       db_session=get_db_session(self.region, self.provider, StockDetail),
                       df=df,
                       force_update=True)

        return True, time.time() - start_point, saved
Beispiel #12
0
async def init_main_index(region: Region, provider=Provider.Exchange):
    if region == Region.CHN:
        for item in CHINA_STOCK_MAIN_INDEX:
            item['timestamp'] = to_pd_timestamp(item['timestamp'])
        df = pd.DataFrame(CHINA_STOCK_MAIN_INDEX)
    elif region == Region.US:
        for item in US_STOCK_MAIN_INDEX:
            item['timestamp'] = to_pd_timestamp(item['timestamp'])
        df = pd.DataFrame(US_STOCK_MAIN_INDEX)
    else:
        print("index not initialized, in file: init_main_index")
        df = pd.DataFrame()

    if pd_valid(df):
        await df_to_db(region=region,
                       provider=provider,
                       data_schema=Index,
                       db_session=get_db_session(region, provider, Index),
                       df=df)
    async def run(self):
        http_session = get_sync_http_session()
        db_session = get_db_session(self.region, self.provider,
                                    self.data_schema)

        # 抓取沪市 ETF 列表
        url = 'http://query.sse.com.cn/commonQuery.do?sqlId=COMMON_SSE_ZQPZ_ETFLB_L_NEW'
        text = sync_get(http_session,
                        url,
                        headers=DEFAULT_SH_ETF_LIST_HEADER,
                        return_type='text')
        if text is None:
            return

        response_dict = demjson.decode(text)

        df = pd.DataFrame(response_dict.get('result', []))
        await self.persist_etf_list(df, ChnExchange.SSE.value, db_session)
        self.logger.info('沪市 ETF 列表抓取完成...')

        # 抓取沪市 ETF 成分股
        await self.download_sh_etf_component(df, http_session, db_session)
        self.logger.info('沪市 ETF 成分股抓取完成...')

        # 抓取深市 ETF 列表
        url = 'http://www.szse.cn/api/report/ShowReport?SHOWTYPE=xlsx&CATALOGID=1945'
        content = sync_get(http_session, url, return_type='content')
        if content is None:
            return

        df = pd.read_excel(io.BytesIO(content), dtype=str)
        await self.persist_etf_list(df, ChnExchange.SZSE.value, db_session)
        self.logger.info('深市 ETF 列表抓取完成...')

        # 抓取深市 ETF 成分股
        await self.download_sz_etf_component(df, http_session, db_session)
        self.logger.info('深市 ETF 成分股抓取完成...')
Beispiel #14
0
    async def load_window_df(self, data_schema, window):
        window_df = None

        db_session = get_db_session(self.region, self.provider, data_schema)

        dfs = []
        for entity_id in self.entity_ids:
            data, column_names = data_schema.query_data(
                region=self.region,
                provider=self.provider,
                db_session=db_session,
                index=[self.category_field, self.time_field],
                order=data_schema.timestamp.desc(),
                entity_id=entity_id,
                limit=window)

            if data and len(data) > 0:
                df = pd.DataFrame([s.__dict__ for s in data],
                                  columns=column_names)
                dfs.append(df)
        if dfs:
            window_df = pd.concat(dfs)
            window_df = window_df.sort_index(level=[0, 1])
        return window_df
Beispiel #15
0
async def get_portfolio_stocks(region: Region,
                               provider: Provider,
                               timestamp,
                               portfolio_entity=Fund,
                               code=None,
                               codes=None,
                               ids=None):
    portfolio_stock = f'{portfolio_entity.__name__}Stock'
    data_schema: PortfolioStockHistory = get_schema_by_name(portfolio_stock)
    db_session = get_db_session(region, provider, data_schema)

    latests, column_names = data_schema.query_data(
        region=region,
        provider=provider,
        db_session=db_session,
        code=code,
        end_timestamp=timestamp,
        order=data_schema.timestamp.desc(),
        limit=1)

    if latests and len(latests) > 0:
        latest_record = latests[0]
        # 获取最新的报表
        data, column_names = data_schema.query_data(
            region=region,
            provider=provider,
            db_session=db_session,
            code=code,
            codes=codes,
            ids=ids,
            end_timestamp=timestamp,
            filters=[data_schema.report_date == latest_record.report_date])

        if data and len(data) > 0:
            df = pd.DataFrame([s.__dict__ for s in data], columns=column_names)

            # 最新的为年报或者半年报
            if latest_record.report_period == ReportPeriod.year or latest_record.report_period == ReportPeriod.half_year:
                return df
            # 季报,需要结合 年报或半年报 来算持仓
            else:
                step = 0
                while step <= 20:
                    report_date = get_recent_report_date(
                        latest_record.report_date, step=step)

                    data, column_names = data_schema.query_data(
                        region=region,
                        provider=provider,
                        db_session=db_session,
                        code=code,
                        codes=codes,
                        ids=ids,
                        end_timestamp=timestamp,
                        filters=[
                            data_schema.report_date == to_pd_timestamp(
                                report_date)
                        ])

                    if data and len(data) > 0:
                        pre_df = pd.DataFrame.from_records(
                            [s.__dict__ for s in data], columns=column_names)
                        df = df.append(pre_df)

                    # 半年报和年报
                    if (ReportPeriod.half_year.value
                            in pre_df['report_period'].tolist()) or (
                                ReportPeriod.year.value
                                in pre_df['report_period'].tolist()):
                        # 保留最新的持仓
                        df = df.drop_duplicates(subset=['stock_code'],
                                                keep='first')
                        return df
                    step = step + 1
    async def on_finish_entity(self, entity, http_session, db_session):
        total_time = await super().on_finish_entity(entity, http_session,
                                                    db_session)

        if not self.fetch_jq_timestamp:
            return total_time

        now = time.time()

        # fill the timestamp for report published date
        the_data_list, column_names = self.data_schema.query_data(
            region=self.region,
            provider=self.provider,
            db_session=db_session,
            entity_id=entity.id,
            order=self.data_schema.timestamp.asc(),
            filters=[
                self.data_schema.timestamp == self.data_schema.report_date
            ])

        if the_data_list and len(the_data_list) > 0:
            data, column_names = FinanceFactor.query_data(
                region=self.region,
                provider=self.provider,
                db_session=get_db_session(self.region, self.provider,
                                          FinanceFactor),
                entity_id=entity.id,
                columns=[
                    FinanceFactor.timestamp, FinanceFactor.report_date,
                    FinanceFactor.id
                ],
                filters=[
                    FinanceFactor.timestamp != FinanceFactor.report_date,
                    FinanceFactor.report_date >= the_data_list[0].report_date,
                    FinanceFactor.report_date <= the_data_list[-1].report_date
                ])

            if data and len(data) > 0:
                df = pd.DataFrame(data, columns=column_names)
                df = index_df(df,
                              index='report_date',
                              time_field='report_date')

                if pd_valid(df):
                    for the_data in the_data_list:
                        if the_data.report_date in df.index:
                            the_data.timestamp = df.at[the_data.report_date,
                                                       'timestamp']
                            self.logger.info(
                                'db fill {} {} timestamp:{} for report_date:{}'
                                .format(self.data_schema.__name__, entity.id,
                                        the_data.timestamp,
                                        the_data.report_date))
                    try:
                        db_session.commit()
                    except Exception as e:
                        self.logger.error(
                            f'{self.__class__.__name__}, error: {e}')
                        db_session.rollback()

        return total_time + time.time() - now