def on_finish(self):
        desc = DividendFinancing.__name__ + ": update relevant table"
        with tqdm(total=len(self.entities),
                  ncols=90,
                  desc=desc,
                  position=2,
                  leave=True) as pbar:
            session = get_db_session(region=self.region,
                                     provider=self.provider,
                                     data_schema=self.data_schema)

            for entity in self.entities:
                code_security = {}
                code_security[entity.code] = entity

                need_fill_items = DividendFinancing.query_data(
                    region=self.region,
                    provider=self.provider,
                    codes=list(code_security.keys()),
                    return_type='domain',
                    filters=[
                        DividendFinancing.ipo_raising_fund.is_(None),
                        DividendFinancing.ipo_issues != 0
                    ])
                for need_fill_item in need_fill_items:
                    need_fill_item.ipo_raising_fund = code_security[
                        entity.code].raising_fund
                    session.commit()
                pbar.update()

        super().on_finish()
    def on_finish(self):
        last_year = str(now_pd_timestamp(self.region).year)
        codes = [item.code for item in self.entities]
        need_filleds = DividendFinancing.query_data(region=self.region,
                                                    provider=self.provider,
                                                    codes=codes,
                                                    return_type='domain',
                                                    filters=[DividendFinancing.rights_raising_fund.is_(None)],
                                                    end_timestamp=last_year)

        desc = RightsIssueDetail.__name__ + ": update relevant table"
        with tqdm(total=len(need_filleds), ncols=90, desc=desc, position=2, leave=True) as pbar:
            from sqlalchemy import func

            session = get_db_session(region=self.region,
                                     provider=self.provider,
                                     data_schema=self.data_schema)

            for item in need_filleds:
                result = RightsIssueDetail.query_data(
                                                region=self.region,
                                                provider=self.provider,
                                                entity_id=item.entity_id,
                                                start_timestamp=item.timestamp,
                                                end_timestamp="{}-12-31".format(item.timestamp.year),
                                                return_type='func',
                                                func=func.sum(RightsIssueDetail.rights_raising_fund))

                if isinstance(result, (int, float)):
                    item.rights_raising_fund = result
                    session.commit()
                pbar.update()

        super().on_finish()
Esempio n. 3
0
    def init_account(self) -> AccountStats:
        trader_info = get_trader_info(self.region,
                                      trader_name=self.trader_name,
                                      return_type='domain',
                                      limit=1)

        if trader_info:
            self.logger.warning(
                "trader:{} has run before,old result would be deleted".format(
                    self.trader_name))
            session = get_db_session(region=Region.CHN,
                                     provider=Provider.ZVT,
                                     data_schema=TraderInfo)
            session.query(TraderInfo).filter(
                TraderInfo.trader_name == self.trader_name).delete()
            session.query(AccountStats).filter(
                AccountStats.trader_name == self.trader_name).delete()
            session.query(Position).filter(
                Position.trader_name == self.trader_name).delete()
            session.query(Order).filter(
                Order.trader_name == self.trader_name).delete()
            session.commit()

        return AccountStats(entity_id=f'trader_zvt_{self.trader_name}',
                            timestamp=self.start_timestamp,
                            trader_name=self.trader_name,
                            cash=self.base_capital,
                            input_money=self.base_capital,
                            all_value=self.base_capital,
                            value=0,
                            closing=False)
Esempio n. 4
0
    def on_finish_entity(self, entity, http_session):
        kdatas = get_kdata(region=self.region,
                           provider=self.provider,
                           entity_id=entity.id,
                           level=IntervalLevel.LEVEL_1DAY.value,
                           order=Etf1dKdata.timestamp.asc(),
                           return_type='domain',
                           filters=[Etf1dKdata.cumulative_net_value.is_(None)])

        if kdatas and len(kdatas) > 0:
            start = kdatas[0].timestamp
            end = kdatas[-1].timestamp

            # 从东方财富获取基金累计净值
            df = self.fetch_cumulative_net_value(entity, start, end,
                                                 http_session)

            if pd_is_not_null(df):
                for kdata in kdatas:
                    if kdata.timestamp in df.index:
                        kdata.cumulative_net_value = df.loc[kdata.timestamp,
                                                            'LJJZ']
                        kdata.change_pct = df.loc[kdata.timestamp, 'JZZZL']
                session = get_db_session(region=self.region,
                                         provider=self.provider,
                                         data_schema=self.data_schema)
                session.commit()
                self.logger.info(f'{entity.code} - {entity.name}累计净值更新完成...')
Esempio n. 5
0
    def on_finish_entity(self, entity, http_session):
        super().on_finish_entity(entity, http_session)

        if not self.fetch_jq_timestamp:
            return

        # fill the timestamp for report published date
        the_data_list = get_data(
            region=self.region,
            data_schema=self.data_schema,
            provider=self.provider,
            entity_id=entity.id,
            order=self.data_schema.timestamp.asc(),
            return_type='domain',
            filters=[
                self.data_schema.timestamp == self.data_schema.report_date
            ])
        if the_data_list:
            if self.data_schema == FinanceFactor:
                for the_data in the_data_list:
                    self.fill_timestamp_with_jq(entity, the_data)
            else:
                df = FinanceFactor.query_data(
                    region=self.region,
                    entity_id=entity.id,
                    columns=[
                        FinanceFactor.timestamp, FinanceFactor.report_date,
                        FinanceFactor.id
                    ],
                    filters=[
                        FinanceFactor.timestamp != FinanceFactor.report_date,
                        FinanceFactor.report_date >=
                        the_data_list[0].report_date, FinanceFactor.report_date
                        <= the_data_list[-1].report_date
                    ])

                if pd_is_not_null(df):
                    index_df(df, index='report_date', time_field='report_date')

                for the_data in the_data_list:
                    if pd_is_not_null(df) and the_data.report_date in df.index:
                        the_data.timestamp = df.at[the_data.report_date,
                                                   'timestamp']
                        self.logger.info(
                            'db fill {} {} timestamp:{} for report_date:{}'.
                            format(self.data_schema, entity.id,
                                   the_data.timestamp, the_data.report_date))
                        session = get_db_session(region=self.region,
                                                 provider=self.provider,
                                                 data_schema=self.data_schema)
                        session.commit()
                    else:
                        # self.logger.info(
                        #     'waiting jq fill {} {} timestamp:{} for report_date:{}'.format(self.data_schema,
                        #                                                                    security_item.id,
                        #                                                                    the_data.timestamp,
                        #                                                                    the_data.report_date))

                        self.fill_timestamp_with_jq(entity, the_data)
Esempio n. 6
0
def get_order_securities(trader_name):
    items = get_db_session(region=Region.CHN,
                           provider=Provider.ZVT,
                           data_schema=Order).query(Order.entity_id).filter(
                               Order.trader_name == trader_name).group_by(
                                   Order.entity_id).all()

    return [item[0] for item in items]
Esempio n. 7
0
def clear_trader(trader_name, session=None):
    if not session:
        session = get_db_session(region=Region.CHN,
                                 provider=Provider.ZVT,
                                 data_schema=TraderInfo)
    session.query(TraderInfo).filter(
        TraderInfo.trader_name == trader_name).delete()
    session.query(AccountStats).filter(
        AccountStats.trader_name == trader_name).delete()
    session.query(Position).filter(
        Position.trader_name == trader_name).delete()
    session.query(Order).filter(Order.trader_name == trader_name).delete()
    session.commit()
Esempio n. 8
0
    def fill_timestamp_with_jq(self, security_item, the_data):
        # get report published date from jq
        df = jq_get_fundamentals(table='indicator', code=to_jq_entity_id(security_item),
                                 columns='pubDate', date=to_jq_report_period(the_data.report_date),
                                 count=None, parse_dates=['pubDate'])
        if pd_is_not_null(df):
            the_data.timestamp = to_pd_timestamp(df['pubDate'][0])

        self.logger.info('jq fill {} {} timestamp:{} for report_date:{}'.format(
            self.data_schema.__name__, security_item.id, the_data.timestamp, the_data.report_date))
        session = get_db_session(region=self.region,
                                 provider=self.provider,
                                 data_schema=self.data_schema)
        session.commit()
Esempio n. 9
0
def get_group(region: Region,
              provider: Provider,
              data_schema,
              column,
              group_func=func.count,
              session=None):
    if not session:
        session = get_db_session(region=region,
                                 provider=provider,
                                 data_schema=data_schema)
    if group_func:
        query = session.query(column, group_func(column)).group_by(column)
    else:
        query = session.query(column).group_by(column)
    df = pd.read_sql(query.statement, query.session.bind)
    return df
Esempio n. 10
0
def del_data(region: Region,
             data_schema: Type[Mixin],
             filters: List = None,
             provider: Provider = None):
    if not provider:
        provider = data_schema.providers[region][0]

    session = get_db_session(region=region,
                             provider=provider,
                             data_schema=data_schema)
    query = session.query(data_schema)
    if filters:
        for f in filters:
            query = query.filter(f)
    query.delete()
    session.commit()
Esempio n. 11
0
    def on_start(self):
        # run all the selectors
        for selector in self.selectors:
            # run for the history data at first
            selector.run()

        if self.entity_ids:
            entity_ids = json.dumps(self.entity_ids)
        else:
            entity_ids = None

        if self.exchanges:
            exchanges = json.dumps(self.exchanges)
        else:
            exchanges = None

        if self.codes:
            codes = json.dumps(self.codes)
        else:
            codes = None

        entity_type = self.entity_schema.__name__.lower()

        sim_account = TraderInfo(
            id=self.trader_name,
            entity_id=f'trader_zvt_{self.trader_name}',
            timestamp=self.start_timestamp,
            trader_name=self.trader_name,
            entity_type=entity_type,
            entity_ids=entity_ids,
            exchanges=exchanges,
            codes=codes,
            start_timestamp=self.start_timestamp,
            end_timestamp=self.end_timestamp,
            provider=self.provider,
            level=self.level.value,
            real_time=self.real_time,
            kdata_use_begin_time=self.kdata_use_begin_time,
            kdata_adjust_type=self.adjust_type.value)

        session = get_db_session(region=Region.CHN,
                                 provider=Provider.ZVT,
                                 data_schema=TraderInfo)
        session.add(sim_account)
        session.commit()
 def recompute_qfq(self, entity, qfq_factor, last_timestamp):
     # 重新计算前复权数据
     if qfq_factor != 0:
         kdatas = get_kdata(region=self.region,
                            provider=self.provider,
                            entity_id=entity.id,
                            level=self.level.value,
                            order=self.data_schema.timestamp.asc(),
                            return_type='domain',
                            filters=[self.data_schema.timestamp < last_timestamp])
         if kdatas:
             self.logger.info('recomputing {} qfq kdata,factor is:{}'.format(entity.code, qfq_factor))
             for kdata in kdatas:
                 kdata.open = round(kdata.open * qfq_factor, 2)
                 kdata.close = round(kdata.close * qfq_factor, 2)
                 kdata.high = round(kdata.high * qfq_factor, 2)
                 kdata.low = round(kdata.low * qfq_factor, 2)
             session = get_db_session(region=self.region,
                                      provider=self.provider,
                                      data_schema=self.data_schema)
             session.bulk_save_objects(kdatas)
             session.commit()
Esempio n. 13
0
def get_data(region: Region,
             data_schema,
             ids: List[str] = None,
             entity_ids: List[str] = None,
             entity_id: str = None,
             codes: List[str] = None,
             code: str = None,
             level: Union[IntervalLevel, str] = None,
             provider: Provider = Provider.Default,
             columns: List = None,
             col_label: dict = None,
             return_type: str = 'df',
             start_timestamp: Union[pd.Timestamp, str] = None,
             end_timestamp: Union[pd.Timestamp, str] = None,
             filters: List = None,
             session: Session = None,
             order=None,
             limit: int = None,
             index: Union[str, list] = None,
             time_field: str = 'timestamp',
             fun=None):
    assert data_schema is not None
    assert provider.value is not None
    assert provider in zvt_context.providers[region]

    step1 = time.time()
    precision_str = '{' + ':>{},.{}f'.format(8, 4) + '}'

    if not session:
        session = get_db_session(region=region,
                                 provider=provider,
                                 data_schema=data_schema)

    time_col = eval('data_schema.{}'.format(time_field))

    if fun is not None:
        query = session.query(fun)
    elif columns:
        # support str
        if type(columns[0]) == str:
            columns_ = []
            for col in columns:
                assert isinstance(col, str)
                columns_.append(eval('data_schema.{}'.format(col)))
            columns = columns_

        # make sure get timestamp
        if time_col not in columns:
            columns.append(time_col)

        if col_label:
            columns_ = []
            for col in columns:
                if col.name in col_label:
                    columns_.append(col.label(col_label.get(col.name)))
                else:
                    columns_.append(col)
            columns = columns_

        query = session.query(*columns)
    else:
        query = session.query(data_schema)

    if zvt_config['debug'] == 2:
        cost = precision_str.format(time.time() - step1)
        logger.debug("get_data query column: {}".format(cost))

    if entity_id is not None:
        query = query.filter(data_schema.entity_id == entity_id)
    if entity_ids is not None:
        query = query.filter(data_schema.entity_id.in_(entity_ids))
    if code is not None:
        query = query.filter(data_schema.code == code)
    if codes is not None:
        query = query.filter(data_schema.code.in_(codes))
    if ids is not None:
        query = query.filter(data_schema.id.in_(ids))

    # we always store different level in different schema,the level param is not useful now
    # if level:
    #     try:
    #         # some schema has no level,just ignore it
    #         data_schema.level
    #         if type(level) == IntervalLevel:
    #             level = level.value
    #         query = query.filter(data_schema.level == level)
    #     except Exception as _:
    #         pass

    query = common_filter(query,
                          data_schema=data_schema,
                          start_timestamp=start_timestamp,
                          end_timestamp=end_timestamp,
                          filters=filters,
                          order=order,
                          limit=limit,
                          time_field=time_field)

    if zvt_config['debug'] == 2:
        cost = precision_str.format(time.time() - step1)
        logger.debug("get_data query common: {}".format(cost))

    if return_type == 'func':
        result = query.scalar()
        return result

    elif return_type == 'df':
        df = pd.read_sql(query.statement, query.session.bind, index_col=['id'])
        if pd_is_not_null(df):
            if index:
                df = index_df(df, index=index, time_field=time_field)

        if zvt_config['debug'] == 2:
            cost = precision_str.format(time.time() - step1)
            logger.debug("get_data do query cost: {} type: {} size: {}".format(
                cost, return_type, len(df)))
        return df

    elif return_type == 'domain':
        # if limit is not None and limit == 1:
        #     result = [query.first()]
        # else:
        #     result = list(window_query(query, window_size, step1))
        # result = list(query.yield_per(window_size))

        if zvt_config['debug'] == 2:
            with profiled():
                result = query.all()
        else:
            result = query.all()

        if zvt_config['debug'] == 2:
            cost = precision_str.format(time.time() - step1)
            res_cnt = len(result) if result else 0
            logger.debug(
                "get_data do query cost: {} type: {} limit: {} size: {}".
                format(cost, return_type, limit, res_cnt))

        return result

    elif return_type == 'dict':
        # if limit is not None and limit == 1:
        #     result = [item.__dict__ for item in query.first()]
        # else:
        #     result = [item.__dict__ for item in list(window_query(query, window_size, step1))]
        # result = [item.__dict__ for item in list(query.yield_per(window_size))]

        if zvt_config['debug'] == 2:
            with profiled():
                result = [item.__dict__ for item in query.all()]
        else:
            result = [item.__dict__ for item in query.all()]

        if zvt_config['debug'] == 2:
            cost = precision_str.format(time.time() - step1)
            res_cnt = len(result) if result else 0
            logger.debug(
                "get_data do query cost: {} type: {} limit: {} size: {}".
                format(cost, return_type, limit, res_cnt))

        return result
Esempio n. 14
0
    def process_loop(self, entity, http_session):
        assert isinstance(entity, StockDetail)

        step1 = time.time()
        precision_str = '{' + ':>{},.{}f'.format(8, 4) + '}'

        self.result = None

        if entity.exchange == 'sh':
            fc = "{}01".format(entity.code)
        if entity.exchange == 'sz':
            fc = "{}02".format(entity.code)

        # 基本资料
        param = {"color": "w", "fc": fc, "SecurityCode": "SZ300059"}
        url = 'https://emh5.eastmoney.com/api/GongSiGaiKuang/GetJiBenZiLiao'
        json_result = sync_post(http_session, url, json=param)
        if json_result is None:
            return

        resp_json = json_result['JiBenZiLiao']

        entity.profile = resp_json['CompRofile']
        entity.main_business = resp_json['MainBusiness']
        entity.date_of_establishment = to_pd_timestamp(resp_json['FoundDate'])

        # 关联行业
        industry = ','.join(resp_json['Industry'].split('-'))
        entity.industry = industry

        # 关联概念
        entity.sector = resp_json['Block']

        # 关联地区
        entity.area = resp_json['Provice']

        self.sleep()

        # 发行相关
        param = {"color": "w", "fc": fc}
        url = 'https://emh5.eastmoney.com/api/GongSiGaiKuang/GetFaXingXiangGuan'
        json_result = sync_post(http_session, url, json=param)
        if json_result is None:
            return

        resp_json = json_result['FaXingXiangGuan']

        entity.issue_pe = to_float(resp_json['PEIssued'])
        entity.price = to_float(resp_json['IssuePrice'])
        entity.issues = to_float(resp_json['ShareIssued'])
        entity.raising_fund = to_float((resp_json['NetCollection']))
        entity.net_winning_rate = pct_to_float(resp_json['LotRateOn'])

        session = get_db_session(region=self.region,
                                 provider=self.provider,
                                 data_schema=self.data_schema)
        session.commit()

        cost = precision_str.format(time.time() - step1)

        prefix = "finish~ " if zvt_config['debug'] else ""
        postfix = "\n" if zvt_config['debug'] else ""

        if self.result is not None:
            self.logger.info(
                "{}{}, {}, time: {}, size: {:>7,}, date: [ {}, {} ]{}".format(
                    prefix, self.data_schema.__name__, entity.id, cost,
                    self.result[0], self.result[1], self.result[2], postfix))
        else:
            self.logger.info("{}{}, {}, time: {}{}".format(
                prefix, self.data_schema.__name__, entity.id, cost, postfix))
Esempio n. 15
0
    def update_position(self, current_position, order_amount, current_price,
                        order_type, timestamp):
        """

        :param timestamp:
        :type timestamp:
        :param current_position:
        :type current_position: Position
        :param order_amount:
        :type order_amount:
        :param current_price:
        :type current_price:
        :param order_type:
        :type order_type:
        """
        if order_type == ORDER_TYPE_LONG:
            need_money = (order_amount * current_price) * (1 + self.slippage +
                                                           self.buy_cost)
            if self.account.cash < need_money:
                if self.rich_mode:
                    self.input_money()
                else:
                    raise NotEnoughMoneyError()

            self.account.cash -= need_money

            # 计算平均价
            long_amount = current_position.long_amount + order_amount
            if long_amount == 0:
                current_position.average_long_price = 0
            current_position.average_long_price = (
                current_position.average_long_price *
                current_position.long_amount +
                current_price * order_amount) / long_amount

            current_position.long_amount = long_amount

            if current_position.trading_t == 0:
                current_position.available_long += order_amount

        elif order_type == ORDER_TYPE_SHORT:
            need_money = (order_amount * current_price) * (1 + self.slippage +
                                                           self.buy_cost)
            if self.account.cash < need_money:
                if self.rich_mode:
                    self.input_money()
                else:
                    raise NotEnoughMoneyError()

            self.account.cash -= need_money

            short_amount = current_position.short_amount + order_amount
            current_position.average_short_price = (
                current_position.average_short_price *
                current_position.short_amount +
                current_price * order_amount) / short_amount

            current_position.short_amount = short_amount

            if current_position.trading_t == 0:
                current_position.available_short += order_amount

        elif order_type == ORDER_TYPE_CLOSE_LONG:
            self.account.cash += (order_amount * current_price *
                                  (1 - self.slippage - self.sell_cost))
            # FIXME:如果没卖完,重新计算计算平均价

            current_position.available_long -= order_amount
            current_position.long_amount -= order_amount

        elif order_type == ORDER_TYPE_CLOSE_SHORT:
            self.account.cash += 2 * (order_amount *
                                      current_position.average_short_price)
            self.account.cash -= order_amount * current_price * (
                1 + self.slippage + self.sell_cost)

            current_position.available_short -= order_amount
            current_position.short_amount -= order_amount

        # save the order info to db
        order_id = '{}_{}_{}_{}'.format(
            self.trader_name, order_type, current_position.entity_id,
            to_time_str(timestamp, TIME_FORMAT_ISO8601))
        order = Order(id=order_id,
                      timestamp=to_pd_timestamp(timestamp),
                      trader_name=self.trader_name,
                      entity_id=current_position.entity_id,
                      order_price=current_price,
                      order_amount=order_amount,
                      order_type=order_type,
                      level=self.level.value,
                      status='success')

        session = get_db_session(region=Region.CHN,
                                 provider=Provider.ZVT,
                                 data_schema=Order)
        session.add(order)
        session.commit()
Esempio n. 16
0
    def on_trading_close(self, timestamp):
        self.logger.info('on_trading_close:{}'.format(timestamp))
        # remove the empty position
        self.account.positions = [
            position for position in self.account.positions
            if position.long_amount > 0 or position.short_amount > 0
        ]

        # clear the data which need recomputing
        the_id = '{}_{}'.format(self.trader_name,
                                to_time_str(timestamp, TIME_FORMAT_ISO8601))

        self.account.value = 0
        self.account.all_value = 0
        for position in self.account.positions:
            entity_type, _, _ = decode_entity_id(position.entity_id)
            data_schema = get_kdata_schema(entity_type,
                                           level=IntervalLevel.LEVEL_1DAY,
                                           adjust_type=self.adjust_type)

            kdata = get_kdata(provider=self.provider,
                              level=IntervalLevel.LEVEL_1DAY,
                              entity_id=position.entity_id,
                              order=data_schema.timestamp.desc(),
                              end_timestamp=timestamp,
                              limit=1,
                              adjust_type=self.adjust_type)

            closing_price = kdata['close'][0]

            position.available_long = position.long_amount
            position.available_short = position.short_amount

            if closing_price:
                if (position.long_amount
                        is not None) and position.long_amount > 0:
                    position.value = position.long_amount * closing_price
                    self.account.value += position.value
                elif (position.short_amount
                      is not None) and position.short_amount > 0:
                    position.value = 2 * (position.short_amount *
                                          position.average_short_price)
                    position.value -= position.short_amount * closing_price
                    self.account.value += position.value

                # refresh profit
                position.profit = (closing_price - position.average_long_price
                                   ) * position.long_amount
                position.profit_rate = position.profit / (
                    position.average_long_price * position.long_amount)

            else:
                self.logger.warning(
                    'could not refresh close value for position:{},timestamp:{}'
                    .format(position.entity_id, timestamp))

            position.id = '{}_{}_{}'.format(
                self.trader_name, position.entity_id,
                to_time_str(timestamp, TIME_FORMAT_ISO8601))
            position.timestamp = to_pd_timestamp(timestamp)
            position.account_stats_id = the_id

        self.account.id = the_id
        self.account.all_value = self.account.value + self.account.cash
        self.account.closing = True
        self.account.timestamp = to_pd_timestamp(timestamp)
        self.account.profit = (
            self.account.all_value -
            self.account.input_money) / self.account.input_money

        session = get_db_session(region=Region.CHN,
                                 provider=Provider.ZVT,
                                 data_schema=AccountStats)
        session.add(self.account)
        session.commit()

        account_info = f'on_trading_close,holding size:{len(self.account.positions)} profit:{self.account.profit} input_money:{self.account.input_money} ' \
                       f'cash:{self.account.cash} value:{self.account.value} all_value:{self.account.all_value}'
        self.logger.info(account_info)
Esempio n. 17
0
def df_to_db(df: pd.DataFrame,
             ref_df: pd.DataFrame,
             region: Region,
             data_schema: DeclarativeMeta,
             provider: Provider,
             drop_duplicates: bool = True,
             fix_duplicate_way: str = 'ignore',
             force_update=False) -> object:
    step1 = time.time()
    precision_str = '{' + ':>{},.{}f'.format(8, 4) + '}'

    if not pd_is_not_null(df):
        return 0

    if drop_duplicates and df.duplicated(subset='id').any():
        df.drop_duplicates(subset='id', keep='last', inplace=True)

    schema_cols = get_schema_columns(data_schema)
    cols = list(set(df.columns.tolist()) & set(schema_cols))

    if not cols:
        logger.error("{} get columns failed".format(data_schema.__tablename__))
        return 0

    df = df[cols]

    # force update mode, delete duplicate id data in db, and write new data back
    if force_update:
        ids = df["id"].tolist()
        if len(ids) == 1:
            sql = f"delete from {data_schema.__tablename__} where id = '{ids[0]}'"
        else:
            sql = f"delete from {data_schema.__tablename__} where id in {tuple(ids)}"

        session = get_db_session(region=region,
                                 provider=provider,
                                 data_schema=data_schema)
        session.execute(sql)
        session.commit()
        df_new = df

    else:
        if ref_df is None:
            ref_df = get_data(region=region,
                              provider=provider,
                              columns=['id', 'timestamp'],
                              data_schema=data_schema,
                              return_type='df')
        if ref_df.empty:
            df_new = df
        else:
            df_new = df[~df.id.isin(ref_df.index)]

        # 不能单靠ID决定是否新增,要全量比对
        # if fix_duplicate_way == 'add':
        #     df_add = df[df.id.isin(ref_df.index)]
        #     if not df_add.empty:
        #         df_add.id = uuid.uuid1()
        #         df_new = pd.concat([df_new, df_add])

    cost = precision_str.format(time.time() - step1)
    logger.debug("remove duplicated: {}".format(cost))

    saved = 0
    if pd_is_not_null(df_new):
        saved = to_postgresql(region, df_new, data_schema.__tablename__)

    cost = precision_str.format(time.time() - step1)
    logger.debug("write db: {}, size: {}".format(cost, saved))

    return saved