Пример #1
0
class SQLiteAdjustmentReader(object):
    """
    Loads adjustments based on corporate actions from a SQLite database.

    Expects data written in the format output by `SQLiteAdjustmentWriter`.

    Parameters
    ----------
    conn : str or sqlite3.Connection
        Connection from which to load data.

    See Also
    --------
    :class:`zipline.data.adjustments.SQLiteAdjustmentWriter`
    """
    _datetime_int_cols = {
        'splits': ('effective_date', ),
        'mergers': ('effective_date', ),
        'dividends': ('effective_date', ),
        'dividend_payouts': (
            'declared_date',
            'ex_date',
            'pay_date',
            'record_date',
        ),
        'stock_dividend_payouts': (
            'declared_date',
            'ex_date',
            'pay_date',
            'record_date',
        )
    }
    _raw_table_dtypes = {
        # We use any_integer above to be lenient in accepting different dtypes
        # from users. For our outputs, however, we always want to return the
        # same types, and any_integer turns into int32 on some numpy windows
        # builds, so specify int64 explicitly here.
        'splits':
        specialize_any_integer(SQLITE_ADJUSTMENT_COLUMN_DTYPES),
        'mergers':
        specialize_any_integer(SQLITE_ADJUSTMENT_COLUMN_DTYPES),
        'dividends':
        specialize_any_integer(SQLITE_ADJUSTMENT_COLUMN_DTYPES),
        'dividend_payouts':
        specialize_any_integer(SQLITE_DIVIDEND_PAYOUT_COLUMN_DTYPES, ),
        'stock_dividend_payouts':
        specialize_any_integer(SQLITE_STOCK_DIVIDEND_PAYOUT_COLUMN_DTYPES, ),
    }

    @preprocess(conn=coerce_string_to_conn(require_exists=True))
    def __init__(self, conn):
        self.conn = conn

    def __enter__(self):
        return self

    def __exit__(self, *exc_info):
        self.close()

    def close(self):
        return self.conn.close()

    def load_adjustments(self, dates, assets, should_include_splits,
                         should_include_mergers, should_include_dividends,
                         adjustment_type):
        """
        Load collection of Adjustment objects from underlying adjustments db.

        Parameters
        ----------
        dates : pd.DatetimeIndex
            Dates for which adjustments are needed.
        assets : pd.Int64Index
            Assets for which adjustments are needed.
        should_include_splits : bool
            Whether split adjustments should be included.
        should_include_mergers : bool
            Whether merger adjustments should be included.
        should_include_dividends : bool
            Whether dividend adjustments should be included.
        adjustment_type : str
            Whether price adjustments, volume adjustments, or both, should be
            included in the output.

        Returns
        -------
        adjustments : dict[str -> dict[int -> Adjustment]]
            A dictionary containing price and/or volume adjustment mappings
            from index to adjustment objects to apply at that index.
        """
        return load_adjustments_from_sqlite(
            self.conn,
            dates,
            assets,
            should_include_splits,
            should_include_mergers,
            should_include_dividends,
            adjustment_type,
        )

    def load_pricing_adjustments(self, columns, dates, assets):
        if 'volume' not in set(columns):
            adjustment_type = 'price'
        elif len(set(columns)) == 1:
            adjustment_type = 'volume'
        else:
            adjustment_type = 'all'

        adjustments = self.load_adjustments(
            dates,
            assets,
            should_include_splits=True,
            should_include_mergers=True,
            should_include_dividends=True,
            adjustment_type=adjustment_type,
        )
        price_adjustments = adjustments.get('price')
        volume_adjustments = adjustments.get('volume')

        return [
            volume_adjustments if column == 'volume' else price_adjustments
            for column in columns
        ]

    def get_adjustments_for_sid(self, table_name, sid):
        t = (sid, )
        c = self.conn.cursor()
        adjustments_for_sid = c.execute(
            "SELECT effective_date, ratio FROM %s WHERE sid = ?" % table_name,
            t).fetchall()
        c.close()

        return [[Timestamp(adjustment[0], unit='s', tz='UTC'), adjustment[1]]
                for adjustment in adjustments_for_sid]

    def get_dividends_with_ex_date(self, assets, date, asset_finder):
        seconds = date.value / int(1e9)
        c = self.conn.cursor()

        divs = []
        for chunk in group_into_chunks(assets):
            query = UNPAID_QUERY_TEMPLATE.format(",".join(['?'
                                                           for _ in chunk]))
            t = (seconds, ) + tuple(map(lambda x: int(x), chunk))

            c.execute(query, t)

            rows = c.fetchall()
            for row in rows:
                div = Dividend(asset_finder.retrieve_asset(row[0]), row[1],
                               Timestamp(row[2], unit='s', tz='UTC'))
                divs.append(div)
        c.close()

        return divs

    def get_stock_dividends_with_ex_date(self, assets, date, asset_finder):
        seconds = date.value / int(1e9)
        c = self.conn.cursor()

        stock_divs = []
        for chunk in group_into_chunks(assets):
            query = UNPAID_STOCK_DIVIDEND_QUERY_TEMPLATE.format(",".join(
                ['?' for _ in chunk]))
            t = (seconds, ) + tuple(map(lambda x: int(x), chunk))

            c.execute(query, t)

            rows = c.fetchall()

            for row in rows:
                stock_div = StockDividend(
                    asset_finder.retrieve_asset(row[0]),  # asset
                    asset_finder.retrieve_asset(row[1]),  # payment_asset
                    row[2],
                    Timestamp(row[3], unit='s', tz='UTC'))
                stock_divs.append(stock_div)
        c.close()

        return stock_divs

    def unpack_db_to_component_dfs(self, convert_dates=False):
        """Returns the set of known tables in the adjustments file in DataFrame
        form.

        Parameters
        ----------
        convert_dates : bool, optional
            By default, dates are returned in seconds since EPOCH. If
            convert_dates is True, all ints in date columns will be converted
            to datetimes.

        Returns
        -------
        dfs : dict{str->DataFrame}
            Dictionary which maps table name to the corresponding DataFrame
            version of the table, where all date columns have been coerced back
            from int to datetime.
        """
        return {
            t_name: self.get_df_from_table(t_name, convert_dates)
            for t_name in self._datetime_int_cols
        }

    def get_df_from_table(self, table_name, convert_dates=False):
        try:
            date_cols = self._datetime_int_cols[table_name]
        except KeyError:
            raise ValueError("Requested table %s not found.\n"
                             "Available tables: %s\n" % (
                                 table_name,
                                 self._datetime_int_cols.keys(),
                             ))

        # Dates are stored in second resolution as ints in adj.db tables.
        # Need to specifically convert them as UTC, not local time.
        kwargs = ({
            'parse_dates':
            {col: {
                'unit': 's',
                'utc': True
            }
             for col in date_cols}
        } if convert_dates else {})

        result = pd.read_sql('select * from "{}"'.format(table_name),
                             self.conn,
                             index_col='index',
                             **kwargs).rename_axis(None)

        if not len(result):
            dtypes = self._df_dtypes(table_name, convert_dates)
            return empty_dataframe(*keysorted(dtypes))

        return result

    def _df_dtypes(self, table_name, convert_dates):
        """Get dtypes to use when unpacking sqlite tables as dataframes.
        """
        out = self._raw_table_dtypes[table_name]
        if convert_dates:
            out = out.copy()
            for date_column in self._datetime_int_cols[table_name]:
                out[date_column] = datetime64ns_dtype

        return out
Пример #2
0
class SQLiteAdjustmentReader(object):
    """
    Loads adjustments based on corporate actions from a SQLite database.

    Expects data written in the format output by `SQLiteAdjustmentWriter`.

    Parameters
    ----------
    conn : str or sqlite3.Connection
        Connection from which to load data.

    See Also
    --------
    :class:`zipline.data.adjustments.SQLiteAdjustmentWriter`
    """
    @preprocess(conn=coerce_string_to_conn(require_exists=True))
    def __init__(self, conn):
        self.conn = conn

        # Given the tables in the adjustments.db file, dict which knows which
        # col names contain dates that have been coerced into ints.
        self._datetime_int_cols = {
            'dividend_payouts':
            ('declared_date', 'ex_date', 'pay_date', 'record_date'),
            'dividends': ('effective_date', ),
            'mergers': ('effective_date', ),
            'splits': ('effective_date', ),
            'stock_dividend_payouts':
            ('declared_date', 'ex_date', 'pay_date', 'record_date')
        }

    def __enter__(self):
        return self

    def __exit__(self, *exc_info):
        self.close()

    def close(self):
        return self.conn.close()

    def load_adjustments(self, columns, dates, assets):
        return load_adjustments_from_sqlite(
            self.conn,
            list(columns),
            dates,
            assets,
        )

    def get_adjustments_for_sid(self, table_name, sid):
        t = (sid, )
        c = self.conn.cursor()
        adjustments_for_sid = c.execute(
            "SELECT effective_date, ratio FROM %s WHERE sid = ?" % table_name,
            t).fetchall()
        c.close()

        return [[Timestamp(adjustment[0], unit='s', tz='UTC'), adjustment[1]]
                for adjustment in adjustments_for_sid]

    def get_dividends_with_ex_date(self, assets, date, asset_finder):
        seconds = date.value / int(1e9)
        c = self.conn.cursor()

        divs = []
        for chunk in group_into_chunks(assets):
            query = UNPAID_QUERY_TEMPLATE.format(",".join(['?'
                                                           for _ in chunk]))
            t = (seconds, ) + tuple(map(lambda x: int(x), chunk))

            c.execute(query, t)

            rows = c.fetchall()
            for row in rows:
                div = Dividend(asset_finder.retrieve_asset(row[0]), row[1],
                               Timestamp(row[2], unit='s', tz='UTC'))
                divs.append(div)
        c.close()

        return divs

    def get_stock_dividends_with_ex_date(self, assets, date, asset_finder):
        seconds = date.value / int(1e9)
        c = self.conn.cursor()

        stock_divs = []
        for chunk in group_into_chunks(assets):
            query = UNPAID_STOCK_DIVIDEND_QUERY_TEMPLATE.format(",".join(
                ['?' for _ in chunk]))
            t = (seconds, ) + tuple(map(lambda x: int(x), chunk))

            c.execute(query, t)

            rows = c.fetchall()

            for row in rows:
                stock_div = StockDividend(
                    asset_finder.retrieve_asset(row[0]),  # asset
                    asset_finder.retrieve_asset(row[1]),  # payment_asset
                    row[2],
                    Timestamp(row[3], unit='s', tz='UTC'))
                stock_divs.append(stock_div)
        c.close()

        return stock_divs

    def unpack_db_to_component_dfs(self, convert_dates=False):
        """Returns the set of known tables in the adjustments file in DataFrame
        form.

        Parameters
        ----------
        convert_dates : bool, optional
            By default, dates are returned in seconds since EPOCH. If
            convert_dates is True, all ints in date columns will be converted
            to datetimes.

        Returns
        -------
        dfs : dict{str->DataFrame}
            Dictionary which maps table name to the corresponding DataFrame
            version of the table, where all date columns have been coerced back
            from int to datetime.
        """
        def _get_df_from_table(table_name, date_cols):

            # Dates are stored in second resolution as ints in adj.db tables.
            # Need to specifically convert them as UTC, not local time.
            kwargs = ({
                'parse_dates':
                {col: {
                    'unit': 's',
                    'utc': True
                }
                 for col in date_cols}
            } if convert_dates else {})

            return pd.read_sql('select * from "{}"'.format(table_name),
                               self.conn,
                               index_col='index',
                               **kwargs).rename_axis(None)

        return {
            t_name: _get_df_from_table(t_name, date_cols)
            for t_name, date_cols in self._datetime_int_cols.items()
        }
Пример #3
0
class SQLiteFundamentalsReader(object):
    """
    Loads fundamentals from a SQLite database.

    Expects data written in the format output by `SQLiteFundamentalsWriter`.

    Parameters
    ----------
    conn : str or sqlite3.Connection
        Connection from which to load data.

    See Also
    --------
    :class:`zipline.data.fundamentals.SQLiteFundamentalsWriter`
    """
    @preprocess(conn=coerce_string_to_conn(require_exists=True))
    def __init__(self, conn):
        self.conn = conn

    def read(self, name, dates, assets):

        start_dt64 = dates[0].to_datetime64().astype(any_integer) / 1000000000
        end_dt64 = dates[-1].to_datetime64().astype(any_integer) / 1000000000

        sql = 'SELECT sid, value, date FROM fundamentals_%s WHERE date < %s ORDER BY date' % (
            name, end_dt64)

        df = pd.read_sql_query(
            sql,
            self.conn,
            # index_col=['trading_date', 'code'],
            # parse_dates=['date'],
            # chunksize=500,
        )
        result = pd.DataFrame(index=dates, columns=assets)

        for asset in assets:
            df_sid = df[df['sid'] == asset].copy()

            # set start_date
            st_df = df_sid[df_sid['date'] < start_dt64]['date']
            start_date = st_df.iloc[-1] if st_df.any() else start_dt64

            df_sid = df_sid[df_sid['date'] >= start_date]
            if start_date < start_dt64:
                result[asset].loc[dates[0]] = df_sid['value'].iloc[0]

            for row in df_sid.iterrows():
                date, value = int(row[1]['date']), row[1]['value']
                if date >= end_dt64:
                    break
                dtime = np.datetime64(date, 's')
                if dtime in result.index:
                    result[asset].loc[dtime] = value

        return result.fillna(method='ffill')

    def read_fundamentals(self, names, dates, assets):
        name = names[0]

        start_date = dates[0].to_datetime()
        end_date = dates[-1].to_datetime()

        sql = '''SELECT sid, f_ann_date, end_date, %s From fundamentals''' % name

        df = pd.read_sql_query(
            sql,
            self.conn,
            # index_col=['trading_date', 'code'],
            parse_dates=['f_ann_date'],
            # chunksize=500,
        )

        df = df.sort_values(by="f_ann_date", ascending=True)

        result = pd.DataFrame(index=dates, columns=assets)

        for asset in assets:
            df_sid = df[df['sid'] == asset].copy()

            # set start_date
            st_df = df_sid[df_sid['f_ann_date'] < start_date]['f_ann_date']
            st_date = st_df.iloc[-1] if st_df.any() else start_date

            df_sid = df_sid[df_sid['f_ann_date'] >= st_date]
            st_date = st_date.tz_localize('utc')
            if st_date < start_date:
                result[asset].loc[dates[0]] = df_sid[name].iloc[0]

            for row in df_sid.iterrows():
                date, value = row[1]['f_ann_date'], row[1][name]
                date = date.tz_localize('utc')
                if date >= end_date:
                    break
                dtime = np.datetime64(date, 's')
                if dtime in result.index:
                    result[asset].loc[dtime] = value

        return result.fillna(method='ffill')