Beispiel #1
0
def get_ma_factor_schema(entity_type: str,
                         level: Union[IntervalLevel,
                                      str] = IntervalLevel.LEVEL_1DAY):

    if type(level) == str:
        level = IntervalLevel(level)

    schema_str = '{}{}MaFactor'.format(entity_type.capitalize(),
                                       level.value.capitalize())

    return get_schema_by_name(schema_str)
Beispiel #2
0
def evaluate_size_from_timestamp(start_timestamp,
                                 level: IntervalLevel,
                                 one_day_trading_minutes,
                                 end_timestamp: pd.Timestamp = None):
    """
    given from timestamp,level,one_day_trading_minutes,this func evaluate size of kdata to current.
    it maybe a little bigger than the real size for fetching all the kdata.

    :param start_timestamp:
    :type start_timestamp: pd.Timestamp
    :param level:
    :type level: IntervalLevel
    :param one_day_trading_minutes:
    :type one_day_trading_minutes: int
    """
    if not end_timestamp:
        end_timestamp = pd.Timestamp.now()
    else:
        end_timestamp = to_pd_timestamp(end_timestamp)

    time_delta = end_timestamp - to_pd_timestamp(start_timestamp)

    one_day_trading_seconds = one_day_trading_minutes * 60

    if level == IntervalLevel.LEVEL_1DAY:
        return time_delta.days + 1

    if level == IntervalLevel.LEVEL_1WEEK:
        return int(math.ceil(time_delta.days / 7)) + 1

    if level == IntervalLevel.LEVEL_1MON:
        return int(math.ceil(time_delta.days / 30)) + 1

    if time_delta.days > 0:
        seconds = (time_delta.days + 1) * one_day_trading_seconds
        return int(math.ceil(seconds / level.to_second())) + 1
    else:
        seconds = time_delta.total_seconds()
        return min(
            int(math.ceil(seconds / level.to_second())) + 1,
            one_day_trading_seconds / level.to_second() + 1)
Beispiel #3
0
    def __init__(
        self,
        force_update=True,
        sleeping_time=10,
        exchanges=None,
        entity_id=None,
        entity_ids=None,
        code=None,
        codes=None,
        day_data=False,
        entity_filters=None,
        ignore_failed=True,
        real_time=False,
        fix_duplicate_way="ignore",
        start_timestamp=None,
        end_timestamp=None,
        level=IntervalLevel.LEVEL_1DAY,
        kdata_use_begin_time=False,
        one_day_trading_minutes=24 * 60,
        adjust_type=AdjustType.qfq,
    ) -> None:
        level = IntervalLevel(level)
        adjust_type = AdjustType(adjust_type)
        self.data_schema = get_kdata_schema(entity_type="stock",
                                            level=level,
                                            adjust_type=adjust_type)
        self.jq_trading_level = to_jq_trading_level(level)

        super().__init__(
            force_update,
            sleeping_time,
            exchanges,
            entity_id,
            entity_ids,
            code,
            codes,
            day_data,
            entity_filters,
            ignore_failed,
            real_time,
            fix_duplicate_way,
            start_timestamp,
            end_timestamp,
            level,
            kdata_use_begin_time,
            one_day_trading_minutes,
        )

        self.adjust_type = adjust_type

        get_token(zvt_config["jq_username"],
                  zvt_config["jq_password"],
                  force=True)
Beispiel #4
0
def get_ma_state_stats_schema(entity_type: str,
                              level: Union[IntervalLevel,
                                           str] = IntervalLevel.LEVEL_1DAY):
    if type(level) == str:
        level = IntervalLevel(level)

    # ma state stats schema rule
    # 1)name:{SecurityType.value.capitalize()}{IntervalLevel.value.upper()}MaStateStats
    schema_str = '{}{}MaStateStats'.format(entity_type.capitalize(),
                                           level.value.capitalize())

    return eval(schema_str)
Beispiel #5
0
def get_z_factor_schema(entity_type: str,
                        level: Union[IntervalLevel,
                                     str] = IntervalLevel.LEVEL_1DAY):
    if type(level) == str:
        level = IntervalLevel(level)

    # z factor schema rule
    # 1)name:{SecurityType.value.capitalize()}{IntervalLevel.value.upper()}ZFactor
    schema_str = "{}{}ZFactor".format(entity_type.capitalize(),
                                      level.value.capitalize())

    return get_schema_by_name(schema_str)
Beispiel #6
0
    def __init__(
        self,
        force_update=True,
        sleeping_time=10,
        exchanges=None,
        entity_id=None,
        entity_ids=None,
        code=None,
        codes=None,
        day_data=False,
        entity_filters=None,
        ignore_failed=True,
        real_time=False,
        fix_duplicate_way="ignore",
        start_timestamp=None,
        end_timestamp=None,
        level=IntervalLevel.LEVEL_1DAY,
        kdata_use_begin_time=False,
        one_day_trading_minutes=24 * 60,
        adjust_type=AdjustType.qfq,
    ) -> None:
        level = IntervalLevel(level)
        self.adjust_type = AdjustType(adjust_type)
        self.entity_type = self.entity_schema.__name__.lower()

        self.data_schema = get_kdata_schema(entity_type=self.entity_type,
                                            level=level,
                                            adjust_type=self.adjust_type)

        super().__init__(
            force_update,
            sleeping_time,
            exchanges,
            entity_id,
            entity_ids,
            code,
            codes,
            day_data,
            entity_filters,
            ignore_failed,
            real_time,
            fix_duplicate_way,
            start_timestamp,
            end_timestamp,
            level,
            kdata_use_begin_time,
            one_day_trading_minutes,
        )
Beispiel #7
0
def get_kdata_schema(entity_type: str,
                     level: Union[IntervalLevel, str] = IntervalLevel.LEVEL_1DAY,
                     adjust_type: Union[AdjustType, str] = None):
    if type(level) == str:
        level = IntervalLevel(level)
    if type(adjust_type) == str:
        adjust_type = AdjustType(adjust_type)

    # kdata schema rule
    # 1)name:{entity_type.capitalize()}{IntervalLevel.value.upper()}Kdata
    if adjust_type and (adjust_type != AdjustType.qfq):
        schema_str = '{}{}{}Kdata'.format(entity_type.capitalize(), level.value.capitalize(),
                                          adjust_type.value.capitalize())
    else:
        schema_str = '{}{}Kdata'.format(entity_type.capitalize(), level.value.capitalize())
    return get_schema_by_name(schema_str)
    def __init__(self,
                 exchanges=['sh', 'sz'],
                 entity_ids=None,
                 codes=None,
                 batch_size=10,
                 force_update=True,
                 sleeping_time=0,
                 default_size=2000,
                 real_time=False,
                 fix_duplicate_way='ignore',
                 start_timestamp=None,
                 end_timestamp=None,
                 level=IntervalLevel.LEVEL_1WEEK,
                 kdata_use_begin_time=False,
                 close_hour=15,
                 close_minute=0,
                 one_day_trading_minutes=4 * 60,
                 adjust_type=AdjustType.qfq,
                 share_para=None) -> None:
        level = IntervalLevel(level)
        adjust_type = AdjustType(adjust_type)
        self.data_schema = get_kdata_schema(entity_type=EntityType.Stock,
                                            level=level,
                                            adjust_type=adjust_type)
        self.jq_trading_level = to_jq_trading_level(level)

        super().__init__(EntityType.Stock,
                         exchanges,
                         entity_ids,
                         codes,
                         batch_size,
                         force_update,
                         sleeping_time,
                         default_size,
                         real_time,
                         fix_duplicate_way,
                         start_timestamp,
                         end_timestamp,
                         close_hour,
                         close_minute,
                         level,
                         kdata_use_begin_time,
                         one_day_trading_minutes,
                         share_para=share_para)
        self.adjust_type = adjust_type
        jq_auth()
Beispiel #9
0
def to_em_level_flag(level: IntervalLevel):
    level = IntervalLevel(level)
    if level == IntervalLevel.LEVEL_5MIN:
        return 5
    if level == IntervalLevel.LEVEL_15MIN:
        return 15
    elif level == IntervalLevel.LEVEL_30MIN:
        return 30
    elif level == IntervalLevel.LEVEL_1HOUR:
        return 60
    elif level == IntervalLevel.LEVEL_1DAY:
        return 101
    elif level == IntervalLevel.LEVEL_1WEEK:
        return 102
    elif level == IntervalLevel.LEVEL_1MON:
        return 103

    assert False
Beispiel #10
0
    def __init__(
        self,
        force_update=True,
        sleeping_time=10,
        exchanges=None,
        entity_id=None,
        entity_ids=None,
        code=None,
        codes=None,
        day_data=False,
        entity_filters=None,
        ignore_failed=True,
        real_time=False,
        fix_duplicate_way="ignore",
        start_timestamp=None,
        end_timestamp=None,
        level=IntervalLevel.LEVEL_1DAY,
        kdata_use_begin_time=False,
        one_day_trading_minutes=24 * 60,
    ) -> None:
        super().__init__(
            force_update,
            sleeping_time,
            exchanges,
            entity_id,
            entity_ids,
            code=code,
            codes=codes,
            day_data=day_data,
            entity_filters=entity_filters,
            ignore_failed=ignore_failed,
            real_time=real_time,
            fix_duplicate_way=fix_duplicate_way,
            start_timestamp=start_timestamp,
            end_timestamp=end_timestamp,
        )

        self.level = IntervalLevel(level)
        self.kdata_use_begin_time = kdata_use_begin_time
        self.one_day_trading_minutes = one_day_trading_minutes
Beispiel #11
0
    def __init__(self,
                 exchanges=['sh', 'sz'],
                 entity_ids=None,
                 codes=None,
                 batch_size=10,
                 force_update=True,
                 sleeping_time=0,
                 default_size=2000,
                 real_time=False,
                 fix_duplicate_way='ignore',
                 start_timestamp=None,
                 end_timestamp=None,
                 level=IntervalLevel.LEVEL_1WEEK,
                 kdata_use_begin_time=False,
                 close_hour=15,
                 close_minute=0,
                 one_day_trading_minutes=4 * 60,
                 adjust_type=AdjustType.qfq) -> None:
        level = IntervalLevel(level)
        adjust_type = AdjustType(adjust_type)
        self.data_schema = get_kdata_schema(entity_type='stock',
                                            level=level,
                                            adjust_type=adjust_type)
        self.bs_trading_level = to_bs_trading_level(level)

        super().__init__('stock', exchanges, entity_ids, codes, batch_size,
                         force_update, sleeping_time, default_size, real_time,
                         fix_duplicate_way, start_timestamp, end_timestamp,
                         close_hour, close_minute, level, kdata_use_begin_time,
                         one_day_trading_minutes)
        self.adjust_type = adjust_type

        print("尝试登陆baostock")
        #####login#####
        lg = bs.login(user_id="anonymous", password="******")
        if (lg.error_code == '0'):
            print("登陆成功")
        else:
            print("登录失败")
Beispiel #12
0
    def __init__(self,
                 exchanges=['hk'],
                 entity_ids=None,
                 codes=None,
                 batch_size=10,
                 force_update=True,
                 sleeping_time=0,
                 default_size=2000,
                 real_time=False,
                 fix_duplicate_way='ignore',
                 start_timestamp=None,
                 end_timestamp=None,
                 level=IntervalLevel.LEVEL_1WEEK,
                 kdata_use_begin_time=False,
                 close_hour=15,
                 close_minute=0,
                 one_day_trading_minutes=4 * 60,
                 adjust_type=AdjustType.qfq) -> None:
        level = IntervalLevel(level)
        adjust_type = AdjustType(adjust_type)
        self.data_schema = get_kdata_schema(entity_type='stock',
                                            level=level,
                                            adjust_type=adjust_type)
        self.jq_trading_level = to_jq_trading_level(level)

        super().__init__('stock', exchanges, entity_ids, codes, batch_size,
                         force_update, sleeping_time, default_size, real_time,
                         fix_duplicate_way, start_timestamp, end_timestamp,
                         close_hour, close_minute, level, kdata_use_begin_time,
                         one_day_trading_minutes)
        self.adjust_type = adjust_type

        # 调用登录函数(激活后使用,不需要用户名密码)
        loginResult = c.start("ForceLogin=1", '')
        if (loginResult.ErrorCode != 0):
            print("login in fail")
            exit()
Beispiel #13
0
    def __init__(self,
                 exchanges=['sh', 'sz'],
                 entity_ids=None,
                 codes=None,
                 day_data=False,
                 batch_size=10,
                 force_update=True,
                 sleeping_time=0,
                 default_size=2000,
                 real_time=False,
                 fix_duplicate_way='ignore',
                 start_timestamp=None,
                 end_timestamp=None,
                 level=IntervalLevel.LEVEL_1WEEK,
                 kdata_use_begin_time=False,
                 close_hour=15,
                 close_minute=0,
                 one_day_trading_minutes=4 * 60,
                 adjust_type=AdjustType.qfq) -> None:
        level = IntervalLevel(level)
        adjust_type = AdjustType(adjust_type)
        self.data_schema = get_kdata_schema(entity_type='stock',
                                            level=level,
                                            adjust_type=adjust_type)
        self.jq_trading_level = to_jq_trading_level(level)

        super().__init__('stock', exchanges, entity_ids, codes, day_data,
                         batch_size, force_update, sleeping_time, default_size,
                         real_time, fix_duplicate_way, start_timestamp,
                         end_timestamp, close_hour, close_minute, level,
                         kdata_use_begin_time, one_day_trading_minutes)
        self.adjust_type = adjust_type

        get_token(zvt_config['jq_username'],
                  zvt_config['jq_password'],
                  force=True)
Beispiel #14
0
def gen_kdata_schema(
    pkg: str,
    providers: List[str],
    entity_type: str,
    levels: List[IntervalLevel],
    adjust_types=None,
    entity_in_submodule: bool = False,
    kdata_module="quotes",
):
    if adjust_types is None:
        adjust_types = [None]
    tables = []

    base_path = "./domain"

    if kdata_module:
        base_path = os.path.join(base_path, kdata_module)
    if entity_in_submodule:
        base_path = os.path.join(base_path, entity_type)

    if not os.path.exists(base_path):
        logger.info(f"create dir {base_path}")
        os.makedirs(base_path)

    for level in levels:

        for adjust_type in adjust_types:
            level = IntervalLevel(level)

            cap_entity_type = entity_type.capitalize()
            cap_level = level.value.capitalize()

            # you should define {EntityType}KdataCommon in kdata_module at first
            kdata_common = f"{cap_entity_type}KdataCommon"

            if adjust_type and (adjust_type != AdjustType.qfq):
                class_name = f"{cap_entity_type}{cap_level}{adjust_type.value.capitalize()}Kdata"
                table_name = f"{entity_type}_{level.value}_{adjust_type.value.lower()}_kdata"

            else:
                class_name = f"{cap_entity_type}{cap_level}Kdata"
                table_name = f"{entity_type}_{level.value}_kdata"

            tables.append(table_name)

            schema_template = f"""# -*- coding: utf-8 -*-
# this file is generated by gen_kdata_schema function, dont't change it
from sqlalchemy.orm import declarative_base

from zvt.contract.register import register_schema
from {pkg}.domain.{kdata_module} import {kdata_common}

KdataBase = declarative_base()


class {class_name}(KdataBase, {kdata_common}):
    __tablename__ = '{table_name}'


register_schema(providers={providers}, db_name='{table_name}', schema_base=KdataBase, entity_type='{entity_type}')

"""
            # generate the schema
            with open(os.path.join(base_path, f"{table_name}.py"), "w") as outfile:
                outfile.write(schema_template)

        # generate the package
        pkg_file = os.path.join(base_path, "__init__.py")
        if not os.path.exists(pkg_file):
            package_template = """# -*- coding: utf-8 -*-
"""
            with open(pkg_file, "w") as outfile:
                outfile.write(package_template)

    # generate exports
    gen_exports("./domain")
Beispiel #15
0
def to_high_level_kdata(kdata_df: pd.DataFrame, to_level: IntervalLevel):
    def to_close(s):
        if pd_is_not_null(s):
            return s[-1]

    def to_open(s):
        if pd_is_not_null(s):
            return s[0]

    def to_high(s):
        return np.max(s)

    def to_low(s):
        return np.min(s)

    def to_sum(s):
        return np.sum(s)

    original_level = kdata_df['level'][0]
    entity_id = kdata_df['entity_id'][0]
    provider = kdata_df['provider'][0]
    name = kdata_df['name'][0]
    code = kdata_df['code'][0]

    entity_type, _, _ = decode_entity_id(entity_id=entity_id)

    assert IntervalLevel(original_level) <= IntervalLevel.LEVEL_1DAY
    assert IntervalLevel(original_level) < IntervalLevel(to_level)

    df: pd.DataFrame = None
    if to_level == IntervalLevel.LEVEL_1WEEK:
        # loffset='-2' 用周五作为时间标签
        if entity_type == 'stock':
            df = kdata_df.resample('W', loffset=pd.DateOffset(days=-2)).apply({
                'close':
                to_close,
                'open':
                to_open,
                'high':
                to_high,
                'low':
                to_low,
                'volume':
                to_sum,
                'turnover':
                to_sum
            })
        else:
            df = kdata_df.resample('W', loffset=pd.DateOffset(days=-2)).apply({
                'close':
                to_close,
                'open':
                to_open,
                'high':
                to_high,
                'low':
                to_low,
                'volume':
                to_sum,
                'turnover':
                to_sum
            })
    df = df.dropna()
    # id        entity_id  timestamp   provider    code  name level
    df['entity_id'] = entity_id
    df['provider'] = provider
    df['code'] = code
    df['name'] = name

    return df
Beispiel #16
0
def to_high_level_kdata(kdata_df: pd.DataFrame, to_level: IntervalLevel):
    def to_close(s):
        if pd_is_not_null(s):
            return s[-1]

    def to_open(s):
        if pd_is_not_null(s):
            return s[0]

    def to_high(s):
        return np.max(s)

    def to_low(s):
        return np.min(s)

    def to_sum(s):
        return np.sum(s)

    original_level = kdata_df["level"][0]
    entity_id = kdata_df["entity_id"][0]
    provider = kdata_df["provider"][0]
    name = kdata_df["name"][0]
    code = kdata_df["code"][0]

    entity_type, _, _ = decode_entity_id(entity_id=entity_id)

    assert IntervalLevel(original_level) <= IntervalLevel.LEVEL_1DAY
    assert IntervalLevel(original_level) < IntervalLevel(to_level)

    df: pd.DataFrame = None
    if to_level == IntervalLevel.LEVEL_1WEEK:
        # loffset='-2' 用周五作为时间标签
        if entity_type == "stock":
            df = kdata_df.resample("W", loffset=pd.DateOffset(days=-2)).apply({
                "close":
                to_close,
                "open":
                to_open,
                "high":
                to_high,
                "low":
                to_low,
                "volume":
                to_sum,
                "turnover":
                to_sum,
            })
        else:
            df = kdata_df.resample("W", loffset=pd.DateOffset(days=-2)).apply({
                "close":
                to_close,
                "open":
                to_open,
                "high":
                to_high,
                "low":
                to_low,
                "volume":
                to_sum,
                "turnover":
                to_sum,
            })
    df = df.dropna()
    # id        entity_id  timestamp   provider    code  name level
    df["entity_id"] = entity_id
    df["provider"] = provider
    df["code"] = code
    df["name"] = name

    return df
Beispiel #17
0
    def __init__(self,
                 region: Region,
                 data_schema: Type[Mixin],
                 entity_schema: Type[EntityMixin],
                 provider: Provider = Provider.Default,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = None,
                 codes: List[str] = None,
                 the_timestamp: Union[str, pd.Timestamp] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 columns: List = None,
                 filters: List = None,
                 order: object = None,
                 limit: int = None,
                 level: IntervalLevel = None,
                 category_field: str = 'entity_id',
                 time_field: str = 'timestamp',
                 computing_window: int = None) -> None:
        super().__init__()
        self.logger = logging.getLogger(self.__class__.__name__)

        self.data_schema = data_schema
        self.entity_schema = entity_schema

        self.region = region
        self.provider = provider

        if end_timestamp is None:
            end_timestamp = now_pd_timestamp(self.region)

        self.the_timestamp = the_timestamp
        if the_timestamp:
            self.start_timestamp = the_timestamp
            self.end_timestamp = the_timestamp
        else:
            self.start_timestamp = start_timestamp
            self.end_timestamp = end_timestamp

        self.start_timestamp = to_pd_timestamp(self.start_timestamp)
        self.end_timestamp = to_pd_timestamp(self.end_timestamp)

        self.exchanges = exchanges

        if codes:
            if type(codes) == str:
                codes = codes.replace(' ', '')
                if codes.startswith('[') and codes.endswith(']'):
                    codes = json.loads(codes)
                else:
                    codes = codes.split(',')

        self.codes = codes
        self.entity_ids = entity_ids

        # 转换成标准entity_id
        if entity_schema and not self.entity_ids:
            df = get_entities(region=self.region,
                              entity_schema=entity_schema,
                              provider=self.provider,
                              exchanges=self.exchanges,
                              codes=self.codes)
            if pd_is_not_null(df):
                self.entity_ids = df['entity_id'].to_list()

        self.filters = filters
        self.order = order
        self.limit = limit

        if level:
            self.level = IntervalLevel(level)
        else:
            self.level = level

        self.category_field = category_field
        self.time_field = time_field
        self.computing_window = computing_window

        self.category_col = eval('self.data_schema.{}'.format(
            self.category_field))
        self.time_col = eval('self.data_schema.{}'.format(self.time_field))

        self.columns = columns

        # we store the data in a multiple index(category_column,timestamp) Dataframe
        if self.columns:
            # support str
            if type(columns[0]) == str:
                self.columns = []
                for col in columns:
                    self.columns.append(eval('data_schema.{}'.format(col)))

            # always add category_column and time_field for normalizing
            self.columns = list(
                set(self.columns) | {self.category_col, self.time_col})

        self.data_listeners: List[DataListener] = []

        self.data_df: pd.DataFrame = None

        self.load_data()
Beispiel #18
0
    def __init__(
        self,
        data_schema: Type[Mixin],
        entity_schema: Type[TradableEntity],
        provider: str = None,
        entity_provider: str = None,
        entity_ids: List[str] = None,
        exchanges: List[str] = None,
        codes: List[str] = None,
        start_timestamp: Union[str, pd.Timestamp] = None,
        end_timestamp: Union[str, pd.Timestamp] = now_pd_timestamp(),
        columns: List = None,
        filters: List = None,
        order: object = None,
        limit: int = None,
        level: IntervalLevel = None,
        category_field: str = "entity_id",
        time_field: str = "timestamp",
        computing_window: int = None,
    ) -> None:
        self.logger = logging.getLogger(self.__class__.__name__)

        self.data_schema = data_schema
        self.entity_schema = entity_schema
        self.provider = provider
        self.entity_provider = entity_provider
        self.start_timestamp = start_timestamp
        self.end_timestamp = end_timestamp
        self.start_timestamp = to_pd_timestamp(self.start_timestamp)
        self.end_timestamp = to_pd_timestamp(self.end_timestamp)
        self.exchanges = exchanges
        self.codes = codes
        self.entity_ids = entity_ids

        # 转换成标准entity_id
        if not self.entity_ids:
            df = get_entities(entity_schema=entity_schema,
                              provider=self.entity_provider,
                              exchanges=self.exchanges,
                              codes=self.codes)
            if pd_is_not_null(df):
                self.entity_ids = df["entity_id"].to_list()

        self.filters = filters
        self.order = order
        self.limit = limit

        if level:
            self.level = IntervalLevel(level)
        else:
            self.level = level

        self.category_field = category_field
        self.time_field = time_field
        self.computing_window = computing_window

        self.category_col = eval("self.data_schema.{}".format(
            self.category_field))
        self.time_col = eval("self.data_schema.{}".format(self.time_field))

        self.columns = columns

        if self.columns:
            # always add category_column and time_field for normalizing
            self.columns = list(
                set(self.columns) | {self.category_field, self.time_field})

        self.data_listeners: List[DataListener] = []

        self.data_df: pd.DataFrame = None

        self.load_data()
        pass


__all__ = ['BaoChinaStockKdataRecorder']

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--level',
                        help='trading level',
                        default='1d',
                        choices=[item.value for item in IntervalLevel])
    parser.add_argument('--codes', help='codes', default=['000001'], nargs='+')

    args = parser.parse_args()

    level = IntervalLevel(args.level)
    codes = args.codes

    init_log('bao_china_stock_{}_kdata.log'.format(args.level))
    BaoChinaStockKdataRecorder(level=level,
                               sleeping_time=0,
                               codes=codes,
                               real_time=False,
                               adjust_type=AdjustType.hfq).run()

    print(
        get_kdata(region=Region.CHN,
                  entity_id='stock_sz_000001',
                  limit=10,
                  order=Stock1dHfqKdata.timestamp.desc(),
                  adjust_type=AdjustType.hfq))
Beispiel #20
0
class Trader(object):
    entity_schema: EntityMixin = None

    def __init__(self,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = None,
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: str = None,
                 level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False,
                 draw_result: bool = True) -> None:
        assert self.entity_schema is not None

        self.logger = logging.getLogger(__name__)

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners: List[TradingListener] = []

        self.selectors: List[TargetSelector] = []

        self.entity_ids = entity_ids

        self.exchanges = exchanges
        self.codes = codes

        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = IntervalLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        self.trading_dates = self.entity_schema.get_trading_dates(
            start_date=self.start_timestamp, end_date=self.end_timestamp)

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever'
            )
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time
        self.draw_result = draw_result

        self.account_service = SimAccountService(
            entity_schema=self.entity_schema,
            trader_name=self.trader_name,
            timestamp=self.start_timestamp,
            provider=self.provider,
            level=self.level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(entity_ids=entity_ids,
                            entity_schema=self.entity_schema,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        if self.selectors:
            self.trading_level_asc = list(
                set([
                    IntervalLevel(selector.level)
                    for selector in self.selectors
                ]))
            self.trading_level_asc.sort()

            self.logger.info(
                f'trader level:{self.level},selectors level:{self.trading_level_asc}'
            )

            if self.level != self.trading_level_asc[0]:
                raise Exception(
                    "trader level should be the min of the selectors")

            self.trading_level_desc = list(self.trading_level_asc)
            self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()

        self.session = get_db_session('zvt', data_schema=TraderInfo)
        self.on_start()

    def on_start(self):
        # run all the selectors
        for selector in self.selectors:
            # run for the history data at first
            selector.run()

        if self.entity_ids:
            entity_ids = json.dumps(self.entity_ids)
        else:
            entity_ids = None

        if self.exchanges:
            exchanges = json.dumps(self.exchanges)
        else:
            exchanges = None

        if self.codes:
            codes = json.dumps(self.codes)
        else:
            codes = None

        sim_account = TraderInfo(
            id=self.trader_name,
            entity_id=f'trader_zvt_{self.trader_name}',
            timestamp=self.start_timestamp,
            trader_name=self.trader_name,
            entity_ids=entity_ids,
            exchanges=exchanges,
            codes=codes,
            start_timestamp=self.start_timestamp,
            end_timestamp=self.end_timestamp,
            provider=self.provider,
            level=self.level.value,
            real_time=self.real_time,
            kdata_use_begin_time=self.kdata_use_begin_time)
        self.session.add(sim_account)
        self.session.commit()

    def init_selectors(self, entity_ids, entity_schema, exchanges, codes,
                       start_timestamp, end_timestamp):
        """
        implement this to init selectors

        """
        pass

    def add_trading_signal_listener(self, listener):
        if listener not in self.trading_signal_listeners:
            self.trading_signal_listeners.append(listener)

    def remove_trading_signal_listener(self, listener):
        if listener in self.trading_signal_listeners:
            self.trading_signal_listeners.remove(listener)

    def handle_targets_slot(self, due_timestamp: pd.Timestamp,
                            happen_timestamp: pd.Timestamp):
        """
        this function would be called in every cycle, you could overwrite it for your custom algorithm to select the
        targets of different levels

        the default implementation is selecting the targets in all levels

        :param due_timestamp:
        :param happen_timestamp:

        """
        long_selected = None
        short_selected = None
        for level in self.trading_level_desc:
            targets = self.targets_slot.get_targets(level=level)
            if targets:
                long_targets = set(targets[0])
                short_targets = set(targets[1])

                if not long_selected:
                    long_selected = long_targets
                else:
                    long_selected = long_selected & long_targets

                if not short_selected:
                    short_selected = short_targets
                else:
                    short_selected = short_selected & short_targets

        self.logger.debug('timestamp:{},long_selected:{}'.format(
            due_timestamp, long_selected))

        self.logger.debug('timestamp:{},short_selected:{}'.format(
            due_timestamp, short_selected))

        self.trade_the_targets(due_timestamp=due_timestamp,
                               happen_timestamp=happen_timestamp,
                               long_selected=long_selected,
                               short_selected=short_selected)

    def get_current_account(self) -> AccountStats:
        return self.account_service.account

    def buy(self,
            due_timestamp,
            happen_timestamp,
            entity_ids,
            position_pct=1.0,
            ignore_in_position=True):
        if ignore_in_position:
            account = self.get_current_account()
            current_holdings = []
            if account.positions:
                current_holdings = [
                    position.entity_id for position in account.positions
                    if position != None and position.available_long > 0
                ]

            entity_ids = set(entity_ids) - set(current_holdings)

        if entity_ids:
            position_pct = (1.0 / len(entity_ids)) * position_pct

        for entity_id in entity_ids:
            trading_signal = TradingSignal(
                entity_id=entity_id,
                due_timestamp=due_timestamp,
                happen_timestamp=happen_timestamp,
                trading_signal_type=TradingSignalType.open_long,
                trading_level=self.level,
                position_pct=position_pct)
            self.send_trading_signal(trading_signal)

    def sell(self,
             due_timestamp,
             happen_timestamp,
             entity_ids,
             position_pct=1.0):
        # current position
        account = self.get_current_account()
        current_holdings = []
        if account.positions:
            current_holdings = [
                position.entity_id for position in account.positions
                if position != None and position.available_long > 0
            ]

        shorted = set(current_holdings) & entity_ids

        for entity_id in shorted:
            trading_signal = TradingSignal(
                entity_id=entity_id,
                due_timestamp=due_timestamp,
                happen_timestamp=happen_timestamp,
                trading_signal_type=TradingSignalType.close_long,
                trading_level=self.level,
                position_pct=position_pct)
            self.send_trading_signal(trading_signal)

    def trade_the_targets(self,
                          due_timestamp,
                          happen_timestamp,
                          long_selected,
                          short_selected,
                          long_pct=1.0,
                          short_pct=1.0):
        self.buy(due_timestamp=due_timestamp,
                 happen_timestamp=happen_timestamp,
                 entity_ids=long_selected,
                 position_pct=long_pct)
        self.sell(due_timestamp=due_timestamp,
                  happen_timestamp=happen_timestamp,
                  entity_ids=short_selected,
                  position_pct=short_pct)

    def send_trading_signal(self, signal: TradingSignal):
        for listener in self.trading_signal_listeners:
            try:
                listener.on_trading_signal(signal)
            except Exception as e:
                self.logger.exception(e)
                listener.on_trading_error(timestamp=signal.happen_timestamp,
                                          error=e)

    def on_finish(self):
        # show the result
        if self.draw_result:
            import plotly.io as pio
            pio.renderers.default = "browser"
            reader = AccountStatsReader(trader_names=[self.trader_name])
            df = reader.data_df
            drawer = Drawer(main_data=NormalData(
                df.copy()[['trader_name', 'timestamp', 'all_value']],
                category_field='trader_name'))
            drawer.draw_line(show=True)

    def select_long_targets(self, long_targets: List[str]) -> List[str]:
        if len(long_targets) > 10:
            return long_targets[0:10]
        return long_targets

    def select_short_targets(self, short_targets: List[str]) -> List[str]:
        if len(short_targets) > 10:
            return short_targets[0:10]
        return short_targets

    def in_trading_date(self, timestamp):
        return to_time_str(timestamp) in self.trading_dates

    def on_time(self, timestamp):
        self.logger.debug(f'current timestamp:{timestamp}')

    def run(self):
        # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level
        # timestamp represents the timestamp in kdata
        for timestamp in self.entity_schema.get_interval_timestamps(
                start_date=self.start_timestamp,
                end_date=self.end_timestamp,
                level=self.level):

            if not self.in_trading_date(timestamp=timestamp):
                continue

            if self.real_time:
                # all selector move on to handle the coming data
                if self.kdata_use_begin_time:
                    real_end_timestamp = timestamp + pd.Timedelta(
                        seconds=self.level.to_second())
                else:
                    real_end_timestamp = timestamp

                seconds = (now_pd_timestamp() -
                           real_end_timestamp).total_seconds()
                waiting_seconds = self.level.to_second() - seconds
                # meaning the future kdata not ready yet,we could move on to check
                if waiting_seconds > 0:
                    # iterate the selector from min to max which in finished timestamp kdata
                    for level in self.trading_level_asc:
                        if self.entity_schema.is_finished_kdata_timestamp(
                                timestamp=timestamp, level=level):
                            for selector in self.selectors:
                                if selector.level == level:
                                    selector.move_on(timestamp,
                                                     self.kdata_use_begin_time,
                                                     timeout=waiting_seconds +
                                                     20)

            # on_trading_open to setup the account
            if self.level == IntervalLevel.LEVEL_1DAY or (
                    self.level != IntervalLevel.LEVEL_1DAY
                    and self.entity_schema.is_open_timestamp(timestamp)):
                self.account_service.on_trading_open(timestamp)

            self.on_time(timestamp=timestamp)

            if self.selectors:
                for level in self.trading_level_asc:
                    # in every cycle, all level selector do its job in its time
                    if self.entity_schema.is_finished_kdata_timestamp(
                            timestamp=timestamp, level=level):
                        all_long_targets = []
                        all_short_targets = []
                        for selector in self.selectors:
                            if selector.level == level:
                                long_targets = selector.get_open_long_targets(
                                    timestamp=timestamp)
                                long_targets = self.select_long_targets(
                                    long_targets)

                                short_targets = selector.get_open_short_targets(
                                    timestamp=timestamp)
                                short_targets = self.select_short_targets(
                                    short_targets)

                                all_long_targets += long_targets
                                all_short_targets += short_targets

                        if all_long_targets or all_short_targets:
                            self.targets_slot.input_targets(
                                level, all_long_targets, all_short_targets)
                            # the time always move on by min level step and we could check all level targets in the slot
                            # 1)the targets is generated for next interval
                            # 2)the acceptable price is next interval prices,you could buy it at current price if the time is before the timestamp(due_timestamp) when trading signal received
                            # 3)the suggest price the the close price for generating the signal(happen_timestamp)
                            due_timestamp = timestamp + pd.Timedelta(
                                seconds=self.level.to_second())
                            if level == self.level:
                                self.handle_targets_slot(
                                    due_timestamp=due_timestamp,
                                    happen_timestamp=timestamp)

            # on_trading_close to calculate date account
            if self.level == IntervalLevel.LEVEL_1DAY or (
                    self.level != IntervalLevel.LEVEL_1DAY
                    and self.entity_schema.is_close_timestamp(timestamp)):
                self.account_service.on_trading_close(timestamp)

        self.on_finish()
Beispiel #21
0
    def record_data(cls,
                    provider_index: int = 0,
                    provider: str = None,
                    exchanges=None,
                    entity_ids=None,
                    codes=None,
                    batch_size=None,
                    force_update=None,
                    sleeping_time=None,
                    default_size=None,
                    real_time=None,
                    fix_duplicate_way=None,
                    start_timestamp=None,
                    end_timestamp=None,
                    close_hour=None,
                    close_minute=None,
                    one_day_trading_minutes=None,
                    **kwargs):
        if cls.provider_map_recorder:
            print(
                f'{cls.__name__} registered recorders:{cls.provider_map_recorder}'
            )

            if provider:
                recorder_class = cls.provider_map_recorder[provider]
            else:
                recorder_class = cls.provider_map_recorder[
                    cls.providers[provider_index]]

            # get args for specific recorder class
            from zvt.contract.recorder import TimeSeriesDataRecorder
            if issubclass(recorder_class, TimeSeriesDataRecorder):
                args = [
                    item
                    for item in inspect.getfullargspec(cls.record_data).args
                    if item not in ('cls', 'provider_index', 'provider')
                ]
            else:
                args = ['batch_size', 'force_update', 'sleeping_time']

            # just fill the None arg to kw,so we could use the recorder_class default args
            kw = {}
            for arg in args:
                tmp = eval(arg)
                if tmp is not None:
                    kw[arg] = tmp

            # FixedCycleDataRecorder
            from zvt.contract.recorder import FixedCycleDataRecorder
            if issubclass(recorder_class, FixedCycleDataRecorder):
                # contract:
                # 1)use FixedCycleDataRecorder to record the data with IntervalLevel
                # 2)the table of schema with IntervalLevel format is {entity}_{level}_[adjust_type]_{event}
                table: str = cls.__tablename__
                try:
                    items = table.split('_')
                    if len(items) == 4:
                        adjust_type = items[2]
                        kw['adjust_type'] = adjust_type
                    level = IntervalLevel(items[1])
                except:
                    # for other schema not with normal format,but need to calculate size for remaining days
                    level = IntervalLevel.LEVEL_1DAY

                kw['level'] = level

                # add other custom args
                for k in kwargs:
                    kw[k] = kwargs[k]

                r = recorder_class(**kw)
                r.run()
                return
            else:
                r = recorder_class(**kw)
                r.run()
                return
        else:
            print(f'no recorders for {cls.__name__}')
Beispiel #22
0
    def __init__(
        self,
        data_schema: Type[Mixin],
        entity_schema: Type[TradableEntity],
        provider: str = None,
        entity_provider: str = None,
        entity_ids: List[str] = None,
        exchanges: List[str] = None,
        codes: List[str] = None,
        start_timestamp: Union[str, pd.Timestamp] = None,
        end_timestamp: Union[str, pd.Timestamp] = now_pd_timestamp(),
        columns: List = None,
        filters: List = None,
        order: object = None,
        limit: int = None,
        level: IntervalLevel = None,
        category_field: str = "entity_id",
        time_field: str = "timestamp",
        computing_window: int = None,
    ) -> None:
        self.logger = logging.getLogger(self.__class__.__name__)

        self.data_schema = data_schema
        self.entity_schema = entity_schema

        self.provider = provider
        self.entity_provider = entity_provider

        self.start_timestamp = start_timestamp
        self.end_timestamp = end_timestamp

        self.start_timestamp = to_pd_timestamp(self.start_timestamp)
        self.end_timestamp = to_pd_timestamp(self.end_timestamp)

        self.exchanges = exchanges

        if codes:
            if type(codes) == str:
                codes = codes.replace(" ", "")
                if codes.startswith("[") and codes.endswith("]"):
                    codes = json.loads(codes)
                else:
                    codes = codes.split(",")

        self.codes = codes
        self.entity_ids = entity_ids

        # 转换成标准entity_id
        if entity_schema and not self.entity_ids:
            df = get_entities(entity_schema=entity_schema,
                              provider=self.entity_provider,
                              exchanges=self.exchanges,
                              codes=self.codes)
            if pd_is_not_null(df):
                self.entity_ids = df["entity_id"].to_list()

        self.filters = filters
        self.order = order
        self.limit = limit

        if level:
            self.level = IntervalLevel(level)
        else:
            self.level = level

        self.category_field = category_field
        self.time_field = time_field
        self.computing_window = computing_window

        self.category_col = eval("self.data_schema.{}".format(
            self.category_field))
        self.time_col = eval("self.data_schema.{}".format(self.time_field))

        self.columns = columns

        # we store the data in a multiple index(category_column,timestamp) Dataframe
        if self.columns:
            # support str
            if type(columns[0]) == str:
                self.columns = []
                for col in columns:
                    self.columns.append(eval("data_schema.{}".format(col)))

            # always add category_column and time_field for normalizing
            self.columns = list(
                set(self.columns) | {self.category_col, self.time_col})

        self.data_listeners: List[DataListener] = []

        self.data_df: pd.DataFrame = None

        self.load_data()
Beispiel #23
0
class Trader(object):
    entity_schema: EntityMixin = None

    def __init__(self,
                 region: Region,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = None,
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: Provider = Provider.Default,
                 level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False,
                 draw_result: bool = True,
                 rich_mode: bool = True) -> None:
        assert self.entity_schema is not None

        self.logger = logging.getLogger(__name__)

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners: List[TradingListener] = []

        self.selectors: List[TargetSelector] = []

        self.entity_ids = entity_ids

        self.exchanges = exchanges
        self.codes = codes

        self.region = region
        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = IntervalLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        self.trading_dates = self.entity_schema.get_trading_dates(
            start_date=self.start_timestamp, end_date=self.end_timestamp)

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever'
            )
            assert self.end_timestamp >= now_pd_timestamp(self.region)

        self.kdata_use_begin_time = kdata_use_begin_time
        self.draw_result = draw_result
        self.rich_mode = rich_mode

        self.account_service = SimAccountService(
            entity_schema=self.entity_schema,
            trader_name=self.trader_name,
            timestamp=self.start_timestamp,
            provider=self.provider,
            level=self.level,
            rich_mode=rich_mode)

        self.register_trading_signal_listener(self.account_service)

        self.init_selectors(entity_ids=entity_ids,
                            entity_schema=self.entity_schema,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        if self.selectors:
            self.trading_level_asc = list(
                set([
                    IntervalLevel(selector.level)
                    for selector in self.selectors
                ]))
            self.trading_level_asc.sort()

            self.logger.info(
                f'trader level:{self.level},selectors level:{self.trading_level_asc}'
            )

            if self.level != self.trading_level_asc[0]:
                raise Exception(
                    "trader level should be the min of the selectors")

            self.trading_level_desc = list(self.trading_level_asc)
            self.trading_level_desc.reverse()

        self.session = get_db_session('zvt', data_schema=TraderInfo)

        self.level_map_long_targets = {}
        self.level_map_short_targets = {}
        self.trading_signals: List[TradingSignal] = []

        self.on_start()

    def on_start(self):
        # run all the selectors
        for selector in self.selectors:
            # run for the history data at first
            selector.run()

        if self.entity_ids:
            entity_ids = json.dumps(self.entity_ids)
        else:
            entity_ids = None

        if self.exchanges:
            exchanges = json.dumps(self.exchanges)
        else:
            exchanges = None

        if self.codes:
            codes = json.dumps(self.codes)
        else:
            codes = None

        sim_account = TraderInfo(
            id=self.trader_name,
            entity_id=f'trader_zvt_{self.trader_name}',
            timestamp=self.start_timestamp,
            trader_name=self.trader_name,
            entity_ids=entity_ids,
            exchanges=exchanges,
            codes=codes,
            start_timestamp=self.start_timestamp,
            end_timestamp=self.end_timestamp,
            provider=self.provider,
            level=self.level.value,
            real_time=self.real_time,
            kdata_use_begin_time=self.kdata_use_begin_time)
        self.session.add(sim_account)
        self.session.commit()

    def init_selectors(self, entity_ids, entity_schema, exchanges, codes,
                       start_timestamp, end_timestamp):
        """
        overwrite it to init selectors if you want to use selector/factor computing model or just write strategy in on_time

        """
        pass

    def register_trading_signal_listener(self, listener):
        if listener not in self.trading_signal_listeners:
            self.trading_signal_listeners.append(listener)

    def deregister_trading_signal_listener(self, listener):
        if listener in self.trading_signal_listeners:
            self.trading_signal_listeners.remove(listener)

    def set_long_targets_by_level(self, level: IntervalLevel,
                                  targets: List[str]) -> None:
        logger.debug(
            f'level:{level},old long targets:{self.level_map_long_targets.get(level)},new long targets:{targets}'
        )
        self.level_map_long_targets[level] = targets

    def set_short_targets_by_level(self, level: IntervalLevel,
                                   targets: List[str]) -> None:
        logger.debug(
            f'level:{level},old short targets:{self.level_map_short_targets.get(level)},new short targets:{targets}'
        )
        self.level_map_short_targets[level] = targets

    def get_long_targets_by_level(self, level: IntervalLevel) -> List[str]:
        return self.level_map_long_targets.get(level)

    def get_short_targets_by_level(self, level: IntervalLevel) -> List[str]:
        return self.level_map_short_targets.get(level)

    def select_long_targets_from_levels(self, timestamp):
        """
        overwrite it to select long targets from multiple levels,the default implementation is selecting the targets in all level

        :param timestamp:

        """

        long_selected = None

        for level in self.trading_level_desc:
            long_targets = self.level_map_long_targets.get(level)
            if long_targets:
                long_targets = set(long_targets)
                if not long_selected:
                    long_selected = long_targets
                else:
                    long_selected = long_selected & long_targets
        return long_selected

    def select_short_targets_from_levels(self, timestamp):
        """
        overwrite it to select short targets from multiple levels,the default implementation is selecting the targets in all level

        :param timestamp:

        """
        short_selected = None
        for level in self.trading_level_desc:
            short_targets = self.level_map_short_targets.get(level)
            if short_targets:
                short_targets = set(short_targets)
                if not short_selected:
                    short_selected = short_targets
                else:
                    short_selected = short_selected & short_targets
        return short_selected

    def get_current_account(self) -> AccountStats:
        return self.account_service.account

    def get_current_positions(self) -> List[Position]:
        return self.get_current_account().positions

    def long_position_control(self):
        positions = self.get_current_positions()

        position_pct = 1.0
        if not positions:
            position_pct = 0.2
        elif len(positions) <= 10:
            position_pct = 0.5
        return position_pct

    def short_position_control(self):
        return 1.0

    def buy(self,
            due_timestamp,
            happen_timestamp,
            entity_ids,
            ignore_in_position=True):
        if ignore_in_position:
            account = self.get_current_account()
            current_holdings = []
            if account.positions:
                current_holdings = [
                    position.entity_id for position in account.positions
                    if position != None and position.available_long > 0
                ]

            entity_ids = set(entity_ids) - set(current_holdings)

        if entity_ids:
            position_pct = self.long_position_control()
            position_pct = (1.0 / len(entity_ids)) * position_pct

            for entity_id in entity_ids:
                trading_signal = TradingSignal(
                    entity_id=entity_id,
                    due_timestamp=due_timestamp,
                    happen_timestamp=happen_timestamp,
                    trading_signal_type=TradingSignalType.open_long,
                    trading_level=self.level,
                    position_pct=position_pct)
                self.trading_signals.append(trading_signal)

    def sell(self, due_timestamp, happen_timestamp, entity_ids):
        # current position
        account = self.get_current_account()
        current_holdings = []
        if account.positions:
            current_holdings = [
                position.entity_id for position in account.positions
                if position != None and position.available_long > 0
            ]

        shorted = set(current_holdings) & set(entity_ids)

        if shorted:
            position_pct = self.short_position_control()

            for entity_id in shorted:
                trading_signal = TradingSignal(
                    entity_id=entity_id,
                    due_timestamp=due_timestamp,
                    happen_timestamp=happen_timestamp,
                    trading_signal_type=TradingSignalType.close_long,
                    trading_level=self.level,
                    position_pct=position_pct)
                self.trading_signals.append(trading_signal)

    def trade_the_targets(self, due_timestamp, happen_timestamp, long_selected,
                          short_selected):
        if short_selected:
            self.sell(due_timestamp=due_timestamp,
                      happen_timestamp=happen_timestamp,
                      entity_ids=short_selected)
        if long_selected:
            self.buy(due_timestamp=due_timestamp,
                     happen_timestamp=happen_timestamp,
                     entity_ids=long_selected)

    def on_finish(self, timestmap):
        self.on_trading_finish(timestmap)
        # show the result
        if self.draw_result:
            import plotly.io as pio
            pio.renderers.default = "browser"
            reader = AccountStatsReader(trader_names=[self.trader_name])
            df = reader.data_df
            drawer = Drawer(main_data=NormalData(
                df.copy()[['trader_name', 'timestamp', 'all_value']],
                category_field='trader_name'))
            drawer.draw_line(show=True)

    def filter_selector_long_targets(self, timestamp, selector: TargetSelector,
                                     long_targets: List[str]) -> List[str]:
        if len(long_targets) > 10:
            return long_targets[0:10]
        return long_targets

    def filter_selector_short_targets(self, timestamp,
                                      selector: TargetSelector,
                                      short_targets: List[str]) -> List[str]:
        if len(short_targets) > 10:
            return short_targets[0:10]
        return short_targets

    def in_trading_date(self, timestamp):
        return to_time_str(timestamp) in self.trading_dates

    def on_time(self, timestamp):
        self.logger.debug(f'current timestamp:{timestamp}')

    def on_trading_signals(self, trading_signals: List[TradingSignal]):
        for l in self.trading_signal_listeners:
            l.on_trading_signals(trading_signals)

    def on_trading_signal(self, trading_signal: TradingSignal):
        for l in self.trading_signal_listeners:
            try:
                l.on_trading_signal(trading_signal)
            except Exception as e:
                self.logger.exception(e)
                l.on_trading_error(timestamp=trading_signal.happen_timestamp,
                                   error=e)

    def on_trading_open(self, timestamp):
        for l in self.trading_signal_listeners:
            l.on_trading_open(timestamp)

    def on_trading_close(self, timestamp):
        for l in self.trading_signal_listeners:
            l.on_trading_close(timestamp)

    def on_trading_finish(self, timestamp):
        for l in self.trading_signal_listeners:
            l.on_trading_finish(timestamp)

    def on_trading_error(self, timestamp, error):
        for l in self.trading_signal_listeners:
            l.on_trading_error(timestamp, error)

    def run(self):
        now = now_pd_timestamp(self.region)
        # iterate timestamp of the min level,e.g,9:30,9:35,9.40...for 5min level
        # timestamp represents the timestamp in kdata
        for timestamp in self.entity_schema.get_interval_timestamps(
                start_date=self.start_timestamp,
                end_date=self.end_timestamp,
                level=self.level):

            if not self.in_trading_date(timestamp=timestamp):
                continue

            waiting_seconds = 0

            if self.level == IntervalLevel.LEVEL_1DAY:
                if is_same_date(timestamp, now):
                    while True:
                        self.logger.info(
                            f'time is:{now},just smoke for minutes')
                        time.sleep(60)
                        if now.hour >= 19:
                            waiting_seconds = 20
                            break

            elif self.real_time:
                # all selector move on to handle the coming data
                if self.kdata_use_begin_time:
                    real_end_timestamp = timestamp + pd.Timedelta(
                        seconds=self.level.to_second())
                else:
                    real_end_timestamp = timestamp

                seconds = (now - real_end_timestamp).total_seconds()
                waiting_seconds = self.level.to_second() - seconds

            # meaning the future kdata not ready yet,we could move on to check
            if waiting_seconds > 0:
                # iterate the selector from min to max which in finished timestamp kdata
                for level in self.trading_level_asc:
                    if self.entity_schema.is_finished_kdata_timestamp(
                            timestamp=timestamp, level=level):
                        for selector in self.selectors:
                            if selector.level == level:
                                selector.move_on(timestamp,
                                                 self.kdata_use_begin_time,
                                                 timeout=waiting_seconds + 20)

            # on_trading_open to setup the account
            if self.level >= IntervalLevel.LEVEL_1DAY or (
                    self.level != IntervalLevel.LEVEL_1DAY
                    and self.entity_schema.is_open_timestamp(timestamp)):
                self.on_trading_open(timestamp)

            self.on_time(timestamp=timestamp)

            if self.selectors:
                for level in self.trading_level_asc:
                    # in every cycle, all level selector do its job in its time
                    if self.entity_schema.is_finished_kdata_timestamp(
                            timestamp=timestamp, level=level):
                        all_long_targets = []
                        all_short_targets = []
                        for selector in self.selectors:
                            if selector.level == level:
                                long_targets = selector.get_open_long_targets(
                                    timestamp=timestamp)
                                long_targets = self.filter_selector_long_targets(
                                    timestamp=timestamp,
                                    selector=selector,
                                    long_targets=long_targets)

                                short_targets = selector.get_open_short_targets(
                                    timestamp=timestamp)
                                short_targets = self.filter_selector_short_targets(
                                    timestamp=timestamp,
                                    selector=selector,
                                    short_targets=short_targets)

                                if long_targets:
                                    all_long_targets += long_targets
                                if short_targets:
                                    all_short_targets += short_targets

                        if all_long_targets:
                            self.set_long_targets_by_level(
                                level, all_long_targets)
                        if all_short_targets:
                            self.set_short_targets_by_level(
                                level, all_short_targets)

                        # the time always move on by min level step and we could check all targets of levels
                        # 1)the targets is generated for next interval
                        # 2)the acceptable price is next interval prices,you could buy it at current price
                        # if the time is before the timestamp(due_timestamp) when trading signal received
                        # 3)the suggest price the the close price for generating the signal(happen_timestamp)
                        due_timestamp = timestamp + pd.Timedelta(
                            seconds=self.level.to_second())
                        if level == self.level:
                            long_selected = self.select_long_targets_from_levels(
                                timestamp)
                            short_selected = self.select_short_targets_from_levels(
                                timestamp)

                            self.logger.debug(
                                'timestamp:{},long_selected:{}'.format(
                                    due_timestamp, long_selected))

                            self.logger.debug(
                                'timestamp:{},short_selected:{}'.format(
                                    due_timestamp, short_selected))

                            self.trade_the_targets(
                                due_timestamp=due_timestamp,
                                happen_timestamp=timestamp,
                                long_selected=long_selected,
                                short_selected=short_selected)

            if self.trading_signals:
                self.on_trading_signals(self.trading_signals)
            # clear
            self.trading_signals = []

            # on_trading_close to calculate date account
            if self.level >= IntervalLevel.LEVEL_1DAY or (
                    self.level != IntervalLevel.LEVEL_1DAY
                    and self.entity_schema.is_close_timestamp(timestamp)):
                self.on_trading_close(timestamp)

        self.on_finish(timestamp)
Beispiel #24
0
def is_finished_kdata_timestamp(timestamp, level: IntervalLevel):
    timestamp = to_pd_timestamp(timestamp)
    if level.floor_timestamp(timestamp) == timestamp:
        return True
    return False
Beispiel #25
0
def next_timestamp(current_timestamp: pd.Timestamp, level: IntervalLevel) -> pd.Timestamp:
    current_timestamp = to_pd_timestamp(current_timestamp)
    return current_timestamp + pd.Timedelta(seconds=level.to_second())
Beispiel #26
0
def evaluate_size_from_timestamp(start_timestamp: pd.Timestamp,
                                 end_timestamp: pd.Timestamp,
                                 level: IntervalLevel,
                                 one_day_trading_minutes,
                                 trade_day=None):
    """
    given from timestamp,level,one_day_trading_minutes,this func evaluate size of kdata to current.
    it maybe a little bigger than the real size for fetching all the kdata.

    :param start_timestamp:
    :type start_timestamp: pd.Timestamp
    :param level:
    :type level: IntervalLevel
    :param one_day_trading_minutes:
    :type one_day_trading_minutes: int
    """
    # if not end_timestamp:
    #     end_timestamp = now_pd_timestamp()
    # else:
    #     end_timestamp = to_pd_timestamp(end_timestamp)

    time_delta = end_timestamp - to_pd_timestamp(start_timestamp)

    one_day_trading_seconds = one_day_trading_minutes * 60

    if level == IntervalLevel.LEVEL_1MON:
        if trade_day is not None:
            try:
                size = int(math.ceil(trade_day.index(start_timestamp) / 22))
                size = 0 if size == 0 else size + 1
                return size
            except ValueError as _:
                if start_timestamp < trade_day[-1]:
                    return int(math.ceil(len(trade_day) / 22))
                # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e))
        return int(math.ceil(time_delta.days / 30))

    if level == IntervalLevel.LEVEL_1WEEK:
        if trade_day is not None:
            try:
                size = int(math.ceil(trade_day.index(start_timestamp) / 5))
                size = 0 if size == 0 else size + 1
                return size
            except ValueError as _:
                if start_timestamp < trade_day[-1]:
                    return int(math.ceil(len(trade_day) / 5))
                # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e))
        return int(math.ceil(time_delta.days / 7))

    if level == IntervalLevel.LEVEL_1DAY:
        if trade_day is not None and len(trade_day) > 0:
            try:
                return trade_day.index(start_timestamp)
            except ValueError as _:
                if start_timestamp < trade_day[-1]:
                    return len(trade_day)
                # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e))
        return time_delta.days

    if level == IntervalLevel.LEVEL_1HOUR:
        if trade_day is not None:
            start_date = start_timestamp.replace(hour=0, minute=0, second=0)
            try:
                days = trade_day.index(start_date)
                time = datetime.datetime.time(start_timestamp)
                size = (days) * 4 + int(math.ceil(count_hours_from_day(time)))
                return size
            except ValueError as _:
                if start_date < trade_day[-1]:
                    return len(trade_day) * 4
                # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e))
        return int(math.ceil(time_delta.days * 4 * 2))

    if level == IntervalLevel.LEVEL_30MIN:
        if trade_day is not None:
            start_date = start_timestamp.replace(hour=0, minute=0, second=0)
            try:
                days = trade_day.index(start_date)
                time = datetime.datetime.time(start_timestamp)
                size = (days) * 4 * 2 + int(
                    math.ceil(count_mins_from_day(time) / 5))
                return size
            except ValueError as _:
                if start_date < trade_day[-1]:
                    return len(trade_day) * 4 * 2
                # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e))
        return int(math.ceil(time_delta.days * 4 * 2))

    if level == IntervalLevel.LEVEL_15MIN:
        if trade_day is not None:
            start_date = start_timestamp.replace(hour=0, minute=0, second=0)
            try:
                days = trade_day.index(start_date)
                time = datetime.datetime.time(start_timestamp)
                size = (days) * 4 * 4 + int(
                    math.ceil(count_mins_from_day(time) / 5))
                return size
            except ValueError as _:
                if start_date < trade_day[-1]:
                    return len(trade_day) * 4 * 4
                # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e))
        return int(math.ceil(time_delta.days * 4 * 4))

    if level == IntervalLevel.LEVEL_5MIN:
        if trade_day is not None:
            start_date = start_timestamp.replace(hour=0, minute=0, second=0)
            try:
                days = trade_day.index(start_date)
                time = datetime.datetime.time(start_timestamp)
                size = (days) * 4 * 12 + int(
                    math.ceil(count_mins_from_day(time) / 5))
                return size
            except ValueError as _:
                if start_date < trade_day[-1]:
                    return len(trade_day) * 4 * 12
                # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e))
        return int(math.ceil(time_delta.days * 4 * 12))

    if level == IntervalLevel.LEVEL_1MIN:
        if trade_day is not None:
            start_date = start_timestamp.replace(hour=0, minute=0, second=0)
            try:
                days = trade_day.index(start_date)
                time = datetime.datetime.time(start_timestamp)
                size = (days) * 4 * 60 + count_mins_from_day(time)
                return size
            except ValueError as _:
                if start_date < trade_day[-1]:
                    return len(trade_day) * 4 * 60
                # raise Exception("wrong start time:{}, error:{}".format(start_timestamp, e))
        return int(math.ceil(time_delta.days * 4 * 60))

    if time_delta.days > 0:
        seconds = (time_delta.days + 1) * one_day_trading_seconds
        return int(math.ceil(seconds / level.to_second()))
    else:
        seconds = time_delta.total_seconds()
        return min(int(math.ceil(seconds / level.to_second())),
                   one_day_trading_seconds / level.to_second())
Beispiel #27
0
    def __init__(self,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = None,
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: str = None,
                 level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False,
                 draw_result: bool = True,
                 rich_mode: bool = False,
                 adjust_type: AdjustType = None,
                 profit_threshold=(3, -0.3),
                 keep_history=False) -> None:
        assert self.entity_schema is not None
        assert start_timestamp is not None
        assert end_timestamp is not None

        self.logger = logging.getLogger(__name__)

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.entity_ids = entity_ids
        self.exchanges = exchanges
        self.codes = codes
        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = IntervalLevel(level)
        self.real_time = real_time
        self.start_timestamp = to_pd_timestamp(start_timestamp)
        self.end_timestamp = to_pd_timestamp(end_timestamp)

        self.trading_dates = self.entity_schema.get_trading_dates(start_date=self.start_timestamp,
                                                                  end_date=self.end_timestamp)

        if real_time:
            self.logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever')
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time
        self.draw_result = draw_result
        self.rich_mode = rich_mode

        self.adjust_type = AdjustType(adjust_type)
        self.profit_threshold = profit_threshold
        self.keep_history = keep_history

        self.level_map_long_targets = {}
        self.level_map_short_targets = {}
        self.trading_signals: List[TradingSignal] = []
        self.trading_signal_listeners: List[TradingListener] = []
        self.selectors: List[TargetSelector] = []

        self.account_service = SimAccountService(entity_schema=self.entity_schema,
                                                 trader_name=self.trader_name,
                                                 timestamp=self.start_timestamp,
                                                 provider=self.provider,
                                                 level=self.level,
                                                 rich_mode=self.rich_mode,
                                                 adjust_type=self.adjust_type,
                                                 keep_history=self.keep_history)

        self.register_trading_signal_listener(self.account_service)

        self.init_selectors(entity_ids=self.entity_ids, entity_schema=self.entity_schema, exchanges=self.exchanges,
                            codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp,
                            adjust_type=self.adjust_type)

        if self.selectors:
            self.trading_level_asc = list(set([IntervalLevel(selector.level) for selector in self.selectors]))
            self.trading_level_asc.sort()

            self.logger.info(f'trader level:{self.level},selectors level:{self.trading_level_asc}')

            if self.level != self.trading_level_asc[0]:
                raise Exception("trader level should be the min of the selectors")

            self.trading_level_desc = list(self.trading_level_asc)
            self.trading_level_desc.reverse()

            # run selectors for history data at first
            for selector in self.selectors:
                selector.run()

        self.on_start()
Beispiel #28
0
    def __init__(self,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = None,
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: str = None,
                 level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False,
                 draw_result: bool = True) -> None:
        assert self.entity_schema is not None

        self.logger = logging.getLogger(__name__)

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners: List[TradingListener] = []

        self.selectors: List[TargetSelector] = []

        self.entity_ids = entity_ids

        self.exchanges = exchanges
        self.codes = codes

        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = IntervalLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        self.trading_dates = self.entity_schema.get_trading_dates(
            start_date=self.start_timestamp, end_date=self.end_timestamp)

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever'
            )
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time
        self.draw_result = draw_result

        self.account_service = SimAccountService(
            entity_schema=self.entity_schema,
            trader_name=self.trader_name,
            timestamp=self.start_timestamp,
            provider=self.provider,
            level=self.level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(entity_ids=entity_ids,
                            entity_schema=self.entity_schema,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        if self.selectors:
            self.trading_level_asc = list(
                set([
                    IntervalLevel(selector.level)
                    for selector in self.selectors
                ]))
            self.trading_level_asc.sort()

            self.logger.info(
                f'trader level:{self.level},selectors level:{self.trading_level_asc}'
            )

            if self.level != self.trading_level_asc[0]:
                raise Exception(
                    "trader level should be the min of the selectors")

            self.trading_level_desc = list(self.trading_level_asc)
            self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()

        self.session = get_db_session('zvt', data_schema=TraderInfo)
        self.on_start()
Beispiel #29
0
def gen_kdata_schema(pkg: str,
                     providers: List[str],
                     entity_type: str,
                     levels: List[IntervalLevel],
                     adjust_types: List[AdjustType] = [None],
                     entity_in_submodule: bool = False,
                     kdata_module='quotes'):
    tables = []

    base_path = './domain'

    if kdata_module:
        base_path = os.path.join(base_path, kdata_module)
    if entity_in_submodule:
        base_path = os.path.join(base_path, entity_type)

    for level in levels:

        for adjust_type in adjust_types:
            level = IntervalLevel(level)

            cap_entity_type = entity_type.capitalize()
            cap_level = level.value.capitalize()

            # you should define {EntityType}KdataCommon in kdata_module at first
            kdata_common = f'{cap_entity_type}KdataCommon'

            if adjust_type and (adjust_type != AdjustType.qfq):
                class_name = f'{cap_entity_type}{cap_level}{adjust_type.value.capitalize()}Kdata'
                table_name = f'{entity_type}_{level.value}_{adjust_type.value.lower()}_kdata'

            else:
                class_name = f'{cap_entity_type}{cap_level}Kdata'
                table_name = f'{entity_type}_{level.value}_kdata'

            tables.append(table_name)

            schema_template = f'''# -*- coding: utf-8 -*-
# this file is generated by gen_kdata_schema function, dont't change it
from sqlalchemy.orm import declarative_base

from zvt.contract.register import register_schema
from {pkg}.domain.{kdata_module} import {kdata_common}

KdataBase = declarative_base()


class {class_name}(KdataBase, {kdata_common}):
    __tablename__ = '{table_name}'


register_schema(providers={providers}, db_name='{table_name}', schema_base=KdataBase, entity_type='{entity_type}')

'''
            # generate the schema
            with open(os.path.join(base_path, f'{table_name}.py'),
                      'w') as outfile:
                outfile.write(schema_template)

        # generate the package
        pkg_file = os.path.join(base_path, '__init__.py')
        if not os.path.exists(pkg_file):
            package_template = '''# -*- coding: utf-8 -*-
'''
            with open(pkg_file, 'w') as outfile:
                outfile.write(package_template)

    # generate exports
    gen_exports('./domain')
Beispiel #30
0
def get_kdata(entity_id, level=IntervalLevel.LEVEL_1DAY, adjust_type=AdjustType.qfq, limit=10000):
    entity_type, exchange, code = decode_entity_id(entity_id)
    level = IntervalLevel(level)

    sec_id = to_em_sec_id(entity_id)
    fq_flag = to_em_fq_flag(adjust_type)
    level_flag = to_em_level_flag(level)
    # f131 结算价
    # f133 持仓
    # 目前未获取
    url = f"https://push2his.eastmoney.com/api/qt/stock/kline/get?secid={sec_id}&klt={level_flag}&fqt={fq_flag}&lmt={limit}&end=20500000&iscca=1&fields1=f1,f2,f3,f4,f5,f6,f7,f8&fields2=f51,f52,f53,f54,f55,f56,f57,f58,f59,f60,f61,f62,f63,f64&ut=f057cbcbce2a86e2866ab8877db1d059&forcect=1"

    resp = requests.get(url, headers=DEFAULT_HEADER)
    resp.raise_for_status()
    results = resp.json()
    data = results["data"]

    kdatas = []

    if data:
        klines = data["klines"]
        name = data["name"]

        for result in klines:
            # "2000-01-28,1005.26,1012.56,1173.12,982.13,3023326,3075552000.00"
            # "2021-08-27,19.39,20.30,20.30,19.25,1688497,3370240912.00,5.48,6.01,1.15,3.98,0,0,0"
            # time,open,close,high,low,volume,turnover
            # "2022-04-13,10708,10664,10790,10638,402712,43124771328,1.43,0.57,60,0.00,4667112399583576064,4690067230254170112,1169270784"
            fields = result.split(",")
            the_timestamp = to_pd_timestamp(fields[0])

            the_id = generate_kdata_id(entity_id=entity_id, timestamp=the_timestamp, level=level)

            open = to_float(fields[1])
            close = to_float(fields[2])
            high = to_float(fields[3])
            low = to_float(fields[4])
            volume = to_float(fields[5])
            turnover = to_float(fields[6])
            # 7 振幅
            change_pct = value_to_pct(to_float(fields[8]))
            # 9 变动
            turnover_rate = value_to_pct(to_float(fields[10]))

            kdatas.append(
                dict(
                    id=the_id,
                    timestamp=the_timestamp,
                    entity_id=entity_id,
                    provider="em",
                    code=code,
                    name=name,
                    level=level.value,
                    open=open,
                    close=close,
                    high=high,
                    low=low,
                    volume=volume,
                    turnover=turnover,
                    turnover_rate=turnover_rate,
                    change_pct=change_pct,
                )
            )
    if kdatas:
        df = pd.DataFrame.from_records(kdatas)
        return df