Example #1
0
class CleanableCM(object):
    """Cleanable context manager (based on ExitStack)"""

    def __init__(self):
        super(CleanableCM, self).__init__()
        self.stack = ExitStack()

    def _enter(self):
        """Should be override"""
        raise NotImplementedError

    @contextmanager
    def _cleanup_on_error(self):
        with ExitStack() as stack:
            stack.push(self)
            yield
            stack.pop_all()

    def __enter__(self):
        with self._cleanup_on_error():
            self.stack.__enter__()
            return self._enter()

    def __exit__(self, exc_type, exc_value, traceback):
        self.stack.__exit__(exc_type, exc_value, traceback)
Example #2
0
 def __call__(s, *args, **kwargs):
     stack = ExitStack()
     stack.enter_context(self.flask_app.app_context())
     if getattr(s, 'request_context', False):
         stack.enter_context(self.flask_app.test_request_context(base_url=config.BASE_URL))
     args = _CelerySAWrapper.unwrap_args(args)
     kwargs = _CelerySAWrapper.unwrap_kwargs(kwargs)
     plugin = getattr(s, 'plugin', s.request.get('indico_plugin'))
     if isinstance(plugin, basestring):
         plugin_name = plugin
         plugin = plugin_engine.get_plugin(plugin)
         if plugin is None:
             stack.close()
             raise ValueError('Plugin not active: ' + plugin_name)
     stack.enter_context(plugin_context(plugin))
     clearCache()
     with stack:
         request_stats_request_started()
         return super(IndicoTask, s).__call__(*args, **kwargs)
Example #3
0
 def setUp(self):
     self._instance_teardown_stack = ExitStack()
     try:
         self._init_instance_fixtures_was_called = False
         self.init_instance_fixtures()
         assert self._init_instance_fixtures_was_called, (
             "ZiplineTestCase.init_instance_fixtures() was not"
             " called.\n"
             "This probably means that you overrode"
             " init_instance_fixtures without calling super()."
         )
     except:
         self.tearDown()
         raise
Example #4
0
def shell_cmd(verbose, with_req_context):
    try:
        from IPython.terminal.ipapp import TerminalIPythonApp
    except ImportError:
        click.echo(cformat('%{red!}You need to `pip install ipython` to use the Indico shell'))
        sys.exit(1)

    current_app.config['REPL'] = True  # disables e.g. memoize_request
    request_stats_request_started()
    context, info = _make_shell_context()
    banner = cformat('%{yellow!}Indico v{} is ready for your commands').format(indico.__version__)
    if verbose:
        banner = '\n'.join(info + ['', banner])
    ctx = current_app.make_shell_context()
    ctx.update(context)
    clearCache()
    stack = ExitStack()
    if with_req_context:
        stack.enter_context(current_app.test_request_context(base_url=config.BASE_URL))
    with stack:
        ipython_app = TerminalIPythonApp.instance(user_ns=ctx, display_banner=False)
        ipython_app.initialize(argv=[])
        ipython_app.shell.show_banner(banner)
        ipython_app.start()
Example #5
0
 def __init__(self):
     super(NosePlugin, self).__init__()
     self.patterns = []
     self.stderr = False
     self.record = False
     def set_stderr(ignore):
         self.stderr = True
     self.addArgument(self.patterns, 'P', 'pattern',
                      'Add a test matching pattern')
     self.addFlag(set_stderr, 'E', 'stderr',
                  'Enable stderr logging to sub-runners')
     def set_record(ignore):
         self.record = True
     self.addFlag(set_record, 'R', 'rerecord',
                  """Force re-recording of test responses.  Requires
                  Mailman to be running.""")
     self._data_path = os.path.join(TOPDIR, 'tests', 'data', 'tape.yaml')
     self._resources = ExitStack()
Example #6
0
 def setUp(self):
     type(self)._in_setup = True
     self._pre_setup_attrs = set(vars(self))
     self._instance_teardown_stack = ExitStack()
     try:
         self._init_instance_fixtures_was_called = False
         self.init_instance_fixtures()
         assert self._init_instance_fixtures_was_called, (
             "ZiplineTestCase.init_instance_fixtures() was not"
             " called.\n"
             "This probably means that you overrode"
             " init_instance_fixtures without calling super()."
         )
     except:
         self.tearDown()
         raise
     finally:
         type(self)._in_setup = False
Example #7
0
def train_command(args):
    if args.ooc_gpu_memory_size is not None:
        ooc_gpu_memory_size = str_to_num(args.ooc_gpu_memory_size)
        if ooc_gpu_memory_size < 0:
            logger.log(
                99,
                f'Fatal error. invalid ooc_gpu_memory_size [{args.ooc_gpu_memory_size}].'
            )
            return False
        args.ooc_gpu_memory_size = ooc_gpu_memory_size
    if args.ooc_window_length is not None:
        ooc_window_length = str_to_num(args.ooc_window_length)
        if ooc_window_length < 0:
            logger.log(
                99,
                f'Fatal error. invalid ooc_window_length [{args.ooc_window_length}].'
            )
            return False
        args.ooc_window_length = ooc_window_length

    callback.update_status(args)

    if single_or_rankzero():
        configure_progress(os.path.join(args.outdir, 'progress.txt'))

    info = load.load([args.config],
                     prepare_data_iterator=None,
                     exclude_parameter=True,
                     context=args.context)

    # Check dataset uri is empty.
    dataset_error = False
    for dataset in info.datasets.values():
        if dataset.uri.strip() == '':
            dataset_error = True
    if dataset_error:
        logger.log(99, 'Fatal error. Dataset URI is empty.')
        return False

    class TrainConfig:
        pass

    config = TrainConfig()
    config.timelimit = -1
    if args.param:
        # If this parameter file contains optimizer information
        # we need to info to recovery.
        #load.load([args.param], parameter_only=True)
        load_train_state(args.param, info)

    config.timelimit = callback.get_timelimit(args)

    config.global_config = info.global_config
    config.training_config = info.training_config

    if single_or_rankzero():
        logger.log(99, 'Train with contexts {}'.format(available_contexts))

    class OptConfig:
        pass

    config.optimizers = OrderedDict()
    for name, opt in info.optimizers.items():
        o = OptConfig()
        o.optimizer = opt
        o.data_iterators = []
        config.optimizers[name] = o

    class MonConfig:
        pass

    config.monitors = OrderedDict()
    for name, mon in info.monitors.items():
        m = MonConfig()
        m.monitor = mon
        m.data_iterators = []
        config.monitors[name] = m

    # Training
    comm = current_communicator()
    config.training_config.iter_per_epoch //= comm.size if comm else 1
    max_iteration = config.training_config.max_epoch * \
        config.training_config.iter_per_epoch

    global _save_parameter_info
    _save_parameter_info = {}
    _, config_ext = os.path.splitext(args.config)
    if config_ext == '.prototxt' or config_ext == '.nntxt':
        _save_parameter_info['config'] = args.config
    elif config_ext == '.nnp':
        with zipfile.ZipFile(args.config, 'r') as nnp:
            for name in nnp.namelist():
                _, ext = os.path.splitext(name)
                if ext == '.nntxt' or ext == '.prototxt':
                    nnp.extract(name, args.outdir)
                    _save_parameter_info['config'] = os.path.join(
                        args.outdir, name)

    result = False
    restart = False
    if max_iteration > 0:
        rng = np.random.RandomState(comm.rank if comm else 0)
        with ExitStack() as stack:
            # Create data_iterator instance only once for each dataset in optimizers
            optimizer_data_iterators = {}
            for name, o in config.optimizers.items():
                for di in o.optimizer.data_iterators.values():
                    if di not in optimizer_data_iterators:
                        di_instance = stack.enter_context(di())
                        if comm and comm.size > 1:
                            di_instance = di_instance.slice(
                                rng, comm.size, comm.rank)
                        optimizer_data_iterators[di] = di_instance
                    else:
                        di_instance = optimizer_data_iterators[di]
                    o.data_iterators.append(di_instance)

            # Create data_iterator instance only once for each dataset in monitors
            monitor_data_iterators = {}
            for name, m in config.monitors.items():
                for di in m.monitor.data_iterators.values():
                    if di not in monitor_data_iterators:
                        di_instance = stack.enter_context(di())
                        if comm and comm.size > 1:
                            di_instance = di_instance.slice(
                                rng, comm.size, comm.rank)
                        monitor_data_iterators[di] = di_instance
                    else:
                        di_instance = monitor_data_iterators[di]
                    m.data_iterators.append(di_instance)
            monitor_data_iterators.update(optimizer_data_iterators)

            result, restart = _train(args, config)
    else:
        # save parameters without training (0 epoch learning)
        logger.log(99, '0 epoch learning. (Just save parameter.)')
        if single_or_rankzero():
            _save_parameters(args, None, 0, config, True)
        result = True

    if single_or_rankzero() and not restart:
        if result:
            logger.log(99, 'Training Completed.')
            callback.update_status('finished')
        else:
            logger.log(99, 'Training Incompleted.')
            callback.update_status('failed')
    if single_or_rankzero():
        progress(None)
    return True
Example #8
0
class ZiplineTestCase(with_metaclass(FinalMeta, TestCase)):
    """
    Shared extensions to core unittest.TestCase.

    Overrides the default unittest setUp/tearDown functions with versions that
    use ExitStack to correctly clean up resources, even in the face of
    exceptions that occur during setUp/setUpClass.

    Subclasses **should not override setUp or setUpClass**!

    Instead, they should implement `init_instance_fixtures` for per-test-method
    resources, and `init_class_fixtures` for per-class resources.

    Resources that need to be cleaned up should be registered using
    either `enter_{class,instance}_context` or `add_{class,instance}_callback}.
    """
    _in_setup = False

    @final
    @classmethod
    def setUpClass(cls):
        # Hold a set of all the "static" attributes on the class. These are
        # things that are not populated after the class was created like
        # methods or other class level attributes.
        cls._static_class_attributes = set(vars(cls))
        cls._class_teardown_stack = ExitStack()
        try:
            cls._base_init_fixtures_was_called = False
            cls.init_class_fixtures()
            assert cls._base_init_fixtures_was_called, (
                "ZiplineTestCase.init_class_fixtures() was not called.\n"
                "This probably means that you overrode init_class_fixtures"
                " without calling super().")
        except:
            cls.tearDownClass()
            raise

    @classmethod
    def init_class_fixtures(cls):
        """
        Override and implement this classmethod to register resources that
        should be created and/or torn down on a per-class basis.

        Subclass implementations of this should always invoke this with super()
        to ensure that fixture mixins work properly.
        """
        if cls._in_setup:
            raise ValueError(
                'Called init_class_fixtures from init_instance_fixtures.'
                'Did you write super(..., self).init_class_fixtures() instead'
                ' of super(..., self).init_instance_fixtures()?', )
        cls._base_init_fixtures_was_called = True

    @final
    @classmethod
    def tearDownClass(cls):
        cls._class_teardown_stack.close()
        for name in set(vars(cls)) - cls._static_class_attributes:
            # Remove all of the attributes that were added after the class was
            # constructed. This cleans up any large test data that is class
            # scoped while still allowing subclasses to access class level
            # attributes.
            delattr(cls, name)

    @final
    @classmethod
    def enter_class_context(cls, context_manager):
        """
        Enter a context manager to be exited during the tearDownClass
        """
        if cls._in_setup:
            raise ValueError(
                'Attempted to enter a class context in init_instance_fixtures.'
                '\nDid you mean to call enter_instance_context?', )
        return cls._class_teardown_stack.enter_context(context_manager)

    @final
    @classmethod
    def add_class_callback(cls, callback):
        """
        Register a callback to be executed during tearDownClass.

        Parameters
        ----------
        callback : callable
            The callback to invoke at the end of the test suite.
        """
        if cls._in_setup:
            raise ValueError(
                'Attempted to add a class callback in init_instance_fixtures.'
                '\nDid you mean to call add_instance_callback?', )
        return cls._class_teardown_stack.callback(callback)

    @final
    def setUp(self):
        type(self)._in_setup = True
        self._pre_setup_attrs = set(vars(self))
        self._instance_teardown_stack = ExitStack()
        try:
            self._init_instance_fixtures_was_called = False
            self.init_instance_fixtures()
            assert self._init_instance_fixtures_was_called, (
                "ZiplineTestCase.init_instance_fixtures() was not"
                " called.\n"
                "This probably means that you overrode"
                " init_instance_fixtures without calling super().")
        except:
            self.tearDown()
            raise
        finally:
            type(self)._in_setup = False

    def init_instance_fixtures(self):
        self._init_instance_fixtures_was_called = True

    @final
    def tearDown(self):
        self._instance_teardown_stack.close()
        for attr in set(vars(self)) - self._pre_setup_attrs:
            delattr(self, attr)

    @final
    def enter_instance_context(self, context_manager):
        """
        Enter a context manager that should be exited during tearDown.
        """
        return self._instance_teardown_stack.enter_context(context_manager)

    @final
    def add_instance_callback(self, callback):
        """
        Register a callback to be executed during tearDown.

        Parameters
        ----------
        callback : callable
            The callback to invoke at the end of each test.
        """
        return self._instance_teardown_stack.callback(callback)
Example #9
0
def stack():
    """Provide a cleanup stack to use in the test (without indentation)."""
    with ExitStack() as stack:
        yield stack
Example #10
0
def profile_command(args):
    callback.update_status(args)

    configure_progress(os.path.join(args.outdir, 'progress.txt'))

    class TrainConfig:
        pass
    config = TrainConfig()
    info = load.load(args.config)

    config.global_config = info.global_config
    config.training_config = info.training_config

    class OptConfig:
        pass
    config.optimizers = OrderedDict()
    for name, opt in info.optimizers.items():
        o = OptConfig()
        o.optimizer = opt
        o.data_iterators = []
        config.optimizers[name] = o

    class MonConfig:
        pass
    config.monitors = OrderedDict()
    for name, mon in info.monitors.items():
        m = MonConfig()
        m.monitor = mon
        m.data_iterators = []
        config.monitors[name] = m

    ext_module = import_extension_module(
        config.global_config.default_context.backend[0].split(':')[0])

    def synchronize(): return ext_module.synchronize(
        device_id=config.global_config.default_context.device_id)

    result_array = [['time in ms']]

    callback.update_status('processing', True)

    # Profile Optimizer
    with ExitStack() as stack:
        # Create data_iterator instance only once for each dataset in optimizers
        optimizer_data_iterators = {}
        for name, o in config.optimizers.items():
            for di in o.optimizer.data_iterators.values():
                if di not in optimizer_data_iterators:
                    di_instance = stack.enter_context(di())
                    optimizer_data_iterators[di] = di_instance
                else:
                    di_instance = optimizer_data_iterators[di]
                o.data_iterators.append(di_instance)
        result_array = profile_optimizer(config, result_array, synchronize)

    # Write profiling result
    import csv
    with open(args.outdir + os.sep + 'profile.csv', 'w') as f:
        writer = csv.writer(f, lineterminator='\n')
        writer.writerows(result_array)

    logger.log(99, 'Profile Completed.')
    progress(None)
    callback.update_status('finished')
    return True
Example #11
0
    def transform(self):
        """
        Main generator work loop.
        """
        algo = self.algo
        emission_rate = algo.perf_tracker.emission_rate

        def every_bar(dt_to_use,
                      current_data=self.current_data,
                      handle_data=algo.event_manager.handle_data):
            # called every tick (minute or day).
            algo.on_dt_changed(dt_to_use)

            for capital_change in calculate_minute_capital_changes(dt_to_use):
                yield capital_change

            self.simulation_dt = dt_to_use

            blotter = algo.blotter
            perf_tracker = algo.perf_tracker

            # handle any transactions and commissions coming out new orders
            # placed in the last bar

            new_transactions, new_commissions, closed_orders = \
                blotter.get_transactions(current_data)

            blotter.prune_orders(closed_orders)

            for transaction in new_transactions:
                perf_tracker.process_transaction(transaction)

                # since this order was modified, record it
                order = blotter.orders[transaction.order_id]
                perf_tracker.process_order(order)

            if new_commissions:
                for commission in new_commissions:
                    perf_tracker.process_commission(commission)

            handle_data(algo, current_data, dt_to_use)

            # grab any new orders from the blotter, then clear the list.
            # this includes cancelled orders.
            new_orders = blotter.new_orders
            blotter.new_orders = []

            # if we have any new orders, record them so that we know
            # in what perf period they were placed.
            if new_orders:
                for new_order in new_orders:
                    perf_tracker.process_order(new_order)

            algo.portfolio_needs_update = True
            algo.account_needs_update = True
            algo.performance_needs_update = True

        def once_a_day(midnight_dt,
                       current_data=self.current_data,
                       data_portal=self.data_portal):

            perf_tracker = algo.perf_tracker

            # Get the positions before updating the date so that prices are
            # fetched for trading close instead of midnight
            positions = algo.perf_tracker.position_tracker.positions
            position_assets = algo.asset_finder.retrieve_all(positions)

            # set all the timestamps
            self.simulation_dt = midnight_dt
            algo.on_dt_changed(midnight_dt)

            # process any capital changes that came overnight
            for capital_change in algo.calculate_capital_changes(
                    midnight_dt, emission_rate=emission_rate,
                    is_interday=True):
                yield capital_change

            # we want to wait until the clock rolls over to the next day
            # before cleaning up expired assets.
            self._cleanup_expired_assets(midnight_dt, position_assets)

            # handle any splits that impact any positions or any open orders.
            assets_we_care_about = \
                viewkeys(perf_tracker.position_tracker.positions) | \
                viewkeys(algo.blotter.open_orders)

            for a in assets_we_care_about.copy():
                if not isinstance(a, Equity):
                    assets_we_care_about.remove(a)

            # TODO GD REMOVE THE NOT !!! TMP HACK !!!
            if assets_we_care_about:
                splits = data_portal.get_splits(assets_we_care_about,
                                                midnight_dt)
                if splits:
                    algo.blotter.process_splits(splits)
                    perf_tracker.position_tracker.handle_splits(splits)

        def handle_benchmark(date, benchmark_source=self.benchmark_source):
            algo.perf_tracker.all_benchmark_returns[date] = \
                benchmark_source.get_value(date)

        def on_exit():
            # Remove references to algo, data portal, et al to break cycles
            # and ensure deterministic cleanup of these objects when the
            # simulation finishes.
            self.algo = None
            self.benchmark_source = self.current_data = self.data_portal = None

        with ExitStack() as stack:
            stack.callback(on_exit)
            stack.enter_context(self.processor)
            stack.enter_context(ZiplineAPI(self.algo))

            if algo.data_frequency == 'minute':

                def execute_order_cancellation_policy():
                    algo.blotter.execute_cancel_policy(SESSION_END)

                def calculate_minute_capital_changes(dt):
                    # process any capital changes that came between the last
                    # and current minutes
                    return algo.calculate_capital_changes(
                        dt, emission_rate=emission_rate, is_interday=False)
            else:

                def execute_order_cancellation_policy():
                    pass

                def calculate_minute_capital_changes(dt):
                    return []

            for dt, action in self.clock:
                if action == BAR:
                    for capital_change_packet in every_bar(dt):
                        yield capital_change_packet
                elif action == SESSION_START:
                    for capital_change_packet in once_a_day(dt):
                        yield capital_change_packet
                elif action == SESSION_END:
                    # End of the session.
                    if emission_rate == 'daily':
                        handle_benchmark(normalize_date(dt))
                    execute_order_cancellation_policy()

                    yield self._get_daily_message(dt, algo, algo.perf_tracker)
                elif action == BEFORE_TRADING_START_BAR:
                    self.simulation_dt = dt
                    algo.on_dt_changed(dt)
                    algo.before_trading_start(self.current_data)
                elif action == MINUTE_END:
                    handle_benchmark(dt)
                    minute_msg = \
                        self._get_minute_message(dt, algo, algo.perf_tracker)

                    yield minute_msg

        risk_message = algo.perf_tracker.handle_simulation_end()
        yield risk_message
Example #12
0
def run_test(in_file, test_spec, global_cfg):
    """Run a single tavern test

    Note that each tavern test can consist of multiple requests (log in,
    create, update, delete, etc).

    The global configuration is copied and used as an initial configuration for
    this test. Any values which are saved from any tests are saved into this
    test block and can be used for formatting in later stages in the test.

    Args:
        in_file (str): filename containing this test
        test_spec (dict): The specification for this test
        global_cfg (dict): Any global configuration for this test

    No Longer Raises:
        TavernException: If any of the tests failed
    """

    # pylint: disable=too-many-locals

    # Initialise test config for this test with the global configuration before
    # starting
    test_block_config = dict(global_cfg)

    if "variables" not in test_block_config:
        test_block_config["variables"] = {}

    tavern_box = Box({"env_vars": dict(os.environ)})

    if not test_spec:
        logger.warning("Empty test block in %s", in_file)
        return

    # Get included stages and resolve any into the test spec dictionary
    available_stages = test_block_config.get("stages", [])
    included_stages = _get_included_stages(tavern_box, test_block_config,
                                           test_spec, available_stages)
    all_stages = {s["id"]: s for s in available_stages + included_stages}
    test_spec["stages"] = _resolve_test_stages(test_spec, all_stages)

    test_block_config["variables"]["tavern"] = tavern_box

    test_block_name = test_spec["test_name"]

    # Strict on body by default
    default_strictness = test_block_config["strict"]

    logger.info("Running test : %s", test_block_name)

    with ExitStack() as stack:
        sessions = get_extra_sessions(test_spec, test_block_config)

        for name, session in sessions.items():
            logger.debug("Entering context for %s", name)
            stack.enter_context(session)

        def getonly(stage):
            o = stage.get("only")
            if o is None:
                return False
            elif isinstance(o, bool):
                return o
            else:
                return strtobool(o)

        has_only = any(getonly(stage) for stage in test_spec["stages"])

        # Run tests in a path in order
        for stage in test_spec["stages"]:
            if stage.get("skip"):
                continue
            elif has_only and not getonly(stage):
                continue

            test_block_config["strict"] = default_strictness

            # Can be overridden per stage
            # NOTE
            # this is hardcoded to check for the 'response' block. In the far
            # future there might not be a response block, but at the moment it
            # is the hardcoded value for any HTTP request.
            if stage.get("response", {}):
                if stage.get("response").get("strict", None) is not None:
                    stage_strictness = stage.get("response").get(
                        "strict", None)
                elif test_spec.get("strict", None) is not None:
                    stage_strictness = test_spec.get("strict", None)
                else:
                    stage_strictness = default_strictness

                logger.debug("Strict key checking for this stage is '%s'",
                             stage_strictness)

                test_block_config["strict"] = stage_strictness
            elif default_strictness:
                logger.debug("Default strictness '%s' ignored for this stage",
                             default_strictness)

            # Wrap run_stage with retry helper
            run_stage_with_retries = retry(stage, test_block_config)(run_stage)

            try:
                run_stage_with_retries(sessions, stage, tavern_box,
                                       test_block_config)
            except exceptions.TavernException as e:
                e.stage = stage
                e.test_block_config = test_block_config
                raise

            if getonly(stage):
                break
Example #13
0
 def context(self):
     stack = ExitStack()
     for (name, context) in self.__contexts__:
         setattr(stack, name, stack.enter_context(context))
     return stack
Example #14
0
class NosePlugin(Plugin):
    configSection = 'mailman'

    def __init__(self):
        super(NosePlugin, self).__init__()
        self.patterns = []
        self.stderr = False
        self.record = False
        def set_stderr(ignore):
            self.stderr = True
        self.addArgument(self.patterns, 'P', 'pattern',
                         'Add a test matching pattern')
        self.addFlag(set_stderr, 'E', 'stderr',
                     'Enable stderr logging to sub-runners')
        def set_record(ignore):
            self.record = True
        self.addFlag(set_record, 'R', 'rerecord',
                     """Force re-recording of test responses.  Requires
                     Mailman to be running.""")
        self._data_path = os.path.join(TOPDIR, 'tests', 'data', 'tape.yaml')
        self._resources = ExitStack()

    def startTestRun(self, event):
        # Check to see if we're running the test suite in record mode.  If so,
        # delete any existing recording.
        if self.record:
            try:
                os.remove(self._data_path)
            except OSError as error:
                if error.errno != errno.ENOENT:
                    raise
        # This will automatically create the recording file.
        self._resources.enter_context(vcr.use_cassette(self._data_path))

    def stopTestRun(self, event):
        # Stop all recording.
        self._resources.close()

    def getTestCaseNames(self, event):
        if len(self.patterns) == 0:
            # No filter patterns, so everything should be tested.
            return
        # Does the pattern match the fully qualified class name?
        for pattern in self.patterns:
            full_class_name = '{}.{}'.format(
                event.testCase.__module__, event.testCase.__name__)
            if re.search(pattern, full_class_name):
                # Don't suppress this test class.
                return
        names = filter(event.isTestMethod, dir(event.testCase))
        for name in names:
            full_test_name = '{}.{}.{}'.format(
                event.testCase.__module__,
                event.testCase.__name__,
                name)
            for pattern in self.patterns:
                if re.search(pattern, full_test_name):
                    break
            else:
                event.excludedNames.append(name)

    def handleFile(self, event):
        path = event.path[len(TOPDIR)+1:]
        if len(self.patterns) > 0:
            for pattern in self.patterns:
                if re.search(pattern, path):
                    break
            else:
                # Skip this doctest.
                return
        base, ext = os.path.splitext(path)
        if ext != '.rst':
            return
        test = doctest.DocFileTest(
            path, package=mailmanclient,
            optionflags=FLAGS,
            setUp=setup,
            tearDown=teardown)
        # Suppress the extra "Doctest: ..." line.
        test.shortDescription = lambda: None
        event.extraTests.append(test)
Example #15
0
    def transform(self):
        """
        Main generator work loop.
        """
        algo = self.algo
        metrics_tracker = algo.metrics_tracker
        emission_rate = metrics_tracker.emission_rate

        def every_bar(dt_to_use,
                      current_data=self.current_data,
                      handle_data=algo.event_manager.handle_data):
            for capital_change in calculate_minute_capital_changes(dt_to_use):
                yield capital_change

            self.simulation_dt = dt_to_use
            # called every tick (minute or day).
            algo.on_dt_changed(dt_to_use)

            blotter = algo.blotter

            # handle any transactions and commissions coming out new orders
            # placed in the last bar
            new_transactions, new_commissions, closed_orders = \
                blotter.get_transactions(current_data)

            blotter.prune_orders(closed_orders)

            for transaction in new_transactions:
                metrics_tracker.process_transaction(transaction)

                # since this order was modified, record it
                order = blotter.orders[transaction.order_id]
                metrics_tracker.process_order(order)

            for commission in new_commissions:
                metrics_tracker.process_commission(commission)

            handle_data(algo, current_data, dt_to_use)

            # grab any new orders from the blotter, then clear the list.
            # this includes cancelled orders.
            new_orders = blotter.new_orders
            blotter.new_orders = []

            # if we have any new orders, record them so that we know
            # in what perf period they were placed.
            for new_order in new_orders:
                metrics_tracker.process_order(new_order)

        def once_a_day(midnight_dt,
                       current_data=self.current_data,
                       data_portal=self.data_portal):
            # process any capital changes that came overnight
            for capital_change in algo.calculate_capital_changes(
                    midnight_dt, emission_rate=emission_rate,
                    is_interday=True):
                yield capital_change

            # set all the timestamps
            self.simulation_dt = midnight_dt
            algo.on_dt_changed(midnight_dt)

            metrics_tracker.handle_market_open(
                midnight_dt,
                algo.data_portal,
            )

            # handle any splits that impact any positions or any open orders.
            assets_we_care_about = (viewkeys(metrics_tracker.positions)
                                    | viewkeys(algo.blotter.open_orders))

            if assets_we_care_about:
                splits = data_portal.get_splits(assets_we_care_about,
                                                midnight_dt)
                if splits:
                    algo.blotter.process_splits(splits)
                    metrics_tracker.handle_splits(splits)

        def on_exit():
            # Remove references to algo, data portal, et al to break cycles
            # and ensure deterministic cleanup of these objects when the
            # simulation finishes.
            self.algo = None
            self.benchmark_source = self.current_data = self.data_portal = None

        with ExitStack() as stack:
            stack.callback(on_exit)
            stack.enter_context(self.processor)
            stack.enter_context(ZiplineAPI(self.algo))

            if algo.data_frequency == 'minute':

                def execute_order_cancellation_policy():
                    algo.blotter.execute_cancel_policy(SESSION_END)

                def calculate_minute_capital_changes(dt):
                    # process any capital changes that came between the last
                    # and current minutes
                    return algo.calculate_capital_changes(
                        dt, emission_rate=emission_rate, is_interday=False)
            else:

                def execute_order_cancellation_policy():
                    pass

                def calculate_minute_capital_changes(dt):
                    return []

            for dt, action in self.clock:
                if action == BAR:
                    for capital_change_packet in every_bar(dt):
                        yield capital_change_packet
                elif action == SESSION_START:
                    for capital_change_packet in once_a_day(dt):
                        yield capital_change_packet
                elif action == SESSION_END:
                    # End of the session.
                    positions = metrics_tracker.positions
                    position_assets = algo.asset_finder.retrieve_all(positions)
                    self._cleanup_expired_assets(dt, position_assets)

                    execute_order_cancellation_policy()
                    algo.validate_account_controls()

                    yield self._get_daily_message(dt, algo, metrics_tracker)
                elif action == BEFORE_TRADING_START_BAR:
                    self.simulation_dt = dt
                    algo.on_dt_changed(dt)
                    algo.before_trading_start(self.current_data)
                elif action == MINUTE_END:
                    minute_msg = self._get_minute_message(
                        dt,
                        algo,
                        metrics_tracker,
                    )

                    yield minute_msg

            if not self.fast_backtest:
                risk_message = metrics_tracker.handle_simulation_end(
                    self.data_portal, )
            else:
                risk_message = None
            yield risk_message
Example #16
0
def run_test(in_file, test_spec, global_cfg):
    """Run a single tavern test

    Note that each tavern test can consist of multiple requests (log in,
    create, update, delete, etc).

    The global configuration is copied and used as an initial configuration for
    this test. Any values which are saved from any tests are saved into this
    test block and can be used for formatting in later stages in the test.

    Args:
        in_file (str): filename containing this test
        test_spec (dict): The specification for this test
        global_cfg (dict): Any global configuration for this test

    Raises:
        TavernException: If any of the tests failed
    """

    # pylint: disable=too-many-locals

    # Initialise test config for this test with the global configuration before
    # starting
    test_block_config = dict(global_cfg)

    if "variables" not in test_block_config:
        test_block_config["variables"] = {}

    tavern_box = Box({
        "env_vars": dict(os.environ),
    })

    test_block_config["variables"]["tavern"] = tavern_box

    if not test_spec:
        logger.warning("Empty test block in %s", in_file)
        return

    if test_spec.get("includes"):
        for included in test_spec["includes"]:
            if "variables" in included:
                formatted_include = format_keys(included["variables"],
                                                {"tavern": tavern_box})
                test_block_config["variables"].update(formatted_include)

    test_block_name = test_spec["test_name"]

    logger.info("Running test : %s", test_block_name)

    with ExitStack() as stack:
        sessions = get_extra_sessions(test_spec)

        for name, session in sessions.items():
            logger.debug("Entering context for %s", name)
            stack.enter_context(session)

        # Run tests in a path in order
        for stage in test_spec["stages"]:
            name = stage["name"]

            try:
                r = get_request_type(stage, test_block_config, sessions)
            except exceptions.MissingFormatError:
                log_fail(stage, None, None)
                raise

            tavern_box.update(request_vars=r.request_vars)

            try:
                expected = get_expected(stage, test_block_config, sessions)
            except exceptions.TavernException:
                log_fail(stage, None, None)
                raise

            delay(stage, "before")

            logger.info("Running stage : %s", name)

            try:
                response = r.run()
            except exceptions.TavernException:
                log_fail(stage, None, expected)
                raise

            verifiers = get_verifiers(stage, test_block_config, sessions,
                                      expected)

            for v in verifiers:
                try:
                    saved = v.verify(response)
                except exceptions.TavernException:
                    log_fail(stage, v, expected)
                    raise
                else:
                    test_block_config["variables"].update(saved)

            log_pass(stage, verifiers)

            tavern_box.pop("request_vars")
            delay(stage, "after")
Example #17
0
def main(argv=None):
    if argv is None:
        argv = sys.argv

    usage = "usage: %prog [options] [workflow_file]"
    parser = optparse.OptionParser(usage=usage)

    parser.add_option("--no-discovery",
                      action="store_true",
                      help="Don't run widget discovery "
                      "(use full cache instead)")
    parser.add_option("--force-discovery",
                      action="store_true",
                      help="Force full widget discovery "
                      "(invalidate cache)")
    parser.add_option("--clear-widget-settings",
                      action="store_true",
                      help="Remove stored widget setting")
    parser.add_option("--no-welcome",
                      action="store_true",
                      help="Don't show welcome dialog.")
    parser.add_option("--no-splash",
                      action="store_true",
                      help="Don't show splash screen.")
    parser.add_option("-l",
                      "--log-level",
                      help="Logging level (0, 1, 2, 3, 4)",
                      type="int",
                      default=1)
    parser.add_option("--style",
                      help="QStyle to use",
                      type="str",
                      default=None)
    parser.add_option("--stylesheet",
                      help="Application level CSS style sheet to use",
                      type="str",
                      default="orange.qss")
    parser.add_option("--qt",
                      help="Additional arguments for QApplication",
                      type="str",
                      default=None)

    parser.add_option("--config",
                      help="Configuration namespace",
                      type="str",
                      default="orangecanvas.example")

    # -m canvas orange.widgets
    # -m canvas --config orange.widgets

    (options, args) = parser.parse_args(argv[1:])

    levels = [
        logging.CRITICAL, logging.ERROR, logging.WARN, logging.INFO,
        logging.DEBUG
    ]

    # Fix streams before configuring logging (otherwise it will store
    # and write to the old file descriptors)
    fix_win_pythonw_std_stream()

    # Try to fix fonts on OSX Mavericks/Yosemite, ...
    fix_osx_private_font()

    # File handler should always be at least INFO level so we need
    # the application root level to be at least at INFO.
    root_level = min(levels[options.log_level], logging.INFO)
    rootlogger = logging.getLogger(__package__)
    rootlogger.setLevel(root_level)

    # Standard output stream handler at the requested level
    stream_hander = logging.StreamHandler()
    stream_hander.setLevel(level=levels[options.log_level])
    rootlogger.addHandler(stream_hander)

    if options.config is not None:
        try:
            cfg = utils.name_lookup(options.config)
        except (ImportError, AttributeError):
            pass
        else:
            #             config.default = cfg
            config.set_default(cfg)

            log.info("activating %s", options.config)

    log.info("Starting 'Orange Canvas' application.")

    qt_argv = argv[:1]

    if options.style is not None:
        qt_argv += ["-style", options.style]

    if options.qt is not None:
        qt_argv += shlex.split(options.qt)

    qt_argv += args

    if options.clear_widget_settings:
        log.debug("Clearing widget settings")
        shutil.rmtree(config.widget_settings_dir(), ignore_errors=True)

    if QT_VERSION >= 0x50600:
        CanvasApplication.setAttribute(Qt.AA_UseHighDpiPixmaps)

    log.debug("Starting CanvasApplicaiton with argv = %r.", qt_argv)
    app = CanvasApplication(qt_argv)

    # NOTE: config.init() must be called after the QApplication constructor
    config.init()

    file_handler = logging.FileHandler(filename=os.path.join(
        config.log_dir(), "canvas.log"),
                                       mode="w")

    file_handler.setLevel(root_level)
    rootlogger.addHandler(file_handler)

    # intercept any QFileOpenEvent requests until the main window is
    # fully initialized.
    # NOTE: The QApplication must have the executable ($0) and filename
    # arguments passed in argv otherwise the FileOpen events are
    # triggered for them (this is done by Cocoa, but QApplicaiton filters
    # them out if passed in argv)

    open_requests = []

    def onrequest(url):
        log.info("Received an file open request %s", url)
        open_requests.append(url)

    app.fileOpenRequest.connect(onrequest)

    settings = QSettings()

    stylesheet = options.stylesheet
    stylesheet_string = None

    if stylesheet != "none":
        if os.path.isfile(stylesheet):
            with io.open(stylesheet, "r") as f:
                stylesheet_string = f.read()
        else:
            if not os.path.splitext(stylesheet)[1]:
                # no extension
                stylesheet = os.path.extsep.join([stylesheet, "qss"])

            pkg_name = __package__
            resource = "styles/" + stylesheet

            if pkg_resources.resource_exists(pkg_name, resource):
                stylesheet_string = \
                    pkg_resources.resource_string(pkg_name, resource).decode("utf-8")

                base = pkg_resources.resource_filename(pkg_name, "styles")

                pattern = re.compile(
                    r"^\s@([a-zA-Z0-9_]+?)\s*:\s*([a-zA-Z0-9_/]+?);\s*$",
                    flags=re.MULTILINE)

                matches = pattern.findall(stylesheet_string)

                for prefix, search_path in matches:
                    QDir.addSearchPath(prefix, os.path.join(base, search_path))
                    log.info("Adding search path %r for prefix, %r",
                             search_path, prefix)

                stylesheet_string = pattern.sub("", stylesheet_string)

            else:
                log.info("%r style sheet not found.", stylesheet)

    # Add the default canvas_icons search path
    dirpath = os.path.abspath(os.path.dirname(__file__))
    QDir.addSearchPath("canvas_icons", os.path.join(dirpath, "icons"))

    canvas_window = CanvasMainWindow()
    canvas_window.setWindowIcon(config.application_icon())

    if stylesheet_string is not None:
        canvas_window.setStyleSheet(stylesheet_string)

    if not options.force_discovery:
        reg_cache = cache.registry_cache()
    else:
        reg_cache = None

    widget_registry = qt.QtWidgetRegistry()
    widget_discovery = config.widget_discovery(widget_registry,
                                               cached_descriptions=reg_cache)

    want_splash = \
        settings.value("startup/show-splash-screen", True, type=bool) and \
        not options.no_splash

    if want_splash:
        pm, rect = config.splash_screen()
        splash_screen = SplashScreen(pixmap=pm, textRect=rect)
        splash_screen.setAttribute(Qt.WA_DeleteOnClose)
        splash_screen.setFont(QFont("Helvetica", 12))
        color = QColor("#FFD39F")

        def show_message(message):
            splash_screen.showMessage(message, color=color)

        widget_registry.category_added.connect(show_message)
        show_splash = splash_screen.show
        close_splash = splash_screen.close
    else:
        show_splash = close_splash = lambda: None

    log.info("Running widget discovery process.")

    cache_filename = os.path.join(config.cache_dir(), "widget-registry.pck")
    if options.no_discovery:
        with open(cache_filename, "rb") as f:
            widget_registry = pickle.load(f)
        widget_registry = qt.QtWidgetRegistry(widget_registry)
    else:
        show_splash()
        widget_discovery.run(config.widgets_entry_points())
        close_splash()

        # Store cached descriptions
        cache.save_registry_cache(widget_discovery.cached_descriptions)
        with open(cache_filename, "wb") as f:
            pickle.dump(WidgetRegistry(widget_registry), f)

    set_global_registry(widget_registry)
    canvas_window.set_widget_registry(widget_registry)
    canvas_window.show()
    canvas_window.raise_()

    want_welcome = \
        settings.value("startup/show-welcome-screen", True, type=bool) \
        and not options.no_welcome

    # Process events to make sure the canvas_window layout has
    # a chance to activate (the welcome dialog is modal and will
    # block the event queue, plus we need a chance to receive open file
    # signals when running without a splash screen)
    app.processEvents()

    app.fileOpenRequest.connect(canvas_window.open_scheme_file)

    if want_welcome and not args and not open_requests:
        canvas_window.welcome_dialog()

    elif args:
        log.info("Loading a scheme from the command line argument %r", args[0])
        canvas_window.load_scheme(args[0])
    elif open_requests:
        log.info("Loading a scheme from an `QFileOpenEvent` for %r",
                 open_requests[-1])
        canvas_window.load_scheme(open_requests[-1].toLocalFile())

    # Tee stdout and stderr into Output dock
    output_view = canvas_window.output_view()
    stdout = TextStream()
    stdout.stream.connect(output_view.write)
    if sys.stdout:
        stdout.stream.connect(sys.stdout.write)
        stdout.flushed.connect(sys.stdout.flush)
    stderr = TextStream()
    error_writer = output_view.formated(color=Qt.red)
    stderr.stream.connect(error_writer.write)
    if sys.stderr:
        stderr.stream.connect(sys.stderr.write)
        stderr.flushed.connect(sys.stderr.flush)
    sys.excepthook = ExceptHook(stream=stderr)

    with ExitStack() as stack:
        stack.enter_context(redirect_stdout(stdout))
        stack.enter_context(redirect_stderr(stderr))
        log.info("Entering main event loop.")
        try:
            status = app.exec_()
        except BaseException:
            log.error("Error in main event loop.", exc_info=True)

    canvas_window.deleteLater()
    app.processEvents()
    app.flush()
    del canvas_window

    # Collect any cycles before deleting the QApplication instance
    gc.collect()

    del app
    return status
Example #18
0
    def main(self, orings):
        for iseqs in izip(
                *
            [iring.read(guarantee=self.guarantee) for iring in self.irings]):
            if self.shutdown_event.is_set():
                break
            for i, iseq in enumerate(iseqs):
                self.sequence_proclogs[i].update(iseq.header)
            oheaders, islices = self._on_sequence(iseqs)
            for ohdr in oheaders:
                if 'time_tag' not in ohdr:
                    ohdr['time_tag'] = self._seq_count
            self._seq_count += 1

            # Allow passing None to mean slice(gulp_nframe)
            if islices is None:
                islices = [None] * len(self.irings)
            default_igulp_nframes = [
                self.gulp_nframe or iseq.header['gulp_nframe']
                for iseq in iseqs
            ]
            islices = [
                islice or slice(igulp_nframe)
                for (islice,
                     igulp_nframe) in zip(islices, default_igulp_nframes)
            ]

            islices = [_span_slice(slice_) for slice_ in islices]
            for iseq, islice in zip(iseqs, islices):
                if self.buffer_factor is None:
                    src_block = iseq.ring.owner
                    if src_block is not None and self.is_fused_with(src_block):
                        buffer_factor = 1
                    else:
                        buffer_factor = None
                else:
                    buffer_factor = self.buffer_factor
                iseq.resize(gulp_nframe=(islice.stop - islice.start),
                            buf_nframe=self.buffer_nframe,
                            buffer_factor=buffer_factor)

            igulp_nframes = [islice.stop - islice.start for islice in islices]

            with ExitStack() as oseq_stack:
                oseqs = self.begin_sequences(oseq_stack, orings, oheaders,
                                             igulp_nframes)
                prev_time = time.time()
                for ispans in izip(*[
                        iseq.read(islice.stop -
                                  islice.start, islice.step, islice.start)
                        for (iseq, islice) in zip(iseqs, islices)
                ]):
                    if self.shutdown_event.is_set():
                        break
                    cur_time = time.time()
                    acquire_time = cur_time - prev_time
                    prev_time = cur_time
                    with ExitStack() as ospan_stack:
                        ospans = self.reserve_spans(ospan_stack, oseqs, ispans)
                        cur_time = time.time()
                        reserve_time = cur_time - prev_time
                        prev_time = cur_time
                        # *TODO: See if can fuse together multiple on_data calls here before
                        #          calling stream_synchronize().
                        #        Consider passing .data instead of rings here
                        ostrides = self._on_data(ispans, ospans)
                        # TODO: // Default to not spinning the CPU: cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync);
                        bf.device.stream_synchronize()
                        # Allow returning None to indicate complete consumption
                        if ostrides is None:
                            ostrides = [ospan.nframe for ospan in ospans]
                        ostrides = [
                            ostride if ostride is not None else ospan.nframe
                            for (ostride, ospan) in zip(ostrides, ospans)
                        ]
                        for ospan, ostride in zip(ospans, ostrides):
                            ospan.commit(ostride)
                    cur_time = time.time()
                    process_time = cur_time - prev_time
                    prev_time = cur_time
                    self.perf_proclog.update({
                        'acquire_time': acquire_time,
                        'reserve_time': reserve_time,
                        'process_time': process_time
                    })
            self._on_sequence_end(iseqs)
Example #19
0
    def transform(self, stream_in):
        """
        Main generator work loop.
        """
        # Initialize the mkt_close
        mkt_open = self.algo.perf_tracker.market_open
        mkt_close = self.algo.perf_tracker.market_close

        # inject the current algo
        # snapshot time to any log record generated.

        with ExitStack() as stack:
            stack.enter_context(self.processor)
            stack.enter_context(ZiplineAPI(self.algo))

            data_frequency = self.sim_params.data_frequency

            self._call_before_trading_start(mkt_open)

            for date, snapshot in stream_in:

                self.simulation_dt = date
                self.on_dt_changed(date)

                # If we're still in the warmup period.  Use the event to
                # update our universe, but don't yield any perf messages,
                # and don't send a snapshot to handle_data.
                if date < self.algo_start:
                    for event in snapshot:
                        if event.type == DATASOURCE_TYPE.SPLIT:
                            self.algo.blotter.process_split(event)

                        elif event.type == DATASOURCE_TYPE.TRADE:
                            self.update_universe(event)
                            self.algo.perf_tracker.process_trade(event)
                        elif event.type == DATASOURCE_TYPE.CUSTOM:
                            self.update_universe(event)

                else:
                    messages = self._process_snapshot(
                        date,
                        snapshot,
                        self.algo.instant_fill,
                    )
                    # Perf messages are only emitted if the snapshot contained
                    # a benchmark event.
                    for message in messages:
                        yield message

                    # When emitting minutely, we need to call
                    # before_trading_start before the next trading day begins
                    if date == mkt_close:
                        if mkt_close <= self.algo.perf_tracker.last_close:
                            before_last_close = \
                                mkt_close < self.algo.perf_tracker.last_close
                            try:
                                mkt_open, mkt_close = \
                                    trading.environment \
                                           .next_open_and_close(mkt_close)

                            except trading.NoFurtherDataError:
                                # If at the end of backtest history,
                                # skip advancing market close.
                                pass

                            if before_last_close:
                                self._call_before_trading_start(mkt_open)

                    elif data_frequency == 'daily':
                        next_day = trading.environment.next_trading_day(date)

                        if next_day is not None and \
                           next_day < self.algo.perf_tracker.last_close:
                            self._call_before_trading_start(next_day)

                    self.algo.portfolio_needs_update = True
                    self.algo.account_needs_update = True
                    self.algo.performance_needs_update = True

            risk_message = self.algo.perf_tracker.handle_simulation_end()
            yield risk_message
Example #20
0
def settings(*args, **kwargs):
    """
    Nest context managers and/or override ``env`` variables.

    `settings` serves two purposes:

    * Most usefully, it allows temporary overriding/updating of ``env`` with
      any provided keyword arguments, e.g. ``with settings(user='******'):``.
      Original values, if any, will be restored once the ``with`` block closes.

        * The keyword argument ``clean_revert`` has special meaning for
          ``settings`` itself (see below) and will be stripped out before
          execution.

    * In addition, it will use `contextlib.nested`_ to nest any given
      non-keyword arguments, which should be other context managers, e.g.
      ``with settings(hide('stderr'), show('stdout')):``.

    .. _contextlib.nested: http://docs.python.org/library/contextlib.html#contextlib.nested

    These behaviors may be specified at the same time if desired. An example
    will hopefully illustrate why this is considered useful::

        def my_task():
            with settings(
                hide('warnings', 'running', 'stdout', 'stderr'),
                warn_only=True
            ):
                if run('ls /etc/lsb-release'):
                    return 'Ubuntu'
                elif run('ls /etc/redhat-release'):
                    return 'RedHat'

    The above task executes a `run` statement, but will warn instead of
    aborting if the ``ls`` fails, and all output -- including the warning
    itself -- is prevented from printing to the user. The end result, in this
    scenario, is a completely silent task that allows the caller to figure out
    what type of system the remote host is, without incurring the handful of
    output that would normally occur.

    Thus, `settings` may be used to set any combination of environment
    variables in tandem with hiding (or showing) specific levels of output, or
    in tandem with any other piece of Fabric functionality implemented as a
    context manager.

    If ``clean_revert`` is set to ``True``, ``settings`` will **not** revert
    keys which are altered within the nested block, instead only reverting keys
    whose values remain the same as those given. More examples will make this
    clear; below is how ``settings`` operates normally::

        # Before the block, env.parallel defaults to False, host_string to None
        with settings(parallel=True, host_string='myhost'):
            # env.parallel is True
            # env.host_string is 'myhost'
            env.host_string = 'otherhost'
            # env.host_string is now 'otherhost'
        # Outside the block:
        # * env.parallel is False again
        # * env.host_string is None again

    The internal modification of ``env.host_string`` is nullified -- not always
    desirable. That's where ``clean_revert`` comes in::

        # Before the block, env.parallel defaults to False, host_string to None
        with settings(parallel=True, host_string='myhost', clean_revert=True):
            # env.parallel is True
            # env.host_string is 'myhost'
            env.host_string = 'otherhost'
            # env.host_string is now 'otherhost'
        # Outside the block:
        # * env.parallel is False again
        # * env.host_string remains 'otherhost'

    Brand new keys which did not exist in ``env`` prior to using ``settings``
    are also preserved if ``clean_revert`` is active. When ``False``, such keys
    are removed when the block exits.

    .. versionadded:: 1.4.1
        The ``clean_revert`` kwarg.
    """
    managers = list(args)
    if kwargs:
        managers.append(_setenv(kwargs))
    with ExitStack() as stack:
        yield tuple(stack.enter_context(cm) for cm in managers)
Example #21
0
    def transform(self):
        """
        Main generator work loop.
        """
        algo = self.algo

        def every_bar(dt_to_use,
                      current_data=self.current_data,
                      handle_data=algo.event_manager.handle_data):
            # called every tick (minute or day).

            if dt_to_use in algo.capital_changes:
                process_minute_capital_changes(dt_to_use)

            self.simulation_dt = dt_to_use
            algo.on_dt_changed(dt_to_use)

            blotter = algo.blotter
            perf_tracker = algo.perf_tracker

            # handle any transactions and commissions coming out new orders
            # placed in the last bar
            new_transactions, new_commissions, closed_orders = \
                blotter.get_transactions(current_data)

            blotter.prune_orders(closed_orders)

            for transaction in new_transactions:
                perf_tracker.process_transaction(transaction)

                # since this order was modified, record it
                order = blotter.orders[transaction.order_id]
                perf_tracker.process_order(order)

            if new_commissions:
                for commission in new_commissions:
                    perf_tracker.process_commission(commission)

            handle_data(algo, current_data, dt_to_use)

            # grab any new orders from the blotter, then clear the list.
            # this includes cancelled orders.
            new_orders = blotter.new_orders
            blotter.new_orders = []

            # if we have any new orders, record them so that we know
            # in what perf period they were placed.
            if new_orders:
                for new_order in new_orders:
                    perf_tracker.process_order(new_order)

            algo.portfolio_needs_update = True
            algo.account_needs_update = True
            algo.performance_needs_update = True

        def once_a_day(midnight_dt,
                       current_data=self.current_data,
                       data_portal=self.data_portal):

            perf_tracker = algo.perf_tracker

            if midnight_dt in algo.capital_changes:
                # process any capital changes that came overnight
                change = algo.capital_changes[midnight_dt]
                log.info('Processing capital change of %s at %s' %
                         (change, midnight_dt))
                perf_tracker.process_capital_changes(change, is_interday=True)

            # Get the positions before updating the date so that prices are
            # fetched for trading close instead of midnight
            positions = algo.perf_tracker.position_tracker.positions
            position_assets = algo.asset_finder.retrieve_all(positions)

            # set all the timestamps
            self.simulation_dt = midnight_dt
            algo.on_dt_changed(midnight_dt)

            # we want to wait until the clock rolls over to the next day
            # before cleaning up expired assets.
            self._cleanup_expired_assets(midnight_dt, position_assets)

            # handle any splits that impact any positions or any open orders.
            assets_we_care_about = \
                viewkeys(perf_tracker.position_tracker.positions) | \
                viewkeys(algo.blotter.open_orders)

            if assets_we_care_about:
                splits = data_portal.get_splits(assets_we_care_about,
                                                midnight_dt)
                if splits:
                    algo.blotter.process_splits(splits)
                    perf_tracker.position_tracker.handle_splits(splits)

            # call before trading start
            algo.before_trading_start(current_data)

        def handle_benchmark(date, benchmark_source=self.benchmark_source):
            algo.perf_tracker.all_benchmark_returns[date] = \
                benchmark_source.get_value(date)

        def on_exit():
            # Remove references to algo, data portal, et al to break cycles
            # and ensure deterministic cleanup of these objects when the
            # simulation finishes.
            self.algo = None
            self.benchmark_source = self.current_data = self.data_portal = None

        with ExitStack() as stack:
            stack.callback(on_exit)
            stack.enter_context(self.processor)
            stack.enter_context(ZiplineAPI(self.algo))

            if algo.data_frequency == 'minute':

                def execute_order_cancellation_policy():
                    algo.blotter.execute_cancel_policy(DAY_END)

                def process_minute_capital_changes(dt):
                    # If we are running daily emission, prices won't
                    # necessarily be synced at the end of every minute, and we
                    # need the up-to-date prices for capital change
                    # calculations. We want to sync the prices as of the
                    # last market minute, and this is okay from a data portal
                    # perspective as we have technically not "advanced" to the
                    # current dt yet.
                    algo.perf_tracker.position_tracker.sync_last_sale_prices(
                        self.algo.trading_schedule.previous_execution_minute(
                            dt), False, self.data_portal)

                    # process any capital changes that came between the last
                    # and current minutes
                    change = algo.capital_changes[dt]
                    log.info('Processing capital change of %s at %s' %
                             (change, dt))
                    algo.perf_tracker.process_capital_changes(
                        change, is_interday=False)
            else:

                def execute_order_cancellation_policy():
                    pass

                def process_minute_capital_changes(dt):
                    pass

            for dt, action in self.clock:
                if action == BAR:
                    every_bar(dt)
                elif action == DAY_START:
                    once_a_day(dt)
                elif action == DAY_END:
                    # End of the day.
                    if algo.perf_tracker.emission_rate == 'daily':
                        handle_benchmark(normalize_date(dt))
                    execute_order_cancellation_policy()

                    yield self._get_daily_message(dt, algo, algo.perf_tracker)
                elif action == MINUTE_END:
                    handle_benchmark(dt)
                    minute_msg = \
                        self._get_minute_message(dt, algo, algo.perf_tracker)

                    yield minute_msg

        risk_message = algo.perf_tracker.handle_simulation_end()
        yield risk_message
Example #22
0
 def __init__(self):
     super(CleanableCM, self).__init__()
     self.stack = ExitStack()
Example #23
0
def open_files(files, **kwargs):
    """A plural form of :func:`open_file`."""
    with ExitStack() as stack:
        yield [stack.enter_context(open_file(f, **kwargs)) for f in files]
Example #24
0
class ZiplineTestCase(with_metaclass(FinalMeta, TestCase)):
    """
    Shared extensions to core unittest.TestCase.

    Overrides the default unittest setUp/tearDown functions with versions that
    use ExitStack to correctly clean up resources, even in the face of
    exceptions that occur during setUp/setUpClass.

    Subclasses **should not override setUp or setUpClass**!

    Instead, they should implement `init_instance_fixtures` for per-test-method
    resources, and `init_class_fixtures` for per-class resources.

    Resources that need to be cleaned up should be registered using
    either `enter_{class,instance}_context` or `add_{class,instance}_callback}.
    """
    _in_setup = False

    @final
    @classmethod
    def setUpClass(cls):
        cls._class_teardown_stack = ExitStack()
        try:
            cls._base_init_fixtures_was_called = False
            cls.init_class_fixtures()
            assert cls._base_init_fixtures_was_called, (
                "ZiplineTestCase.init_class_fixtures() was not called.\n"
                "This probably means that you overrode init_class_fixtures"
                " without calling super()."
            )
        except:
            cls.tearDownClass()
            raise

    @classmethod
    def init_class_fixtures(cls):
        """
        Override and implement this classmethod to register resources that
        should be created and/or torn down on a per-class basis.

        Subclass implementations of this should always invoke this with super()
        to ensure that fixture mixins work properly.
        """
        if cls._in_setup:
            raise ValueError(
                'Called init_class_fixtures from init_instance_fixtures.'
                'Did you write super(..., self).init_class_fixtures() instead'
                ' of super(..., self).init_instance_fixtures()?',
            )
        cls._base_init_fixtures_was_called = True

    @final
    @classmethod
    def tearDownClass(cls):
        cls._class_teardown_stack.close()

    @final
    @classmethod
    def enter_class_context(cls, context_manager):
        """
        Enter a context manager to be exited during the tearDownClass
        """
        if cls._in_setup:
            raise ValueError(
                'Attempted to enter a class context in init_instance_fixtures.'
                '\nDid you mean to call enter_instance_context?',
            )
        return cls._class_teardown_stack.enter_context(context_manager)

    @final
    @classmethod
    def add_class_callback(cls, callback):
        """
        Register a callback to be executed during tearDownClass.

        Parameters
        ----------
        callback : callable
            The callback to invoke at the end of the test suite.
        """
        if cls._in_setup:
            raise ValueError(
                'Attempted to add a class callback in init_instance_fixtures.'
                '\nDid you mean to call add_instance_callback?',
            )
        return cls._class_teardown_stack.callback(callback)

    @final
    def setUp(self):
        type(self)._in_setup = True
        self._instance_teardown_stack = ExitStack()
        try:
            self._init_instance_fixtures_was_called = False
            self.init_instance_fixtures()
            assert self._init_instance_fixtures_was_called, (
                "ZiplineTestCase.init_instance_fixtures() was not"
                " called.\n"
                "This probably means that you overrode"
                " init_instance_fixtures without calling super()."
            )
        except:
            self.tearDown()
            raise
        finally:
            type(self)._in_setup = False

    def init_instance_fixtures(self):
        self._init_instance_fixtures_was_called = True

    @final
    def tearDown(self):
        self._instance_teardown_stack.close()

    @final
    def enter_instance_context(self, context_manager):
        """
        Enter a context manager that should be exited during tearDown.
        """
        return self._instance_teardown_stack.enter_context(context_manager)

    @final
    def add_instance_callback(self, callback):
        """
        Register a callback to be executed during tearDown.

        Parameters
        ----------
        callback : callable
            The callback to invoke at the end of each test.
        """
        return self._instance_teardown_stack.callback(callback)
Example #25
0
    def transform(self, stream_in):
        """
        Main generator work loop.
        """
        # Initialize the mkt_close
        mkt_open = self.algo.perf_tracker.market_open
        mkt_close = self.algo.perf_tracker.market_close

        # inject the current algo
        # snapshot time to any log record generated.
        # with 。。。as 是一种上下文管理器,打开与关闭。exitstack()是一个语法糖

        with ExitStack() as stack:
            stack.enter_context(self.processor)
            stack.enter_context(ZiplineAPI(self.algo))

            data_frequency = self.sim_params.data_frequency
            self._call_before_trading_start(mkt_open)

            for date, snapshot in stream_in:
                #print date,u'在主循环之内的date',self.algo_start
                #raw_input()
                #for i in snapshot:
                #    print i
                #raw_input()

                # snapshot,为迭代的数据系统,包括时间,股票数据等

                # 进入主循环,跟随日期进行循环
                self.simulation_dt = date  #模拟日期
                self.on_dt_changed(date)

                # If we're still in the warmup period.  Use the event to
                # update our universe, but don't yield any perf messages,
                # and don't send a snapshot to handle_data.
                # 如果在热身阶段
                # 判断是否进入交易日期,若开始了则handle_data发送给
                # 判断是否进行到达模拟开始的时间
                if date < self.algo_start:
                    for event in snapshot:
                        if event.type == DATASOURCE_TYPE.SPLIT:
                            self.algo.blotter.process_split(event)
                        elif event.type == DATASOURCE_TYPE.TRADE:
                            self.update_universe(event)
                            self.algo.perf_tracker.process_trade(event)
                        elif event.type == DATASOURCE_TYPE.CUSTOM:
                            self.update_universe(event)
                    if self.algo.history_container:
                        #print self.current_data
                        self.algo.history_container.update(
                            self.current_data, date)
                else:

                    # 进入每日信息的处理,
                    messages = self._process_snapshot(
                        date,
                        snapshot,
                        self.algo.instant_fill,
                    )
                    # Perf messages are only emitted if the snapshot contained
                    # a benchmark event.
                    for message in messages:
                        yield message

                    # When emitting minutely, we need to call
                    # before_trading_start before the next trading day begins
                    if date == mkt_close:
                        if mkt_close <= self.algo.perf_tracker.last_close:
                            before_last_close = \
                                mkt_close < self.algo.perf_tracker.last_close
                            try:
                                mkt_open, mkt_close = \
                                    self.env.next_open_and_close(mkt_close)

                            except NoFurtherDataError:
                                # If at the end of backtest history,
                                # skip advancing market close.
                                pass

                            if before_last_close:
                                self._call_before_trading_start(mkt_open)

                    elif data_frequency == 'daily':
                        next_day = self.env.next_trading_day(date)

                        if next_day is not None and \
                           next_day < self.algo.perf_tracker.last_close:
                            self._call_before_trading_start(
                                next_day
                            )  #如果下一天非空,并且next不是表现的最后一天。last_close,就
                    self.algo.portfolio_needs_update = True
                    self.algo.account_needs_update = True
                    self.algo.performance_needs_update = True
            risk_message = self.algo.perf_tracker.handle_simulation_end()
            yield risk_message
Example #26
0
def dedup_tracked1(sess, tt, ofile_reserved, query, fs, skipped):
    space_gain1 = space_gain2 = space_gain3 = 0
    ofile_soft, ofile_hard = resource.getrlimit(resource.RLIMIT_OFILE)

    # Hopefully close any files we left around
    gc.collect()

    # The log can cause frequent commits, we don't mind losing them in
    # a crash (no need for durability). SQLite is in WAL mode, so this pragma
    # should disable most commit-time fsync calls without compromising
    # consistency.
    sess.execute('PRAGMA synchronous=NORMAL;')

    for comm1 in query:
        if len(sess.identity_map) > 300:
            sess.flush()

        space_gain1 += comm1.size * (comm1.inode_count - 1)
        tt.update(comm1=comm1)
        for inode in comm1.inodes:
            # XXX Need to cope with deleted inodes.
            # We cannot find them in the search-new pass, not without doing
            # some tracking of directory modifications to poke updated
            # directories to find removed elements.

            # rehash everytime for now
            # I don't know enough about how inode transaction numbers are
            # updated (as opposed to extent updates) to be able to actually
            # cache the result
            try:
                path = lookup_ino_path_one(inode.vol.fd, inode.ino)
            except IOError as e:
                if e.errno != errno.ENOENT:
                    raise
                # We have a stale record for a removed inode
                # XXX If an inode number is reused and the second instance
                # is below the size cutoff, we won't update the .size
                # attribute and we won't get an IOError to notify us
                # either.  Inode reuse does happen (with and without
                # inode_cache), so this branch isn't enough to rid us of
                # all stale entries.  We can also get into trouble with
                # regular file inodes being replaced by some other kind of
                # inode.
                sess.delete(inode)
                continue
            with closing(fopenat(inode.vol.fd, path)) as rfile:
                inode.mini_hash_from_file(rfile)

        for comm2 in comm1.comm2:
            space_gain2 += comm2.size * (comm2.inode_count - 1)
            tt.update(comm2=comm2)
            for inode in comm2.inodes:
                try:
                    path = lookup_ino_path_one(inode.vol.fd, inode.ino)
                except IOError as e:
                    if e.errno != errno.ENOENT:
                        raise
                    sess.delete(inode)
                    continue
                with closing(fopenat(inode.vol.fd, path)) as rfile:
                    inode.fiemap_hash_from_file(rfile)

            if not comm2.comm3:
                continue

            comm3, = comm2.comm3
            count3 = comm3.inode_count
            space_gain3 += comm3.size * (count3 - 1)
            tt.update(comm3=comm3)
            files = []
            fds = []
            fd_names = {}
            fd_inodes = {}
            by_hash = collections.defaultdict(list)

            # XXX I have no justification for doubling count3
            ofile_req = 2 * count3 + ofile_reserved
            if ofile_req > ofile_soft:
                if ofile_req <= ofile_hard:
                    resource.setrlimit(resource.RLIMIT_OFILE,
                                       (ofile_req, ofile_hard))
                    ofile_soft = ofile_req
                else:
                    tt.notify(
                        'Too many duplicates (%d at size %d), '
                        'would bring us over the open files limit (%d, %d).' %
                        (count3, comm3.size, ofile_soft, ofile_hard))
                    for inode in comm3.inodes:
                        if inode.has_updates:
                            skipped.append(inode)
                    continue

            for inode in comm3.inodes:
                # Open everything rw, we can't pick one for the source side
                # yet because the crypto hash might eliminate it.
                # We may also want to defragment the source.
                try:
                    path = lookup_ino_path_one(inode.vol.fd, inode.ino)
                except IOError as e:
                    if e.errno == errno.ENOENT:
                        sess.delete(inode)
                        continue
                    raise
                try:
                    afile = fopenat_rw(inode.vol.fd, path)
                except IOError as e:
                    if e.errno == errno.ETXTBSY:
                        # The file contains the image of a running process,
                        # we can't open it in write mode.
                        tt.notify('File %r is busy, skipping' % path)
                        skipped.append(inode)
                        continue
                    elif e.errno == errno.EACCES:
                        # Could be SELinux or immutability
                        tt.notify('Access denied on %r, skipping' % path)
                        skipped.append(inode)
                        continue
                    elif e.errno == errno.ENOENT:
                        # The file was moved or unlinked by a racing process
                        tt.notify('File %r may have moved, skipping' % path)
                        skipped.append(inode)
                        continue
                    raise

                # It's not completely guaranteed we have the right inode,
                # there may still be race conditions at this point.
                # Gets re-checked below (tell and fstat).
                fd = afile.fileno()
                fd_inodes[fd] = inode
                fd_names[fd] = path
                files.append(afile)
                fds.append(fd)

            with ExitStack() as stack:
                for afile in files:
                    stack.enter_context(closing(afile))
                # Enter this context last
                immutability = stack.enter_context(ImmutableFDs(fds))

                for afile in files:
                    fd = afile.fileno()
                    inode = fd_inodes[fd]
                    if fd in immutability.fds_in_write_use:
                        tt.notify('File %r is in use, skipping' % fd_names[fd])
                        skipped.append(inode)
                        continue
                    hasher = hashlib.sha1()
                    for buf in iter(lambda: afile.read(BUFSIZE), b''):
                        hasher.update(buf)

                    # Gets rid of a race condition
                    st = os.fstat(fd)
                    if st.st_ino != inode.ino:
                        skipped.append(inode)
                        continue
                    if st.st_dev != inode.vol.st_dev:
                        skipped.append(inode)
                        continue

                    size = afile.tell()
                    if size != comm3.size:
                        if size < inode.vol.size_cutoff:
                            # if we didn't delete this inode, it would cause
                            # spurious comm groups in all future invocations.
                            sess.delete(inode)
                        else:
                            skipped.append(inode)
                        continue

                    by_hash[hasher.digest()].append(afile)

                for fileset in by_hash.itervalues():
                    if len(fileset) < 2:
                        continue
                    sfile = fileset[0]
                    sfd = sfile.fileno()
                    # Commented out, defragmentation can unshare extents.
                    # It can also disable compression as a side-effect.
                    if False:
                        defragment(sfd)
                    dfiles = fileset[1:]
                    dfiles_successful = []
                    for dfile in dfiles:
                        dfd = dfile.fileno()
                        sname = fd_names[sfd]
                        dname = fd_names[dfd]
                        if not cmp_files(sfile, dfile):
                            # Probably a bug since we just used a crypto hash
                            tt.notify('Files differ: %r %r' % (sname, dname))
                            assert False, (sname, dname)
                            continue
                        if clone_data(dest=dfd, src=sfd, check_first=True):
                            tt.notify('Deduplicated: %r %r' % (sname, dname))
                            dfiles_successful.append(dfile)
                        else:
                            tt.notify(
                                'Did not deduplicate (same extents): %r %r' %
                                (sname, dname))
                    if dfiles_successful:
                        evt = DedupEvent(fs=fs,
                                         item_size=comm3.size,
                                         created=system_now())
                        sess.add(evt)
                        for afile in [sfile] + dfiles_successful:
                            inode = fd_inodes[afile.fileno()]
                            evti = DedupEventInode(event=evt,
                                                   ino=inode.ino,
                                                   vol=inode.vol)
                            sess.add(evti)
                        sess.commit()

    tt.format(None)
    tt.notify('Potential space gain: pass 1 %d, pass 2 %d pass 3 %d' %
              (space_gain1, space_gain2, space_gain3))
    # Restore fsync so that the final commit (in dedup_tracked)
    # will be durable.
    sess.commit()
    sess.execute('PRAGMA synchronous=FULL;')
Example #27
0
def run_test(in_file, test_spec, global_cfg):
    """Run a single tavern test

    Note that each tavern test can consist of multiple requests (log in,
    create, update, delete, etc).

    The global configuration is copied and used as an initial configuration for
    this test. Any values which are saved from any tests are saved into this
    test block and can be used for formatting in later stages in the test.

    Args:
        in_file (str): filename containing this test
        test_spec (dict): The specification for this test
        global_cfg (dict): Any global configuration for this test

    No Longer Raises:
        TavernException: If any of the tests failed
    """

    # pylint: disable=too-many-locals

    # Initialise test config for this test with the global configuration before
    # starting
    test_block_config = dict(global_cfg)

    if "variables" not in test_block_config:
        test_block_config["variables"] = {}

    tavern_box = Box({
        "env_vars": dict(os.environ),
    })

    test_block_config["variables"]["tavern"] = tavern_box

    if not test_spec:
        logger.warning("Empty test block in %s", in_file)
        return

    available_stages = {}
    if test_spec.get("includes"):
        for included in test_spec["includes"]:
            if "variables" in included:
                formatted_include = format_keys(included["variables"],
                                                {"tavern": tavern_box})
                test_block_config["variables"].update(formatted_include)

            if "stages" in included:
                for stage in included["stages"]:
                    if stage["id"] in available_stages:
                        raise exceptions.DuplicateStageDefinitionError(
                            "Stage with specified id already defined: {}".
                            format(stage["id"]))
                    available_stages[stage["id"]] = stage

    test_block_name = test_spec["test_name"]

    # Strict on body by default
    default_strictness = test_block_config["strict"]

    logger.info("Running test : %s", test_block_name)

    with ExitStack() as stack:
        test_spec["stages"] = _resolve_test_stages(test_spec, available_stages)
        sessions = get_extra_sessions(test_spec, test_block_config)

        for name, session in sessions.items():
            logger.debug("Entering context for %s", name)
            stack.enter_context(session)

        # Run tests in a path in order
        for stage in test_spec["stages"]:
            if stage.get('skip'):
                continue

            test_block_config["strict"] = default_strictness

            # Can be overridden per stage
            # NOTE
            # this is hardcoded to check for the 'response' block. In the far
            # future there might not be a response block, but at the moment it
            # is the hardcoded value for any HTTP request.
            if stage.get("response", {}):
                if stage.get("response").get("strict", None) is not None:
                    stage_strictness = stage.get("response").get(
                        "strict", None)
                elif test_spec.get("strict", None) is not None:
                    stage_strictness = test_spec.get("strict", None)
                else:
                    stage_strictness = default_strictness

                logger.debug("Strict key checking for this stage is '%s'",
                             stage_strictness)

                test_block_config["strict"] = stage_strictness
            elif default_strictness:
                logger.debug("Default strictness '%s' ignored for this stage",
                             default_strictness)

            try:
                run_stage(sessions, stage, tavern_box, test_block_config)
            except exceptions.TavernException as e:
                e.stage = stage
                e.test_block_config = test_block_config
                raise

            if stage.get('only'):
                break
Example #28
0
def do_dedup(sess, tt, chunk):

    #print ">>> do dedup ", chunk[0].size, chunk[0].mini_hash, len(chunk)

    global ofile_soft
    global ofile_hard
    global ofile_reserved
    global fs

    files = []
    fds = []
    fd_names = {}
    fd_inodes = {}
    by_hash = collections.defaultdict(list)

    # XXX I have no justification for doubling count3
    ofile_req = 2 * len(chunk) + ofile_reserved
    if ofile_req > ofile_soft:
        if ofile_req <= ofile_hard:
            resource.setrlimit(resource.RLIMIT_OFILE, (ofile_req, ofile_hard))
            ofile_soft = ofile_req
        else:
            tt.notify('Too many duplicates (%d at size %d), '
                      'would bring us over the open files limit (%d, %d).' %
                      (count3, comm3.size, ofile_soft, ofile_hard))
            for inode in comm3.inodes:
                if inode.has_updates:
                    skipped.append(inode)
                    continue

    for inode in chunk:
        # Open everything rw, we can't pick one for the source side
        # yet because the crypto hash might eliminate it.
        # We may also want to defragment the source.
        try:
            path = lookup_ino_path_one(inode.vol.fd, inode.ino)
        except IOError as e:
            if e.errno == errno.ENOENT:
                sess.delete(inode)
                continue
            raise
        try:
            afile = fopenat_rw(inode.vol.fd, path)
        except IOError as e:
            if e.errno == errno.ETXTBSY:
                # The file contains the image of a running process,
                # we can't open it in write mode.
                tt.notify('File %r is busy, skipping' % path)
                skipped.append(inode)
                continue
            elif e.errno == errno.EACCES:
                # Could be SELinux or immutability
                tt.notify('Access denied on %r, skipping' % path)
                skipped.append(inode)
                continue
            elif e.errno == errno.ENOENT:
                # The file was moved or unlinked by a racing process
                tt.notify('File %r may have moved, skipping' % path)
                skipped.append(inode)
                continue
            raise

        # It's not completely guaranteed we have the right inode,
        # there may still be race conditions at this point.
        # Gets re-checked below (tell and fstat).
        fd = afile.fileno()
        fd_inodes[fd] = inode
        fd_names[fd] = path
        files.append(afile)
        fds.append(fd)

    with ExitStack() as stack:
        for afile in files:
            stack.enter_context(closing(afile))
        # Enter this context last
        immutability = stack.enter_context(ImmutableFDs(fds))

        for afile in files:
            fd = afile.fileno()
            inode = fd_inodes[fd]
            if fd in immutability.fds_in_write_use:
                tt.notify('File %r is in use, skipping' % fd_names[fd])
                skipped.append(inode)
                continue
            hasher = hashlib.sha1()
            for buf in iter(lambda: afile.read(BUFSIZE), b''):
                hasher.update(buf)

            # Gets rid of a race condition
            st = os.fstat(fd)
            if st.st_ino != inode.ino:
                skipped.append(inode)
                continue
            if st.st_dev != inode.vol.st_dev:
                skipped.append(inode)
                continue

            size = afile.tell()
            if size != inode.size:
                if size < inode.vol.size_cutoff:
                    # if we didn't delete this inode, it would cause
                    # spurious comm groups in all future invocations.
                    sess.delete(inode)
                else:
                    skipped.append(inode)
                continue

            by_hash[hasher.digest()].append(afile)

        for fileset in by_hash.itervalues():
            if len(fileset) < 2:
                continue
            sfile = fileset[0]
            sfd = sfile.fileno()
            # Commented out, defragmentation can unshare extents.
            # It can also disable compression as a side-effect.
            if False:
                defragment(sfd)
            dfiles = fileset[1:]
            dfiles_successful = []
            for dfile in dfiles:
                dfd = dfile.fileno()
                sname = fd_names[sfd]
                dname = fd_names[dfd]
                if not cmp_files(sfile, dfile):
                    # Probably a bug since we just used a crypto hash
                    tt.notify('Files differ: %r %r' % (sname, dname))
                    assert False, (sname, dname)
                    continue
                if clone_data(dest=dfd, src=sfd, check_first=True):
                    tt.notify('Deduplicated: %r %r' % (sname, dname))
                    dfiles_successful.append(dfile)
                else:
                    tt.notify(
                        'Did not deduplicate (same extents): %r %i %r %i' %
                        (sname, fd_inodes[sfd].ino, dname, fd_inodes[dfd].ino))
            if dfiles_successful:
                evt = DedupEvent(fs=fs,
                                 item_size=inode.size,
                                 created=system_now())
                sess.add(evt)
                for afile in [sfile] + dfiles_successful:
                    inode = fd_inodes[afile.fileno()]
                    evti = DedupEventInode(event=evt,
                                           ino=inode.ino,
                                           vol=inode.vol)
                    sess.add(evti)
                sess.commit()
Example #29
0
    def ingest(name,
               environ=os.environ,
               timestamp=None,
               assets_versions=(),
               show_progress=False):
        """Ingest data for a given bundle.

        Parameters
        ----------
        name : str
            The name of the bundle.
        environ : mapping, optional
            The environment variables. By default this is os.environ.
        timestamp : datetime, optional
            The timestamp to use for the load.
            By default this is the current time.
        assets_versions : Iterable[int], optional
            Versions of the assets db to which to downgrade.
        show_progress : bool, optional
            Tell the ingest function to display the progress where possible.
        """
        try:
            bundle = bundles[name]
        except KeyError:
            raise UnknownBundle(name)

        if timestamp is None:
            timestamp = pd.Timestamp.utcnow()
        timestamp = timestamp.tz_convert('utc').tz_localize(None)
        timestr = to_bundle_ingest_dirname(timestamp)
        cachepath = cache_path(name, environ=environ)
        pth.ensure_directory(pth.data_path([name, timestr], environ=environ))
        pth.ensure_directory(cachepath)
        with dataframe_cache(cachepath, clean_on_failure=False) as cache, \
                ExitStack() as stack:
            # we use `cleanup_on_failure=False` so that we don't purge the
            # cache directory if the load fails in the middle
            if bundle.create_writers:
                wd = stack.enter_context(
                    working_dir(pth.data_path([], environ=environ)))
                daily_bars_path = wd.ensure_dir(*daily_equity_relative(
                    name,
                    timestr,
                    environ=environ,
                ))
                daily_bar_writer = BcolzDailyBarWriter(
                    daily_bars_path,
                    bundle.calendar,
                    bundle.start_session,
                    bundle.end_session,
                )
                # Do an empty write to ensure that the daily ctables exist
                # when we create the SQLiteAdjustmentWriter below. The
                # SQLiteAdjustmentWriter needs to open the daily ctables so
                # that it can compute the adjustment ratios for the dividends.

                daily_bar_writer.write(())
                minute_bar_writer = BcolzMinuteBarWriter(
                    wd.ensure_dir(*minute_equity_relative(
                        name, timestr, environ=environ)),
                    bundle.calendar,
                    bundle.start_session,
                    bundle.end_session,
                    minutes_per_day=bundle.minutes_per_day,
                )
                assets_db_path = wd.getpath(*asset_db_relative(
                    name,
                    timestr,
                    environ=environ,
                ))
                asset_db_writer = AssetDBWriter(assets_db_path)

                adjustment_db_writer = stack.enter_context(
                    SQLiteAdjustmentWriter(
                        wd.getpath(*adjustment_db_relative(
                            name, timestr, environ=environ)),
                        BcolzDailyBarReader(daily_bars_path),
                        bundle.calendar.all_sessions,
                        overwrite=True,
                    ))
            else:
                daily_bar_writer = None
                minute_bar_writer = None
                asset_db_writer = None
                adjustment_db_writer = None
                if assets_versions:
                    raise ValueError('Need to ingest a bundle that creates '
                                     'writers in order to downgrade the assets'
                                     ' db.')
            bundle.ingest(
                environ,
                asset_db_writer,
                minute_bar_writer,
                daily_bar_writer,
                adjustment_db_writer,
                bundle.calendar,
                bundle.start_session,
                bundle.end_session,
                cache,
                show_progress,
                pth.data_path([name, timestr], environ=environ),
            )

            for version in sorted(set(assets_versions), reverse=True):
                version_path = wd.getpath(*asset_db_relative(
                    name,
                    timestr,
                    environ=environ,
                    db_version=version,
                ))
                with working_file(version_path) as wf:
                    shutil.copy2(assets_db_path, wf.path)
                    downgrade(wf.path, version)
Example #30
0
 def __call__(s, *args, **kwargs):
     stack = ExitStack()
     stack.enter_context(self.flask_app.app_context())
     stack.enter_context(DBMgr.getInstance().global_connection())
     if getattr(s, 'request_context', False):
         stack.enter_context(self.flask_app.test_request_context())
     args = _CelerySAWrapper.unwrap_args(args)
     kwargs = _CelerySAWrapper.unwrap_kwargs(kwargs)
     plugin = getattr(s, 'plugin', kwargs.pop('__current_plugin__', None))
     if isinstance(plugin, basestring):
         plugin_name = plugin
         plugin = plugin_engine.get_plugin(plugin)
         if plugin is None:
             stack.close()
             raise ValueError('Plugin not active: ' + plugin_name)
     stack.enter_context(plugin_context(plugin))
     clearCache()
     with stack:
         return super(IndicoTask, s).__call__(*args, **kwargs)
Example #31
0
def compare_with_cpu_command(args):
    configure_progress(os.path.join(args.outdir, 'progress.txt'))

    class TrainConfig:
        pass

    class OptConfig:
        pass

    class MonConfig:
        pass

    # Load config with current context
    files = []
    files.append(args.config)

    with nn.parameter_scope('current'):
        info = load.load(files)
        parameters = get_parameters(grad_only=False)

    config = TrainConfig()
    config.global_config = info.global_config
    config.training_config = info.training_config

    config.optimizers = OrderedDict()
    for name, opt in info.optimizers.items():
        o = OptConfig()
        o.optimizer = opt
        o.data_iterator = None
        config.optimizers[name] = o

    config.monitors = OrderedDict()
    for name, mon in info.monitors.items():
        m = MonConfig()
        m.monitor = mon
        m.data_iterator = None
        config.monitors[name] = m

    # Load config with cpu context
    files = []
    files.append(args.config2)

    with nn.parameter_scope('cpu'):
        info_cpu = load.load(files)
        cpu_parameters = get_parameters(grad_only=False)

    config_cpu = TrainConfig()
    config_cpu.global_config = info_cpu.global_config
    config_cpu.training_config = info_cpu.training_config

    config_cpu.optimizers = OrderedDict()
    for name, opt in info_cpu.optimizers.items():
        o = OptConfig()
        o.optimizer = opt
        o.data_iterator = None
        config_cpu.optimizers[name] = o

    config_cpu.monitors = OrderedDict()
    for name, mon in info_cpu.monitors.items():
        m = MonConfig()
        m.monitor = mon
        m.data_iterator = None
        config_cpu.monitors[name] = m

    result_array = [['1-Correl']]

    # Profile Optimizer
    with ExitStack() as stack:
        for name, o in config.optimizers.items():
            o.data_iterator = stack.enter_context(o.optimizer.data_iterator())
        for name, o in config_cpu.optimizers.items():
            o.data_iterator = stack.enter_context(o.optimizer.data_iterator())
        result_array = compare_optimizer(config, parameters, config_cpu,
                                         cpu_parameters, result_array)

    # Write profiling result
    import csv
    with open(args.outdir + os.sep + 'compare_with_cpu.csv', 'w') as f:
        writer = csv.writer(f, lineterminator='\n')
        writer.writerows(result_array)

    logger.log(99, 'Compare with CPU Completed.')
    progress(None)
    return True
Example #32
0
class ZiplineTestCase(with_metaclass(FinalMeta, TestCase)):
    """
    Shared extensions to core unittest.TestCase.

    Overrides the default unittest setUp/tearDown functions with versions that
    use ExitStack to correctly clean up resources, even in the face of
    exceptions that occur during setUp/setUpClass.

    Subclasses **should not override setUp or setUpClass**!

    Instead, they should implement `init_instance_fixtures` for per-test-method
    resources, and `init_class_fixtures` for per-class resources.

    Resources that need to be cleaned up should be registered using
    either `enter_{class,instance}_context` or `add_{class,instance}_callback}.
    """
    _in_setup = False

    @final
    @classmethod
    def setUpClass(cls):
        # Hold a set of all the "static" attributes on the class. These are
        # things that are not populated after the class was created like
        # methods or other class level attributes.
        cls._static_class_attributes = set(vars(cls))
        cls._class_teardown_stack = ExitStack()
        try:
            cls._base_init_fixtures_was_called = False
            cls.init_class_fixtures()
            assert cls._base_init_fixtures_was_called, (
                "ZiplineTestCase.init_class_fixtures() was not called.\n"
                "This probably means that you overrode init_class_fixtures"
                " without calling super()."
            )
        except:
            cls.tearDownClass()
            raise

    @classmethod
    def init_class_fixtures(cls):
        """
        Override and implement this classmethod to register resources that
        should be created and/or torn down on a per-class basis.

        Subclass implementations of this should always invoke this with super()
        to ensure that fixture mixins work properly.
        """
        if cls._in_setup:
            raise ValueError(
                'Called init_class_fixtures from init_instance_fixtures.'
                'Did you write super(..., self).init_class_fixtures() instead'
                ' of super(..., self).init_instance_fixtures()?',
            )
        cls._base_init_fixtures_was_called = True

    @final
    @classmethod
    def tearDownClass(cls):
        # We need to get this before it's deleted by the loop.
        stack = cls._class_teardown_stack
        for name in set(vars(cls)) - cls._static_class_attributes:
            # Remove all of the attributes that were added after the class was
            # constructed. This cleans up any large test data that is class
            # scoped while still allowing subclasses to access class level
            # attributes.
            delattr(cls, name)
        stack.close()

    @final
    @classmethod
    def enter_class_context(cls, context_manager):
        """
        Enter a context manager to be exited during the tearDownClass
        """
        if cls._in_setup:
            raise ValueError(
                'Attempted to enter a class context in init_instance_fixtures.'
                '\nDid you mean to call enter_instance_context?',
            )
        return cls._class_teardown_stack.enter_context(context_manager)

    @final
    @classmethod
    def add_class_callback(cls, callback):
        """
        Register a callback to be executed during tearDownClass.

        Parameters
        ----------
        callback : callable
            The callback to invoke at the end of the test suite.
        """
        if cls._in_setup:
            raise ValueError(
                'Attempted to add a class callback in init_instance_fixtures.'
                '\nDid you mean to call add_instance_callback?',
            )
        return cls._class_teardown_stack.callback(callback)

    @final
    def setUp(self):
        type(self)._in_setup = True
        self._pre_setup_attrs = set(vars(self))
        self._instance_teardown_stack = ExitStack()
        try:
            self._init_instance_fixtures_was_called = False
            self.init_instance_fixtures()
            assert self._init_instance_fixtures_was_called, (
                "ZiplineTestCase.init_instance_fixtures() was not"
                " called.\n"
                "This probably means that you overrode"
                " init_instance_fixtures without calling super()."
            )
        except:
            self.tearDown()
            raise
        finally:
            type(self)._in_setup = False

    def init_instance_fixtures(self):
        self._init_instance_fixtures_was_called = True

    @final
    def tearDown(self):
        # We need to get this before it's deleted by the loop.
        stack = self._instance_teardown_stack
        for attr in set(vars(self)) - self._pre_setup_attrs:
            delattr(self, attr)
        stack.close()

    @final
    def enter_instance_context(self, context_manager):
        """
        Enter a context manager that should be exited during tearDown.
        """
        return self._instance_teardown_stack.enter_context(context_manager)

    @final
    def add_instance_callback(self, callback):
        """
        Register a callback to be executed during tearDown.

        Parameters
        ----------
        callback : callable
            The callback to invoke at the end of each test.
        """
        return self._instance_teardown_stack.callback(callback)
Example #33
0
import sys

from .program import TestProgram
from .is_standalone import is_standalone_use

from .utils import PY2, drop_into_debugger

if PY2:
    from contextlib2 import ExitStack
else:
    from contextlib import ExitStack

with ExitStack() as stack:
    if '--debug' in sys.argv:
        stack.enter_context(drop_into_debugger())
    is_standalone_use(False)
    TestProgram(module=None)