Ejemplo n.º 1
0
    def setUpClass(cls):
        metadata = {
            0: {
                "symbol": "CLG06",
                "root_symbol": "CL",
                "start_date": pd.Timestamp("2005-12-01", tz="UTC"),
                "notice_date": pd.Timestamp("2005-12-20", tz="UTC"),
                "expiration_date": pd.Timestamp("2006-01-20", tz="UTC"),
            },
            1: {
                "root_symbol": "CL",
                "symbol": "CLK06",
                "start_date": pd.Timestamp("2005-12-01", tz="UTC"),
                "notice_date": pd.Timestamp("2006-03-20", tz="UTC"),
                "expiration_date": pd.Timestamp("2006-04-20", tz="UTC"),
            },
            2: {
                "symbol": "CLQ06",
                "root_symbol": "CL",
                "start_date": pd.Timestamp("2005-12-01", tz="UTC"),
                "notice_date": pd.Timestamp("2006-06-20", tz="UTC"),
                "expiration_date": pd.Timestamp("2006-07-20", tz="UTC"),
            },
            3: {
                "symbol": "CLX06",
                "root_symbol": "CL",
                "start_date": pd.Timestamp("2006-02-01", tz="UTC"),
                "notice_date": pd.Timestamp("2006-09-20", tz="UTC"),
                "expiration_date": pd.Timestamp("2006-10-20", tz="UTC"),
            },
        }

        env = TradingEnvironment(load=noop_load)
        env.write_data(futures_data=metadata)
        cls.asset_finder = env.asset_finder
Ejemplo n.º 2
0
    def test_compute_lifetimes(self):
        num_assets = 4
        env = TradingEnvironment()
        trading_day = env.trading_day
        first_start = pd.Timestamp('2015-04-01', tz='UTC')

        frame = make_rotating_asset_info(
            num_assets=num_assets,
            first_start=first_start,
            frequency=env.trading_day,
            periods_between_starts=3,
            asset_lifetime=5
        )

        env.write_data(equities_df=frame)
        finder = env.asset_finder

        all_dates = pd.date_range(
            start=first_start,
            end=frame.end_date.max(),
            freq=trading_day,
        )

        for dates in all_subindices(all_dates):
            expected_with_start_raw = full(
                shape=(len(dates), num_assets),
                fill_value=False,
                dtype=bool,
            )
            expected_no_start_raw = full(
                shape=(len(dates), num_assets),
                fill_value=False,
                dtype=bool,
            )

            for i, date in enumerate(dates):
                it = frame[['start_date', 'end_date']].itertuples()
                for j, start, end in it:
                    # This way of doing the checks is redundant, but very
                    # clear.
                    if start <= date <= end:
                        expected_with_start_raw[i, j] = True
                        if start < date:
                            expected_no_start_raw[i, j] = True

            expected_with_start = pd.DataFrame(
                data=expected_with_start_raw,
                index=dates,
                columns=frame.index.values,
            )
            result = finder.lifetimes(dates, include_start_date=True)
            assert_frame_equal(result, expected_with_start)

            expected_no_start = pd.DataFrame(
                data=expected_no_start_raw,
                index=dates,
                columns=frame.index.values,
            )
            result = finder.lifetimes(dates, include_start_date=False)
            assert_frame_equal(result, expected_no_start)
Ejemplo n.º 3
0
    def setUp(self):
        self.constants = {
            # Every day, assume every stock starts at 2, goes down to 1,
            # goes up to 4, and finishes at 3.
            USEquityPricing.low:
            1,
            USEquityPricing.open:
            2,
            USEquityPricing.close:
            3,
            USEquityPricing.high:
            4,
        }
        self.assets = [1, 2, 3]
        self.dates = date_range('2014-01', '2014-03', freq='D', tz='UTC')
        self.loader = ConstantLoader(
            constants=self.constants,
            dates=self.dates,
            assets=self.assets,
        )

        self.asset_info = make_simple_asset_info(
            self.assets,
            start_date=self.dates[0],
            end_date=self.dates[-1],
        )
        environment = TradingEnvironment()
        environment.write_data(equities_df=self.asset_info)
        self.asset_finder = environment.asset_finder
Ejemplo n.º 4
0
    def setUpClass(cls):
        metadata = {
            0: {
                'symbol': 'CLG06',
                'root_symbol': 'CL',
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
                'notice_date': pd.Timestamp('2005-12-20', tz='UTC'),
                'expiration_date': pd.Timestamp('2006-01-20', tz='UTC')},
            1: {
                'root_symbol': 'CL',
                'symbol': 'CLK06',
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
                'notice_date': pd.Timestamp('2006-03-20', tz='UTC'),
                'expiration_date': pd.Timestamp('2006-04-20', tz='UTC')},
            2: {
                'symbol': 'CLQ06',
                'root_symbol': 'CL',
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
                'notice_date': pd.Timestamp('2006-06-20', tz='UTC'),
                'expiration_date': pd.Timestamp('2006-07-20', tz='UTC')},
            3: {
                'symbol': 'CLX06',
                'root_symbol': 'CL',
                'start_date': pd.Timestamp('2006-02-01', tz='UTC'),
                'notice_date': pd.Timestamp('2006-09-20', tz='UTC'),
                'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')}
        }

        env = TradingEnvironment(load=noop_load)
        env.write_data(futures_data=metadata)
        cls.asset_finder = env.asset_finder
Ejemplo n.º 5
0
    def setUpClass(cls):
        metadata = {
            0: {
                'symbol': 'CLG06',
                'root_symbol': 'CL',
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
                'notice_date': pd.Timestamp('2005-12-20', tz='UTC'),
                'expiration_date': pd.Timestamp('2006-01-20', tz='UTC')
            },
            1: {
                'root_symbol': 'CL',
                'symbol': 'CLK06',
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
                'notice_date': pd.Timestamp('2006-03-20', tz='UTC'),
                'expiration_date': pd.Timestamp('2006-04-20', tz='UTC')
            },
            2: {
                'symbol': 'CLQ06',
                'root_symbol': 'CL',
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
                'notice_date': pd.Timestamp('2006-06-20', tz='UTC'),
                'expiration_date': pd.Timestamp('2006-07-20', tz='UTC')
            },
            3: {
                'symbol': 'CLX06',
                'root_symbol': 'CL',
                'start_date': pd.Timestamp('2006-02-01', tz='UTC'),
                'notice_date': pd.Timestamp('2006-09-20', tz='UTC'),
                'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')
            }
        }

        env = TradingEnvironment(load=noop_load)
        env.write_data(futures_data=metadata)
        cls.asset_finder = env.asset_finder
Ejemplo n.º 6
0
class TestChangeOfSids(TestCase):
    def setUp(self):
        self.sids = range(90)
        self.env = TradingEnvironment()
        self.env.write_data(equities_identifiers=self.sids)

        self.sim_params = factory.create_simulation_parameters(
            start=datetime(1990, 1, 1, tzinfo=pytz.utc),
            end=datetime(1990, 1, 8, tzinfo=pytz.utc),
            env=self.env,
        )

    def test_all_sids_passed(self):
        algo = BatchTransformAlgorithmSetSid(
            sim_params=self.sim_params,
            env=self.env,
        )
        source = DifferentSidSource()
        algo.run(source)
        for i, (df, date) in enumerate(zip(algo.history, source.trading_days)):
            self.assertEqual(
                df.index[-1], date, "Newest event doesn't \
                             match.")

            for sid in self.sids[:i]:
                self.assertIn(sid, df.columns)

            self.assertEqual(df.iloc[-1].iloc[-1], i)
Ejemplo n.º 7
0
    def test_yahoo_bars_to_panel_source(self):
        env = TradingEnvironment()
        finder = AssetFinder(env.engine)
        stocks = ['AAPL', 'GE']
        env.write_data(equities_identifiers=stocks)
        start = pd.datetime(1993, 1, 1, 0, 0, 0, 0, pytz.utc)
        end = pd.datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc)
        data = factory.load_bars_from_yahoo(stocks=stocks,
                                            indexes={},
                                            start=start,
                                            end=end)
        check_fields = ['sid', 'open', 'high', 'low', 'close',
                        'volume', 'price']

        copy_panel = data.copy()
        sids = finder.map_identifier_index_to_sids(
            data.items, data.major_axis[0]
        )
        copy_panel.items = sids
        source = DataPanelSource(copy_panel)
        for event in source:
            for check_field in check_fields:
                self.assertIn(check_field, event)
            self.assertTrue(isinstance(event['volume'], (integer_types)))
            self.assertTrue(event['sid'] in sids)
Ejemplo n.º 8
0
    def test_yahoo_bars_to_panel_source(self):
        env = TradingEnvironment()
        finder = AssetFinder(env.engine)
        stocks = ['AAPL', 'GE']
        env.write_data(equities_identifiers=stocks)
        start = pd.datetime(1993, 1, 1, 0, 0, 0, 0, pytz.utc)
        end = pd.datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc)
        data = factory.load_bars_from_yahoo(stocks=stocks,
                                            indexes={},
                                            start=start,
                                            end=end)
        check_fields = [
            'sid', 'open', 'high', 'low', 'close', 'volume', 'price'
        ]

        copy_panel = data.copy()
        sids = finder.map_identifier_index_to_sids(data.items,
                                                   data.major_axis[0])
        copy_panel.items = sids
        source = DataPanelSource(copy_panel)
        for event in source:
            for check_field in check_fields:
                self.assertIn(check_field, event)
            self.assertTrue(isinstance(event['volume'], (integer_types)))
            self.assertTrue(event['sid'] in sids)
Ejemplo n.º 9
0
    def setUp(self):
        self.constants = {
            # Every day, assume every stock starts at 2, goes down to 1,
            # goes up to 4, and finishes at 3.
            USEquityPricing.low: 1,
            USEquityPricing.open: 2,
            USEquityPricing.close: 3,
            USEquityPricing.high: 4,
        }
        self.asset_ids = [1, 2, 3, 4]
        self.dates = date_range('2014-01', '2014-03', freq='D', tz='UTC')
        self.loader = PrecomputedLoader(
            constants=self.constants,
            dates=self.dates,
            sids=self.asset_ids,
        )

        self.asset_info = make_simple_equity_info(
            self.asset_ids,
            start_date=self.dates[0],
            end_date=self.dates[-1],
        )
        environment = TradingEnvironment()
        environment.write_data(equities_df=self.asset_info)
        self.asset_finder = environment.asset_finder
        self.assets = self.asset_finder.retrieve_all(self.asset_ids)
Ejemplo n.º 10
0
class TestChangeOfSids(TestCase):
    def setUp(self):
        self.sids = range(90)
        self.env = TradingEnvironment()
        self.env.write_data(equities_identifiers=self.sids)

        self.sim_params = factory.create_simulation_parameters(
            start=datetime(1990, 1, 1, tzinfo=pytz.utc),
            end=datetime(1990, 1, 8, tzinfo=pytz.utc),
            env=self.env,
        )

    def test_all_sids_passed(self):
        algo = BatchTransformAlgorithmSetSid(
            sim_params=self.sim_params,
            env=self.env,
        )
        source = DifferentSidSource()
        algo.run(source)
        for i, (df, date) in enumerate(zip(algo.history, source.trading_days)):
            self.assertEqual(df.index[-1], date, "Newest event doesn't \
                             match.")

            for sid in self.sids[:i]:
                self.assertIn(sid, df.columns)

            self.assertEqual(df.iloc[-1].iloc[-1], i)
Ejemplo n.º 11
0
    def setUp(self):
        self.assets = [1, 2, 3]
        self.dates = date_range("2014-01", "2014-03", freq="D", tz="UTC")

        asset_info = make_simple_asset_info(self.assets, start_date=self.dates[0], end_date=self.dates[-1])
        env = TradingEnvironment()
        env.write_data(equities_df=asset_info)
        self.asset_finder = env.asset_finder
Ejemplo n.º 12
0
    def test_compute_lifetimes(self):
        num_assets = 4
        env = TradingEnvironment()
        trading_day = env.trading_day
        first_start = pd.Timestamp('2015-04-01', tz='UTC')

        frame = make_rotating_asset_info(num_assets=num_assets,
                                         first_start=first_start,
                                         frequency=env.trading_day,
                                         periods_between_starts=3,
                                         asset_lifetime=5)

        env.write_data(equities_df=frame)
        finder = env.asset_finder

        all_dates = pd.date_range(
            start=first_start,
            end=frame.end_date.max(),
            freq=trading_day,
        )

        for dates in all_subindices(all_dates):
            expected_with_start_raw = full(
                shape=(len(dates), num_assets),
                fill_value=False,
                dtype=bool,
            )
            expected_no_start_raw = full(
                shape=(len(dates), num_assets),
                fill_value=False,
                dtype=bool,
            )

            for i, date in enumerate(dates):
                it = frame[['start_date', 'end_date']].itertuples()
                for j, start, end in it:
                    # This way of doing the checks is redundant, but very
                    # clear.
                    if start <= date <= end:
                        expected_with_start_raw[i, j] = True
                        if start < date:
                            expected_no_start_raw[i, j] = True

            expected_with_start = pd.DataFrame(
                data=expected_with_start_raw,
                index=dates,
                columns=frame.index.values,
            )
            result = finder.lifetimes(dates, include_start_date=True)
            assert_frame_equal(result, expected_with_start)

            expected_no_start = pd.DataFrame(
                data=expected_no_start_raw,
                index=dates,
                columns=frame.index.values,
            )
            result = finder.lifetimes(dates, include_start_date=False)
            assert_frame_equal(result, expected_no_start)
Ejemplo n.º 13
0
 def test_sids(self):
     # Ensure that the sids property of the AssetFinder is functioning
     env = TradingEnvironment()
     env.write_data(equities_identifiers=[1, 2, 3])
     sids = env.asset_finder.sids
     self.assertEqual(3, len(sids))
     self.assertTrue(1 in sids)
     self.assertTrue(2 in sids)
     self.assertTrue(3 in sids)
Ejemplo n.º 14
0
 def test_sids(self):
     # Ensure that the sids property of the AssetFinder is functioning
     env = TradingEnvironment()
     env.write_data(equities_identifiers=[1, 2, 3])
     sids = env.asset_finder.sids
     self.assertEqual(3, len(sids))
     self.assertTrue(1 in sids)
     self.assertTrue(2 in sids)
     self.assertTrue(3 in sids)
Ejemplo n.º 15
0
    def setUp(self):
        self.assets = [1, 2, 3]
        self.dates = date_range('2014-01-01', '2014-02-01', freq='D', tz='UTC')

        asset_info = make_simple_asset_info(
            self.assets,
            start_date=self.dates[0],
            end_date=self.dates[-1],
        )
        env = TradingEnvironment()
        env.write_data(equities_df=asset_info)
        self.asset_finder = env.asset_finder
Ejemplo n.º 16
0
    def setUp(self):
        self.assets = [1, 2, 3]
        self.dates = date_range('2014-01', '2014-03', freq='D', tz='UTC')

        asset_info = make_simple_asset_info(
            self.assets,
            start_date=self.dates[0],
            end_date=self.dates[-1],
        )
        env = TradingEnvironment()
        env.write_data(equities_df=asset_info)
        self.asset_finder = env.asset_finder
Ejemplo n.º 17
0
    def setUp(self):
        self.__calendar = date_range('2014', '2015', freq=trading_day)
        self.__assets = assets = Int64Index(arange(1, 20))

        # Set up env for test
        env = TradingEnvironment()
        env.write_data(equities_df=make_simple_asset_info(
            assets,
            self.__calendar[0],
            self.__calendar[-1],
        ))
        self.__finder = env.asset_finder
        self.__mask = self.__finder.lifetimes(self.__calendar[-10:])
Ejemplo n.º 18
0
    def setUp(self):
        self.__calendar = date_range('2014', '2015', freq=trading_day)
        self.__assets = assets = Int64Index(arange(1, 20))

        # Set up env for test
        env = TradingEnvironment()
        env.write_data(
            equities_df=make_simple_asset_info(
                assets,
                self.__calendar[0],
                self.__calendar[-1],
            ))
        self.__finder = env.asset_finder
        self.__mask = self.__finder.lifetimes(self.__calendar[-10:])
Ejemplo n.º 19
0
    def setUp(self):
        self.constants = {
            # Every day, assume every stock starts at 2, goes down to 1,
            # goes up to 4, and finishes at 3.
            USEquityPricing.low: 1,
            USEquityPricing.open: 2,
            USEquityPricing.close: 3,
            USEquityPricing.high: 4,
        }
        self.assets = [1, 2, 3]
        self.dates = date_range("2014-01-01", "2014-02-01", freq="D", tz="UTC")
        self.loader = ConstantLoader(constants=self.constants, dates=self.dates, assets=self.assets)

        self.asset_info = make_simple_asset_info(self.assets, start_date=self.dates[0], end_date=self.dates[-1])
        environment = TradingEnvironment()
        environment.write_data(equities_df=self.asset_info)
        self.asset_finder = environment.asset_finder
Ejemplo n.º 20
0
    def test_compute_lifetimes(self):
        num_assets = 4
        env = TradingEnvironment()
        trading_day = env.trading_day
        first_start = pd.Timestamp('2015-04-01', tz='UTC')

        frame = make_rotating_asset_info(
            num_assets=num_assets,
            first_start=first_start,
            frequency=env.trading_day,
            periods_between_starts=3,
            asset_lifetime=5
        )

        env.write_data(equities_df=frame)
        finder = env.asset_finder

        all_dates = pd.date_range(
            start=first_start,
            end=frame.end_date.max(),
            freq=trading_day,
        )

        for dates in all_subindices(all_dates):
            expected_mask = full(
                shape=(len(dates), num_assets),
                fill_value=False,
                dtype=bool,
            )

            for i, date in enumerate(dates):
                it = frame[['start_date', 'end_date']].itertuples()
                for j, start, end in it:
                    if start <= date <= end:
                        expected_mask[i, j] = True

            # Filter out columns with all-empty columns.
            expected_result = pd.DataFrame(
                data=expected_mask,
                index=dates,
                columns=frame.index.values,
            )

            actual_result = finder.lifetimes(dates)
            assert_frame_equal(actual_result, expected_result)
Ejemplo n.º 21
0
    def setUp(self):
        self.__calendar = date_range('2014', '2015', freq=trading_day)
        self.__assets = assets = Int64Index(arange(1, 20))

        # Set up env for test
        env = TradingEnvironment()
        env.write_data(equities_df=make_simple_equity_info(
            assets,
            self.__calendar[0],
            self.__calendar[-1],
        ), )
        self.__finder = env.asset_finder

        # Use a 30-day period at the end of the year by default.
        self.__mask = self.__finder.lifetimes(
            self.__calendar[-30:],
            include_start_date=False,
        )
Ejemplo n.º 22
0
    def setUp(self):
        self.__calendar = date_range('2014', '2015', freq=trading_day)
        self.__assets = assets = Int64Index(arange(1, 20))

        # Set up env for test
        env = TradingEnvironment()
        env.write_data(
            equities_df=make_simple_asset_info(
                assets,
                self.__calendar[0],
                self.__calendar[-1],
            ),
        )
        self.__finder = env.asset_finder

        # Use a 30-day period at the end of the year by default.
        self.__mask = self.__finder.lifetimes(
            self.__calendar[-30:],
            include_start_date=False,
        )
Ejemplo n.º 23
0
 def setUpClass(cls):
     cls.future = Future(2468,
                         symbol='OMH15',
                         root_symbol='OM',
                         notice_date=pd.Timestamp('2014-01-20', tz='UTC'),
                         expiration_date=pd.Timestamp('2014-02-20',
                                                      tz='UTC'),
                         auto_close_date=pd.Timestamp('2014-01-18',
                                                      tz='UTC'),
                         contract_multiplier=500)
     cls.future2 = Future(0,
                          symbol='CLG06',
                          root_symbol='CL',
                          start_date=pd.Timestamp('2005-12-01', tz='UTC'),
                          notice_date=pd.Timestamp('2005-12-20', tz='UTC'),
                          expiration_date=pd.Timestamp('2006-01-20',
                                                       tz='UTC'))
     env = TradingEnvironment()
     env.write_data(
         futures_identifiers=[TestFuture.future, TestFuture.future2])
     cls.asset_finder = env.asset_finder
Ejemplo n.º 24
0
 def setUpClass(cls):
     cls.future = Future(
         2468,
         symbol="OMH15",
         root_symbol="OM",
         notice_date=pd.Timestamp("2014-01-20", tz="UTC"),
         expiration_date=pd.Timestamp("2014-02-20", tz="UTC"),
         auto_close_date=pd.Timestamp("2014-01-18", tz="UTC"),
         contract_multiplier=500,
     )
     cls.future2 = Future(
         0,
         symbol="CLG06",
         root_symbol="CL",
         start_date=pd.Timestamp("2005-12-01", tz="UTC"),
         notice_date=pd.Timestamp("2005-12-20", tz="UTC"),
         expiration_date=pd.Timestamp("2006-01-20", tz="UTC"),
     )
     env = TradingEnvironment(load=noop_load)
     env.write_data(futures_identifiers=[TestFuture.future, TestFuture.future2])
     cls.asset_finder = env.asset_finder
Ejemplo n.º 25
0
 def setUpClass(cls):
     cls.future = Future(
         2468,
         symbol='OMH15',
         root_symbol='OM',
         notice_date=pd.Timestamp('2014-01-20', tz='UTC'),
         expiration_date=pd.Timestamp('2014-02-20', tz='UTC'),
         auto_close_date=pd.Timestamp('2014-01-18', tz='UTC'),
         contract_multiplier=500
     )
     cls.future2 = Future(
         0,
         symbol='CLG06',
         root_symbol='CL',
         start_date=pd.Timestamp('2005-12-01', tz='UTC'),
         notice_date=pd.Timestamp('2005-12-20', tz='UTC'),
         expiration_date=pd.Timestamp('2006-01-20', tz='UTC')
     )
     env = TradingEnvironment(load=noop_load)
     env.write_data(futures_identifiers=[TestFuture.future,
                                         TestFuture.future2])
     cls.asset_finder = env.asset_finder
Ejemplo n.º 26
0
    def test_algo_without_rl_violation_after_delete(self):
        new_tempdir = TempDirectory()
        try:
            with security_list_copy():
                # add a delete statement removing bzq
                # write a new delete statement file to disk
                add_security_data([], ['BZQ'])

                # now fast-forward to self.extra_knowledge_date.  requires
                # a new env, simparams, and dataportal
                env = TradingEnvironment()
                sim_params = factory.create_simulation_parameters(
                    start=self.extra_knowledge_date, num_days=4, env=env)

                env.write_data(equities_data={
                    "0": {
                        'symbol': 'BZQ',
                        'start_date': sim_params.period_start,
                        'end_date': sim_params.period_end,
                    }
                })

                data_portal = create_data_portal(
                    env,
                    new_tempdir,
                    sim_params,
                    range(0, 5)
                )

                algo = RestrictedAlgoWithoutCheck(
                    symbol='BZQ', sim_params=sim_params, env=env
                )
                algo.run(data_portal)

        finally:
            new_tempdir.cleanup()
Ejemplo n.º 27
0
def build_lookup_generic_cases():
    """
    Generate test cases for AssetFinder test_lookup_generic.
    """

    unique_start = pd.Timestamp('2013-01-01', tz='UTC')
    unique_end = pd.Timestamp('2014-01-01', tz='UTC')

    dupe_0_start = pd.Timestamp('2013-01-01', tz='UTC')
    dupe_0_end = dupe_0_start + timedelta(days=1)

    dupe_1_start = pd.Timestamp('2013-01-03', tz='UTC')
    dupe_1_end = dupe_1_start + timedelta(days=1)

    frame = pd.DataFrame.from_records([
        {
            'sid': 0,
            'symbol': 'duplicated',
            'start_date': dupe_0_start.value,
            'end_date': dupe_0_end.value,
            'exchange': '',
        },
        {
            'sid': 1,
            'symbol': 'duplicated',
            'start_date': dupe_1_start.value,
            'end_date': dupe_1_end.value,
            'exchange': '',
        },
        {
            'sid': 2,
            'symbol': 'unique',
            'start_date': unique_start.value,
            'end_date': unique_end.value,
            'exchange': '',
        },
    ],
                                      index='sid')
    env = TradingEnvironment()
    env.write_data(equities_df=frame)
    finder = env.asset_finder
    dupe_0, dupe_1, unique = assets = [
        finder.retrieve_asset(i) for i in range(3)
    ]

    dupe_0_start = dupe_0.start_date
    dupe_1_start = dupe_1.start_date
    cases = [
        ##
        # Scalars

        # Asset object
        (finder, assets[0], None, assets[0]),
        (finder, assets[1], None, assets[1]),
        (finder, assets[2], None, assets[2]),
        # int
        (finder, 0, None, assets[0]),
        (finder, 1, None, assets[1]),
        (finder, 2, None, assets[2]),
        # Duplicated symbol with resolution date
        (finder, 'duplicated', dupe_0_start, dupe_0),
        (finder, 'duplicated', dupe_1_start, dupe_1),
        # Unique symbol, with or without resolution date.
        (finder, 'unique', unique_start, unique),
        (finder, 'unique', None, unique),

        ##
        # Iterables

        # Iterables of Asset objects.
        (finder, assets, None, assets),
        (finder, iter(assets), None, assets),
        # Iterables of ints
        (finder, (0, 1), None, assets[:-1]),
        (finder, iter((0, 1)), None, assets[:-1]),
        # Iterables of symbols.
        (finder, ('duplicated', 'unique'), dupe_0_start, [dupe_0, unique]),
        (finder, ('duplicated', 'unique'), dupe_1_start, [dupe_1, unique]),
        # Mixed types
        (finder, ('duplicated', 2, 'unique', 1, dupe_1), dupe_0_start,
         [dupe_0, assets[2], unique, assets[1], dupe_1]),
    ]
    return cases
Ejemplo n.º 28
0
class AssetFinderTestCase(TestCase):
    def setUp(self):
        self.env = TradingEnvironment()

    def test_lookup_symbol_delimited(self):
        as_of = pd.Timestamp('2013-01-01', tz='UTC')
        frame = pd.DataFrame.from_records([{
            'sid': i,
            'symbol': 'TEST.%d' % i,
            'company_name': "company%d" % i,
            'start_date': as_of.value,
            'end_date': as_of.value,
            'exchange': uuid.uuid4().hex
        } for i in range(3)])
        self.env.write_data(equities_df=frame)
        finder = AssetFinder(self.env.engine)
        asset_0, asset_1, asset_2 = (finder.retrieve_asset(i)
                                     for i in range(3))

        # we do it twice to catch caching bugs
        for i in range(2):
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol('test', as_of)
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol('test1', as_of)
            # '@' is not a supported delimiter
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol('test@1', as_of)

            # Adding an unnecessary fuzzy shouldn't matter.
            for fuzzy_char in ['-', '/', '_', '.']:
                self.assertEqual(
                    asset_1, finder.lookup_symbol('test%s1' % fuzzy_char,
                                                  as_of))

    def test_lookup_symbol_fuzzy(self):
        metadata = {
            0: {
                'symbol': 'PRTY_HRD'
            },
            1: {
                'symbol': 'BRKA'
            },
            2: {
                'symbol': 'BRK_A'
            },
        }
        self.env.write_data(equities_data=metadata)
        finder = self.env.asset_finder
        dt = pd.Timestamp('2013-01-01', tz='UTC')

        # Try combos of looking up PRTYHRD with and without a time or fuzzy
        # Both non-fuzzys get no result
        with self.assertRaises(SymbolNotFound):
            finder.lookup_symbol('PRTYHRD', None)
        with self.assertRaises(SymbolNotFound):
            finder.lookup_symbol('PRTYHRD', dt)
        # Both fuzzys work
        self.assertEqual(0, finder.lookup_symbol('PRTYHRD', None, fuzzy=True))
        self.assertEqual(0, finder.lookup_symbol('PRTYHRD', dt, fuzzy=True))

        # Try combos of looking up PRTY_HRD, all returning sid 0
        self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', None))
        self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', dt))
        self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', None, fuzzy=True))
        self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', dt, fuzzy=True))

        # Try combos of looking up BRKA, all returning sid 1
        self.assertEqual(1, finder.lookup_symbol('BRKA', None))
        self.assertEqual(1, finder.lookup_symbol('BRKA', dt))
        self.assertEqual(1, finder.lookup_symbol('BRKA', None, fuzzy=True))
        self.assertEqual(1, finder.lookup_symbol('BRKA', dt, fuzzy=True))

        # Try combos of looking up BRK_A, all returning sid 2
        self.assertEqual(2, finder.lookup_symbol('BRK_A', None))
        self.assertEqual(2, finder.lookup_symbol('BRK_A', dt))
        self.assertEqual(2, finder.lookup_symbol('BRK_A', None, fuzzy=True))
        self.assertEqual(2, finder.lookup_symbol('BRK_A', dt, fuzzy=True))

    def test_lookup_symbol(self):

        # Incrementing by two so that start and end dates for each
        # generated Asset don't overlap (each Asset's end_date is the
        # day after its start date.)
        dates = pd.date_range('2013-01-01', freq='2D', periods=5, tz='UTC')
        df = pd.DataFrame.from_records([{
            'sid':
            i,
            'symbol':
            'existing',
            'start_date':
            date.value,
            'end_date': (date + timedelta(days=1)).value,
            'exchange':
            'NYSE',
        } for i, date in enumerate(dates)])
        self.env.write_data(equities_df=df)
        finder = AssetFinder(self.env.engine)
        for _ in range(2):  # Run checks twice to test for caching bugs.
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol('non_existing', dates[0])

            with self.assertRaises(MultipleSymbolsFound):
                finder.lookup_symbol('existing', None)

            for i, date in enumerate(dates):
                # Verify that we correctly resolve multiple symbols using
                # the supplied date
                result = finder.lookup_symbol('existing', date)
                self.assertEqual(result.symbol, 'EXISTING')
                self.assertEqual(result.sid, i)

    @parameterized.expand(build_lookup_generic_cases())
    def test_lookup_generic(self, finder, symbols, reference_date, expected):
        """
        Ensure that lookup_generic works with various permutations of inputs.
        """
        results, missing = finder.lookup_generic(symbols, reference_date)
        self.assertEqual(results, expected)
        self.assertEqual(missing, [])

    def test_lookup_generic_handle_missing(self):
        data = pd.DataFrame.from_records([
            {
                'sid': 0,
                'symbol': 'real',
                'start_date': pd.Timestamp('2013-1-1', tz='UTC'),
                'end_date': pd.Timestamp('2014-1-1', tz='UTC'),
                'exchange': '',
            },
            {
                'sid': 1,
                'symbol': 'also_real',
                'start_date': pd.Timestamp('2013-1-1', tz='UTC'),
                'end_date': pd.Timestamp('2014-1-1', tz='UTC'),
                'exchange': '',
            },
            # Sid whose end date is before our query date.  We should
            # still correctly find it.
            {
                'sid': 2,
                'symbol': 'real_but_old',
                'start_date': pd.Timestamp('2002-1-1', tz='UTC'),
                'end_date': pd.Timestamp('2003-1-1', tz='UTC'),
                'exchange': '',
            },
            # Sid whose start_date is **after** our query date.  We should
            # **not** find it.
            {
                'sid': 3,
                'symbol': 'real_but_in_the_future',
                'start_date': pd.Timestamp('2014-1-1', tz='UTC'),
                'end_date': pd.Timestamp('2020-1-1', tz='UTC'),
                'exchange': 'THE FUTURE',
            },
        ])
        self.env.write_data(equities_df=data)
        finder = AssetFinder(self.env.engine)
        results, missing = finder.lookup_generic(
            ['real', 1, 'fake', 'real_but_old', 'real_but_in_the_future'],
            pd.Timestamp('2013-02-01', tz='UTC'),
        )

        self.assertEqual(len(results), 3)
        self.assertEqual(results[0].symbol, 'REAL')
        self.assertEqual(results[0].sid, 0)
        self.assertEqual(results[1].symbol, 'ALSO_REAL')
        self.assertEqual(results[1].sid, 1)
        self.assertEqual(results[2].symbol, 'REAL_BUT_OLD')
        self.assertEqual(results[2].sid, 2)

        self.assertEqual(len(missing), 2)
        self.assertEqual(missing[0], 'fake')
        self.assertEqual(missing[1], 'real_but_in_the_future')

    def test_insert_metadata(self):
        data = {
            0: {
                'asset_type': 'equity',
                'start_date': '2014-01-01',
                'end_date': '2015-01-01',
                'symbol': "PLAY",
                'foo_data': "FOO"
            }
        }
        self.env.write_data(equities_data=data)
        finder = AssetFinder(self.env.engine)
        # Test proper insertion
        equity = finder.retrieve_asset(0)
        self.assertIsInstance(equity, Equity)
        self.assertEqual('PLAY', equity.symbol)
        self.assertEqual(pd.Timestamp('2015-01-01', tz='UTC'), equity.end_date)

        # Test invalid field
        with self.assertRaises(AttributeError):
            equity.foo_data

    def test_consume_metadata(self):

        # Test dict consumption
        dict_to_consume = {0: {'symbol': 'PLAY'}, 1: {'symbol': 'MSFT'}}
        self.env.write_data(equities_data=dict_to_consume)
        finder = AssetFinder(self.env.engine)

        equity = finder.retrieve_asset(0)
        self.assertIsInstance(equity, Equity)
        self.assertEqual('PLAY', equity.symbol)

        # Test dataframe consumption
        df = pd.DataFrame(columns=['asset_name', 'exchange'], index=[0, 1])
        df['asset_name'][0] = "Dave'N'Busters"
        df['exchange'][0] = "NASDAQ"
        df['asset_name'][1] = "Microsoft"
        df['exchange'][1] = "NYSE"
        self.env = TradingEnvironment()
        self.env.write_data(equities_df=df)
        finder = AssetFinder(self.env.engine)
        self.assertEqual('NASDAQ', finder.retrieve_asset(0).exchange)
        self.assertEqual('Microsoft', finder.retrieve_asset(1).asset_name)

    def test_consume_asset_as_identifier(self):
        # Build some end dates
        eq_end = pd.Timestamp('2012-01-01', tz='UTC')
        fut_end = pd.Timestamp('2008-01-01', tz='UTC')

        # Build some simple Assets
        equity_asset = Equity(1, symbol="TESTEQ", end_date=eq_end)
        future_asset = Future(200, symbol="TESTFUT", end_date=fut_end)

        # Consume the Assets
        self.env.write_data(equities_identifiers=[equity_asset],
                            futures_identifiers=[future_asset])
        finder = AssetFinder(self.env.engine)

        # Test equality with newly built Assets
        self.assertEqual(equity_asset, finder.retrieve_asset(1))
        self.assertEqual(future_asset, finder.retrieve_asset(200))
        self.assertEqual(eq_end, finder.retrieve_asset(1).end_date)
        self.assertEqual(fut_end, finder.retrieve_asset(200).end_date)

    def test_sid_assignment(self):

        # This metadata does not contain SIDs
        metadata = ['PLAY', 'MSFT']

        today = normalize_date(pd.Timestamp('2015-07-09', tz='UTC'))

        # Write data with sid assignment
        self.env.write_data(equities_identifiers=metadata,
                            allow_sid_assignment=True)

        # Verify that Assets were built and different sids were assigned
        finder = AssetFinder(self.env.engine)
        play = finder.lookup_symbol('PLAY', today)
        msft = finder.lookup_symbol('MSFT', today)
        self.assertEqual('PLAY', play.symbol)
        self.assertIsNotNone(play.sid)
        self.assertNotEqual(play.sid, msft.sid)

    def test_sid_assignment_failure(self):

        # This metadata does not contain SIDs
        metadata = ['PLAY', 'MSFT']

        # Write data without sid assignment, asserting failure
        with self.assertRaises(SidAssignmentError):
            self.env.write_data(equities_identifiers=metadata,
                                allow_sid_assignment=False)

    def test_security_dates_warning(self):

        # Build an asset with an end_date
        eq_end = pd.Timestamp('2012-01-01', tz='UTC')
        equity_asset = Equity(1, symbol="TESTEQ", end_date=eq_end)

        # Catch all warnings
        with warnings.catch_warnings(record=True) as w:
            # Cause all warnings to always be triggered
            warnings.simplefilter("always")
            equity_asset.security_start_date
            equity_asset.security_end_date
            equity_asset.security_name
            # Verify the warning
            self.assertEqual(3, len(w))
            for warning in w:
                self.assertTrue(
                    issubclass(warning.category, DeprecationWarning))

    def test_lookup_future_chain(self):
        metadata = {
            # Notice day is today, so not valid
            2: {
                'symbol': 'ADN15',
                'root_symbol': 'AD',
                'asset_type': 'future',
                'notice_date': pd.Timestamp('2015-05-14', tz='UTC'),
                'start_date': pd.Timestamp('2015-01-01', tz='UTC')
            },
            1: {
                'symbol': 'ADV15',
                'root_symbol': 'AD',
                'asset_type': 'future',
                'notice_date': pd.Timestamp('2015-08-14', tz='UTC'),
                'start_date': pd.Timestamp('2015-01-01', tz='UTC')
            },
            # Starts trading today, so should be valid.
            0: {
                'symbol': 'ADF16',
                'root_symbol': 'AD',
                'asset_type': 'future',
                'notice_date': pd.Timestamp('2015-11-16', tz='UTC'),
                'start_date': pd.Timestamp('2015-05-14', tz='UTC')
            },
            # Starts trading in August, so not valid.
            3: {
                'symbol': 'ADX16',
                'root_symbol': 'AD',
                'asset_type': 'future',
                'notice_date': pd.Timestamp('2015-11-16', tz='UTC'),
                'start_date': pd.Timestamp('2015-08-01', tz='UTC')
            },
        }
        self.env.write_data(futures_data=metadata)
        finder = AssetFinder(self.env.engine)
        dt = pd.Timestamp('2015-05-14', tz='UTC')
        last_year = pd.Timestamp('2014-01-01', tz='UTC')
        first_day = pd.Timestamp('2015-01-01', tz='UTC')

        # Check that we get the expected number of contracts, in the
        # right order
        ad_contracts = finder.lookup_future_chain('AD', dt, dt)
        self.assertEqual(len(ad_contracts), 2)
        self.assertEqual(ad_contracts[0].sid, 1)
        self.assertEqual(ad_contracts[1].sid, 0)

        # Check that pd.NaT for knowledge_date uses the value of as_of_date
        ad_contracts = finder.lookup_future_chain('AD', dt, pd.NaT)
        self.assertEqual(len(ad_contracts), 2)

        # Check that we get nothing if our knowledge date is last year
        ad_contracts = finder.lookup_future_chain('AD', dt, last_year)
        self.assertEqual(len(ad_contracts), 0)

        # Check that we get things that start on the knowledge date
        ad_contracts = finder.lookup_future_chain('AD', dt, first_day)
        self.assertEqual(len(ad_contracts), 1)

        # Check that pd.NaT for as_of_date gives the whole chain
        ad_contracts = finder.lookup_future_chain('AD', pd.NaT, first_day)
        self.assertEqual(len(ad_contracts), 4)

    def test_map_identifier_index_to_sids(self):
        # Build an empty finder and some Assets
        dt = pd.Timestamp('2014-01-01', tz='UTC')
        finder = AssetFinder(self.env.engine)
        asset1 = Equity(1, symbol="AAPL")
        asset2 = Equity(2, symbol="GOOG")
        asset200 = Future(200, symbol="CLK15")
        asset201 = Future(201, symbol="CLM15")

        # Check for correct mapping and types
        pre_map = [asset1, asset2, asset200, asset201]
        post_map = finder.map_identifier_index_to_sids(pre_map, dt)
        self.assertListEqual([1, 2, 200, 201], post_map)
        for sid in post_map:
            self.assertIsInstance(sid, int)

        # Change order and check mapping again
        pre_map = [asset201, asset2, asset200, asset1]
        post_map = finder.map_identifier_index_to_sids(pre_map, dt)
        self.assertListEqual([201, 2, 200, 1], post_map)

    def test_compute_lifetimes(self):
        num_assets = 4
        env = TradingEnvironment()
        trading_day = env.trading_day
        first_start = pd.Timestamp('2015-04-01', tz='UTC')

        frame = make_rotating_asset_info(num_assets=num_assets,
                                         first_start=first_start,
                                         frequency=env.trading_day,
                                         periods_between_starts=3,
                                         asset_lifetime=5)

        env.write_data(equities_df=frame)
        finder = env.asset_finder

        all_dates = pd.date_range(
            start=first_start,
            end=frame.end_date.max(),
            freq=trading_day,
        )

        for dates in all_subindices(all_dates):
            expected_with_start_raw = full(
                shape=(len(dates), num_assets),
                fill_value=False,
                dtype=bool,
            )
            expected_no_start_raw = full(
                shape=(len(dates), num_assets),
                fill_value=False,
                dtype=bool,
            )

            for i, date in enumerate(dates):
                it = frame[['start_date', 'end_date']].itertuples()
                for j, start, end in it:
                    # This way of doing the checks is redundant, but very
                    # clear.
                    if start <= date <= end:
                        expected_with_start_raw[i, j] = True
                        if start < date:
                            expected_no_start_raw[i, j] = True

            expected_with_start = pd.DataFrame(
                data=expected_with_start_raw,
                index=dates,
                columns=frame.index.values,
            )
            result = finder.lifetimes(dates, include_start_date=True)
            assert_frame_equal(result, expected_with_start)

            expected_no_start = pd.DataFrame(
                data=expected_no_start_raw,
                index=dates,
                columns=frame.index.values,
            )
            result = finder.lifetimes(dates, include_start_date=False)
            assert_frame_equal(result, expected_no_start)

    def test_sids(self):
        # Ensure that the sids property of the AssetFinder is functioning
        env = TradingEnvironment()
        env.write_data(equities_identifiers=[1, 2, 3])
        sids = env.asset_finder.sids
        self.assertEqual(3, len(sids))
        self.assertTrue(1 in sids)
        self.assertTrue(2 in sids)
        self.assertTrue(3 in sids)
Ejemplo n.º 29
0
class TradingAlgorithm(object):
    """
    Base class for trading algorithms. Inherit and overload
    initialize() and handle_data(data).

    A new algorithm could look like this:
    ```
    from zipline.api import order, symbol

    def initialize(context):
        context.sid = symbol('AAPL')
        context.amount = 100

    def handle_data(context, data):
        sid = context.sid
        amount = context.amount
        order(sid, amount)
    ```
    To then to run this algorithm pass these functions to
    TradingAlgorithm:

    my_algo = TradingAlgorithm(initialize, handle_data)
    stats = my_algo.run(data)

    """
    def __init__(self, *args, **kwargs):
        """Initialize sids and other state variables.

        :Arguments:
        :Optional:
            initialize : function
                Function that is called with a single
                argument at the begninning of the simulation.
            handle_data : function
                Function that is called with 2 arguments
                (context and data) on every bar.
            script : str
                Algoscript that contains initialize and
                handle_data function definition.
            data_frequency : {'daily', 'minute'}
               The duration of the bars.
            capital_base : float <default: 1.0e5>
               How much capital to start with.
            instant_fill : bool <default: False>
               Whether to fill orders immediately or on next bar.
            asset_finder : An AssetFinder object
                A new AssetFinder object to be used in this TradingEnvironment
            equities_metadata : can be either:
                            - dict
                            - pandas.DataFrame
                            - object with 'read' property
                If dict is provided, it must have the following structure:
                * keys are the identifiers
                * values are dicts containing the metadata, with the metadata
                  field name as the key
                If pandas.DataFrame is provided, it must have the
                following structure:
                * column names must be the metadata fields
                * index must be the different asset identifiers
                * array contents should be the metadata value
                If an object with a 'read' property is provided, 'read' must
                return rows containing at least one of 'sid' or 'symbol' along
                with the other metadata fields.
            identifiers : List
                Any asset identifiers that are not provided in the
                equities_metadata, but will be traded by this TradingAlgorithm
        """
        self.sources = []

        # List of trading controls to be used to validate orders.
        self.trading_controls = []

        # List of account controls to be checked on each bar.
        self.account_controls = []

        self._recorded_vars = {}
        self.namespace = kwargs.get('namespace', {})

        self._platform = kwargs.pop('platform', 'zipline')

        self.logger = None

        self.benchmark_return_source = None

        # default components for transact
        self.slippage = VolumeShareSlippage()
        self.commission = PerShare()

        self.instant_fill = kwargs.pop('instant_fill', False)

        # If an env has been provided, pop it
        self.trading_environment = kwargs.pop('env', None)

        if self.trading_environment is None:
            self.trading_environment = TradingEnvironment()

        # Update the TradingEnvironment with the provided asset metadata
        self.trading_environment.write_data(
            equities_data=kwargs.pop('equities_metadata', {}),
            equities_identifiers=kwargs.pop('identifiers', []),
            futures_data=kwargs.pop('futures_metadata', {}),
        )

        # set the capital base
        self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)
        self.sim_params = kwargs.pop('sim_params', None)
        if self.sim_params is None:
            self.sim_params = create_simulation_parameters(
                capital_base=self.capital_base,
                start=kwargs.pop('start', None),
                end=kwargs.pop('end', None),
                env=self.trading_environment,
            )
        else:
            self.sim_params.update_internal_from_env(self.trading_environment)

        # Build a perf_tracker
        self.perf_tracker = PerformanceTracker(sim_params=self.sim_params,
                                               env=self.trading_environment)

        # Pull in the environment's new AssetFinder for quick reference
        self.asset_finder = self.trading_environment.asset_finder
        self.init_engine(kwargs.pop('ffc_loader', None))

        # Maps from name to Term
        self._filters = {}
        self._factors = {}
        self._classifiers = {}

        self.blotter = kwargs.pop('blotter', None)
        if not self.blotter:
            self.blotter = Blotter()

        # Set the dt initally to the period start by forcing it to change
        self.on_dt_changed(self.sim_params.period_start)

        # The symbol lookup date specifies the date to use when resolving
        # symbols to sids, and can be set using set_symbol_lookup_date()
        self._symbol_lookup_date = None

        self.portfolio_needs_update = True
        self.account_needs_update = True
        self.performance_needs_update = True
        self._portfolio = None
        self._account = None

        self.history_container_class = kwargs.pop(
            'history_container_class',
            HistoryContainer,
        )
        self.history_container = None
        self.history_specs = {}

        # If string is passed in, execute and get reference to
        # functions.
        self.algoscript = kwargs.pop('script', None)

        self._initialize = None
        self._before_trading_start = None
        self._analyze = None

        self.event_manager = EventManager()

        if self.algoscript is not None:
            filename = kwargs.pop('algo_filename', None)
            if filename is None:
                filename = '<string>'
            code = compile(self.algoscript, filename, 'exec')
            exec_(code, self.namespace)
            self._initialize = self.namespace.get('initialize')
            if 'handle_data' not in self.namespace:
                raise ValueError('You must define a handle_data function.')
            else:
                self._handle_data = self.namespace['handle_data']

            self._before_trading_start = \
                self.namespace.get('before_trading_start')
            # Optional analyze function, gets called after run
            self._analyze = self.namespace.get('analyze')

        elif kwargs.get('initialize') and kwargs.get('handle_data'):
            if self.algoscript is not None:
                raise ValueError('You can not set script and \
                initialize/handle_data.')
            self._initialize = kwargs.pop('initialize')
            self._handle_data = kwargs.pop('handle_data')
            self._before_trading_start = kwargs.pop('before_trading_start',
                                                    None)

        self.event_manager.add_event(
            zipline.utils.events.Event(
                zipline.utils.events.Always(),
                # We pass handle_data.__func__ to get the unbound method.
                # We will explicitly pass the algorithm to bind it again.
                self.handle_data.__func__,
            ),
            prepend=True,
        )

        # If method not defined, NOOP
        if self._initialize is None:
            self._initialize = lambda x: None

        # Alternative way of setting data_frequency for backwards
        # compatibility.
        if 'data_frequency' in kwargs:
            self.data_frequency = kwargs.pop('data_frequency')

        self._most_recent_data = None

        # Prepare the algo for initialization
        self.initialized = False
        self.initialize_args = args
        self.initialize_kwargs = kwargs

    def init_engine(self, loader):
        """
        Construct and save an FFCEngine from loader.

        If loader is None, constructs a NoOpFFCEngine.
        """
        if loader is not None:
            self.engine = SimpleFFCEngine(
                loader,
                self.trading_environment.trading_days,
                self.asset_finder,
            )
        else:
            self.engine = NoOpFFCEngine()

    def initialize(self, *args, **kwargs):
        """
        Call self._initialize with `self` made available to Zipline API
        functions.
        """
        with ZiplineAPI(self):
            self._initialize(self)

    def before_trading_start(self, data):
        if self._before_trading_start is None:
            return

        self._before_trading_start(self, data)

    def handle_data(self, data):
        self._most_recent_data = data
        if self.history_container:
            self.history_container.update(data, self.datetime)

        self._handle_data(self, data)

        # Unlike trading controls which remain constant unless placing an
        # order, account controls can change each bar. Thus, must check
        # every bar no matter if the algorithm places an order or not.
        self.validate_account_controls()

    def analyze(self, perf):
        if self._analyze is None:
            return

        with ZiplineAPI(self):
            self._analyze(self, perf)

    def __repr__(self):
        """
        N.B. this does not yet represent a string that can be used
        to instantiate an exact copy of an algorithm.

        However, it is getting close, and provides some value as something
        that can be inspected interactively.
        """
        return """
{class_name}(
    capital_base={capital_base}
    sim_params={sim_params},
    initialized={initialized},
    slippage={slippage},
    commission={commission},
    blotter={blotter},
    recorded_vars={recorded_vars})
""".strip().format(class_name=self.__class__.__name__,
                   capital_base=self.capital_base,
                   sim_params=repr(self.sim_params),
                   initialized=self.initialized,
                   slippage=repr(self.slippage),
                   commission=repr(self.commission),
                   blotter=repr(self.blotter),
                   recorded_vars=repr(self.recorded_vars))

    def _create_data_generator(self, source_filter, sim_params=None):
        """
        Create a merged data generator using the sources attached to this
        algorithm.

        ::source_filter:: is a method that receives events in date
        sorted order, and returns True for those events that should be
        processed by the zipline, and False for those that should be
        skipped.
        """
        if sim_params is None:
            sim_params = self.sim_params

        if self.benchmark_return_source is None:
            if sim_params.data_frequency == 'minute' or \
               sim_params.emission_rate == 'minute':

                def update_time(date):
                    return self.trading_environment.get_open_and_close(date)[1]
            else:

                def update_time(date):
                    return date

            benchmark_return_source = [
                Event({
                    'dt': update_time(dt),
                    'returns': ret,
                    'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK,
                    'source_id': 'benchmarks'
                }) for dt, ret in
                self.trading_environment.benchmark_returns.iteritems()
                if dt.date() >= sim_params.period_start.date()
                and dt.date() <= sim_params.period_end.date()
            ]
        else:
            benchmark_return_source = self.benchmark_return_source

        date_sorted = date_sorted_sources(*self.sources)

        if source_filter:
            date_sorted = filter(source_filter, date_sorted)

        with_benchmarks = date_sorted_sources(benchmark_return_source,
                                              date_sorted)

        # Group together events with the same dt field. This depends on the
        # events already being sorted.
        return groupby(with_benchmarks, attrgetter('dt'))

    def _create_generator(self, sim_params, source_filter=None):
        """
        Create a basic generator setup using the sources to this algorithm.

        ::source_filter:: is a method that receives events in date
        sorted order, and returns True for those events that should be
        processed by the zipline, and False for those that should be
        skipped.
        """

        if not self.initialized:
            self.initialize(*self.initialize_args, **self.initialize_kwargs)
            self.initialized = True

        if self.perf_tracker is None:
            # HACK: When running with the `run` method, we set perf_tracker to
            # None so that it will be overwritten here.
            self.perf_tracker = PerformanceTracker(
                sim_params=sim_params, env=self.trading_environment)

        self.portfolio_needs_update = True
        self.account_needs_update = True
        self.performance_needs_update = True

        self.data_gen = self._create_data_generator(source_filter, sim_params)

        self.trading_client = AlgorithmSimulator(self, sim_params)

        transact_method = transact_partial(self.slippage, self.commission)
        self.set_transact(transact_method)

        return self.trading_client.transform(self.data_gen)

    def get_generator(self):
        """
        Override this method to add new logic to the construction
        of the generator. Overrides can use the _create_generator
        method to get a standard construction generator.
        """
        return self._create_generator(self.sim_params)

    # TODO: make a new subclass, e.g. BatchAlgorithm, and move
    # the run method to the subclass, and refactor to put the
    # generator creation logic into get_generator.
    def run(self,
            source,
            overwrite_sim_params=True,
            benchmark_return_source=None):
        """Run the algorithm.

        :Arguments:
            source : can be either:
                     - pandas.DataFrame
                     - zipline source
                     - list of sources

               If pandas.DataFrame is provided, it must have the
               following structure:
               * column names must be the different asset identifiers
               * index must be DatetimeIndex
               * array contents should be price info.

        :Returns:
            daily_stats : pandas.DataFrame
              Daily performance metrics such as returns, alpha etc.

        """

        # Ensure that source is a DataSource object
        if isinstance(source, list):
            if overwrite_sim_params:
                warnings.warn(
                    """List of sources passed, will not attempt to extract start and end
 dates. Make sure to set the correct fields in sim_params passed to
 __init__().""", UserWarning)
                overwrite_sim_params = False
        elif isinstance(source, pd.DataFrame):
            # if DataFrame provided, map columns to sids and wrap
            # in DataFrameSource
            copy_frame = source.copy()

            # Build new Assets for identifiers that can't be resolved as
            # sids/Assets
            identifiers_to_build = []
            for identifier in source.columns:
                if hasattr(identifier, '__int__'):
                    asset = self.asset_finder.retrieve_asset(sid=identifier,
                                                             default_none=True)
                    if asset is None:
                        identifiers_to_build.append(identifier)
                else:
                    identifiers_to_build.append(identifier)

            self.trading_environment.write_data(
                equities_identifiers=identifiers_to_build)
            copy_frame.columns = \
                self.asset_finder.map_identifier_index_to_sids(
                    source.columns, source.index[0]
                )
            source = DataFrameSource(copy_frame)

        elif isinstance(source, pd.Panel):
            # If Panel provided, map items to sids and wrap
            # in DataPanelSource
            copy_panel = source.copy()

            # Build new Assets for identifiers that can't be resolved as
            # sids/Assets
            identifiers_to_build = []
            for identifier in source.items:
                if hasattr(identifier, '__int__'):
                    asset = self.asset_finder.retrieve_asset(sid=identifier,
                                                             default_none=True)
                    if asset is None:
                        identifiers_to_build.append(identifier)
                else:
                    identifiers_to_build.append(identifier)

            self.trading_environment.write_data(
                equities_identifiers=identifiers_to_build)
            copy_panel.items = self.asset_finder.map_identifier_index_to_sids(
                source.items, source.major_axis[0])
            source = DataPanelSource(copy_panel)

        if isinstance(source, list):
            self.set_sources(source)
        else:
            self.set_sources([source])

        # Override sim_params if params are provided by the source.
        if overwrite_sim_params:
            if hasattr(source, 'start'):
                self.sim_params.period_start = source.start
            if hasattr(source, 'end'):
                self.sim_params.period_end = source.end
            # Changing period_start and period_close might require updating
            # of first_open and last_close.
            self.sim_params.update_internal_from_env(
                env=self.trading_environment)

        # The sids field of the source is the reference for the universe at
        # the start of the run
        self._current_universe = set()
        for source in self.sources:
            for sid in source.sids:
                self._current_universe.add(sid)
        # Check that all sids from the source are accounted for in
        # the AssetFinder. This retrieve call will raise an exception if the
        # sid is not found.
        for sid in self._current_universe:
            self.asset_finder.retrieve_asset(sid)

        # force a reset of the performance tracker, in case
        # this is a repeat run of the algorithm.
        self.perf_tracker = None

        # create zipline
        self.gen = self._create_generator(self.sim_params)

        # Create history containers
        if self.history_specs:
            self.history_container = self.history_container_class(
                self.history_specs,
                self.current_universe(),
                self.sim_params.first_open,
                self.sim_params.data_frequency,
                self.trading_environment,
            )

        # loop through simulated_trading, each iteration returns a
        # perf dictionary
        perfs = []
        for perf in self.gen:
            perfs.append(perf)

        # convert perf dict to pandas dataframe
        daily_stats = self._create_daily_stats(perfs)

        self.analyze(daily_stats)

        return daily_stats

    def _create_daily_stats(self, perfs):
        # create daily and cumulative stats dataframe
        daily_perfs = []
        # TODO: the loop here could overwrite expected properties
        # of daily_perf. Could potentially raise or log a
        # warning.
        for perf in perfs:
            if 'daily_perf' in perf:

                perf['daily_perf'].update(
                    perf['daily_perf'].pop('recorded_vars'))
                perf['daily_perf'].update(perf['cumulative_risk_metrics'])
                daily_perfs.append(perf['daily_perf'])
            else:
                self.risk_report = perf

        daily_dts = [
            np.datetime64(perf['period_close'], utc=True)
            for perf in daily_perfs
        ]
        daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)

        return daily_stats

    @api_method
    def add_transform(self, transform, days=None):
        """
        Ensures that the history container will have enough size to service
        a simple transform.

        :Arguments:
            transform : string
                The transform to add. must be an element of:
                {'mavg', 'stddev', 'vwap', 'returns'}.
            days : int <default=None>
                The maximum amount of days you will want for this transform.
                This is not needed for 'returns'.
        """
        if transform not in {'mavg', 'stddev', 'vwap', 'returns'}:
            raise ValueError('Invalid transform')

        if transform == 'returns':
            if days is not None:
                raise ValueError('returns does use days')

            self.add_history(2, '1d', 'price')
            return
        elif days is None:
            raise ValueError('no number of days specified')

        if self.sim_params.data_frequency == 'daily':
            mult = 1
            freq = '1d'
        else:
            mult = 390
            freq = '1m'

        bars = mult * days
        self.add_history(bars, freq, 'price')

        if transform == 'vwap':
            self.add_history(bars, freq, 'volume')

    @api_method
    def get_environment(self, field='platform'):
        env = {
            'arena': self.sim_params.arena,
            'data_frequency': self.sim_params.data_frequency,
            'start': self.sim_params.first_open,
            'end': self.sim_params.last_close,
            'capital_base': self.sim_params.capital_base,
            'platform': self._platform
        }
        if field == '*':
            return env
        else:
            return env[field]

    def add_event(self, rule=None, callback=None):
        """
        Adds an event to the algorithm's EventManager.
        """
        self.event_manager.add_event(
            zipline.utils.events.Event(rule, callback), )

    @api_method
    def schedule_function(self,
                          func,
                          date_rule=None,
                          time_rule=None,
                          half_days=True):
        """
        Schedules a function to be called with some timed rules.
        """
        date_rule = date_rule or DateRuleFactory.every_day()
        time_rule = ((time_rule or TimeRuleFactory.market_open())
                     if self.sim_params.data_frequency == 'minute' else
                     # If we are in daily mode the time_rule is ignored.
                     zipline.utils.events.Always())

        self.add_event(
            make_eventrule(date_rule, time_rule, half_days),
            func,
        )

    @api_method
    def record(self, *args, **kwargs):
        """
        Track and record local variable (i.e. attributes) each day.
        """
        # Make 2 objects both referencing the same iterator
        args = [iter(args)] * 2

        # Zip generates list entries by calling `next` on each iterator it
        # receives.  In this case the two iterators are the same object, so the
        # call to next on args[0] will also advance args[1], resulting in zip
        # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.
        positionals = zip(*args)
        for name, value in chain(positionals, iteritems(kwargs)):
            self._recorded_vars[name] = value

    @api_method
    def symbol(self, symbol_str):
        """
        Default symbol lookup for any source that directly maps the
        symbol to the Asset (e.g. yahoo finance).
        """
        # If the user has not set the symbol lookup date,
        # use the period_end as the date for sybmol->sid resolution.
        _lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \
            else self.sim_params.period_end

        return self.asset_finder.lookup_symbol_resolve_multiple(
            symbol_str, as_of_date=_lookup_date)

    @api_method
    def symbols(self, *args):
        """
        Default symbols lookup for any source that directly maps the
        symbol to the Asset (e.g. yahoo finance).
        """
        return [self.symbol(identifier) for identifier in args]

    @api_method
    def sid(self, a_sid):
        """
        Default sid lookup for any source that directly maps the integer sid
        to the Asset.
        """
        return self.asset_finder.retrieve_asset(a_sid)

    @api_method
    def future_chain(self, root_symbol, as_of_date=None):
        """ Look up a future chain with the specified parameters.

        Parameters
        ----------
        root_symbol : str
            The root symbol of a future chain.
        as_of_date : datetime.datetime or pandas.Timestamp or str, optional
            Date at which the chain determination is rooted. I.e. the
            existing contract whose notice date is first after this date is
            the primary contract, etc.

        Returns
        -------
        FutureChain
            The future chain matching the specified parameters.

        Raises
        ------
        RootSymbolNotFound
            If a future chain could not be found for the given root symbol.
        """
        if as_of_date:
            try:
                as_of_date = pd.Timestamp(as_of_date, tz='UTC')
            except ValueError:
                raise UnsupportedDatetimeFormat(input=as_of_date,
                                                method='future_chain')
        return FutureChain(asset_finder=self.asset_finder,
                           get_datetime=self.get_datetime,
                           root_symbol=root_symbol.upper(),
                           as_of_date=as_of_date)

    def _calculate_order_value_amount(self, asset, value):
        """
        Calculates how many shares/contracts to order based on the type of
        asset being ordered.
        """
        last_price = self.trading_client.current_data[asset].price

        if tolerant_equals(last_price, 0):
            zero_message = "Price of 0 for {psid}; can't infer value".format(
                psid=asset)
            if self.logger:
                self.logger.debug(zero_message)
            # Don't place any order
            return 0

        if isinstance(asset, Future):
            value_multiplier = asset.contract_multiplier
        else:
            value_multiplier = 1

        return value / (last_price * value_multiplier)

    @api_method
    def order(self,
              sid,
              amount,
              limit_price=None,
              stop_price=None,
              style=None):
        """
        Place an order using the specified parameters.
        """
        def round_if_near_integer(a, epsilon=1e-4):
            """
            Round a to the nearest integer if that integer is within an epsilon
            of a.
            """
            if abs(a - round(a)) <= epsilon:
                return round(a)
            else:
                return a

        # Truncate to the integer share count that's either within .0001 of
        # amount or closer to zero.
        # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0
        amount = int(round_if_near_integer(amount))

        # Raises a ZiplineError if invalid parameters are detected.
        self.validate_order_params(sid, amount, limit_price, stop_price, style)

        # Convert deprecated limit_price and stop_price parameters to use
        # ExecutionStyle objects.
        style = self.__convert_order_params_for_blotter(
            limit_price, stop_price, style)
        return self.blotter.order(sid, amount, style)

    def validate_order_params(self, asset, amount, limit_price, stop_price,
                              style):
        """
        Helper method for validating parameters to the order API function.

        Raises an UnsupportedOrderParameters if invalid arguments are found.
        """

        if not self.initialized:
            raise OrderDuringInitialize(
                msg="order() can only be called from within handle_data()")

        if style:
            if limit_price:
                raise UnsupportedOrderParameters(
                    msg="Passing both limit_price and style is not supported.")

            if stop_price:
                raise UnsupportedOrderParameters(
                    msg="Passing both stop_price and style is not supported.")

        if not isinstance(asset, Asset):
            raise UnsupportedOrderParameters(
                msg="Passing non-Asset argument to 'order()' is not supported."
                " Use 'sid()' or 'symbol()' methods to look up an Asset.")

        for control in self.trading_controls:
            control.validate(asset, amount, self.updated_portfolio(),
                             self.get_datetime(),
                             self.trading_client.current_data)

    @staticmethod
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
        """
        Helper method for converting deprecated limit_price and stop_price
        arguments into ExecutionStyle instances.

        This function assumes that either style == None or (limit_price,
        stop_price) == (None, None).
        """
        # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.
        if style:
            assert (limit_price, stop_price) == (None, None)
            return style
        if limit_price and stop_price:
            return StopLimitOrder(limit_price, stop_price)
        if limit_price:
            return LimitOrder(limit_price)
        if stop_price:
            return StopOrder(stop_price)
        else:
            return MarketOrder()

    @api_method
    def order_value(self,
                    sid,
                    value,
                    limit_price=None,
                    stop_price=None,
                    style=None):
        """
        Place an order by desired value rather than desired number of shares.
        If the requested sid is found in the universe, the requested value is
        divided by its price to imply the number of shares to transact.
        If the Asset being ordered is a Future, the 'value' calculated
        is actually the exposure, as Futures have no 'value'.

        value > 0 :: Buy/Cover
        value < 0 :: Sell/Short
        Market order:    order(sid, value)
        Limit order:     order(sid, value, limit_price)
        Stop order:      order(sid, value, None, stop_price)
        StopLimit order: order(sid, value, limit_price, stop_price)
        """
        amount = self._calculate_order_value_amount(sid, value)
        return self.order(sid,
                          amount,
                          limit_price=limit_price,
                          stop_price=stop_price,
                          style=style)

    @property
    def recorded_vars(self):
        return copy(self._recorded_vars)

    @property
    def portfolio(self):
        return self.updated_portfolio()

    def updated_portfolio(self):
        if self.portfolio_needs_update:
            self._portfolio = \
                self.perf_tracker.get_portfolio(self.performance_needs_update)
            self.portfolio_needs_update = False
            self.performance_needs_update = False
        return self._portfolio

    @property
    def account(self):
        return self.updated_account()

    def updated_account(self):
        if self.account_needs_update:
            self._account = \
                self.perf_tracker.get_account(self.performance_needs_update)
            self.account_needs_update = False
            self.performance_needs_update = False
        return self._account

    def set_logger(self, logger):
        self.logger = logger

    def on_dt_changed(self, dt):
        """
        Callback triggered by the simulation loop whenever the current dt
        changes.

        Any logic that should happen exactly once at the start of each datetime
        group should happen here.
        """
        assert isinstance(dt, datetime), \
            "Attempt to set algorithm's current time with non-datetime"
        assert dt.tzinfo == pytz.utc, \
            "Algorithm expects a utc datetime"

        self.datetime = dt
        self.perf_tracker.set_date(dt)
        self.blotter.set_date(dt)

    @api_method
    def get_datetime(self, tz=None):
        """
        Returns the simulation datetime.
        """
        dt = self.datetime
        assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime"

        if tz is not None:
            # Convert to the given timezone passed as a string or tzinfo.
            if isinstance(tz, string_types):
                tz = pytz.timezone(tz)
            dt = dt.astimezone(tz)

        return dt  # datetime.datetime objects are immutable.

    def set_transact(self, transact):
        """
        Set the method that will be called to create a
        transaction from open orders and trade events.
        """
        self.blotter.transact = transact

    def update_dividends(self, dividend_frame):
        """
        Set DataFrame used to process dividends.  DataFrame columns should
        contain at least the entries in zp.DIVIDEND_FIELDS.
        """
        self.perf_tracker.update_dividends(dividend_frame)

    @api_method
    def set_slippage(self, slippage):
        if not isinstance(slippage, SlippageModel):
            raise UnsupportedSlippageModel()
        if self.initialized:
            raise OverrideSlippagePostInit()
        self.slippage = slippage

    @api_method
    def set_commission(self, commission):
        if not isinstance(commission, (PerShare, PerTrade, PerDollar)):
            raise UnsupportedCommissionModel()

        if self.initialized:
            raise OverrideCommissionPostInit()
        self.commission = commission

    @api_method
    def set_symbol_lookup_date(self, dt):
        """
        Set the date for which symbols will be resolved to their sids
        (symbols may map to different firms or underlying assets at
        different times)
        """
        try:
            self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC')
        except ValueError:
            raise UnsupportedDatetimeFormat(input=dt,
                                            method='set_symbol_lookup_date')

    def set_sources(self, sources):
        assert isinstance(sources, list)
        self.sources = sources

    # Remain backwards compatibility
    @property
    def data_frequency(self):
        return self.sim_params.data_frequency

    @data_frequency.setter
    def data_frequency(self, value):
        assert value in ('daily', 'minute')
        self.sim_params.data_frequency = value

    @api_method
    def order_percent(self,
                      sid,
                      percent,
                      limit_price=None,
                      stop_price=None,
                      style=None):
        """
        Place an order in the specified asset corresponding to the given
        percent of the current portfolio value.

        Note that percent must expressed as a decimal (0.50 means 50\%).
        """
        value = self.portfolio.portfolio_value * percent
        return self.order_value(sid,
                                value,
                                limit_price=limit_price,
                                stop_price=stop_price,
                                style=style)

    @api_method
    def order_target(self,
                     sid,
                     target,
                     limit_price=None,
                     stop_price=None,
                     style=None):
        """
        Place an order to adjust a position to a target number of shares. If
        the position doesn't already exist, this is equivalent to placing a new
        order. If the position does exist, this is equivalent to placing an
        order for the difference between the target number of shares and the
        current number of shares.
        """
        if sid in self.portfolio.positions:
            current_position = self.portfolio.positions[sid].amount
            req_shares = target - current_position
            return self.order(sid,
                              req_shares,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)
        else:
            return self.order(sid,
                              target,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)

    @api_method
    def order_target_value(self,
                           sid,
                           target,
                           limit_price=None,
                           stop_price=None,
                           style=None):
        """
        Place an order to adjust a position to a target value. If
        the position doesn't already exist, this is equivalent to placing a new
        order. If the position does exist, this is equivalent to placing an
        order for the difference between the target value and the
        current value.
        If the Asset being ordered is a Future, the 'target value' calculated
        is actually the target exposure, as Futures have no 'value'.
        """
        target_amount = self._calculate_order_value_amount(sid, target)
        return self.order_target(sid,
                                 target_amount,
                                 limit_price=limit_price,
                                 stop_price=stop_price,
                                 style=style)

    @api_method
    def order_target_percent(self,
                             sid,
                             target,
                             limit_price=None,
                             stop_price=None,
                             style=None):
        """
        Place an order to adjust a position to a target percent of the
        current portfolio value. If the position doesn't already exist, this is
        equivalent to placing a new order. If the position does exist, this is
        equivalent to placing an order for the difference between the target
        percent and the current percent.

        Note that target must expressed as a decimal (0.50 means 50\%).
        """
        target_value = self.portfolio.portfolio_value * target
        return self.order_target_value(sid,
                                       target_value,
                                       limit_price=limit_price,
                                       stop_price=stop_price,
                                       style=style)

    @api_method
    def get_open_orders(self, sid=None):
        if sid is None:
            return {
                key: [order.to_api_obj() for order in orders]
                for key, orders in iteritems(self.blotter.open_orders)
                if orders
            }
        if sid in self.blotter.open_orders:
            orders = self.blotter.open_orders[sid]
            return [order.to_api_obj() for order in orders]
        return []

    @api_method
    def get_order(self, order_id):
        if order_id in self.blotter.orders:
            return self.blotter.orders[order_id].to_api_obj()

    @api_method
    def cancel_order(self, order_param):
        order_id = order_param
        if isinstance(order_param, zipline.protocol.Order):
            order_id = order_param.id

        self.blotter.cancel(order_id)

    @api_method
    def add_history(self, bar_count, frequency, field, ffill=True):
        data_frequency = self.sim_params.data_frequency
        history_spec = HistorySpec(bar_count,
                                   frequency,
                                   field,
                                   ffill,
                                   data_frequency=data_frequency,
                                   env=self.trading_environment)
        self.history_specs[history_spec.key_str] = history_spec
        if self.initialized:
            if self.history_container:
                self.history_container.ensure_spec(
                    history_spec,
                    self.datetime,
                    self._most_recent_data,
                )
            else:
                self.history_container = self.history_container_class(
                    self.history_specs,
                    self.current_universe(),
                    self.sim_params.first_open,
                    self.sim_params.data_frequency,
                    env=self.trading_environment,
                )

    def get_history_spec(self, bar_count, frequency, field, ffill):
        spec_key = HistorySpec.spec_key(bar_count, frequency, field, ffill)
        if spec_key not in self.history_specs:
            data_freq = self.sim_params.data_frequency
            spec = HistorySpec(
                bar_count,
                frequency,
                field,
                ffill,
                data_frequency=data_freq,
                env=self.trading_environment,
            )
            self.history_specs[spec_key] = spec
            if not self.history_container:
                self.history_container = self.history_container_class(
                    self.history_specs,
                    self.current_universe(),
                    self.datetime,
                    self.sim_params.data_frequency,
                    bar_data=self._most_recent_data,
                    env=self.trading_environment,
                )
            self.history_container.ensure_spec(
                spec,
                self.datetime,
                self._most_recent_data,
            )
        return self.history_specs[spec_key]

    @api_method
    def history(self, bar_count, frequency, field, ffill=True):
        history_spec = self.get_history_spec(
            bar_count,
            frequency,
            field,
            ffill,
        )
        return self.history_container.get_history(history_spec, self.datetime)

    ####################
    # Account Controls #
    ####################

    def register_account_control(self, control):
        """
        Register a new AccountControl to be checked on each bar.
        """
        if self.initialized:
            raise RegisterAccountControlPostInit()
        self.account_controls.append(control)

    def validate_account_controls(self):
        for control in self.account_controls:
            control.validate(self.updated_portfolio(), self.updated_account(),
                             self.get_datetime(),
                             self.trading_client.current_data)

    @api_method
    def set_max_leverage(self, max_leverage=None):
        """
        Set a limit on the maximum leverage of the algorithm.
        """
        control = MaxLeverage(max_leverage)
        self.register_account_control(control)

    ####################
    # Trading Controls #
    ####################

    def register_trading_control(self, control):
        """
        Register a new TradingControl to be checked prior to order calls.
        """
        if self.initialized:
            raise RegisterTradingControlPostInit()
        self.trading_controls.append(control)

    @api_method
    def set_max_position_size(self,
                              sid=None,
                              max_shares=None,
                              max_notional=None):
        """
        Set a limit on the number of shares and/or dollar value held for the
        given sid. Limits are treated as absolute values and are enforced at
        the time that the algo attempts to place an order for sid. This means
        that it's possible to end up with more than the max number of shares
        due to splits/dividends, and more than the max notional due to price
        improvement.

        If an algorithm attempts to place an order that would result in
        increasing the absolute value of shares/dollar value exceeding one of
        these limits, raise a TradingControlException.
        """
        control = MaxPositionSize(asset=sid,
                                  max_shares=max_shares,
                                  max_notional=max_notional)
        self.register_trading_control(control)

    @api_method
    def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
        """
        Set a limit on the number of shares and/or dollar value of any single
        order placed for sid.  Limits are treated as absolute values and are
        enforced at the time that the algo attempts to place an order for sid.

        If an algorithm attempts to place an order that would result in
        exceeding one of these limits, raise a TradingControlException.
        """
        control = MaxOrderSize(asset=sid,
                               max_shares=max_shares,
                               max_notional=max_notional)
        self.register_trading_control(control)

    @api_method
    def set_max_order_count(self, max_count):
        """
        Set a limit on the number of orders that can be placed within the given
        time interval.
        """
        control = MaxOrderCount(max_count)
        self.register_trading_control(control)

    @api_method
    def set_do_not_order_list(self, restricted_list):
        """
        Set a restriction on which sids can be ordered.
        """
        control = RestrictedListOrder(restricted_list)
        self.register_trading_control(control)

    @api_method
    def set_long_only(self):
        """
        Set a rule specifying that this algorithm cannot take short positions.
        """
        self.register_trading_control(LongOnly())

    ###########
    # FFC API #
    ###########
    @api_method
    @require_not_initialized(AddTermPostInit())
    def add_factor(self, factor, name):
        if name in self._factors:
            raise ValueError("Name %r is already a factor!" % name)
        self._factors[name] = factor

    @api_method
    @require_not_initialized(AddTermPostInit())
    def add_filter(self, filter):
        name = "anon_filter_%d" % len(self._filters)
        self._filters[name] = filter

    # Note: add_classifier is not yet implemented since you can't do anything
    # useful with classifiers yet.

    def _all_terms(self):
        # Merge all three dicts.
        return dict(
            chain.from_iterable(
                iteritems(terms) for terms in (self._filters, self._factors,
                                               self._classifiers)))

    def compute_factor_matrix(self, start_date):
        """
        Compute a factor matrix starting at start_date.
        """
        days = self.trading_environment.trading_days
        start_date_loc = days.get_loc(start_date)
        sim_end = self.sim_params.last_close.normalize()
        end_loc = min(start_date_loc + 252, days.get_loc(sim_end))
        end_date = days[end_loc]
        return self.engine.factor_matrix(
            self._all_terms(),
            start_date,
            end_date,
        ), end_date

    def current_universe(self):
        return self._current_universe

    @classmethod
    def all_api_methods(cls):
        """
        Return a list of all the TradingAlgorithm API methods.
        """
        return [
            fn for fn in itervalues(vars(cls))
            if getattr(fn, 'is_api_method', False)
        ]
Ejemplo n.º 30
0
    def transaction_sim(self, **params):
        """ This is a utility method that asserts expected
        results for conversion of orders to transactions given a
        trade history"""
        tempdir = TempDirectory()
        try:
            trade_count = params['trade_count']
            trade_interval = params['trade_interval']
            order_count = params['order_count']
            order_amount = params['order_amount']
            order_interval = params['order_interval']
            expected_txn_count = params['expected_txn_count']
            expected_txn_volume = params['expected_txn_volume']

            # optional parameters
            # ---------------------
            # if present, alternate between long and short sales
            alternate = params.get('alternate')

            # if present, expect transaction amounts to match orders exactly.
            complete_fill = params.get('complete_fill')

            env = TradingEnvironment()

            sid = 1

            if trade_interval < timedelta(days=1):
                sim_params = factory.create_simulation_parameters(
                    data_frequency="minute")

                minutes = env.market_minute_window(
                    sim_params.first_open,
                    int((trade_interval.total_seconds() / 60) * trade_count) +
                    100)

                price_data = np.array([10.1] * len(minutes))
                assets = {
                    sid:
                    pd.DataFrame({
                        "open": price_data,
                        "high": price_data,
                        "low": price_data,
                        "close": price_data,
                        "volume": np.array([100] * len(minutes)),
                        "dt": minutes
                    }).set_index("dt")
                }

                write_bcolz_minute_data(
                    env, env.days_in_range(minutes[0], minutes[-1]),
                    tempdir.path, assets)

                equity_minute_reader = BcolzMinuteBarReader(tempdir.path)

                data_portal = DataPortal(
                    env,
                    equity_minute_reader=equity_minute_reader,
                )
            else:
                sim_params = factory.create_simulation_parameters(
                    data_frequency="daily")

                days = sim_params.trading_days

                assets = {
                    1:
                    pd.DataFrame(
                        {
                            "open": [10.1] * len(days),
                            "high": [10.1] * len(days),
                            "low": [10.1] * len(days),
                            "close": [10.1] * len(days),
                            "volume": [100] * len(days),
                            "day": [day.value for day in days]
                        },
                        index=days)
                }

                path = os.path.join(tempdir.path, "testdata.bcolz")
                DailyBarWriterFromDataFrames(assets).write(path, days, assets)

                equity_daily_reader = BcolzDailyBarReader(path)

                data_portal = DataPortal(
                    env,
                    equity_daily_reader=equity_daily_reader,
                )

            if "default_slippage" not in params or \
               not params["default_slippage"]:
                slippage_func = FixedSlippage()
            else:
                slippage_func = None

            blotter = Blotter(sim_params.data_frequency, self.env.asset_finder,
                              slippage_func)

            env.write_data(
                equities_data={
                    sid: {
                        "start_date": sim_params.trading_days[0],
                        "end_date": sim_params.trading_days[-1]
                    }
                })

            start_date = sim_params.first_open

            if alternate:
                alternator = -1
            else:
                alternator = 1

            tracker = PerformanceTracker(sim_params, self.env)

            # replicate what tradesim does by going through every minute or day
            # of the simulation and processing open orders each time
            if sim_params.data_frequency == "minute":
                ticks = minutes
            else:
                ticks = days

            transactions = []

            order_list = []
            order_date = start_date
            for tick in ticks:
                blotter.current_dt = tick
                if tick >= order_date and len(order_list) < order_count:
                    # place an order
                    direction = alternator**len(order_list)
                    order_id = blotter.order(
                        blotter.asset_finder.retrieve_asset(sid),
                        order_amount * direction, MarketOrder())
                    order_list.append(blotter.orders[order_id])
                    order_date = order_date + order_interval
                    # move after market orders to just after market next
                    # market open.
                    if order_date.hour >= 21:
                        if order_date.minute >= 00:
                            order_date = order_date + timedelta(days=1)
                            order_date = order_date.replace(hour=14, minute=30)
                else:
                    bar_data = BarData(data_portal, lambda: tick,
                                       sim_params.data_frequency)
                    txns, _ = blotter.get_transactions(bar_data)
                    for txn in txns:
                        tracker.process_transaction(txn)
                        transactions.append(txn)

            for i in range(order_count):
                order = order_list[i]
                self.assertEqual(order.sid, sid)
                self.assertEqual(order.amount, order_amount * alternator**i)

            if complete_fill:
                self.assertEqual(len(transactions), len(order_list))

            total_volume = 0
            for i in range(len(transactions)):
                txn = transactions[i]
                total_volume += txn.amount
                if complete_fill:
                    order = order_list[i]
                    self.assertEqual(order.amount, txn.amount)

            self.assertEqual(total_volume, expected_txn_volume)

            self.assertEqual(len(transactions), expected_txn_count)

            cumulative_pos = tracker.position_tracker.positions[sid]
            if total_volume == 0:
                self.assertIsNone(cumulative_pos)
            else:
                self.assertEqual(total_volume, cumulative_pos.amount)

            # the open orders should not contain sid.
            oo = blotter.open_orders
            self.assertNotIn(sid, oo, "Entry is removed when no open orders")
        finally:
            tempdir.cleanup()
Ejemplo n.º 31
0
class TradingAlgorithm(object):
    """
    Base class for trading algorithms. Inherit and overload
    initialize() and handle_data(data).

    A new algorithm could look like this:
    ```
    from zipline.api import order, symbol

    def initialize(context):
        context.sid = symbol('AAPL')
        context.amount = 100

    def handle_data(context, data):
        sid = context.sid
        amount = context.amount
        order(sid, amount)
    ```
    To then to run this algorithm pass these functions to
    TradingAlgorithm:

    my_algo = TradingAlgorithm(initialize, handle_data)
    stats = my_algo.run(data)

    """

    def __init__(self, *args, **kwargs):
        """Initialize sids and other state variables.

        :Arguments:
        :Optional:
            initialize : function
                Function that is called with a single
                argument at the begninning of the simulation.
            handle_data : function
                Function that is called with 2 arguments
                (context and data) on every bar.
            script : str
                Algoscript that contains initialize and
                handle_data function definition.
            data_frequency : {'daily', 'minute'}
               The duration of the bars.
            capital_base : float <default: 1.0e5>
               How much capital to start with.
            instant_fill : bool <default: False>
               Whether to fill orders immediately or on next bar.
            asset_finder : An AssetFinder object
                A new AssetFinder object to be used in this TradingEnvironment
            equities_metadata : can be either:
                            - dict
                            - pandas.DataFrame
                            - object with 'read' property
                If dict is provided, it must have the following structure:
                * keys are the identifiers
                * values are dicts containing the metadata, with the metadata
                  field name as the key
                If pandas.DataFrame is provided, it must have the
                following structure:
                * column names must be the metadata fields
                * index must be the different asset identifiers
                * array contents should be the metadata value
                If an object with a 'read' property is provided, 'read' must
                return rows containing at least one of 'sid' or 'symbol' along
                with the other metadata fields.
            identifiers : List
                Any asset identifiers that are not provided in the
                equities_metadata, but will be traded by this TradingAlgorithm
        """
        self.sources = []

        # List of trading controls to be used to validate orders.
        self.trading_controls = []

        # List of account controls to be checked on each bar.
        self.account_controls = []

        self._recorded_vars = {}
        self.namespace = kwargs.pop('namespace', {})

        self._platform = kwargs.pop('platform', 'zipline')

        self.logger = None

        self.benchmark_return_source = None

        # default components for transact
        self.slippage = VolumeShareSlippage()
        self.commission = PerShare()

        self.instant_fill = kwargs.pop('instant_fill', False)

        # If an env has been provided, pop it
        self.trading_environment = kwargs.pop('env', None)

        if self.trading_environment is None:
            self.trading_environment = TradingEnvironment()

        # Update the TradingEnvironment with the provided asset metadata
        self.trading_environment.write_data(
            equities_data=kwargs.pop('equities_metadata', {}),
            equities_identifiers=kwargs.pop('identifiers', []),
            futures_data=kwargs.pop('futures_metadata', {}),
        )

        # set the capital base
        self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)
        self.sim_params = kwargs.pop('sim_params', None)
        if self.sim_params is None:
            self.sim_params = create_simulation_parameters(
                capital_base=self.capital_base,
                start=kwargs.pop('start', None),
                end=kwargs.pop('end', None),
                env=self.trading_environment,
            )
        else:
            self.sim_params.update_internal_from_env(self.trading_environment)

        # Build a perf_tracker
        self.perf_tracker = PerformanceTracker(sim_params=self.sim_params,
                                               env=self.trading_environment)

        # Pull in the environment's new AssetFinder for quick reference
        self.asset_finder = self.trading_environment.asset_finder
        self.init_engine(kwargs.pop('ffc_loader', None))

        # Maps from name to Term
        self._filters = {}
        self._factors = {}
        self._classifiers = {}

        self.blotter = kwargs.pop('blotter', None)
        if not self.blotter:
            self.blotter = Blotter()

        # Set the dt initally to the period start by forcing it to change
        self.on_dt_changed(self.sim_params.period_start)

        # The symbol lookup date specifies the date to use when resolving
        # symbols to sids, and can be set using set_symbol_lookup_date()
        self._symbol_lookup_date = None

        self.portfolio_needs_update = True
        self.account_needs_update = True
        self.performance_needs_update = True
        self._portfolio = None
        self._account = None

        self.history_container_class = kwargs.pop(
            'history_container_class', HistoryContainer,
        )
        self.history_container = None
        self.history_specs = {}

        # If string is passed in, execute and get reference to
        # functions.
        self.algoscript = kwargs.pop('script', None)

        self._initialize = None
        self._before_trading_start = None
        self._analyze = None

        self.event_manager = EventManager()

        if self.algoscript is not None:
            filename = kwargs.pop('algo_filename', None)
            if filename is None:
                filename = '<string>'
            code = compile(self.algoscript, filename, 'exec')
            exec_(code, self.namespace)
            self._initialize = self.namespace.get('initialize')
            if 'handle_data' not in self.namespace:
                raise ValueError('You must define a handle_data function.')
            else:
                self._handle_data = self.namespace['handle_data']

            self._before_trading_start = \
                self.namespace.get('before_trading_start')
            # Optional analyze function, gets called after run
            self._analyze = self.namespace.get('analyze')

        elif kwargs.get('initialize') and kwargs.get('handle_data'):
            if self.algoscript is not None:
                raise ValueError('You can not set script and \
                initialize/handle_data.')
            self._initialize = kwargs.pop('initialize')
            self._handle_data = kwargs.pop('handle_data')
            self._before_trading_start = kwargs.pop('before_trading_start',
                                                    None)

        self.event_manager.add_event(
            zipline.utils.events.Event(
                zipline.utils.events.Always(),
                # We pass handle_data.__func__ to get the unbound method.
                # We will explicitly pass the algorithm to bind it again.
                self.handle_data.__func__,
            ),
            prepend=True,
        )

        # If method not defined, NOOP
        if self._initialize is None:
            self._initialize = lambda x: None

        # Alternative way of setting data_frequency for backwards
        # compatibility.
        if 'data_frequency' in kwargs:
            self.data_frequency = kwargs.pop('data_frequency')

        self._most_recent_data = None

        # Prepare the algo for initialization
        self.initialized = False
        self.initialize_args = args
        self.initialize_kwargs = kwargs

    def init_engine(self, loader):
        """
        Construct and save an FFCEngine from loader.

        If loader is None, constructs a NoOpFFCEngine.
        """
        if loader is not None:
            self.engine = SimpleFFCEngine(
                loader,
                self.trading_environment.trading_days,
                self.asset_finder,
            )
        else:
            self.engine = NoOpFFCEngine()

    def initialize(self, *args, **kwargs):
        """
        Call self._initialize with `self` made available to Zipline API
        functions.
        """
        with ZiplineAPI(self):
            self._initialize(self, *args, **kwargs)

    def before_trading_start(self, data):
        if self._before_trading_start is None:
            return

        self._before_trading_start(self, data)

    def handle_data(self, data):
        self._most_recent_data = data
        if self.history_container:
            self.history_container.update(data, self.datetime)

        self._handle_data(self, data)

        # Unlike trading controls which remain constant unless placing an
        # order, account controls can change each bar. Thus, must check
        # every bar no matter if the algorithm places an order or not.
        self.validate_account_controls()

    def analyze(self, perf):
        if self._analyze is None:
            return

        with ZiplineAPI(self):
            self._analyze(self, perf)

    def __repr__(self):
        """
        N.B. this does not yet represent a string that can be used
        to instantiate an exact copy of an algorithm.

        However, it is getting close, and provides some value as something
        that can be inspected interactively.
        """
        return """
{class_name}(
    capital_base={capital_base}
    sim_params={sim_params},
    initialized={initialized},
    slippage={slippage},
    commission={commission},
    blotter={blotter},
    recorded_vars={recorded_vars})
""".strip().format(class_name=self.__class__.__name__,
                   capital_base=self.capital_base,
                   sim_params=repr(self.sim_params),
                   initialized=self.initialized,
                   slippage=repr(self.slippage),
                   commission=repr(self.commission),
                   blotter=repr(self.blotter),
                   recorded_vars=repr(self.recorded_vars))

    def _create_data_generator(self, source_filter, sim_params=None):
        """
        Create a merged data generator using the sources attached to this
        algorithm.

        ::source_filter:: is a method that receives events in date
        sorted order, and returns True for those events that should be
        processed by the zipline, and False for those that should be
        skipped.
        """
        if sim_params is None:
            sim_params = self.sim_params

        if self.benchmark_return_source is None:
            if sim_params.data_frequency == 'minute' or \
               sim_params.emission_rate == 'minute':
                def update_time(date):
                    return self.trading_environment.get_open_and_close(date)[1]
            else:
                def update_time(date):
                    return date
            benchmark_return_source = [
                Event({'dt': update_time(dt),
                       'returns': ret,
                       'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK,
                       'source_id': 'benchmarks'})
                for dt, ret in
                self.trading_environment.benchmark_returns.iteritems()
                if dt.date() >= sim_params.period_start.date() and
                dt.date() <= sim_params.period_end.date()
            ]
        else:
            benchmark_return_source = self.benchmark_return_source

        date_sorted = date_sorted_sources(*self.sources)

        if source_filter:
            date_sorted = filter(source_filter, date_sorted)

        with_benchmarks = date_sorted_sources(benchmark_return_source,
                                              date_sorted)

        # Group together events with the same dt field. This depends on the
        # events already being sorted.
        return groupby(with_benchmarks, attrgetter('dt'))

    def _create_generator(self, sim_params, source_filter=None):
        """
        Create a basic generator setup using the sources to this algorithm.

        ::source_filter:: is a method that receives events in date
        sorted order, and returns True for those events that should be
        processed by the zipline, and False for those that should be
        skipped.
        """

        if not self.initialized:
            self.initialize(*self.initialize_args, **self.initialize_kwargs)
            self.initialized = True

        if self.perf_tracker is None:
            # HACK: When running with the `run` method, we set perf_tracker to
            # None so that it will be overwritten here.
            self.perf_tracker = PerformanceTracker(
                sim_params=sim_params, env=self.trading_environment
            )

        self.portfolio_needs_update = True
        self.account_needs_update = True
        self.performance_needs_update = True

        self.data_gen = self._create_data_generator(source_filter, sim_params)

        self.trading_client = AlgorithmSimulator(self, sim_params)

        transact_method = transact_partial(self.slippage, self.commission)
        self.set_transact(transact_method)

        return self.trading_client.transform(self.data_gen)

    def get_generator(self):
        """
        Override this method to add new logic to the construction
        of the generator. Overrides can use the _create_generator
        method to get a standard construction generator.
        """
        return self._create_generator(self.sim_params)

    # TODO: make a new subclass, e.g. BatchAlgorithm, and move
    # the run method to the subclass, and refactor to put the
    # generator creation logic into get_generator.
    def run(self, source, overwrite_sim_params=True,
            benchmark_return_source=None):
        """Run the algorithm.

        :Arguments:
            source : can be either:
                     - pandas.DataFrame
                     - zipline source
                     - list of sources

               If pandas.DataFrame is provided, it must have the
               following structure:
               * column names must be the different asset identifiers
               * index must be DatetimeIndex
               * array contents should be price info.

        :Returns:
            daily_stats : pandas.DataFrame
              Daily performance metrics such as returns, alpha etc.

        """

        # Ensure that source is a DataSource object
        if isinstance(source, list):
            if overwrite_sim_params:
                warnings.warn("""List of sources passed, will not attempt to extract start and end
 dates. Make sure to set the correct fields in sim_params passed to
 __init__().""", UserWarning)
                overwrite_sim_params = False
        elif isinstance(source, pd.DataFrame):
            # if DataFrame provided, map columns to sids and wrap
            # in DataFrameSource
            copy_frame = source.copy()
            copy_frame.columns = self._write_and_map_id_index_to_sids(
                source.columns, source.index[0],
            )
            source = DataFrameSource(copy_frame)

        elif isinstance(source, pd.Panel):
            # If Panel provided, map items to sids and wrap
            # in DataPanelSource
            copy_panel = source.copy()
            copy_panel.items = self._write_and_map_id_index_to_sids(
                source.items, source.major_axis[0],
            )
            source = DataPanelSource(copy_panel)

        if isinstance(source, list):
            self.set_sources(source)
        else:
            self.set_sources([source])

        # Override sim_params if params are provided by the source.
        if overwrite_sim_params:
            if hasattr(source, 'start'):
                self.sim_params.period_start = source.start
            if hasattr(source, 'end'):
                self.sim_params.period_end = source.end
            # Changing period_start and period_close might require updating
            # of first_open and last_close.
            self.sim_params.update_internal_from_env(
                env=self.trading_environment
            )

        # The sids field of the source is the reference for the universe at
        # the start of the run
        self._current_universe = set()
        for source in self.sources:
            for sid in source.sids:
                self._current_universe.add(sid)
        # Check that all sids from the source are accounted for in
        # the AssetFinder. This retrieve call will raise an exception if the
        # sid is not found.
        for sid in self._current_universe:
            self.asset_finder.retrieve_asset(sid)

        # force a reset of the performance tracker, in case
        # this is a repeat run of the algorithm.
        self.perf_tracker = None

        # create zipline
        self.gen = self._create_generator(self.sim_params)

        # Create history containers
        if self.history_specs:
            self.history_container = self.history_container_class(
                self.history_specs,
                self.current_universe(),
                self.sim_params.first_open,
                self.sim_params.data_frequency,
                self.trading_environment,
            )

        # loop through simulated_trading, each iteration returns a
        # perf dictionary
        perfs = []
        for perf in self.gen:
            perfs.append(perf)

        # convert perf dict to pandas dataframe
        daily_stats = self._create_daily_stats(perfs)

        self.analyze(daily_stats)

        return daily_stats

    def _write_and_map_id_index_to_sids(self, identifiers, as_of_date):
        # Build new Assets for identifiers that can't be resolved as
        # sids/Assets
        identifiers_to_build = []
        for identifier in identifiers:
            asset = None

            if isinstance(identifier, Asset):
                asset = self.asset_finder.retrieve_asset(sid=identifier.sid,
                                                         default_none=True)

            elif hasattr(identifier, '__int__'):
                asset = self.asset_finder.retrieve_asset(sid=identifier,
                                                         default_none=True)
            if asset is None:
                identifiers_to_build.append(identifier)

        self.trading_environment.write_data(
            equities_identifiers=identifiers_to_build)

        return self.asset_finder.map_identifier_index_to_sids(
            identifiers, as_of_date,
        )

    def _create_daily_stats(self, perfs):
        # create daily and cumulative stats dataframe
        daily_perfs = []
        # TODO: the loop here could overwrite expected properties
        # of daily_perf. Could potentially raise or log a
        # warning.
        for perf in perfs:
            if 'daily_perf' in perf:

                perf['daily_perf'].update(
                    perf['daily_perf'].pop('recorded_vars')
                )
                perf['daily_perf'].update(perf['cumulative_risk_metrics'])
                daily_perfs.append(perf['daily_perf'])
            else:
                self.risk_report = perf

        daily_dts = [np.datetime64(perf['period_close'], utc=True)
                     for perf in daily_perfs]
        daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)

        return daily_stats

    @api_method
    def add_transform(self, transform, days=None):
        """
        Ensures that the history container will have enough size to service
        a simple transform.

        :Arguments:
            transform : string
                The transform to add. must be an element of:
                {'mavg', 'stddev', 'vwap', 'returns'}.
            days : int <default=None>
                The maximum amount of days you will want for this transform.
                This is not needed for 'returns'.
        """
        if transform not in {'mavg', 'stddev', 'vwap', 'returns'}:
            raise ValueError('Invalid transform')

        if transform == 'returns':
            if days is not None:
                raise ValueError('returns does use days')

            self.add_history(2, '1d', 'price')
            return
        elif days is None:
            raise ValueError('no number of days specified')

        if self.sim_params.data_frequency == 'daily':
            mult = 1
            freq = '1d'
        else:
            mult = 390
            freq = '1m'

        bars = mult * days
        self.add_history(bars, freq, 'price')

        if transform == 'vwap':
            self.add_history(bars, freq, 'volume')

    @api_method
    def get_environment(self, field='platform'):
        env = {
            'arena': self.sim_params.arena,
            'data_frequency': self.sim_params.data_frequency,
            'start': self.sim_params.first_open,
            'end': self.sim_params.last_close,
            'capital_base': self.sim_params.capital_base,
            'platform': self._platform
        }
        if field == '*':
            return env
        else:
            return env[field]

    def add_event(self, rule=None, callback=None):
        """
        Adds an event to the algorithm's EventManager.
        """
        self.event_manager.add_event(
            zipline.utils.events.Event(rule, callback),
        )

    @api_method
    def schedule_function(self,
                          func,
                          date_rule=None,
                          time_rule=None,
                          half_days=True):
        """
        Schedules a function to be called with some timed rules.
        """
        date_rule = date_rule or DateRuleFactory.every_day()
        time_rule = ((time_rule or TimeRuleFactory.market_open())
                     if self.sim_params.data_frequency == 'minute' else
                     # If we are in daily mode the time_rule is ignored.
                     zipline.utils.events.Always())

        self.add_event(
            make_eventrule(date_rule, time_rule, half_days),
            func,
        )

    @api_method
    def record(self, *args, **kwargs):
        """
        Track and record local variable (i.e. attributes) each day.
        """
        # Make 2 objects both referencing the same iterator
        args = [iter(args)] * 2

        # Zip generates list entries by calling `next` on each iterator it
        # receives.  In this case the two iterators are the same object, so the
        # call to next on args[0] will also advance args[1], resulting in zip
        # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.
        positionals = zip(*args)
        for name, value in chain(positionals, iteritems(kwargs)):
            self._recorded_vars[name] = value

    @api_method
    def symbol(self, symbol_str):
        """
        Default symbol lookup for any source that directly maps the
        symbol to the Asset (e.g. yahoo finance).
        """
        # If the user has not set the symbol lookup date,
        # use the period_end as the date for sybmol->sid resolution.
        _lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \
            else self.sim_params.period_end

        return self.asset_finder.lookup_symbol(
            symbol_str,
            as_of_date=_lookup_date,
        )

    @api_method
    def symbols(self, *args):
        """
        Default symbols lookup for any source that directly maps the
        symbol to the Asset (e.g. yahoo finance).
        """
        return [self.symbol(identifier) for identifier in args]

    @api_method
    def sid(self, a_sid):
        """
        Default sid lookup for any source that directly maps the integer sid
        to the Asset.
        """
        return self.asset_finder.retrieve_asset(a_sid)

    @api_method
    def future_chain(self, root_symbol, as_of_date=None):
        """ Look up a future chain with the specified parameters.

        Parameters
        ----------
        root_symbol : str
            The root symbol of a future chain.
        as_of_date : datetime.datetime or pandas.Timestamp or str, optional
            Date at which the chain determination is rooted. I.e. the
            existing contract whose notice date is first after this date is
            the primary contract, etc.

        Returns
        -------
        FutureChain
            The future chain matching the specified parameters.

        Raises
        ------
        RootSymbolNotFound
            If a future chain could not be found for the given root symbol.
        """
        if as_of_date:
            try:
                as_of_date = pd.Timestamp(as_of_date, tz='UTC')
            except ValueError:
                raise UnsupportedDatetimeFormat(input=as_of_date,
                                                method='future_chain')
        return FutureChain(
            asset_finder=self.asset_finder,
            get_datetime=self.get_datetime,
            root_symbol=root_symbol.upper(),
            as_of_date=as_of_date
        )

    def _calculate_order_value_amount(self, asset, value):
        """
        Calculates how many shares/contracts to order based on the type of
        asset being ordered.
        """
        last_price = self.trading_client.current_data[asset].price

        if tolerant_equals(last_price, 0):
            zero_message = "Price of 0 for {psid}; can't infer value".format(
                psid=asset
            )
            if self.logger:
                self.logger.debug(zero_message)
            # Don't place any order
            return 0

        if isinstance(asset, Future):
            value_multiplier = asset.contract_multiplier
        else:
            value_multiplier = 1

        return value / (last_price * value_multiplier)

    @api_method
    def order(self, sid, amount,
              limit_price=None,
              stop_price=None,
              style=None):
        """
        Place an order using the specified parameters.
        """

        def round_if_near_integer(a, epsilon=1e-4):
            """
            Round a to the nearest integer if that integer is within an epsilon
            of a.
            """
            if abs(a - round(a)) <= epsilon:
                return round(a)
            else:
                return a

        # Truncate to the integer share count that's either within .0001 of
        # amount or closer to zero.
        # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0
        amount = int(round_if_near_integer(amount))

        # Raises a ZiplineError if invalid parameters are detected.
        self.validate_order_params(sid,
                                   amount,
                                   limit_price,
                                   stop_price,
                                   style)

        # Convert deprecated limit_price and stop_price parameters to use
        # ExecutionStyle objects.
        style = self.__convert_order_params_for_blotter(limit_price,
                                                        stop_price,
                                                        style)
        return self.blotter.order(sid, amount, style)

    def validate_order_params(self,
                              asset,
                              amount,
                              limit_price,
                              stop_price,
                              style):
        """
        Helper method for validating parameters to the order API function.

        Raises an UnsupportedOrderParameters if invalid arguments are found.
        """

        if not self.initialized:
            raise OrderDuringInitialize(
                msg="order() can only be called from within handle_data()"
            )

        if style:
            if limit_price:
                raise UnsupportedOrderParameters(
                    msg="Passing both limit_price and style is not supported."
                )

            if stop_price:
                raise UnsupportedOrderParameters(
                    msg="Passing both stop_price and style is not supported."
                )

        if not isinstance(asset, Asset):
            raise UnsupportedOrderParameters(
                msg="Passing non-Asset argument to 'order()' is not supported."
                    " Use 'sid()' or 'symbol()' methods to look up an Asset."
            )

        for control in self.trading_controls:
            control.validate(asset,
                             amount,
                             self.updated_portfolio(),
                             self.get_datetime(),
                             self.trading_client.current_data)

    @staticmethod
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
        """
        Helper method for converting deprecated limit_price and stop_price
        arguments into ExecutionStyle instances.

        This function assumes that either style == None or (limit_price,
        stop_price) == (None, None).
        """
        # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.
        if style:
            assert (limit_price, stop_price) == (None, None)
            return style
        if limit_price and stop_price:
            return StopLimitOrder(limit_price, stop_price)
        if limit_price:
            return LimitOrder(limit_price)
        if stop_price:
            return StopOrder(stop_price)
        else:
            return MarketOrder()

    @api_method
    def order_value(self, sid, value,
                    limit_price=None, stop_price=None, style=None):
        """
        Place an order by desired value rather than desired number of shares.
        If the requested sid is found in the universe, the requested value is
        divided by its price to imply the number of shares to transact.
        If the Asset being ordered is a Future, the 'value' calculated
        is actually the exposure, as Futures have no 'value'.

        value > 0 :: Buy/Cover
        value < 0 :: Sell/Short
        Market order:    order(sid, value)
        Limit order:     order(sid, value, limit_price)
        Stop order:      order(sid, value, None, stop_price)
        StopLimit order: order(sid, value, limit_price, stop_price)
        """
        amount = self._calculate_order_value_amount(sid, value)
        return self.order(sid, amount,
                          limit_price=limit_price,
                          stop_price=stop_price,
                          style=style)

    @property
    def recorded_vars(self):
        return copy(self._recorded_vars)

    @property
    def portfolio(self):
        return self.updated_portfolio()

    def updated_portfolio(self):
        if self.portfolio_needs_update:
            self._portfolio = \
                self.perf_tracker.get_portfolio(self.performance_needs_update)
            self.portfolio_needs_update = False
            self.performance_needs_update = False
        return self._portfolio

    @property
    def account(self):
        return self.updated_account()

    def updated_account(self):
        if self.account_needs_update:
            self._account = \
                self.perf_tracker.get_account(self.performance_needs_update)
            self.account_needs_update = False
            self.performance_needs_update = False
        return self._account

    def set_logger(self, logger):
        self.logger = logger

    def on_dt_changed(self, dt):
        """
        Callback triggered by the simulation loop whenever the current dt
        changes.

        Any logic that should happen exactly once at the start of each datetime
        group should happen here.
        """
        assert isinstance(dt, datetime), \
            "Attempt to set algorithm's current time with non-datetime"
        assert dt.tzinfo == pytz.utc, \
            "Algorithm expects a utc datetime"

        self.datetime = dt
        self.perf_tracker.set_date(dt)
        self.blotter.set_date(dt)

    @api_method
    def get_datetime(self, tz=None):
        """
        Returns the simulation datetime.
        """
        dt = self.datetime
        assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime"

        if tz is not None:
            # Convert to the given timezone passed as a string or tzinfo.
            if isinstance(tz, string_types):
                tz = pytz.timezone(tz)
            dt = dt.astimezone(tz)

        return dt  # datetime.datetime objects are immutable.

    def set_transact(self, transact):
        """
        Set the method that will be called to create a
        transaction from open orders and trade events.
        """
        self.blotter.transact = transact

    def update_dividends(self, dividend_frame):
        """
        Set DataFrame used to process dividends.  DataFrame columns should
        contain at least the entries in zp.DIVIDEND_FIELDS.
        """
        self.perf_tracker.update_dividends(dividend_frame)

    @api_method
    def set_slippage(self, slippage):
        if not isinstance(slippage, SlippageModel):
            raise UnsupportedSlippageModel()
        if self.initialized:
            raise OverrideSlippagePostInit()
        self.slippage = slippage

    @api_method
    def set_commission(self, commission):
        if not isinstance(commission, (PerShare, PerTrade, PerDollar)):
            raise UnsupportedCommissionModel()

        if self.initialized:
            raise OverrideCommissionPostInit()
        self.commission = commission

    @api_method
    def set_symbol_lookup_date(self, dt):
        """
        Set the date for which symbols will be resolved to their sids
        (symbols may map to different firms or underlying assets at
        different times)
        """
        try:
            self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC')
        except ValueError:
            raise UnsupportedDatetimeFormat(input=dt,
                                            method='set_symbol_lookup_date')

    def set_sources(self, sources):
        assert isinstance(sources, list)
        self.sources = sources

    # Remain backwards compatibility
    @property
    def data_frequency(self):
        return self.sim_params.data_frequency

    @data_frequency.setter
    def data_frequency(self, value):
        assert value in ('daily', 'minute')
        self.sim_params.data_frequency = value

    @api_method
    def order_percent(self, sid, percent,
                      limit_price=None, stop_price=None, style=None):
        """
        Place an order in the specified asset corresponding to the given
        percent of the current portfolio value.

        Note that percent must expressed as a decimal (0.50 means 50\%).
        """
        value = self.portfolio.portfolio_value * percent
        return self.order_value(sid, value,
                                limit_price=limit_price,
                                stop_price=stop_price,
                                style=style)

    @api_method
    def order_target(self, sid, target,
                     limit_price=None, stop_price=None, style=None):
        """
        Place an order to adjust a position to a target number of shares. If
        the position doesn't already exist, this is equivalent to placing a new
        order. If the position does exist, this is equivalent to placing an
        order for the difference between the target number of shares and the
        current number of shares.
        """
        if sid in self.portfolio.positions:
            current_position = self.portfolio.positions[sid].amount
            req_shares = target - current_position
            return self.order(sid, req_shares,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)
        else:
            return self.order(sid, target,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)

    @api_method
    def order_target_value(self, sid, target,
                           limit_price=None, stop_price=None, style=None):
        """
        Place an order to adjust a position to a target value. If
        the position doesn't already exist, this is equivalent to placing a new
        order. If the position does exist, this is equivalent to placing an
        order for the difference between the target value and the
        current value.
        If the Asset being ordered is a Future, the 'target value' calculated
        is actually the target exposure, as Futures have no 'value'.
        """
        target_amount = self._calculate_order_value_amount(sid, target)
        return self.order_target(sid, target_amount,
                                 limit_price=limit_price,
                                 stop_price=stop_price,
                                 style=style)

    @api_method
    def order_target_percent(self, sid, target,
                             limit_price=None, stop_price=None, style=None):
        """
        Place an order to adjust a position to a target percent of the
        current portfolio value. If the position doesn't already exist, this is
        equivalent to placing a new order. If the position does exist, this is
        equivalent to placing an order for the difference between the target
        percent and the current percent.

        Note that target must expressed as a decimal (0.50 means 50\%).
        """
        target_value = self.portfolio.portfolio_value * target
        return self.order_target_value(sid, target_value,
                                       limit_price=limit_price,
                                       stop_price=stop_price,
                                       style=style)

    @api_method
    def get_open_orders(self, sid=None):
        if sid is None:
            return {
                key: [order.to_api_obj() for order in orders]
                for key, orders in iteritems(self.blotter.open_orders)
                if orders
            }
        if sid in self.blotter.open_orders:
            orders = self.blotter.open_orders[sid]
            return [order.to_api_obj() for order in orders]
        return []

    @api_method
    def get_order(self, order_id):
        if order_id in self.blotter.orders:
            return self.blotter.orders[order_id].to_api_obj()

    @api_method
    def cancel_order(self, order_param):
        order_id = order_param
        if isinstance(order_param, zipline.protocol.Order):
            order_id = order_param.id

        self.blotter.cancel(order_id)

    @api_method
    def add_history(self, bar_count, frequency, field, ffill=True):
        data_frequency = self.sim_params.data_frequency
        history_spec = HistorySpec(bar_count, frequency, field, ffill,
                                   data_frequency=data_frequency,
                                   env=self.trading_environment)
        self.history_specs[history_spec.key_str] = history_spec
        if self.initialized:
            if self.history_container:
                self.history_container.ensure_spec(
                    history_spec, self.datetime, self._most_recent_data,
                )
            else:
                self.history_container = self.history_container_class(
                    self.history_specs,
                    self.current_universe(),
                    self.sim_params.first_open,
                    self.sim_params.data_frequency,
                    env=self.trading_environment,
                )

    def get_history_spec(self, bar_count, frequency, field, ffill):
        spec_key = HistorySpec.spec_key(bar_count, frequency, field, ffill)
        if spec_key not in self.history_specs:
            data_freq = self.sim_params.data_frequency
            spec = HistorySpec(
                bar_count,
                frequency,
                field,
                ffill,
                data_frequency=data_freq,
                env=self.trading_environment,
            )
            self.history_specs[spec_key] = spec
            if not self.history_container:
                self.history_container = self.history_container_class(
                    self.history_specs,
                    self.current_universe(),
                    self.datetime,
                    self.sim_params.data_frequency,
                    bar_data=self._most_recent_data,
                    env=self.trading_environment,
                )
            self.history_container.ensure_spec(
                spec, self.datetime, self._most_recent_data,
            )
        return self.history_specs[spec_key]

    @api_method
    def history(self, bar_count, frequency, field, ffill=True):
        history_spec = self.get_history_spec(
            bar_count,
            frequency,
            field,
            ffill,
        )
        return self.history_container.get_history(history_spec, self.datetime)

    ####################
    # Account Controls #
    ####################

    def register_account_control(self, control):
        """
        Register a new AccountControl to be checked on each bar.
        """
        if self.initialized:
            raise RegisterAccountControlPostInit()
        self.account_controls.append(control)

    def validate_account_controls(self):
        for control in self.account_controls:
            control.validate(self.updated_portfolio(),
                             self.updated_account(),
                             self.get_datetime(),
                             self.trading_client.current_data)

    @api_method
    def set_max_leverage(self, max_leverage=None):
        """
        Set a limit on the maximum leverage of the algorithm.
        """
        control = MaxLeverage(max_leverage)
        self.register_account_control(control)

    ####################
    # Trading Controls #
    ####################

    def register_trading_control(self, control):
        """
        Register a new TradingControl to be checked prior to order calls.
        """
        if self.initialized:
            raise RegisterTradingControlPostInit()
        self.trading_controls.append(control)

    @api_method
    def set_max_position_size(self,
                              sid=None,
                              max_shares=None,
                              max_notional=None):
        """
        Set a limit on the number of shares and/or dollar value held for the
        given sid. Limits are treated as absolute values and are enforced at
        the time that the algo attempts to place an order for sid. This means
        that it's possible to end up with more than the max number of shares
        due to splits/dividends, and more than the max notional due to price
        improvement.

        If an algorithm attempts to place an order that would result in
        increasing the absolute value of shares/dollar value exceeding one of
        these limits, raise a TradingControlException.
        """
        control = MaxPositionSize(asset=sid,
                                  max_shares=max_shares,
                                  max_notional=max_notional)
        self.register_trading_control(control)

    @api_method
    def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
        """
        Set a limit on the number of shares and/or dollar value of any single
        order placed for sid.  Limits are treated as absolute values and are
        enforced at the time that the algo attempts to place an order for sid.

        If an algorithm attempts to place an order that would result in
        exceeding one of these limits, raise a TradingControlException.
        """
        control = MaxOrderSize(asset=sid,
                               max_shares=max_shares,
                               max_notional=max_notional)
        self.register_trading_control(control)

    @api_method
    def set_max_order_count(self, max_count):
        """
        Set a limit on the number of orders that can be placed within the given
        time interval.
        """
        control = MaxOrderCount(max_count)
        self.register_trading_control(control)

    @api_method
    def set_do_not_order_list(self, restricted_list):
        """
        Set a restriction on which sids can be ordered.
        """
        control = RestrictedListOrder(restricted_list)
        self.register_trading_control(control)

    @api_method
    def set_long_only(self):
        """
        Set a rule specifying that this algorithm cannot take short positions.
        """
        self.register_trading_control(LongOnly())

    ###########
    # FFC API #
    ###########
    @api_method
    @require_not_initialized(AddTermPostInit())
    def add_factor(self, factor, name):
        if name in self._factors:
            raise ValueError("Name %r is already a factor!" % name)
        self._factors[name] = factor

    @api_method
    @require_not_initialized(AddTermPostInit())
    def add_filter(self, filter):
        name = "anon_filter_%d" % len(self._filters)
        self._filters[name] = filter

    # Note: add_classifier is not yet implemented since you can't do anything
    # useful with classifiers yet.

    def _all_terms(self):
        # Merge all three dicts.
        return dict(
            chain.from_iterable(
                iteritems(terms)
                for terms in (self._filters, self._factors, self._classifiers)
            )
        )

    def compute_factor_matrix(self, start_date):
        """
        Compute a factor matrix containing at least the data necessary to
        provide values for `start_date`.

        Loads a factor matrix with data extending from `start_date` until a
        year from `start_date`, or until the end of the simulation.
        """
        days = self.trading_environment.trading_days

        # Load data starting from the previous trading day...
        start_date_loc = days.get_loc(start_date)

        # ...continuing until either the day before the simulation end, or
        # until 252 days of data have been loaded.  252 is a totally arbitrary
        # choice that seemed reasonable based on napkin math.
        sim_end = self.sim_params.last_close.normalize()
        end_loc = min(start_date_loc + 252, days.get_loc(sim_end))
        end_date = days[end_loc]

        return self.engine.factor_matrix(
            self._all_terms(),
            start_date,
            end_date,
        ), end_date

    def current_universe(self):
        return self._current_universe

    @classmethod
    def all_api_methods(cls):
        """
        Return a list of all the TradingAlgorithm API methods.
        """
        return [
            fn for fn in itervalues(vars(cls))
            if getattr(fn, 'is_api_method', False)
        ]
Ejemplo n.º 32
0
class AssetFinderTestCase(TestCase):
    def setUp(self):
        self.env = TradingEnvironment(load=noop_load)
        self.asset_finder_type = AssetFinder

    def test_lookup_symbol_delimited(self):
        as_of = pd.Timestamp("2013-01-01", tz="UTC")
        frame = pd.DataFrame.from_records(
            [
                {
                    "sid": i,
                    "symbol": "TEST.%d" % i,
                    "company_name": "company%d" % i,
                    "start_date": as_of.value,
                    "end_date": as_of.value,
                    "exchange": uuid.uuid4().hex,
                }
                for i in range(3)
            ]
        )
        self.env.write_data(equities_df=frame)
        finder = self.asset_finder_type(self.env.engine)
        asset_0, asset_1, asset_2 = (finder.retrieve_asset(i) for i in range(3))

        # we do it twice to catch caching bugs
        for i in range(2):
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol("TEST", as_of)
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol("TEST1", as_of)
            # '@' is not a supported delimiter
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol("TEST@1", as_of)

            # Adding an unnecessary fuzzy shouldn't matter.
            for fuzzy_char in ["-", "/", "_", "."]:
                self.assertEqual(asset_1, finder.lookup_symbol("TEST%s1" % fuzzy_char, as_of))

    def test_lookup_symbol_fuzzy(self):
        metadata = {0: {"symbol": "PRTY_HRD"}, 1: {"symbol": "BRKA"}, 2: {"symbol": "BRK_A"}}
        self.env.write_data(equities_data=metadata)
        finder = self.env.asset_finder
        dt = pd.Timestamp("2013-01-01", tz="UTC")

        # Try combos of looking up PRTYHRD with and without a time or fuzzy
        # Both non-fuzzys get no result
        with self.assertRaises(SymbolNotFound):
            finder.lookup_symbol("PRTYHRD", None)
        with self.assertRaises(SymbolNotFound):
            finder.lookup_symbol("PRTYHRD", dt)
        # Both fuzzys work
        self.assertEqual(0, finder.lookup_symbol("PRTYHRD", None, fuzzy=True))
        self.assertEqual(0, finder.lookup_symbol("PRTYHRD", dt, fuzzy=True))

        # Try combos of looking up PRTY_HRD, all returning sid 0
        self.assertEqual(0, finder.lookup_symbol("PRTY_HRD", None))
        self.assertEqual(0, finder.lookup_symbol("PRTY_HRD", dt))
        self.assertEqual(0, finder.lookup_symbol("PRTY_HRD", None, fuzzy=True))
        self.assertEqual(0, finder.lookup_symbol("PRTY_HRD", dt, fuzzy=True))

        # Try combos of looking up BRKA, all returning sid 1
        self.assertEqual(1, finder.lookup_symbol("BRKA", None))
        self.assertEqual(1, finder.lookup_symbol("BRKA", dt))
        self.assertEqual(1, finder.lookup_symbol("BRKA", None, fuzzy=True))
        self.assertEqual(1, finder.lookup_symbol("BRKA", dt, fuzzy=True))

        # Try combos of looking up BRK_A, all returning sid 2
        self.assertEqual(2, finder.lookup_symbol("BRK_A", None))
        self.assertEqual(2, finder.lookup_symbol("BRK_A", dt))
        self.assertEqual(2, finder.lookup_symbol("BRK_A", None, fuzzy=True))
        self.assertEqual(2, finder.lookup_symbol("BRK_A", dt, fuzzy=True))

    def test_lookup_symbol(self):

        # Incrementing by two so that start and end dates for each
        # generated Asset don't overlap (each Asset's end_date is the
        # day after its start date.)
        dates = pd.date_range("2013-01-01", freq="2D", periods=5, tz="UTC")
        df = pd.DataFrame.from_records(
            [
                {
                    "sid": i,
                    "symbol": "existing",
                    "start_date": date.value,
                    "end_date": (date + timedelta(days=1)).value,
                    "exchange": "NYSE",
                }
                for i, date in enumerate(dates)
            ]
        )
        self.env.write_data(equities_df=df)
        finder = self.asset_finder_type(self.env.engine)
        for _ in range(2):  # Run checks twice to test for caching bugs.
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol("NON_EXISTING", dates[0])

            with self.assertRaises(MultipleSymbolsFound):
                finder.lookup_symbol("EXISTING", None)

            for i, date in enumerate(dates):
                # Verify that we correctly resolve multiple symbols using
                # the supplied date
                result = finder.lookup_symbol("EXISTING", date)
                self.assertEqual(result.symbol, "EXISTING")
                self.assertEqual(result.sid, i)

    def test_lookup_symbol_from_multiple_valid(self):
        # This test asserts that we resolve conflicts in accordance with the
        # following rules when we have multiple assets holding the same symbol
        # at the same time:

        # If multiple SIDs exist for symbol S at time T, return the candidate
        # SID whose start_date is highest. (200 cases)

        # If multiple SIDs exist for symbol S at time T, the best candidate
        # SIDs share the highest start_date, return the SID with the highest
        # end_date. (34 cases)

        # It is the opinion of the author (ssanderson) that we should consider
        # this malformed input and fail here.  But this is the current indended
        # behavior of the code, and I accidentally broke it while refactoring.
        # These will serve as regression tests until the time comes that we
        # decide to enforce this as an error.

        # See https://github.com/quantopian/zipline/issues/837 for more
        # details.

        df = pd.DataFrame.from_records(
            [
                {
                    "sid": 1,
                    "symbol": "multiple",
                    "start_date": pd.Timestamp("2010-01-01"),
                    "end_date": pd.Timestamp("2012-01-01"),
                    "exchange": "NYSE",
                },
                # Same as asset 1, but with a later end date.
                {
                    "sid": 2,
                    "symbol": "multiple",
                    "start_date": pd.Timestamp("2010-01-01"),
                    "end_date": pd.Timestamp("2013-01-01"),
                    "exchange": "NYSE",
                },
                # Same as asset 1, but with a later start_date
                {
                    "sid": 3,
                    "symbol": "multiple",
                    "start_date": pd.Timestamp("2011-01-01"),
                    "end_date": pd.Timestamp("2012-01-01"),
                    "exchange": "NYSE",
                },
            ]
        )

        def check(expected_sid, date):
            result = finder.lookup_symbol("MULTIPLE", date)
            self.assertEqual(result.symbol, "MULTIPLE")
            self.assertEqual(result.sid, expected_sid)

        with tmp_asset_finder(finder_cls=self.asset_finder_type, equities=df) as finder:
            self.assertIsInstance(finder, self.asset_finder_type)

            # Sids 1 and 2 are eligible here.  We should get asset 2 because it
            # has the later end_date.
            check(2, pd.Timestamp("2010-12-31"))

            # Sids 1, 2, and 3 are eligible here.  We should get sid 3 because
            # it has a later start_date
            check(3, pd.Timestamp("2011-01-01"))

    def test_lookup_generic(self):
        """
        Ensure that lookup_generic works with various permutations of inputs.
        """
        with build_lookup_generic_cases(self.asset_finder_type) as cases:
            for finder, symbols, reference_date, expected in cases:
                results, missing = finder.lookup_generic(symbols, reference_date)
                self.assertEqual(results, expected)
                self.assertEqual(missing, [])

    def test_lookup_generic_handle_missing(self):
        data = pd.DataFrame.from_records(
            [
                {
                    "sid": 0,
                    "symbol": "real",
                    "start_date": pd.Timestamp("2013-1-1", tz="UTC"),
                    "end_date": pd.Timestamp("2014-1-1", tz="UTC"),
                    "exchange": "",
                },
                {
                    "sid": 1,
                    "symbol": "also_real",
                    "start_date": pd.Timestamp("2013-1-1", tz="UTC"),
                    "end_date": pd.Timestamp("2014-1-1", tz="UTC"),
                    "exchange": "",
                },
                # Sid whose end date is before our query date.  We should
                # still correctly find it.
                {
                    "sid": 2,
                    "symbol": "real_but_old",
                    "start_date": pd.Timestamp("2002-1-1", tz="UTC"),
                    "end_date": pd.Timestamp("2003-1-1", tz="UTC"),
                    "exchange": "",
                },
                # Sid whose start_date is **after** our query date.  We should
                # **not** find it.
                {
                    "sid": 3,
                    "symbol": "real_but_in_the_future",
                    "start_date": pd.Timestamp("2014-1-1", tz="UTC"),
                    "end_date": pd.Timestamp("2020-1-1", tz="UTC"),
                    "exchange": "THE FUTURE",
                },
            ]
        )
        self.env.write_data(equities_df=data)
        finder = self.asset_finder_type(self.env.engine)
        results, missing = finder.lookup_generic(
            ["REAL", 1, "FAKE", "REAL_BUT_OLD", "REAL_BUT_IN_THE_FUTURE"], pd.Timestamp("2013-02-01", tz="UTC")
        )

        self.assertEqual(len(results), 3)
        self.assertEqual(results[0].symbol, "REAL")
        self.assertEqual(results[0].sid, 0)
        self.assertEqual(results[1].symbol, "ALSO_REAL")
        self.assertEqual(results[1].sid, 1)
        self.assertEqual(results[2].symbol, "REAL_BUT_OLD")
        self.assertEqual(results[2].sid, 2)

        self.assertEqual(len(missing), 2)
        self.assertEqual(missing[0], "FAKE")
        self.assertEqual(missing[1], "REAL_BUT_IN_THE_FUTURE")

    def test_insert_metadata(self):
        data = {0: {"start_date": "2014-01-01", "end_date": "2015-01-01", "symbol": "PLAY", "foo_data": "FOO"}}
        self.env.write_data(equities_data=data)
        finder = self.asset_finder_type(self.env.engine)
        # Test proper insertion
        equity = finder.retrieve_asset(0)
        self.assertIsInstance(equity, Equity)
        self.assertEqual("PLAY", equity.symbol)
        self.assertEqual(pd.Timestamp("2015-01-01", tz="UTC"), equity.end_date)

        # Test invalid field
        with self.assertRaises(AttributeError):
            equity.foo_data

    def test_consume_metadata(self):

        # Test dict consumption
        dict_to_consume = {0: {"symbol": "PLAY"}, 1: {"symbol": "MSFT"}}
        self.env.write_data(equities_data=dict_to_consume)
        finder = self.asset_finder_type(self.env.engine)

        equity = finder.retrieve_asset(0)
        self.assertIsInstance(equity, Equity)
        self.assertEqual("PLAY", equity.symbol)

        # Test dataframe consumption
        df = pd.DataFrame(columns=["asset_name", "exchange"], index=[0, 1])
        df["asset_name"][0] = "Dave'N'Busters"
        df["exchange"][0] = "NASDAQ"
        df["asset_name"][1] = "Microsoft"
        df["exchange"][1] = "NYSE"
        self.env = TradingEnvironment(load=noop_load)
        self.env.write_data(equities_df=df)
        finder = self.asset_finder_type(self.env.engine)
        self.assertEqual("NASDAQ", finder.retrieve_asset(0).exchange)
        self.assertEqual("Microsoft", finder.retrieve_asset(1).asset_name)

    def test_consume_asset_as_identifier(self):
        # Build some end dates
        eq_end = pd.Timestamp("2012-01-01", tz="UTC")
        fut_end = pd.Timestamp("2008-01-01", tz="UTC")

        # Build some simple Assets
        equity_asset = Equity(1, symbol="TESTEQ", end_date=eq_end)
        future_asset = Future(200, symbol="TESTFUT", end_date=fut_end)

        # Consume the Assets
        self.env.write_data(equities_identifiers=[equity_asset], futures_identifiers=[future_asset])
        finder = self.asset_finder_type(self.env.engine)

        # Test equality with newly built Assets
        self.assertEqual(equity_asset, finder.retrieve_asset(1))
        self.assertEqual(future_asset, finder.retrieve_asset(200))
        self.assertEqual(eq_end, finder.retrieve_asset(1).end_date)
        self.assertEqual(fut_end, finder.retrieve_asset(200).end_date)

    def test_sid_assignment(self):

        # This metadata does not contain SIDs
        metadata = ["PLAY", "MSFT"]

        today = normalize_date(pd.Timestamp("2015-07-09", tz="UTC"))

        # Write data with sid assignment
        self.env.write_data(equities_identifiers=metadata, allow_sid_assignment=True)

        # Verify that Assets were built and different sids were assigned
        finder = self.asset_finder_type(self.env.engine)
        play = finder.lookup_symbol("PLAY", today)
        msft = finder.lookup_symbol("MSFT", today)
        self.assertEqual("PLAY", play.symbol)
        self.assertIsNotNone(play.sid)
        self.assertNotEqual(play.sid, msft.sid)

    def test_sid_assignment_failure(self):

        # This metadata does not contain SIDs
        metadata = ["PLAY", "MSFT"]

        # Write data without sid assignment, asserting failure
        with self.assertRaises(SidAssignmentError):
            self.env.write_data(equities_identifiers=metadata, allow_sid_assignment=False)

    def test_security_dates_warning(self):

        # Build an asset with an end_date
        eq_end = pd.Timestamp("2012-01-01", tz="UTC")
        equity_asset = Equity(1, symbol="TESTEQ", end_date=eq_end)

        # Catch all warnings
        with warnings.catch_warnings(record=True) as w:
            # Cause all warnings to always be triggered
            warnings.simplefilter("always")
            equity_asset.security_start_date
            equity_asset.security_end_date
            equity_asset.security_name
            # Verify the warning
            self.assertEqual(3, len(w))
            for warning in w:
                self.assertTrue(issubclass(warning.category, DeprecationWarning))

    def test_lookup_future_chain(self):
        metadata = {
            # Notice day is today, so should be valid.
            0: {
                "symbol": "ADN15",
                "root_symbol": "AD",
                "notice_date": pd.Timestamp("2015-05-14", tz="UTC"),
                "expiration_date": pd.Timestamp("2015-06-14", tz="UTC"),
                "start_date": pd.Timestamp("2015-01-01", tz="UTC"),
            },
            1: {
                "symbol": "ADV15",
                "root_symbol": "AD",
                "notice_date": pd.Timestamp("2015-08-14", tz="UTC"),
                "expiration_date": pd.Timestamp("2015-09-14", tz="UTC"),
                "start_date": pd.Timestamp("2015-01-01", tz="UTC"),
            },
            # Starts trading today, so should be valid.
            2: {
                "symbol": "ADF16",
                "root_symbol": "AD",
                "notice_date": pd.Timestamp("2015-11-16", tz="UTC"),
                "expiration_date": pd.Timestamp("2015-12-16", tz="UTC"),
                "start_date": pd.Timestamp("2015-05-14", tz="UTC"),
            },
            # Starts trading in August, so not valid.
            3: {
                "symbol": "ADX16",
                "root_symbol": "AD",
                "notice_date": pd.Timestamp("2015-11-16", tz="UTC"),
                "expiration_date": pd.Timestamp("2015-12-16", tz="UTC"),
                "start_date": pd.Timestamp("2015-08-01", tz="UTC"),
            },
            # Notice date comes after expiration
            4: {
                "symbol": "ADZ16",
                "root_symbol": "AD",
                "notice_date": pd.Timestamp("2016-11-25", tz="UTC"),
                "expiration_date": pd.Timestamp("2016-11-16", tz="UTC"),
                "start_date": pd.Timestamp("2015-08-01", tz="UTC"),
            },
            # This contract has no start date and also this contract should be
            # last in all chains
            5: {
                "symbol": "ADZ20",
                "root_symbol": "AD",
                "notice_date": pd.Timestamp("2020-11-25", tz="UTC"),
                "expiration_date": pd.Timestamp("2020-11-16", tz="UTC"),
            },
        }
        self.env.write_data(futures_data=metadata)
        finder = self.asset_finder_type(self.env.engine)
        dt = pd.Timestamp("2015-05-14", tz="UTC")
        dt_2 = pd.Timestamp("2015-10-14", tz="UTC")
        dt_3 = pd.Timestamp("2016-11-17", tz="UTC")

        # Check that we get the expected number of contracts, in the
        # right order
        ad_contracts = finder.lookup_future_chain("AD", dt)
        self.assertEqual(len(ad_contracts), 6)
        self.assertEqual(ad_contracts[0].sid, 0)
        self.assertEqual(ad_contracts[1].sid, 1)
        self.assertEqual(ad_contracts[5].sid, 5)

        # Check that, when some contracts have expired, the chain has advanced
        # properly to the next contracts
        ad_contracts = finder.lookup_future_chain("AD", dt_2)
        self.assertEqual(len(ad_contracts), 4)
        self.assertEqual(ad_contracts[0].sid, 2)
        self.assertEqual(ad_contracts[3].sid, 5)

        # Check that when the expiration_date has passed but the
        # notice_date hasn't, contract is still considered invalid.
        ad_contracts = finder.lookup_future_chain("AD", dt_3)
        self.assertEqual(len(ad_contracts), 1)
        self.assertEqual(ad_contracts[0].sid, 5)

        # Check that pd.NaT for as_of_date gives the whole chain
        ad_contracts = finder.lookup_future_chain("AD", pd.NaT)
        self.assertEqual(len(ad_contracts), 6)
        self.assertEqual(ad_contracts[5].sid, 5)

    def test_map_identifier_index_to_sids(self):
        # Build an empty finder and some Assets
        dt = pd.Timestamp("2014-01-01", tz="UTC")
        finder = self.asset_finder_type(self.env.engine)
        asset1 = Equity(1, symbol="AAPL")
        asset2 = Equity(2, symbol="GOOG")
        asset200 = Future(200, symbol="CLK15")
        asset201 = Future(201, symbol="CLM15")

        # Check for correct mapping and types
        pre_map = [asset1, asset2, asset200, asset201]
        post_map = finder.map_identifier_index_to_sids(pre_map, dt)
        self.assertListEqual([1, 2, 200, 201], post_map)
        for sid in post_map:
            self.assertIsInstance(sid, int)

        # Change order and check mapping again
        pre_map = [asset201, asset2, asset200, asset1]
        post_map = finder.map_identifier_index_to_sids(pre_map, dt)
        self.assertListEqual([201, 2, 200, 1], post_map)

    def test_compute_lifetimes(self):
        num_assets = 4
        trading_day = self.env.trading_day
        first_start = pd.Timestamp("2015-04-01", tz="UTC")

        frame = make_rotating_equity_info(
            num_assets=num_assets,
            first_start=first_start,
            frequency=self.env.trading_day,
            periods_between_starts=3,
            asset_lifetime=5,
        )

        self.env.write_data(equities_df=frame)
        finder = self.env.asset_finder

        all_dates = pd.date_range(start=first_start, end=frame.end_date.max(), freq=trading_day)

        for dates in all_subindices(all_dates):
            expected_with_start_raw = full(shape=(len(dates), num_assets), fill_value=False, dtype=bool)
            expected_no_start_raw = full(shape=(len(dates), num_assets), fill_value=False, dtype=bool)

            for i, date in enumerate(dates):
                it = frame[["start_date", "end_date"]].itertuples()
                for j, start, end in it:
                    # This way of doing the checks is redundant, but very
                    # clear.
                    if start <= date <= end:
                        expected_with_start_raw[i, j] = True
                        if start < date:
                            expected_no_start_raw[i, j] = True

            expected_with_start = pd.DataFrame(data=expected_with_start_raw, index=dates, columns=frame.index.values)
            result = finder.lifetimes(dates, include_start_date=True)
            assert_frame_equal(result, expected_with_start)

            expected_no_start = pd.DataFrame(data=expected_no_start_raw, index=dates, columns=frame.index.values)
            result = finder.lifetimes(dates, include_start_date=False)
            assert_frame_equal(result, expected_no_start)

    def test_sids(self):
        # Ensure that the sids property of the AssetFinder is functioning
        self.env.write_data(equities_identifiers=[1, 2, 3])
        sids = self.env.asset_finder.sids
        self.assertEqual(3, len(sids))
        self.assertTrue(1 in sids)
        self.assertTrue(2 in sids)
        self.assertTrue(3 in sids)

    def test_group_by_type(self):
        equities = make_simple_equity_info(
            range(5), start_date=pd.Timestamp("2014-01-01"), end_date=pd.Timestamp("2015-01-01")
        )
        futures = make_commodity_future_info(first_sid=6, root_symbols=["CL"], years=[2014])
        # Intersecting sid queries, to exercise loading of partially-cached
        # results.
        queries = [([0, 1, 3], [6, 7]), ([0, 2, 3], [7, 10]), (list(equities.index), list(futures.index))]
        with tmp_asset_finder(equities=equities, futures=futures) as finder:
            for equity_sids, future_sids in queries:
                results = finder.group_by_type(equity_sids + future_sids)
                self.assertEqual(results, {"equity": set(equity_sids), "future": set(future_sids)})

    @parameterized.expand(
        [
            (Equity, "retrieve_equities", EquitiesNotFound),
            (Future, "retrieve_futures_contracts", FutureContractsNotFound),
        ]
    )
    def test_retrieve_specific_type(self, type_, lookup_name, failure_type):
        equities = make_simple_equity_info(
            range(5), start_date=pd.Timestamp("2014-01-01"), end_date=pd.Timestamp("2015-01-01")
        )
        max_equity = equities.index.max()
        futures = make_commodity_future_info(first_sid=max_equity + 1, root_symbols=["CL"], years=[2014])
        equity_sids = [0, 1]
        future_sids = [max_equity + 1, max_equity + 2, max_equity + 3]
        if type_ == Equity:
            success_sids = equity_sids
            fail_sids = future_sids
        else:
            fail_sids = equity_sids
            success_sids = future_sids

        with tmp_asset_finder(equities=equities, futures=futures) as finder:
            # Run twice to exercise caching.
            lookup = getattr(finder, lookup_name)
            for _ in range(2):
                results = lookup(success_sids)
                self.assertIsInstance(results, dict)
                self.assertEqual(set(results.keys()), set(success_sids))
                self.assertEqual(valmap(int, results), dict(zip(success_sids, success_sids)))
                self.assertEqual({type_}, {type(asset) for asset in itervalues(results)})
                with self.assertRaises(failure_type):
                    lookup(fail_sids)
                with self.assertRaises(failure_type):
                    # Should fail if **any** of the assets are bad.
                    lookup([success_sids[0], fail_sids[0]])

    def test_retrieve_all(self):
        equities = make_simple_equity_info(
            range(5), start_date=pd.Timestamp("2014-01-01"), end_date=pd.Timestamp("2015-01-01")
        )
        max_equity = equities.index.max()
        futures = make_commodity_future_info(first_sid=max_equity + 1, root_symbols=["CL"], years=[2014])

        with tmp_asset_finder(equities=equities, futures=futures) as finder:
            all_sids = finder.sids
            self.assertEqual(len(all_sids), len(equities) + len(futures))
            queries = [
                # Empty Query.
                (),
                # Only Equities.
                tuple(equities.index[:2]),
                # Only Futures.
                tuple(futures.index[:3]),
                # Mixed, all cache misses.
                tuple(equities.index[2:]) + tuple(futures.index[3:]),
                # Mixed, all cache hits.
                tuple(equities.index[2:]) + tuple(futures.index[3:]),
                # Everything.
                all_sids,
                all_sids,
            ]
            for sids in queries:
                equity_sids = [i for i in sids if i <= max_equity]
                future_sids = [i for i in sids if i > max_equity]
                results = finder.retrieve_all(sids)
                self.assertEqual(sids, tuple(map(int, results)))

                self.assertEqual(
                    [Equity for _ in equity_sids] + [Future for _ in future_sids], list(map(type, results))
                )
                self.assertEqual(
                    (list(equities.symbol.loc[equity_sids]) + list(futures.symbol.loc[future_sids])),
                    list(asset.symbol for asset in results),
                )

    @parameterized.expand(
        [
            (EquitiesNotFound, "equity", "equities"),
            (FutureContractsNotFound, "future contract", "future contracts"),
            (SidsNotFound, "asset", "assets"),
        ]
    )
    def test_error_message_plurality(self, error_type, singular, plural):
        try:
            raise error_type(sids=[1])
        except error_type as e:
            self.assertEqual(str(e), "No {singular} found for sid: 1.".format(singular=singular))
        try:
            raise error_type(sids=[1, 2])
        except error_type as e:
            self.assertEqual(str(e), "No {plural} found for sids: [1, 2].".format(plural=plural))
Ejemplo n.º 33
0
def build_lookup_generic_cases():
    """
    Generate test cases for AssetFinder test_lookup_generic.
    """

    unique_start = pd.Timestamp('2013-01-01', tz='UTC')
    unique_end = pd.Timestamp('2014-01-01', tz='UTC')

    dupe_0_start = pd.Timestamp('2013-01-01', tz='UTC')
    dupe_0_end = dupe_0_start + timedelta(days=1)

    dupe_1_start = pd.Timestamp('2013-01-03', tz='UTC')
    dupe_1_end = dupe_1_start + timedelta(days=1)

    frame = pd.DataFrame.from_records(
        [
            {
                'sid': 0,
                'symbol': 'duplicated',
                'start_date': dupe_0_start.value,
                'end_date': dupe_0_end.value,
                'exchange': '',
            },
            {
                'sid': 1,
                'symbol': 'duplicated',
                'start_date': dupe_1_start.value,
                'end_date': dupe_1_end.value,
                'exchange': '',
            },
            {
                'sid': 2,
                'symbol': 'unique',
                'start_date': unique_start.value,
                'end_date': unique_end.value,
                'exchange': '',
            },
        ],
        index='sid')
    env = TradingEnvironment()
    env.write_data(equities_df=frame)
    finder = env.asset_finder
    dupe_0, dupe_1, unique = assets = [
        finder.retrieve_asset(i)
        for i in range(3)
    ]

    dupe_0_start = dupe_0.start_date
    dupe_1_start = dupe_1.start_date
    cases = [
        ##
        # Scalars

        # Asset object
        (finder, assets[0], None, assets[0]),
        (finder, assets[1], None, assets[1]),
        (finder, assets[2], None, assets[2]),
        # int
        (finder, 0, None, assets[0]),
        (finder, 1, None, assets[1]),
        (finder, 2, None, assets[2]),
        # Duplicated symbol with resolution date
        (finder, 'DUPLICATED', dupe_0_start, dupe_0),
        (finder, 'DUPLICATED', dupe_1_start, dupe_1),
        # Unique symbol, with or without resolution date.
        (finder, 'UNIQUE', unique_start, unique),
        (finder, 'UNIQUE', None, unique),

        ##
        # Iterables

        # Iterables of Asset objects.
        (finder, assets, None, assets),
        (finder, iter(assets), None, assets),
        # Iterables of ints
        (finder, (0, 1), None, assets[:-1]),
        (finder, iter((0, 1)), None, assets[:-1]),
        # Iterables of symbols.
        (finder, ('DUPLICATED', 'UNIQUE'), dupe_0_start, [dupe_0, unique]),
        (finder, ('DUPLICATED', 'UNIQUE'), dupe_1_start, [dupe_1, unique]),
        # Mixed types
        (finder,
         ('DUPLICATED', 2, 'UNIQUE', 1, dupe_1),
         dupe_0_start,
         [dupe_0, assets[2], unique, assets[1], dupe_1]),
    ]
    return cases
Ejemplo n.º 34
0
class AssetFinderTestCase(TestCase):
    def setUp(self):
        self.env = TradingEnvironment(load=noop_load)
        self.asset_finder_type = AssetFinder

    def test_lookup_symbol_delimited(self):
        as_of = pd.Timestamp('2013-01-01', tz='UTC')
        frame = pd.DataFrame.from_records([{
            'sid': i,
            'symbol': 'TEST.%d' % i,
            'company_name': "company%d" % i,
            'start_date': as_of.value,
            'end_date': as_of.value,
            'exchange': uuid.uuid4().hex
        } for i in range(3)])
        self.env.write_data(equities_df=frame)
        finder = self.asset_finder_type(self.env.engine)
        asset_0, asset_1, asset_2 = (finder.retrieve_asset(i)
                                     for i in range(3))

        # we do it twice to catch caching bugs
        for i in range(2):
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol('TEST', as_of)
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol('TEST1', as_of)
            # '@' is not a supported delimiter
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol('TEST@1', as_of)

            # Adding an unnecessary fuzzy shouldn't matter.
            for fuzzy_char in ['-', '/', '_', '.']:
                self.assertEqual(
                    asset_1, finder.lookup_symbol('TEST%s1' % fuzzy_char,
                                                  as_of))

    def test_lookup_symbol_fuzzy(self):
        metadata = {
            0: {
                'symbol': 'PRTY_HRD'
            },
            1: {
                'symbol': 'BRKA'
            },
            2: {
                'symbol': 'BRK_A'
            },
        }
        self.env.write_data(equities_data=metadata)
        finder = self.env.asset_finder
        dt = pd.Timestamp('2013-01-01', tz='UTC')

        # Try combos of looking up PRTYHRD with and without a time or fuzzy
        # Both non-fuzzys get no result
        with self.assertRaises(SymbolNotFound):
            finder.lookup_symbol('PRTYHRD', None)
        with self.assertRaises(SymbolNotFound):
            finder.lookup_symbol('PRTYHRD', dt)
        # Both fuzzys work
        self.assertEqual(0, finder.lookup_symbol('PRTYHRD', None, fuzzy=True))
        self.assertEqual(0, finder.lookup_symbol('PRTYHRD', dt, fuzzy=True))

        # Try combos of looking up PRTY_HRD, all returning sid 0
        self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', None))
        self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', dt))
        self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', None, fuzzy=True))
        self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', dt, fuzzy=True))

        # Try combos of looking up BRKA, all returning sid 1
        self.assertEqual(1, finder.lookup_symbol('BRKA', None))
        self.assertEqual(1, finder.lookup_symbol('BRKA', dt))
        self.assertEqual(1, finder.lookup_symbol('BRKA', None, fuzzy=True))
        self.assertEqual(1, finder.lookup_symbol('BRKA', dt, fuzzy=True))

        # Try combos of looking up BRK_A, all returning sid 2
        self.assertEqual(2, finder.lookup_symbol('BRK_A', None))
        self.assertEqual(2, finder.lookup_symbol('BRK_A', dt))
        self.assertEqual(2, finder.lookup_symbol('BRK_A', None, fuzzy=True))
        self.assertEqual(2, finder.lookup_symbol('BRK_A', dt, fuzzy=True))

    def test_lookup_symbol(self):

        # Incrementing by two so that start and end dates for each
        # generated Asset don't overlap (each Asset's end_date is the
        # day after its start date.)
        dates = pd.date_range('2013-01-01', freq='2D', periods=5, tz='UTC')
        df = pd.DataFrame.from_records([{
            'sid':
            i,
            'symbol':
            'existing',
            'start_date':
            date.value,
            'end_date': (date + timedelta(days=1)).value,
            'exchange':
            'NYSE',
        } for i, date in enumerate(dates)])
        self.env.write_data(equities_df=df)
        finder = self.asset_finder_type(self.env.engine)
        for _ in range(2):  # Run checks twice to test for caching bugs.
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol('NON_EXISTING', dates[0])

            with self.assertRaises(MultipleSymbolsFound):
                finder.lookup_symbol('EXISTING', None)

            for i, date in enumerate(dates):
                # Verify that we correctly resolve multiple symbols using
                # the supplied date
                result = finder.lookup_symbol('EXISTING', date)
                self.assertEqual(result.symbol, 'EXISTING')
                self.assertEqual(result.sid, i)

    def test_lookup_symbol_from_multiple_valid(self):
        # This test asserts that we resolve conflicts in accordance with the
        # following rules when we have multiple assets holding the same symbol
        # at the same time:

        # If multiple SIDs exist for symbol S at time T, return the candidate
        # SID whose start_date is highest. (200 cases)

        # If multiple SIDs exist for symbol S at time T, the best candidate
        # SIDs share the highest start_date, return the SID with the highest
        # end_date. (34 cases)

        # It is the opinion of the author (ssanderson) that we should consider
        # this malformed input and fail here.  But this is the current indended
        # behavior of the code, and I accidentally broke it while refactoring.
        # These will serve as regression tests until the time comes that we
        # decide to enforce this as an error.

        # See https://github.com/quantopian/zipline/issues/837 for more
        # details.

        df = pd.DataFrame.from_records([
            {
                'sid': 1,
                'symbol': 'multiple',
                'start_date': pd.Timestamp('2010-01-01'),
                'end_date': pd.Timestamp('2012-01-01'),
                'exchange': 'NYSE'
            },
            # Same as asset 1, but with a later end date.
            {
                'sid': 2,
                'symbol': 'multiple',
                'start_date': pd.Timestamp('2010-01-01'),
                'end_date': pd.Timestamp('2013-01-01'),
                'exchange': 'NYSE'
            },
            # Same as asset 1, but with a later start_date
            {
                'sid': 3,
                'symbol': 'multiple',
                'start_date': pd.Timestamp('2011-01-01'),
                'end_date': pd.Timestamp('2012-01-01'),
                'exchange': 'NYSE'
            },
        ])

        def check(expected_sid, date):
            result = finder.lookup_symbol(
                'MULTIPLE',
                date,
            )
            self.assertEqual(result.symbol, 'MULTIPLE')
            self.assertEqual(result.sid, expected_sid)

        with tmp_asset_finder(finder_cls=self.asset_finder_type,
                              equities=df) as finder:
            self.assertIsInstance(finder, self.asset_finder_type)

            # Sids 1 and 2 are eligible here.  We should get asset 2 because it
            # has the later end_date.
            check(2, pd.Timestamp('2010-12-31'))

            # Sids 1, 2, and 3 are eligible here.  We should get sid 3 because
            # it has a later start_date
            check(3, pd.Timestamp('2011-01-01'))

    def test_lookup_generic(self):
        """
        Ensure that lookup_generic works with various permutations of inputs.
        """
        with build_lookup_generic_cases(self.asset_finder_type) as cases:
            for finder, symbols, reference_date, expected in cases:
                results, missing = finder.lookup_generic(
                    symbols, reference_date)
                self.assertEqual(results, expected)
                self.assertEqual(missing, [])

    def test_lookup_generic_handle_missing(self):
        data = pd.DataFrame.from_records([
            {
                'sid': 0,
                'symbol': 'real',
                'start_date': pd.Timestamp('2013-1-1', tz='UTC'),
                'end_date': pd.Timestamp('2014-1-1', tz='UTC'),
                'exchange': '',
            },
            {
                'sid': 1,
                'symbol': 'also_real',
                'start_date': pd.Timestamp('2013-1-1', tz='UTC'),
                'end_date': pd.Timestamp('2014-1-1', tz='UTC'),
                'exchange': '',
            },
            # Sid whose end date is before our query date.  We should
            # still correctly find it.
            {
                'sid': 2,
                'symbol': 'real_but_old',
                'start_date': pd.Timestamp('2002-1-1', tz='UTC'),
                'end_date': pd.Timestamp('2003-1-1', tz='UTC'),
                'exchange': '',
            },
            # Sid whose start_date is **after** our query date.  We should
            # **not** find it.
            {
                'sid': 3,
                'symbol': 'real_but_in_the_future',
                'start_date': pd.Timestamp('2014-1-1', tz='UTC'),
                'end_date': pd.Timestamp('2020-1-1', tz='UTC'),
                'exchange': 'THE FUTURE',
            },
        ])
        self.env.write_data(equities_df=data)
        finder = self.asset_finder_type(self.env.engine)
        results, missing = finder.lookup_generic(
            ['REAL', 1, 'FAKE', 'REAL_BUT_OLD', 'REAL_BUT_IN_THE_FUTURE'],
            pd.Timestamp('2013-02-01', tz='UTC'),
        )

        self.assertEqual(len(results), 3)
        self.assertEqual(results[0].symbol, 'REAL')
        self.assertEqual(results[0].sid, 0)
        self.assertEqual(results[1].symbol, 'ALSO_REAL')
        self.assertEqual(results[1].sid, 1)
        self.assertEqual(results[2].symbol, 'REAL_BUT_OLD')
        self.assertEqual(results[2].sid, 2)

        self.assertEqual(len(missing), 2)
        self.assertEqual(missing[0], 'FAKE')
        self.assertEqual(missing[1], 'REAL_BUT_IN_THE_FUTURE')

    def test_insert_metadata(self):
        data = {
            0: {
                'start_date': '2014-01-01',
                'end_date': '2015-01-01',
                'symbol': "PLAY",
                'foo_data': "FOO"
            }
        }
        self.env.write_data(equities_data=data)
        finder = self.asset_finder_type(self.env.engine)
        # Test proper insertion
        equity = finder.retrieve_asset(0)
        self.assertIsInstance(equity, Equity)
        self.assertEqual('PLAY', equity.symbol)
        self.assertEqual(pd.Timestamp('2015-01-01', tz='UTC'), equity.end_date)

        # Test invalid field
        with self.assertRaises(AttributeError):
            equity.foo_data

    def test_consume_metadata(self):

        # Test dict consumption
        dict_to_consume = {0: {'symbol': 'PLAY'}, 1: {'symbol': 'MSFT'}}
        self.env.write_data(equities_data=dict_to_consume)
        finder = self.asset_finder_type(self.env.engine)

        equity = finder.retrieve_asset(0)
        self.assertIsInstance(equity, Equity)
        self.assertEqual('PLAY', equity.symbol)

        # Test dataframe consumption
        df = pd.DataFrame(columns=['asset_name', 'exchange'], index=[0, 1])
        df['asset_name'][0] = "Dave'N'Busters"
        df['exchange'][0] = "NASDAQ"
        df['asset_name'][1] = "Microsoft"
        df['exchange'][1] = "NYSE"
        self.env = TradingEnvironment(load=noop_load)
        self.env.write_data(equities_df=df)
        finder = self.asset_finder_type(self.env.engine)
        self.assertEqual('NASDAQ', finder.retrieve_asset(0).exchange)
        self.assertEqual('Microsoft', finder.retrieve_asset(1).asset_name)

    def test_consume_asset_as_identifier(self):
        # Build some end dates
        eq_end = pd.Timestamp('2012-01-01', tz='UTC')
        fut_end = pd.Timestamp('2008-01-01', tz='UTC')

        # Build some simple Assets
        equity_asset = Equity(1, symbol="TESTEQ", end_date=eq_end)
        future_asset = Future(200, symbol="TESTFUT", end_date=fut_end)

        # Consume the Assets
        self.env.write_data(equities_identifiers=[equity_asset],
                            futures_identifiers=[future_asset])
        finder = self.asset_finder_type(self.env.engine)

        # Test equality with newly built Assets
        self.assertEqual(equity_asset, finder.retrieve_asset(1))
        self.assertEqual(future_asset, finder.retrieve_asset(200))
        self.assertEqual(eq_end, finder.retrieve_asset(1).end_date)
        self.assertEqual(fut_end, finder.retrieve_asset(200).end_date)

    def test_sid_assignment(self):

        # This metadata does not contain SIDs
        metadata = ['PLAY', 'MSFT']

        today = normalize_date(pd.Timestamp('2015-07-09', tz='UTC'))

        # Write data with sid assignment
        self.env.write_data(equities_identifiers=metadata,
                            allow_sid_assignment=True)

        # Verify that Assets were built and different sids were assigned
        finder = self.asset_finder_type(self.env.engine)
        play = finder.lookup_symbol('PLAY', today)
        msft = finder.lookup_symbol('MSFT', today)
        self.assertEqual('PLAY', play.symbol)
        self.assertIsNotNone(play.sid)
        self.assertNotEqual(play.sid, msft.sid)

    def test_sid_assignment_failure(self):

        # This metadata does not contain SIDs
        metadata = ['PLAY', 'MSFT']

        # Write data without sid assignment, asserting failure
        with self.assertRaises(SidAssignmentError):
            self.env.write_data(equities_identifiers=metadata,
                                allow_sid_assignment=False)

    def test_security_dates_warning(self):

        # Build an asset with an end_date
        eq_end = pd.Timestamp('2012-01-01', tz='UTC')
        equity_asset = Equity(1, symbol="TESTEQ", end_date=eq_end)

        # Catch all warnings
        with warnings.catch_warnings(record=True) as w:
            # Cause all warnings to always be triggered
            warnings.simplefilter("always")
            equity_asset.security_start_date
            equity_asset.security_end_date
            equity_asset.security_name
            # Verify the warning
            self.assertEqual(3, len(w))
            for warning in w:
                self.assertTrue(
                    issubclass(warning.category, DeprecationWarning))

    def test_lookup_future_chain(self):
        metadata = {
            # Notice day is today, so should be valid.
            0: {
                'symbol': 'ADN15',
                'root_symbol': 'AD',
                'notice_date': pd.Timestamp('2015-05-14', tz='UTC'),
                'expiration_date': pd.Timestamp('2015-06-14', tz='UTC'),
                'start_date': pd.Timestamp('2015-01-01', tz='UTC')
            },
            1: {
                'symbol': 'ADV15',
                'root_symbol': 'AD',
                'notice_date': pd.Timestamp('2015-08-14', tz='UTC'),
                'expiration_date': pd.Timestamp('2015-09-14', tz='UTC'),
                'start_date': pd.Timestamp('2015-01-01', tz='UTC')
            },
            # Starts trading today, so should be valid.
            2: {
                'symbol': 'ADF16',
                'root_symbol': 'AD',
                'notice_date': pd.Timestamp('2015-11-16', tz='UTC'),
                'expiration_date': pd.Timestamp('2015-12-16', tz='UTC'),
                'start_date': pd.Timestamp('2015-05-14', tz='UTC')
            },
            # Starts trading in August, so not valid.
            3: {
                'symbol': 'ADX16',
                'root_symbol': 'AD',
                'notice_date': pd.Timestamp('2015-11-16', tz='UTC'),
                'expiration_date': pd.Timestamp('2015-12-16', tz='UTC'),
                'start_date': pd.Timestamp('2015-08-01', tz='UTC')
            },
            # Notice date comes after expiration
            4: {
                'symbol': 'ADZ16',
                'root_symbol': 'AD',
                'notice_date': pd.Timestamp('2016-11-25', tz='UTC'),
                'expiration_date': pd.Timestamp('2016-11-16', tz='UTC'),
                'start_date': pd.Timestamp('2015-08-01', tz='UTC')
            },
            # This contract has no start date and also this contract should be
            # last in all chains
            5: {
                'symbol': 'ADZ20',
                'root_symbol': 'AD',
                'notice_date': pd.Timestamp('2020-11-25', tz='UTC'),
                'expiration_date': pd.Timestamp('2020-11-16', tz='UTC')
            },
        }
        self.env.write_data(futures_data=metadata)
        finder = self.asset_finder_type(self.env.engine)
        dt = pd.Timestamp('2015-05-14', tz='UTC')
        dt_2 = pd.Timestamp('2015-10-14', tz='UTC')
        dt_3 = pd.Timestamp('2016-11-17', tz='UTC')

        # Check that we get the expected number of contracts, in the
        # right order
        ad_contracts = finder.lookup_future_chain('AD', dt)
        self.assertEqual(len(ad_contracts), 6)
        self.assertEqual(ad_contracts[0].sid, 0)
        self.assertEqual(ad_contracts[1].sid, 1)
        self.assertEqual(ad_contracts[5].sid, 5)

        # Check that, when some contracts have expired, the chain has advanced
        # properly to the next contracts
        ad_contracts = finder.lookup_future_chain('AD', dt_2)
        self.assertEqual(len(ad_contracts), 4)
        self.assertEqual(ad_contracts[0].sid, 2)
        self.assertEqual(ad_contracts[3].sid, 5)

        # Check that when the expiration_date has passed but the
        # notice_date hasn't, contract is still considered invalid.
        ad_contracts = finder.lookup_future_chain('AD', dt_3)
        self.assertEqual(len(ad_contracts), 1)
        self.assertEqual(ad_contracts[0].sid, 5)

        # Check that pd.NaT for as_of_date gives the whole chain
        ad_contracts = finder.lookup_future_chain('AD', pd.NaT)
        self.assertEqual(len(ad_contracts), 6)
        self.assertEqual(ad_contracts[5].sid, 5)

    def test_map_identifier_index_to_sids(self):
        # Build an empty finder and some Assets
        dt = pd.Timestamp('2014-01-01', tz='UTC')
        finder = self.asset_finder_type(self.env.engine)
        asset1 = Equity(1, symbol="AAPL")
        asset2 = Equity(2, symbol="GOOG")
        asset200 = Future(200, symbol="CLK15")
        asset201 = Future(201, symbol="CLM15")

        # Check for correct mapping and types
        pre_map = [asset1, asset2, asset200, asset201]
        post_map = finder.map_identifier_index_to_sids(pre_map, dt)
        self.assertListEqual([1, 2, 200, 201], post_map)
        for sid in post_map:
            self.assertIsInstance(sid, int)

        # Change order and check mapping again
        pre_map = [asset201, asset2, asset200, asset1]
        post_map = finder.map_identifier_index_to_sids(pre_map, dt)
        self.assertListEqual([201, 2, 200, 1], post_map)

    def test_compute_lifetimes(self):
        num_assets = 4
        trading_day = self.env.trading_day
        first_start = pd.Timestamp('2015-04-01', tz='UTC')

        frame = make_rotating_equity_info(num_assets=num_assets,
                                          first_start=first_start,
                                          frequency=self.env.trading_day,
                                          periods_between_starts=3,
                                          asset_lifetime=5)

        self.env.write_data(equities_df=frame)
        finder = self.env.asset_finder

        all_dates = pd.date_range(
            start=first_start,
            end=frame.end_date.max(),
            freq=trading_day,
        )

        for dates in all_subindices(all_dates):
            expected_with_start_raw = full(
                shape=(len(dates), num_assets),
                fill_value=False,
                dtype=bool,
            )
            expected_no_start_raw = full(
                shape=(len(dates), num_assets),
                fill_value=False,
                dtype=bool,
            )

            for i, date in enumerate(dates):
                it = frame[['start_date', 'end_date']].itertuples()
                for j, start, end in it:
                    # This way of doing the checks is redundant, but very
                    # clear.
                    if start <= date <= end:
                        expected_with_start_raw[i, j] = True
                        if start < date:
                            expected_no_start_raw[i, j] = True

            expected_with_start = pd.DataFrame(
                data=expected_with_start_raw,
                index=dates,
                columns=frame.index.values,
            )
            result = finder.lifetimes(dates, include_start_date=True)
            assert_frame_equal(result, expected_with_start)

            expected_no_start = pd.DataFrame(
                data=expected_no_start_raw,
                index=dates,
                columns=frame.index.values,
            )
            result = finder.lifetimes(dates, include_start_date=False)
            assert_frame_equal(result, expected_no_start)

    def test_sids(self):
        # Ensure that the sids property of the AssetFinder is functioning
        self.env.write_data(equities_identifiers=[1, 2, 3])
        sids = self.env.asset_finder.sids
        self.assertEqual(3, len(sids))
        self.assertTrue(1 in sids)
        self.assertTrue(2 in sids)
        self.assertTrue(3 in sids)

    def test_group_by_type(self):
        equities = make_simple_equity_info(
            range(5),
            start_date=pd.Timestamp('2014-01-01'),
            end_date=pd.Timestamp('2015-01-01'),
        )
        futures = make_commodity_future_info(
            first_sid=6,
            root_symbols=['CL'],
            years=[2014],
        )
        # Intersecting sid queries, to exercise loading of partially-cached
        # results.
        queries = [
            ([0, 1, 3], [6, 7]),
            ([0, 2, 3], [7, 10]),
            (list(equities.index), list(futures.index)),
        ]
        with tmp_asset_finder(equities=equities, futures=futures) as finder:
            for equity_sids, future_sids in queries:
                results = finder.group_by_type(equity_sids + future_sids)
                self.assertEqual(
                    results,
                    {
                        'equity': set(equity_sids),
                        'future': set(future_sids)
                    },
                )

    @parameterized.expand([
        (Equity, 'retrieve_equities', EquitiesNotFound),
        (Future, 'retrieve_futures_contracts', FutureContractsNotFound),
    ])
    def test_retrieve_specific_type(self, type_, lookup_name, failure_type):
        equities = make_simple_equity_info(
            range(5),
            start_date=pd.Timestamp('2014-01-01'),
            end_date=pd.Timestamp('2015-01-01'),
        )
        max_equity = equities.index.max()
        futures = make_commodity_future_info(
            first_sid=max_equity + 1,
            root_symbols=['CL'],
            years=[2014],
        )
        equity_sids = [0, 1]
        future_sids = [max_equity + 1, max_equity + 2, max_equity + 3]
        if type_ == Equity:
            success_sids = equity_sids
            fail_sids = future_sids
        else:
            fail_sids = equity_sids
            success_sids = future_sids

        with tmp_asset_finder(equities=equities, futures=futures) as finder:
            # Run twice to exercise caching.
            lookup = getattr(finder, lookup_name)
            for _ in range(2):
                results = lookup(success_sids)
                self.assertIsInstance(results, dict)
                self.assertEqual(set(results.keys()), set(success_sids))
                self.assertEqual(
                    valmap(int, results),
                    dict(zip(success_sids, success_sids)),
                )
                self.assertEqual(
                    {type_},
                    {type(asset)
                     for asset in itervalues(results)},
                )
                with self.assertRaises(failure_type):
                    lookup(fail_sids)
                with self.assertRaises(failure_type):
                    # Should fail if **any** of the assets are bad.
                    lookup([success_sids[0], fail_sids[0]])

    def test_retrieve_all(self):
        equities = make_simple_equity_info(
            range(5),
            start_date=pd.Timestamp('2014-01-01'),
            end_date=pd.Timestamp('2015-01-01'),
        )
        max_equity = equities.index.max()
        futures = make_commodity_future_info(
            first_sid=max_equity + 1,
            root_symbols=['CL'],
            years=[2014],
        )

        with tmp_asset_finder(equities=equities, futures=futures) as finder:
            all_sids = finder.sids
            self.assertEqual(len(all_sids), len(equities) + len(futures))
            queries = [
                # Empty Query.
                (),
                # Only Equities.
                tuple(equities.index[:2]),
                # Only Futures.
                tuple(futures.index[:3]),
                # Mixed, all cache misses.
                tuple(equities.index[2:]) + tuple(futures.index[3:]),
                # Mixed, all cache hits.
                tuple(equities.index[2:]) + tuple(futures.index[3:]),
                # Everything.
                all_sids,
                all_sids,
            ]
            for sids in queries:
                equity_sids = [i for i in sids if i <= max_equity]
                future_sids = [i for i in sids if i > max_equity]
                results = finder.retrieve_all(sids)
                self.assertEqual(sids, tuple(map(int, results)))

                self.assertEqual(
                    [Equity
                     for _ in equity_sids] + [Future for _ in future_sids],
                    list(map(type, results)),
                )
                self.assertEqual(
                    (list(equities.symbol.loc[equity_sids]) +
                     list(futures.symbol.loc[future_sids])),
                    list(asset.symbol for asset in results),
                )

    @parameterized.expand([
        (EquitiesNotFound, 'equity', 'equities'),
        (FutureContractsNotFound, 'future contract', 'future contracts'),
        (SidsNotFound, 'asset', 'assets'),
    ])
    def test_error_message_plurality(self, error_type, singular, plural):
        try:
            raise error_type(sids=[1])
        except error_type as e:
            self.assertEqual(
                str(e),
                "No {singular} found for sid: 1.".format(singular=singular))
        try:
            raise error_type(sids=[1, 2])
        except error_type as e:
            self.assertEqual(
                str(e),
                "No {plural} found for sids: [1, 2].".format(plural=plural))
Ejemplo n.º 35
0
class AssetFinderTestCase(TestCase):

    def setUp(self):
        self.env = TradingEnvironment()

    def test_lookup_symbol_delimited(self):
        as_of = pd.Timestamp('2013-01-01', tz='UTC')
        frame = pd.DataFrame.from_records(
            [
                {
                    'sid': i,
                    'symbol':  'TEST.%d' % i,
                    'company_name': "company%d" % i,
                    'start_date': as_of.value,
                    'end_date': as_of.value,
                    'exchange': uuid.uuid4().hex
                }
                for i in range(3)
            ]
        )
        self.env.write_data(equities_df=frame)
        finder = AssetFinder(self.env.engine)
        asset_0, asset_1, asset_2 = (
            finder.retrieve_asset(i) for i in range(3)
        )

        # we do it twice to catch caching bugs
        for i in range(2):
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol('test', as_of)
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol('test1', as_of)
            # '@' is not a supported delimiter
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol('test@1', as_of)

            # Adding an unnecessary fuzzy shouldn't matter.
            for fuzzy_char in ['-', '/', '_', '.']:
                self.assertEqual(
                    asset_1,
                    finder.lookup_symbol('test%s1' % fuzzy_char, as_of)
                )

    def test_lookup_symbol_fuzzy(self):
        metadata = {
            0: {'symbol': 'PRTY_HRD'},
            1: {'symbol': 'BRKA'},
            2: {'symbol': 'BRK_A'},
        }
        self.env.write_data(equities_data=metadata)
        finder = self.env.asset_finder
        dt = pd.Timestamp('2013-01-01', tz='UTC')

        # Try combos of looking up PRTYHRD with and without a time or fuzzy
        # Both non-fuzzys get no result
        with self.assertRaises(SymbolNotFound):
            finder.lookup_symbol('PRTYHRD', None)
        with self.assertRaises(SymbolNotFound):
            finder.lookup_symbol('PRTYHRD', dt)
        # Both fuzzys work
        self.assertEqual(0, finder.lookup_symbol('PRTYHRD', None, fuzzy=True))
        self.assertEqual(0, finder.lookup_symbol('PRTYHRD', dt, fuzzy=True))

        # Try combos of looking up PRTY_HRD, all returning sid 0
        self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', None))
        self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', dt))
        self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', None, fuzzy=True))
        self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', dt, fuzzy=True))

        # Try combos of looking up BRKA, all returning sid 1
        self.assertEqual(1, finder.lookup_symbol('BRKA', None))
        self.assertEqual(1, finder.lookup_symbol('BRKA', dt))
        self.assertEqual(1, finder.lookup_symbol('BRKA', None, fuzzy=True))
        self.assertEqual(1, finder.lookup_symbol('BRKA', dt, fuzzy=True))

        # Try combos of looking up BRK_A, all returning sid 2
        self.assertEqual(2, finder.lookup_symbol('BRK_A', None))
        self.assertEqual(2, finder.lookup_symbol('BRK_A', dt))
        self.assertEqual(2, finder.lookup_symbol('BRK_A', None, fuzzy=True))
        self.assertEqual(2, finder.lookup_symbol('BRK_A', dt, fuzzy=True))

    def test_lookup_symbol(self):

        # Incrementing by two so that start and end dates for each
        # generated Asset don't overlap (each Asset's end_date is the
        # day after its start date.)
        dates = pd.date_range('2013-01-01', freq='2D', periods=5, tz='UTC')
        df = pd.DataFrame.from_records(
            [
                {
                    'sid': i,
                    'symbol':  'existing',
                    'start_date': date.value,
                    'end_date': (date + timedelta(days=1)).value,
                    'exchange': 'NYSE',
                }
                for i, date in enumerate(dates)
            ]
        )
        self.env.write_data(equities_df=df)
        finder = AssetFinder(self.env.engine)
        for _ in range(2):  # Run checks twice to test for caching bugs.
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol('non_existing', dates[0])

            with self.assertRaises(MultipleSymbolsFound):
                finder.lookup_symbol('existing', None)

            for i, date in enumerate(dates):
                # Verify that we correctly resolve multiple symbols using
                # the supplied date
                result = finder.lookup_symbol('existing', date)
                self.assertEqual(result.symbol, 'EXISTING')
                self.assertEqual(result.sid, i)

    @parameterized.expand(
        build_lookup_generic_cases()
    )
    def test_lookup_generic(self, finder, symbols, reference_date, expected):
        """
        Ensure that lookup_generic works with various permutations of inputs.
        """
        results, missing = finder.lookup_generic(symbols, reference_date)
        self.assertEqual(results, expected)
        self.assertEqual(missing, [])

    def test_lookup_generic_handle_missing(self):
        data = pd.DataFrame.from_records(
            [
                {
                    'sid': 0,
                    'symbol': 'real',
                    'start_date': pd.Timestamp('2013-1-1', tz='UTC'),
                    'end_date': pd.Timestamp('2014-1-1', tz='UTC'),
                    'exchange': '',
                },
                {
                    'sid': 1,
                    'symbol': 'also_real',
                    'start_date': pd.Timestamp('2013-1-1', tz='UTC'),
                    'end_date': pd.Timestamp('2014-1-1', tz='UTC'),
                    'exchange': '',
                },
                # Sid whose end date is before our query date.  We should
                # still correctly find it.
                {
                    'sid': 2,
                    'symbol': 'real_but_old',
                    'start_date': pd.Timestamp('2002-1-1', tz='UTC'),
                    'end_date': pd.Timestamp('2003-1-1', tz='UTC'),
                    'exchange': '',
                },
                # Sid whose start_date is **after** our query date.  We should
                # **not** find it.
                {
                    'sid': 3,
                    'symbol': 'real_but_in_the_future',
                    'start_date': pd.Timestamp('2014-1-1', tz='UTC'),
                    'end_date': pd.Timestamp('2020-1-1', tz='UTC'),
                    'exchange': 'THE FUTURE',
                },
            ]
        )
        self.env.write_data(equities_df=data)
        finder = AssetFinder(self.env.engine)
        results, missing = finder.lookup_generic(
            ['real', 1, 'fake', 'real_but_old', 'real_but_in_the_future'],
            pd.Timestamp('2013-02-01', tz='UTC'),
        )

        self.assertEqual(len(results), 3)
        self.assertEqual(results[0].symbol, 'REAL')
        self.assertEqual(results[0].sid, 0)
        self.assertEqual(results[1].symbol, 'ALSO_REAL')
        self.assertEqual(results[1].sid, 1)
        self.assertEqual(results[2].symbol, 'REAL_BUT_OLD')
        self.assertEqual(results[2].sid, 2)

        self.assertEqual(len(missing), 2)
        self.assertEqual(missing[0], 'fake')
        self.assertEqual(missing[1], 'real_but_in_the_future')

    def test_insert_metadata(self):
        data = {0: {'asset_type': 'equity',
                    'start_date': '2014-01-01',
                    'end_date': '2015-01-01',
                    'symbol': "PLAY",
                    'foo_data': "FOO"}}
        self.env.write_data(equities_data=data)
        finder = AssetFinder(self.env.engine)
        # Test proper insertion
        equity = finder.retrieve_asset(0)
        self.assertIsInstance(equity, Equity)
        self.assertEqual('PLAY', equity.symbol)
        self.assertEqual(pd.Timestamp('2015-01-01', tz='UTC'),
                         equity.end_date)

        # Test invalid field
        with self.assertRaises(AttributeError):
            equity.foo_data

    def test_consume_metadata(self):

        # Test dict consumption
        dict_to_consume = {0: {'symbol': 'PLAY'},
                           1: {'symbol': 'MSFT'}}
        self.env.write_data(equities_data=dict_to_consume)
        finder = AssetFinder(self.env.engine)

        equity = finder.retrieve_asset(0)
        self.assertIsInstance(equity, Equity)
        self.assertEqual('PLAY', equity.symbol)

        # Test dataframe consumption
        df = pd.DataFrame(columns=['asset_name', 'exchange'], index=[0, 1])
        df['asset_name'][0] = "Dave'N'Busters"
        df['exchange'][0] = "NASDAQ"
        df['asset_name'][1] = "Microsoft"
        df['exchange'][1] = "NYSE"
        self.env = TradingEnvironment()
        self.env.write_data(equities_df=df)
        finder = AssetFinder(self.env.engine)
        self.assertEqual('NASDAQ', finder.retrieve_asset(0).exchange)
        self.assertEqual('Microsoft', finder.retrieve_asset(1).asset_name)

    def test_consume_asset_as_identifier(self):
        # Build some end dates
        eq_end = pd.Timestamp('2012-01-01', tz='UTC')
        fut_end = pd.Timestamp('2008-01-01', tz='UTC')

        # Build some simple Assets
        equity_asset = Equity(1, symbol="TESTEQ", end_date=eq_end)
        future_asset = Future(200, symbol="TESTFUT", end_date=fut_end)

        # Consume the Assets
        self.env.write_data(equities_identifiers=[equity_asset],
                            futures_identifiers=[future_asset])
        finder = AssetFinder(self.env.engine)

        # Test equality with newly built Assets
        self.assertEqual(equity_asset, finder.retrieve_asset(1))
        self.assertEqual(future_asset, finder.retrieve_asset(200))
        self.assertEqual(eq_end, finder.retrieve_asset(1).end_date)
        self.assertEqual(fut_end, finder.retrieve_asset(200).end_date)

    def test_sid_assignment(self):

        # This metadata does not contain SIDs
        metadata = ['PLAY', 'MSFT']

        today = normalize_date(pd.Timestamp('2015-07-09', tz='UTC'))

        # Write data with sid assignment
        self.env.write_data(equities_identifiers=metadata,
                            allow_sid_assignment=True)

        # Verify that Assets were built and different sids were assigned
        finder = AssetFinder(self.env.engine)
        play = finder.lookup_symbol('PLAY', today)
        msft = finder.lookup_symbol('MSFT', today)
        self.assertEqual('PLAY', play.symbol)
        self.assertIsNotNone(play.sid)
        self.assertNotEqual(play.sid, msft.sid)

    def test_sid_assignment_failure(self):

        # This metadata does not contain SIDs
        metadata = ['PLAY', 'MSFT']

        # Write data without sid assignment, asserting failure
        with self.assertRaises(SidAssignmentError):
            self.env.write_data(equities_identifiers=metadata,
                                allow_sid_assignment=False)

    def test_security_dates_warning(self):

        # Build an asset with an end_date
        eq_end = pd.Timestamp('2012-01-01', tz='UTC')
        equity_asset = Equity(1, symbol="TESTEQ", end_date=eq_end)

        # Catch all warnings
        with warnings.catch_warnings(record=True) as w:
            # Cause all warnings to always be triggered
            warnings.simplefilter("always")
            equity_asset.security_start_date
            equity_asset.security_end_date
            equity_asset.security_name
            # Verify the warning
            self.assertEqual(3, len(w))
            for warning in w:
                self.assertTrue(issubclass(warning.category,
                                           DeprecationWarning))

    def test_lookup_future_chain(self):
        metadata = {
            # Notice day is today, so not valid
            2: {
                'symbol': 'ADN15',
                'root_symbol': 'AD',
                'asset_type': 'future',
                'notice_date': pd.Timestamp('2015-05-14', tz='UTC'),
                'start_date': pd.Timestamp('2015-01-01', tz='UTC')
            },
            1: {
                'symbol': 'ADV15',
                'root_symbol': 'AD',
                'asset_type': 'future',
                'notice_date': pd.Timestamp('2015-08-14', tz='UTC'),
                'start_date': pd.Timestamp('2015-01-01', tz='UTC')
            },
            # Starts trading today, so should be valid.
            0: {
                'symbol': 'ADF16',
                'root_symbol': 'AD',
                'asset_type': 'future',
                'notice_date': pd.Timestamp('2015-11-16', tz='UTC'),
                'start_date': pd.Timestamp('2015-05-14', tz='UTC')
            },
            # Copy of the above future, but starts trading in August,
            # so it isn't valid.
            3: {
                'symbol': 'ADF16',
                'root_symbol': 'AD',
                'asset_type': 'future',
                'notice_date': pd.Timestamp('2015-11-16', tz='UTC'),
                'start_date': pd.Timestamp('2015-08-01', tz='UTC')
            },
        }
        self.env.write_data(futures_data=metadata)
        finder = AssetFinder(self.env.engine)
        dt = pd.Timestamp('2015-05-14', tz='UTC')
        last_year = pd.Timestamp('2014-01-01', tz='UTC')
        first_day = pd.Timestamp('2015-01-01', tz='UTC')

        # Check that we get the expected number of contracts, in the
        # right order
        ad_contracts = finder.lookup_future_chain('AD', dt, dt)
        self.assertEqual(len(ad_contracts), 2)
        self.assertEqual(ad_contracts[0].sid, 1)
        self.assertEqual(ad_contracts[1].sid, 0)

        # Check that pd.NaT for knowledge_date uses the value of as_of_date
        ad_contracts = finder.lookup_future_chain('AD', dt, pd.NaT)
        self.assertEqual(len(ad_contracts), 2)

        # Check that we get nothing if our knowledge date is last year
        ad_contracts = finder.lookup_future_chain('AD', dt, last_year)
        self.assertEqual(len(ad_contracts), 0)

        # Check that we get things that start on the knowledge date
        ad_contracts = finder.lookup_future_chain('AD', dt, first_day)
        self.assertEqual(len(ad_contracts), 1)

        # Check that pd.NaT for as_of_date gives the whole chain
        ad_contracts = finder.lookup_future_chain('AD', pd.NaT, first_day)
        self.assertEqual(len(ad_contracts), 4)

    def test_map_identifier_index_to_sids(self):
        # Build an empty finder and some Assets
        dt = pd.Timestamp('2014-01-01', tz='UTC')
        finder = AssetFinder(self.env.engine)
        asset1 = Equity(1, symbol="AAPL")
        asset2 = Equity(2, symbol="GOOG")
        asset200 = Future(200, symbol="CLK15")
        asset201 = Future(201, symbol="CLM15")

        # Check for correct mapping and types
        pre_map = [asset1, asset2, asset200, asset201]
        post_map = finder.map_identifier_index_to_sids(pre_map, dt)
        self.assertListEqual([1, 2, 200, 201], post_map)
        for sid in post_map:
            self.assertIsInstance(sid, int)

        # Change order and check mapping again
        pre_map = [asset201, asset2, asset200, asset1]
        post_map = finder.map_identifier_index_to_sids(pre_map, dt)
        self.assertListEqual([201, 2, 200, 1], post_map)

    def test_compute_lifetimes(self):
        num_assets = 4
        env = TradingEnvironment()
        trading_day = env.trading_day
        first_start = pd.Timestamp('2015-04-01', tz='UTC')

        frame = make_rotating_asset_info(
            num_assets=num_assets,
            first_start=first_start,
            frequency=env.trading_day,
            periods_between_starts=3,
            asset_lifetime=5
        )

        env.write_data(equities_df=frame)
        finder = env.asset_finder

        all_dates = pd.date_range(
            start=first_start,
            end=frame.end_date.max(),
            freq=trading_day,
        )

        for dates in all_subindices(all_dates):
            expected_with_start_raw = full(
                shape=(len(dates), num_assets),
                fill_value=False,
                dtype=bool,
            )
            expected_no_start_raw = full(
                shape=(len(dates), num_assets),
                fill_value=False,
                dtype=bool,
            )

            for i, date in enumerate(dates):
                it = frame[['start_date', 'end_date']].itertuples()
                for j, start, end in it:
                    # This way of doing the checks is redundant, but very
                    # clear.
                    if start <= date <= end:
                        expected_with_start_raw[i, j] = True
                        if start < date:
                            expected_no_start_raw[i, j] = True

            expected_with_start = pd.DataFrame(
                data=expected_with_start_raw,
                index=dates,
                columns=frame.index.values,
            )
            result = finder.lifetimes(dates, include_start_date=True)
            assert_frame_equal(result, expected_with_start)

            expected_no_start = pd.DataFrame(
                data=expected_no_start_raw,
                index=dates,
                columns=frame.index.values,
            )
            result = finder.lifetimes(dates, include_start_date=False)
            assert_frame_equal(result, expected_no_start)

    def test_sids(self):
        # Ensure that the sids property of the AssetFinder is functioning
        env = TradingEnvironment()
        env.write_data(equities_identifiers=[1, 2, 3])
        sids = env.asset_finder.sids
        self.assertEqual(3, len(sids))
        self.assertTrue(1 in sids)
        self.assertTrue(2 in sids)
        self.assertTrue(3 in sids)
Ejemplo n.º 36
0
def build_lookup_generic_cases():
    """
    Generate test cases for AssetFinder test_lookup_generic.
    """

    unique_start = pd.Timestamp("2013-01-01", tz="UTC")
    unique_end = pd.Timestamp("2014-01-01", tz="UTC")

    dupe_0_start = pd.Timestamp("2013-01-01", tz="UTC")
    dupe_0_end = dupe_0_start + timedelta(days=1)

    dupe_1_start = pd.Timestamp("2013-01-03", tz="UTC")
    dupe_1_end = dupe_1_start + timedelta(days=1)

    frame = pd.DataFrame.from_records(
        [
            {
                "sid": 0,
                "symbol": "duplicated",
                "start_date": dupe_0_start.value,
                "end_date": dupe_0_end.value,
                "exchange": "",
            },
            {
                "sid": 1,
                "symbol": "duplicated",
                "start_date": dupe_1_start.value,
                "end_date": dupe_1_end.value,
                "exchange": "",
            },
            {
                "sid": 2,
                "symbol": "unique",
                "start_date": unique_start.value,
                "end_date": unique_end.value,
                "exchange": "",
            },
        ],
        index="sid",
    )
    env = TradingEnvironment()
    env.write_data(equities_df=frame)
    finder = env.asset_finder
    dupe_0, dupe_1, unique = assets = [finder.retrieve_asset(i) for i in range(3)]

    dupe_0_start = dupe_0.start_date
    dupe_1_start = dupe_1.start_date
    cases = [
        ##
        # Scalars
        # Asset object
        (finder, assets[0], None, assets[0]),
        (finder, assets[1], None, assets[1]),
        (finder, assets[2], None, assets[2]),
        # int
        (finder, 0, None, assets[0]),
        (finder, 1, None, assets[1]),
        (finder, 2, None, assets[2]),
        # Duplicated symbol with resolution date
        (finder, "DUPLICATED", dupe_0_start, dupe_0),
        (finder, "DUPLICATED", dupe_1_start, dupe_1),
        # Unique symbol, with or without resolution date.
        (finder, "UNIQUE", unique_start, unique),
        (finder, "UNIQUE", None, unique),
        ##
        # Iterables
        # Iterables of Asset objects.
        (finder, assets, None, assets),
        (finder, iter(assets), None, assets),
        # Iterables of ints
        (finder, (0, 1), None, assets[:-1]),
        (finder, iter((0, 1)), None, assets[:-1]),
        # Iterables of symbols.
        (finder, ("DUPLICATED", "UNIQUE"), dupe_0_start, [dupe_0, unique]),
        (finder, ("DUPLICATED", "UNIQUE"), dupe_1_start, [dupe_1, unique]),
        # Mixed types
        (finder, ("DUPLICATED", 2, "UNIQUE", 1, dupe_1), dupe_0_start, [dupe_0, assets[2], unique, assets[1], dupe_1]),
    ]
    return cases
Ejemplo n.º 37
0
class AssetFinderTestCase(TestCase):
    def setUp(self):
        self.env = TradingEnvironment(load=noop_load)

    def test_lookup_symbol_delimited(self):
        as_of = pd.Timestamp("2013-01-01", tz="UTC")
        frame = pd.DataFrame.from_records(
            [
                {
                    "sid": i,
                    "symbol": "TEST.%d" % i,
                    "company_name": "company%d" % i,
                    "start_date": as_of.value,
                    "end_date": as_of.value,
                    "exchange": uuid.uuid4().hex,
                }
                for i in range(3)
            ]
        )
        self.env.write_data(equities_df=frame)
        finder = AssetFinder(self.env.engine)
        asset_0, asset_1, asset_2 = (finder.retrieve_asset(i) for i in range(3))

        # we do it twice to catch caching bugs
        for i in range(2):
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol("TEST", as_of)
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol("TEST1", as_of)
            # '@' is not a supported delimiter
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol("TEST@1", as_of)

            # Adding an unnecessary fuzzy shouldn't matter.
            for fuzzy_char in ["-", "/", "_", "."]:
                self.assertEqual(asset_1, finder.lookup_symbol("TEST%s1" % fuzzy_char, as_of))

    def test_lookup_symbol_fuzzy(self):
        metadata = {0: {"symbol": "PRTY_HRD"}, 1: {"symbol": "BRKA"}, 2: {"symbol": "BRK_A"}}
        self.env.write_data(equities_data=metadata)
        finder = self.env.asset_finder
        dt = pd.Timestamp("2013-01-01", tz="UTC")

        # Try combos of looking up PRTYHRD with and without a time or fuzzy
        # Both non-fuzzys get no result
        with self.assertRaises(SymbolNotFound):
            finder.lookup_symbol("PRTYHRD", None)
        with self.assertRaises(SymbolNotFound):
            finder.lookup_symbol("PRTYHRD", dt)
        # Both fuzzys work
        self.assertEqual(0, finder.lookup_symbol("PRTYHRD", None, fuzzy=True))
        self.assertEqual(0, finder.lookup_symbol("PRTYHRD", dt, fuzzy=True))

        # Try combos of looking up PRTY_HRD, all returning sid 0
        self.assertEqual(0, finder.lookup_symbol("PRTY_HRD", None))
        self.assertEqual(0, finder.lookup_symbol("PRTY_HRD", dt))
        self.assertEqual(0, finder.lookup_symbol("PRTY_HRD", None, fuzzy=True))
        self.assertEqual(0, finder.lookup_symbol("PRTY_HRD", dt, fuzzy=True))

        # Try combos of looking up BRKA, all returning sid 1
        self.assertEqual(1, finder.lookup_symbol("BRKA", None))
        self.assertEqual(1, finder.lookup_symbol("BRKA", dt))
        self.assertEqual(1, finder.lookup_symbol("BRKA", None, fuzzy=True))
        self.assertEqual(1, finder.lookup_symbol("BRKA", dt, fuzzy=True))

        # Try combos of looking up BRK_A, all returning sid 2
        self.assertEqual(2, finder.lookup_symbol("BRK_A", None))
        self.assertEqual(2, finder.lookup_symbol("BRK_A", dt))
        self.assertEqual(2, finder.lookup_symbol("BRK_A", None, fuzzy=True))
        self.assertEqual(2, finder.lookup_symbol("BRK_A", dt, fuzzy=True))

    def test_lookup_symbol(self):

        # Incrementing by two so that start and end dates for each
        # generated Asset don't overlap (each Asset's end_date is the
        # day after its start date.)
        dates = pd.date_range("2013-01-01", freq="2D", periods=5, tz="UTC")
        df = pd.DataFrame.from_records(
            [
                {
                    "sid": i,
                    "symbol": "existing",
                    "start_date": date.value,
                    "end_date": (date + timedelta(days=1)).value,
                    "exchange": "NYSE",
                }
                for i, date in enumerate(dates)
            ]
        )
        self.env.write_data(equities_df=df)
        finder = AssetFinder(self.env.engine)
        for _ in range(2):  # Run checks twice to test for caching bugs.
            with self.assertRaises(SymbolNotFound):
                finder.lookup_symbol("NON_EXISTING", dates[0])

            with self.assertRaises(MultipleSymbolsFound):
                finder.lookup_symbol("EXISTING", None)

            for i, date in enumerate(dates):
                # Verify that we correctly resolve multiple symbols using
                # the supplied date
                result = finder.lookup_symbol("EXISTING", date)
                self.assertEqual(result.symbol, "EXISTING")
                self.assertEqual(result.sid, i)

    @parameterized.expand(build_lookup_generic_cases())
    def test_lookup_generic(self, finder, symbols, reference_date, expected):
        """
        Ensure that lookup_generic works with various permutations of inputs.
        """
        results, missing = finder.lookup_generic(symbols, reference_date)
        self.assertEqual(results, expected)
        self.assertEqual(missing, [])

    def test_lookup_generic_handle_missing(self):
        data = pd.DataFrame.from_records(
            [
                {
                    "sid": 0,
                    "symbol": "real",
                    "start_date": pd.Timestamp("2013-1-1", tz="UTC"),
                    "end_date": pd.Timestamp("2014-1-1", tz="UTC"),
                    "exchange": "",
                },
                {
                    "sid": 1,
                    "symbol": "also_real",
                    "start_date": pd.Timestamp("2013-1-1", tz="UTC"),
                    "end_date": pd.Timestamp("2014-1-1", tz="UTC"),
                    "exchange": "",
                },
                # Sid whose end date is before our query date.  We should
                # still correctly find it.
                {
                    "sid": 2,
                    "symbol": "real_but_old",
                    "start_date": pd.Timestamp("2002-1-1", tz="UTC"),
                    "end_date": pd.Timestamp("2003-1-1", tz="UTC"),
                    "exchange": "",
                },
                # Sid whose start_date is **after** our query date.  We should
                # **not** find it.
                {
                    "sid": 3,
                    "symbol": "real_but_in_the_future",
                    "start_date": pd.Timestamp("2014-1-1", tz="UTC"),
                    "end_date": pd.Timestamp("2020-1-1", tz="UTC"),
                    "exchange": "THE FUTURE",
                },
            ]
        )
        self.env.write_data(equities_df=data)
        finder = AssetFinder(self.env.engine)
        results, missing = finder.lookup_generic(
            ["REAL", 1, "FAKE", "REAL_BUT_OLD", "REAL_BUT_IN_THE_FUTURE"], pd.Timestamp("2013-02-01", tz="UTC")
        )

        self.assertEqual(len(results), 3)
        self.assertEqual(results[0].symbol, "REAL")
        self.assertEqual(results[0].sid, 0)
        self.assertEqual(results[1].symbol, "ALSO_REAL")
        self.assertEqual(results[1].sid, 1)
        self.assertEqual(results[2].symbol, "REAL_BUT_OLD")
        self.assertEqual(results[2].sid, 2)

        self.assertEqual(len(missing), 2)
        self.assertEqual(missing[0], "FAKE")
        self.assertEqual(missing[1], "REAL_BUT_IN_THE_FUTURE")

    def test_insert_metadata(self):
        data = {0: {"start_date": "2014-01-01", "end_date": "2015-01-01", "symbol": "PLAY", "foo_data": "FOO"}}
        self.env.write_data(equities_data=data)
        finder = AssetFinder(self.env.engine)
        # Test proper insertion
        equity = finder.retrieve_asset(0)
        self.assertIsInstance(equity, Equity)
        self.assertEqual("PLAY", equity.symbol)
        self.assertEqual(pd.Timestamp("2015-01-01", tz="UTC"), equity.end_date)

        # Test invalid field
        with self.assertRaises(AttributeError):
            equity.foo_data

    def test_consume_metadata(self):

        # Test dict consumption
        dict_to_consume = {0: {"symbol": "PLAY"}, 1: {"symbol": "MSFT"}}
        self.env.write_data(equities_data=dict_to_consume)
        finder = AssetFinder(self.env.engine)

        equity = finder.retrieve_asset(0)
        self.assertIsInstance(equity, Equity)
        self.assertEqual("PLAY", equity.symbol)

        # Test dataframe consumption
        df = pd.DataFrame(columns=["asset_name", "exchange"], index=[0, 1])
        df["asset_name"][0] = "Dave'N'Busters"
        df["exchange"][0] = "NASDAQ"
        df["asset_name"][1] = "Microsoft"
        df["exchange"][1] = "NYSE"
        self.env = TradingEnvironment(load=noop_load)
        self.env.write_data(equities_df=df)
        finder = AssetFinder(self.env.engine)
        self.assertEqual("NASDAQ", finder.retrieve_asset(0).exchange)
        self.assertEqual("Microsoft", finder.retrieve_asset(1).asset_name)

    def test_consume_asset_as_identifier(self):
        # Build some end dates
        eq_end = pd.Timestamp("2012-01-01", tz="UTC")
        fut_end = pd.Timestamp("2008-01-01", tz="UTC")

        # Build some simple Assets
        equity_asset = Equity(1, symbol="TESTEQ", end_date=eq_end)
        future_asset = Future(200, symbol="TESTFUT", end_date=fut_end)

        # Consume the Assets
        self.env.write_data(equities_identifiers=[equity_asset], futures_identifiers=[future_asset])
        finder = AssetFinder(self.env.engine)

        # Test equality with newly built Assets
        self.assertEqual(equity_asset, finder.retrieve_asset(1))
        self.assertEqual(future_asset, finder.retrieve_asset(200))
        self.assertEqual(eq_end, finder.retrieve_asset(1).end_date)
        self.assertEqual(fut_end, finder.retrieve_asset(200).end_date)

    def test_sid_assignment(self):

        # This metadata does not contain SIDs
        metadata = ["PLAY", "MSFT"]

        today = normalize_date(pd.Timestamp("2015-07-09", tz="UTC"))

        # Write data with sid assignment
        self.env.write_data(equities_identifiers=metadata, allow_sid_assignment=True)

        # Verify that Assets were built and different sids were assigned
        finder = AssetFinder(self.env.engine)
        play = finder.lookup_symbol("PLAY", today)
        msft = finder.lookup_symbol("MSFT", today)
        self.assertEqual("PLAY", play.symbol)
        self.assertIsNotNone(play.sid)
        self.assertNotEqual(play.sid, msft.sid)

    def test_sid_assignment_failure(self):

        # This metadata does not contain SIDs
        metadata = ["PLAY", "MSFT"]

        # Write data without sid assignment, asserting failure
        with self.assertRaises(SidAssignmentError):
            self.env.write_data(equities_identifiers=metadata, allow_sid_assignment=False)

    def test_security_dates_warning(self):

        # Build an asset with an end_date
        eq_end = pd.Timestamp("2012-01-01", tz="UTC")
        equity_asset = Equity(1, symbol="TESTEQ", end_date=eq_end)

        # Catch all warnings
        with warnings.catch_warnings(record=True) as w:
            # Cause all warnings to always be triggered
            warnings.simplefilter("always")
            equity_asset.security_start_date
            equity_asset.security_end_date
            equity_asset.security_name
            # Verify the warning
            self.assertEqual(3, len(w))
            for warning in w:
                self.assertTrue(issubclass(warning.category, DeprecationWarning))

    def test_lookup_future_chain(self):
        metadata = {
            # Notice day is today, so should be valid.
            0: {
                "symbol": "ADN15",
                "root_symbol": "AD",
                "notice_date": pd.Timestamp("2015-05-14", tz="UTC"),
                "expiration_date": pd.Timestamp("2015-06-14", tz="UTC"),
                "start_date": pd.Timestamp("2015-01-01", tz="UTC"),
            },
            1: {
                "symbol": "ADV15",
                "root_symbol": "AD",
                "notice_date": pd.Timestamp("2015-08-14", tz="UTC"),
                "expiration_date": pd.Timestamp("2015-09-14", tz="UTC"),
                "start_date": pd.Timestamp("2015-01-01", tz="UTC"),
            },
            # Starts trading today, so should be valid.
            2: {
                "symbol": "ADF16",
                "root_symbol": "AD",
                "notice_date": pd.Timestamp("2015-11-16", tz="UTC"),
                "expiration_date": pd.Timestamp("2015-12-16", tz="UTC"),
                "start_date": pd.Timestamp("2015-05-14", tz="UTC"),
            },
            # Starts trading in August, so not valid.
            3: {
                "symbol": "ADX16",
                "root_symbol": "AD",
                "notice_date": pd.Timestamp("2015-11-16", tz="UTC"),
                "expiration_date": pd.Timestamp("2015-12-16", tz="UTC"),
                "start_date": pd.Timestamp("2015-08-01", tz="UTC"),
            },
            # Notice date comes after expiration
            4: {
                "symbol": "ADZ16",
                "root_symbol": "AD",
                "notice_date": pd.Timestamp("2016-11-25", tz="UTC"),
                "expiration_date": pd.Timestamp("2016-11-16", tz="UTC"),
                "start_date": pd.Timestamp("2015-08-01", tz="UTC"),
            },
            # This contract has no start date and also this contract should be
            # last in all chains
            5: {
                "symbol": "ADZ20",
                "root_symbol": "AD",
                "notice_date": pd.Timestamp("2020-11-25", tz="UTC"),
                "expiration_date": pd.Timestamp("2020-11-16", tz="UTC"),
            },
        }
        self.env.write_data(futures_data=metadata)
        finder = AssetFinder(self.env.engine)
        dt = pd.Timestamp("2015-05-14", tz="UTC")
        dt_2 = pd.Timestamp("2015-10-14", tz="UTC")
        dt_3 = pd.Timestamp("2016-11-17", tz="UTC")

        # Check that we get the expected number of contracts, in the
        # right order
        ad_contracts = finder.lookup_future_chain("AD", dt)
        self.assertEqual(len(ad_contracts), 6)
        self.assertEqual(ad_contracts[0].sid, 0)
        self.assertEqual(ad_contracts[1].sid, 1)
        self.assertEqual(ad_contracts[5].sid, 5)

        # Check that, when some contracts have expired, the chain has advanced
        # properly to the next contracts
        ad_contracts = finder.lookup_future_chain("AD", dt_2)
        self.assertEqual(len(ad_contracts), 4)
        self.assertEqual(ad_contracts[0].sid, 2)
        self.assertEqual(ad_contracts[3].sid, 5)

        # Check that when the expiration_date has passed but the
        # notice_date hasn't, contract is still considered invalid.
        ad_contracts = finder.lookup_future_chain("AD", dt_3)
        self.assertEqual(len(ad_contracts), 1)
        self.assertEqual(ad_contracts[0].sid, 5)

        # Check that pd.NaT for as_of_date gives the whole chain
        ad_contracts = finder.lookup_future_chain("AD", pd.NaT)
        self.assertEqual(len(ad_contracts), 6)
        self.assertEqual(ad_contracts[5].sid, 5)

    def test_map_identifier_index_to_sids(self):
        # Build an empty finder and some Assets
        dt = pd.Timestamp("2014-01-01", tz="UTC")
        finder = AssetFinder(self.env.engine)
        asset1 = Equity(1, symbol="AAPL")
        asset2 = Equity(2, symbol="GOOG")
        asset200 = Future(200, symbol="CLK15")
        asset201 = Future(201, symbol="CLM15")

        # Check for correct mapping and types
        pre_map = [asset1, asset2, asset200, asset201]
        post_map = finder.map_identifier_index_to_sids(pre_map, dt)
        self.assertListEqual([1, 2, 200, 201], post_map)
        for sid in post_map:
            self.assertIsInstance(sid, int)

        # Change order and check mapping again
        pre_map = [asset201, asset2, asset200, asset1]
        post_map = finder.map_identifier_index_to_sids(pre_map, dt)
        self.assertListEqual([201, 2, 200, 1], post_map)

    def test_compute_lifetimes(self):
        num_assets = 4
        trading_day = self.env.trading_day
        first_start = pd.Timestamp("2015-04-01", tz="UTC")

        frame = make_rotating_asset_info(
            num_assets=num_assets,
            first_start=first_start,
            frequency=self.env.trading_day,
            periods_between_starts=3,
            asset_lifetime=5,
        )

        self.env.write_data(equities_df=frame)
        finder = self.env.asset_finder

        all_dates = pd.date_range(start=first_start, end=frame.end_date.max(), freq=trading_day)

        for dates in all_subindices(all_dates):
            expected_with_start_raw = full(shape=(len(dates), num_assets), fill_value=False, dtype=bool)
            expected_no_start_raw = full(shape=(len(dates), num_assets), fill_value=False, dtype=bool)

            for i, date in enumerate(dates):
                it = frame[["start_date", "end_date"]].itertuples()
                for j, start, end in it:
                    # This way of doing the checks is redundant, but very
                    # clear.
                    if start <= date <= end:
                        expected_with_start_raw[i, j] = True
                        if start < date:
                            expected_no_start_raw[i, j] = True

            expected_with_start = pd.DataFrame(data=expected_with_start_raw, index=dates, columns=frame.index.values)
            result = finder.lifetimes(dates, include_start_date=True)
            assert_frame_equal(result, expected_with_start)

            expected_no_start = pd.DataFrame(data=expected_no_start_raw, index=dates, columns=frame.index.values)
            result = finder.lifetimes(dates, include_start_date=False)
            assert_frame_equal(result, expected_no_start)

    def test_sids(self):
        # Ensure that the sids property of the AssetFinder is functioning
        self.env.write_data(equities_identifiers=[1, 2, 3])
        sids = self.env.asset_finder.sids
        self.assertEqual(3, len(sids))
        self.assertTrue(1 in sids)
        self.assertTrue(2 in sids)
        self.assertTrue(3 in sids)
Ejemplo n.º 38
0
class Matcher:
  def __init__(self):
    self.env = TradingEnvironment(load=load_nothing)

    # prepare for data portal
    # from zipline/tests/test_finance.py#L238
    # Note that 2013-01-05 and 2013-01-06 were Sat/Sun
    # Also note that in UTC, NYSE starts trading at 14.30
    # TODO tailor for FFA Dubai
    #  start_date=pd.Timestamp('2013-12-08 9:31AM', tz='UTC'),
    START_DATE = pd.Timestamp('2013-01-01', tz='utc')
    END_DATE = pd.Timestamp(datetime.now(), tz='utc')
    self.sim_params = zl_factory.create_simulation_parameters(
        start = START_DATE,
        end = END_DATE,
        data_frequency="minute"
    )

    self.trading_calendar=get_calendar("AlwaysOpen")
  #  self.trading_calendar=get_calendar("NYSE")
   # self.trading_calendar=get_calendar("ICEUS")

  def get_minutes(self,fills,orders):
    if len(fills)==0 and len(orders)==0:
      return []

    # get fill minutes
    minutes_fills = [sid.index.values.tolist() for _, sid in fills.items()]
    minutes_fills = reduce_concatenate(minutes_fills)

    # get order minutes
    minutes_orders = [list(o.values()) for o in list(orders.values())]
    minutes_orders = [o["dt"] for o in reduce_concatenate(minutes_orders)]

    # concatenate
    #print("minutes",[type(x).__name__ for x in minutes])
    #print("minutes",minutes)
    minutes = numpy.concatenate((
      minutes_fills,
      minutes_orders
    ))

    minutes = [pd.Timestamp(x, tz='utc') for x in minutes]
    minutes = list(set(minutes))
    minutes.sort()
    #print("minutes",minutes)
    return minutes

  @staticmethod
  def chopSeconds(fills, orders):
    for sid in fills:
      # for fills, start by adding a floored-minute "index" column, and group by it
      # Combination of
      #    http://stackoverflow.com/a/34297689/4126114
      #    http://stackoverflow.com/a/29583335/4126114
      fills[sid] = fills[sid].reset_index()
      fills[sid]['dt_fl'] = [dt.floor('1Min') for dt in fills[sid]['dt']]
      del fills[sid]['dt']

      # group in preparation of aggregation
      # http://pandas.pydata.org/pandas-docs/stable/groupby.html#applying-different-functions-to-dataframe-columns
      grouped = fills[sid].groupby('dt_fl')

      # weighted-average of price
      # http://stackoverflow.com/a/35327787/4126114
      grouped_close = grouped.apply(lambda g: numpy.average(g['close'],weights=g['volume']))
      grouped_close = pd.DataFrame({'close':grouped_close})

      # add all volumes of fills at the same minute
      grouped_volume = grouped['volume'].aggregate('sum')
      grouped2 = pd.concat([grouped_close, grouped_volume], axis=1)

      # go back to original indexing
      # http://pandas.pydata.org/pandas-docs/stable/groupby.html#aggregation
      grouped2.reset_index().set_index('dt_fl')
      grouped2.sort_index(inplace=True)

      # replace original
      fills[sid] = grouped2

    # For orders, just floor the minute
    for sid in orders:
      for oid,order in orders[sid].items():
        order["dt"]=pd.Timestamp(order["dt"],tz="utc").floor('1Min')

    #print("chop seconds fills ", fills)
    #print("chop seconds orders", orders)
    return fills, orders

  def fills2reader(self, tempdir, minutes, fills, orders):
    if len(minutes)==0:
      return None

    for _,fill in fills.items():
      fill["open"] = fill["close"]
      fill["high"] = fill["close"]
      fill["low"]  = fill["close"]

      # since the below abs affects the original dataframe, storing the sign for later revert
      fill["is_neg"] = fill["volume"]<0

      # take absolute value, since negatives are split in the factory function to begin with
      # and zipline doesnt support negative OHLC volumes (which dont make sense anyway)
      fill["volume"] = abs(fill["volume"])

    # append empty OHLC dataframes for sid's in orders but not (yet) in fills
    # dummy OHLC data with volume=0 so as not to affect orders
    empty = {"open":[0], "high":[0], "low":[0], "close":[0], "volume":[0], "dt":[minutes[0]], "is_neg":[False]}
    for sid in orders:
      if sid not in fills:
        fills[sid]=pd.DataFrame(empty).set_index("dt")

    d1 = self.trading_calendar.minute_to_session_label(
      minutes[0]
    )
    d2=self.trading_calendar.minute_to_session_label(
      minutes[-1]
    )
    days = self.trading_calendar.sessions_in_range(d1, d2)
    #print("minutes",minutes)
    #print("days: %s, %s, %s" % (d1, d2, days))

    #path = os.path.join(tempdir.path, "testdata.bcolz")
    path = tempdir.path
    writer = BcolzMinuteBarWriter(
      rootdir=path,
      calendar=self.trading_calendar,
      start_session=days[0],
      end_session=days[-1],
      minutes_per_day=1440
    )
    #print("Writer session labels: %s" % (writer._session_labels))
    #print('last date for sid 1', writer.last_date_in_output_for_sid(1))
    #print('last date for sid 2', writer.last_date_in_output_for_sid(2))
    #for f in iteritems(fills): print("fill",f)
    writer.write(iteritems(fills))

    # now that the data is written, revert the volume sign and drop the extra columns
    for _,fill in fills.items():
      del fill["open"]
      del fill["high"]
      del fill["low"]
      if any(fill["is_neg"]):
        fill.loc[fill["is_neg"],"volume"] = -1 * fill["volume"]
      del fill["is_neg"]

    #print("temp path: %s" % (path))
    reader = BcolzMinuteBarReader(path)

    return reader

  # save an asset
  def write_assets(self, assets: dict):
    # unique assets by using sid
    # http://stackoverflow.com/a/11092590/4126114
    if not any(assets):
      #raise ValueError("Got empty orders!")
      return

    # make assets unique by "symbol" field also
    assets2 = { a["symbol"]: {"k":k,"a":a} for k,a in assets.items() }
    assets2  = {v["k"]: v["a"] for v in assets2.values() }

    # log dropped sid's
    dropped = [k for k in assets.keys() if k not in assets2.keys()]
    if len(dropped)>0: logger.error("Dropped asset ID with duplicated symbol: %s" % dropped)

    assets = assets2

    # check zipline/zipline/assets/asset_writer.py#write
    df = pd.DataFrame(
        {
          "sid"       : list(assets.keys()),
          "exchange"  : [asset["exchange"] for asset in list(assets.values())],
          "symbol"    : [asset["symbol"] for asset in list(assets.values())],
          "asset_name": [asset["name"] for asset in list(assets.values())],
        }
    ).set_index("sid")
    #print("write data",df)
    self.env.write_data(equities=df)

  def get_blotter(self):
    slippage_func = VolumeShareSlippage(
      volume_limit=1,
      price_impact=0
    )
    blotter = MyBlotter(
      data_frequency=self.sim_params.data_frequency,
      asset_finder=self.env.asset_finder,
      slippage_func=slippage_func,
      # https://github.com/quantopian/zipline/blob/3350227f44dcf36b6fe3c509dcc35fe512965183/tests/test_blotter.py#L136
      cancel_policy=NeverCancel()
    )
    return blotter

  def _orders2blotter(self, orders, blotter):
    #print("Place orders")
    for sid in orders:

      # append 'id' in object, otherwise it will be lost after the sorting
      for oid,order in orders[sid].items():
        order["id"]=oid

      # sort by field
      # http://stackoverflow.com/questions/72899/ddg#73050
      orders2 = sorted(orders[sid].values(), key=lambda k: k['dt'])

      for order in orders2:
        # 2017-02-17: Actually it's a good idea to allo orders to match with earlier fill
        #             It allows to assign extra fills to an error account for example, or to another client
        #             Note that this is coupled with:
        #             - moving the _orders2blotter out of the for loop in match_orders_fills
        #             - adding the blotter.set_date below
        #             - sorting the orders by ascending time above
        # 2017-02-15: skip orders in the future
        #if order["dt"] > blotter.current_dt:
        #  #logger.debug("Order in future skipped: %s" % order)
        #  continue
        #if oid in blotter.orders:
        #  #logger.debug("Order already included: %s" % order)
        #  continue
        blotter.set_date(order["dt"])

        #logger.debug("Order included: %s" % order)
        asset = self.env.asset_finder.retrieve_asset(sid=sid, default_none=True)

        blotter.order(
          sid=asset,
          amount=order["amount"],
          style=order["style"],
          order_id = order["id"],
          validity = None if "validity" not in order else order["validity"]
        )

    #print("Open orders: %s" % ({k.symbol: len(v) for k,v in iteritems(blotter.open_orders)}))
    return blotter

  def blotter2bardata(self, equity_minute_reader, blotter):
    if equity_minute_reader is None:
      return None

    dp = DataPortal(
      asset_finder=self.env.asset_finder,
      trading_calendar=self.trading_calendar,
      first_trading_day=equity_minute_reader.first_trading_day,
      equity_minute_reader=equity_minute_reader
    )

    restrictions=NoRestrictions()

    bd = BarData(
      data_portal=dp,
      simulation_dt_func=lambda: blotter.current_dt,
      data_frequency=self.sim_params.data_frequency,
      trading_calendar=self.trading_calendar,
      restrictions=restrictions
    )

    return bd

  # Cannot use zipline cancellation policy based on order expiration type, expiration date, and current date because it will cancel all orders (and not filter based on type)
  # Will use cancel orders one-by-one instead, with the NeverCancel being the default
  #
  # Definition of cancellation policies
  # https://github.com/quantopian/zipline/blob/e1b27c45ae4b881e5416a5c50e8945232527ea59/zipline/finance/cancel_policy.py
  #
  # Execution of cancellation policy cancels all
  # https://github.com/quantopian/zipline/blob/3350227f44dcf36b6fe3c509dcc35fe512965183/zipline/finance/blotter.py#L238
  def _cancel_expired(self, blotter):
    #print('*'*100)
    for asset, orders in iteritems(blotter.open_orders):
      order_ids = [order.id for order in orders]
      #print('order ids', order_ids)
      for order_id in order_ids:
        order = blotter.orders[order_id]
        # cancel past day orders .. necessary if the latest open order s in the past
        logger.debug('0. look to cancel order %s, now: %s, blotter: %s, order: %s'%(order.id, timezone_django.now().day, blotter.current_dt.day, order.dt.day))
        if timezone_django.now().day != blotter.current_dt.day:
          if order.validity is not None:
            if order.validity['type']==ORDER_VALIDITY.DAY:
              if order.dt.day != timezone_django.now().day:
                logger.debug('1. will cancel order %s %s %s'%(order.id, timezone_django.now().day, blotter.current_dt.day))
                blotter.cancel(order.id)
                continue

        # do not cancel orders not reached yet with the clock
        if order.dt > blotter.current_dt:
          continue

        # treat orders without a validity as GTC orders
        if order.validity is None:
          continue

        if order.validity['type']==ORDER_VALIDITY.GTC:
          continue
        if order.validity['type']==ORDER_VALIDITY.GTD:
          if order.validity['date'] < blotter.current_dt:
            logger.debug('2. will cancel order %s'%order.id)
            blotter.cancel(order.id)
          continue
        if order.validity['type']==ORDER_VALIDITY.DAY:
          if order.dt.day != blotter.current_dt.day:
            logger.debug('3. will cancel order %s'%order.id)
            blotter.cancel(order.id)
          continue
        raise ValueError("Invalid order validity type used: %s"%order.validity['type'])

  def match_orders_fills(self, blotter, bar_data, all_minutes, orders):
    all_closed = []
    all_txns = []
    self._orders2blotter(orders,blotter)
    #print('all mins', all_minutes)
    for dt in all_minutes:
        logger.debug("======================== %s"%dt)
        dt = pd.Timestamp(dt, tz='utc')
        blotter.set_date(dt)

        self._cancel_expired(blotter)

        #self._orders2blotter(orders,blotter)
        #logger.debug("DQ1: %s" % (blotter.current_dt))
        #logger.debug("DQ6: %s" % blotter.open_orders)
        new_transactions, new_commissions, closed_orders = blotter.get_transactions(bar_data)

        #logger.debug("Closed orders: %s" % (len(closed_orders)))
        #for order in closed_orders:
        #  logger.debug("Closed orders: %s" % (order))
  
        #logger.debug("Transactions: %s" % (len(new_transactions)))
        #for txn in new_transactions:
        #  logger.debug("Transactions: %s" % (txn.to_dict()))
  
        #logger.debug("Commissions: %s" % (len(new_commissions)))
        #for txn in new_commissions:
        #  logger.debug("Commissions: %s" % (txn))

        blotter.prune_orders(closed_orders)
        ##logger.debug("Open orders: %s" % (len(blotter.open_orders[a1])))
        ##logger.debug("Open order status: %s" % ([o.open for o in blotter.open_orders[a1]]))

        all_closed = numpy.concatenate((all_closed,closed_orders))
        all_txns = numpy.concatenate((all_txns, new_transactions))

    # https://github.com/quantopian/zipline/blob/3350227f44dcf36b6fe3c509dcc35fe512965183/tests/test_blotter.py#L154
    for order in blotter.orders.values():
      if order.status==ORDER_STATUS.CANCELLED:
        order_cancelled.send(sender=None, id=order.id)

    return all_closed, all_txns

  # check if any volume was not used for the orders yet
  def unused_fills(self,all_txns,fills):
    unused = {}
    for sid, fill in fills.items():
      sub = [x.amount for x in all_txns if x.sid.sid==sid]
      extra = fill.volume.sum() - sum(sub)
      if extra!=0:
        asset = self.env.asset_finder.retrieve_asset(sid=sid,default_none=True)
        # if the asset was already dropped because it was a duplicate, ignore
        if asset is None:
          logger.warning("Ignoring asset "+str(sid)+" as it was not imported into ZlModel (possible duplicate symbol?)")
          continue
        unused[asset]=extra
    return unused

  @staticmethod
  def filterBySign(mySign, fills_all, orders_all):
    fills_sub ={}
    orders_sub={}

    for sid in fills_all:
      condition = fills_all[sid]['volume']*mySign > 0
      filtered = fills_all[sid][condition]
      if len(filtered)>0:
        # Need .copy
        # http://stackoverflow.com/a/32682095/4126114
        fills_sub[sid]=filtered.copy()

    for sid in orders_all:
      filtered = {}
      for oid, order in orders_all[sid].items():
        if order['amount']*mySign>0:
          filtered[oid]=order
      if len(filtered)>0:
        orders_sub[sid]=filtered

    return fills_sub, orders_sub
Ejemplo n.º 39
0
    def transaction_sim(self, **params):
        """ This is a utility method that asserts expected
        results for conversion of orders to transactions given a
        trade history"""
        tempdir = TempDirectory()
        try:
            trade_count = params['trade_count']
            trade_interval = params['trade_interval']
            order_count = params['order_count']
            order_amount = params['order_amount']
            order_interval = params['order_interval']
            expected_txn_count = params['expected_txn_count']
            expected_txn_volume = params['expected_txn_volume']

            # optional parameters
            # ---------------------
            # if present, alternate between long and short sales
            alternate = params.get('alternate')

            # if present, expect transaction amounts to match orders exactly.
            complete_fill = params.get('complete_fill')

            env = TradingEnvironment()

            sid = 1

            if trade_interval < timedelta(days=1):
                sim_params = factory.create_simulation_parameters(
                    data_frequency="minute"
                )

                minutes = env.market_minute_window(
                    sim_params.first_open,
                    int((trade_interval.total_seconds() / 60) * trade_count)
                    + 100)

                price_data = np.array([10.1] * len(minutes))
                assets = {
                    sid: pd.DataFrame({
                        "open": price_data,
                        "high": price_data,
                        "low": price_data,
                        "close": price_data,
                        "volume": np.array([100] * len(minutes)),
                        "dt": minutes
                    }).set_index("dt")
                }

                write_bcolz_minute_data(
                    env,
                    env.days_in_range(minutes[0], minutes[-1]),
                    tempdir.path,
                    assets
                )

                equity_minute_reader = BcolzMinuteBarReader(tempdir.path)

                data_portal = DataPortal(
                    env,
                    equity_minute_reader=equity_minute_reader,
                )
            else:
                sim_params = factory.create_simulation_parameters(
                    data_frequency="daily"
                )

                days = sim_params.trading_days

                assets = {
                    1: pd.DataFrame({
                        "open": [10.1] * len(days),
                        "high": [10.1] * len(days),
                        "low": [10.1] * len(days),
                        "close": [10.1] * len(days),
                        "volume": [100] * len(days),
                        "day": [day.value for day in days]
                    }, index=days)
                }

                path = os.path.join(tempdir.path, "testdata.bcolz")
                DailyBarWriterFromDataFrames(assets).write(
                    path, days, assets)

                equity_daily_reader = BcolzDailyBarReader(path)

                data_portal = DataPortal(
                    env,
                    equity_daily_reader=equity_daily_reader,
                )

            if "default_slippage" not in params or \
               not params["default_slippage"]:
                slippage_func = FixedSlippage()
            else:
                slippage_func = None

            blotter = Blotter(sim_params.data_frequency, self.env.asset_finder,
                              slippage_func)

            env.write_data(equities_data={
                sid: {
                    "start_date": sim_params.trading_days[0],
                    "end_date": sim_params.trading_days[-1]
                }
            })

            start_date = sim_params.first_open

            if alternate:
                alternator = -1
            else:
                alternator = 1

            tracker = PerformanceTracker(sim_params, self.env)

            # replicate what tradesim does by going through every minute or day
            # of the simulation and processing open orders each time
            if sim_params.data_frequency == "minute":
                ticks = minutes
            else:
                ticks = days

            transactions = []

            order_list = []
            order_date = start_date
            for tick in ticks:
                blotter.current_dt = tick
                if tick >= order_date and len(order_list) < order_count:
                    # place an order
                    direction = alternator ** len(order_list)
                    order_id = blotter.order(
                        blotter.asset_finder.retrieve_asset(sid),
                        order_amount * direction,
                        MarketOrder())
                    order_list.append(blotter.orders[order_id])
                    order_date = order_date + order_interval
                    # move after market orders to just after market next
                    # market open.
                    if order_date.hour >= 21:
                        if order_date.minute >= 00:
                            order_date = order_date + timedelta(days=1)
                            order_date = order_date.replace(hour=14, minute=30)
                else:
                    bar_data = BarData(
                        data_portal,
                        lambda: tick,
                        sim_params.data_frequency
                    )
                    txns, _ = blotter.get_transactions(bar_data)
                    for txn in txns:
                        tracker.process_transaction(txn)
                        transactions.append(txn)

            for i in range(order_count):
                order = order_list[i]
                self.assertEqual(order.sid, sid)
                self.assertEqual(order.amount, order_amount * alternator ** i)

            if complete_fill:
                self.assertEqual(len(transactions), len(order_list))

            total_volume = 0
            for i in range(len(transactions)):
                txn = transactions[i]
                total_volume += txn.amount
                if complete_fill:
                    order = order_list[i]
                    self.assertEqual(order.amount, txn.amount)

            self.assertEqual(total_volume, expected_txn_volume)

            self.assertEqual(len(transactions), expected_txn_count)

            cumulative_pos = tracker.position_tracker.positions[sid]
            if total_volume == 0:
                self.assertIsNone(cumulative_pos)
            else:
                self.assertEqual(total_volume, cumulative_pos.amount)

            # the open orders should not contain sid.
            oo = blotter.open_orders
            self.assertNotIn(sid, oo, "Entry is removed when no open orders")
        finally:
            tempdir.cleanup()