async def run(self):
        # 抓取股票列表
        df_entity = self.bao_get_all_securities(
            to_bao_entity_type(EntityType.Stock))

        if pd_valid(df_entity):
            df_stock = self.to_entity(df_entity, entity_type=EntityType.Stock)

            # persist to Stock
            await df_to_db(region=self.region,
                           provider=self.provider,
                           data_schema=Stock,
                           db_session=get_db_session(self.region,
                                                     self.provider, Stock),
                           df=df_stock)

            # persist StockDetail too
            await df_to_db(region=self.region,
                           provider=self.provider,
                           data_schema=StockDetail,
                           db_session=get_db_session(self.region,
                                                     self.provider,
                                                     StockDetail),
                           df=df_stock)

            self.logger.info("persist stock list success")
예제 #2
0
    async def eval_fetch_timestamps(self, entity, ref_record, http_session):
        timestamps = self.security_timestamps_map.get(entity.id)
        if not timestamps:
            timestamps = self.init_timestamps(entity, http_session)
            if self.start_timestamp:
                timestamps = [t for t in timestamps if t >= self.start_timestamp]

            if self.end_timestamp:
                timestamps = [t for t in timestamps if t <= self.end_timestamp]

            self.security_timestamps_map[entity.id] = timestamps

        if not timestamps:
            return None, None, 0, timestamps

        timestamps.sort()

        latest_timestamp = None
        try:
            if pd_valid(ref_record):
                time_field = self.get_evaluated_time_field()
                latest_timestamp = ref_record[time_field].max(axis=0)
        except Exception as e:
            self.logger.warning(f'get ref_record failed with error: {e}')

        if latest_timestamp is not None and isinstance(latest_timestamp, pd.Timestamp):
            timestamps = [t for t in timestamps if t >= latest_timestamp]

            if timestamps:
                return timestamps[0], timestamps[-1], len(timestamps), timestamps
            return None, None, 0, None

        return timestamps[0], timestamps[-1], len(timestamps), timestamps
예제 #3
0
    async def eval_fetch_timestamps(self, entity, ref_record, http_session):
        latest_timestamp = None
        try:
            if pd_valid(ref_record):
                time_field = self.get_evaluated_time_field()
                latest_timestamp = ref_record[time_field].max(axis=0)
        except Exception as e:
            self.logger.warning("get ref_record failed with error: {}".format(e))

        if not latest_timestamp:
            latest_timestamp = entity.timestamp

        if not latest_timestamp:
            return self.start_timestamp, self.end_timestamp, self.default_size, None

        if latest_timestamp.date() >= now_pd_timestamp(self.region).date():
            return latest_timestamp, None, 0, None

        if len(self.trade_day) > 0 and \
           latest_timestamp.date() >= self.trade_day[0].date():
            return latest_timestamp, None, 0, None

        if self.start_timestamp:
            latest_timestamp = max(latest_timestamp, self.start_timestamp)

        if self.end_timestamp and latest_timestamp > self.end_timestamp:
            size = 0
        else:
            size = self.default_size

        return latest_timestamp, self.end_timestamp, size, None
예제 #4
0
    async def record(self, entity, http_session, db_session, para):
        start_point = time.time()

        (ref_record, start, end, size, timestamps) = para

        if timestamps:
            original_list = []
            for the_timestamp in timestamps:
                param = self.generate_request_param(entity, start, end, size,
                                                    the_timestamp,
                                                    http_session)
                tmp_list = None
                try:
                    tmp_list = await self.api_wrapper.request(
                        http_session,
                        url=self.url,
                        param=param,
                        method=self.request_method,
                        path_fields=self.path_fields)
                except Exception as e:
                    self.logger.error(f"url: {self.url}, error: {e}")

                if tmp_list is None:
                    continue

                # fill timestamp field
                for tmp in tmp_list:
                    tmp[self.get_evaluated_time_field()] = the_timestamp
                original_list += tmp_list
                if len(original_list) == self.batch_size:
                    break
            return False, time.time() - start_point, (
                ref_record, pd.DataFrame.from_records(original_list))

        else:
            param = self.generate_request_param(entity, start, end, size, None,
                                                http_session)
            try:
                result = await self.api_wrapper.request(
                    http_session,
                    url=self.url,
                    param=param,
                    method=self.request_method,
                    path_fields=self.path_fields)
                df = pd.DataFrame.from_records(result)

                if pd_valid(df):
                    timefield = self.get_original_time_field()
                    df[timefield] = pd.to_datetime(df[timefield],
                                                   format=PD_TIME_FORMAT_DAY)
                    return False, time.time() - start_point, (ref_record,
                                                              self.format(
                                                                  entity, df))

            except Exception as e:
                self.logger.error(f"url: {self.url}, error: {e}")

        return True, time.time() - start_point, None
예제 #5
0
    async def record(self, entity, http_session, db_session, para):
        start_point = time.time()

        # get stock info
        info = self.tushare_get_info(entity)

        if pd_valid(info):
            return False, time.time() - start_point, self.format(entity, info)

        return True, time.time() - start_point, None
예제 #6
0
    async def record(self, entity, http_session, db_session, para):
        start_point = time.time()

        (ref_record, start, end, size, timestamps) = para

        end_timestamp = to_time_str(
            self.end_timestamp) if self.end_timestamp else None
        df = await self.yh_get_bars(http_session,
                                    entity,
                                    start=start,
                                    end=end_timestamp)

        if pd_valid(df):
            return False, time.time() - start_point, (ref_record,
                                                      self.format(entity, df))

        return True, time.time() - start_point, None
예제 #7
0
    async def record(self, entity, http_session, db_session, para):
        start_point = time.time()

        trade_day, column_names = StockTradeDay.query_data(
            region=self.region,
            provider=self.provider,
            db_session=db_session,
            func=func.max(StockTradeDay.timestamp))

        start = to_time_str(trade_day) if trade_day else "1990-12-19"

        df = self.bao_get_trade_days(start_date=start)

        if pd_valid(df):
            return False, time.time() - start_point, self.format(entity, df)

        return True, time.time() - start_point, None
    def record(self, entity, http_session, db_session, para):
        start_point = time.time()

        (ref_record, start, end, size, timestamps) = para

        # get stock info
        balance_sheet = self.yh_get_balance_sheet(entity.code)

        if balance_sheet is None or len(balance_sheet) == 0:
            return True, time.time() - start_point, None
        balance_sheet = balance_sheet.T

        balance_sheet['timestamp'] = balance_sheet.index

        if pd_valid(balance_sheet):
            return False, time.time() - start_point, (ref_record, self.format(entity, balance_sheet))

        return True, time.time() - start_point, None
예제 #9
0
async def init_main_index(region: Region, provider=Provider.Exchange):
    if region == Region.CHN:
        for item in CHINA_STOCK_MAIN_INDEX:
            item['timestamp'] = to_pd_timestamp(item['timestamp'])
        df = pd.DataFrame(CHINA_STOCK_MAIN_INDEX)
    elif region == Region.US:
        for item in US_STOCK_MAIN_INDEX:
            item['timestamp'] = to_pd_timestamp(item['timestamp'])
        df = pd.DataFrame(US_STOCK_MAIN_INDEX)
    else:
        print("index not initialized, in file: init_main_index")
        df = pd.DataFrame()

    if pd_valid(df):
        await df_to_db(region=region,
                       provider=provider,
                       data_schema=Index,
                       db_session=get_db_session(region, provider, Index),
                       df=df)
예제 #10
0
    async def persist(self, entity, http_session, db_session, para):
        start_point = time.time()

        (ref_record, df_record) = para
        saved_counts = 0
        is_finished = False

        if pd_valid(df_record):
            assert 'id' in df_record.columns
            saved_counts = await df_to_db(region=self.region,
                                          provider=self.provider,
                                          data_schema=self.data_schema,
                                          db_session=db_session,
                                          df=df_record,
                                          ref_df=ref_record,
                                          fix_duplicate_way=self.fix_duplicate_way)
            if saved_counts == 0:
                is_finished = True

        # could not get more data
        else:
            # not realtime
            if not self.real_time:
                is_finished = True

            # realtime and to the close time
            elif (self.close_hour is not None) and (self.close_minute is not None):
                now = now_pd_timestamp(self.region)
                if now.hour >= self.close_hour:
                    if now.minute - self.close_minute >= 5:
                        self.logger.info(f'{entity.id} now is the close time: {now}')
                        is_finished = True

        if isinstance(self, KDataRecorder):
            is_finished = True

        start_timestamp = to_time_str(df_record['timestamp'].min(axis=0))
        end_timestamp = to_time_str(df_record['timestamp'].max(axis=0))

        self.result = [saved_counts, start_timestamp, end_timestamp]

        return is_finished, time.time() - start_point, saved_counts
    async def process_loop(self, entity, http_session, db_session, kafka_producer, throttler):
        url = self.category_map_url.get(entity, None)
        if url is None:
            return

        async with throttler:
            async with http_session.get(url, headers=self.category_map_header[entity]) as response:
                if response.status != 200:
                    return

                response = await response.read()

                df = self.format(resp=response, exchange=entity)

                if pd_valid(df):
                    await self.persist(df, db_session)

            (taskid, desc) = self.share_para[1]
            data = {"task": taskid, "total": len(self.entities), "desc": desc, "leave": True, "update": 1}
            publish_message(kafka_producer, progress_topic, bytes(progress_key, encoding='utf-8'), bytes(json.dumps(data), encoding='utf-8'))
예제 #12
0
    async def eval_fetch_timestamps(self, entity, referenced_record,
                                    http_session):
        # get latest record
        latest_record = None
        try:
            if pd_valid(referenced_record):
                latest_record = referenced_record.loc[referenced_record[
                    self.get_evaluated_time_field()].idxmax()]
        except Exception as e:
            self.logger.warning(
                f"get referenced_record failed with error: {e}")

        if latest_record is not None:
            remote_record = self.get_remote_latest_record(entity, http_session)
            if remote_record is None or (latest_record.index
                                         == remote_record.id):
                return None, None, 0, None
            else:
                return None, None, 10, None

        return None, None, 1000, None
예제 #13
0
    async def record(self, entity, http_session, db_session, para):
        start_point = time.time()

        (ref_record, start, end, size, timestamps) = para

        start = to_time_str(start)
        if self.bao_trading_level in ['d', 'w', 'm']:
            start = max(start, "1990-12-19")
        else:
            start = max(start, "1999-07-26")

        df = self.bao_get_bars(to_bao_entity_id(entity),
                               start=start,
                               end=end if end is None else to_time_str(end),
                               frequency=self.bao_trading_level,
                               fields=to_bao_trading_field(self.bao_trading_level),
                               adjustflag=to_bao_adjust_flag(self.adjust_type))

        if pd_valid(df):
            return False, time.time() - start_point, (ref_record, self.format(entity, df))

        return True, time.time() - start_point, None
예제 #14
0
    async def eval_fetch_timestamps(self, entity, ref_record, http_session):
        latest_timestamp = None
        try:
            if pd_valid(ref_record):
                time_field = self.get_evaluated_time_field()
                latest_timestamp = ref_record[time_field].max(axis=0)
        except Exception as e:
            self.logger.warning(f'get ref record failed with error: {e}')

        if not latest_timestamp:
            latest_timestamp = entity.timestamp

        if not latest_timestamp:
            return self.start_timestamp, self.end_timestamp, self.default_size, None

        now = now_pd_timestamp(self.region)
        now_end = now.replace(hour=18, minute=0, second=0)

        trade_day_index = 0
        if len(self.trade_day) > 0:
            if is_same_date(self.trade_day[trade_day_index], now) and now < now_end:
                trade_day_index = 1
            end = self.trade_day[trade_day_index]
        else:
            end = now

        start_timestamp = next_dates(latest_timestamp)
        start = max(self.start_timestamp, start_timestamp) if self.start_timestamp else start_timestamp

        if start >= end:
            return start, end, 0, None

        size = self.eval_size_of_timestamp(start_timestamp=start,
                                           end_timestamp=end,
                                           level=self.level,
                                           one_day_trading_minutes=self.one_day_trading_minutes)

        return start, end, size, None
예제 #15
0
async def df_to_db(region: Region,
                   provider: Provider,
                   data_schema: DeclarativeMeta,
                   db_session,
                   df: pd.DataFrame,
                   ref_df: pd.DataFrame = None,
                   drop_duplicates: bool = True,
                   fix_duplicate_way: str = 'ignore',
                   force_update=False) -> object:
    now = time.time()

    if not pd_valid(df):
        return 0

    if drop_duplicates and df.duplicated(subset='id').any():
        df.drop_duplicates(subset='id', keep='last', inplace=True)

    schema_cols = get_schema_columns(data_schema)
    cols = list(set(df.columns.tolist()) & set(schema_cols))

    if not cols:
        logger.error(f"{data_schema.__tablename__} get columns failed")
        return 0

    df = df[cols]

    # force update mode, delete duplicate id data, and rewrite new data back
    if force_update:
        ids = df["id"].tolist()
        if len(ids) == 1:
            sql = f"delete from {data_schema.__tablename__} where id = '{ids[0]}'"
        else:
            sql = f"delete from {data_schema.__tablename__} where id in {tuple(ids)}"

        try:
            db_session.execute(sql)
        except Exception as e:
            logger.error(
                f"query {data_schema.__tablename__} failed with error: {e}")

        try:
            db_session.commit()
        except Exception as e:
            logger.error(f'df_to_db {data_schema.__tablename__}, error: {e}')
            db_session.rollback()
        df_new = df

    else:
        if ref_df is None:
            data, column_names = data_schema.query_data(
                region=region,
                provider=provider,
                db_session=db_session,
                columns=[data_schema.id, data_schema.timestamp])

            if data and len(data) > 0:
                ref_df = pd.DataFrame(data, columns=column_names)

        if pd_valid(ref_df):
            df_new = df[~df.id.isin(ref_df.id)]
        else:
            df_new = df

        # 不能单靠ID决定是否新增,要全量比对
        # if fix_duplicate_way == 'add':
        #     df_add = df[df.id.isin(ref_df.id)]
        #     if not df_add.empty:
        #         df_add.id = uuid.uuid1()
        #         df_new = pd.concat([df_new, df_add])

    cost = PRECISION_STR.format(time.time() - now)
    logger.debug(f"remove duplicated: {cost}")

    saved = 0
    if pd_valid(df_new):
        saved = to_postgresql(region, df_new, data_schema.__tablename__)

    cost = PRECISION_STR.format(time.time() - now)
    logger.debug(f"write db: {cost}, size: {saved}")

    return saved
예제 #16
0
 def to_open(s):
     if pd_valid(s):
         return s[0]
예제 #17
0
 def to_close(s):
     if pd_valid(s):
         return s[-1]
    async def on_finish_entity(self, entity, http_session, db_session):
        total_time = await super().on_finish_entity(entity, http_session,
                                                    db_session)

        if not self.fetch_jq_timestamp:
            return total_time

        now = time.time()

        # fill the timestamp for report published date
        the_data_list, column_names = self.data_schema.query_data(
            region=self.region,
            provider=self.provider,
            db_session=db_session,
            entity_id=entity.id,
            order=self.data_schema.timestamp.asc(),
            filters=[
                self.data_schema.timestamp == self.data_schema.report_date
            ])

        if the_data_list and len(the_data_list) > 0:
            data, column_names = FinanceFactor.query_data(
                region=self.region,
                provider=self.provider,
                db_session=get_db_session(self.region, self.provider,
                                          FinanceFactor),
                entity_id=entity.id,
                columns=[
                    FinanceFactor.timestamp, FinanceFactor.report_date,
                    FinanceFactor.id
                ],
                filters=[
                    FinanceFactor.timestamp != FinanceFactor.report_date,
                    FinanceFactor.report_date >= the_data_list[0].report_date,
                    FinanceFactor.report_date <= the_data_list[-1].report_date
                ])

            if data and len(data) > 0:
                df = pd.DataFrame(data, columns=column_names)
                df = index_df(df,
                              index='report_date',
                              time_field='report_date')

                if pd_valid(df):
                    for the_data in the_data_list:
                        if the_data.report_date in df.index:
                            the_data.timestamp = df.at[the_data.report_date,
                                                       'timestamp']
                            self.logger.info(
                                'db fill {} {} timestamp:{} for report_date:{}'
                                .format(self.data_schema.__name__, entity.id,
                                        the_data.timestamp,
                                        the_data.report_date))
                    try:
                        db_session.commit()
                    except Exception as e:
                        self.logger.error(
                            f'{self.__class__.__name__}, error: {e}')
                        db_session.rollback()

        return total_time + time.time() - now