Exemplo n.º 1
0
    def __init__(self,
                 entity_type='stock',
                 exchanges=['sh', 'sz'],
                 entity_ids=None,
                 codes=None,
                 batch_size=10,
                 force_update=False,
                 sleeping_time=10,
                 default_size=2000,
                 one_shot=False,
                 fix_duplicate_way='add',
                 start_timestamp=None,
                 end_timestamp=None,
                 contain_unfinished_data=False,
                 level=IntervalLevel.LEVEL_1DAY,
                 kdata_use_begin_time=False,
                 close_hour=0,
                 close_minute=0,
                 one_day_trading_minutes=24 * 60) -> None:
        super().__init__(entity_type, exchanges, entity_ids, codes, batch_size,
                         force_update, sleeping_time, default_size, one_shot,
                         fix_duplicate_way, start_timestamp, end_timestamp)

        self.level = IntervalLevel(level)
        # FIXME:should remove unfinished data when recording,always set it to False now
        self.contain_unfinished_data = contain_unfinished_data
        self.kdata_use_begin_time = kdata_use_begin_time
        self.close_hour = close_hour
        self.close_minute = close_minute
        self.one_day_trading_minutes = one_day_trading_minutes
Exemplo n.º 2
0
def is_in_same_interval(t1: pd.Timestamp, t2: pd.Timestamp,
                        level: IntervalLevel):
    t1 = to_pd_timestamp(t1)
    t2 = to_pd_timestamp(t2)
    if level == IntervalLevel.LEVEL_1WEEK:
        return t1.week == t2.week
    if level == IntervalLevel.LEVEL_1MON:
        return t1.month == t2.month

    return level.floor_timestamp(t1) == level.floor_timestamp(t2)
Exemplo n.º 3
0
def to_high_level_kdata(kdata_df: pd.DataFrame, to_level: IntervalLevel):
    def to_close(s):
        if pd_is_not_null(s):
            return s[-1]

    def to_open(s):
        if pd_is_not_null(s):
            return s[0]

    def to_high(s):
        return np.max(s)

    def to_low(s):
        return np.min(s)

    def to_sum(s):
        return np.sum(s)

    original_level = kdata_df['level'][0]
    entity_id = kdata_df['entity_id'][0]
    provider = kdata_df['provider'][0]
    name = kdata_df['name'][0]
    code = kdata_df['code'][0]

    entity_type, _, _ = decode_entity_id(entity_id=entity_id)

    assert IntervalLevel(original_level) <= IntervalLevel.LEVEL_1DAY
    assert IntervalLevel(original_level) < IntervalLevel(to_level)

    df: pd.DataFrame = None
    if to_level == IntervalLevel.LEVEL_1WEEK:
        # loffset='-2' 用周五作为时间标签
        if entity_type == 'stock':
            df = kdata_df.resample('W', loffset=pd.DateOffset(days=-2)).apply({'close': to_close,
                                                                               'open': to_open,
                                                                               'high': to_high,
                                                                               'low': to_low,
                                                                               'volume': to_sum,
                                                                               'turnover': to_sum})
        else:
            df = kdata_df.resample('W', loffset=pd.DateOffset(days=-2)).apply({'close': to_close,
                                                                               'open': to_open,
                                                                               'high': to_high,
                                                                               'low': to_low,
                                                                               'volume': to_sum,
                                                                               'turnover': to_sum})
    df = df.dropna()
    # id        entity_id  timestamp   provider    code  name level
    df['entity_id'] = entity_id
    df['provider'] = provider
    df['code'] = code
    df['name'] = name

    return df
Exemplo n.º 4
0
    def __init__(self,
                 exchanges=['sh', 'sz'],
                 entity_ids=None,
                 codes=None,
                 batch_size=10,
                 force_update=True,
                 sleeping_time=10,
                 default_size=2000,
                 real_time=False,
                 fix_duplicate_way='ignore',
                 start_timestamp=None,
                 end_timestamp=None,
                 level=IntervalLevel.LEVEL_1WEEK,
                 kdata_use_begin_time=False,
                 close_hour=15,
                 close_minute=0,
                 one_day_trading_minutes=4 * 60) -> None:
        level = IntervalLevel(level)
        self.data_schema = get_kdata_schema(entity_type='stock', level=level)
        self.jq_trading_level = to_jq_trading_level(level)

        super().__init__('stock', exchanges, entity_ids, codes, batch_size,
                         force_update, sleeping_time, default_size, real_time,
                         fix_duplicate_way, start_timestamp, end_timestamp,
                         close_hour, close_minute, level, kdata_use_begin_time,
                         one_day_trading_minutes)

        self.factor = 0
        self.last_timestamp = None

        auth(zvt_env['jq_username'], zvt_env['jq_password'])
Exemplo n.º 5
0
def coin_finished_timestamp(timestamp: pd.Timestamp, level: IntervalLevel):
    timestamp = to_pd_timestamp(timestamp)

    if timestamp.microsecond != 0:
        return False

    return timestamp.minute % level.to_minute() == 0
Exemplo n.º 6
0
    def __init__(self,
                 entity_type='stock',
                 exchanges=['sh', 'sz'],
                 entity_ids=None,
                 codes=None,
                 batch_size=10,
                 force_update=True,
                 sleeping_time=10,
                 default_size=2000,
                 real_time=False,
                 fix_duplicate_way='ignore',
                 start_timestamp=None,
                 end_timestamp=None,
                 close_hour=0,
                 close_minute=0,
                 # child add
                 level=IntervalLevel.LEVEL_1DAY,
                 kdata_use_begin_time=False,
                 one_day_trading_minutes=24 * 60) -> None:
        super().__init__(entity_type, exchanges, entity_ids, codes, batch_size, force_update, sleeping_time,
                         default_size, real_time, fix_duplicate_way, start_timestamp, end_timestamp, close_hour,
                         close_minute)

        self.level = IntervalLevel(level)
        self.kdata_use_begin_time = kdata_use_begin_time
        self.one_day_trading_minutes = one_day_trading_minutes
Exemplo n.º 7
0
def get_zen_factor_schema(entity_type: str,
                          level: Union[IntervalLevel, str] = IntervalLevel.LEVEL_1DAY):
    if type(level) == str:
        level = IntervalLevel(level)

    schema_str = '{}{}ZenFactor'.format(entity_type.capitalize(), level.value.capitalize())

    return eval(schema_str)
Exemplo n.º 8
0
def get_ma_state_stats_schema(entity_type: str,
                              level: Union[IntervalLevel, str] = IntervalLevel.LEVEL_1DAY):
    if type(level) == str:
        level = IntervalLevel(level)

    # ma state stats schema rule
    # 1)name:{SecurityType.value.capitalize()}{IntervalLevel.value.upper()}MaStateStats
    schema_str = '{}{}MaStateStats'.format(entity_type.capitalize(), level.value.capitalize())

    return eval(schema_str)
Exemplo n.º 9
0
def level_flag(level: IntervalLevel):
    level = IntervalLevel(level)
    if level == IntervalLevel.LEVEL_1DAY:
        return 101
    if level == IntervalLevel.LEVEL_1WEEK:
        return 102
    if level == IntervalLevel.LEVEL_1MON:
        return 103

    assert False
Exemplo n.º 10
0
def evaluate_size_from_timestamp(start_timestamp,
                                 level: IntervalLevel,
                                 one_day_trading_minutes,
                                 end_timestamp: pd.Timestamp = None):
    """
    given from timestamp,level,one_day_trading_minutes,this func evaluate size of kdata to current.
    it maybe a little bigger than the real size for fetching all the kdata.

    :param start_timestamp:
    :type start_timestamp: pd.Timestamp
    :param level:
    :type level: IntervalLevel
    :param one_day_trading_minutes:
    :type one_day_trading_minutes: int
    """
    if not end_timestamp:
        end_timestamp = pd.Timestamp.now()
    else:
        end_timestamp = to_pd_timestamp(end_timestamp)

    time_delta = end_timestamp - to_pd_timestamp(start_timestamp)

    one_day_trading_seconds = one_day_trading_minutes * 60

    if level == IntervalLevel.LEVEL_1DAY:
        return time_delta.days + 1

    if level == IntervalLevel.LEVEL_1WEEK:
        return int(math.ceil(time_delta.days / 7)) + 1

    if level == IntervalLevel.LEVEL_1MON:
        return int(math.ceil(time_delta.days / 30)) + 1

    if time_delta.days > 0:
        seconds = (time_delta.days + 1) * one_day_trading_seconds
        return int(math.ceil(seconds / level.to_second())) + 1
    else:
        seconds = time_delta.total_seconds()
        return min(
            int(math.ceil(seconds / level.to_second())) + 1,
            one_day_trading_seconds / level.to_second() + 1)
Exemplo n.º 11
0
def gen_quote_domain(providers: List[str], entity_type: str, levels):
    tables = []
    for level in levels:
        level = IntervalLevel(level)

        cap_entity_type = entity_type.capitalize()
        cap_level = level.value.capitalize()

        if level != IntervalLevel.LEVEL_TICK:
            kdata_common = f'{cap_entity_type}KdataCommon'
        else:
            kdata_common = f'{cap_entity_type}TickCommon'

        class_name = f'{cap_entity_type}{cap_level}Kdata'
        table_name = f'{entity_type}_{level.value}_kdata'
        tables.append(table_name)

        domain_template = f'''# -*- coding: utf-8 -*-
# this file is generated by gen_quote_domain function, dont't change it
from sqlalchemy.ext.declarative import declarative_base

from zvdata.contract import register_schema
from zvt.domain.quotes import {kdata_common}

KdataBase = declarative_base()


class {class_name}(KdataBase, {kdata_common}):
    __tablename__ = '{table_name}'


register_schema(providers={providers}, db_name='{table_name}', schema_base=KdataBase)

__all__ = ['{class_name}']
'''
        # generate the domain
        with open(os.path.join(entity_type, f'{table_name}.py'),
                  'w') as outfile:
            outfile.write(domain_template)

    # generate the package
    imports = [
        f'from zvt.domain.quotes.{entity_type}.{table} import *'
        for table in tables
    ]
    imports_str = '\n'.join(imports)

    package_template = '''# -*- coding: utf-8 -*-
# this file is generated by gen_quote_domain function, dont't change it
''' + imports_str

    with open(os.path.join(entity_type, '__init__.py'), 'w') as outfile:
        outfile.write(package_template)
Exemplo n.º 12
0
                         keep_all_timestamp, fill_method, effective_number,
                         transformer, acc, persist_factor, dry_run)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--level',
                        help='trading level',
                        default='1d',
                        choices=[item.value for item in IntervalLevel])
    parser.add_argument('--start', help='start code', default='000001')
    parser.add_argument('--end', help='end code', default='000005')

    args = parser.parse_args()

    level = IntervalLevel(args.level)
    start = args.start
    end = args.end

    entities = get_entities(provider='eastmoney',
                            entity_type='stock',
                            columns=[Stock.entity_id, Stock.code],
                            filters=[Stock.code >= start, Stock.code < end])

    codes = entities.index.to_list()

    factor = ZenFactor(codes=codes,
                       start_timestamp='2005-01-01',
                       end_timestamp=now_pd_timestamp(),
                       level=level)
Exemplo n.º 13
0
class FixedCycleDataRecorder(TimeSeriesDataRecorder):
    def __init__(self,
                 entity_type='stock',
                 exchanges=['sh', 'sz'],
                 entity_ids=None,
                 codes=None,
                 batch_size=10,
                 force_update=False,
                 sleeping_time=10,
                 default_size=2000,
                 one_shot=False,
                 fix_duplicate_way='add',
                 start_timestamp=None,
                 end_timestamp=None,
                 contain_unfinished_data=False,
                 level=IntervalLevel.LEVEL_1DAY,
                 kdata_use_begin_time=False,
                 close_hour=0,
                 close_minute=0,
                 one_day_trading_minutes=24 * 60) -> None:
        super().__init__(entity_type, exchanges, entity_ids, codes, batch_size,
                         force_update, sleeping_time, default_size, one_shot,
                         fix_duplicate_way, start_timestamp, end_timestamp)

        self.level = IntervalLevel(level)
        # FIXME:should remove unfinished data when recording,always set it to False now
        self.contain_unfinished_data = contain_unfinished_data
        self.kdata_use_begin_time = kdata_use_begin_time
        self.close_hour = close_hour
        self.close_minute = close_minute
        self.one_day_trading_minutes = one_day_trading_minutes

    def get_latest_saved_record(self, entity):
        order = eval('self.data_schema.{}.desc()'.format(
            self.get_evaluated_time_field()))

        return get_data(entity_id=entity.id,
                        provider=self.provider,
                        data_schema=self.data_schema,
                        order=order,
                        limit=1,
                        return_type='domain',
                        session=self.session,
                        level=self.level.value)

    def evaluate_start_end_size_timestamps(self, entity):
        # get latest record
        latest_saved_record = self.get_latest_saved_record(entity=entity)

        if latest_saved_record:
            latest_timestamp = latest_saved_record[0].timestamp
        else:
            latest_timestamp = entity.timestamp

        if not latest_timestamp:
            return latest_timestamp, None, self.default_size, None

        current_time = pd.Timestamp.now()
        time_delta = current_time - latest_timestamp

        if self.level == IntervalLevel.LEVEL_1DAY:
            if is_same_date(current_time, latest_timestamp):
                return latest_timestamp, None, 0, None
            return latest_timestamp, None, time_delta.days + 1, None

        # to today,check closing time
        # 0,0 means never stop,e.g,coin
        if (self.close_hour != 0
                and self.close_minute != 0) and time_delta.days == 0:
            if latest_timestamp.hour == self.close_hour and latest_timestamp.minute == self.close_minute:
                return latest_timestamp, None, 0, None

        if self.kdata_use_begin_time:
            touching_timestamp = latest_timestamp + pd.Timedelta(
                seconds=self.level.to_second())
        else:
            touching_timestamp = latest_timestamp

        waiting_seconds, size = self.level.count_from_timestamp(
            touching_timestamp,
            one_day_trading_minutes=self.one_day_trading_minutes)
        if not self.one_shot and waiting_seconds and (waiting_seconds > 30):
            t = waiting_seconds / 2
            self.logger.info(
                'level:{},recorded_time:{},touching_timestamp:{},current_time:{},next_ok_time:{},just sleep:{} seconds'
                .format(
                    self.level.value, latest_timestamp, touching_timestamp,
                    current_time, touching_timestamp +
                    pd.Timedelta(seconds=self.level.to_second()), t))
            time.sleep(t)

        return latest_timestamp, None, size, None

    def persist(self, entity, domain_list):
        if domain_list:
            if domain_list[0].timestamp >= domain_list[-1].timestamp:
                first_timestamp = domain_list[-1].timestamp
                last_timestamp = domain_list[0].timestamp
            else:
                first_timestamp = domain_list[0].timestamp
                last_timestamp = domain_list[-1].timestamp

            self.logger.info(
                "persist {} for entity_id:{},time interval:[{},{}]".format(
                    self.data_schema, entity.id, first_timestamp,
                    last_timestamp))

            current_timestamp = now_pd_timestamp()

            saving_datas = domain_list

            # FIXME:remove this logic
            # FIXME:should remove unfinished data when recording,always set it to False now
            if is_same_date(current_timestamp,
                            last_timestamp) and self.contain_unfinished_data:
                if current_timestamp.hour >= self.close_hour and current_timestamp.minute >= self.close_minute + 2:
                    # after the closing time of the day,we think the last data is finished
                    saving_datas = domain_list
                else:
                    # ignore unfinished kdata
                    saving_datas = domain_list[:-1]
                    self.logger.info(
                        "ignore kdata for entity_id:{},level:{},timestamp:{},current_timestamp"
                        .format(entity.id, self.level, last_timestamp,
                                current_timestamp))

            self.session.add_all(saving_datas)
            self.session.commit()
Exemplo n.º 14
0
    def __init__(self,
                 data_schema: object,
                 entity_ids: List[str] = None,
                 entity_type: str = 'stock',
                 exchanges: List[str] = ['sh', 'sz'],
                 codes: List[str] = None,
                 the_timestamp: Union[str, pd.Timestamp] = None,
                 start_timestamp: Union[str, pd.Timestamp] = '2018-01-01',
                 end_timestamp: Union[str, pd.Timestamp] = '2019-06-23',
                 columns: List = None,
                 filters: List = None,
                 order: object = None,
                 limit: int = None,
                 provider: str = 'eastmoney',
                 level: IntervalLevel = IntervalLevel.LEVEL_1DAY,
                 category_field: str = 'entity_id',
                 time_field: str = 'timestamp',
                 trip_timestamp: bool = True,
                 auto_load: bool = True) -> None:
        self.data_schema = data_schema

        self.the_timestamp = the_timestamp
        if the_timestamp:
            self.start_timestamp = the_timestamp
            self.end_timestamp = the_timestamp
        else:
            self.start_timestamp = start_timestamp
            self.end_timestamp = end_timestamp

        self.start_timestamp = to_pd_timestamp(self.start_timestamp)
        self.end_timestamp = to_pd_timestamp(self.end_timestamp)

        self.entity_type = entity_type
        self.exchanges = exchanges
        if codes:
            if type(codes) == str:
                codes = codes.replace(' ', '')
                if codes.startswith('[') and codes.endswith(']'):
                    codes = json.loads(codes)
                else:
                    codes = codes.split(',')

        self.codes = codes
        self.entity_ids = entity_ids

        self.provider = provider
        self.filters = filters
        self.order = order
        self.limit = limit
        if level:
            self.level = IntervalLevel(level)
        else:
            self.level = level
        self.category_field = category_field
        self.time_field = time_field
        self.trip_timestamp = trip_timestamp
        self.auto_load = auto_load

        self.category_column = eval('self.data_schema.{}'.format(self.category_field))
        self.columns = columns

        # we store the data in a multiple index(category_column,timestamp) Dataframe
        if self.columns:
            # support str
            if type(columns[0]) == str:
                self.columns = []
                for col in columns:
                    self.columns.append(eval('data_schema.{}'.format(col)))

            time_col = eval('self.data_schema.{}'.format(self.time_field))

            # always add category_column and time_field for normalizing
            self.columns = list(set(self.columns) | {self.category_column, time_col})

        self.data_listeners: List[DataListener] = []

        self.data_df: pd.DataFrame = None
        self.normal_data: NormalData = None

        if self.auto_load:
            self.load_data()
Exemplo n.º 15
0
    def __init__(self,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = ['sh', 'sz'],
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: str = None,
                 level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False,
                 draw_result: bool = True) -> None:

        assert self.entity_schema is not None

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners = []

        self.selectors: List[TargetSelector] = []

        self.entity_ids = entity_ids

        self.exchanges = exchanges
        self.codes = codes

        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = IntervalLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever'
            )
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time
        self.draw_result = draw_result

        self.account_service = SimAccountService(
            trader_name=self.trader_name,
            timestamp=self.start_timestamp,
            provider=self.provider,
            level=self.level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(entity_ids=entity_ids,
                            entity_schema=self.entity_schema,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        self.selectors_comparator = self.init_selectors_comparator()

        self.trading_level_asc = list(
            set([IntervalLevel(selector.level)
                 for selector in self.selectors]))
        self.trading_level_asc.sort()

        self.trading_level_desc = list(self.trading_level_asc)
        self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()

        self.session = get_db_session('zvt', 'business')
        trader = get_trader(session=self.session,
                            trader_name=self.trader_name,
                            return_type='domain',
                            limit=1)

        if trader:
            self.logger.warning(
                "trader:{} has run before,old result would be deleted".format(
                    self.trader_name))
            self.session.query(business.Trader).filter(
                business.Trader.trader_name == self.trader_name).delete()
            self.session.commit()
        self.on_start()
Exemplo n.º 16
0
class Trader(object):
    logger = logging.getLogger(__name__)

    entity_schema: EntityMixin = None

    def __init__(self,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = ['sh', 'sz'],
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: str = None,
                 level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False,
                 draw_result: bool = True) -> None:

        assert self.entity_schema is not None

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners = []

        self.selectors: List[TargetSelector] = []

        self.entity_ids = entity_ids

        self.exchanges = exchanges
        self.codes = codes

        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = IntervalLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever'
            )
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time
        self.draw_result = draw_result

        self.account_service = SimAccountService(
            trader_name=self.trader_name,
            timestamp=self.start_timestamp,
            provider=self.provider,
            level=self.level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(entity_ids=entity_ids,
                            entity_schema=self.entity_schema,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        self.selectors_comparator = self.init_selectors_comparator()

        self.trading_level_asc = list(
            set([IntervalLevel(selector.level)
                 for selector in self.selectors]))
        self.trading_level_asc.sort()

        self.trading_level_desc = list(self.trading_level_asc)
        self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()

        self.session = get_db_session('zvt', 'business')
        trader = get_trader(session=self.session,
                            trader_name=self.trader_name,
                            return_type='domain',
                            limit=1)

        if trader:
            self.logger.warning(
                "trader:{} has run before,old result would be deleted".format(
                    self.trader_name))
            self.session.query(business.Trader).filter(
                business.Trader.trader_name == self.trader_name).delete()
            self.session.commit()
        self.on_start()

    def on_start(self):
        if not self.selectors:
            raise Exception(
                'please setup self.selectors in init_selectors at first')

        # run all the selectors
        for selector in self.selectors:
            # run for the history data at first
            selector.run()

        if self.entity_ids:
            entity_ids = json.dumps(self.entity_ids)
        else:
            entity_ids = None

        if self.exchanges:
            exchanges = json.dumps(self.exchanges)
        else:
            exchanges = None

        if self.codes:
            codes = json.dumps(self.codes)
        else:
            codes = None

        trader_domain = business.Trader(
            id=self.trader_name,
            timestamp=self.start_timestamp,
            trader_name=self.trader_name,
            entity_type=entity_ids,
            exchanges=exchanges,
            codes=codes,
            start_timestamp=self.start_timestamp,
            end_timestamp=self.end_timestamp,
            provider=self.provider,
            level=self.level.value,
            real_time=self.real_time,
            kdata_use_begin_time=self.kdata_use_begin_time)
        self.session.add(trader_domain)
        self.session.commit()

    def init_selectors(self, entity_ids, entity_schema, exchanges, codes,
                       start_timestamp, end_timestamp):
        """
        implement this to init selectors

        """
        raise NotImplementedError

    def init_selectors_comparator(self):
        """
        overwrite this to set selectors_comparator

        """
        return LimitSelectorsComparator(self.selectors)

    def add_trading_signal_listener(self, listener):
        if listener not in self.trading_signal_listeners:
            self.trading_signal_listeners.append(listener)

    def remove_trading_signal_listener(self, listener):
        if listener in self.trading_signal_listeners:
            self.trading_signal_listeners.remove(listener)

    def handle_targets_slot(self, timestamp):
        """
        this function would be called in every cycle, you could overwrite it for your custom algorithm to select the
        targets of different levels

        the default implementation is selecting the targets in all levels

        :param timestamp:
        :type timestamp:
        """
        long_selected = None
        short_selected = None
        for level in self.trading_level_desc:
            targets = self.targets_slot.get_targets(level=level)
            if targets:
                long_targets = set(targets[0])
                short_targets = set(targets[1])

                if not long_selected:
                    long_selected = long_targets
                else:
                    long_selected = long_selected & long_targets

                if not short_selected:
                    short_selected = short_targets
                else:
                    short_selected = short_selected & short_targets

        self.logger.debug('timestamp:{},long_selected:{}'.format(
            timestamp, long_selected))

        self.logger.debug('timestamp:{},short_selected:{}'.format(
            timestamp, short_selected))

        self.send_trading_signals(timestamp=timestamp,
                                  long_selected=long_selected,
                                  short_selected=short_selected)

    def send_trading_signals(self, timestamp, long_selected, short_selected):
        # current position
        account = self.account_service.latest_account
        current_holdings = [
            position['entity_id'] for position in account['positions']
            if position['available_long'] > 0
        ]

        if long_selected:
            # just long the security not in the positions
            longed = long_selected - set(current_holdings)
            if longed:
                position_pct = 1.0 / len(longed)
                order_money = account['cash'] * position_pct

                for entity_id in longed:
                    trading_signal = TradingSignal(
                        entity_id=entity_id,
                        the_timestamp=timestamp,
                        trading_signal_type=TradingSignalType.
                        trading_signal_open_long,
                        trading_level=self.level,
                        order_money=order_money)
                    for listener in self.trading_signal_listeners:
                        listener.on_trading_signal(trading_signal)

        # just short the security in current_holdings and short_selected
        if short_selected:
            shorted = set(current_holdings) & short_selected

            for entity_id in shorted:
                trading_signal = TradingSignal(
                    entity_id=entity_id,
                    the_timestamp=timestamp,
                    trading_signal_type=TradingSignalType.
                    trading_signal_close_long,
                    position_pct=1.0,
                    trading_level=self.level)
                for listener in self.trading_signal_listeners:
                    listener.on_trading_signal(trading_signal)

    def on_finish(self):
        # show the result
        if self.draw_result:
            import plotly.io as pio
            pio.renderers.default = "browser"
            reader = AccountReader(trader_names=[self.trader_name])
            drawer = Drawer(main_data=NormalData(reader.data_df.copy()[
                ['trader_name', 'timestamp', 'all_value']],
                                                 category_field='trader_name'))
            drawer.draw_line()

    def run(self):
        # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level
        # timestamp represents the timestamp in kdata
        for timestamp in iterate_timestamps(
                entity_type=self.entity_schema,
                exchange=self.exchanges[0],
                start_timestamp=self.start_timestamp,
                end_timestamp=self.end_timestamp,
                level=self.level):

            if not is_trading_date(entity_type=self.entity_schema,
                                   exchange=self.exchanges[0],
                                   timestamp=timestamp):
                continue
            if self.real_time:
                # all selector move on to handle the coming data
                if self.kdata_use_begin_time:
                    real_end_timestamp = timestamp + pd.Timedelta(
                        seconds=self.level.to_second())
                else:
                    real_end_timestamp = timestamp

                waiting_seconds, _ = self.level.count_from_timestamp(
                    real_end_timestamp,
                    one_day_trading_minutes=get_one_day_trading_minutes(
                        entity_type=self.entity_schema))
                # meaning the future kdata not ready yet,we could move on to check
                if waiting_seconds and (waiting_seconds > 0):
                    # iterate the selector from min to max which in finished timestamp kdata
                    for level in self.trading_level_asc:
                        if (is_in_finished_timestamps(
                                entity_type=self.entity_schema,
                                exchange=self.exchanges[0],
                                timestamp=timestamp,
                                level=level)):
                            for selector in self.selectors:
                                if selector.level == level:
                                    selector.move_on(timestamp,
                                                     self.kdata_use_begin_time,
                                                     timeout=waiting_seconds +
                                                     20)

            # on_trading_open to setup the account
            if self.level == IntervalLevel.LEVEL_1DAY or (
                    self.level != IntervalLevel.LEVEL_1DAY
                    and is_open_time(entity_type=self.entity_schema,
                                     exchange=self.exchanges[0],
                                     timestamp=timestamp)):
                self.account_service.on_trading_open(timestamp)

            # the time always move on by min level step and we could check all level targets in the slot
            self.handle_targets_slot(timestamp=timestamp)

            for level in self.trading_level_asc:
                # in every cycle, all level selector do its job in its time
                if (is_in_finished_timestamps(entity_type=self.entity_schema,
                                              exchange=self.exchanges[0],
                                              timestamp=timestamp,
                                              level=level)):
                    long_targets, short_targets = self.selectors_comparator.make_decision(
                        timestamp=timestamp, trading_level=level)

                    self.targets_slot.input_targets(level, long_targets,
                                                    short_targets)

            # on_trading_close to calculate date account
            if self.level == IntervalLevel.LEVEL_1DAY or (
                    self.level != IntervalLevel.LEVEL_1DAY
                    and is_close_time(entity_type=self.entity_schema,
                                      exchange=self.exchanges[0],
                                      timestamp=timestamp)):
                self.account_service.on_trading_close(timestamp)

        self.on_finish()
Exemplo n.º 17
0
    def __init__(self,
                 data_schema: Mixin,
                 entity_provider: str = None,
                 entity_ids: List[str] = None,
                 entity_type: str = 'stock',
                 exchanges: List[str] = ['sh', 'sz'],
                 codes: List[str] = None,
                 the_timestamp: Union[str, pd.Timestamp] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = now_pd_timestamp(),
                 columns: List = None,
                 filters: List = None,
                 order: object = None,
                 limit: int = None,
                 provider: str = None,
                 level: IntervalLevel = IntervalLevel.LEVEL_1DAY,
                 category_field: str = 'entity_id',
                 time_field: str = 'timestamp',
                 computing_window: int = None) -> None:
        self.logger = logging.getLogger(self.__class__.__name__)

        self.data_schema = data_schema

        self.the_timestamp = the_timestamp
        if the_timestamp:
            self.start_timestamp = the_timestamp
            self.end_timestamp = the_timestamp
        else:
            self.start_timestamp = start_timestamp
            self.end_timestamp = end_timestamp

        self.start_timestamp = to_pd_timestamp(self.start_timestamp)
        self.end_timestamp = to_pd_timestamp(self.end_timestamp)

        self.entity_type = entity_type
        self.entity_provider = entity_provider
        self.exchanges = exchanges
        if codes:
            if type(codes) == str:
                codes = codes.replace(' ', '')
                if codes.startswith('[') and codes.endswith(']'):
                    codes = json.loads(codes)
                else:
                    codes = codes.split(',')

        self.codes = codes
        self.entity_ids = entity_ids

        # 转换成标准entity_id
        if not self.entity_ids and self.entity_provider:
            self.entity_ids = get_entity_ids(provider=self.entity_provider,
                                             entity_type=self.entity_type,
                                             exchanges=self.exchanges,
                                             codes=self.codes)

        self.provider = provider
        self.filters = filters
        self.order = order
        self.limit = limit

        if level:
            self.level = IntervalLevel(level)
        else:
            self.level = level

        self.category_field = category_field
        self.time_field = time_field
        self.computing_window = computing_window

        self.category_col = eval('self.data_schema.{}'.format(
            self.category_field))
        self.time_col = eval('self.data_schema.{}'.format(self.time_field))

        self.columns = columns

        # we store the data in a multiple index(category_column,timestamp) Dataframe
        if self.columns:
            # support str
            if type(columns[0]) == str:
                self.columns = []
                for col in columns:
                    self.columns.append(eval('data_schema.{}'.format(col)))

            # always add category_column and time_field for normalizing
            self.columns = list(
                set(self.columns) | {self.category_col, self.time_col})

        self.data_listeners: List[DataListener] = []

        self.data_df: pd.DataFrame = None

        self.load_data()
Exemplo n.º 18
0
def next_timestamp(current_timestamp: pd.Timestamp, level: IntervalLevel) -> pd.Timestamp:
    current_timestamp = to_pd_timestamp(current_timestamp)
    return current_timestamp + pd.Timedelta(seconds=level.to_second())
Exemplo n.º 19
0
def is_finished_kdata_timestamp(timestamp, level: IntervalLevel):
    timestamp = to_pd_timestamp(timestamp)
    if level.floor_timestamp(timestamp) == timestamp:
        return True
    return False