Beispiel #1
0
    def test_history_grow_length(self,
                                 freq,
                                 field,
                                 data_frequency,
                                 construct_digest):
        bar_count = 2 if construct_digest else 1
        spec = history.HistorySpec(
            bar_count=bar_count,
            frequency=freq,
            field=field,
            ffill=True,
            data_frequency=data_frequency,
        )
        specs = {spec.key_str: spec}
        initial_sids = [1]
        initial_dt = pd.Timestamp(
            '2013-06-28 13:31AM'
            if data_frequency == 'minute'
            else '2013-06-28 12:00AM',
            tz='UTC',
        )

        container = HistoryContainer(
            specs, initial_sids, initial_dt, data_frequency,
        )

        if construct_digest:
            self.assertEqual(
                container.digest_panels[spec.frequency].window_length, 1,
            )

        bar_data = BarData()
        container.update(bar_data, initial_dt)

        to_add = (
            history.HistorySpec(
                bar_count=bar_count + 1,
                frequency=freq,
                field=field,
                ffill=True,
                data_frequency=data_frequency,
            ),
            history.HistorySpec(
                bar_count=bar_count + 2,
                frequency=freq,
                field=field,
                ffill=True,
                data_frequency=data_frequency,
            ),
        )

        for spec in to_add:
            container.ensure_spec(spec, initial_dt, bar_data)

            self.assertEqual(
                container.digest_panels[spec.frequency].window_length,
                spec.bar_count - 1,
            )

            self.assert_history(container, spec, initial_dt)
Beispiel #2
0
    def test_history_container(self,
                               name,
                               specs,
                               sids,
                               dt,
                               updates,
                               expected):

        for spec in specs:
            # Sanity check on test input.
            self.assertEqual(len(expected[spec.key_str]), len(updates))

        container = HistoryContainer(
            {spec.key_str: spec for spec in specs}, sids, dt, 'minute',
        )

        for update_count, update in enumerate(updates):

            bar_dt = self.bar_data_dt(update)
            container.update(update, bar_dt)

            for spec in specs:
                pd.util.testing.assert_frame_equal(
                    container.get_history(spec, bar_dt),
                    expected[spec.key_str][update_count],
                    check_dtype=False,
                    check_column_type=True,
                    check_index_type=True,
                    check_frame_type=True,
                )
Beispiel #3
0
    def test_history_add_freq(self, bar_count, pair, field, data_frequency):
        first, second = pair
        spec = history.HistorySpec(
            bar_count=bar_count, frequency=first, field=field, ffill=True, data_frequency=data_frequency, env=self.env
        )
        specs = {spec.key_str: spec}
        initial_sids = [1]
        initial_dt = pd.Timestamp("2013-06-28 13:31" if data_frequency == "minute" else "2013-06-28 12:00AM", tz="UTC")

        container = HistoryContainer(specs, initial_sids, initial_dt, data_frequency, env=self.env)

        if bar_count > 1:
            self.assertEqual(container.digest_panels[spec.frequency].window_length, 1)

        bar_data = BarData()
        container.update(bar_data, initial_dt)

        new_spec = history.HistorySpec(
            bar_count, frequency=second, field=field, ffill=True, data_frequency=data_frequency, env=self.env
        )

        container.ensure_spec(new_spec, initial_dt, bar_data)

        if bar_count > 1:
            digest_panel = container.digest_panels[new_spec.frequency]
            self.assertEqual(digest_panel.window_length, bar_count - 1)
        else:
            self.assertNotIn(new_spec.frequency, container.digest_panels)

        self.assert_history(container, new_spec, initial_dt)
Beispiel #4
0
    def test_multiple_specs_on_same_bar(self):
        """
        Test that a ffill and non ffill spec both get
        the correct results when called on the same tick
        """
        spec = history.HistorySpec(
            bar_count=3, frequency="1m", field="price", ffill=True, data_frequency="minute", env=self.env
        )
        no_fill_spec = history.HistorySpec(
            bar_count=3, frequency="1m", field="price", ffill=False, data_frequency="minute", env=self.env
        )

        specs = {spec.key_str: spec, no_fill_spec.key_str: no_fill_spec}
        initial_sids = [1]
        initial_dt = pd.Timestamp("2013-06-28 9:31AM", tz="US/Eastern").tz_convert("UTC")

        container = HistoryContainer(specs, initial_sids, initial_dt, "minute", env=self.env)

        bar_data = BarData()
        container.update(bar_data, initial_dt)
        # Add data on bar two of first day.
        second_bar_dt = pd.Timestamp("2013-06-28 9:32AM", tz="US/Eastern").tz_convert("UTC")
        bar_data[1] = {"price": 10, "dt": second_bar_dt}
        container.update(bar_data, second_bar_dt)

        third_bar_dt = pd.Timestamp("2013-06-28 9:33AM", tz="US/Eastern").tz_convert("UTC")

        del bar_data[1]

        # add nan for 3rd bar
        container.update(bar_data, third_bar_dt)
        prices = container.get_history(spec, third_bar_dt)
        no_fill_prices = container.get_history(no_fill_spec, third_bar_dt)
        self.assertEqual(prices.values[-1], 10)
        self.assertTrue(np.isnan(no_fill_prices.values[-1]), "Last price should be np.nan")
Beispiel #5
0
    def test_history_add_field(self, bar_count, freq, pair, data_frequency):
        first, second = pair
        spec = history.HistorySpec(
            bar_count=bar_count,
            frequency=freq,
            field=first,
            ffill=True,
            data_frequency=data_frequency,
            env=self.env,
        )
        specs = {spec.key_str: spec}
        initial_sids = [1]
        initial_dt = pd.Timestamp(
            '2013-06-28 13:31'
            if data_frequency == 'minute'
            else '2013-06-28 12:00AM',
            tz='UTC',
        )

        container = HistoryContainer(
            specs, initial_sids, initial_dt, data_frequency, env=self.env
        )

        if bar_count > 1:
            self.assertEqual(
                container.digest_panels[spec.frequency].window_length, 1,
            )

        bar_data = BarData()
        container.update(bar_data, initial_dt)

        new_spec = history.HistorySpec(
            bar_count,
            frequency=freq,
            field=second,
            ffill=True,
            data_frequency=data_frequency,
            env=self.env,
        )

        container.ensure_spec(new_spec, initial_dt, bar_data)

        if bar_count > 1:
            digest_panel = container.digest_panels[new_spec.frequency]
            self.assertEqual(digest_panel.window_length, bar_count - 1)
            self.assertIn(second, digest_panel.items)
        else:
            self.assertNotIn(new_spec.frequency, container.digest_panels)

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')

            self.assert_history(container, new_spec, initial_dt)
Beispiel #6
0
    def test_container_nans_and_daily_roll(self):

        spec = history.HistorySpec(
            bar_count=3,
            frequency='1d',
            field='price',
            ffill=True,
            data_frequency='minute'
        )
        specs = {spec.key_str: spec}
        initial_sids = [1, ]
        initial_dt = pd.Timestamp(
            '2013-06-28 9:31AM', tz='US/Eastern').tz_convert('UTC')

        container = HistoryContainer(
            specs, initial_sids, initial_dt, 'minute'
        )

        bar_data = BarData()
        container.update(bar_data, initial_dt)
        # Since there was no backfill because of no db.
        # And no first bar of data, so all values should be nans.
        prices = container.get_history(spec, initial_dt)
        nan_values = np.isnan(prices[1])
        self.assertTrue(all(nan_values), nan_values)

        # Add data on bar two of first day.
        second_bar_dt = pd.Timestamp(
            '2013-06-28 9:32AM', tz='US/Eastern').tz_convert('UTC')

        bar_data[1] = {
            'price': 10,
            'dt': second_bar_dt
        }
        container.update(bar_data, second_bar_dt)

        prices = container.get_history(spec, second_bar_dt)
        # Prices should be
        #                             1
        # 2013-06-26 20:00:00+00:00 NaN
        # 2013-06-27 20:00:00+00:00 NaN
        # 2013-06-28 13:32:00+00:00  10

        self.assertTrue(np.isnan(prices[1].ix[0]))
        self.assertTrue(np.isnan(prices[1].ix[1]))
        self.assertEqual(prices[1].ix[2], 10)

        third_bar_dt = pd.Timestamp(
            '2013-06-28 9:33AM', tz='US/Eastern').tz_convert('UTC')

        del bar_data[1]

        container.update(bar_data, third_bar_dt)

        prices = container.get_history(spec, third_bar_dt)
        # The one should be forward filled

        # Prices should be
        #                             1
        # 2013-06-26 20:00:00+00:00 NaN
        # 2013-06-27 20:00:00+00:00 NaN
        # 2013-06-28 13:33:00+00:00  10

        self.assertEquals(prices[1][third_bar_dt], 10)

        # Note that we did not fill in data at the close.
        # There was a bug where a nan was being introduced because of the
        # last value of 'raw' data was used, instead of a ffilled close price.

        day_two_first_bar_dt = pd.Timestamp(
            '2013-07-01 9:31AM', tz='US/Eastern').tz_convert('UTC')

        bar_data[1] = {
            'price': 20,
            'dt': day_two_first_bar_dt
        }

        container.update(bar_data, day_two_first_bar_dt)

        prices = container.get_history(spec, day_two_first_bar_dt)

        # Prices Should Be

        #                              1
        # 2013-06-27 20:00:00+00:00  nan
        # 2013-06-28 20:00:00+00:00   10
        # 2013-07-01 13:31:00+00:00   20

        self.assertTrue(np.isnan(prices[1].ix[0]))
        self.assertEqual(prices[1].ix[1], 10)
        self.assertEqual(prices[1].ix[2], 20)

        # Clear out the bar data

        del bar_data[1]

        day_three_first_bar_dt = pd.Timestamp(
            '2013-07-02 9:31AM', tz='US/Eastern').tz_convert('UTC')

        container.update(bar_data, day_three_first_bar_dt)

        prices = container.get_history(spec, day_three_first_bar_dt)

        #                             1
        # 2013-06-28 20:00:00+00:00  10
        # 2013-07-01 20:00:00+00:00  20
        # 2013-07-02 13:31:00+00:00  20

        self.assertTrue(prices[1].ix[0], 10)
        self.assertTrue(prices[1].ix[1], 20)
        self.assertTrue(prices[1].ix[2], 20)

        day_four_first_bar_dt = pd.Timestamp(
            '2013-07-03 9:31AM', tz='US/Eastern').tz_convert('UTC')

        container.update(bar_data, day_four_first_bar_dt)

        prices = container.get_history(spec, day_four_first_bar_dt)

        #                             1
        # 2013-07-01 20:00:00+00:00  20
        # 2013-07-02 20:00:00+00:00  20
        # 2013-07-03 13:31:00+00:00  20

        self.assertEqual(prices[1].ix[0], 20)
        self.assertEqual(prices[1].ix[1], 20)
        self.assertEqual(prices[1].ix[2], 20)
Beispiel #7
0
class TradingAlgorithm(object):
    """
    Base class for trading algorithms. Inherit and overload
    initialize() and handle_data(data).

    A new algorithm could look like this:
    ```
    from zipline.api import order

    def initialize(context):
        context.sid = 'AAPL'
        context.amount = 100

    def handle_data(self, data):
        sid = context.sid
        amount = context.amount
        order(sid, amount)
    ```
    To then to run this algorithm pass these functions to
    TradingAlgorithm:

    my_algo = TradingAlgorithm(initialize, handle_data)
    stats = my_algo.run(data)

    """

    # If this is set to false then it is the responsibility
    # of the overriding subclass to set initialized = true
    AUTO_INITIALIZE = True

    def __init__(self, *args, **kwargs):
        """Initialize sids and other state variables.

        :Arguments:
        :Optional:
            initialize : function
                Function that is called with a single
                argument at the begninning of the simulation.
            handle_data : function
                Function that is called with 2 arguments
                (context and data) on every bar.
            script : str
                Algoscript that contains initialize and
                handle_data function definition.
            data_frequency : str (daily, hourly or minutely)
               The duration of the bars.
            capital_base : float <default: 1.0e5>
               How much capital to start with.
            instant_fill : bool <default: False>
               Whether to fill orders immediately or on next bar.
            environment : str <default: 'zipline'>
               The environment that this algorithm is running in.
        """
        self.datetime = None

        self.registered_transforms = {}
        self.transforms = []
        self.sources = []

        # List of trading controls to be used to validate orders.
        self.trading_controls = []

        self._recorded_vars = {}
        self.namespace = kwargs.get('namespace', {})

        self._environment = kwargs.pop('environment', 'zipline')

        self.logger = None

        self.benchmark_return_source = None

        # default components for transact
        self.slippage = VolumeShareSlippage()
        self.commission = PerShare()

        self.instant_fill = kwargs.pop('instant_fill', False)

        # set the capital base
        self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)

        self.sim_params = kwargs.pop('sim_params', None)
        if self.sim_params is None:
            self.sim_params = create_simulation_parameters(
                capital_base=self.capital_base
            )
        self.perf_tracker = PerformanceTracker(self.sim_params)

        self.blotter = kwargs.pop('blotter', None)
        if not self.blotter:
            self.blotter = Blotter()

        self.portfolio_needs_update = True
        self.account_needs_update = True
        self.performance_needs_update = True
        self._portfolio = None
        self._account = None

        self.history_container = None
        self.history_specs = {}

        # If string is passed in, execute and get reference to
        # functions.
        self.algoscript = kwargs.pop('script', None)

        self._initialize = None
        self._before_trading_start = None
        self._analyze = None

        self.event_manager = EventManager()

        if self.algoscript is not None:
            exec_(self.algoscript, self.namespace)
            self._initialize = self.namespace.get('initialize')
            if 'handle_data' not in self.namespace:
                raise ValueError('You must define a handle_data function.')
            else:
                self._handle_data = self.namespace['handle_data']

            self._before_trading_start = \
                self.namespace.get('before_trading_start')
            # Optional analyze function, gets called after run
            self._analyze = self.namespace.get('analyze')

        elif kwargs.get('initialize') and kwargs.get('handle_data'):
            if self.algoscript is not None:
                raise ValueError('You can not set script and \
                initialize/handle_data.')
            self._initialize = kwargs.pop('initialize')
            self._handle_data = kwargs.pop('handle_data')
            self._before_trading_start = kwargs.pop('before_trading_start',
                                                    None)

        self.event_manager.add_event(
            zipline.utils.events.Event(
                zipline.utils.events.Always(),
                # We pass handle_data.__func__ to get the unbound method.
                # We will explicitly pass the algorithm to bind it again.
                self.handle_data.__func__,
            ),
            prepend=True,
        )

        # If method not defined, NOOP
        if self._initialize is None:
            self._initialize = lambda x: None

        # Alternative way of setting data_frequency for backwards
        # compatibility.
        if 'data_frequency' in kwargs:
            self.data_frequency = kwargs.pop('data_frequency')

        # Subclasses that override initialize should only worry about
        # setting self.initialized = True if AUTO_INITIALIZE is
        # is manually set to False.
        self.initialized = False
        self.initialize(*args, **kwargs)
        if self.AUTO_INITIALIZE:
            self.initialized = True

    def initialize(self, *args, **kwargs):
        """
        Call self._initialize with `self` made available to Zipline API
        functions.
        """
        with ZiplineAPI(self):
            self._initialize(self)

    def before_trading_start(self):
        if self._before_trading_start is None:
            return

        self._before_trading_start(self)

    def handle_data(self, data):
        if self.history_container:
            self.history_container.update(data, self.datetime)

        self._handle_data(self, data)

    def analyze(self, perf):
        if self._analyze is None:
            return

        with ZiplineAPI(self):
            self._analyze(self, perf)

    def __repr__(self):
        """
        N.B. this does not yet represent a string that can be used
        to instantiate an exact copy of an algorithm.

        However, it is getting close, and provides some value as something
        that can be inspected interactively.
        """
        return """
{class_name}(
    capital_base={capital_base}
    sim_params={sim_params},
    initialized={initialized},
    slippage={slippage},
    commission={commission},
    blotter={blotter},
    recorded_vars={recorded_vars})
""".strip().format(class_name=self.__class__.__name__,
                   capital_base=self.capital_base,
                   sim_params=repr(self.sim_params),
                   initialized=self.initialized,
                   slippage=repr(self.slippage),
                   commission=repr(self.commission),
                   blotter=repr(self.blotter),
                   recorded_vars=repr(self.recorded_vars))

    def _create_data_generator(self, source_filter, sim_params=None):
        """
        Create a merged data generator using the sources and
        transforms attached to this algorithm.

        ::source_filter:: is a method that receives events in date
        sorted order, and returns True for those events that should be
        processed by the zipline, and False for those that should be
        skipped.
        """
        if sim_params is None:
            sim_params = self.sim_params

        if self.benchmark_return_source is None:
            env = trading.environment
            if (sim_params.data_frequency == 'minute'
                    or sim_params.emission_rate == 'minute'):
                update_time = lambda date: env.get_open_and_close(date)[1]
            else:
                update_time = lambda date: date
            benchmark_return_source = [
                Event({'dt': update_time(dt),
                       'returns': ret,
                       'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK,
                       'source_id': 'benchmarks'})
                for dt, ret in trading.environment.benchmark_returns.iterkv()
                if dt.date() >= sim_params.period_start.date()
                and dt.date() <= sim_params.period_end.date()
            ]
        else:
            benchmark_return_source = self.benchmark_return_source

        date_sorted = date_sorted_sources(*self.sources)

        if source_filter:
            date_sorted = filter(source_filter, date_sorted)

        with_tnfms = sequential_transforms(date_sorted,
                                           *self.transforms)

        with_benchmarks = date_sorted_sources(benchmark_return_source,
                                              with_tnfms)

        # Group together events with the same dt field. This depends on the
        # events already being sorted.
        return groupby(with_benchmarks, attrgetter('dt'))

    def _create_generator(self, sim_params, source_filter=None):
        """
        Create a basic generator setup using the sources and
        transforms attached to this algorithm.

        ::source_filter:: is a method that receives events in date
        sorted order, and returns True for those events that should be
        processed by the zipline, and False for those that should be
        skipped.
        """
        if self.perf_tracker is None:
            # HACK: When running with the `run` method, we set perf_tracker to
            # None so that it will be overwritten here.
            self.perf_tracker = PerformanceTracker(sim_params)

        self.portfolio_needs_update = True
        self.account_needs_update = True
        self.performance_needs_update = True

        self.data_gen = self._create_data_generator(source_filter, sim_params)

        self.trading_client = AlgorithmSimulator(self, sim_params)

        transact_method = transact_partial(self.slippage, self.commission)
        self.set_transact(transact_method)

        return self.trading_client.transform(self.data_gen)

    def get_generator(self):
        """
        Override this method to add new logic to the construction
        of the generator. Overrides can use the _create_generator
        method to get a standard construction generator.
        """
        return self._create_generator(self.sim_params)

    # TODO: make a new subclass, e.g. BatchAlgorithm, and move
    # the run method to the subclass, and refactor to put the
    # generator creation logic into get_generator.
    def run(self, source, overwrite_sim_params=True,
            benchmark_return_source=None):
        """Run the algorithm.

        :Arguments:
            source : can be either:
                     - pandas.DataFrame
                     - zipline source
                     - list of sources

               If pandas.DataFrame is provided, it must have the
               following structure:
               * column names must consist of ints representing the
                 different sids
               * index must be DatetimeIndex
               * array contents should be price info.

        :Returns:
            daily_stats : pandas.DataFrame
              Daily performance metrics such as returns, alpha etc.

        """
        if isinstance(source, list):
            if overwrite_sim_params:
                warnings.warn("""List of sources passed, will not attempt to extract sids, and start and end
 dates. Make sure to set the correct fields in sim_params passed to
 __init__().""", UserWarning)
                overwrite_sim_params = False
        elif isinstance(source, pd.DataFrame):
            # if DataFrame provided, wrap in DataFrameSource
            source = DataFrameSource(source)
        elif isinstance(source, pd.Panel):
            source = DataPanelSource(source)

        if isinstance(source, list):
            self.set_sources(source)
        else:
            self.set_sources([source])

        # Override sim_params if params are provided by the source.
        if overwrite_sim_params:
            if hasattr(source, 'start'):
                self.sim_params.period_start = source.start
            if hasattr(source, 'end'):
                self.sim_params.period_end = source.end
            all_sids = [sid for s in self.sources for sid in s.sids]
            self.sim_params.sids = set(all_sids)
            # Changing period_start and period_close might require updating
            # of first_open and last_close.
            self.sim_params._update_internal()

        # Create history containers
        if len(self.history_specs) != 0:
            self.history_container = HistoryContainer(
                self.history_specs,
                self.sim_params.sids,
                self.sim_params.first_open)

        # Create transforms by wrapping them into StatefulTransforms
        self.transforms = []
        for namestring, trans_descr in iteritems(self.registered_transforms):
            sf = StatefulTransform(
                trans_descr['class'],
                *trans_descr['args'],
                **trans_descr['kwargs']
            )
            sf.namestring = namestring

            self.transforms.append(sf)

        # force a reset of the performance tracker, in case
        # this is a repeat run of the algorithm.
        self.perf_tracker = None

        # create transforms and zipline
        self.gen = self._create_generator(self.sim_params)

        with ZiplineAPI(self):
            # loop through simulated_trading, each iteration returns a
            # perf dictionary
            perfs = []
            for perf in self.gen:
                perfs.append(perf)

            # convert perf dict to pandas dataframe
            daily_stats = self._create_daily_stats(perfs)

        self.analyze(daily_stats)

        return daily_stats

    def _create_daily_stats(self, perfs):
        # create daily and cumulative stats dataframe
        daily_perfs = []
        # TODO: the loop here could overwrite expected properties
        # of daily_perf. Could potentially raise or log a
        # warning.
        for perf in perfs:
            if 'daily_perf' in perf:

                perf['daily_perf'].update(
                    perf['daily_perf'].pop('recorded_vars')
                )
                daily_perfs.append(perf['daily_perf'])
            else:
                self.risk_report = perf

        daily_dts = [np.datetime64(perf['period_close'], utc=True)
                     for perf in daily_perfs]
        daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)

        return daily_stats

    def add_transform(self, transform_class, tag, *args, **kwargs):
        """Add a single-sid, sequential transform to the model.

        :Arguments:
            transform_class : class
                Which transform to use. E.g. mavg.
            tag : str
                How to name the transform. Can later be access via:
                data[sid].tag()

        Extra args and kwargs will be forwarded to the transform
        instantiation.

        """
        self.registered_transforms[tag] = {'class': transform_class,
                                           'args': args,
                                           'kwargs': kwargs}

    @api_method
    def get_environment(self):
        return self._environment

    def add_event(self, rule=None, callback=None):
        """
        Adds an event to the algorithm's EventManager.
        """
        self.event_manager.add_event(
            zipline.utils.events.Event(rule, callback),
        )

    @api_method
    def schedule_function(self,
                          func,
                          date_rule=None,
                          time_rule=None,
                          half_days=True):
        """
        Schedules a function to be called with some timed rules.
        """
        if self.sim_params.data_frequency != 'minute':
            raise IncompatibleScheduleFunctionDataFrequency()

        date_rule = date_rule or DateRuleFactory.every_day()
        time_rule = time_rule or TimeRuleFactory.market_open()

        self.add_event(
            make_eventrule(date_rule, time_rule, half_days),
            func,
        )

    @api_method
    def record(self, *args, **kwargs):
        """
        Track and record local variable (i.e. attributes) each day.
        """
        # Make 2 objects both referencing the same iterator
        args = [iter(args)] * 2

        # Zip generates list entries by calling `next` on each iterator it
        # receives.  In this case the two iterators are the same object, so the
        # call to next on args[0] will also advance args[1], resulting in zip
        # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.
        positionals = zip(*args)
        for name, value in chain(positionals, iteritems(kwargs)):
            self._recorded_vars[name] = value

    @api_method
    def symbol(self, symbol_str, as_of_date=None):
        """
        Default symbol lookup for any source that directly maps the
        symbol to the identifier (e.g. yahoo finance).
        Keyword argument as_of_date is ignored.
        """
        return symbol_str

    @api_method
    def order(self, sid, amount,
              limit_price=None,
              stop_price=None,
              style=None):
        """
        Place an order using the specified parameters.
        """

        def round_if_near_integer(a, epsilon=1e-4):
            """
            Round a to the nearest integer if that integer is within an epsilon
            of a.
            """
            if abs(a - round(a)) <= epsilon:
                return round(a)
            else:
                return a

        # Truncate to the integer share count that's either within .0001 of
        # amount or closer to zero.
        # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0
        amount = int(round_if_near_integer(amount))

        # Raises a ZiplineError if invalid parameters are detected.
        self.validate_order_params(sid,
                                   amount,
                                   limit_price,
                                   stop_price,
                                   style)

        # Convert deprecated limit_price and stop_price parameters to use
        # ExecutionStyle objects.
        style = self.__convert_order_params_for_blotter(limit_price,
                                                        stop_price,
                                                        style)
        return self.blotter.order(sid, amount, style)

    def validate_order_params(self,
                              sid,
                              amount,
                              limit_price,
                              stop_price,
                              style):
        """
        Helper method for validating parameters to the order API function.

        Raises an UnsupportedOrderParameters if invalid arguments are found.
        """

        if not self.initialized:
            raise OrderDuringInitialize(
                msg="order() can only be called from within handle_data()"
            )

        if style:
            if limit_price:
                raise UnsupportedOrderParameters(
                    msg="Passing both limit_price and style is not supported."
                )

            if stop_price:
                raise UnsupportedOrderParameters(
                    msg="Passing both stop_price and style is not supported."
                )

        for control in self.trading_controls:
            control.validate(sid,
                             amount,
                             self.updated_portfolio(),
                             self.get_datetime(),
                             self.trading_client.current_data)

    @staticmethod
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
        """
        Helper method for converting deprecated limit_price and stop_price
        arguments into ExecutionStyle instances.

        This function assumes that either style == None or (limit_price,
        stop_price) == (None, None).
        """
        # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.
        if style:
            assert (limit_price, stop_price) == (None, None)
            return style
        if limit_price and stop_price:
            return StopLimitOrder(limit_price, stop_price)
        if limit_price:
            return LimitOrder(limit_price)
        if stop_price:
            return StopOrder(stop_price)
        else:
            return MarketOrder()

    @api_method
    def order_value(self, sid, value,
                    limit_price=None, stop_price=None, style=None):
        """
        Place an order by desired value rather than desired number of shares.
        If the requested sid is found in the universe, the requested value is
        divided by its price to imply the number of shares to transact.

        value > 0 :: Buy/Cover
        value < 0 :: Sell/Short
        Market order:    order(sid, value)
        Limit order:     order(sid, value, limit_price)
        Stop order:      order(sid, value, None, stop_price)
        StopLimit order: order(sid, value, limit_price, stop_price)
        """
        last_price = self.trading_client.current_data[sid].price
        if np.allclose(last_price, 0):
            zero_message = "Price of 0 for {psid}; can't infer value".format(
                psid=sid
            )
            if self.logger:
                self.logger.debug(zero_message)
            # Don't place any order
            return
        else:
            amount = value / last_price
            return self.order(sid, amount,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)

    @property
    def recorded_vars(self):
        return copy(self._recorded_vars)

    @property
    def portfolio(self):
        return self.updated_portfolio()

    def updated_portfolio(self):
        if self.portfolio_needs_update:
            self._portfolio = \
                self.perf_tracker.get_portfolio(self.performance_needs_update)
            self.portfolio_needs_update = False
            self.performance_needs_update = False
        return self._portfolio

    @property
    def account(self):
        return self.updated_account()

    def updated_account(self):
        if self.account_needs_update:
            self._account = \
                self.perf_tracker.get_account(self.performance_needs_update)
            self.account_needs_update = False
            self.performance_needs_update = False
        return self._account

    def set_logger(self, logger):
        self.logger = logger

    def on_dt_changed(self, dt):
        """
        Callback triggered by the simulation loop whenever the current dt
        changes.

        Any logic that should happen exactly once at the start of each datetime
        group should happen here.
        """
        assert isinstance(dt, datetime), \
            "Attempt to set algorithm's current time with non-datetime"
        assert dt.tzinfo == pytz.utc, \
            "Algorithm expects a utc datetime"

        self.datetime = dt
        self.perf_tracker.set_date(dt)
        self.blotter.set_date(dt)

    @api_method
    def get_datetime(self):
        """
        Returns a copy of the datetime.
        """
        date_copy = copy(self.datetime)
        assert date_copy.tzinfo == pytz.utc, \
            "Algorithm should have a utc datetime"
        return date_copy

    def set_transact(self, transact):
        """
        Set the method that will be called to create a
        transaction from open orders and trade events.
        """
        self.blotter.transact = transact

    def update_dividends(self, dividend_frame):
        """
        Set DataFrame used to process dividends.  DataFrame columns should
        contain at least the entries in zp.DIVIDEND_FIELDS.
        """
        self.perf_tracker.update_dividends(dividend_frame)

    @api_method
    def set_slippage(self, slippage):
        if not isinstance(slippage, SlippageModel):
            raise UnsupportedSlippageModel()
        if self.initialized:
            raise OverrideSlippagePostInit()
        self.slippage = slippage

    @api_method
    def set_commission(self, commission):
        if not isinstance(commission, (PerShare, PerTrade, PerDollar)):
            raise UnsupportedCommissionModel()

        if self.initialized:
            raise OverrideCommissionPostInit()
        self.commission = commission

    def set_sources(self, sources):
        assert isinstance(sources, list)
        self.sources = sources

    def set_transforms(self, transforms):
        assert isinstance(transforms, list)
        self.transforms = transforms

    # Remain backwards compatibility
    @property
    def data_frequency(self):
        return self.sim_params.data_frequency

    @data_frequency.setter
    def data_frequency(self, value):
        assert value in ('daily', 'minute')
        self.sim_params.data_frequency = value

    @api_method
    def order_percent(self, sid, percent,
                      limit_price=None, stop_price=None, style=None):
        """
        Place an order in the specified security corresponding to the given
        percent of the current portfolio value.

        Note that percent must expressed as a decimal (0.50 means 50\%).
        """
        value = self.portfolio.portfolio_value * percent
        return self.order_value(sid, value,
                                limit_price=limit_price,
                                stop_price=stop_price,
                                style=style)

    @api_method
    def order_target(self, sid, target,
                     limit_price=None, stop_price=None, style=None):
        """
        Place an order to adjust a position to a target number of shares. If
        the position doesn't already exist, this is equivalent to placing a new
        order. If the position does exist, this is equivalent to placing an
        order for the difference between the target number of shares and the
        current number of shares.
        """
        if sid in self.portfolio.positions:
            current_position = self.portfolio.positions[sid].amount
            req_shares = target - current_position
            return self.order(sid, req_shares,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)
        else:
            return self.order(sid, target,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)

    @api_method
    def order_target_value(self, sid, target,
                           limit_price=None, stop_price=None, style=None):
        """
        Place an order to adjust a position to a target value. If
        the position doesn't already exist, this is equivalent to placing a new
        order. If the position does exist, this is equivalent to placing an
        order for the difference between the target value and the
        current value.
        """
        last_price = self.trading_client.current_data[sid].price
        if np.allclose(last_price, 0):
            # Don't place an order
            if self.logger:
                zero_message = "Price of 0 for {psid}; can't infer value"
                self.logger.debug(zero_message.format(psid=sid))
            return
        target_amount = target / last_price
        return self.order_target(sid, target_amount,
                                 limit_price=limit_price,
                                 stop_price=stop_price,
                                 style=style)

    @api_method
    def order_target_percent(self, sid, target,
                             limit_price=None, stop_price=None, style=None):
        """
        Place an order to adjust a position to a target percent of the
        current portfolio value. If the position doesn't already exist, this is
        equivalent to placing a new order. If the position does exist, this is
        equivalent to placing an order for the difference between the target
        percent and the current percent.

        Note that target must expressed as a decimal (0.50 means 50\%).
        """
        target_value = self.portfolio.portfolio_value * target
        return self.order_target_value(sid, target_value,
                                       limit_price=limit_price,
                                       stop_price=stop_price,
                                       style=style)

    @api_method
    def get_open_orders(self, sid=None):
        if sid is None:
            return {
                key: [order.to_api_obj() for order in orders]
                for key, orders in iteritems(self.blotter.open_orders)
                if orders
            }
        if sid in self.blotter.open_orders:
            orders = self.blotter.open_orders[sid]
            return [order.to_api_obj() for order in orders]
        return []

    @api_method
    def get_order(self, order_id):
        if order_id in self.blotter.orders:
            return self.blotter.orders[order_id].to_api_obj()

    @api_method
    def cancel_order(self, order_param):
        order_id = order_param
        if isinstance(order_param, zipline.protocol.Order):
            order_id = order_param.id

        self.blotter.cancel(order_id)

    @api_method
    def add_history(self, bar_count, frequency, field,
                    ffill=True):
        data_frequency = self.sim_params.data_frequency
        daily_at_midnight = (data_frequency == 'daily')

        history_spec = HistorySpec(bar_count, frequency, field, ffill,
                                   daily_at_midnight=daily_at_midnight,
                                   data_frequency=data_frequency)
        self.history_specs[history_spec.key_str] = history_spec

    @api_method
    def history(self, bar_count, frequency, field, ffill=True):
        spec_key_str = HistorySpec.spec_key(
            bar_count, frequency, field, ffill)
        history_spec = self.history_specs[spec_key_str]
        return self.history_container.get_history(history_spec, self.datetime)

    ####################
    # Trading Controls #
    ####################

    def register_trading_control(self, control):
        """
        Register a new TradingControl to be checked prior to order calls.
        """
        if self.initialized:
            raise RegisterTradingControlPostInit()
        self.trading_controls.append(control)

    @api_method
    def set_max_position_size(self,
                              sid=None,
                              max_shares=None,
                              max_notional=None):
        """
        Set a limit on the number of shares and/or dollar value held for the
        given sid. Limits are treated as absolute values and are enforced at
        the time that the algo attempts to place an order for sid. This means
        that it's possible to end up with more than the max number of shares
        due to splits/dividends, and more than the max notional due to price
        improvement.

        If an algorithm attempts to place an order that would result in
        increasing the absolute value of shares/dollar value exceeding one of
        these limits, raise a TradingControlException.
        """
        control = MaxPositionSize(sid=sid,
                                  max_shares=max_shares,
                                  max_notional=max_notional)
        self.register_trading_control(control)

    @api_method
    def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
        """
        Set a limit on the number of shares and/or dollar value of any single
        order placed for sid.  Limits are treated as absolute values and are
        enforced at the time that the algo attempts to place an order for sid.

        If an algorithm attempts to place an order that would result in
        exceeding one of these limits, raise a TradingControlException.
        """
        control = MaxOrderSize(sid=sid,
                               max_shares=max_shares,
                               max_notional=max_notional)
        self.register_trading_control(control)

    @api_method
    def set_max_order_count(self, max_count):
        """
        Set a limit on the number of orders that can be placed within the given
        time interval.
        """
        control = MaxOrderCount(max_count)
        self.register_trading_control(control)

    @api_method
    def set_long_only(self):
        """
        Set a rule specifying that this algorithm cannot take short positions.
        """
        self.register_trading_control(LongOnly())

    @classmethod
    def all_api_methods(cls):
        """
        Return a list of all the TradingAlgorithm API methods.
        """
        return [fn for fn in cls.__dict__.itervalues()
                if getattr(fn, 'is_api_method', False)]
Beispiel #8
0
    def run(self, source, overwrite_sim_params=True,
            benchmark_return_source=None):
        """Run the algorithm.

        :Arguments:
            source : can be either:
                     - pandas.DataFrame
                     - zipline source
                     - list of sources

               If pandas.DataFrame is provided, it must have the
               following structure:
               * column names must consist of ints representing the
                 different sids
               * index must be DatetimeIndex
               * array contents should be price info.

        :Returns:
            daily_stats : pandas.DataFrame
              Daily performance metrics such as returns, alpha etc.

        """
        if isinstance(source, list):
            if overwrite_sim_params:
                warnings.warn("""List of sources passed, will not attempt to extract sids, and start and end
 dates. Make sure to set the correct fields in sim_params passed to
 __init__().""", UserWarning)
                overwrite_sim_params = False
        elif isinstance(source, pd.DataFrame):
            # if DataFrame provided, wrap in DataFrameSource
            source = DataFrameSource(source)
        elif isinstance(source, pd.Panel):
            source = DataPanelSource(source)

        if isinstance(source, list):
            self.set_sources(source)
        else:
            self.set_sources([source])

        # Override sim_params if params are provided by the source.
        if overwrite_sim_params:
            if hasattr(source, 'start'):
                self.sim_params.period_start = source.start
            if hasattr(source, 'end'):
                self.sim_params.period_end = source.end
            all_sids = [sid for s in self.sources for sid in s.sids]
            self.sim_params.sids = set(all_sids)
            # Changing period_start and period_close might require updating
            # of first_open and last_close.
            self.sim_params._update_internal()

        # Create history containers
        if len(self.history_specs) != 0:
            self.history_container = HistoryContainer(
                self.history_specs,
                self.sim_params.sids,
                self.sim_params.first_open)

        # Create transforms by wrapping them into StatefulTransforms
        self.transforms = []
        for namestring, trans_descr in iteritems(self.registered_transforms):
            sf = StatefulTransform(
                trans_descr['class'],
                *trans_descr['args'],
                **trans_descr['kwargs']
            )
            sf.namestring = namestring

            self.transforms.append(sf)

        # force a reset of the performance tracker, in case
        # this is a repeat run of the algorithm.
        self.perf_tracker = None

        # create transforms and zipline
        self.gen = self._create_generator(self.sim_params)

        with ZiplineAPI(self):
            # loop through simulated_trading, each iteration returns a
            # perf dictionary
            perfs = []
            for perf in self.gen:
                perfs.append(perf)

            # convert perf dict to pandas dataframe
            daily_stats = self._create_daily_stats(perfs)

        self.analyze(daily_stats)

        return daily_stats
Beispiel #9
0
    def test_container_nans_and_daily_roll(self):

        spec = history.HistorySpec(
            bar_count=3,
            frequency='1d',
            field='price',
            ffill=True,
            data_frequency='minute'
        )
        specs = {spec.key_str: spec}
        initial_sids = [1, ]
        initial_dt = pd.Timestamp(
            '2013-06-28 9:31AM', tz='US/Eastern').tz_convert('UTC')

        container = HistoryContainer(
            specs, initial_sids, initial_dt, 'minute'
        )

        bar_data = BarData()
        container.update(bar_data, initial_dt)
        # Since there was no backfill because of no db.
        # And no first bar of data, so all values should be nans.
        prices = container.get_history(spec, initial_dt)
        nan_values = np.isnan(prices[1])
        self.assertTrue(all(nan_values), nan_values)

        # Add data on bar two of first day.
        second_bar_dt = pd.Timestamp(
            '2013-06-28 9:32AM', tz='US/Eastern').tz_convert('UTC')

        bar_data[1] = {
            'price': 10,
            'dt': second_bar_dt
        }
        container.update(bar_data, second_bar_dt)

        prices = container.get_history(spec, second_bar_dt)
        # Prices should be
        #                             1
        # 2013-06-26 20:00:00+00:00 NaN
        # 2013-06-27 20:00:00+00:00 NaN
        # 2013-06-28 13:32:00+00:00  10

        self.assertTrue(np.isnan(prices[1].ix[0]))
        self.assertTrue(np.isnan(prices[1].ix[1]))
        self.assertEqual(prices[1].ix[2], 10)

        third_bar_dt = pd.Timestamp(
            '2013-06-28 9:33AM', tz='US/Eastern').tz_convert('UTC')

        del bar_data[1]

        container.update(bar_data, third_bar_dt)

        prices = container.get_history(spec, third_bar_dt)
        # The one should be forward filled

        # Prices should be
        #                             1
        # 2013-06-26 20:00:00+00:00 NaN
        # 2013-06-27 20:00:00+00:00 NaN
        # 2013-06-28 13:33:00+00:00  10

        self.assertEquals(prices[1][third_bar_dt], 10)

        # Note that we did not fill in data at the close.
        # There was a bug where a nan was being introduced because of the
        # last value of 'raw' data was used, instead of a ffilled close price.

        day_two_first_bar_dt = pd.Timestamp(
            '2013-07-01 9:31AM', tz='US/Eastern').tz_convert('UTC')

        bar_data[1] = {
            'price': 20,
            'dt': day_two_first_bar_dt
        }

        container.update(bar_data, day_two_first_bar_dt)

        prices = container.get_history(spec, day_two_first_bar_dt)

        # Prices Should Be

        #                              1
        # 2013-06-27 20:00:00+00:00  nan
        # 2013-06-28 20:00:00+00:00   10
        # 2013-07-01 13:31:00+00:00   20

        self.assertTrue(np.isnan(prices[1].ix[0]))
        self.assertEqual(prices[1].ix[1], 10)
        self.assertEqual(prices[1].ix[2], 20)

        # Clear out the bar data

        del bar_data[1]

        day_three_first_bar_dt = pd.Timestamp(
            '2013-07-02 9:31AM', tz='US/Eastern').tz_convert('UTC')

        container.update(bar_data, day_three_first_bar_dt)

        prices = container.get_history(spec, day_three_first_bar_dt)

        #                             1
        # 2013-06-28 20:00:00+00:00  10
        # 2013-07-01 20:00:00+00:00  20
        # 2013-07-02 13:31:00+00:00  20

        self.assertTrue(prices[1].ix[0], 10)
        self.assertTrue(prices[1].ix[1], 20)
        self.assertTrue(prices[1].ix[2], 20)

        day_four_first_bar_dt = pd.Timestamp(
            '2013-07-03 9:31AM', tz='US/Eastern').tz_convert('UTC')

        container.update(bar_data, day_four_first_bar_dt)

        prices = container.get_history(spec, day_four_first_bar_dt)

        #                             1
        # 2013-07-01 20:00:00+00:00  20
        # 2013-07-02 20:00:00+00:00  20
        # 2013-07-03 13:31:00+00:00  20

        self.assertEqual(prices[1].ix[0], 20)
        self.assertEqual(prices[1].ix[1], 20)
        self.assertEqual(prices[1].ix[2], 20)
Beispiel #10
0
class TradingAlgorithm(object):

    """
    Base class for trading algorithms. Inherit and overload
    initialize() and handle_data(data).

    A new algorithm could look like this:
    ```
    from zipline.api import order

    def initialize(context):
        context.sid = 'AAPL'
        context.amount = 100

    def handle_data(self, data):
        sid = context.sid
        amount = context.amount
        order(sid, amount)
    ```
    To then to run this algorithm pass these functions to
    TradingAlgorithm:

    my_algo = TradingAlgorithm(initialize, handle_data)
    stats = my_algo.run(data)

    """

    def __init__(self, *args, **kwargs):
        """Initialize sids and other state variables.

        :Arguments:
        :Optional:
            initialize : function
                Function that is called with a single
                argument at the begninning of the simulation.
            handle_data : function
                Function that is called with 2 arguments
                (context and data) on every bar.
            script : str
                Algoscript that contains initialize and
                handle_data function definition.
            data_frequency : str (daily, hourly or minutely)
               The duration of the bars.
            annualizer : int <optional>
               Which constant to use for annualizing risk metrics.
               If not provided, will extract from data_frequency.
            capital_base : float <default: 1.0e5>
               How much capital to start with.
            instant_fill : bool <default: False>
               Whether to fill orders immediately or on next bar.
        """
        self.datetime = None

        self.registered_transforms = {}
        self.transforms = []
        self.sources = []

        self._recorded_vars = {}
        self.namespace = kwargs.get('namespace', {})

        self.logger = None

        self.benchmark_return_source = None
        self.perf_tracker = None

        # default components for transact
        self.slippage = VolumeShareSlippage()
        self.commission = PerShare()

        if 'data_frequency' in kwargs:
            self.set_data_frequency(kwargs.pop('data_frequency'))
        else:
            self.data_frequency = None

        self.instant_fill = kwargs.pop('instant_fill', False)

        # Override annualizer if set
        if 'annualizer' in kwargs:
            self.annualizer = kwargs['annualizer']

        # set the capital base
        self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)

        self.sim_params = kwargs.pop('sim_params', None)
        if self.sim_params:
            if self.data_frequency is None:
                self.data_frequency = self.sim_params.data_frequency
            else:
                self.sim_params.data_frequency = self.data_frequency

            self.perf_tracker = PerformanceTracker(self.sim_params)

        self.blotter = kwargs.pop('blotter', None)
        if not self.blotter:
            self.blotter = Blotter()

        self.portfolio_needs_update = True
        self._portfolio = None

        self.history_container = None
        self.history_specs = {}

        # If string is passed in, execute and get reference to
        # functions.
        self.algoscript = kwargs.pop('script', None)

        self._initialize = None
        self._analyze = None

        if self.algoscript is not None:
            exec_(self.algoscript, self.namespace)
            self._initialize = self.namespace.get('initialize', None)
            if 'handle_data' not in self.namespace:
                raise ValueError('You must define a handle_data function.')
            else:
                self._handle_data = self.namespace['handle_data']

            # Optional analyze function, gets called after run
            self._analyze = self.namespace.get('analyze', None)

        elif kwargs.get('initialize', False) and kwargs.get('handle_data'):
            if self.algoscript is not None:
                raise ValueError('You can not set script and \
                initialize/handle_data.')
            self._initialize = kwargs.pop('initialize')
            self._handle_data = kwargs.pop('handle_data')

        # If method not defined, NOOP
        if self._initialize is None:
            self._initialize = lambda x: None

        # an algorithm subclass needs to set initialized to True when
        # it is fully initialized.
        self.initialized = False
        self.initialize(*args, **kwargs)

    def initialize(self, *args, **kwargs):
        """
        Call self._initialize with `self` made available to Zipline API
        functions.
        """
        with ZiplineAPI(self):
            self._initialize(self)

    def handle_data(self, data):
        if self.history_container:
            self.history_container.update(data, self.datetime)

        self._handle_data(self, data)

    def analyze(self, perf):
        if self._analyze is None:
            return

        with ZiplineAPI(self):
            self._analyze(self, perf)

    def __repr__(self):
        """
        N.B. this does not yet represent a string that can be used
        to instantiate an exact copy of an algorithm.

        However, it is getting close, and provides some value as something
        that can be inspected interactively.
        """
        return """
{class_name}(
    capital_base={capital_base}
    sim_params={sim_params},
    initialized={initialized},
    slippage={slippage},
    commission={commission},
    blotter={blotter},
    recorded_vars={recorded_vars})
""".strip().format(class_name=self.__class__.__name__,
                   capital_base=self.capital_base,
                   sim_params=repr(self.sim_params),
                   initialized=self.initialized,
                   slippage=repr(self.slippage),
                   commission=repr(self.commission),
                   blotter=repr(self.blotter),
                   recorded_vars=repr(self.recorded_vars))

    def _create_data_generator(self, source_filter, sim_params):
        """
        Create a merged data generator using the sources and
        transforms attached to this algorithm.

        ::source_filter:: is a method that receives events in date
        sorted order, and returns True for those events that should be
        processed by the zipline, and False for those that should be
        skipped.
        """
        if self.benchmark_return_source is None:
            env = trading.environment
            if (self.data_frequency == 'minute'
                    or sim_params.emission_rate == 'minute'):
                update_time = lambda date: env.get_open_and_close(date)[1]
            else:
                update_time = lambda date: date
            benchmark_return_source = [
                Event({'dt': update_time(dt),
                       'returns': ret,
                       'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK,
                       'source_id': 'benchmarks'})
                for dt, ret in trading.environment.benchmark_returns.iterkv()
                if dt.date() >= sim_params.period_start.date()
                and dt.date() <= sim_params.period_end.date()
            ]
        else:
            benchmark_return_source = self.benchmark_return_source

        date_sorted = date_sorted_sources(*self.sources)

        if source_filter:
            date_sorted = filter(source_filter, date_sorted)

        with_tnfms = sequential_transforms(date_sorted,
                                           *self.transforms)

        with_benchmarks = date_sorted_sources(benchmark_return_source,
                                              with_tnfms)

        # Group together events with the same dt field. This depends on the
        # events already being sorted.
        return groupby(with_benchmarks, attrgetter('dt'))

    def _create_generator(self, sim_params, source_filter=None):
        """
        Create a basic generator setup using the sources and
        transforms attached to this algorithm.

        ::source_filter:: is a method that receives events in date
        sorted order, and returns True for those events that should be
        processed by the zipline, and False for those that should be
        skipped.
        """
        sim_params.data_frequency = self.data_frequency

        # perf_tracker will be instantiated in __init__ if a sim_params
        # is passed to the constructor. If not, we instantiate here.
        if self.perf_tracker is None:
            self.perf_tracker = PerformanceTracker(sim_params)

        self.data_gen = self._create_data_generator(source_filter,
                                                    sim_params)

        self.trading_client = AlgorithmSimulator(self, sim_params)

        transact_method = transact_partial(self.slippage, self.commission)
        self.set_transact(transact_method)

        return self.trading_client.transform(self.data_gen)

    def get_generator(self):
        """
        Override this method to add new logic to the construction
        of the generator. Overrides can use the _create_generator
        method to get a standard construction generator.
        """
        return self._create_generator(self.sim_params)

    # TODO: make a new subclass, e.g. BatchAlgorithm, and move
    # the run method to the subclass, and refactor to put the
    # generator creation logic into get_generator.
    def run(self, source, sim_params=None, benchmark_return_source=None):
        """Run the algorithm.

        :Arguments:
            source : can be either:
                     - pandas.DataFrame
                     - zipline source
                     - list of zipline sources

               If pandas.DataFrame is provided, it must have the
               following structure:
               * column names must consist of ints representing the
                 different sids
               * index must be DatetimeIndex
               * array contents should be price info.

        :Returns:
            daily_stats : pandas.DataFrame
              Daily performance metrics such as returns, alpha etc.

        """
        if isinstance(source, (list, tuple)):
            assert self.sim_params is not None or sim_params is not None, \
                """When providing a list of sources, \
                sim_params have to be specified as a parameter
                or in the constructor."""
        elif isinstance(source, pd.DataFrame):
            # if DataFrame provided, wrap in DataFrameSource
            source = DataFrameSource(source)
        elif isinstance(source, pd.Panel):
            source = DataPanelSource(source)

        if not isinstance(source, (list, tuple)):
            self.sources = [source]
        else:
            self.sources = source

        # Check for override of sim_params.
        # If it isn't passed to this function,
        # use the default params set with the algorithm.
        # Else, we create simulation parameters using the start and end of the
        # source provided.
        if sim_params is None:
            if self.sim_params is None:
                start = source.start
                end = source.end
                sim_params = create_simulation_parameters(
                    start=start,
                    end=end,
                    capital_base=self.capital_base,
                )
            else:
                sim_params = self.sim_params

        # update sim params to ensure it's set
        self.sim_params = sim_params
        if self.sim_params.sids is None:
            all_sids = [sid for s in self.sources for sid in s.sids]
            self.sim_params.sids = set(all_sids)

        # Create history containers
        if len(self.history_specs) != 0:
            self.history_container = HistoryContainer(
                self.history_specs,
                self.sim_params.sids,
                self.sim_params.first_open)

        # Create transforms by wrapping them into StatefulTransforms
        self.transforms = []
        for namestring, trans_descr in iteritems(self.registered_transforms):
            sf = StatefulTransform(
                trans_descr['class'],
                *trans_descr['args'],
                **trans_descr['kwargs']
            )
            sf.namestring = namestring

            self.transforms.append(sf)

        # force a reset of the performance tracker, in case
        # this is a repeat run of the algorithm.
        self.perf_tracker = None

        # create transforms and zipline
        self.gen = self._create_generator(sim_params)

        with ZiplineAPI(self):
            # loop through simulated_trading, each iteration returns a
            # perf dictionary
            perfs = []
            for perf in self.gen:
                perfs.append(perf)

            # convert perf dict to pandas dataframe
            daily_stats = self._create_daily_stats(perfs)

        self.analyze(daily_stats)

        return daily_stats

    def _create_daily_stats(self, perfs):
        # create daily and cumulative stats dataframe
        daily_perfs = []
        # TODO: the loop here could overwrite expected properties
        # of daily_perf. Could potentially raise or log a
        # warning.
        for perf in perfs:
            if 'daily_perf' in perf:

                perf['daily_perf'].update(
                    perf['daily_perf'].pop('recorded_vars')
                )
                daily_perfs.append(perf['daily_perf'])
            else:
                self.risk_report = perf

        daily_dts = [np.datetime64(perf['period_close'], utc=True)
                     for perf in daily_perfs]
        daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)

        return daily_stats

    def add_transform(self, transform_class, tag, *args, **kwargs):
        """Add a single-sid, sequential transform to the model.

        :Arguments:
            transform_class : class
                Which transform to use. E.g. mavg.
            tag : str
                How to name the transform. Can later be access via:
                data[sid].tag()

        Extra args and kwargs will be forwarded to the transform
        instantiation.

        """
        self.registered_transforms[tag] = {'class': transform_class,
                                           'args': args,
                                           'kwargs': kwargs}

    @api_method
    def record(self, **kwargs):
        """
        Track and record local variable (i.e. attributes) each day.
        """
        for name, value in kwargs.items():
            self._recorded_vars[name] = value

    @api_method
    def order(self, sid, amount,
              limit_price=None,
              stop_price=None,
              style=None):
        """
        Place an order using the specified parameters.
        """
        # Raises a ZiplineError if invalid parameters are detected.
        self.validate_order_params(sid,
                                   amount,
                                   limit_price,
                                   stop_price,
                                   style)

        # Convert deprecated limit_price and stop_price parameters to use
        # ExecutionStyle objects.
        style = self.__convert_order_params_for_blotter(limit_price,
                                                        stop_price,
                                                        style)
        return self.blotter.order(sid, amount, style)

    def validate_order_params(self,
                              sid,
                              amount,
                              limit_price,
                              stop_price,
                              style):
        """
        Helper method for validating parameters to the order API function.

        Raises an UnsupportedOrderParameters if invalid arguments are found.
        """
        if style:
            if limit_price:
                raise UnsupportedOrderParameters(
                    msg="Passing both limit_price and style is not supported."
                )

            if stop_price:
                raise UnsupportedOrderParameters(
                    msg="Passing both stop_price and style is not supported."
                )

    @staticmethod
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
        """
        Helper method for converting deprecated limit_price and stop_price
        arguments into ExecutionStyle instances.

        This function assumes that either style == None or (limit_price,
        stop_price) == (None, None).
        """
        # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.
        if style:
            assert (limit_price, stop_price) == (None, None)
            return style
        if limit_price and stop_price:
            return StopLimitOrder(limit_price, stop_price)
        if limit_price:
            return LimitOrder(limit_price)
        if stop_price:
            return StopOrder(stop_price)
        else:
            return MarketOrder()

    @api_method
    def order_value(self, sid, value,
                    limit_price=None, stop_price=None, style=None):
        """
        Place an order by desired value rather than desired number of shares.
        If the requested sid is found in the universe, the requested value is
        divided by its price to imply the number of shares to transact.

        value > 0 :: Buy/Cover
        value < 0 :: Sell/Short
        Market order:    order(sid, value)
        Limit order:     order(sid, value, limit_price)
        Stop order:      order(sid, value, None, stop_price)
        StopLimit order: order(sid, value, limit_price, stop_price)
        """
        last_price = self.trading_client.current_data[sid].price
        if np.allclose(last_price, 0):
            zero_message = "Price of 0 for {psid}; can't infer value".format(
                psid=sid
            )
            if self.logger:
                self.logger.debug(zero_message)
            # Don't place any order
            return
        else:
            amount = value / last_price
            return self.order(sid, amount,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)

    @property
    def recorded_vars(self):
        return copy(self._recorded_vars)

    @property
    def portfolio(self):
        # internally this will cause a refresh of the
        # period performance calculations.
        return self.perf_tracker.get_portfolio()

    def updated_portfolio(self):
        # internally this will cause a refresh of the
        # period performance calculations.
        if self.portfolio_needs_update:
            self._portfolio = self.perf_tracker.get_portfolio()
            self.portfolio_needs_update = False
        return self._portfolio

    def set_logger(self, logger):
        self.logger = logger

    def set_datetime(self, dt):
        assert isinstance(dt, datetime), \
            "Attempt to set algorithm's current time with non-datetime"
        assert dt.tzinfo == pytz.utc, \
            "Algorithm expects a utc datetime"
        self.datetime = dt

    @api_method
    def get_datetime(self):
        """
        Returns a copy of the datetime.
        """
        date_copy = copy(self.datetime)
        assert date_copy.tzinfo == pytz.utc, \
            "Algorithm should have a utc datetime"
        return date_copy

    def set_transact(self, transact):
        """
        Set the method that will be called to create a
        transaction from open orders and trade events.
        """
        self.blotter.transact = transact

    @api_method
    def set_slippage(self, slippage):
        if not isinstance(slippage, SlippageModel):
            raise UnsupportedSlippageModel()
        if self.initialized:
            raise OverrideSlippagePostInit()
        self.slippage = slippage

    @api_method
    def set_commission(self, commission):
        if not isinstance(commission, (PerShare, PerTrade, PerDollar)):
            raise UnsupportedCommissionModel()

        if self.initialized:
            raise OverrideCommissionPostInit()
        self.commission = commission

    def set_sources(self, sources):
        assert isinstance(sources, list)
        self.sources = sources

    def set_transforms(self, transforms):
        assert isinstance(transforms, list)
        self.transforms = transforms

    def set_data_frequency(self, data_frequency):
        assert data_frequency in ('daily', 'minute')
        self.data_frequency = data_frequency
        self.annualizer = ANNUALIZER[self.data_frequency]

    @api_method
    def order_percent(self, sid, percent,
                      limit_price=None, stop_price=None, style=None):
        """
        Place an order in the specified security corresponding to the given
        percent of the current portfolio value.

        Note that percent must expressed as a decimal (0.50 means 50\%).
        """
        value = self.portfolio.portfolio_value * percent
        return self.order_value(sid, value,
                                limit_price=limit_price,
                                stop_price=stop_price,
                                style=style)

    @api_method
    def order_target(self, sid, target,
                     limit_price=None, stop_price=None, style=None):
        """
        Place an order to adjust a position to a target number of shares. If
        the position doesn't already exist, this is equivalent to placing a new
        order. If the position does exist, this is equivalent to placing an
        order for the difference between the target number of shares and the
        current number of shares.
        """
        if sid in self.portfolio.positions:
            current_position = self.portfolio.positions[sid].amount
            req_shares = target - current_position
            return self.order(sid, req_shares,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)
        else:
            return self.order(sid, target,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)

    @api_method
    def order_target_value(self, sid, target,
                           limit_price=None, stop_price=None, style=None):
        """
        Place an order to adjust a position to a target value. If
        the position doesn't already exist, this is equivalent to placing a new
        order. If the position does exist, this is equivalent to placing an
        order for the difference between the target value and the
        current value.
        """
        if sid in self.portfolio.positions:
            current_position = self.portfolio.positions[sid].amount
            current_price = self.trading_client.current_data[sid].price
            current_value = current_position * current_price
            req_value = target - current_value
            return self.order_value(sid, req_value,
                                    limit_price=limit_price,
                                    stop_price=stop_price,
                                    style=style)
        else:
            return self.order_value(sid, target,
                                    limit_price=limit_price,
                                    stop_price=stop_price,
                                    style=style)

    @api_method
    def order_target_percent(self, sid, target,
                             limit_price=None, stop_price=None, style=None):
        """
        Place an order to adjust a position to a target percent of the
        current portfolio value. If the position doesn't already exist, this is
        equivalent to placing a new order. If the position does exist, this is
        equivalent to placing an order for the difference between the target
        percent and the current percent.

        Note that target must expressed as a decimal (0.50 means 50\%).
        """
        if sid in self.portfolio.positions:
            current_position = self.portfolio.positions[sid].amount
            current_price = self.trading_client.current_data[sid].price
            current_value = current_position * current_price
        else:
            current_value = 0
        target_value = self.portfolio.portfolio_value * target

        req_value = target_value - current_value
        return self.order_value(sid, req_value,
                                limit_price=limit_price,
                                stop_price=stop_price,
                                style=style)

    @api_method
    def get_open_orders(self, sid=None):
        if sid is None:
            return {key: [order.to_api_obj() for order in orders]
                    for key, orders
                    in self.blotter.open_orders.iteritems()}
        if sid in self.blotter.open_orders:
            orders = self.blotter.open_orders[sid]
            return [order.to_api_obj() for order in orders]
        return []

    @api_method
    def get_order(self, order_id):
        if order_id in self.blotter.orders:
            return self.blotter.orders[order_id].to_api_obj()

    @api_method
    def cancel_order(self, order_param):
        order_id = order_param
        if isinstance(order_param, zipline.protocol.Order):
            order_id = order_param.id

        self.blotter.cancel(order_id)

    def raw_positions(self):
        """
        Returns the current portfolio for the algorithm.

        N.B. this is not done as a property, so that the function can be
        passed and called from within a source.
        """
        # Return the 'internal' positions object, as in the one that is
        # not passed to the algo, and thus should not have tainted keys.
        return self.perf_tracker.cumulative_performance.positions

    def raw_orders(self):
        """
        Returns the current open orders from the blotter.

        N.B. this is not a property, so that the function can be passed
        and called back from within a source.
        """

        return self.blotter.open_orders

    @api_method
    def add_history(self, bar_count, frequency, field,
                    ffill=True):
        history_spec = HistorySpec(bar_count, frequency, field, ffill)
        self.history_specs[history_spec.key_str] = history_spec

    @api_method
    def history(self, bar_count, frequency, field, ffill=True):
        spec_key_str = HistorySpec.spec_key(
            bar_count, frequency, field, ffill)
        history_spec = self.history_specs[spec_key_str]
        return self.history_container.get_history(history_spec, self.datetime)
Beispiel #11
0
    def run(self, source, sim_params=None, benchmark_return_source=None):
        """Run the algorithm.

        :Arguments:
            source : can be either:
                     - pandas.DataFrame
                     - zipline source
                     - list of zipline sources

               If pandas.DataFrame is provided, it must have the
               following structure:
               * column names must consist of ints representing the
                 different sids
               * index must be DatetimeIndex
               * array contents should be price info.

        :Returns:
            daily_stats : pandas.DataFrame
              Daily performance metrics such as returns, alpha etc.

        """
        if isinstance(source, (list, tuple)):
            assert self.sim_params is not None or sim_params is not None, \
                """When providing a list of sources, \
                sim_params have to be specified as a parameter
                or in the constructor."""
        elif isinstance(source, pd.DataFrame):
            # if DataFrame provided, wrap in DataFrameSource
            source = DataFrameSource(source)
        elif isinstance(source, pd.Panel):
            source = DataPanelSource(source)

        if not isinstance(source, (list, tuple)):
            self.sources = [source]
        else:
            self.sources = source

        # Check for override of sim_params.
        # If it isn't passed to this function,
        # use the default params set with the algorithm.
        # Else, we create simulation parameters using the start and end of the
        # source provided.
        if sim_params is None:
            if self.sim_params is None:
                start = source.start
                end = source.end
                sim_params = create_simulation_parameters(
                    start=start,
                    end=end,
                    capital_base=self.capital_base,
                )
            else:
                sim_params = self.sim_params

        # update sim params to ensure it's set
        self.sim_params = sim_params
        if self.sim_params.sids is None:
            all_sids = [sid for s in self.sources for sid in s.sids]
            self.sim_params.sids = set(all_sids)

        # Create history containers
        if len(self.history_specs) != 0:
            self.history_container = HistoryContainer(
                self.history_specs,
                self.sim_params.sids,
                self.sim_params.first_open)

        # Create transforms by wrapping them into StatefulTransforms
        self.transforms = []
        for namestring, trans_descr in iteritems(self.registered_transforms):
            sf = StatefulTransform(
                trans_descr['class'],
                *trans_descr['args'],
                **trans_descr['kwargs']
            )
            sf.namestring = namestring

            self.transforms.append(sf)

        # force a reset of the performance tracker, in case
        # this is a repeat run of the algorithm.
        self.perf_tracker = None

        # create transforms and zipline
        self.gen = self._create_generator(sim_params)

        with ZiplineAPI(self):
            # loop through simulated_trading, each iteration returns a
            # perf dictionary
            perfs = []
            for perf in self.gen:
                perfs.append(perf)

            # convert perf dict to pandas dataframe
            daily_stats = self._create_daily_stats(perfs)

        self.analyze(daily_stats)

        return daily_stats
class TradingAlgorithm(object):

    """
    Base class for trading algorithms. Inherit and overload
    initialize() and handle_data(data).

    A new algorithm could look like this:
    ```
    from zipline.api import order

    def initialize(context):
        context.sid = 'AAPL'
        context.amount = 100

    def handle_data(self, data):
        sid = context.sid
        amount = context.amount
        order(sid, amount)
    ```
    To then to run this algorithm pass these functions to
    TradingAlgorithm:

    my_algo = TradingAlgorithm(initialize, handle_data)
    stats = my_algo.run(data)

    """

    def __init__(self, *args, **kwargs):
        """Initialize sids and other state variables.

        :Arguments:
        :Optional:
            initialize : function
                Function that is called with a single
                argument at the begninning of the simulation.
            handle_data : function
                Function that is called with 2 arguments
                (context and data) on every bar.
            script : str
                Algoscript that contains initialize and
                handle_data function definition.
            data_frequency : str (daily, hourly or minutely)
               The duration of the bars.
            annualizer : int <optional>
               Which constant to use for annualizing risk metrics.
               If not provided, will extract from data_frequency.
            capital_base : float <default: 1.0e5>
               How much capital to start with.
            instant_fill : bool <default: False>
               Whether to fill orders immediately or on next bar.
        """
        self.datetime = None

        self.registered_transforms = {}
        self.transforms = []
        self.sources = []

        # List of trading controls to be used to validate orders.
        self.trading_controls = []

        self._recorded_vars = {}
        self.namespace = kwargs.get('namespace', {})

        self.logger = None

        self.benchmark_return_source = None
        self.perf_tracker = None

        # default components for transact
        self.slippage = VolumeShareSlippage()
        self.commission = PerShare()

        if 'data_frequency' in kwargs:
            self.set_data_frequency(kwargs.pop('data_frequency'))
        else:
            self.data_frequency = None

        self.instant_fill = kwargs.pop('instant_fill', False)

        # Override annualizer if set
        if 'annualizer' in kwargs:
            self.annualizer = kwargs['annualizer']

        # set the capital base
        self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)

        self.sim_params = kwargs.pop('sim_params', None)
        if self.sim_params:
            if self.data_frequency is None:
                self.data_frequency = self.sim_params.data_frequency
            else:
                self.sim_params.data_frequency = self.data_frequency

            self.perf_tracker = PerformanceTracker(self.sim_params)

        self.blotter = kwargs.pop('blotter', None)
        if not self.blotter:
            self.blotter = Blotter()

        self.portfolio_needs_update = True
        self._portfolio = None

        self.history_container = None
        self.history_specs = {}

        # If string is passed in, execute and get reference to
        # functions.
        self.algoscript = kwargs.pop('script', None)

        self._initialize = None
        self._analyze = None

        if self.algoscript is not None:
            exec_(self.algoscript, self.namespace)
            self._initialize = self.namespace.get('initialize', None)
            if 'handle_data' not in self.namespace:
                raise ValueError('You must define a handle_data function.')
            else:
                self._handle_data = self.namespace['handle_data']

            # Optional analyze function, gets called after run
            self._analyze = self.namespace.get('analyze', None)

        elif kwargs.get('initialize', False) and kwargs.get('handle_data'):
            if self.algoscript is not None:
                raise ValueError('You can not set script and \
                initialize/handle_data.')
            self._initialize = kwargs.pop('initialize')
            self._handle_data = kwargs.pop('handle_data')

        # If method not defined, NOOP
        if self._initialize is None:
            self._initialize = lambda x: None

        # an algorithm subclass needs to set initialized to True when
        # it is fully initialized.
        self.initialized = False
        self.initialize(*args, **kwargs)

    def initialize(self, *args, **kwargs):
        """
        Call self._initialize with `self` made available to Zipline API
        functions.
        """
        with ZiplineAPI(self):
            self._initialize(self)

    def handle_data(self, data):
        if self.history_container:
            self.history_container.update(data, self.datetime)

        self._handle_data(self, data)

    def analyze(self, perf):
        if self._analyze is None:
            return

        with ZiplineAPI(self):
            self._analyze(self, perf)

    def __repr__(self):
        """
        N.B. this does not yet represent a string that can be used
        to instantiate an exact copy of an algorithm.

        However, it is getting close, and provides some value as something
        that can be inspected interactively.
        """
        return """
{class_name}(
    capital_base={capital_base}
    sim_params={sim_params},
    initialized={initialized},
    slippage={slippage},
    commission={commission},
    blotter={blotter},
    recorded_vars={recorded_vars})
""".strip().format(class_name=self.__class__.__name__,
                   capital_base=self.capital_base,
                   sim_params=repr(self.sim_params),
                   initialized=self.initialized,
                   slippage=repr(self.slippage),
                   commission=repr(self.commission),
                   blotter=repr(self.blotter),
                   recorded_vars=repr(self.recorded_vars))

    def _create_data_generator(self, source_filter, sim_params):
        """
        Create a merged data generator using the sources and
        transforms attached to this algorithm.

        ::source_filter:: is a method that receives events in date
        sorted order, and returns True for those events that should be
        processed by the zipline, and False for those that should be
        skipped.
        """
        if self.benchmark_return_source is None:
            env = trading.environment
            if (self.data_frequency == 'minute'
                    or sim_params.emission_rate == 'minute'):
                update_time = lambda date: env.get_open_and_close(date)[1]
            else:
                update_time = lambda date: date
            benchmark_return_source = [
                Event({'dt': update_time(dt),
                       'returns': ret,
                       'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK,
                       'source_id': 'benchmarks'})
                for dt, ret in trading.environment.benchmark_returns.iterkv()
                if dt.date() >= sim_params.period_start.date()
                and dt.date() <= sim_params.period_end.date()
            ]
        else:
            benchmark_return_source = self.benchmark_return_source

        date_sorted = date_sorted_sources(*self.sources)

        if source_filter:
            date_sorted = filter(source_filter, date_sorted)

        with_tnfms = sequential_transforms(date_sorted,
                                           *self.transforms)

        with_benchmarks = date_sorted_sources(benchmark_return_source,
                                              with_tnfms)

        # Group together events with the same dt field. This depends on the
        # events already being sorted.
        return groupby(with_benchmarks, attrgetter('dt'))

    def _create_generator(self, sim_params, source_filter=None):
        """
        Create a basic generator setup using the sources and
        transforms attached to this algorithm.

        ::source_filter:: is a method that receives events in date
        sorted order, and returns True for those events that should be
        processed by the zipline, and False for those that should be
        skipped.
        """
        sim_params.data_frequency = self.data_frequency

        # perf_tracker will be instantiated in __init__ if a sim_params
        # is passed to the constructor. If not, we instantiate here.
        if self.perf_tracker is None:
            self.perf_tracker = PerformanceTracker(sim_params)

        self.data_gen = self._create_data_generator(source_filter,
                                                    sim_params)

        self.trading_client = AlgorithmSimulator(self, sim_params)

        transact_method = transact_partial(self.slippage, self.commission)
        self.set_transact(transact_method)

        return self.trading_client.transform(self.data_gen)

    def get_generator(self):
        """
        Override this method to add new logic to the construction
        of the generator. Overrides can use the _create_generator
        method to get a standard construction generator.
        """
        return self._create_generator(self.sim_params)

    # TODO: make a new subclass, e.g. BatchAlgorithm, and move
    # the run method to the subclass, and refactor to put the
    # generator creation logic into get_generator.
    def run(self, source, sim_params=None, benchmark_return_source=None):
        """Run the algorithm.

        :Arguments:
            source : can be either:
                     - pandas.DataFrame
                     - zipline source
                     - list of zipline sources

               If pandas.DataFrame is provided, it must have the
               following structure:
               * column names must consist of ints representing the
                 different sids
               * index must be DatetimeIndex
               * array contents should be price info.

        :Returns:
            daily_stats : pandas.DataFrame
              Daily performance metrics such as returns, alpha etc.

        """
        if isinstance(source, (list, tuple)):
            assert self.sim_params is not None or sim_params is not None, \
                """When providing a list of sources, \
                sim_params have to be specified as a parameter
                or in the constructor."""
        elif isinstance(source, pd.DataFrame):
            # if DataFrame provided, wrap in DataFrameSource
            source = DataFrameSource(source)
        elif isinstance(source, pd.Panel):
            source = DataPanelSource(source)

        if not isinstance(source, (list, tuple)):
            self.sources = [source]
        else:
            self.sources = source

        # Check for override of sim_params.
        # If it isn't passed to this function,
        # use the default params set with the algorithm.
        # Else, we create simulation parameters using the start and end of the
        # source provided.
        if sim_params is None:
            if self.sim_params is None:
                start = source.start
                end = source.end
                sim_params = create_simulation_parameters(
                    start=start,
                    end=end,
                    capital_base=self.capital_base,
                )
            else:
                sim_params = self.sim_params

        # update sim params to ensure it's set
        self.sim_params = sim_params
        if self.sim_params.sids is None:
            all_sids = [sid for s in self.sources for sid in s.sids]
            self.sim_params.sids = set(all_sids)

        # Create history containers
        if len(self.history_specs) != 0:
            self.history_container = HistoryContainer(
                self.history_specs,
                self.sim_params.sids,
                self.sim_params.first_open)

        # Create transforms by wrapping them into StatefulTransforms
        self.transforms = []
        for namestring, trans_descr in iteritems(self.registered_transforms):
            sf = StatefulTransform(
                trans_descr['class'],
                *trans_descr['args'],
                **trans_descr['kwargs']
            )
            sf.namestring = namestring

            self.transforms.append(sf)

        # force a reset of the performance tracker, in case
        # this is a repeat run of the algorithm.
        self.perf_tracker = None

        # create transforms and zipline
        self.gen = self._create_generator(sim_params)

        with ZiplineAPI(self):
            # loop through simulated_trading, each iteration returns a
            # perf dictionary
            perfs = []
            for perf in self.gen:
                perfs.append(perf)

            # convert perf dict to pandas dataframe
            daily_stats = self._create_daily_stats(perfs)

        self.analyze(daily_stats)

        return daily_stats

    def _create_daily_stats(self, perfs):
        # create daily and cumulative stats dataframe
        daily_perfs = []
        # TODO: the loop here could overwrite expected properties
        # of daily_perf. Could potentially raise or log a
        # warning.
        for perf in perfs:
            if 'daily_perf' in perf:

                perf['daily_perf'].update(
                    perf['daily_perf'].pop('recorded_vars')
                )
                daily_perfs.append(perf['daily_perf'])
            else:
                self.risk_report = perf

        daily_dts = [np.datetime64(perf['period_close'], utc=True)
                     for perf in daily_perfs]
        daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)

        return daily_stats

    def add_transform(self, transform_class, tag, *args, **kwargs):
        """Add a single-sid, sequential transform to the model.

        :Arguments:
            transform_class : class
                Which transform to use. E.g. mavg.
            tag : str
                How to name the transform. Can later be access via:
                data[sid].tag()

        Extra args and kwargs will be forwarded to the transform
        instantiation.

        """
        self.registered_transforms[tag] = {'class': transform_class,
                                           'args': args,
                                           'kwargs': kwargs}

    @api_method
    def record(self, **kwargs):
        """
        Track and record local variable (i.e. attributes) each day.
        """
        for name, value in kwargs.items():
            self._recorded_vars[name] = value

    @api_method
    def order(self, sid, amount,
              limit_price=None,
              stop_price=None,
              style=None):
        """
        Place an order using the specified parameters.
        """
        # Raises a ZiplineError if invalid parameters are detected.
        self.validate_order_params(sid,
                                   amount,
                                   limit_price,
                                   stop_price,
                                   style)

        # Convert deprecated limit_price and stop_price parameters to use
        # ExecutionStyle objects.
        style = self.__convert_order_params_for_blotter(limit_price,
                                                        stop_price,
                                                        style)
        return self.blotter.order(sid, amount, style)

    def validate_order_params(self,
                              sid,
                              amount,
                              limit_price,
                              stop_price,
                              style):
        """
        Helper method for validating parameters to the order API function.

        Raises an UnsupportedOrderParameters if invalid arguments are found.
        """
        if style:
            if limit_price:
                raise UnsupportedOrderParameters(
                    msg="Passing both limit_price and style is not supported."
                )

            if stop_price:
                raise UnsupportedOrderParameters(
                    msg="Passing both stop_price and style is not supported."
                )

        for control in self.trading_controls:
            control.validate(sid,
                             amount,
                             self.updated_portfolio(),
                             self.get_datetime(),
                             self.trading_client.current_data)

    @staticmethod
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
        """
        Helper method for converting deprecated limit_price and stop_price
        arguments into ExecutionStyle instances.

        This function assumes that either style == None or (limit_price,
        stop_price) == (None, None).
        """
        # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.
        if style:
            assert (limit_price, stop_price) == (None, None)
            return style
        if limit_price and stop_price:
            return StopLimitOrder(limit_price, stop_price)
        if limit_price:
            return LimitOrder(limit_price)
        if stop_price:
            return StopOrder(stop_price)
        else:
            return MarketOrder()

    @api_method
    def order_value(self, sid, value,
                    limit_price=None, stop_price=None, style=None):
        """
        Place an order by desired value rather than desired number of shares.
        If the requested sid is found in the universe, the requested value is
        divided by its price to imply the number of shares to transact.

        value > 0 :: Buy/Cover
        value < 0 :: Sell/Short
        Market order:    order(sid, value)
        Limit order:     order(sid, value, limit_price)
        Stop order:      order(sid, value, None, stop_price)
        StopLimit order: order(sid, value, limit_price, stop_price)
        """
        last_price = self.trading_client.current_data[sid].price
        if np.allclose(last_price, 0):
            zero_message = "Price of 0 for {psid}; can't infer value".format(
                psid=sid
            )
            if self.logger:
                self.logger.debug(zero_message)
            # Don't place any order
            return
        else:
            amount = value / last_price
            return self.order(sid, amount,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)

    @property
    def recorded_vars(self):
        return copy(self._recorded_vars)

    @property
    def portfolio(self):
        return self.updated_portfolio()

    def updated_portfolio(self):
        if self.portfolio_needs_update:
            self._portfolio = self.perf_tracker.get_portfolio()
            self.portfolio_needs_update = False
        return self._portfolio

    def set_logger(self, logger):
        self.logger = logger

    def set_datetime(self, dt):
        assert isinstance(dt, datetime), \
            "Attempt to set algorithm's current time with non-datetime"
        assert dt.tzinfo == pytz.utc, \
            "Algorithm expects a utc datetime"
        self.datetime = dt

    @api_method
    def get_datetime(self):
        """
        Returns a copy of the datetime.
        """
        date_copy = copy(self.datetime)
        assert date_copy.tzinfo == pytz.utc, \
            "Algorithm should have a utc datetime"
        return date_copy

    def set_transact(self, transact):
        """
        Set the method that will be called to create a
        transaction from open orders and trade events.
        """
        self.blotter.transact = transact

    @api_method
    def set_slippage(self, slippage):
        if not isinstance(slippage, SlippageModel):
            raise UnsupportedSlippageModel()
        if self.initialized:
            raise OverrideSlippagePostInit()
        self.slippage = slippage

    @api_method
    def set_commission(self, commission):
        if not isinstance(commission, (PerShare, PerTrade, PerDollar)):
            raise UnsupportedCommissionModel()

        if self.initialized:
            raise OverrideCommissionPostInit()
        self.commission = commission

    def set_sources(self, sources):
        assert isinstance(sources, list)
        self.sources = sources

    def set_transforms(self, transforms):
        assert isinstance(transforms, list)
        self.transforms = transforms

    def set_data_frequency(self, data_frequency):
        assert data_frequency in ('daily', 'minute')
        self.data_frequency = data_frequency
        self.annualizer = ANNUALIZER[self.data_frequency]

    @api_method
    def order_percent(self, sid, percent,
                      limit_price=None, stop_price=None, style=None):
        """
        Place an order in the specified security corresponding to the given
        percent of the current portfolio value.

        Note that percent must expressed as a decimal (0.50 means 50\%).
        """
        value = self.portfolio.portfolio_value * percent
        return self.order_value(sid, value,
                                limit_price=limit_price,
                                stop_price=stop_price,
                                style=style)

    @api_method
    def order_target(self, sid, target,
                     limit_price=None, stop_price=None, style=None):
        """
        Place an order to adjust a position to a target number of shares. If
        the position doesn't already exist, this is equivalent to placing a new
        order. If the position does exist, this is equivalent to placing an
        order for the difference between the target number of shares and the
        current number of shares.
        """
        if sid in self.portfolio.positions:
            current_position = self.portfolio.positions[sid].amount
            req_shares = target - current_position
            return self.order(sid, req_shares,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)
        else:
            return self.order(sid, target,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)

    @api_method
    def order_target_value(self, sid, target,
                           limit_price=None, stop_price=None, style=None):
        """
        Place an order to adjust a position to a target value. If
        the position doesn't already exist, this is equivalent to placing a new
        order. If the position does exist, this is equivalent to placing an
        order for the difference between the target value and the
        current value.
        """
        last_price = self.trading_client.current_data[sid].price
        if np.allclose(last_price, 0):
            # Don't place an order
            if self.logger:
                zero_message = "Price of 0 for {psid}; can't infer value"
                self.logger.debug(zero_message.format(psid=sid))
            return
        target_amount = target / last_price
        return self.order_target(sid, target_amount,
                                 limit_price=limit_price,
                                 stop_price=stop_price,
                                 style=style)

    @api_method
    def order_target_percent(self, sid, target,
                             limit_price=None, stop_price=None, style=None):
        """
        Place an order to adjust a position to a target percent of the
        current portfolio value. If the position doesn't already exist, this is
        equivalent to placing a new order. If the position does exist, this is
        equivalent to placing an order for the difference between the target
        percent and the current percent.

        Note that target must expressed as a decimal (0.50 means 50\%).
        """
        target_value = self.portfolio.portfolio_value * target
        return self.order_target_value(sid, target_value,
                                       limit_price=limit_price,
                                       stop_price=stop_price,
                                       style=style)

    @api_method
    def get_open_orders(self, sid=None):
        if sid is None:
            return {
                key: [order.to_api_obj() for order in orders]
                for key, orders in iteritems(self.blotter.open_orders)
                if orders
            }
        if sid in self.blotter.open_orders:
            orders = self.blotter.open_orders[sid]
            return [order.to_api_obj() for order in orders]
        return []

    @api_method
    def get_order(self, order_id):
        if order_id in self.blotter.orders:
            return self.blotter.orders[order_id].to_api_obj()

    @api_method
    def cancel_order(self, order_param):
        order_id = order_param
        if isinstance(order_param, zipline.protocol.Order):
            order_id = order_param.id

        self.blotter.cancel(order_id)

    def raw_positions(self):
        """
        Returns the current portfolio for the algorithm.

        N.B. this is not done as a property, so that the function can be
        passed and called from within a source.
        """
        # Return the 'internal' positions object, as in the one that is
        # not passed to the algo, and thus should not have tainted keys.
        return self.perf_tracker.cumulative_performance.positions

    def raw_orders(self):
        """
        Returns the current open orders from the blotter.

        N.B. this is not a property, so that the function can be passed
        and called back from within a source.
        """

        return self.blotter.open_orders

    @api_method
    def add_history(self, bar_count, frequency, field,
                    ffill=True):
        history_spec = HistorySpec(bar_count, frequency, field, ffill)
        self.history_specs[history_spec.key_str] = history_spec

    @api_method
    def history(self, bar_count, frequency, field, ffill=True):
        spec_key_str = HistorySpec.spec_key(
            bar_count, frequency, field, ffill)
        history_spec = self.history_specs[spec_key_str]
        return self.history_container.get_history(history_spec, self.datetime)

    ####################
    # Trading Controls #
    ####################

    def register_trading_control(self, control):
        """
        Register a new TradingControl to be checked prior to order calls.
        """
        if self.initialized:
            raise RegisterTradingControlPostInit()
        self.trading_controls.append(control)

    @api_method
    def set_max_position_size(self,
                              sid=None,
                              max_shares=None,
                              max_notional=None):
        """
        Set a limit on the number of shares and/or dollar value held for the
        given sid. Limits are treated as absolute values and are enforced at
        the time that the algo attempts to place an order for sid. This means
        that it's possible to end up with more than the max number of shares
        due to splits/dividends, and more than the max notional due to price
        improvement.

        If an algorithm attempts to place an order that would result in
        increasing the absolute value of shares/dollar value exceeding one of
        these limits, raise a TradingControlException.
        """
        control = MaxPositionSize(sid=sid,
                                  max_shares=max_shares,
                                  max_notional=max_notional)
        self.register_trading_control(control)

    @api_method
    def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
        """
        Set a limit on the number of shares and/or dollar value of any single
        order placed for sid.  Limits are treated as absolute values and are
        enforced at the time that the algo attempts to place an order for sid.

        If an algorithm attempts to place an order that would result in
        exceeding one of these limits, raise a TradingControlException.
        """
        control = MaxOrderSize(sid=sid,
                               max_shares=max_shares,
                               max_notional=max_notional)
        self.register_trading_control(control)

    @api_method
    def set_max_order_count(self, max_count):
        """
        Set a limit on the number of orders that can be placed within the given
        time interval.
        """
        control = MaxOrderCount(max_count)
        self.register_trading_control(control)

    @api_method
    def set_long_only(self):
        """
        Set a rule specifying that this algorithm cannot take short positions.
        """
        self.register_trading_control(LongOnly())

    @classmethod
    def all_api_methods(cls):
        """
        Return a list of all the TradingAlgorithm API methods.
        """
        return [fn for fn in cls.__dict__.itervalues()
                if getattr(fn, 'is_api_method', False)]
    def run(self, source, sim_params=None, benchmark_return_source=None):
        """Run the algorithm.

        :Arguments:
            source : can be either:
                     - pandas.DataFrame
                     - zipline source
                     - list of zipline sources

               If pandas.DataFrame is provided, it must have the
               following structure:
               * column names must consist of ints representing the
                 different sids
               * index must be DatetimeIndex
               * array contents should be price info.

        :Returns:
            daily_stats : pandas.DataFrame
              Daily performance metrics such as returns, alpha etc.

        """
        if isinstance(source, (list, tuple)):
            assert self.sim_params is not None or sim_params is not None, \
                """When providing a list of sources, \
                sim_params have to be specified as a parameter
                or in the constructor."""
        elif isinstance(source, pd.DataFrame):
            # if DataFrame provided, wrap in DataFrameSource
            source = DataFrameSource(source)
        elif isinstance(source, pd.Panel):
            source = DataPanelSource(source)

        if not isinstance(source, (list, tuple)):
            self.sources = [source]
        else:
            self.sources = source

        # Check for override of sim_params.
        # If it isn't passed to this function,
        # use the default params set with the algorithm.
        # Else, we create simulation parameters using the start and end of the
        # source provided.
        if sim_params is None:
            if self.sim_params is None:
                start = source.start
                end = source.end
                sim_params = create_simulation_parameters(
                    start=start,
                    end=end,
                    capital_base=self.capital_base,
                )
            else:
                sim_params = self.sim_params

        # update sim params to ensure it's set
        self.sim_params = sim_params
        if self.sim_params.sids is None:
            all_sids = [sid for s in self.sources for sid in s.sids]
            self.sim_params.sids = set(all_sids)

        # Create history containers
        if len(self.history_specs) != 0:
            self.history_container = HistoryContainer(
                self.history_specs,
                self.sim_params.sids,
                self.sim_params.first_open)

        # Create transforms by wrapping them into StatefulTransforms
        self.transforms = []
        for namestring, trans_descr in iteritems(self.registered_transforms):
            sf = StatefulTransform(
                trans_descr['class'],
                *trans_descr['args'],
                **trans_descr['kwargs']
            )
            sf.namestring = namestring

            self.transforms.append(sf)

        # force a reset of the performance tracker, in case
        # this is a repeat run of the algorithm.
        self.perf_tracker = None

        # create transforms and zipline
        self.gen = self._create_generator(sim_params)

        with ZiplineAPI(self):
            # loop through simulated_trading, each iteration returns a
            # perf dictionary
            perfs = []
            for perf in self.gen:
                perfs.append(perf)

            # convert perf dict to pandas dataframe
            daily_stats = self._create_daily_stats(perfs)

        self.analyze(daily_stats)

        return daily_stats
Beispiel #14
0
    def run(self, source, overwrite_sim_params=True,
            benchmark_return_source=None):
        """Run the algorithm.

        :Arguments:
            source : can be either:
                     - pandas.DataFrame
                     - zipline source
                     - list of sources

               If pandas.DataFrame is provided, it must have the
               following structure:
               * column names must consist of ints representing the
                 different sids
               * index must be DatetimeIndex
               * array contents should be price info.

        :Returns:
            daily_stats : pandas.DataFrame
              Daily performance metrics such as returns, alpha etc.

        """
        if isinstance(source, list):
            if overwrite_sim_params:
                warnings.warn("""List of sources passed, will not attempt to extract sids, and start and end
 dates. Make sure to set the correct fields in sim_params passed to
 __init__().""", UserWarning)
                overwrite_sim_params = False
        elif isinstance(source, pd.DataFrame):
            # if DataFrame provided, wrap in DataFrameSource
            source = DataFrameSource(source)
        elif isinstance(source, pd.Panel):
            source = DataPanelSource(source)

        if isinstance(source, list):
            self.set_sources(source)
        else:
            self.set_sources([source])

        # Override sim_params if params are provided by the source.
        if overwrite_sim_params:
            if hasattr(source, 'start'):
                self.sim_params.period_start = source.start
            if hasattr(source, 'end'):
                self.sim_params.period_end = source.end
            all_sids = [sid for s in self.sources for sid in s.sids]
            self.sim_params.sids = set(all_sids)
            # Changing period_start and period_close might require updating
            # of first_open and last_close.
            self.sim_params._update_internal()

        # Create history containers
        if len(self.history_specs) != 0:
            self.history_container = HistoryContainer(
                self.history_specs,
                self.sim_params.sids,
                self.sim_params.first_open)

        # Create transforms by wrapping them into StatefulTransforms
        self.transforms = []
        for namestring, trans_descr in iteritems(self.registered_transforms):
            sf = StatefulTransform(
                trans_descr['class'],
                *trans_descr['args'],
                **trans_descr['kwargs']
            )
            sf.namestring = namestring

            self.transforms.append(sf)

        # force a reset of the performance tracker, in case
        # this is a repeat run of the algorithm.
        self.perf_tracker = None

        # create transforms and zipline
        self.gen = self._create_generator(self.sim_params)

        with ZiplineAPI(self):
            # loop through simulated_trading, each iteration returns a
            # perf dictionary
            perfs = []
            for perf in self.gen:
                perfs.append(perf)

            # convert perf dict to pandas dataframe
            daily_stats = self._create_daily_stats(perfs)

        self.analyze(daily_stats)

        return daily_stats
Beispiel #15
0
class TradingAlgorithm(object):
    """
    Base class for trading algorithms. Inherit and overload
    initialize() and handle_data(data).

    A new algorithm could look like this:
    ```
    from zipline.api import order

    def initialize(context):
        context.sid = 'AAPL'
        context.amount = 100

    def handle_data(self, data):
        sid = context.sid
        amount = context.amount
        order(sid, amount)
    ```
    To then to run this algorithm pass these functions to
    TradingAlgorithm:

    my_algo = TradingAlgorithm(initialize, handle_data)
    stats = my_algo.run(data)

    """

    # If this is set to false then it is the responsibility
    # of the overriding subclass to set initialized = true
    AUTO_INITIALIZE = True

    def __init__(self, *args, **kwargs):
        """Initialize sids and other state variables.

        :Arguments:
        :Optional:
            initialize : function
                Function that is called with a single
                argument at the begninning of the simulation.
            handle_data : function
                Function that is called with 2 arguments
                (context and data) on every bar.
            script : str
                Algoscript that contains initialize and
                handle_data function definition.
            data_frequency : str (daily, hourly or minutely)
               The duration of the bars.
            capital_base : float <default: 1.0e5>
               How much capital to start with.
            instant_fill : bool <default: False>
               Whether to fill orders immediately or on next bar.
            environment : str <default: 'zipline'>
               The environment that this algorithm is running in.
        """
        self.datetime = None

        self.registered_transforms = {}
        self.transforms = []
        self.sources = []

        # List of trading controls to be used to validate orders.
        self.trading_controls = []

        self._recorded_vars = {}
        self.namespace = kwargs.get('namespace', {})

        self._environment = kwargs.pop('environment', 'zipline')

        self.logger = None

        self.benchmark_return_source = None

        # default components for transact
        self.slippage = VolumeShareSlippage()
        self.commission = PerShare()

        self.instant_fill = kwargs.pop('instant_fill', False)

        # set the capital base
        self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)

        self.sim_params = kwargs.pop('sim_params', None)
        if self.sim_params is None:
            self.sim_params = create_simulation_parameters(
                capital_base=self.capital_base
            )
        self.perf_tracker = PerformanceTracker(self.sim_params)

        self.blotter = kwargs.pop('blotter', None)
        if not self.blotter:
            self.blotter = Blotter()

        self.portfolio_needs_update = True
        self.account_needs_update = True
        self.performance_needs_update = True
        self._portfolio = None
        self._account = None

        self.history_container = None
        self.history_specs = {}

        # If string is passed in, execute and get reference to
        # functions.
        self.algoscript = kwargs.pop('script', None)

        self._initialize = None
        self._before_trading_start = None
        self._analyze = None

        self.event_manager = EventManager()

        if self.algoscript is not None:
            exec_(self.algoscript, self.namespace)
            self._initialize = self.namespace.get('initialize')
            if 'handle_data' not in self.namespace:
                raise ValueError('You must define a handle_data function.')
            else:
                self._handle_data = self.namespace['handle_data']

            self._before_trading_start = \
                self.namespace.get('before_trading_start')
            # Optional analyze function, gets called after run
            self._analyze = self.namespace.get('analyze')

        elif kwargs.get('initialize') and kwargs.get('handle_data'):
            if self.algoscript is not None:
                raise ValueError('You can not set script and \
                initialize/handle_data.')
            self._initialize = kwargs.pop('initialize')
            self._handle_data = kwargs.pop('handle_data')
            self._before_trading_start = kwargs.pop('before_trading_start',
                                                    None)

        self.event_manager.add_event(
            zipline.utils.events.Event(
                zipline.utils.events.Always(),
                # We pass handle_data.__func__ to get the unbound method.
                # We will explicitly pass the algorithm to bind it again.
                self.handle_data.__func__,
            ),
            prepend=True,
        )

        # If method not defined, NOOP
        if self._initialize is None:
            self._initialize = lambda x: None

        # Alternative way of setting data_frequency for backwards
        # compatibility.
        if 'data_frequency' in kwargs:
            self.data_frequency = kwargs.pop('data_frequency')

        # Subclasses that override initialize should only worry about
        # setting self.initialized = True if AUTO_INITIALIZE is
        # is manually set to False.
        self.initialized = False
        self.initialize(*args, **kwargs)
        if self.AUTO_INITIALIZE:
            self.initialized = True

    def initialize(self, *args, **kwargs):
        """
        Call self._initialize with `self` made available to Zipline API
        functions.
        """
        with ZiplineAPI(self):
            self._initialize(self)

    def before_trading_start(self):
        if self._before_trading_start is None:
            return

        self._before_trading_start(self)

    def handle_data(self, data):
        if self.history_container:
            self.history_container.update(data, self.datetime)

        self._handle_data(self, data)

    def analyze(self, perf):
        if self._analyze is None:
            return

        with ZiplineAPI(self):
            self._analyze(self, perf)

    def __repr__(self):
        """
        N.B. this does not yet represent a string that can be used
        to instantiate an exact copy of an algorithm.

        However, it is getting close, and provides some value as something
        that can be inspected interactively.
        """
        return """
{class_name}(
    capital_base={capital_base}
    sim_params={sim_params},
    initialized={initialized},
    slippage={slippage},
    commission={commission},
    blotter={blotter},
    recorded_vars={recorded_vars})
""".strip().format(class_name=self.__class__.__name__,
                   capital_base=self.capital_base,
                   sim_params=repr(self.sim_params),
                   initialized=self.initialized,
                   slippage=repr(self.slippage),
                   commission=repr(self.commission),
                   blotter=repr(self.blotter),
                   recorded_vars=repr(self.recorded_vars))

    def _create_data_generator(self, source_filter, sim_params=None):
        """
        Create a merged data generator using the sources and
        transforms attached to this algorithm.

        ::source_filter:: is a method that receives events in date
        sorted order, and returns True for those events that should be
        processed by the zipline, and False for those that should be
        skipped.
        """
        if sim_params is None:
            sim_params = self.sim_params

        if self.benchmark_return_source is None:
            env = trading.environment
            if (sim_params.data_frequency == 'minute'
                    or sim_params.emission_rate == 'minute'):
                update_time = lambda date: env.get_open_and_close(date)[1]
            else:
                update_time = lambda date: date
            benchmark_return_source = [
                Event({'dt': update_time(dt),
                       'returns': ret,
                       'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK,
                       'source_id': 'benchmarks'})
                for dt, ret in
                trading.environment.benchmark_returns.iteritems()
                if dt.date() >= sim_params.period_start.date()
                and dt.date() <= sim_params.period_end.date()
            ]
        else:
            benchmark_return_source = self.benchmark_return_source

        date_sorted = date_sorted_sources(*self.sources)

        if source_filter:
            date_sorted = filter(source_filter, date_sorted)

        with_tnfms = sequential_transforms(date_sorted,
                                           *self.transforms)

        with_benchmarks = date_sorted_sources(benchmark_return_source,
                                              with_tnfms)

        # Group together events with the same dt field. This depends on the
        # events already being sorted.
        return groupby(with_benchmarks, attrgetter('dt'))

    def _create_generator(self, sim_params, source_filter=None):
        """
        Create a basic generator setup using the sources and
        transforms attached to this algorithm.

        ::source_filter:: is a method that receives events in date
        sorted order, and returns True for those events that should be
        processed by the zipline, and False for those that should be
        skipped.
        """
        if self.perf_tracker is None:
            # HACK: When running with the `run` method, we set perf_tracker to
            # None so that it will be overwritten here.
            self.perf_tracker = PerformanceTracker(sim_params)

        self.portfolio_needs_update = True
        self.account_needs_update = True
        self.performance_needs_update = True

        self.data_gen = self._create_data_generator(source_filter, sim_params)

        self.trading_client = AlgorithmSimulator(self, sim_params)

        transact_method = transact_partial(self.slippage, self.commission)
        self.set_transact(transact_method)

        return self.trading_client.transform(self.data_gen)

    def get_generator(self):
        """
        Override this method to add new logic to the construction
        of the generator. Overrides can use the _create_generator
        method to get a standard construction generator.
        """
        return self._create_generator(self.sim_params)

    # TODO: make a new subclass, e.g. BatchAlgorithm, and move
    # the run method to the subclass, and refactor to put the
    # generator creation logic into get_generator.
    def run(self, source, overwrite_sim_params=True,
            benchmark_return_source=None):
        """Run the algorithm.

        :Arguments:
            source : can be either:
                     - pandas.DataFrame
                     - zipline source
                     - list of sources

               If pandas.DataFrame is provided, it must have the
               following structure:
               * column names must consist of ints representing the
                 different sids
               * index must be DatetimeIndex
               * array contents should be price info.

        :Returns:
            daily_stats : pandas.DataFrame
              Daily performance metrics such as returns, alpha etc.

        """
        if isinstance(source, list):
            if overwrite_sim_params:
                warnings.warn("""List of sources passed, will not attempt to extract sids, and start and end
 dates. Make sure to set the correct fields in sim_params passed to
 __init__().""", UserWarning)
                overwrite_sim_params = False
        elif isinstance(source, pd.DataFrame):
            # if DataFrame provided, wrap in DataFrameSource
            source = DataFrameSource(source)
        elif isinstance(source, pd.Panel):
            source = DataPanelSource(source)

        if isinstance(source, list):
            self.set_sources(source)
        else:
            self.set_sources([source])

        # Override sim_params if params are provided by the source.
        if overwrite_sim_params:
            if hasattr(source, 'start'):
                self.sim_params.period_start = source.start
            if hasattr(source, 'end'):
                self.sim_params.period_end = source.end
            all_sids = [sid for s in self.sources for sid in s.sids]
            self.sim_params.sids = set(all_sids)
            # Changing period_start and period_close might require updating
            # of first_open and last_close.
            self.sim_params._update_internal()

        # Create history containers
        if len(self.history_specs) != 0:
            self.history_container = HistoryContainer(
                self.history_specs,
                self.sim_params.sids,
                self.sim_params.first_open)

        # Create transforms by wrapping them into StatefulTransforms
        self.transforms = []
        for namestring, trans_descr in iteritems(self.registered_transforms):
            sf = StatefulTransform(
                trans_descr['class'],
                *trans_descr['args'],
                **trans_descr['kwargs']
            )
            sf.namestring = namestring

            self.transforms.append(sf)

        # force a reset of the performance tracker, in case
        # this is a repeat run of the algorithm.
        self.perf_tracker = None

        # create transforms and zipline
        self.gen = self._create_generator(self.sim_params)

        with ZiplineAPI(self):
            # loop through simulated_trading, each iteration returns a
            # perf dictionary
            perfs = []
            for perf in self.gen:
                perfs.append(perf)

            # convert perf dict to pandas dataframe
            daily_stats = self._create_daily_stats(perfs)

        self.analyze(daily_stats)

        return daily_stats

    def _create_daily_stats(self, perfs):
        # create daily and cumulative stats dataframe
        daily_perfs = []
        # TODO: the loop here could overwrite expected properties
        # of daily_perf. Could potentially raise or log a
        # warning.
        for perf in perfs:
            if 'daily_perf' in perf:

                perf['daily_perf'].update(
                    perf['daily_perf'].pop('recorded_vars')
                )
                daily_perfs.append(perf['daily_perf'])
            else:
                self.risk_report = perf

        daily_dts = [np.datetime64(perf['period_close'], utc=True)
                     for perf in daily_perfs]
        daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)

        return daily_stats

    def add_transform(self, transform_class, tag, *args, **kwargs):
        """Add a single-sid, sequential transform to the model.

        :Arguments:
            transform_class : class
                Which transform to use. E.g. mavg.
            tag : str
                How to name the transform. Can later be access via:
                data[sid].tag()

        Extra args and kwargs will be forwarded to the transform
        instantiation.

        """
        self.registered_transforms[tag] = {'class': transform_class,
                                           'args': args,
                                           'kwargs': kwargs}

    @api_method
    def get_environment(self):
        return self._environment

    def add_event(self, rule=None, callback=None):
        """
        Adds an event to the algorithm's EventManager.
        """
        self.event_manager.add_event(
            zipline.utils.events.Event(rule, callback),
        )

    @api_method
    def schedule_function(self,
                          func,
                          date_rule=None,
                          time_rule=None,
                          half_days=True):
        """
        Schedules a function to be called with some timed rules.
        """
        if self.sim_params.data_frequency != 'minute':
            raise IncompatibleScheduleFunctionDataFrequency()

        date_rule = date_rule or DateRuleFactory.every_day()
        time_rule = time_rule or TimeRuleFactory.market_open()

        self.add_event(
            make_eventrule(date_rule, time_rule, half_days),
            func,
        )

    @api_method
    def record(self, *args, **kwargs):
        """
        Track and record local variable (i.e. attributes) each day.
        """
        # Make 2 objects both referencing the same iterator
        args = [iter(args)] * 2

        # Zip generates list entries by calling `next` on each iterator it
        # receives.  In this case the two iterators are the same object, so the
        # call to next on args[0] will also advance args[1], resulting in zip
        # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.
        positionals = zip(*args)
        for name, value in chain(positionals, iteritems(kwargs)):
            self._recorded_vars[name] = value

    @api_method
    def symbol(self, symbol_str, as_of_date=None):
        """
        Default symbol lookup for any source that directly maps the
        symbol to the identifier (e.g. yahoo finance).
        Keyword argument as_of_date is ignored.
        """
        return symbol_str

    @api_method
    def order(self, sid, amount,
              limit_price=None,
              stop_price=None,
              style=None):
        """
        Place an order using the specified parameters.
        """

        def round_if_near_integer(a, epsilon=1e-4):
            """
            Round a to the nearest integer if that integer is within an epsilon
            of a.
            """
            if abs(a - round(a)) <= epsilon:
                return round(a)
            else:
                return a

        # Truncate to the integer share count that's either within .0001 of
        # amount or closer to zero.
        # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0
        amount = int(round_if_near_integer(amount))

        # Raises a ZiplineError if invalid parameters are detected.
        self.validate_order_params(sid,
                                   amount,
                                   limit_price,
                                   stop_price,
                                   style)

        # Convert deprecated limit_price and stop_price parameters to use
        # ExecutionStyle objects.
        style = self.__convert_order_params_for_blotter(limit_price,
                                                        stop_price,
                                                        style)
        return self.blotter.order(sid, amount, style)

    def validate_order_params(self,
                              sid,
                              amount,
                              limit_price,
                              stop_price,
                              style):
        """
        Helper method for validating parameters to the order API function.

        Raises an UnsupportedOrderParameters if invalid arguments are found.
        """

        if not self.initialized:
            raise OrderDuringInitialize(
                msg="order() can only be called from within handle_data()"
            )

        if style:
            if limit_price:
                raise UnsupportedOrderParameters(
                    msg="Passing both limit_price and style is not supported."
                )

            if stop_price:
                raise UnsupportedOrderParameters(
                    msg="Passing both stop_price and style is not supported."
                )

        for control in self.trading_controls:
            control.validate(sid,
                             amount,
                             self.updated_portfolio(),
                             self.get_datetime(),
                             self.trading_client.current_data)

    @staticmethod
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
        """
        Helper method for converting deprecated limit_price and stop_price
        arguments into ExecutionStyle instances.

        This function assumes that either style == None or (limit_price,
        stop_price) == (None, None).
        """
        # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.
        if style:
            assert (limit_price, stop_price) == (None, None)
            return style
        if limit_price and stop_price:
            return StopLimitOrder(limit_price, stop_price)
        if limit_price:
            return LimitOrder(limit_price)
        if stop_price:
            return StopOrder(stop_price)
        else:
            return MarketOrder()

    @api_method
    def order_value(self, sid, value,
                    limit_price=None, stop_price=None, style=None):
        """
        Place an order by desired value rather than desired number of shares.
        If the requested sid is found in the universe, the requested value is
        divided by its price to imply the number of shares to transact.

        value > 0 :: Buy/Cover
        value < 0 :: Sell/Short
        Market order:    order(sid, value)
        Limit order:     order(sid, value, limit_price)
        Stop order:      order(sid, value, None, stop_price)
        StopLimit order: order(sid, value, limit_price, stop_price)
        """
        last_price = self.trading_client.current_data[sid].price
        if np.allclose(last_price, 0):
            zero_message = "Price of 0 for {psid}; can't infer value".format(
                psid=sid
            )
            if self.logger:
                self.logger.debug(zero_message)
            # Don't place any order
            return
        else:
            amount = value / last_price
            return self.order(sid, amount,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)

    @property
    def recorded_vars(self):
        return copy(self._recorded_vars)

    @property
    def portfolio(self):
        return self.updated_portfolio()

    def updated_portfolio(self):
        if self.portfolio_needs_update:
            self._portfolio = \
                self.perf_tracker.get_portfolio(self.performance_needs_update)
            self.portfolio_needs_update = False
            self.performance_needs_update = False
        return self._portfolio

    @property
    def account(self):
        return self.updated_account()

    def updated_account(self):
        if self.account_needs_update:
            self._account = \
                self.perf_tracker.get_account(self.performance_needs_update)
            self.account_needs_update = False
            self.performance_needs_update = False
        return self._account

    def set_logger(self, logger):
        self.logger = logger

    def on_dt_changed(self, dt):
        """
        Callback triggered by the simulation loop whenever the current dt
        changes.

        Any logic that should happen exactly once at the start of each datetime
        group should happen here.
        """
        assert isinstance(dt, datetime), \
            "Attempt to set algorithm's current time with non-datetime"
        assert dt.tzinfo == pytz.utc, \
            "Algorithm expects a utc datetime"

        self.datetime = dt
        self.perf_tracker.set_date(dt)
        self.blotter.set_date(dt)

    @api_method
    def get_datetime(self, tz=None):
        """
        Returns a copy of the datetime.
        """
        date_copy = copy(self.datetime)
        assert date_copy.tzinfo == pytz.utc, \
            "Algorithm should have a utc datetime"
        if tz is not None:
            date_copy = date_copy.tz_convert(tz)
        return date_copy

    def set_transact(self, transact):
        """
        Set the method that will be called to create a
        transaction from open orders and trade events.
        """
        self.blotter.transact = transact

    def update_dividends(self, dividend_frame):
        """
        Set DataFrame used to process dividends.  DataFrame columns should
        contain at least the entries in zp.DIVIDEND_FIELDS.
        """
        self.perf_tracker.update_dividends(dividend_frame)

    @api_method
    def set_slippage(self, slippage):
        if not isinstance(slippage, SlippageModel):
            raise UnsupportedSlippageModel()
        if self.initialized:
            raise OverrideSlippagePostInit()
        self.slippage = slippage

    @api_method
    def set_commission(self, commission):
        if not isinstance(commission, (PerShare, PerTrade, PerDollar)):
            raise UnsupportedCommissionModel()

        if self.initialized:
            raise OverrideCommissionPostInit()
        self.commission = commission

    def set_sources(self, sources):
        assert isinstance(sources, list)
        self.sources = sources

    def set_transforms(self, transforms):
        assert isinstance(transforms, list)
        self.transforms = transforms

    # Remain backwards compatibility
    @property
    def data_frequency(self):
        return self.sim_params.data_frequency

    @data_frequency.setter
    def data_frequency(self, value):
        assert value in ('daily', 'minute')
        self.sim_params.data_frequency = value

    @api_method
    def order_percent(self, sid, percent,
                      limit_price=None, stop_price=None, style=None):
        """
        Place an order in the specified security corresponding to the given
        percent of the current portfolio value.

        Note that percent must expressed as a decimal (0.50 means 50\%).
        """
        value = self.portfolio.portfolio_value * percent
        return self.order_value(sid, value,
                                limit_price=limit_price,
                                stop_price=stop_price,
                                style=style)

    @api_method
    def order_target(self, sid, target,
                     limit_price=None, stop_price=None, style=None):
        """
        Place an order to adjust a position to a target number of shares. If
        the position doesn't already exist, this is equivalent to placing a new
        order. If the position does exist, this is equivalent to placing an
        order for the difference between the target number of shares and the
        current number of shares.
        """
        if sid in self.portfolio.positions:
            current_position = self.portfolio.positions[sid].amount
            req_shares = target - current_position
            return self.order(sid, req_shares,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)
        else:
            return self.order(sid, target,
                              limit_price=limit_price,
                              stop_price=stop_price,
                              style=style)

    @api_method
    def order_target_value(self, sid, target,
                           limit_price=None, stop_price=None, style=None):
        """
        Place an order to adjust a position to a target value. If
        the position doesn't already exist, this is equivalent to placing a new
        order. If the position does exist, this is equivalent to placing an
        order for the difference between the target value and the
        current value.
        """
        last_price = self.trading_client.current_data[sid].price
        if np.allclose(last_price, 0):
            # Don't place an order
            if self.logger:
                zero_message = "Price of 0 for {psid}; can't infer value"
                self.logger.debug(zero_message.format(psid=sid))
            return
        target_amount = target / last_price
        return self.order_target(sid, target_amount,
                                 limit_price=limit_price,
                                 stop_price=stop_price,
                                 style=style)

    @api_method
    def order_target_percent(self, sid, target,
                             limit_price=None, stop_price=None, style=None):
        """
        Place an order to adjust a position to a target percent of the
        current portfolio value. If the position doesn't already exist, this is
        equivalent to placing a new order. If the position does exist, this is
        equivalent to placing an order for the difference between the target
        percent and the current percent.

        Note that target must expressed as a decimal (0.50 means 50\%).
        """
        target_value = self.portfolio.portfolio_value * target
        return self.order_target_value(sid, target_value,
                                       limit_price=limit_price,
                                       stop_price=stop_price,
                                       style=style)

    @api_method
    def get_open_orders(self, sid=None):
        if sid is None:
            return {
                key: [order.to_api_obj() for order in orders]
                for key, orders in iteritems(self.blotter.open_orders)
                if orders
            }
        if sid in self.blotter.open_orders:
            orders = self.blotter.open_orders[sid]
            return [order.to_api_obj() for order in orders]
        return []

    @api_method
    def get_order(self, order_id):
        if order_id in self.blotter.orders:
            return self.blotter.orders[order_id].to_api_obj()

    @api_method
    def cancel_order(self, order_param):
        order_id = order_param
        if isinstance(order_param, zipline.protocol.Order):
            order_id = order_param.id

        self.blotter.cancel(order_id)

    @api_method
    def add_history(self, bar_count, frequency, field,
                    ffill=True):
        data_frequency = self.sim_params.data_frequency
        daily_at_midnight = (data_frequency == 'daily')

        history_spec = HistorySpec(bar_count, frequency, field, ffill,
                                   daily_at_midnight=daily_at_midnight,
                                   data_frequency=data_frequency)
        self.history_specs[history_spec.key_str] = history_spec

    @api_method
    def history(self, bar_count, frequency, field, ffill=True):
        spec_key_str = HistorySpec.spec_key(
            bar_count, frequency, field, ffill)
        history_spec = self.history_specs[spec_key_str]
        return self.history_container.get_history(history_spec, self.datetime)

    ####################
    # Trading Controls #
    ####################

    def register_trading_control(self, control):
        """
        Register a new TradingControl to be checked prior to order calls.
        """
        if self.initialized:
            raise RegisterTradingControlPostInit()
        self.trading_controls.append(control)

    @api_method
    def set_max_position_size(self,
                              sid=None,
                              max_shares=None,
                              max_notional=None):
        """
        Set a limit on the number of shares and/or dollar value held for the
        given sid. Limits are treated as absolute values and are enforced at
        the time that the algo attempts to place an order for sid. This means
        that it's possible to end up with more than the max number of shares
        due to splits/dividends, and more than the max notional due to price
        improvement.

        If an algorithm attempts to place an order that would result in
        increasing the absolute value of shares/dollar value exceeding one of
        these limits, raise a TradingControlException.
        """
        control = MaxPositionSize(sid=sid,
                                  max_shares=max_shares,
                                  max_notional=max_notional)
        self.register_trading_control(control)

    @api_method
    def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
        """
        Set a limit on the number of shares and/or dollar value of any single
        order placed for sid.  Limits are treated as absolute values and are
        enforced at the time that the algo attempts to place an order for sid.

        If an algorithm attempts to place an order that would result in
        exceeding one of these limits, raise a TradingControlException.
        """
        control = MaxOrderSize(sid=sid,
                               max_shares=max_shares,
                               max_notional=max_notional)
        self.register_trading_control(control)

    @api_method
    def set_max_order_count(self, max_count):
        """
        Set a limit on the number of orders that can be placed within the given
        time interval.
        """
        control = MaxOrderCount(max_count)
        self.register_trading_control(control)

    @api_method
    def set_long_only(self):
        """
        Set a rule specifying that this algorithm cannot take short positions.
        """
        self.register_trading_control(LongOnly())

    @classmethod
    def all_api_methods(cls):
        """
        Return a list of all the TradingAlgorithm API methods.
        """
        return [fn for fn in cls.__dict__.itervalues()
                if getattr(fn, 'is_api_method', False)]
Beispiel #16
0
    def test_history_grow_length(self,
                                 freq,
                                 field,
                                 data_frequency,
                                 construct_digest):
        bar_count = 2 if construct_digest else 1
        spec = history.HistorySpec(
            bar_count=bar_count,
            frequency=freq,
            field=field,
            ffill=True,
            data_frequency=data_frequency,
            env=self.env,
        )
        specs = {spec.key_str: spec}
        initial_sids = [1]
        initial_dt = pd.Timestamp(
            '2013-06-28 13:31'
            if data_frequency == 'minute'
            else '2013-06-28 12:00AM',
            tz='UTC',
        )

        container = HistoryContainer(
            specs, initial_sids, initial_dt, data_frequency, env=self.env,
        )

        if construct_digest:
            self.assertEqual(
                container.digest_panels[spec.frequency].window_length, 1,
            )

        bar_data = BarData()
        container.update(bar_data, initial_dt)

        to_add = (
            history.HistorySpec(
                bar_count=bar_count + 1,
                frequency=freq,
                field=field,
                ffill=True,
                data_frequency=data_frequency,
                env=self.env,
            ),
            history.HistorySpec(
                bar_count=bar_count + 2,
                frequency=freq,
                field=field,
                ffill=True,
                data_frequency=data_frequency,
                env=self.env,
            ),
        )

        for spec in to_add:
            container.ensure_spec(spec, initial_dt, bar_data)

            self.assertEqual(
                container.digest_panels[spec.frequency].window_length,
                spec.bar_count - 1,
            )

            self.assert_history(container, spec, initial_dt)