Ejemplo n.º 1
0
                op.drop_index('ix_%s_%s' % (table, column.name))
            except sa.exc.OperationalError:
                pass

    op.create_table(name, *columns)
    op.execute(
        'insert into %s select %s from %s' % (
            name,
            selection_string,
            tmp_name,
        ),
    )
    op.drop_table(tmp_name)


@preprocess(engine=coerce_string_to_eng(require_exists=True))
def downgrade(engine, desired_version):
    """Downgrades the assets db at the given engine to the desired version.

    Parameters
    ----------
    engine : Engine
        An SQLAlchemy engine to the assets database.
    desired_version : int
        The desired resulting version for the assets database.
    """

    # Check the version of the db at the engine
    with engine.begin() as conn:
        metadata = sa.MetaData(conn)
        metadata.reflect()
Ejemplo n.º 2
0
class AssetDBWriter(object):
    """Class used to write data to an assets db.

    Parameters
    ----------
    engine : Engine or str
        An SQLAlchemy engine or path to a SQL database.
    """
    DEFAULT_CHUNK_SIZE = SQLITE_MAX_VARIABLE_NUMBER

    @preprocess(engine=coerce_string_to_eng(require_exists=False))
    def __init__(self, engine):
        self.engine = engine

    def _real_write(self,
                    equities,
                    equity_symbol_mappings,
                    equity_supplementary_mappings,
                    futures,
                    exchanges,
                    root_symbols,
                    chunk_size):
        with self.engine.begin() as conn:
            # Create SQL tables if they do not exist.
            self.init_db(conn)

            if exchanges is not None:
                self._write_df_to_table(
                    exchanges_table,
                    exchanges,
                    conn,
                    chunk_size,
                )

            if root_symbols is not None:
                self._write_df_to_table(
                    futures_root_symbols,
                    root_symbols,
                    conn,
                    chunk_size,
                )

            if equity_supplementary_mappings is not None:
                self._write_df_to_table(
                    equity_supplementary_mappings_table,
                    equity_supplementary_mappings,
                    conn,
                    chunk_size,
                )

            if futures is not None:
                self._write_assets(
                    'future',
                    futures,
                    conn,
                    chunk_size,
                )

            if equities is not None:
                self._write_assets(
                    'equity',
                    equities,
                    conn,
                    chunk_size,
                    mapping_data=equity_symbol_mappings,
                )

    def write_direct(self,
                     equities=None,
                     equity_symbol_mappings=None,
                     equity_supplementary_mappings=None,
                     futures=None,
                     exchanges=None,
                     root_symbols=None,
                     chunk_size=DEFAULT_CHUNK_SIZE):
        """Write asset metadata to a sqlite database in the format that it is
        stored in the assets db.

        Parameters
        ----------
        equities : pd.DataFrame, optional
            The equity metadata. The columns for this dataframe are:

              symbol : str
                  The ticker symbol for this equity.
              asset_name : str
                  The full name for this asset.
              start_date : datetime
                  The date when this asset was created.
              end_date : datetime, optional
                  The last date we have trade data for this asset.
              first_traded : datetime, optional
                  The first date we have trade data for this asset.
              auto_close_date : datetime, optional
                  The date on which to close any positions in this asset.
              exchange : str
                  The exchange where this asset is traded.

            The index of this dataframe should contain the sids.
        futures : pd.DataFrame, optional
            The future contract metadata. The columns for this dataframe are:

              symbol : str
                  The ticker symbol for this futures contract.
              root_symbol : str
                  The root symbol, or the symbol with the expiration stripped
                  out.
              asset_name : str
                  The full name for this asset.
              start_date : datetime, optional
                  The date when this asset was created.
              end_date : datetime, optional
                  The last date we have trade data for this asset.
              first_traded : datetime, optional
                  The first date we have trade data for this asset.
              exchange : str
                  The exchange where this asset is traded.
              notice_date : datetime
                  The date when the owner of the contract may be forced
                  to take physical delivery of the contract's asset.
              expiration_date : datetime
                  The date when the contract expires.
              auto_close_date : datetime
                  The date when the broker will automatically close any
                  positions in this contract.
              tick_size : float
                  The minimum price movement of the contract.
              multiplier: float
                  The amount of the underlying asset represented by this
                  contract.
        exchanges : pd.DataFrame, optional
            The exchanges where assets can be traded. The columns of this
            dataframe are:

              exchange : str
                  The full name of the exchange.
              canonical_name : str
                  The canonical name of the exchange.
              country_code : str
                  The ISO 3166 alpha-2 country code of the exchange.
        root_symbols : pd.DataFrame, optional
            The root symbols for the futures contracts. The columns for this
            dataframe are:

              root_symbol : str
                  The root symbol name.
              root_symbol_id : int
                  The unique id for this root symbol.
              sector : string, optional
                  The sector of this root symbol.
              description : string, optional
                  A short description of this root symbol.
              exchange : str
                  The exchange where this root symbol is traded.
        equity_supplementary_mappings : pd.DataFrame, optional
            Additional mappings from values of abitrary type to assets.
        chunk_size : int, optional
            The amount of rows to write to the SQLite table at once.
            This defaults to the default number of bind params in sqlite.
            If you have compiled sqlite3 with more bind or less params you may
            want to pass that value here.

        """
        if equities is not None:
            equities = _generate_output_dataframe(
                equities,
                _direct_equities_defaults,
            )
            if equity_symbol_mappings is None:
                raise ValueError(
                    'equities provided with no symbol mapping data',
                )

            equity_symbol_mappings = _generate_output_dataframe(
                equity_symbol_mappings,
                _equity_symbol_mappings_defaults,
            )
            _check_symbol_mappings(
                equity_symbol_mappings,
                exchanges,
                equities['exchange'],
            )

        if equity_supplementary_mappings is not None:
            equity_supplementary_mappings = _generate_output_dataframe(
                equity_supplementary_mappings,
                _equity_supplementary_mappings_defaults,
            )

        if futures is not None:
            futures = _generate_output_dataframe(_futures_defaults, futures)

        if exchanges is not None:
            exchanges = _generate_output_dataframe(
                exchanges.set_index('exchange'),
                _exchanges_defaults,
            )

        if root_symbols is not None:
            root_symbols = _generate_output_dataframe(
                root_symbols,
                _root_symbols_defaults,
            )

        # Set named identifier columns as indices, if provided.
        _normalize_index_columns_in_place(
            equities=equities,
            equity_supplementary_mappings=equity_supplementary_mappings,
            futures=futures,
            exchanges=exchanges,
            root_symbols=root_symbols,
        )

        self._real_write(
            equities=equities,
            equity_symbol_mappings=equity_symbol_mappings,
            equity_supplementary_mappings=equity_supplementary_mappings,
            futures=futures,
            exchanges=exchanges,
            root_symbols=root_symbols,
            chunk_size=chunk_size,
        )

    def write(self,
              equities=None,
              futures=None,
              exchanges=None,
              root_symbols=None,
              equity_supplementary_mappings=None,
              chunk_size=DEFAULT_CHUNK_SIZE):
        """Write asset metadata to a sqlite database.

        Parameters
        ----------
        equities : pd.DataFrame, optional
            The equity metadata. The columns for this dataframe are:

              symbol : str
                  The ticker symbol for this equity.
              asset_name : str
                  The full name for this asset.
              start_date : datetime
                  The date when this asset was created.
              end_date : datetime, optional
                  The last date we have trade data for this asset.
              first_traded : datetime, optional
                  The first date we have trade data for this asset.
              auto_close_date : datetime, optional
                  The date on which to close any positions in this asset.
              exchange : str
                  The exchange where this asset is traded.

            The index of this dataframe should contain the sids.
        futures : pd.DataFrame, optional
            The future contract metadata. The columns for this dataframe are:

              symbol : str
                  The ticker symbol for this futures contract.
              root_symbol : str
                  The root symbol, or the symbol with the expiration stripped
                  out.
              asset_name : str
                  The full name for this asset.
              start_date : datetime, optional
                  The date when this asset was created.
              end_date : datetime, optional
                  The last date we have trade data for this asset.
              first_traded : datetime, optional
                  The first date we have trade data for this asset.
              exchange : str
                  The exchange where this asset is traded.
              notice_date : datetime
                  The date when the owner of the contract may be forced
                  to take physical delivery of the contract's asset.
              expiration_date : datetime
                  The date when the contract expires.
              auto_close_date : datetime
                  The date when the broker will automatically close any
                  positions in this contract.
              tick_size : float
                  The minimum price movement of the contract.
              multiplier: float
                  The amount of the underlying asset represented by this
                  contract.
        exchanges : pd.DataFrame, optional
            The exchanges where assets can be traded. The columns of this
            dataframe are:

              exchange : str
                  The full name of the exchange.
              canonical_name : str
                  The canonical name of the exchange.
              country_code : str
                  The ISO 3166 alpha-2 country code of the exchange.
        root_symbols : pd.DataFrame, optional
            The root symbols for the futures contracts. The columns for this
            dataframe are:

              root_symbol : str
                  The root symbol name.
              root_symbol_id : int
                  The unique id for this root symbol.
              sector : string, optional
                  The sector of this root symbol.
              description : string, optional
                  A short description of this root symbol.
              exchange : str
                  The exchange where this root symbol is traded.
        equity_supplementary_mappings : pd.DataFrame, optional
            Additional mappings from values of abitrary type to assets.
        chunk_size : int, optional
            The amount of rows to write to the SQLite table at once.
            This defaults to the default number of bind params in sqlite.
            If you have compiled sqlite3 with more bind or less params you may
            want to pass that value here.

        See Also
        --------
        zipline.assets.asset_finder
        """
        if exchanges is None:
            exchange_names = [
                df['exchange']
                for df in (equities, futures, root_symbols)
                if df is not None
            ]
            if exchange_names:
                exchanges = pd.DataFrame({
                    'exchange': pd.concat(exchange_names).unique(),
                })

        data = self._load_data(
            equities if equities is not None else pd.DataFrame(),
            futures if futures is not None else pd.DataFrame(),
            exchanges if exchanges is not None else pd.DataFrame(),
            root_symbols if root_symbols is not None else pd.DataFrame(),
            (
                equity_supplementary_mappings
                if equity_supplementary_mappings is not None
                else pd.DataFrame()
            ),
        )
        self._real_write(
            equities=data.equities,
            equity_symbol_mappings=data.equities_mappings,
            equity_supplementary_mappings=data.equity_supplementary_mappings,
            futures=data.futures,
            root_symbols=data.root_symbols,
            exchanges=data.exchanges,
            chunk_size=chunk_size,
        )

    def _write_df_to_table(self, tbl, df, txn, chunk_size):
        df = df.copy()
        for column, dtype in df.dtypes.iteritems():
            if dtype.kind == 'M':
                df[column] = _dt_to_epoch_ns(df[column])

        df.to_sql(
            tbl.name,
            txn.connection,
            index=True,
            index_label=first(tbl.primary_key.columns).name,
            if_exists='append',
            chunksize=chunk_size,
        )

    def _write_assets(self,
                      asset_type,
                      assets,
                      txn,
                      chunk_size,
                      mapping_data=None):
        if asset_type == 'future':
            tbl = futures_contracts_table
            if mapping_data is not None:
                raise TypeError('no mapping data expected for futures')

        elif asset_type == 'equity':
            tbl = equities_table
            if mapping_data is None:
                raise TypeError('mapping data required for equities')
            # write the symbol mapping data.
            self._write_df_to_table(
                equity_symbol_mappings,
                mapping_data,
                txn,
                chunk_size,
            )

        else:
            raise ValueError(
                "asset_type must be in {'future', 'equity'}, got: %s" %
                asset_type,
            )

        self._write_df_to_table(tbl, assets, txn, chunk_size)

        pd.DataFrame({
            asset_router.c.sid.name: assets.index.values,
            asset_router.c.asset_type.name: asset_type,
        }).to_sql(
            asset_router.name,
            txn.connection,
            if_exists='append',
            index=False,
            chunksize=chunk_size
        )

    def _all_tables_present(self, txn):
        """
        Checks if any tables are present in the current assets database.

        Parameters
        ----------
        txn : Transaction
            The open transaction to check in.

        Returns
        -------
        has_tables : bool
            True if any tables are present, otherwise False.
        """
        conn = txn.connect()
        for table_name in asset_db_table_names:
            if txn.dialect.has_table(conn, table_name):
                return True
        return False

    def init_db(self, txn=None):
        """Connect to database and create tables.

        Parameters
        ----------
        txn : sa.engine.Connection, optional
            The transaction to execute in. If this is not provided, a new
            transaction will be started with the engine provided.

        Returns
        -------
        metadata : sa.MetaData
            The metadata that describes the new assets db.
        """
        with ExitStack() as stack:
            if txn is None:
                txn = stack.enter_context(self.engine.begin())

            tables_already_exist = self._all_tables_present(txn)

            # Create the SQL tables if they do not already exist.
            metadata.create_all(txn, checkfirst=True)

            if tables_already_exist:
                check_version_info(txn, version_info, ASSET_DB_VERSION)
            else:
                write_version_info(txn, version_info, ASSET_DB_VERSION)

    def _normalize_equities(self, equities, exchanges):
        # HACK: If 'company_name' is provided, map it to asset_name
        if ('company_name' in equities.columns and
                'asset_name' not in equities.columns):
            equities['asset_name'] = equities['company_name']

        # remap 'file_name' to 'symbol' if provided
        if 'file_name' in equities.columns:
            equities['symbol'] = equities['file_name']

        equities_output = _generate_output_dataframe(
            data_subset=equities,
            defaults=_equities_defaults,
        )

        # Split symbols to company_symbols and share_class_symbols
        tuple_series = equities_output['symbol'].apply(split_delimited_symbol)
        split_symbols = pd.DataFrame(
            tuple_series.tolist(),
            columns=['company_symbol', 'share_class_symbol'],
            index=tuple_series.index
        )
        equities_output = pd.concat((equities_output, split_symbols), axis=1)

        # Upper-case all symbol data
        for col in symbol_columns:
            equities_output[col] = equities_output[col].str.upper()

        # Convert date columns to UNIX Epoch integers (nanoseconds)
        for col in ('start_date',
                    'end_date',
                    'first_traded',
                    'auto_close_date'):
            equities_output[col] = _dt_to_epoch_ns(equities_output[col])

        return _split_symbol_mappings(equities_output, exchanges)

    def _normalize_futures(self, futures):
        futures_output = _generate_output_dataframe(
            data_subset=futures,
            defaults=_futures_defaults,
        )
        for col in ('symbol', 'root_symbol'):
            futures_output[col] = futures_output[col].str.upper()

        for col in ('start_date',
                    'end_date',
                    'first_traded',
                    'notice_date',
                    'expiration_date',
                    'auto_close_date'):
            futures_output[col] = _dt_to_epoch_ns(futures_output[col])

        return futures_output

    def _normalize_equity_supplementary_mappings(self, mappings):
        mappings_output = _generate_output_dataframe(
            data_subset=mappings,
            defaults=_equity_supplementary_mappings_defaults,
        )

        for col in ('start_date', 'end_date'):
            mappings_output[col] = _dt_to_epoch_ns(mappings_output[col])

        return mappings_output

    def _load_data(self,
                   equities,
                   futures,
                   exchanges,
                   root_symbols,
                   equity_supplementary_mappings):
        """
        Returns a standard set of pandas.DataFrames:
        equities, futures, exchanges, root_symbols
        """
        # Set named identifier columns as indices, if provided.
        _normalize_index_columns_in_place(
            equities=equities,
            equity_supplementary_mappings=equity_supplementary_mappings,
            futures=futures,
            exchanges=exchanges,
            root_symbols=root_symbols,
        )

        futures_output = self._normalize_futures(futures)

        equity_supplementary_mappings_output = (
            self._normalize_equity_supplementary_mappings(
                equity_supplementary_mappings,
            )
        )

        exchanges_output = _generate_output_dataframe(
            data_subset=exchanges,
            defaults=_exchanges_defaults,
        )

        equities_output, equities_mappings = self._normalize_equities(
            equities,
            exchanges_output,
        )

        root_symbols_output = _generate_output_dataframe(
            data_subset=root_symbols,
            defaults=_root_symbols_defaults,
        )

        return AssetData(
            equities=equities_output,
            equities_mappings=equities_mappings,
            futures=futures_output,
            exchanges=exchanges_output,
            root_symbols=root_symbols_output,
            equity_supplementary_mappings=equity_supplementary_mappings_output,
        )
Ejemplo n.º 3
0
class AssetFinder(object):
    """
    An AssetFinder is an interface to a database of Asset metadata written by
    an ``AssetDBWriter``.

    This class provides methods for looking up assets by unique integer id or
    by symbol.  For historical reasons, we refer to these unique ids as 'sids'.

    Parameters
    ----------
    engine : str or SQLAlchemy.engine
        An engine with a connection to the asset database to use, or a string
        that can be parsed by SQLAlchemy as a URI.
    future_chain_predicates : dict
        A dict mapping future root symbol to a predicate function which accepts
    a contract as a parameter and returns whether or not the contract should be
    included in the chain.

    See Also
    --------
    :class:`zipline.assets.AssetDBWriter`
    """
    # Token used as a substitute for pickling objects that contain a
    # reference to an AssetFinder.
    PERSISTENT_TOKEN = "<AssetFinder>"

    @preprocess(engine=coerce_string_to_eng(require_exists=True))
    def __init__(self, engine, future_chain_predicates=CHAIN_PREDICATES):
        self.engine = engine
        metadata = sa.MetaData(bind=engine)
        metadata.reflect(only=asset_db_table_names)
        for table_name in asset_db_table_names:
            setattr(self, table_name, metadata.tables[table_name])

        # Check the version info of the db for compatibility
        check_version_info(engine, self.version_info, ASSET_DB_VERSION)

        # Cache for lookup of assets by sid, the objects in the asset lookup
        # may be shared with the results from equity and future lookup caches.
        #
        # The top level cache exists to minimize lookups on the asset type
        # routing.
        #
        # The caches are read through, i.e. accessing an asset through
        # retrieve_asset will populate the cache on first retrieval.
        self._caches = (self._asset_cache, self._asset_type_cache) = {}, {}

        self._future_chain_predicates = future_chain_predicates \
            if future_chain_predicates is not None else {}
        self._ordered_contracts = {}

        # Populated on first call to `lifetimes`.
        self._asset_lifetimes = None

    def _reset_caches(self):
        """
        Reset our asset caches.

        You probably shouldn't call this method.
        """
        # This method exists as a workaround for the in-place mutating behavior
        # of `TradingAlgorithm._write_and_map_id_index_to_sids`.  No one else
        # should be calling this.
        for cache in self._caches:
            cache.clear()
        self.reload_symbol_maps()

    def reload_symbol_maps(self):
        """Clear the in memory symbol lookup maps.

        This will make any changes to the underlying db available to the
        symbol maps.
        """
        # clear the lazyval caches, the next access will requery
        try:
            del type(self).symbol_ownership_map[self]
        except KeyError:
            pass
        try:
            del type(self).fuzzy_symbol_ownership_map[self]
        except KeyError:
            pass
        try:
            del type(self).equity_supplementary_map[self]
        except KeyError:
            pass
        try:
            del type(self).equity_supplementary_map_by_sid[self]
        except KeyError:
            pass

    @lazyval
    def symbol_ownership_map(self):
        return build_ownership_map(
            table=self.equity_symbol_mappings,
            key_from_row=(
                lambda row: (row.company_symbol, row.share_class_symbol)
            ),
            value_from_row=lambda row: row.symbol,
        )

    @lazyval
    def fuzzy_symbol_ownership_map(self):
        fuzzy_mappings = {}
        for (cs, scs), owners in iteritems(self.symbol_ownership_map):
            fuzzy_owners = fuzzy_mappings.setdefault(
                cs + scs,
                [],
            )
            fuzzy_owners.extend(owners)
            fuzzy_owners.sort()
        return fuzzy_mappings

    @lazyval
    def equity_supplementary_map(self):
        return build_ownership_map(
            table=self.equity_supplementary_mappings,
            key_from_row=lambda row: (row.field, row.value),
            value_from_row=lambda row: row.value,
        )

    @lazyval
    def equity_supplementary_map_by_sid(self):
        return build_ownership_map(
            table=self.equity_supplementary_mappings,
            key_from_row=lambda row: (row.field, row.sid),
            value_from_row=lambda row: row.value,
        )

    def lookup_asset_types(self, sids):
        """
        Retrieve asset types for a list of sids.

        Parameters
        ----------
        sids : list[int]

        Returns
        -------
        types : dict[sid -> str or None]
            Asset types for the provided sids.
        """
        found = {}
        missing = set()

        for sid in sids:
            try:
                found[sid] = self._asset_type_cache[sid]
            except KeyError:
                missing.add(sid)

        if not missing:
            return found

        router_cols = self.asset_router.c

        for assets in group_into_chunks(missing):
            query = sa.select((router_cols.sid, router_cols.asset_type)).where(
                self.asset_router.c.sid.in_(map(int, assets))
            )
            for sid, type_ in query.execute().fetchall():
                missing.remove(sid)
                found[sid] = self._asset_type_cache[sid] = type_

            for sid in missing:
                found[sid] = self._asset_type_cache[sid] = None

        return found

    def group_by_type(self, sids):
        """
        Group a list of sids by asset type.

        Parameters
        ----------
        sids : list[int]

        Returns
        -------
        types : dict[str or None -> list[int]]
            A dict mapping unique asset types to lists of sids drawn from sids.
            If we fail to look up an asset, we assign it a key of None.
        """
        return invert(self.lookup_asset_types(sids))

    def retrieve_asset(self, sid, default_none=False):
        """
        Retrieve the Asset for a given sid.
        """
        try:
            asset = self._asset_cache[sid]
            if asset is None and not default_none:
                raise SidsNotFound(sids=[sid])
            return asset
        except KeyError:
            return self.retrieve_all((sid,), default_none=default_none)[0]

    def retrieve_all(self, sids, default_none=False):
        """
        Retrieve all assets in `sids`.

        Parameters
        ----------
        sids : iterable of int
            Assets to retrieve.
        default_none : bool
            If True, return None for failed lookups.
            If False, raise `SidsNotFound`.

        Returns
        -------
        assets : list[Asset or None]
            A list of the same length as `sids` containing Assets (or Nones)
            corresponding to the requested sids.

        Raises
        ------
        SidsNotFound
            When a requested sid is not found and default_none=False.
        """
        hits, missing, failures = {}, set(), []
        for sid in sids:
            try:
                asset = self._asset_cache[sid]
                if not default_none and asset is None:
                    # Bail early if we've already cached that we don't know
                    # about an asset.
                    raise SidsNotFound(sids=[sid])
                hits[sid] = asset
            except KeyError:
                missing.add(sid)

        # All requests were cache hits.  Return requested sids in order.
        if not missing:
            return [hits[sid] for sid in sids]

        update_hits = hits.update

        # Look up cache misses by type.
        type_to_assets = self.group_by_type(missing)

        # Handle failures
        failures = {failure: None for failure in type_to_assets.pop(None, ())}
        update_hits(failures)
        self._asset_cache.update(failures)

        if failures and not default_none:
            raise SidsNotFound(sids=list(failures))

        # We don't update the asset cache here because it should already be
        # updated by `self.retrieve_equities`.
        update_hits(self.retrieve_equities(type_to_assets.pop('equity', ())))
        update_hits(
            self.retrieve_futures_contracts(type_to_assets.pop('future', ()))
        )

        # We shouldn't know about any other asset types.
        if type_to_assets:
            raise AssertionError(
                "Found asset types: %s" % list(type_to_assets.keys())
            )

        return [hits[sid] for sid in sids]

    def retrieve_equities(self, sids):
        """
        Retrieve Equity objects for a list of sids.

        Users generally shouldn't need to this method (instead, they should
        prefer the more general/friendly `retrieve_assets`), but it has a
        documented interface and tests because it's used upstream.

        Parameters
        ----------
        sids : iterable[int]

        Returns
        -------
        equities : dict[int -> Equity]

        Raises
        ------
        EquitiesNotFound
            When any requested asset isn't found.
        """
        return self._retrieve_assets(sids, self.equities, Equity)

    def _retrieve_equity(self, sid):
        return self.retrieve_equities((sid,))[sid]

    def retrieve_futures_contracts(self, sids):
        """
        Retrieve Future objects for an iterable of sids.

        Users generally shouldn't need to this method (instead, they should
        prefer the more general/friendly `retrieve_assets`), but it has a
        documented interface and tests because it's used upstream.

        Parameters
        ----------
        sids : iterable[int]

        Returns
        -------
        equities : dict[int -> Equity]

        Raises
        ------
        EquitiesNotFound
            When any requested asset isn't found.
        """
        return self._retrieve_assets(sids, self.futures_contracts, Future)

    @staticmethod
    def _select_assets_by_sid(asset_tbl, sids):
        return sa.select([asset_tbl]).where(
            asset_tbl.c.sid.in_(map(int, sids))
        )

    @staticmethod
    def _select_asset_by_symbol(asset_tbl, symbol):
        return sa.select([asset_tbl]).where(asset_tbl.c.symbol == symbol)

    def _select_most_recent_symbols_chunk(self, sid_group):
        """Retrieve the most recent symbol for a set of sids.

        Parameters
        ----------
        sid_group : iterable[int]
            The sids to lookup. The length of this sequence must be less than
            or equal to SQLITE_MAX_VARIABLE_NUMBER because the sids will be
            passed in as sql bind params.

        Returns
        -------
        sel : Selectable
            The sqlalchemy selectable that will query for the most recent
            symbol for each sid.

        Notes
        -----
        This is implemented as an inner select of the columns of interest
        ordered by the end date of the (sid, symbol) mapping. We then group
        that inner select on the sid with no aggregations to select the last
        row per group which gives us the most recently active symbol for all
        of the sids.
        """
        symbol_cols = self.equity_symbol_mappings.c
        inner = sa.select(
            (symbol_cols.sid,) +
            tuple(map(
                op.getitem(symbol_cols),
                symbol_columns,
            )),
        ).where(
            symbol_cols.sid.in_(map(int, sid_group)),
        ).order_by(
            symbol_cols.end_date.asc(),
        )
        return sa.select(inner.c).group_by(inner.c.sid)

    def _lookup_most_recent_symbols(self, sids):
        symbols = {
            row.sid: {c: row[c] for c in symbol_columns}
            for row in concat(
                self.engine.execute(
                    self._select_most_recent_symbols_chunk(sid_group),
                ).fetchall()
                for sid_group in partition_all(
                    SQLITE_MAX_VARIABLE_NUMBER,
                    sids
                ),
            )
        }

        if len(symbols) != len(sids):
            raise EquitiesNotFound(
                sids=set(sids) - set(symbols),
                plural=True,
            )
        return symbols

    def _retrieve_asset_dicts(self, sids, asset_tbl, querying_equities):
        if not sids:
            return

        if querying_equities:
            def mkdict(row,
                       symbols=self._lookup_most_recent_symbols(sids)):
                return merge(row, symbols[row['sid']])
        else:
            mkdict = dict

        for assets in group_into_chunks(sids):
            # Load misses from the db.
            query = self._select_assets_by_sid(asset_tbl, assets)

            for row in query.execute().fetchall():
                yield _convert_asset_timestamp_fields(mkdict(row))

    def _retrieve_assets(self, sids, asset_tbl, asset_type):
        """
        Internal function for loading assets from a table.

        This should be the only method of `AssetFinder` that writes Assets into
        self._asset_cache.

        Parameters
        ---------
        sids : iterable of int
            Asset ids to look up.
        asset_tbl : sqlalchemy.Table
            Table from which to query assets.
        asset_type : type
            Type of asset to be constructed.

        Returns
        -------
        assets : dict[int -> Asset]
            Dict mapping requested sids to the retrieved assets.
        """
        # Fastpath for empty request.
        if not sids:
            return {}

        cache = self._asset_cache
        hits = {}

        querying_equities = issubclass(asset_type, Equity)
        filter_kwargs = (
            _filter_equity_kwargs
            if querying_equities else
            _filter_future_kwargs
        )

        rows = self._retrieve_asset_dicts(sids, asset_tbl, querying_equities)
        for row in rows:
            sid = row['sid']
            asset = asset_type(**filter_kwargs(row))
            hits[sid] = cache[sid] = asset

        # If we get here, it means something in our code thought that a
        # particular sid was an equity/future and called this function with a
        # concrete type, but we couldn't actually resolve the asset.  This is
        # an error in our code, not a user-input error.
        misses = tuple(set(sids) - viewkeys(hits))
        if misses:
            if querying_equities:
                raise EquitiesNotFound(sids=misses)
            else:
                raise FutureContractsNotFound(sids=misses)
        return hits

    def _lookup_symbol_strict(self, symbol, as_of_date):
        # split the symbol into the components, if there are no
        # company/share class parts then share_class_symbol will be empty
        company_symbol, share_class_symbol = split_delimited_symbol(symbol)
        try:
            owners = self.symbol_ownership_map[
                company_symbol,
                share_class_symbol,
            ]
            assert owners, 'empty owners list for %r' % symbol
        except KeyError:
            # no equity has ever held this symbol
            raise SymbolNotFound(symbol=symbol)

        if not as_of_date:
            if len(owners) > 1:
                # more than one equity has held this ticker, this is ambigious
                # without the date
                raise MultipleSymbolsFound(
                    symbol=symbol,
                    options=set(map(
                        compose(self.retrieve_asset, attrgetter('sid')),
                        owners,
                    )),
                )

            # exactly one equity has ever held this symbol, we may resolve
            # without the date
            return self.retrieve_asset(owners[0].sid)

        for start, end, sid, _ in owners:
            if start <= as_of_date < end:
                # find the equity that owned it on the given asof date
                return self.retrieve_asset(sid)

        # no equity held the ticker on the given asof date
        raise SymbolNotFound(symbol=symbol)

    def _lookup_symbol_fuzzy(self, symbol, as_of_date):
        symbol = symbol.upper()
        company_symbol, share_class_symbol = split_delimited_symbol(symbol)
        try:
            owners = self.fuzzy_symbol_ownership_map[
                company_symbol + share_class_symbol
            ]
            assert owners, 'empty owners list for %r' % symbol
        except KeyError:
            # no equity has ever held a symbol matching the fuzzy symbol
            raise SymbolNotFound(symbol=symbol)

        if not as_of_date:
            if len(owners) == 1:
                # only one valid match
                return self.retrieve_asset(owners[0].sid)

            options = []
            for _, _, sid, sym in owners:
                if sym == symbol:
                    # there are multiple options, look for exact matches
                    options.append(self.retrieve_asset(sid))

            if len(options) == 1:
                # there was only one exact match
                return options[0]

            # there are more than one exact match for this fuzzy symbol
            raise MultipleSymbolsFound(
                symbol=symbol,
                options=set(options),
            )

        options = {}
        for start, end, sid, sym in owners:
            if start <= as_of_date < end:
                # see which fuzzy symbols were owned on the asof date.
                options[sid] = sym

        if not options:
            # no equity owned the fuzzy symbol on the date requested
            raise SymbolNotFound(symbol=symbol)

        sid_keys = list(options.keys())
        # If there was only one owner, or there is a fuzzy and non-fuzzy which
        # map to the same sid, return it.
        if len(options) == 1:
            return self.retrieve_asset(sid_keys[0])

        for sid, sym in options.items():
            # Possible to have a scenario where multiple fuzzy matches have the
            # same date. Want to find the one where symbol and share class
            # match.
            if (company_symbol, share_class_symbol) == \
                    split_delimited_symbol(sym):
                return self.retrieve_asset(sid)

        # multiple equities held tickers matching the fuzzy ticker but
        # there are no exact matches
        raise MultipleSymbolsFound(
            symbol=symbol,
            options=[self.retrieve_asset(s) for s in sid_keys],
        )

    def lookup_symbol(self, symbol, as_of_date, fuzzy=False):
        """Lookup an equity by symbol.

        Parameters
        ----------
        symbol : str
            The ticker symbol to resolve.
        as_of_date : datetime or None
            Look up the last owner of this symbol as of this datetime.
            If ``as_of_date`` is None, then this can only resolve the equity
            if exactly one equity has ever owned the ticker.
        fuzzy : bool, optional
            Should fuzzy symbol matching be used? Fuzzy symbol matching
            attempts to resolve differences in representations for
            shareclasses. For example, some people may represent the ``A``
            shareclass of ``BRK`` as ``BRK.A``, where others could write
            ``BRK_A``.

        Returns
        -------
        equity : Equity
            The equity that held ``symbol`` on the given ``as_of_date``, or the
            only equity to hold ``symbol`` if ``as_of_date`` is None.

        Raises
        ------
        SymbolNotFound
            Raised when no equity has ever held the given symbol.
        MultipleSymbolsFound
            Raised when no ``as_of_date`` is given and more than one equity
            has held ``symbol``. This is also raised when ``fuzzy=True`` and
            there are multiple candidates for the given ``symbol`` on the
            ``as_of_date``.
        """
        if symbol is None:
            raise TypeError("Cannot lookup asset for symbol of None for "
                            "as of date %s." % as_of_date)

        if fuzzy:
            return self._lookup_symbol_fuzzy(symbol, as_of_date)
        return self._lookup_symbol_strict(symbol, as_of_date)

    def lookup_symbols(self, symbols, as_of_date, fuzzy=False):
        """
        Lookup a list of equities by symbol.

        Equivalent to::

            [finder.lookup_symbol(s, as_of, fuzzy) for s in symbols]

        but potentially faster because repeated lookups are memoized.

        Parameters
        ----------
        symbols : sequence[str]
            Sequence of ticker symbols to resolve.
        as_of_date : pd.Timestamp
            Forwarded to ``lookup_symbol``.
        fuzzy : bool, optional
            Forwarded to ``lookup_symbol``.

        Returns
        -------
        equities : list[Equity]
        """
        memo = {}
        out = []
        append_output = out.append
        for sym in symbols:
            if sym in memo:
                append_output(memo[sym])
            else:
                equity = memo[sym] = self.lookup_symbol(sym, as_of_date, fuzzy)
                append_output(equity)
        return out

    def lookup_future_symbol(self, symbol):
        """Lookup a future contract by symbol.

        Parameters
        ----------
        symbol : str
            The symbol of the desired contract.

        Returns
        -------
        future : Future
            The future contract referenced by ``symbol``.

        Raises
        ------
        SymbolNotFound
            Raised when no contract named 'symbol' is found.

        """

        data = self._select_asset_by_symbol(self.futures_contracts, symbol)\
                   .execute().fetchone()

        # If no data found, raise an exception
        if not data:
            raise SymbolNotFound(symbol=symbol)
        return self.retrieve_asset(data['sid'])

    def lookup_by_supplementary_field(self, field_name, value, as_of_date):
        try:
            owners = self.equity_supplementary_map[
                field_name,
                value,
            ]
            assert owners, 'empty owners list for %r' % (field_name, value)
        except KeyError:
            # no equity has ever held this value
            raise ValueNotFoundForField(field=field_name, value=value)

        if not as_of_date:
            if len(owners) > 1:
                # more than one equity has held this value, this is ambigious
                # without the date
                raise MultipleValuesFoundForField(
                    field=field_name,
                    value=value,
                    options=set(map(
                        compose(self.retrieve_asset, attrgetter('sid')),
                        owners,
                    )),
                )
            # exactly one equity has ever held this value, we may resolve
            # without the date
            return self.retrieve_asset(owners[0].sid)

        for start, end, sid, _ in owners:
            if start <= as_of_date < end:
                # find the equity that owned it on the given asof date
                return self.retrieve_asset(sid)

        # no equity held the value on the given asof date
        raise ValueNotFoundForField(field=field_name, value=value)

    def get_supplementary_field(
        self,
        sid,
        field_name,
        as_of_date,
    ):
        """Get the value of a supplementary field for an asset.

        Parameters
        ----------
        sid : int
            The sid of the asset to query.
        field_name : str
            Name of the supplementary field.
        as_of_date : pd.Timestamp, None
            The last known value on this date is returned. If None, a
            value is returned only if we've only ever had one value for
            this sid. If None and we've had multiple values,
            MultipleValuesFoundForSid is raised.

        Raises
        ------
        NoValueForSid
            If we have no values for this asset, or no values was known
            on this as_of_date.
        MultipleValuesFoundForSid
            If we have had multiple values for this asset over time, and
            None was passed for as_of_date.
        """
        try:
            periods = self.equity_supplementary_map_by_sid[
                field_name,
                sid,
            ]
            assert periods, 'empty periods list for %r' % (field_name, sid)
        except KeyError:
            raise NoValueForSid(field=field_name, sid=sid)

        if not as_of_date:
            if len(periods) > 1:
                # This equity has held more than one value, this is ambigious
                # without the date
                raise MultipleValuesFoundForSid(
                    field=field_name,
                    sid=sid,
                    options={p.value for p in periods},
                )
            # this equity has only ever held this value, we may resolve
            # without the date
            return periods[0].value

        for start, end, _, value in periods:
            if start <= as_of_date < end:
                return value

        # Could not find a value for this sid on the as_of_date.
        raise NoValueForSid(field=field_name, sid=sid)

    def _get_contract_sids(self, root_symbol):
        fc_cols = self.futures_contracts.c

        return [r.sid for r in
                list(sa.select((fc_cols.sid,)).where(
                    (fc_cols.root_symbol == root_symbol) &
                    (fc_cols.start_date != pd.NaT.value)).order_by(
                        fc_cols.sid).execute().fetchall())]

    def _get_root_symbol_exchange(self, root_symbol):
        fc_cols = self.futures_root_symbols.c

        fields = (fc_cols.exchange,)

        exchange = sa.select(fields).where(
            fc_cols.root_symbol == root_symbol).execute().scalar()

        if exchange is not None:
            return exchange
        else:
            raise SymbolNotFound(symbol=root_symbol)

    def get_ordered_contracts(self, root_symbol):
        try:
            return self._ordered_contracts[root_symbol]
        except KeyError:
            contract_sids = self._get_contract_sids(root_symbol)
            contracts = deque(self.retrieve_all(contract_sids))
            chain_predicate = self._future_chain_predicates.get(root_symbol,
                                                                None)
            oc = OrderedContracts(root_symbol, contracts, chain_predicate)
            self._ordered_contracts[root_symbol] = oc
            return oc

    def create_continuous_future(self,
                                 root_symbol,
                                 offset,
                                 roll_style,
                                 adjustment):
        if adjustment not in ADJUSTMENT_STYLES:
            raise ValueError(
                'Invalid adjustment style {!r}. Allowed adjustment styles are '
                '{}.'.format(adjustment, list(ADJUSTMENT_STYLES))
            )

        oc = self.get_ordered_contracts(root_symbol)
        exchange = self._get_root_symbol_exchange(root_symbol)

        sid = _encode_continuous_future_sid(root_symbol, offset,
                                            roll_style,
                                            None)
        mul_sid = _encode_continuous_future_sid(root_symbol, offset,
                                                roll_style,
                                                'div')
        add_sid = _encode_continuous_future_sid(root_symbol, offset,
                                                roll_style,
                                                'add')

        cf_template = partial(
            ContinuousFuture,
            root_symbol=root_symbol,
            offset=offset,
            roll_style=roll_style,
            start_date=oc.start_date,
            end_date=oc.end_date,
            exchange=exchange,
        )

        cf = cf_template(sid=sid)
        mul_cf = cf_template(sid=mul_sid, adjustment='mul')
        add_cf = cf_template(sid=add_sid, adjustment='add')

        self._asset_cache[cf.sid] = cf
        self._asset_cache[mul_cf.sid] = mul_cf
        self._asset_cache[add_cf.sid] = add_cf

        return {None: cf, 'mul': mul_cf, 'add': add_cf}[adjustment]

    def _make_sids(tblattr):
        def _(self):
            return tuple(map(
                itemgetter('sid'),
                sa.select((
                    getattr(self, tblattr).c.sid,
                )).execute().fetchall(),
            ))

        return _

    sids = property(
        _make_sids('asset_router'),
        doc='All the sids in the asset finder.',
    )
    equities_sids = property(
        _make_sids('equities'),
        doc='All of the sids for equities in the asset finder.',
    )
    futures_sids = property(
        _make_sids('futures_contracts'),
        doc='All of the sids for futures consracts in the asset finder.',
    )
    del _make_sids

    @lazyval
    def _symbol_lookups(self):
        """
        An iterable of symbol lookup functions to use with ``lookup_generic``

        Attempts equities lookup, then futures.
        """
        return (
            self.lookup_symbol,
            # lookup_future_symbol method does not use as_of date, since
            # symbols are unique.
            #
            # Wrap the function in a lambda so that both methods share a
            # signature, so that when the functions are iterated over
            # the consumer can use the same arguments with both methods.
            lambda symbol, _: self.lookup_future_symbol(symbol)
        )

    def _lookup_generic_scalar(self,
                               asset_convertible,
                               as_of_date,
                               matches,
                               missing):
        """
        Convert asset_convertible to an asset.

        On success, append to matches.
        On failure, append to missing.
        """
        if isinstance(asset_convertible, Asset):
            matches.append(asset_convertible)

        elif isinstance(asset_convertible, Integral):
            try:
                result = self.retrieve_asset(int(asset_convertible))
            except SidsNotFound:
                missing.append(asset_convertible)
                return None
            matches.append(result)

        elif isinstance(asset_convertible, string_types):
            for lookup in self._symbol_lookups:
                try:
                    matches.append(lookup(asset_convertible, as_of_date))
                    return
                except SymbolNotFound:
                    continue
            else:
                missing.append(asset_convertible)
                return None
        else:
            raise NotAssetConvertible(
                "Input was %s, not AssetConvertible."
                % asset_convertible
            )

    def lookup_generic(self,
                       asset_convertible_or_iterable,
                       as_of_date):
        """
        Convert a AssetConvertible or iterable of AssetConvertibles into
        a list of Asset objects.

        This method exists primarily as a convenience for implementing
        user-facing APIs that can handle multiple kinds of input.  It should
        not be used for internal code where we already know the expected types
        of our inputs.

        Returns a pair of objects, the first of which is the result of the
        conversion, and the second of which is a list containing any values
        that couldn't be resolved.
        """
        matches = []
        missing = []

        # Interpret input as scalar.
        if isinstance(asset_convertible_or_iterable, AssetConvertible):
            self._lookup_generic_scalar(
                asset_convertible=asset_convertible_or_iterable,
                as_of_date=as_of_date,
                matches=matches,
                missing=missing,
            )
            try:
                return matches[0], missing
            except IndexError:
                if hasattr(asset_convertible_or_iterable, '__int__'):
                    raise SidsNotFound(sids=[asset_convertible_or_iterable])
                else:
                    raise SymbolNotFound(symbol=asset_convertible_or_iterable)

        # If the input is a ContinuousFuture just return it as-is.
        elif isinstance(asset_convertible_or_iterable, ContinuousFuture):
            return asset_convertible_or_iterable, missing

        # Interpret input as iterable.
        try:
            iterator = iter(asset_convertible_or_iterable)
        except TypeError:
            raise NotAssetConvertible(
                "Input was not a AssetConvertible "
                "or iterable of AssetConvertible."
            )

        for obj in iterator:
            if isinstance(obj, ContinuousFuture):
                matches.append(obj)
            else:
                self._lookup_generic_scalar(obj, as_of_date, matches, missing)
        return matches, missing

    def map_identifier_index_to_sids(self, index, as_of_date):
        """
        This method is for use in sanitizing a user's DataFrame or Panel
        inputs.

        Takes the given index of identifiers, checks their types, builds assets
        if necessary, and returns a list of the sids that correspond to the
        input index.

        Parameters
        ----------
        index : Iterable
            An iterable containing ints, strings, or Assets
        as_of_date : pandas.Timestamp
            A date to be used to resolve any dual-mapped symbols

        Returns
        -------
        List
            A list of integer sids corresponding to the input index
        """
        # This method assumes that the type of the objects in the index is
        # consistent and can, therefore, be taken from the first identifier
        first_identifier = index[0]

        # Ensure that input is AssetConvertible (integer, string, or Asset)
        if not isinstance(first_identifier, AssetConvertible):
            raise MapAssetIdentifierIndexError(obj=first_identifier)

        # If sids are provided, no mapping is necessary
        if isinstance(first_identifier, Integral):
            return index

        # Look up all Assets for mapping
        matches = []
        missing = []
        for identifier in index:
            self._lookup_generic_scalar(identifier, as_of_date,
                                        matches, missing)

        if missing:
            raise ValueError("Missing assets for identifiers: %s" % missing)

        # Return a list of the sids of the found assets
        return [asset.sid for asset in matches]

    def _compute_asset_lifetimes(self):
        """
        Compute and cache a recarry of asset lifetimes.
        """
        equities_cols = self.equities.c
        buf = np.array(
            tuple(
                sa.select((
                    equities_cols.sid,
                    equities_cols.start_date,
                    equities_cols.end_date,
                )).execute(),
            ), dtype='<f8',  # use doubles so we get NaNs
        )
        lifetimes = np.recarray(
            buf=buf,
            shape=(len(buf),),
            dtype=[
                ('sid', '<f8'),
                ('start', '<f8'),
                ('end', '<f8')
            ],
        )
        start = lifetimes.start
        end = lifetimes.end
        start[np.isnan(start)] = 0  # convert missing starts to 0
        end[np.isnan(end)] = np.iinfo(int).max  # convert missing end to INTMAX
        # Cast the results back down to int.
        return lifetimes.astype([
            ('sid', '<i8'),
            ('start', '<i8'),
            ('end', '<i8'),
        ])

    def lifetimes(self, dates, include_start_date):
        """
        Compute a DataFrame representing asset lifetimes for the specified date
        range.

        Parameters
        ----------
        dates : pd.DatetimeIndex
            The dates for which to compute lifetimes.
        include_start_date : bool
            Whether or not to count the asset as alive on its start_date.

            This is useful in a backtesting context where `lifetimes` is being
            used to signify "do I have data for this asset as of the morning of
            this date?"  For many financial metrics, (e.g. daily close), data
            isn't available for an asset until the end of the asset's first
            day.

        Returns
        -------
        lifetimes : pd.DataFrame
            A frame of dtype bool with `dates` as index and an Int64Index of
            assets as columns.  The value at `lifetimes.loc[date, asset]` will
            be True iff `asset` existed on `date`.  If `include_start_date` is
            False, then lifetimes.loc[date, asset] will be false when date ==
            asset.start_date.

        See Also
        --------
        numpy.putmask
        zipline.pipeline.engine.SimplePipelineEngine._compute_root_mask
        """
        # This is a less than ideal place to do this, because if someone adds
        # assets to the finder after we've touched lifetimes we won't have
        # those new assets available.  Mutability is not my favorite
        # programming feature.
        if self._asset_lifetimes is None:
            self._asset_lifetimes = self._compute_asset_lifetimes()
        lifetimes = self._asset_lifetimes

        raw_dates = as_column(dates.asi8)
        if include_start_date:
            mask = lifetimes.start <= raw_dates
        else:
            mask = lifetimes.start < raw_dates
        mask &= (raw_dates <= lifetimes.end)

        return pd.DataFrame(mask, index=dates, columns=lifetimes.sid)
Ejemplo n.º 4
0
class AssetFinder(object):
    """
    An AssetFinder is an interface to a database of Asset metadata written by
    an ``AssetDBWriter``.

    This class provides methods for looking up assets by unique integer id or
    by symbol.  For historical reasons, we refer to these unique ids as 'sids'.

    Parameters
    ----------
    engine : str or SQLAlchemy.engine
        An engine with a connection to the asset database to use, or a string
        that can be parsed by SQLAlchemy as a URI.
    future_chain_predicates : dict
        A dict mapping future root symbol to a predicate function which accepts
    a contract as a parameter and returns whether or not the contract should be
    included in the chain.

    See Also
    --------
    :class:`zipline.assets.AssetDBWriter`
    """
    @preprocess(engine=coerce_string_to_eng(require_exists=True))
    def __init__(self, engine, future_chain_predicates=CHAIN_PREDICATES):
        self.engine = engine
        metadata = sa.MetaData(bind=engine)
        metadata.reflect(only=asset_db_table_names)
        for table_name in asset_db_table_names:
            setattr(self, table_name, metadata.tables[table_name])

        # Check the version info of the db for compatibility
        check_version_info(engine, self.version_info, ASSET_DB_VERSION)

        # Cache for lookup of assets by sid, the objects in the asset lookup
        # may be shared with the results from equity and future lookup caches.
        #
        # The top level cache exists to minimize lookups on the asset type
        # routing.
        #
        # The caches are read through, i.e. accessing an asset through
        # retrieve_asset will populate the cache on first retrieval.
        self._asset_cache = {}
        self._asset_type_cache = {}
        self._caches = (self._asset_cache, self._asset_type_cache)

        self._future_chain_predicates = future_chain_predicates \
            if future_chain_predicates is not None else {}
        self._ordered_contracts = {}

        # Populated on first call to `lifetimes`.
        self._asset_lifetimes = {}

    @lazyval
    def exchange_info(self):
        es = sa.select(self.exchanges.c).execute().fetchall()
        return {
            name: ExchangeInfo(name, canonical_name, country_code)
            for name, canonical_name, country_code in es
        }

    @lazyval
    def symbol_ownership_map(self):
        out = {}
        for mappings in self.symbol_ownership_maps_by_country_code.values():
            for key, ownership_periods in mappings.items():
                out.setdefault(key, []).extend(ownership_periods)

        return out

    @lazyval
    def symbol_ownership_maps_by_country_code(self):
        sid_to_country_code = dict(
            sa.select((
                self.equities.c.sid,
                self.exchanges.c.country_code,
            )).where(self.equities.c.exchange ==
                     self.exchanges.c.exchange).execute().fetchall(), )

        return build_grouped_ownership_map(
            table=self.equity_symbol_mappings,
            key_from_row=(lambda row:
                          (row.company_symbol, row.share_class_symbol)),
            value_from_row=lambda row: row.symbol,
            group_key=lambda row: sid_to_country_code[row.sid],
        )

    @lazyval
    def country_codes(self):
        return tuple(self.symbol_ownership_maps_by_country_code)

    def lookup_asset_types(self, sids):
        """
        Retrieve asset types for a list of sids.

        Parameters
        ----------
        sids : list[int]

        Returns
        -------
        types : dict[sid -> str or None]
            Asset types for the provided sids.
        """
        found = {}
        missing = set()

        for sid in sids:
            try:
                found[sid] = self._asset_type_cache[sid]
            except KeyError:
                missing.add(sid)

        if not missing:
            return found

        router_cols = self.asset_router.c

        for assets in group_into_chunks(missing):
            query = sa.select((router_cols.sid, router_cols.asset_type)).where(
                self.asset_router.c.sid.in_(map(int, assets)))
            for sid, type_ in query.execute().fetchall():
                missing.remove(sid)
                found[sid] = self._asset_type_cache[sid] = type_

            for sid in missing:
                found[sid] = self._asset_type_cache[sid] = None

        return found

    def group_by_type(self, sids):
        """
        Group a list of sids by asset type.

        Parameters
        ----------
        sids : list[int]

        Returns
        -------
        types : dict[str or None -> list[int]]
            A dict mapping unique asset types to lists of sids drawn from sids.
            If we fail to look up an asset, we assign it a key of None.
        """
        return invert(self.lookup_asset_types(sids))

    def retrieve_asset(self, sid, default_none=False):
        """
        Retrieve the Asset for a given sid.
        """
        try:
            asset = self._asset_cache[sid]
            if asset is None and not default_none:
                raise SidsNotFound(sids=[sid])
            return asset
        except KeyError:
            return self.retrieve_all((sid, ), default_none=default_none)[0]

    def retrieve_all(self, sids, default_none=False):
        """
        Retrieve all assets in `sids`.

        Parameters
        ----------
        sids : iterable of int
            Assets to retrieve.
        default_none : bool
            If True, return None for failed lookups.
            If False, raise `SidsNotFound`.

        Returns
        -------
        assets : list[Asset or None]
            A list of the same length as `sids` containing Assets (or Nones)
            corresponding to the requested sids.

        Raises
        ------
        SidsNotFound
            When a requested sid is not found and default_none=False.
        """
        sids = list(sids)
        hits, missing, failures = {}, set(), []
        for sid in sids:
            try:
                asset = self._asset_cache[sid]
                if not default_none and asset is None:
                    # Bail early if we've already cached that we don't know
                    # about an asset.
                    raise SidsNotFound(sids=[sid])
                hits[sid] = asset
            except KeyError:
                missing.add(sid)

        # All requests were cache hits.  Return requested sids in order.
        if not missing:
            return [hits[sid] for sid in sids]

        update_hits = hits.update

        # Look up cache misses by type.
        type_to_assets = self.group_by_type(missing)

        # Handle failures
        failures = {failure: None for failure in type_to_assets.pop(None, ())}
        update_hits(failures)
        self._asset_cache.update(failures)

        if failures and not default_none:
            raise SidsNotFound(sids=list(failures))

        # We don't update the asset cache here because it should already be
        # updated by `self.retrieve_equities`.
        update_hits(self.retrieve_equities(type_to_assets.pop('equity', ())))
        update_hits(
            self.retrieve_futures_contracts(type_to_assets.pop('future', ())))

        # We shouldn't know about any other asset types.
        if type_to_assets:
            raise AssertionError("Found asset types: %s" %
                                 list(type_to_assets.keys()))

        return [hits[sid] for sid in sids]

    def retrieve_equities(self, sids):
        """
        Retrieve Equity objects for a list of sids.

        Users generally shouldn't need to this method (instead, they should
        prefer the more general/friendly `retrieve_assets`), but it has a
        documented interface and tests because it's used upstream.

        Parameters
        ----------
        sids : iterable[int]

        Returns
        -------
        equities : dict[int -> Equity]

        Raises
        ------
        EquitiesNotFound
            When any requested asset isn't found.
        """
        return self._retrieve_assets(sids, self.equities, Equity)

    def _retrieve_equity(self, sid):
        return self.retrieve_equities((sid, ))[sid]

    def retrieve_futures_contracts(self, sids):
        """
        Retrieve Future objects for an iterable of sids.

        Users generally shouldn't need to this method (instead, they should
        prefer the more general/friendly `retrieve_assets`), but it has a
        documented interface and tests because it's used upstream.

        Parameters
        ----------
        sids : iterable[int]

        Returns
        -------
        equities : dict[int -> Equity]

        Raises
        ------
        EquitiesNotFound
            When any requested asset isn't found.
        """
        return self._retrieve_assets(sids, self.futures_contracts, Future)

    @staticmethod
    def _select_assets_by_sid(asset_tbl, sids):
        return sa.select([asset_tbl]).where(asset_tbl.c.sid.in_(map(int,
                                                                    sids)))

    @staticmethod
    def _select_asset_by_symbol(asset_tbl, symbol):
        return sa.select([asset_tbl]).where(asset_tbl.c.symbol == symbol)

    def _select_most_recent_symbols_chunk(self, sid_group):
        """Retrieve the most recent symbol for a set of sids.

        Parameters
        ----------
        sid_group : iterable[int]
            The sids to lookup. The length of this sequence must be less than
            or equal to SQLITE_MAX_VARIABLE_NUMBER because the sids will be
            passed in as sql bind params.

        Returns
        -------
        sel : Selectable
            The sqlalchemy selectable that will query for the most recent
            symbol for each sid.

        Notes
        -----
        This is implemented as an inner select of the columns of interest
        ordered by the end date of the (sid, symbol) mapping. We then group
        that inner select on the sid with no aggregations to select the last
        row per group which gives us the most recently active symbol for all
        of the sids.
        """
        cols = self.equity_symbol_mappings.c

        # These are the columns we actually want.
        data_cols = (cols.sid, ) + tuple(cols[name] for name in symbol_columns)

        # Also select the max of end_date so that all non-grouped fields take
        # on the value associated with the max end_date. The SQLite docs say
        # this:
        #
        # When the min() or max() aggregate functions are used in an aggregate
        # query, all bare columns in the result set take values from the input
        # row which also contains the minimum or maximum. Only the built-in
        # min() and max() functions work this way.
        #
        # See https://www.sqlite.org/lang_select.html#resultset, for more info.
        to_select = data_cols + (sa.func.max(cols.end_date), )

        return sa.select(to_select, ).where(cols.sid.in_(map(
            int, sid_group))).group_by(cols.sid, )

    def _lookup_most_recent_symbols(self, sids):
        return {
            row.sid: {c: row[c]
                      for c in symbol_columns}
            for row in concat(
                self.engine.execute(
                    self._select_most_recent_symbols_chunk(sid_group),
                ).fetchall() for sid_group in partition_all(
                    SQLITE_MAX_VARIABLE_NUMBER, sids))
        }

    def _retrieve_asset_dicts(self, sids, asset_tbl, querying_equities):
        if not sids:
            return

        if querying_equities:

            def mkdict(row,
                       exchanges=self.exchange_info,
                       symbols=self._lookup_most_recent_symbols(sids)):
                d = dict(row)
                d['exchange_info'] = exchanges[d.pop('exchange')]
                # we are not required to have a symbol for every asset, if
                # we don't have any symbols we will just use the empty string
                return merge(d, symbols.get(row['sid'], {}))
        else:

            def mkdict(row, exchanges=self.exchange_info):
                d = dict(row)
                d['exchange_info'] = exchanges[d.pop('exchange')]
                return d

        for assets in group_into_chunks(sids):
            # Load misses from the db.
            query = self._select_assets_by_sid(asset_tbl, assets)

            for row in query.execute().fetchall():
                yield _convert_asset_timestamp_fields(mkdict(row))

    def _retrieve_assets(self, sids, asset_tbl, asset_type):
        """
        Internal function for loading assets from a table.

        This should be the only method of `AssetFinder` that writes Assets into
        self._asset_cache.

        Parameters
        ---------
        sids : iterable of int
            Asset ids to look up.
        asset_tbl : sqlalchemy.Table
            Table from which to query assets.
        asset_type : type
            Type of asset to be constructed.

        Returns
        -------
        assets : dict[int -> Asset]
            Dict mapping requested sids to the retrieved assets.
        """
        # Fastpath for empty request.
        if not sids:
            return {}

        cache = self._asset_cache
        hits = {}

        querying_equities = issubclass(asset_type, Equity)
        filter_kwargs = (_filter_equity_kwargs
                         if querying_equities else _filter_future_kwargs)

        rows = self._retrieve_asset_dicts(sids, asset_tbl, querying_equities)
        for row in rows:
            sid = row['sid']
            asset = asset_type(**filter_kwargs(row))
            hits[sid] = cache[sid] = asset

        # If we get here, it means something in our code thought that a
        # particular sid was an equity/future and called this function with a
        # concrete type, but we couldn't actually resolve the asset.  This is
        # an error in our code, not a user-input error.
        misses = tuple(set(sids) - viewkeys(hits))
        if misses:
            if querying_equities:
                raise EquitiesNotFound(sids=misses)
            else:
                raise FutureContractsNotFound(sids=misses)
        return hits

    def _choose_symbol_ownership_map(self, country_code):
        if country_code is None:
            return self.symbol_ownership_map

        return self.symbol_ownership_maps_by_country_code.get(country_code)

    def lookup_future_symbol(self, symbol):
        """Lookup a future contract by symbol.

        Parameters
        ----------
        symbol : str
            The symbol of the desired contract.

        Returns
        -------
        future : Future
            The future contract referenced by ``symbol``.

        Raises
        ------
        SymbolNotFound
            Raised when no contract named 'symbol' is found.

        """

        data = self._select_asset_by_symbol(self.futures_contracts, symbol)\
                   .execute().fetchone()

        # If no data found, raise an exception
        if not data:
            raise SymbolNotFound(symbol=symbol)
        return self.retrieve_asset(data['sid'])

    def _get_contract_sids(self, root_symbol):
        fc_cols = self.futures_contracts.c

        return [
            r.sid for r in list(
                sa.select((
                    fc_cols.sid, )).where((fc_cols.root_symbol == root_symbol)
                                          & (pd.notnull(fc_cols.start_date))).
                order_by(fc_cols.auto_close_date).execute().fetchall())
        ]

    def _get_root_symbol_exchange(self, root_symbol):
        fc_cols = self.futures_root_symbols.c

        fields = (fc_cols.exchange, )

        exchange = sa.select(fields).where(
            fc_cols.root_symbol == root_symbol).execute().scalar()

        if exchange is not None:
            return exchange
        else:
            raise SymbolNotFound(symbol=root_symbol)

    def get_ordered_contracts(self, root_symbol):
        try:
            return self._ordered_contracts[root_symbol]
        except KeyError:
            contract_sids = self._get_contract_sids(root_symbol)
            contracts = deque(self.retrieve_all(contract_sids))
            chain_predicate = self._future_chain_predicates.get(
                root_symbol, None)
            oc = OrderedContracts(root_symbol, contracts, chain_predicate)
            self._ordered_contracts[root_symbol] = oc
            return oc

    def create_continuous_future(self, root_symbol, offset, roll_style,
                                 adjustment):
        if adjustment not in ADJUSTMENT_STYLES:
            raise ValueError(
                'Invalid adjustment style {!r}. Allowed adjustment styles are '
                '{}.'.format(adjustment, list(ADJUSTMENT_STYLES)))

        oc = self.get_ordered_contracts(root_symbol)
        exchange = self._get_root_symbol_exchange(root_symbol)

        sid = _encode_continuous_future_sid(root_symbol, offset, roll_style,
                                            None)
        mul_sid = _encode_continuous_future_sid(root_symbol, offset,
                                                roll_style, 'div')
        add_sid = _encode_continuous_future_sid(root_symbol, offset,
                                                roll_style, 'add')

        cf_template = partial(
            ContinuousFuture,
            root_symbol=root_symbol,
            offset=offset,
            roll_style=roll_style,
            start_date=oc.start_date,
            end_date=oc.end_date,
            exchange_info=self.exchange_info[exchange],
        )

        cf = cf_template(sid=sid)
        mul_cf = cf_template(sid=mul_sid, adjustment='mul')
        add_cf = cf_template(sid=add_sid, adjustment='add')

        self._asset_cache[cf.sid] = cf
        self._asset_cache[mul_cf.sid] = mul_cf
        self._asset_cache[add_cf.sid] = add_cf

        return {None: cf, 'mul': mul_cf, 'add': add_cf}[adjustment]

    def _make_sids(tblattr):
        def _(self):
            return tuple(
                map(
                    itemgetter('sid'),
                    sa.select((getattr(self,
                                       tblattr).c.sid, )).execute().fetchall(),
                ))

        return _

    sids = property(
        _make_sids('asset_router'),
        doc='All the sids in the asset finder.',
    )
    equities_sids = property(
        _make_sids('equities'),
        doc='All of the sids for equities in the asset finder.',
    )
    futures_sids = property(
        _make_sids('futures_contracts'),
        doc='All of the sids for futures consracts in the asset finder.',
    )
    del _make_sids

    def _compute_asset_lifetimes(self, country_codes):
        """
        Compute and cache a recarray of asset lifetimes.
        """
        sids = starts = ends = []
        equities_cols = self.equities.c
        futures_cols = self.futures_contracts.c
        if country_codes:
            equities_query = sa.select((
                equities_cols.sid,
                equities_cols.start_date,
                equities_cols.end_date,
            )).where((self.exchanges.c.exchange == equities_cols.exchange)
                     & (self.exchanges.c.country_code.in_(country_codes)))
            futures_query = sa.select((
                futures_cols.sid,
                futures_cols.start_date,
                futures_cols.end_date,
            )).where((self.exchanges.c.exchange == futures_cols.exchange)
                     & (self.exchanges.c.country_code.in_(country_codes)))
            results = equities_query.union(futures_query).execute().fetchall()
            if results:
                sids, starts, ends = zip(*results)

        sid = np.array(sids, dtype='i8')
        start = np.array(starts, dtype='f8')
        end = np.array(ends, dtype='f8')
        start[np.isnan(start)] = 0  # convert missing starts to 0
        end[np.isnan(end)] = np.iinfo(int).max  # convert missing end to INTMAX
        return Lifetimes(sid, start.astype('i8'), end.astype('i8'))

    def lifetimes(self, dates, include_start_date, country_codes):
        """
        Compute a DataFrame representing asset lifetimes for the specified date
        range.

        Parameters
        ----------
        dates : pd.DatetimeIndex
            The dates for which to compute lifetimes.
        include_start_date : bool
            Whether or not to count the asset as alive on its start_date.

            This is useful in a backtesting context where `lifetimes` is being
            used to signify "do I have data for this asset as of the morning of
            this date?"  For many financial metrics, (e.g. daily close), data
            isn't available for an asset until the end of the asset's first
            day.
        country_codes : iterable[str]
            The country codes to get lifetimes for.

        Returns
        -------
        lifetimes : pd.DataFrame
            A frame of dtype bool with `dates` as index and an Int64Index of
            assets as columns.  The value at `lifetimes.loc[date, asset]` will
            be True iff `asset` existed on `date`.  If `include_start_date` is
            False, then lifetimes.loc[date, asset] will be false when date ==
            asset.start_date.

        See Also
        --------
        numpy.putmask
        zipline.pipeline.engine.SimplePipelineEngine._compute_root_mask
        """
        if isinstance(country_codes, string_types):
            raise TypeError(
                "Got string {!r} instead of an iterable of strings in "
                "AssetFinder.lifetimes.".format(country_codes), )

        # normalize to a cache-key so that we can memoize results.
        country_codes = frozenset(country_codes)

        lifetimes = self._asset_lifetimes.get(country_codes)
        if lifetimes is None:
            self._asset_lifetimes[country_codes] = lifetimes = (
                self._compute_asset_lifetimes(country_codes))

        raw_dates = as_column(dates.asi8)
        if include_start_date:
            mask = lifetimes.start <= raw_dates
        else:
            mask = lifetimes.start < raw_dates
        mask &= (raw_dates <= lifetimes.end)

        return pd.DataFrame(mask, index=dates, columns=lifetimes.sid)

    def equities_sids_for_country_code(self, country_code):
        """Return all of the sids for a given country.

        Parameters
        ----------
        country_code : str
            An ISO 3166 alpha-2 country code.

        Returns
        -------
        tuple[int]
            The sids whose exchanges are in this country.
        """
        sids = self._compute_asset_lifetimes([country_code]).sid
        return tuple(sids.tolist())
Ejemplo n.º 5
0
            try:
                op.drop_index('ix_%s_%s' % (table, column.name))
            except sa.exc.OperationalError:
                pass

    op.create_table(name, *columns)
    op.execute(
        'insert into %s select %s from %s' % (
            name,
            selection_string,
            tmp_name,
        ), )
    op.drop_table(tmp_name)


@preprocess(engine=coerce_string_to_eng(require_exists=True))
def downgrade(engine, desired_version):
    """Downgrades the assets db at the given engine to the desired version.

    Parameters
    ----------
    engine : Engine
        An SQLAlchemy engine to the assets database.
    desired_version : int
        The desired resulting version for the assets database.
    """

    # Check the version of the db at the engine
    with engine.begin() as conn:
        metadata = sa.MetaData(conn)
        metadata.reflect()
Ejemplo n.º 6
0
class FundamentalWriter(object):
    table_names = ['fundamental', 'full']

    @preprocess(engine=coerce_string_to_eng(require_exists=False))
    def __init__(self, engine):
        self.engine = engine

    def write(self, start, end):
        self.init_db(self.engine)

        start = max(2010, int(start.strftime('%Y')))
        end = int(min(pd.to_datetime('today', utc=True),
                      end).strftime('%Y')) + 1

        pp = [(i, j) for i in range(start, end) for j in range(1, 5)]

        for i in pp:
            self.quarter_report(*i)
            print(i)

    def fill(self):
        self.init_db(self.engine)
        df = pd.read_sql("select * from fundamental",
                         self.engine).sort_values(['report_date', 'quarter'])
        df['trade_date'] = df['report_date'] = pd.to_datetime(
            df['report_date'])

        with click.progressbar(df.groupby('code'),
                               label='writing data',
                               item_show_func=lambda x: x[0]
                               if x else None) as bar:
            bar.is_hidden = False
            for stock, group in bar:
                group = group.drop_duplicates(
                    subset='trade_date', keep="last").set_index('trade_date')
                sessions = pd.date_range(group.index[0], group.index[-1])
                d = group.reindex(sessions, copy=False).fillna(method='pad')
                d.to_sql('full',
                         self.engine,
                         if_exists='append',
                         index_label='trade_date')

    def all_tables_presents(self, txn):
        conn = txn.connect()
        for table_name in self.table_names:
            if not txn.dialect.has_table(conn, table_name):
                return False

        return True

    def init_db(self, txn):
        if not self.all_tables_presents(txn):
            Base.metadata.create_all(txn.connect(), checkfirst=True)

    def quarter_report(self, year, quarter):
        func_names = [
            "report", "profit", "operation", "growth", "debtpaying", "cashflow"
        ]
        dfs = [call_func(name, year, quarter) for name in func_names]

        df = pd.concat(dfs, axis=1).dropna(axis=0,
                                           subset=['report_date'
                                                   ])  # drop if no report_date
        df['report_date'] = pd.to_datetime(
            str(year) + '-' +
            df['report_date'].apply(lambda x: x if x != '02-29' else '02-28'))
        df['quarter'] = quarter
        df.to_sql('fundamental', self.engine, if_exists='append')