Exemplo n.º 1
0
def get_fixture_data_portal(**kwargs):

    b = Backend(**kwargs)

    finder = AssetFinder(b)

    return DataPortal(b, finder, b._calendar, False)
Exemplo n.º 2
0
    def __init__(self, *args, **kwargs):
        '''
        data_frequency: 'minute' or 'daily'
        algoname: str, defaults to 'algo'
        backend: str or Backend instance, defaults to 'alpaca'
                 (str is either backend module name under
                  'pylivetrader.backend', or global import path)
        trading_calendar: pd.DateIndex for trading calendar
        initialize: initialize function
        handle_data: handle_data function
        before_trading_start: before_trading_start function
        '''
        self._recorded_vars = {}

        self.data_frequency = kwargs.pop('data_frequency', 'minute')
        assert self.data_frequency in ('minute', 'daily')

        self._algoname = kwargs.pop('algoname', 'algo')

        self._state_store = StateStore(
            kwargs.pop('statefile', None)
            or '{}-state.pkl'.format(self._algoname))

        self._pipelines = {}

        backend_param = kwargs.pop('backend', 'alpaca')
        if not isinstance(backend_param, str):
            self._backend = backend_param
            self._backend_name = backend_param.__class__.__name__
        else:
            self._backend_name = backend_param
            try:
                # First, tries to import official backend packages
                backendmod = importlib.import_module(
                    'pylivetrader.backend.{}'.format(self._backend_name))
            except ImportError:
                # Then if failes, tries to find pkg in global package
                # namespace.
                try:
                    backendmod = importlib.import_module(self._backend_name)
                except ImportError:
                    raise RuntimeError(
                        "Could not find backend package `{}`.".format(
                            self._backend_name))

            backend_options = kwargs.pop('backend_options', None) or {}
            self._backend = backendmod.Backend(**backend_options)

        self.asset_finder = AssetFinder(self._backend)

        self.trading_calendar = kwargs.pop('trading_calendar',
                                           get_calendar('NYSE'))

        self.data_portal = DataPortal(self._backend, self.asset_finder,
                                      self.trading_calendar)

        self.event_manager = EventManager()

        self.trading_controls = []

        self.account_controls = []

        self.restrictions = NoRestrictions()

        self._initialize = kwargs.pop('initialize', noop)
        self._handle_data = kwargs.pop('handle_data', noop)
        self._before_trading_start = kwargs.pop('before_trading_start', noop)

        self.event_manager.add_event(
            events.Event(
                events.Always(),
                # We pass handle_data.__func__ to get the unbound method.
                self.handle_data.__func__,
            ),
            prepend=True,
        )

        self._account_needs_update = True
        self._portfolio_needs_update = True

        self._in_before_trading_start = False

        self._assets_from_source = []

        self._context_persistence_excludes = []

        self._max_shares = int(1e+11)

        self.initialized = False
Exemplo n.º 3
0
class Algorithm:
    """Provides algorithm compatible with zipline.
    """
    def __init__(self, *args, **kwargs):
        '''
        data_frequency: 'minute' or 'daily'
        algoname: str, defaults to 'algo'
        backend: str or Backend instance, defaults to 'alpaca'
                 (str is either backend module name under
                  'pylivetrader.backend', or global import path)
        trading_calendar: pd.DateIndex for trading calendar
        initialize: initialize function
        handle_data: handle_data function
        before_trading_start: before_trading_start function
        '''
        self._recorded_vars = {}

        self.data_frequency = kwargs.pop('data_frequency', 'minute')
        assert self.data_frequency in ('minute', 'daily')

        self._algoname = kwargs.pop('algoname', 'algo')

        self._state_store = StateStore(
            kwargs.pop('statefile', None)
            or '{}-state.pkl'.format(self._algoname))

        self._pipelines = {}

        backend_param = kwargs.pop('backend', 'alpaca')
        if not isinstance(backend_param, str):
            self._backend = backend_param
            self._backend_name = backend_param.__class__.__name__
        else:
            self._backend_name = backend_param
            try:
                # First, tries to import official backend packages
                backendmod = importlib.import_module(
                    'pylivetrader.backend.{}'.format(self._backend_name))
            except ImportError:
                # Then if failes, tries to find pkg in global package
                # namespace.
                try:
                    backendmod = importlib.import_module(self._backend_name)
                except ImportError:
                    raise RuntimeError(
                        "Could not find backend package `{}`.".format(
                            self._backend_name))

            backend_options = kwargs.pop('backend_options', None) or {}
            self._backend = backendmod.Backend(**backend_options)

        self.asset_finder = AssetFinder(self._backend)

        self.trading_calendar = kwargs.pop('trading_calendar',
                                           get_calendar('NYSE'))

        self.data_portal = DataPortal(self._backend, self.asset_finder,
                                      self.trading_calendar)

        self.event_manager = EventManager()

        self.trading_controls = []

        self.account_controls = []

        self.restrictions = NoRestrictions()

        self._initialize = kwargs.pop('initialize', noop)
        self._handle_data = kwargs.pop('handle_data', noop)
        self._before_trading_start = kwargs.pop('before_trading_start', noop)

        self.event_manager.add_event(
            events.Event(
                events.Always(),
                # We pass handle_data.__func__ to get the unbound method.
                self.handle_data.__func__,
            ),
            prepend=True,
        )

        self._account_needs_update = True
        self._portfolio_needs_update = True

        self._in_before_trading_start = False

        self._assets_from_source = []

        self._context_persistence_excludes = []

        self._max_shares = int(1e+11)

        self.initialized = False

    def initialize(self, *args, **kwargs):
        self._context_persistence_excludes = (list(self.__dict__.keys()) +
                                              ['executor'])
        self._state_store.load(self, self._algoname)

        with LiveTraderAPI(self):
            self._initialize(self, *args, **kwargs)
            self._state_store.save(self, self._algoname,
                                   self._context_persistence_excludes)
        self.initialized = True

    def handle_data(self, data):
        if self._handle_data:
            self._handle_data(self, data)
            self._state_store.save(self, self._algoname,
                                   self._context_persistence_excludes)

    def before_trading_start(self, data):
        if self._before_trading_start is None:
            return

        self._in_before_trading_start = True

        with handle_non_market_minutes(data) if \
                self.data_frequency == "minute" else ExitStack():
            self._before_trading_start(self, data)
            self._state_store.save(self, self._algoname,
                                   self._context_persistence_excludes)

        self._in_before_trading_start = False

    def run(self, retry=True):

        log.info("livetrader start running with "
                 "backend = {} "
                 "data-frequency = {}".format(self._backend_name,
                                              self.data_frequency))

        # for compatibility with zipline to provide history api
        self._assets_from_source = \
            self.asset_finder.retrieve_all(self.asset_finder.sids)

        if not self.initialized:
            self.initialize()

        self.executor = AlgorithmExecutor(
            self,
            self.data_portal,
        )

        return self.executor.run(retry=retry)

    @api_method
    def get_environment(self, field='platform'):
        raise APINotSupported

    @api_method
    def fetch_csv(self,
                  url,
                  pre_func=None,
                  post_func=None,
                  date_column='date',
                  date_format=None,
                  timezone=pytz.utc.zone,
                  symbol=None,
                  mask=True,
                  symbol_column=None,
                  special_params_checker=None,
                  **kwargs):
        raise APINotSupported

    @api_method
    @disallowed_in_before_trading_start(OrderInBeforeTradingStart())
    def order(self,
              asset,
              amount,
              limit_price=None,
              stop_price=None,
              style=None):

        if not self._can_order_asset(asset):
            return None

        amount, style = self._calculate_order(asset, amount, limit_price,
                                              stop_price, style)

        if amount == 0:
            return None

        if amount > self._max_shares:
            # Arbitrary limit of 100 billion (US) shares will never be
            # exceeded except by a buggy algorithm.
            raise OverflowError("Can't order more than %d shares" %
                                self._max_shares)

        o = self._backend.order(asset, amount, style)
        if o:
            return o.id

    @api_method
    def add_event(self, rule=None, callback=None):
        self.event_manager.add_event(events.Event(rule, callback), )

    @api_method
    def schedule_function(self,
                          func,
                          date_rule=None,
                          time_rule=None,
                          half_days=True,
                          calendar=None):
        """Schedules a function to be called according to some timed rules.

        Parameters
        ----------
        func : callable[(context, data) -> None]
            The function to execute when the rule is triggered.
        date_rule : EventRule, optional
            The rule for the dates to execute this function.
        time_rule : EventRule, optional
            The rule for the times to execute this function.
        half_days : bool, optional
            Should this rule fire on half days?
        calendar : Sentinel, optional
            Calendar used to reconcile date and time rules.

        See Also
        --------
        :class:`zipline.api.date_rules`
        :class:`zipline.api.time_rules` sta
        """

        # When the user calls schedule_function(func, <time_rule>), assume that
        # the user meant to specify a time rule but no date rule, instead of
        # a date rule and no time rule as the signature suggests
        if isinstance(date_rule, (AfterOpen, BeforeClose)) and not time_rule:
            warnings.warn(
                'Got a time rule for the second positional argument '
                'date_rule. You should use keyword argument '
                'time_rule= when calling schedule_function without '
                'specifying a date_rule',
                stacklevel=3)

        date_rule = date_rule or date_rules.every_day()
        time_rule = ((time_rule or time_rules.every_minute())
                     if self.data_frequency == 'minute' else
                     # If we are in daily mode the time_rule is ignored.
                     time_rules.every_minute())

        # Check the type of the algorithm's schedule before pulling calendar
        # Note that the ExchangeTradingSchedule is currently the only
        # TradingSchedule class, so this is unlikely to be hit
        if calendar is None:
            cal = self.trading_calendar
        elif calendar is calendars.US_EQUITIES:
            cal = get_calendar('NYSE')
        elif calendar is calendars.US_FUTURES:
            cal = get_calendar('us_futures')
        else:
            raise ScheduleFunctionInvalidCalendar(
                given_calendar=calendar,
                allowed_calendars=(
                    '[calendars.US_EQUITIES, calendars.US_FUTURES]'),
            )

        self.add_event(
            make_eventrule(date_rule, time_rule, cal, half_days),
            func,
        )

    @api_method
    def record(self, *args, **kwargs):
        """Track and record values each day.

        Parameters
        ----------
        **kwargs
            The names and values to record.

        Notes
        -----
        These values will appear in the performance packets and the performance
        dataframe passed to ``analyze`` and returned from
        :func:`~zipline.run_algorithm`.
        """
        # Make 2 objects both referencing the same iterator
        args = [iter(args)] * 2

        # Zip generates list entries by calling `next` on each iterator it
        # receives.  In this case the two iterators are the same object, so the
        # call to next on args[0] will also advance args[1], resulting in zip
        # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.
        positionals = zip(*args)
        for name, value in chain(positionals, kwargs.items()):
            self._recorded_vars[name] = value

    @api_method
    def set_benchmark(self, benchmark):
        '''Just do nothing for compatibility.'''
        pass

    @api_method
    @preprocess(symbol=ensure_upper_case)
    def symbol(self, symbol):
        '''Lookup equity by symbol.

        Parameters:
            symbol (string): The ticker symbol for the asset.

        Returns:
            equity (Equity): The equity object lookuped by the ``symbol``.

        Raises:
            AssetNotFound: When could not resolve the ``Asset`` by ``symbol``.
        '''
        return self.asset_finder.lookup_symbol(symbol, as_of_date=None)

    @api_method
    def continuous_future(self, *args, **kwargs):
        raise APINotSupported

    @api_method
    def symbols(self, *args, **kwargs):
        '''Lookup equities by symbol.

        Parameters:
            args (iterable[str]): List of ticker symbols for the asset.

        Returns:
            equities (List[Equity]): The equity lookuped by the ``symbol``.

        Raises:
            AssetNotFound: When could not resolve the ``Asset`` by ``symbol``.
        '''
        return [self.symbol(idendifier, **kwargs) for idendifier in args]

    @api_method
    def sid(self, sid):
        '''Lookup equity by asset unique identifier

        Parameters:
            sid: asset unique identifier.

        Returns:
            equity (Equity): The equity object lookuped by the ``sid``.

        Raises:
            AssetNotFound: When could not resolve the ``Asset`` by ``sid``.
        '''
        return self.asset_finder.retrieve_asset(sid)

    @api_method
    def future_symbol(self, symbol):
        raise APINotSupported

    @api_method
    def batch_order(self, order_arg_list):
        return [self.order(*order_args) for order_args in order_arg_list]

    @api_method
    @disallowed_in_before_trading_start(OrderInBeforeTradingStart())
    def order_value(self,
                    asset,
                    value,
                    limit_price=None,
                    stop_price=None,
                    style=None):
        if not self._can_order_asset(asset):
            return None

        amount = self._calculate_order_value_amount(asset, value)
        return self.order(asset,
                          amount,
                          limit_price=limit_price,
                          stop_price=stop_price,
                          style=style)

    @property
    def recorded_vars(self):
        return copy(self._recorded_vars)

    @property
    def portfolio(self):
        if self._portfolio_needs_update:
            self._portfolio = self._backend.portfolio
            self._portfolio_needs_update = False
        return self._portfolio

    @property
    def account(self):
        if self._account_needs_update:
            self._account = self._backend.account
            self._account_needs_update = False
        return self._account

    def set_logger(self, logger):
        self.logger = logger

    def on_dt_changed(self, dt):
        self._portfolio_needs_update = True
        self._account_needs_update = True
        self.datetime = dt

    @api_method
    @preprocess(tz=coerce_string(pytz.timezone))
    @expect_types(tz=optional(tzinfo))
    def get_datetime(self, tz=None):
        dt = self.datetime
        assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime"
        if tz is not None:
            dt = dt.astimezone(tz)
        return dt

    @api_method
    def set_slippage(self, **kwargs):
        '''Just do nothing for compatibility.'''
        pass

    @api_method
    def set_commission(self, **kwargs):
        '''Just do nothing for compatibility.'''
        pass

    @api_method
    def set_cancel_policy(self, *args):
        '''Just do nothing for compatibility.'''
        pass

    @api_method
    def set_symbol_lookup_date(self, dt):
        '''Just do nothing for compatibility.'''
        pass

    @api_method
    def order_percent(self,
                      asset,
                      percent,
                      limit_price=None,
                      stop_price=None,
                      style=None):
        if not self._can_order_asset(asset):
            return None

        amount = self._calculate_order_percent_amount(asset, percent)
        return self.order(asset,
                          amount,
                          limit_price=limit_price,
                          stop_price=stop_price,
                          style=style)

    @api_method
    def order_target(self,
                     asset,
                     target,
                     limit_price=None,
                     stop_price=None,
                     style=None):
        if not self._can_order_asset(asset):
            return None

        amount = self._calculate_order_target_amount(asset, target)
        return self.order(asset,
                          amount,
                          limit_price=limit_price,
                          stop_price=stop_price,
                          style=style)

    @api_method
    def order_target_value(self,
                           asset,
                           target,
                           limit_price=None,
                           stop_price=None,
                           style=None):
        if not self._can_order_asset(asset):
            return None

        target_amount = self._calculate_order_value_amount(asset, target)
        amount = self._calculate_order_target_amount(asset, target_amount)
        return self.order(asset,
                          amount,
                          limit_price=limit_price,
                          stop_price=stop_price,
                          style=style)

    @api_method
    def order_target_percent(self,
                             asset,
                             target,
                             limit_price=None,
                             stop_price=None,
                             style=None):
        if not self._can_order_asset(asset):
            return None

        amount = self._calculate_order_target_percent_amount(asset, target)
        return self.order(asset,
                          amount,
                          limit_price=limit_price,
                          stop_price=stop_price,
                          style=style)

    @api_method
    @expect_types(share_counts=pd.Series)
    @expect_dtypes(share_counts=np.dtype('float64'))
    def batch_market_order(self, share_counts):
        style = MarketOrder()
        order_args = [(asset, amount, style)
                      for (asset, amount) in share_counts.items() if amount]
        return self._backend.batch_order(order_args)

    @api_method
    def get_open_orders(self, asset=None):
        '''
        If asset is unspecified or None, returns a dictionary keyed by
        asset ID. The dictionary contains a list of orders for each ID,
        oldest first. If an asset is specified, returns a list of open
        orders for that asset, oldest first.
        '''
        return self.get_all_orders(asset=asset, status='open')

    @api_method
    def get_recent_orders(self, days_back=2):
        '''
        Returns all orders from the past n days.
        '''
        return self.get_all_orders(days_back=days_back)

    @api_method
    def get_all_orders(self,
                       asset=None,
                       before=None,
                       status='all',
                       days_back=None):
        '''
        If asset is unspecified or None, returns a dictionary keyed by
        asset ID. The dictionary contains a list of orders for each ID,
        oldest first. If an asset is specified, returns a list of open
        orders for that asset, oldest first. Orders submitted after
        before will not be returned. If provided, only orders of type
        status ('closed' or 'open') will be returned.
        '''
        orders = self._backend.all_orders(before, status, days_back)

        omap = {}
        orders = sorted([o for o in orders.values()], key=lambda o: o.dt)
        for order in orders:
            key = order.asset
            if key not in omap:
                omap[key] = []
            omap[key].append(order.to_api_obj())

        if asset is None:
            return omap
        return omap.get(asset, [])

    @api_method
    def get_order(self, order_id):
        return self._backend.get_order(order_id).to_api_obj()

    @api_method
    def cancel_order(self, order_param):
        order_id = order_param
        if isinstance(order_param, proto.Order):
            order_id = order_param.id
        self._backend.cancel_order(order_id)

    @api_method
    @require_initialized(HistoryInInitialize())
    def history(self, bar_count, frequency, field, ffill=True):
        """DEPRECATED: use ``data.history`` instead.
        """

        return self.get_history_window(
            bar_count,
            frequency,
            self._calculate_universe(),
            field,
            ffill,
        )

    def get_history_window(self, bar_count, frequency, assets, field, ffill):
        return self.data_portal.get_history_window(
            assets,
            self.datetime,
            bar_count,
            frequency,
            field,
            self.data_frequency,
            ffill,
        )

    def _calculate_order(self,
                         asset,
                         amount,
                         limit_price=None,
                         stop_price=None,
                         style=None):
        amount = self.round_order(amount)

        # Raises a ZiplineError if invalid parameters are detected.
        self.validate_order_params(asset, amount, limit_price, stop_price,
                                   style)

        # Convert deprecated limit_price and stop_price parameters to use
        # ExecutionStyle objects.
        style = self.__convert_order_params_for_blotter(
            limit_price, stop_price, style)
        return amount, style

    def validate_order_params(self, asset, amount, limit_price, stop_price,
                              style):
        """
        Helper method for validating parameters to the order API function.

        Raises an UnsupportedOrderParameters if invalid arguments are found.
        """

        if not self.initialized:
            raise OrderDuringInitialize(
                msg="order() can only be called from within handle_data()")

        if style:
            if limit_price:
                raise UnsupportedOrderParameters(
                    msg="Passing both limit_price and style is not supported.")

            if stop_price:
                raise UnsupportedOrderParameters(
                    msg="Passing both stop_price and style is not supported.")

        for control in self.trading_controls:
            control.validate(asset, amount, self.portfolio,
                             self.get_datetime(), self.executor.current_data)

    @staticmethod
    def round_order(amount):
        """
        Convert number of shares to an integer.

        By default, truncates to the integer share count that's either within
        .0001 of amount or closer to zero.

        E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0
        """
        return int(round_if_near_integer(amount))

    @staticmethod
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
        """
        Helper method for converting deprecated limit_price and stop_price
        arguments into ExecutionStyle instances.

        This function assumes that either style == None or (limit_price,
        stop_price) == (None, None).
        """
        if style:
            assert (limit_price, stop_price) == (None, None)
            return style
        if limit_price and stop_price:
            return StopLimitOrder(limit_price, stop_price)
        if limit_price:
            return LimitOrder(limit_price)
        if stop_price:
            return StopOrder(stop_price)
        else:
            return MarketOrder()

    def _calculate_universe(self):
        # this exists to provide backwards compatibility for older,
        # deprecated APIs, particularly around the iterability of
        # BarData (ie, 'for sid in data`).

        # our universe is all the assets passed into `run`.
        return self._assets_from_source

    def _calculate_order_value_amount(self, asset, value):
        """
        Calculates how many shares/contracts to order based on the type of
        asset being ordered.
        """
        if not self.executor.current_data.can_trade(asset):
            raise CannotOrderDelistedAsset(
                msg="Cannot order {0}, as it not tradable".format(
                    asset.symbol))

        last_price = \
            self.executor.current_data.current(asset, "price")

        if np.isnan(last_price):
            raise CannotOrderDelistedAsset(
                msg="Cannot order {0} on {1} as there is no last "
                "price for the security.".format(asset.symbol, self.datetime))

        if tolerant_equals(last_price, 0):
            zero_message = "Price of 0 for {psid}; can't infer value".format(
                psid=asset)
            if self.logger:
                self.logger.debug(zero_message)
            # Don't place any order
            return 0

        return value / last_price

    def _calculate_order_percent_amount(self, asset, percent):
        value = self.portfolio.portfolio_value * percent
        return self._calculate_order_value_amount(asset, value)

    def _calculate_order_target_amount(self, asset, target):
        if asset in self.portfolio.positions:
            current_position = self.portfolio.positions[asset].amount
            target -= current_position

        return target

    def _calculate_order_target_percent_amount(self, asset, target):
        target_amount = self._calculate_order_percent_amount(asset, target)
        return self._calculate_order_target_amount(asset, target_amount)

    def _can_order_asset(self, asset):

        if not isinstance(asset, Asset):
            raise UnsupportedOrderParameters(
                msg="Passing non-Asset argument to 'order()' is not supported."
                " Use 'sid()' or 'symbol()' methods to look up an Asset.")

        if asset.auto_close_date:
            day = normalize_date(self.get_datetime())

            if day > min(asset.end_date, asset.auto_close_date):
                # If we are after the asset's end date or auto close date, warn
                # the user that they can't place an order for this asset, and
                # return None.
                log.warn("Cannot place order for {0}"
                         ", as it is not tradable.".format(asset.symbol))

                return False

        return True

    #
    # Account Controls
    #
    def register_account_control(self, control):
        """
        Register a new AccountControl to be checked on each bar.
        """
        if self.initialized:
            raise RegisterAccountControlPostInit()
        self.account_controls.append(control)

    def validate_account_controls(self):
        for control in self.account_controls:
            control.validate(self.portfolio, self.account, self.get_datetime(),
                             self.executor.current_data)

    @api_method
    def set_max_leverage(self, max_leverage):
        """Set a limit on the maximum leverage of the algorithm.

        Parameters
        ----------
        max_leverage : float
            The maximum leverage for the algorithm. If not provided there will
            be no maximum.
        """
        control = MaxLeverage(max_leverage)
        self.register_account_control(control)

    #
    # Trading Controls
    #
    def register_trading_control(self, control):
        """
        Register a new TradingControl to be checked prior to order calls.
        """
        if self.initialized:
            raise RegisterTradingControlPostInit()
        self.trading_controls.append(control)

    @api_method
    def set_max_position_size(self,
                              asset=None,
                              max_shares=None,
                              max_notional=None,
                              on_error='fail'):
        """Set a limit on the number of shares and/or dollar value held for the
        given sid. Limits are treated as absolute values and are enforced at
        the time that the algo attempts to place an order for sid. This means
        that it's possible to end up with more than the max number of shares
        due to splits/dividends, and more than the max notional due to price
        improvement.

        If an algorithm attempts to place an order that would result in
        increasing the absolute value of shares/dollar value exceeding one of
        these limits, raise a TradingControlException.

        Parameters
        ----------
        asset : Asset, optional
            If provided, this sets the guard only on positions in the given
            asset.
        max_shares : int, optional
            The maximum number of shares to hold for an asset.
        max_notional : float, optional
            The maximum value to hold for an asset.
        """
        control = MaxPositionSize(asset=asset,
                                  max_shares=max_shares,
                                  max_notional=max_notional,
                                  on_error=on_error)
        self.register_trading_control(control)

    @api_method
    def set_max_order_size(self,
                           asset=None,
                           max_shares=None,
                           max_notional=None,
                           on_error='fail'):
        """Set a limit on the number of shares and/or dollar value of any single
        order placed for sid.  Limits are treated as absolute values and are
        enforced at the time that the algo attempts to place an order for sid.

        If an algorithm attempts to place an order that would result in
        exceeding one of these limits, raise a TradingControlException.

        Parameters
        ----------
        asset : Asset, optional
            If provided, this sets the guard only on positions in the given
            asset.
        max_shares : int, optional
            The maximum number of shares that can be ordered at one time.
        max_notional : float, optional
            The maximum value that can be ordered at one time.
        """
        control = MaxOrderSize(asset=asset,
                               max_shares=max_shares,
                               max_notional=max_notional,
                               on_error=on_error)
        self.register_trading_control(control)

    @api_method
    def set_max_order_count(self, max_count, on_error='fail'):
        """Set a limit on the number of orders that can be placed in a single
        day.

        Parameters
        ----------
        max_count : int
            The maximum number of orders that can be placed on any single day.
        """
        control = MaxOrderCount(on_error, max_count)
        self.register_trading_control(control)

    @api_method
    def set_do_not_order_list(self, restricted_list, on_error='fail'):
        """Set a restriction on which assets can be ordered.

        Parameters
        ----------
        restricted_list : container[Asset], SecurityList
            The assets that cannot be ordered.
        """
        if isinstance(restricted_list, SecurityList):
            warnings.warn(
                "`set_do_not_order_list(security_lists.leveraged_etf_list)` "
                "is deprecated. Use `set_asset_restrictions("
                "security_lists.restrict_leveraged_etfs)` instead.",
                category=DeprecationWarning,
                stacklevel=2)
            restrictions = SecurityListRestrictions(restricted_list)
        else:
            warnings.warn(
                "`set_do_not_order_list(container_of_assets)` is deprecated. "
                "Create a zipline.finance.asset_restrictions."
                "StaticRestrictions object with a container of assets and use "
                "`set_asset_restrictions(StaticRestrictions("
                "container_of_assets))` instead.",
                category=DeprecationWarning,
                stacklevel=2)
            restrictions = StaticRestrictions(restricted_list)

        self.set_asset_restrictions(restrictions, on_error)

    @api_method
    @expect_types(
        restrictions=Restrictions,
        on_error=str,
    )
    def set_asset_restrictions(self, restrictions, on_error='fail'):
        """Set a restriction on which assets can be ordered.

        Parameters
        ----------
        restricted_list : Restrictions
            An object providing information about restricted assets.

        See Also
        --------
        zipline.finance.asset_restrictions.Restrictions
        """
        control = RestrictedListOrder(on_error, restrictions)
        self.register_trading_control(control)
        self.restrictions |= restrictions

    @api_method
    def set_long_only(self, on_error='fail'):
        """Set a rule specifying that this algorithm cannot take short
        positions.
        """
        self.register_trading_control(LongOnly(on_error))

    @api_method
    def attach_pipeline(self, pipeline, name, chunks=None):
        self._pipelines[name] = pipeline

    @api_method
    def pipeline_output(self, name):
        try:
            from pipeline_live.engine import LivePipelineEngine
        except ImportError:
            raise RuntimeError('pipeline-live is not installed')

        finder = self.asset_finder

        def list_symbols():
            return sorted([a.symbol for a in finder._asset_cache.values()])

        eng = LivePipelineEngine(list_symbols)
        output = eng.run_pipeline(self._pipelines[name])
        output.index = pd.Index(finder.lookup_symbols(output.index))
        return output
Exemplo n.º 4
0
    def __init__(self, *args, **kwargs):
        '''
        data_frequency: 'minute' or 'daily'
        algoname: str, defaults to 'algo'
        backend: str or Backend instance, defaults to 'alpaca'
                 (str is either backend module name under
                  'pylivetrader.backend', or global import path)
        trading_calendar: pd.DateIndex for trading calendar
        initialize: initialize function
        handle_data: handle_data function
        before_trading_start: before_trading_start function
        log_level: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'
        storage_engine: 'file', 'redis'
        pipeline_hook: pipeline_output hook function to enable smoke like
                       functionality. it is not meant to be used by the
                       CLI
        '''
        log.level = lookup_level(kwargs.pop('log_level', 'INFO'))
        self._recorded_vars = {}

        self.data_frequency = kwargs.pop('data_frequency', 'minute')
        assert self.data_frequency in ('minute', 'daily')

        self._algoname = kwargs.pop('algoname', 'algo')

        self.quantopian_compatible = kwargs.pop('quantopian_compatible', True)

        storage_engine = kwargs.pop('storage_engine', 'file')
        if storage_engine == 'redis':
            storage_engine = RedisStore()
        else:
            storage_engine = FileStore(
                kwargs.pop('statefile', None)
                or '{}-state.pkl'.format(self._algoname))
        self._state_store = StateStore(storage_engine=storage_engine)

        self._pipelines = {}

        backend_param = kwargs.pop('backend', 'alpaca')
        if not isinstance(backend_param, str):
            self._backend = backend_param
            self._backend_name = backend_param.__class__.__name__
        else:
            self._backend_name = backend_param
            try:
                # First, tries to import official backend packages
                backendmod = importlib.import_module(
                    'pylivetrader.backend.{}'.format(self._backend_name))
            except ImportError:
                # Then if failes, tries to find pkg in global package
                # namespace.
                try:
                    backendmod = importlib.import_module(self._backend_name)
                except ImportError:
                    raise RuntimeError(
                        "Could not find backend package `{}`.".format(
                            self._backend_name))

            self.backend_options = kwargs.pop('backend_options', None) or {}
            self._backend = backendmod.Backend(**self.backend_options)

        self.asset_finder = AssetFinder(self._backend)

        self.trading_calendar = kwargs.pop('trading_calendar',
                                           get_calendar('NYSE'))

        self.data_portal = DataPortal(self._backend, self.asset_finder,
                                      self.trading_calendar,
                                      self.quantopian_compatible)

        self.event_manager = EventManager()

        self.trading_controls = []

        self.account_controls = []

        self.restrictions = NoRestrictions()

        self._initialize = kwargs.pop('initialize', noop)
        self._handle_data = kwargs.pop('handle_data', noop)
        self._before_trading_start = kwargs.pop('before_trading_start', noop)
        self._pipeline_hook = kwargs.get('pipeline_hook')

        self.event_manager.add_event(
            events.Event(
                events.Always(),
                # We pass handle_data.__func__ to get the unbound method.
                self.handle_data.__func__,
            ),
            prepend=True,
        )

        self._account_needs_update = True
        self._portfolio_needs_update = True

        self._in_before_trading_start = False

        self._assets_from_source = []

        self._context_persistence_excludes = []

        self._max_shares = int(1e+11)

        self.initialized = False

        self.api_methods = [
            func for func in dir(Algorithm)
            if callable(getattr(Algorithm, func))
        ]
Exemplo n.º 5
0
def test_finder():
    asset = Equity('asset-id', 'NSDQ', symbol='AAPL')

    class DummyBroker:
        def get_equities(self):
            return [asset]

    # retrieve_asset
    finder = AssetFinder(DummyBroker())
    assert finder.retrieve_asset('asset-id') == asset

    with pytest.raises(SidsNotFound):
        finder.retrieve_asset('invalid')

    # retrieve_all
    assert finder.retrieve_all(['asset-id']) == [asset]

    with pytest.raises(SidsNotFound):
        assert finder.retrieve_all(['asset-id', 'invalid'])

    assert finder.retrieve_all(['asset-id', 'invalid'], default_none=True) \
        == [asset, None]

    # retrieve_equities
    assert finder.retrieve_equities(['asset-id'])['asset-id'] == asset
    with pytest.raises(EquitiesNotFound):
        assert finder.retrieve_equities(['inv'])

    # asset should be cached until cleared
    assert hasattr(finder, 'asset_cache')

    finder.clear_cache()
    assert not hasattr(finder, 'asset_cache')

    # lookup_symbol

    assert finder.lookup_symbol('AAPL', None) == asset
    assert finder.lookup_symbol('AAPL', None, fuzzy=True) == asset

    with pytest.raises(SymbolNotFound):
        finder.lookup_symbol('invalid', None)

    with pytest.raises(SymbolNotFound):
        finder.lookup_symbol('invalid', None, fuzzy=True)

    # lookup_symbols

    assert finder.lookup_symbols(['AAPL'], None) == [asset]
    assert finder.lookup_symbols(['AAPL'], None, fuzzy=True) == [asset]

    with pytest.raises(SymbolNotFound):
        finder.lookup_symbols(['AAPL', 'invalid'], None)

    with pytest.raises(SymbolNotFound):
        finder.lookup_symbols(['AAPL', 'invalid'], None)

    # sids
    assert finder.sids == ['asset-id']