Example #1
0
def get_securities_in_blocks(
        provider: str = 'eastmoney',
        categories: List[Union[str, StockCategory]] = ['concept', 'industry'],
        names=None,
        codes=None,
        ids=None):
    session = get_db_session(provider=provider, data_schema=Index)

    categories = [StockCategory(category).value for category in categories]

    filters = [Index.category.in_(categories)]

    # add name filters
    if names:
        filters.append(Index.name.in_(names))

    blocks = get_entities(entity_ids=ids,
                          codes=codes,
                          entity_type='index',
                          provider=provider,
                          filters=filters,
                          return_type='domain',
                          session=session)
    securities = []
    for block in blocks:
        securities += [item.stock_id for item in block.stocks]

    return securities
Example #2
0
File: api.py Project: lunvs/zvt
def df_to_db(df, data_schema, provider, force=False):
    if not df_is_not_null(df):
        return

    db_engine = get_db_engine(provider, data_schema=data_schema)

    if force:
        session = get_db_session(provider=provider, data_schema=data_schema)
        ids = df["id"].tolist()
        # count = len(ids)
        # start = 0
        # while True:
        #     end = min(count, start + 5000)
        #     sql = f'delete from {data_schema.__tablename__} where id in {tuple(ids[start:end])}'
        #     session.execute(sql)
        #     session.commit()
        #     if end == count:
        #         break
        #     start = end
        sql = f'delete from {data_schema.__tablename__} where id in {tuple(ids)}'

        session.execute(sql)
        session.commit()

    else:
        current = get_data(data_schema=data_schema, columns=[data_schema.id], provider=provider)
        if df_is_not_null(current):
            df = df[~df['id'].isin(current['id'])]

    df.to_sql(data_schema.__tablename__, db_engine, index=False, if_exists='append')
Example #3
0
def get_indices(provider: str = 'sina',
                block_category: Union[str, StockCategory] = 'concept',
                return_type: str = 'df') -> object:
    """
    get indices/blocks on block_category

    :param provider:
    :type provider:
    :param block_category:
    :type block_category:
    :param return_type:
    :type return_type:
    :return:
    :rtype:
    """
    if type(block_category) == StockCategory:
        block_category = block_category.value

    session = get_db_session(provider=provider, data_schema=Index)

    filters = [Index.category == block_category]
    blocks = get_entities(entity_type='index',
                          provider=provider,
                          filters=filters,
                          session=session,
                          return_type=return_type)
    return blocks
Example #4
0
    def __init__(self,
                 batch_size: int = 10,
                 force_update: bool = False,
                 sleeping_time: int = 10) -> None:
        """

        :param batch_size:batch size to saving to db
        :type batch_size:int
        :param force_update: whether force update the data even if it exists,please set it to True if the data need to
        be refreshed from the provider
        :type force_update:bool
        :param sleeping_time:sleeping seconds for recoding loop
        :type sleeping_time:int
        """

        assert self.provider is not None
        assert self.data_schema is not None

        self.batch_size = batch_size
        self.force_update = force_update
        self.sleeping_time = sleeping_time

        # using to do db operations
        self.session = get_db_session(provider=self.provider,
                                      data_schema=self.data_schema)
Example #5
0
File: api.py Project: lunvs/zvt
def get_group(provider, data_schema, column, group_func=func.count, session=None):
    if not session:
        session = get_db_session(provider=provider, data_schema=data_schema)
    if group_func:
        query = session.query(column, group_func(column)).group_by(column)
    else:
        query = session.query(column).group_by(column)
    df = pd.read_sql(query.statement, query.session.bind)
    return df
    def init_entities(self):
        self.entity_session = get_db_session(provider=self.entity_provider, data_schema=self.entity_schema)

        self.entities = get_entities(session=self.entity_session, entity_type='index',
                                     exchanges=self.exchanges,
                                     codes=self.codes,
                                     entity_ids=self.entity_ids,
                                     return_type='domain', provider=self.provider,
                                     filters=[Index.category.in_(
                                         [StockCategory.industry.value, StockCategory.concept.value])])
Example #7
0
def get_stock_category(stock_id, session=None):
    local_session = False
    if not session:
        session = get_db_session(db_name='meta')
        local_session = True
    try:
        return session.query(Index).filter(Index.stocks.any(id=stock_id)).all()
    except Exception:
        raise
    finally:
        if local_session:
            session.close()
Example #8
0
def stock_id_in_index(stock_id, index_id, session=None, data_schema=StockIndex, provider='eastmoney'):
    the_id = '{}_{}'.format(index_id, stock_id)
    local_session = False
    if not session:
        session = get_db_session(provider=provider, data_schema=data_schema)
        local_session = True

    try:
        return data_exist(session=session, schema=data_schema, id=the_id)
    except Exception:
        raise
    finally:
        if local_session:
            session.close()
Example #9
0
    def init_entities(self):
        if self.entity_provider == self.provider and self.entity_schema == self.data_schema:
            self.entity_session = self.session
        else:
            self.entity_session = get_db_session(provider=self.entity_provider, data_schema=self.entity_schema)

        # init the entity list
        self.entities = get_entities(session=self.entity_session,
                                     entity_type=self.entity_type,
                                     exchanges=self.exchanges,
                                     entity_ids=self.entity_ids,
                                     codes=self.codes,
                                     return_type='domain',
                                     provider=self.entity_provider)
Example #10
0
    def __init__(
            self,
            data_schema: object,
            entity_ids: List[str] = None,
            entity_type: str = 'stock',
            exchanges: List[str] = ['sh', 'sz'],
            codes: List[str] = None,
            the_timestamp: Union[str, pd.Timestamp] = None,
            start_timestamp: Union[str, pd.Timestamp] = None,
            end_timestamp: Union[str, pd.Timestamp] = None,
            columns: List = None,
            filters: List = None,
            order: object = None,
            limit: int = None,
            provider: str = 'eastmoney',
            level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
            category_field: str = 'entity_id',
            time_field: str = 'timestamp',
            trip_timestamp: bool = True,
            auto_load: bool = True,
            # child added arguments
            keep_all_timestamp: bool = False,
            fill_method: str = 'ffill',
            effective_number: int = 10) -> None:
        if columns:
            self.factors = [item.key for item in columns]

        super().__init__(data_schema, entity_ids, entity_type, exchanges,
                         codes, the_timestamp, start_timestamp, end_timestamp,
                         columns, filters, order, limit, provider, level,
                         category_field, time_field, trip_timestamp, auto_load)

        register_instance(self.__class__, self)

        # using to do db operations
        self.session = get_db_session(provider='zvdata',
                                      data_schema=FactorDomain)

        self.factor_name = type(self).__name__.lower()

        self.keep_all_timestamp = keep_all_timestamp
        self.fill_method = fill_method
        self.effective_number = effective_number

        self.depth_df: pd.DataFrame = None

        self.result_df: pd.DataFrame = None

        self.register_data_listener(self)
Example #11
0
    def __init__(self,
                 trader_name,
                 timestamp,
                 provider='netease',
                 level=IntervalLevel.LEVEL_1DAY,
                 base_capital=1000000,
                 buy_cost=0.001,
                 sell_cost=0.001,
                 slippage=0.001):

        self.base_capital = base_capital
        self.buy_cost = buy_cost
        self.sell_cost = sell_cost
        self.slippage = slippage
        self.trader_name = trader_name

        self.session = get_db_session('zvt', 'business')
        self.provider = provider
        self.level = level
        self.start_timestamp = timestamp

        account = get_account(session=self.session,
                              trader_name=self.trader_name,
                              return_type='domain',
                              limit=1)

        if account:
            self.logger.warning(
                "trader:{} has run before,old result would be deleted".format(
                    trader_name))
            self.session.query(SimAccount).filter(
                SimAccount.trader_name == self.trader_name).delete()
            self.session.query(Position).filter(
                Position.trader_name == self.trader_name).delete()
            self.session.query(Order).filter(
                Order.trader_name == self.trader_name).delete()
            self.session.commit()

        account = SimAccount(trader_name=self.trader_name,
                             cash=self.base_capital,
                             positions=[],
                             all_value=self.base_capital,
                             value=0,
                             closing=False,
                             timestamp=timestamp)
        self.latest_account = sim_account_schema.dump(account).data
Example #12
0
    def init_entities(self):
        if self.entity_provider == self.provider and self.entity_schema == self.data_schema:
            self.entity_session = self.session
        else:
            self.entity_session = get_db_session(
                provider=self.entity_provider, data_schema=self.entity_schema)

        # init the entity list
        self.entities = get_entities(
            session=self.entity_session,
            entity_type=self.entity_type,
            entity_ids=self.entity_ids,
            codes=self.codes,
            return_type='domain',
            provider=self.entity_provider,
            # 最近7天更新过的跳过
            filters=[(GithubUser.updated_timestamp < day_offset_today(-7)) |
                     (GithubUser.updated_timestamp.is_(None))],
            start_timestamp=self.start_timestamp,
            end_timestamp=self.end_timestamp)
Example #13
0
def get_securities_in_blocks(block_names=['HS300_'],
                             block_category='concept',
                             provider='eastmoney'):
    session = get_db_session(provider=provider, data_schema=Index)

    filters = [Index.category == block_category]
    name_filters = None
    for block_name in block_names:
        if name_filters:
            name_filters |= (Index.name == block_name)
        else:
            name_filters = (Index.name == block_name)
    filters.append(name_filters)
    blocks = get_entities(entity_type='index',
                          provider='eastmoney',
                          filters=filters,
                          return_type='domain',
                          session=session)
    securities = []
    for block in blocks:
        securities += [item.stock_id for item in block.stocks]

    return securities
Example #14
0
def get_group(provider,
              data_schema,
              column,
              group_func=func.count,
              session=None):
    local_session = False
    if not session:
        store_category = get_db_name(data_schema)
        session = get_db_session(provider=provider,
                                 store_category=store_category)
        local_session = True
    try:
        if group_func:
            query = session.query(column, group_func(column)).group_by(column)
        else:
            query = session.query(column).group_by(column)
        df = pd.read_sql(query.statement, query.session.bind)
        return df
    except Exception:
        raise
    finally:
        if local_session:
            session.close()
Example #15
0
from zvdata.domain import get_db_session
from zvdata.structs import IntervalLevel
from ..context import init_context

init_context()

from zvt.api import technical

day_k_session = get_db_session(provider='joinquant',
                               db_name='stock_1d_kdata')  # type: sqlalchemy.orm.Session

day_1h_session = get_db_session(provider='joinquant',
                                db_name='stock_1h_kdata')  # type: sqlalchemy.orm.Session


def test_jq_603220_kdata():
    df = technical.get_kdata(entity_id='stock_sh_603220', session=day_k_session, level=IntervalLevel.LEVEL_1DAY,
                             provider='joinquant')
    print(df)
    df = technical.get_kdata(entity_id='stock_sh_603220', session=day_1h_session, level=IntervalLevel.LEVEL_1HOUR,
                             provider='joinquant')
    print(df)
Example #16
0
from ..context import init_context

init_context()

from zvt.api.api import get_balance_sheet, get_income_statement, get_cash_flow_statement, get_finance_factors
from zvt.domain import FinanceFactor, BalanceSheet, IncomeStatement, CashFlowStatement
from zvdata.domain import get_db_session
from zvdata.utils.time_utils import to_time_str

session = get_db_session(provider='eastmoney',
                         db_name='finance')  # type: sqlalchemy.orm.Session


# 银行指标
def test_000001_finance_factor():
    correct_timestamps = [
        '2018-09-30', '2018-06-30', '2018-03-31', '2017-12-31', '2017-09-30',
        '2017-06-30', '2017-03-31', '2016-12-31', '2016-09-30', '2016-06-30',
        '2016-03-31', '2015-12-31', '2015-09-30', '2015-06-30', '2015-03-31',
        '2014-12-31', '2014-09-30', '2014-06-30', '2014-03-31', '2013-12-31',
        '2013-09-30', '2013-06-30', '2013-03-31', '2012-12-31', '2012-09-30',
        '2012-06-30', '2012-03-31', '2011-12-31', '2011-09-30', '2011-06-30',
        '2011-03-31', '2010-12-31', '2010-09-30', '2010-06-30', '2010-03-31',
        '2009-12-31', '2009-09-30', '2009-06-30', '2009-03-31', '2008-12-31',
        '2008-09-30', '2008-06-30', '2008-03-31', '2007-12-31', '2007-09-30',
        '2007-06-30', '2007-03-31', '2006-12-31', '2006-09-30', '2006-06-30',
        '2006-03-31', '2005-12-31', '2005-09-30', '2005-06-30', '2005-03-31',
        '2004-12-31', '2004-09-30', '2004-06-30', '2004-03-31', '2003-12-31',
        '2003-09-30', '2003-06-30', '2003-03-31', '2002-12-31', '2002-09-30',
        '2002-06-30', '2002-03-31', '2001-12-31', '2001-09-30', '2001-06-30',
        '2001-03-31', '2000-12-31', '2000-06-30', '1999-12-31', '1999-06-30',
Example #17
0
                    raise Exception('{} not support'.format(exchange))

                orderbook = ccxt_exchange.fetch_order_book(code)

                bid = orderbook['bids'][0][0] if len(
                    orderbook['bids']) > 0 else None
                ask = orderbook['asks'][0][0] if len(
                    orderbook['asks']) > 0 else None
                entity_id = f'coin_{exchange}_{code}'
                result[entity_id] = (bid, ask)

    return result


if __name__ == '__main__':
    money_flow_session = get_db_session(provider='sina',
                                        data_schema=IndexMoneyFlow)

    entities = get_entities(
        entity_type='index',
        return_type='domain',
        provider='sina',
        # 只抓概念和行业
        filters=[
            Index.category.in_(
                [StockCategory.industry.value, StockCategory.concept.value])
        ])

    for entity in entities:
        sql = 'UPDATE index_money_flow SET name="{}" where code="{}"'.format(
            entity.name, entity.code)
        money_flow_session.execute(sql)
Example #18
0
def get_data(data_schema,
             entity_ids: List[str] = None,
             entity_id: str = None,
             codes: List[str] = None,
             level: Union[IntervalLevel, str] = None,
             provider: str = None,
             columns: List = None,
             return_type: str = 'df',
             start_timestamp: Union[pd.Timestamp, str] = None,
             end_timestamp: Union[pd.Timestamp, str] = None,
             filters: List = None,
             session: Session = None,
             order=None,
             limit: int = None,
             index: str = 'timestamp',
             index_is_time: bool = True,
             time_field: str = 'timestamp'):
    assert data_schema is not None
    assert provider is not None
    assert provider in global_providers

    local_session = False
    if not session:
        session = get_db_session(provider=provider, data_schema=data_schema)
        local_session = True

    try:
        time_col = eval('data_schema.{}'.format(time_field))

        if columns:
            # support str
            if type(columns[0]) == str:
                columns_ = []
                for col in columns:
                    columns_.append(eval('data_schema.{}'.format(col)))
                columns = columns_

            if time_col not in columns:
                columns.append(time_col)
            query = session.query(*columns)
        else:
            query = session.query(data_schema)

        if entity_id:
            query = query.filter(data_schema.entity_id == entity_id)
        if codes:
            query = query.filter(data_schema.code.in_(codes))
        if entity_ids:
            query = query.filter(data_schema.entity_id.in_(entity_ids))

        # we always store different level in different schema,the level param is not useful now
        if level:
            try:
                # some schema has no level,just ignore it
                data_schema.level
                if type(level) == IntervalLevel:
                    level = level.value
                query = query.filter(data_schema.level == level)
            except Exception as e:
                pass

        query = common_filter(query,
                              data_schema=data_schema,
                              start_timestamp=start_timestamp,
                              end_timestamp=end_timestamp,
                              filters=filters,
                              order=order,
                              limit=limit,
                              time_field=time_field)

        if return_type == 'df':
            df = pd.read_sql(query.statement, query.session.bind)
            if df_is_not_null(df):
                return index_df(df,
                                drop=False,
                                index=index,
                                index_is_time=index_is_time)
            return df
        elif return_type == 'domain':
            return query.all()
        elif return_type == 'dict':
            return [item.__dict__ for item in query.all()]
    except Exception:
        raise
    finally:
        if local_session:
            session.close()
Example #19
0
from zvdata.domain import get_db_session
from zvt.api.api import get_holder_trading, get_manager_trading
from ..context import init_test_context

init_test_context()

from typing import List

from zvt.domain import HolderTrading, ManagerTrading

session = get_db_session(provider='eastmoney',
                         db_name='trading')  # type: sqlalchemy.orm.Session


# 股东交易
def test_000778_holder_trading():
    result: List[HolderTrading] = get_holder_trading(
        session=session,
        provider='eastmoney',
        return_type='domain',
        codes=['000778'],
        end_timestamp='2018-09-30',
        start_timestamp='2018-09-30',
        order=HolderTrading.holding_pct.desc())
    assert len(result) == 6
    assert result[0].holder_name == '新兴际华集团有限公司'
    assert result[0].change_pct == 0.0205
    assert result[0].volume == 32080000
    assert result[0].holding_pct == 0.3996

from ..context import init_context

init_context()

from zvt.api.api import get_spo_detail, get_rights_issue_detail, get_dividend_financing
from zvt.domain import SpoDetail, RightsIssueDetail, DividendFinancing
from zvdata.domain import get_db_session
from zvt.utils.time_utils import to_pd_timestamp

session = get_db_session(
    provider='eastmoney',
    db_name='dividend_financing')  # type: sqlalchemy.orm.Session


# 增发详情
def test_000778_spo_detial():
    result = get_spo_detail(session=session,
                            provider='eastmoney',
                            return_type='domain',
                            codes=['000778'],
                            end_timestamp='2018-09-30',
                            order=SpoDetail.timestamp.desc())
    assert len(result) == 4
    latest: SpoDetail = result[0]
    assert latest.timestamp == to_pd_timestamp('2017-04-01')
    assert latest.spo_issues == 347600000
    assert latest.spo_price == 5.15
    assert latest.spo_raising_fund == 1766000000


# 配股详情
Example #21
0
    def __init__(self,
                 entity_ids: List[str] = None,
                 exchanges: List[str] = ['sh', 'sz'],
                 codes: List[str] = None,
                 start_timestamp: Union[str, pd.Timestamp] = None,
                 end_timestamp: Union[str, pd.Timestamp] = None,
                 provider: str = 'joinquant',
                 level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
                 trader_name: str = None,
                 real_time: bool = False,
                 kdata_use_begin_time: bool = False) -> None:

        assert self.entity_type is not None

        if trader_name:
            self.trader_name = trader_name
        else:
            self.trader_name = type(self).__name__.lower()

        self.trading_signal_listeners = []
        self.state_listeners = []

        self.selectors: List[TargetSelector] = []

        self.entity_ids = entity_ids

        self.exchanges = exchanges
        self.codes = codes

        # FIXME:handle this case gracefully
        if self.entity_ids:
            entity_type, exchange, code = decode_entity_id(self.entity_ids[0])
            if not self.entity_type:
                self.entity_type = entity_type
            if not self.exchanges:
                self.exchanges = [exchange]

        self.provider = provider
        # make sure the min level selector correspond to the provider and level
        self.level = IntervalLevel(level)
        self.real_time = real_time

        if start_timestamp and end_timestamp:
            self.start_timestamp = to_pd_timestamp(start_timestamp)
            self.end_timestamp = to_pd_timestamp(end_timestamp)
        else:
            assert False

        if real_time:
            logger.info(
                'real_time mode, end_timestamp should be future,you could set it big enough for running forever')
            assert self.end_timestamp >= now_pd_timestamp()

        self.kdata_use_begin_time = kdata_use_begin_time

        self.account_service = SimAccountService(trader_name=self.trader_name,
                                                 timestamp=self.start_timestamp,
                                                 provider=self.provider,
                                                 level=self.level)

        self.add_trading_signal_listener(self.account_service)

        self.init_selectors(entity_ids=entity_ids, entity_type=self.entity_type, exchanges=self.exchanges,
                            codes=self.codes, start_timestamp=self.start_timestamp, end_timestamp=self.end_timestamp)

        self.selectors_comparator = self.init_selectors_comparator()

        self.trading_level_asc = list(set([IntervalLevel(selector.level) for selector in self.selectors]))
        self.trading_level_asc.sort()

        self.trading_level_desc = list(self.trading_level_asc)
        self.trading_level_desc.reverse()

        self.targets_slot: TargetsSlot = TargetsSlot()

        self.session = get_db_session('zvt', 'business')
        trader = get_trader(session=self.session, trader_name=self.trader_name, return_type='domain', limit=1)

        if trader:
            self.logger.warning("trader:{} has run before,old result would be deleted".format(self.trader_name))
            self.session.query(business.Trader).filter(business.Trader.trader_name == self.trader_name).delete()
            self.session.commit()
        self.on_start()