Пример #1
0
    def __init__(self,
                 batch_size: int = 10,
                 force_update: bool = False,
                 sleeping_time: int = 10) -> None:
        """

        :param batch_size:batch size to saving to db
        :type batch_size:int
        :param force_update: whether force update the data even if it exists,please set it to True if the data need to
        be refreshed from the provider
        :type force_update:bool
        :param sleeping_time:sleeping seconds for recoding loop
        :type sleeping_time:int
        """
        self.logger = logging.getLogger(self.__class__.__name__)

        assert self.provider is not None
        assert self.data_schema is not None

        self.batch_size = batch_size
        self.force_update = force_update
        self.sleeping_time = sleeping_time

        # using to do db operations
        self.session = get_db_session(provider=self.provider,
                                      data_schema=self.data_schema)
Пример #2
0
def get_securities_in_blocks(
        provider: str = 'eastmoney',
        categories: List[Union[str, BlockCategory]] = ['concept', 'industry'],
        names=None,
        codes=None,
        ids=None):
    session = get_db_session(provider=provider, data_schema=Index)

    categories = [BlockCategory(category).value for category in categories]

    filters = [Index.category.in_(categories)]

    # add name filters
    if names:
        filters.append(Index.name.in_(names))

    blocks = get_entities(entity_ids=ids,
                          codes=codes,
                          entity_type='index',
                          provider=provider,
                          filters=filters,
                          return_type='domain',
                          session=session)
    securities = []
    for block in blocks:
        securities += [item.stock_id for item in block.stocks]

    return securities
Пример #3
0
def get_indices(provider: str = 'sina',
                block_category: Union[str, StockCategory] = 'concept',
                return_type: str = 'df') -> object:
    """
    get indices/blocks on block_category

    :param provider:
    :type provider:
    :param block_category:
    :type block_category:
    :param return_type:
    :type return_type:
    :return:
    :rtype:
    """
    if type(block_category) == StockCategory:
        block_category = block_category.value

    session = get_db_session(provider=provider, data_schema=Index)

    filters = [Index.category == block_category]
    blocks = get_entities(entity_type='index',
                          provider=provider,
                          filters=filters,
                          session=session,
                          return_type=return_type)
    return blocks
Пример #4
0
def get_group(provider, data_schema, column, group_func=func.count, session=None):
    if not session:
        session = get_db_session(provider=provider, data_schema=data_schema)
    if group_func:
        query = session.query(column, group_func(column)).group_by(column)
    else:
        query = session.query(column).group_by(column)
    df = pd.read_sql(query.statement, query.session.bind)
    return df
Пример #5
0
def get_stock_category(stock_id, session=None):
    local_session = False
    if not session:
        session = get_db_session(db_name='meta')
        local_session = True
    try:
        return session.query(Index).filter(Index.stocks.any(id=stock_id)).all()
    except Exception:
        raise
    finally:
        if local_session:
            session.close()
Пример #6
0
    def init_entities(self):
        if self.entity_provider == self.provider and self.entity_schema == self.data_schema:
            self.entity_session = self.session
        else:
            self.entity_session = get_db_session(provider=self.entity_provider, data_schema=self.entity_schema)

        # init the entity list
        self.entities = get_entities(session=self.entity_session,
                                     entity_type=self.entity_type,
                                     exchanges=self.exchanges,
                                     entity_ids=self.entity_ids,
                                     codes=self.codes,
                                     return_type='domain',
                                     provider=self.entity_provider)
Пример #7
0
    def __init__(self,
                 trader_name,
                 timestamp,
                 provider='joinquant',
                 level=IntervalLevel.LEVEL_1DAY,
                 base_capital=1000000,
                 buy_cost=0.001,
                 sell_cost=0.001,
                 slippage=0.001):

        self.base_capital = base_capital
        self.buy_cost = buy_cost
        self.sell_cost = sell_cost
        self.slippage = slippage
        self.trader_name = trader_name

        self.session = get_db_session('zvt', 'business')
        self.provider = provider
        self.level = level
        self.start_timestamp = timestamp

        account = get_account(session=self.session,
                              trader_name=self.trader_name,
                              return_type='domain',
                              limit=1)

        if account:
            self.logger.warning(
                "trader:{} has run before,old result would be deleted".format(
                    trader_name))
            self.session.query(SimAccount).filter(
                SimAccount.trader_name == self.trader_name).delete()
            self.session.query(Position).filter(
                Position.trader_name == self.trader_name).delete()
            self.session.query(Order).filter(
                Order.trader_name == self.trader_name).delete()
            self.session.commit()

        account = SimAccount(trader_name=self.trader_name,
                             cash=self.base_capital,
                             positions=[],
                             all_value=self.base_capital,
                             value=0,
                             closing=False,
                             timestamp=timestamp)
        self.latest_account = sim_account_schema.dump(account)
Пример #8
0
    def init_entities(self):
        self.entity_session = get_db_session(provider=self.entity_provider,
                                             data_schema=self.entity_schema)

        self.entities = get_entities(
            session=self.entity_session,
            entity_type='index',
            exchanges=self.exchanges,
            codes=self.codes,
            entity_ids=self.entity_ids,
            return_type='domain',
            provider=self.provider,
            # 只抓概念和行业
            filters=[
                Index.category.in_([
                    StockCategory.industry.value, StockCategory.concept.value
                ])
            ])
Пример #9
0
def stock_id_in_index(stock_id,
                      index_id,
                      session=None,
                      data_schema=StockIndex,
                      provider='eastmoney'):
    the_id = '{}_{}'.format(index_id, stock_id)
    local_session = False
    if not session:
        session = get_db_session(provider=provider, data_schema=data_schema)
        local_session = True

    try:
        return data_exist(session=session, schema=data_schema, id=the_id)
    except Exception:
        raise
    finally:
        if local_session:
            session.close()
Пример #10
0
from ..context import init_test_context

init_test_context()

from zvt.api.api import get_spo_detail, get_rights_issue_detail, get_dividend_financing
from zvt.domain import SpoDetail, RightsIssueDetail, DividendFinancing
from zvdata.contract import get_db_session
from zvdata.utils.time_utils import to_pd_timestamp

session = get_db_session(
    provider='eastmoney',
    db_name='dividend_financing')  # type: sqlalchemy.orm.Session


# 增发详情
def test_000778_spo_detial():
    result = get_spo_detail(session=session,
                            provider='eastmoney',
                            return_type='domain',
                            codes=['000778'],
                            end_timestamp='2018-09-30',
                            order=SpoDetail.timestamp.desc())
    assert len(result) == 4
    latest: SpoDetail = result[0]
    assert latest.timestamp == to_pd_timestamp('2017-04-01')
    assert latest.spo_issues == 347600000
    assert latest.spo_price == 5.15
    assert latest.spo_raising_fund == 1766000000


# 配股详情
Пример #11
0
    def __init__(self,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = ['sh', 'sz'],
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: str = None,
                 level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False,
                 draw_result: bool = True) -> None:

        assert self.entity_schema is not None

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners = []

        self.selectors: List[TargetSelector] = []

        self.entity_ids = entity_ids

        self.exchanges = exchanges
        self.codes = codes

        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = IntervalLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever'
            )
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time
        self.draw_result = draw_result

        self.account_service = SimAccountService(
            trader_name=self.trader_name,
            timestamp=self.start_timestamp,
            provider=self.provider,
            level=self.level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(entity_ids=entity_ids,
                            entity_schema=self.entity_schema,
                            exchanges=self.exchanges,
                            codes=self.codes,
                            start_timestamp=self.start_timestamp,
                            end_timestamp=self.end_timestamp)

        self.selectors_comparator = self.init_selectors_comparator()

        self.trading_level_asc = list(
            set([IntervalLevel(selector.level)
                 for selector in self.selectors]))
        self.trading_level_asc.sort()

        self.trading_level_desc = list(self.trading_level_asc)
        self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()

        self.session = get_db_session('zvt', 'business')
        trader = get_trader(session=self.session,
                            trader_name=self.trader_name,
                            return_type='domain',
                            limit=1)

        if trader:
            self.logger.warning(
                "trader:{} has run before,old result would be deleted".format(
                    self.trader_name))
            self.session.query(business.Trader).filter(
                business.Trader.trader_name == self.trader_name).delete()
            self.session.commit()
        self.on_start()
Пример #12
0
                assert a == entity_type
                ccxt_exchange = CCXTAccount.get_ccxt_exchange(exchange_str=exchange)

                if not ccxt_exchange:
                    raise Exception('{} not support'.format(exchange))

                orderbook = ccxt_exchange.fetch_order_book(code)

                bid = orderbook['bids'][0][0] if len(orderbook['bids']) > 0 else None
                ask = orderbook['asks'][0][0] if len(orderbook['asks']) > 0 else None
                entity_id = f'coin_{exchange}_{code}'
                result[entity_id] = (bid, ask)

    return result


if __name__ == '__main__':
    money_flow_session = get_db_session(provider='sina', data_schema=IndexMoneyFlow)

    entities = get_entities(entity_type='index',
                            return_type='domain', provider='sina',
                            # 只抓概念和行业
                            filters=[Index.category.in_(
                                [StockCategory.industry.value, StockCategory.concept.value])])

    for entity in entities:
        sql = 'UPDATE index_money_flow SET name="{}" where code="{}"'.format(
            entity.name, entity.code)
        money_flow_session.execute(sql)
        money_flow_session.commit()
Пример #13
0
def get_data(data_schema,
             ids: List[str] = None,
             entity_ids: List[str] = None,
             entity_id: str = None,
             codes: List[str] = None,
             code: str = None,
             level: Union[IntervalLevel, str] = None,
             provider: str = None,
             columns: List = None,
             return_type: str = 'df',
             start_timestamp: Union[pd.Timestamp, str] = None,
             end_timestamp: Union[pd.Timestamp, str] = None,
             filters: List = None,
             session: Session = None,
             order=None,
             limit: int = None,
             index: Union[str, list] = None,
             time_field: str = 'timestamp'):
    assert data_schema is not None
    assert provider is not None
    assert provider in global_providers

    if not session:
        session = get_db_session(provider=provider, data_schema=data_schema)

    time_col = eval('data_schema.{}'.format(time_field))

    if columns:
        # support str
        if type(columns[0]) == str:
            columns_ = []
            for col in columns:
                assert isinstance(col, str)
                columns_.append(eval('data_schema.{}'.format(col)))
            columns = columns_

        # make sure get timestamp
        if time_col not in columns:
            columns.append(time_col)

        query = session.query(*columns)
    else:
        query = session.query(data_schema)

    if entity_id:
        query = query.filter(data_schema.entity_id == entity_id)
    if entity_ids:
        query = query.filter(data_schema.entity_id.in_(entity_ids))
    if code:
        query = query.filter(data_schema.code == code)
    if codes:
        query = query.filter(data_schema.code.in_(codes))
    if ids:
        query = query.filter(data_schema.id.in_(ids))

    # we always store different level in different schema,the level param is not useful now
    if level:
        try:
            # some schema has no level,just ignore it
            data_schema.level
            if type(level) == IntervalLevel:
                level = level.value
            query = query.filter(data_schema.level == level)
        except Exception as e:
            pass

    query = common_filter(query, data_schema=data_schema, start_timestamp=start_timestamp,
                          end_timestamp=end_timestamp, filters=filters, order=order, limit=limit,
                          time_field=time_field)

    if return_type == 'df':
        df = pd.read_sql(query.statement, query.session.bind)
        if pd_is_not_null(df):
            if index:
                df = index_df(df, drop=False, index=index, time_field=time_field)
        return df
    elif return_type == 'domain':
        return query.all()
    elif return_type == 'dict':
        return [item.__dict__ for item in query.all()]
Пример #14
0
def df_to_db(df: pd.DataFrame,
             data_schema: DeclarativeMeta,
             provider: str,
             force_update: bool = False) -> object:
    """
    store the df to db

    :param df:
    :type df:
    :param data_schema:
    :type data_schema:
    :param provider:
    :type provider:
    :param force_update:
    :type force_update:
    :return:
    :rtype:
    """
    if not pd_is_not_null(df):
        return

    db_engine = get_db_engine(provider, data_schema=data_schema)

    schema_cols = get_schema_columns(data_schema)
    cols = set(df.columns.tolist()) & set(schema_cols)

    if not cols:
        print('wrong cols')
        return

    df = df[cols]

    size = len(df)
    sub_size = 5000

    if size >= sub_size:
        step_size = int(size / sub_size)
        if size % sub_size:
            step_size = step_size + 1
    else:
        step_size = 1

    for step in range(step_size):
        df_current = df.iloc[sub_size * step:sub_size * (step + 1)]
        if force_update:
            session = get_db_session(provider=provider, data_schema=data_schema)
            ids = df_current["id"].tolist()
            if len(ids) == 1:
                sql = f'delete from {data_schema.__tablename__} where id = "{ids[0]}"'
            else:
                sql = f'delete from {data_schema.__tablename__} where id in {tuple(ids)}'

            session.execute(sql)
            session.commit()

        else:
            current = get_data(data_schema=data_schema, columns=[data_schema.id], provider=provider,
                               ids=df_current['id'].tolist())
            if pd_is_not_null(current):
                df_current = df_current[~df_current['id'].isin(current['id'])]

        df_current.to_sql(data_schema.__tablename__, db_engine, index=False, if_exists='append')
Пример #15
0
from zvdata.contract import get_db_session
from zvt.api.api import get_top_ten_holder, get_top_ten_tradable_holder
from ..context import init_test_context

init_test_context()

from typing import List

from zvt.domain import TopTenHolder, TopTenTradableHolder

session = get_db_session(provider='eastmoney',
                         db_name='holder')  # type: sqlalchemy.orm.Session


# 十大股东
def test_000778_top_ten_holder():
    result: List[TopTenHolder] = get_top_ten_holder(
        session=session,
        provider='eastmoney',
        return_type='domain',
        codes=['000778'],
        end_timestamp='2018-09-30',
        start_timestamp='2018-09-30',
        order=TopTenHolder.shareholding_ratio.desc())
    assert len(result) == 10
    assert result[0].holder_name == '新兴际华集团有限公司'
    assert result[0].shareholding_numbers == 1595000000
    assert result[0].shareholding_ratio == 0.3996
    assert result[0].change == 32080000
    assert result[0].change_ratio == 0.0205
Пример #16
0
from zvdata.contract import get_db_session
from zvdata import IntervalLevel
from ..context import init_test_context

init_test_context()

from zvt.api import quote

day_k_session = get_db_session(
    provider='joinquant',
    db_name='stock_1d_kdata')  # type: sqlalchemy.orm.Session

day_1h_session = get_db_session(
    provider='joinquant',
    db_name='stock_1h_kdata')  # type: sqlalchemy.orm.Session


def test_jq_603220_kdata():
    df = quote.get_kdata(entity_id='stock_sh_603220',
                         session=day_k_session,
                         level=IntervalLevel.LEVEL_1DAY,
                         provider='joinquant')
    print(df)
    df = quote.get_kdata(entity_id='stock_sh_603220',
                         session=day_1h_session,
                         level=IntervalLevel.LEVEL_1HOUR,
                         provider='joinquant')
    print(df)