class CleanableCM(object): """Cleanable context manager (based on ExitStack)""" def __init__(self): super(CleanableCM, self).__init__() self.stack = ExitStack() def _enter(self): """Should be override""" raise NotImplementedError @contextmanager def _cleanup_on_error(self): with ExitStack() as stack: stack.push(self) yield stack.pop_all() def __enter__(self): with self._cleanup_on_error(): self.stack.__enter__() return self._enter() def __exit__(self, exc_type, exc_value, traceback): self.stack.__exit__(exc_type, exc_value, traceback)
def __call__(s, *args, **kwargs): stack = ExitStack() stack.enter_context(self.flask_app.app_context()) if getattr(s, 'request_context', False): stack.enter_context(self.flask_app.test_request_context(base_url=config.BASE_URL)) args = _CelerySAWrapper.unwrap_args(args) kwargs = _CelerySAWrapper.unwrap_kwargs(kwargs) plugin = getattr(s, 'plugin', s.request.get('indico_plugin')) if isinstance(plugin, basestring): plugin_name = plugin plugin = plugin_engine.get_plugin(plugin) if plugin is None: stack.close() raise ValueError('Plugin not active: ' + plugin_name) stack.enter_context(plugin_context(plugin)) clearCache() with stack: request_stats_request_started() return super(IndicoTask, s).__call__(*args, **kwargs)
def setUp(self): self._instance_teardown_stack = ExitStack() try: self._init_instance_fixtures_was_called = False self.init_instance_fixtures() assert self._init_instance_fixtures_was_called, ( "ZiplineTestCase.init_instance_fixtures() was not" " called.\n" "This probably means that you overrode" " init_instance_fixtures without calling super()." ) except: self.tearDown() raise
def shell_cmd(verbose, with_req_context): try: from IPython.terminal.ipapp import TerminalIPythonApp except ImportError: click.echo(cformat('%{red!}You need to `pip install ipython` to use the Indico shell')) sys.exit(1) current_app.config['REPL'] = True # disables e.g. memoize_request request_stats_request_started() context, info = _make_shell_context() banner = cformat('%{yellow!}Indico v{} is ready for your commands').format(indico.__version__) if verbose: banner = '\n'.join(info + ['', banner]) ctx = current_app.make_shell_context() ctx.update(context) clearCache() stack = ExitStack() if with_req_context: stack.enter_context(current_app.test_request_context(base_url=config.BASE_URL)) with stack: ipython_app = TerminalIPythonApp.instance(user_ns=ctx, display_banner=False) ipython_app.initialize(argv=[]) ipython_app.shell.show_banner(banner) ipython_app.start()
def __init__(self): super(NosePlugin, self).__init__() self.patterns = [] self.stderr = False self.record = False def set_stderr(ignore): self.stderr = True self.addArgument(self.patterns, 'P', 'pattern', 'Add a test matching pattern') self.addFlag(set_stderr, 'E', 'stderr', 'Enable stderr logging to sub-runners') def set_record(ignore): self.record = True self.addFlag(set_record, 'R', 'rerecord', """Force re-recording of test responses. Requires Mailman to be running.""") self._data_path = os.path.join(TOPDIR, 'tests', 'data', 'tape.yaml') self._resources = ExitStack()
def setUp(self): type(self)._in_setup = True self._pre_setup_attrs = set(vars(self)) self._instance_teardown_stack = ExitStack() try: self._init_instance_fixtures_was_called = False self.init_instance_fixtures() assert self._init_instance_fixtures_was_called, ( "ZiplineTestCase.init_instance_fixtures() was not" " called.\n" "This probably means that you overrode" " init_instance_fixtures without calling super()." ) except: self.tearDown() raise finally: type(self)._in_setup = False
def train_command(args): if args.ooc_gpu_memory_size is not None: ooc_gpu_memory_size = str_to_num(args.ooc_gpu_memory_size) if ooc_gpu_memory_size < 0: logger.log( 99, f'Fatal error. invalid ooc_gpu_memory_size [{args.ooc_gpu_memory_size}].' ) return False args.ooc_gpu_memory_size = ooc_gpu_memory_size if args.ooc_window_length is not None: ooc_window_length = str_to_num(args.ooc_window_length) if ooc_window_length < 0: logger.log( 99, f'Fatal error. invalid ooc_window_length [{args.ooc_window_length}].' ) return False args.ooc_window_length = ooc_window_length callback.update_status(args) if single_or_rankzero(): configure_progress(os.path.join(args.outdir, 'progress.txt')) info = load.load([args.config], prepare_data_iterator=None, exclude_parameter=True, context=args.context) # Check dataset uri is empty. dataset_error = False for dataset in info.datasets.values(): if dataset.uri.strip() == '': dataset_error = True if dataset_error: logger.log(99, 'Fatal error. Dataset URI is empty.') return False class TrainConfig: pass config = TrainConfig() config.timelimit = -1 if args.param: # If this parameter file contains optimizer information # we need to info to recovery. #load.load([args.param], parameter_only=True) load_train_state(args.param, info) config.timelimit = callback.get_timelimit(args) config.global_config = info.global_config config.training_config = info.training_config if single_or_rankzero(): logger.log(99, 'Train with contexts {}'.format(available_contexts)) class OptConfig: pass config.optimizers = OrderedDict() for name, opt in info.optimizers.items(): o = OptConfig() o.optimizer = opt o.data_iterators = [] config.optimizers[name] = o class MonConfig: pass config.monitors = OrderedDict() for name, mon in info.monitors.items(): m = MonConfig() m.monitor = mon m.data_iterators = [] config.monitors[name] = m # Training comm = current_communicator() config.training_config.iter_per_epoch //= comm.size if comm else 1 max_iteration = config.training_config.max_epoch * \ config.training_config.iter_per_epoch global _save_parameter_info _save_parameter_info = {} _, config_ext = os.path.splitext(args.config) if config_ext == '.prototxt' or config_ext == '.nntxt': _save_parameter_info['config'] = args.config elif config_ext == '.nnp': with zipfile.ZipFile(args.config, 'r') as nnp: for name in nnp.namelist(): _, ext = os.path.splitext(name) if ext == '.nntxt' or ext == '.prototxt': nnp.extract(name, args.outdir) _save_parameter_info['config'] = os.path.join( args.outdir, name) result = False restart = False if max_iteration > 0: rng = np.random.RandomState(comm.rank if comm else 0) with ExitStack() as stack: # Create data_iterator instance only once for each dataset in optimizers optimizer_data_iterators = {} for name, o in config.optimizers.items(): for di in o.optimizer.data_iterators.values(): if di not in optimizer_data_iterators: di_instance = stack.enter_context(di()) if comm and comm.size > 1: di_instance = di_instance.slice( rng, comm.size, comm.rank) optimizer_data_iterators[di] = di_instance else: di_instance = optimizer_data_iterators[di] o.data_iterators.append(di_instance) # Create data_iterator instance only once for each dataset in monitors monitor_data_iterators = {} for name, m in config.monitors.items(): for di in m.monitor.data_iterators.values(): if di not in monitor_data_iterators: di_instance = stack.enter_context(di()) if comm and comm.size > 1: di_instance = di_instance.slice( rng, comm.size, comm.rank) monitor_data_iterators[di] = di_instance else: di_instance = monitor_data_iterators[di] m.data_iterators.append(di_instance) monitor_data_iterators.update(optimizer_data_iterators) result, restart = _train(args, config) else: # save parameters without training (0 epoch learning) logger.log(99, '0 epoch learning. (Just save parameter.)') if single_or_rankzero(): _save_parameters(args, None, 0, config, True) result = True if single_or_rankzero() and not restart: if result: logger.log(99, 'Training Completed.') callback.update_status('finished') else: logger.log(99, 'Training Incompleted.') callback.update_status('failed') if single_or_rankzero(): progress(None) return True
class ZiplineTestCase(with_metaclass(FinalMeta, TestCase)): """ Shared extensions to core unittest.TestCase. Overrides the default unittest setUp/tearDown functions with versions that use ExitStack to correctly clean up resources, even in the face of exceptions that occur during setUp/setUpClass. Subclasses **should not override setUp or setUpClass**! Instead, they should implement `init_instance_fixtures` for per-test-method resources, and `init_class_fixtures` for per-class resources. Resources that need to be cleaned up should be registered using either `enter_{class,instance}_context` or `add_{class,instance}_callback}. """ _in_setup = False @final @classmethod def setUpClass(cls): # Hold a set of all the "static" attributes on the class. These are # things that are not populated after the class was created like # methods or other class level attributes. cls._static_class_attributes = set(vars(cls)) cls._class_teardown_stack = ExitStack() try: cls._base_init_fixtures_was_called = False cls.init_class_fixtures() assert cls._base_init_fixtures_was_called, ( "ZiplineTestCase.init_class_fixtures() was not called.\n" "This probably means that you overrode init_class_fixtures" " without calling super().") except: cls.tearDownClass() raise @classmethod def init_class_fixtures(cls): """ Override and implement this classmethod to register resources that should be created and/or torn down on a per-class basis. Subclass implementations of this should always invoke this with super() to ensure that fixture mixins work properly. """ if cls._in_setup: raise ValueError( 'Called init_class_fixtures from init_instance_fixtures.' 'Did you write super(..., self).init_class_fixtures() instead' ' of super(..., self).init_instance_fixtures()?', ) cls._base_init_fixtures_was_called = True @final @classmethod def tearDownClass(cls): cls._class_teardown_stack.close() for name in set(vars(cls)) - cls._static_class_attributes: # Remove all of the attributes that were added after the class was # constructed. This cleans up any large test data that is class # scoped while still allowing subclasses to access class level # attributes. delattr(cls, name) @final @classmethod def enter_class_context(cls, context_manager): """ Enter a context manager to be exited during the tearDownClass """ if cls._in_setup: raise ValueError( 'Attempted to enter a class context in init_instance_fixtures.' '\nDid you mean to call enter_instance_context?', ) return cls._class_teardown_stack.enter_context(context_manager) @final @classmethod def add_class_callback(cls, callback): """ Register a callback to be executed during tearDownClass. Parameters ---------- callback : callable The callback to invoke at the end of the test suite. """ if cls._in_setup: raise ValueError( 'Attempted to add a class callback in init_instance_fixtures.' '\nDid you mean to call add_instance_callback?', ) return cls._class_teardown_stack.callback(callback) @final def setUp(self): type(self)._in_setup = True self._pre_setup_attrs = set(vars(self)) self._instance_teardown_stack = ExitStack() try: self._init_instance_fixtures_was_called = False self.init_instance_fixtures() assert self._init_instance_fixtures_was_called, ( "ZiplineTestCase.init_instance_fixtures() was not" " called.\n" "This probably means that you overrode" " init_instance_fixtures without calling super().") except: self.tearDown() raise finally: type(self)._in_setup = False def init_instance_fixtures(self): self._init_instance_fixtures_was_called = True @final def tearDown(self): self._instance_teardown_stack.close() for attr in set(vars(self)) - self._pre_setup_attrs: delattr(self, attr) @final def enter_instance_context(self, context_manager): """ Enter a context manager that should be exited during tearDown. """ return self._instance_teardown_stack.enter_context(context_manager) @final def add_instance_callback(self, callback): """ Register a callback to be executed during tearDown. Parameters ---------- callback : callable The callback to invoke at the end of each test. """ return self._instance_teardown_stack.callback(callback)
def stack(): """Provide a cleanup stack to use in the test (without indentation).""" with ExitStack() as stack: yield stack
def profile_command(args): callback.update_status(args) configure_progress(os.path.join(args.outdir, 'progress.txt')) class TrainConfig: pass config = TrainConfig() info = load.load(args.config) config.global_config = info.global_config config.training_config = info.training_config class OptConfig: pass config.optimizers = OrderedDict() for name, opt in info.optimizers.items(): o = OptConfig() o.optimizer = opt o.data_iterators = [] config.optimizers[name] = o class MonConfig: pass config.monitors = OrderedDict() for name, mon in info.monitors.items(): m = MonConfig() m.monitor = mon m.data_iterators = [] config.monitors[name] = m ext_module = import_extension_module( config.global_config.default_context.backend[0].split(':')[0]) def synchronize(): return ext_module.synchronize( device_id=config.global_config.default_context.device_id) result_array = [['time in ms']] callback.update_status('processing', True) # Profile Optimizer with ExitStack() as stack: # Create data_iterator instance only once for each dataset in optimizers optimizer_data_iterators = {} for name, o in config.optimizers.items(): for di in o.optimizer.data_iterators.values(): if di not in optimizer_data_iterators: di_instance = stack.enter_context(di()) optimizer_data_iterators[di] = di_instance else: di_instance = optimizer_data_iterators[di] o.data_iterators.append(di_instance) result_array = profile_optimizer(config, result_array, synchronize) # Write profiling result import csv with open(args.outdir + os.sep + 'profile.csv', 'w') as f: writer = csv.writer(f, lineterminator='\n') writer.writerows(result_array) logger.log(99, 'Profile Completed.') progress(None) callback.update_status('finished') return True
def transform(self): """ Main generator work loop. """ algo = self.algo emission_rate = algo.perf_tracker.emission_rate def every_bar(dt_to_use, current_data=self.current_data, handle_data=algo.event_manager.handle_data): # called every tick (minute or day). algo.on_dt_changed(dt_to_use) for capital_change in calculate_minute_capital_changes(dt_to_use): yield capital_change self.simulation_dt = dt_to_use blotter = algo.blotter perf_tracker = algo.perf_tracker # handle any transactions and commissions coming out new orders # placed in the last bar new_transactions, new_commissions, closed_orders = \ blotter.get_transactions(current_data) blotter.prune_orders(closed_orders) for transaction in new_transactions: perf_tracker.process_transaction(transaction) # since this order was modified, record it order = blotter.orders[transaction.order_id] perf_tracker.process_order(order) if new_commissions: for commission in new_commissions: perf_tracker.process_commission(commission) handle_data(algo, current_data, dt_to_use) # grab any new orders from the blotter, then clear the list. # this includes cancelled orders. new_orders = blotter.new_orders blotter.new_orders = [] # if we have any new orders, record them so that we know # in what perf period they were placed. if new_orders: for new_order in new_orders: perf_tracker.process_order(new_order) algo.portfolio_needs_update = True algo.account_needs_update = True algo.performance_needs_update = True def once_a_day(midnight_dt, current_data=self.current_data, data_portal=self.data_portal): perf_tracker = algo.perf_tracker # Get the positions before updating the date so that prices are # fetched for trading close instead of midnight positions = algo.perf_tracker.position_tracker.positions position_assets = algo.asset_finder.retrieve_all(positions) # set all the timestamps self.simulation_dt = midnight_dt algo.on_dt_changed(midnight_dt) # process any capital changes that came overnight for capital_change in algo.calculate_capital_changes( midnight_dt, emission_rate=emission_rate, is_interday=True): yield capital_change # we want to wait until the clock rolls over to the next day # before cleaning up expired assets. self._cleanup_expired_assets(midnight_dt, position_assets) # handle any splits that impact any positions or any open orders. assets_we_care_about = \ viewkeys(perf_tracker.position_tracker.positions) | \ viewkeys(algo.blotter.open_orders) for a in assets_we_care_about.copy(): if not isinstance(a, Equity): assets_we_care_about.remove(a) # TODO GD REMOVE THE NOT !!! TMP HACK !!! if assets_we_care_about: splits = data_portal.get_splits(assets_we_care_about, midnight_dt) if splits: algo.blotter.process_splits(splits) perf_tracker.position_tracker.handle_splits(splits) def handle_benchmark(date, benchmark_source=self.benchmark_source): algo.perf_tracker.all_benchmark_returns[date] = \ benchmark_source.get_value(date) def on_exit(): # Remove references to algo, data portal, et al to break cycles # and ensure deterministic cleanup of these objects when the # simulation finishes. self.algo = None self.benchmark_source = self.current_data = self.data_portal = None with ExitStack() as stack: stack.callback(on_exit) stack.enter_context(self.processor) stack.enter_context(ZiplineAPI(self.algo)) if algo.data_frequency == 'minute': def execute_order_cancellation_policy(): algo.blotter.execute_cancel_policy(SESSION_END) def calculate_minute_capital_changes(dt): # process any capital changes that came between the last # and current minutes return algo.calculate_capital_changes( dt, emission_rate=emission_rate, is_interday=False) else: def execute_order_cancellation_policy(): pass def calculate_minute_capital_changes(dt): return [] for dt, action in self.clock: if action == BAR: for capital_change_packet in every_bar(dt): yield capital_change_packet elif action == SESSION_START: for capital_change_packet in once_a_day(dt): yield capital_change_packet elif action == SESSION_END: # End of the session. if emission_rate == 'daily': handle_benchmark(normalize_date(dt)) execute_order_cancellation_policy() yield self._get_daily_message(dt, algo, algo.perf_tracker) elif action == BEFORE_TRADING_START_BAR: self.simulation_dt = dt algo.on_dt_changed(dt) algo.before_trading_start(self.current_data) elif action == MINUTE_END: handle_benchmark(dt) minute_msg = \ self._get_minute_message(dt, algo, algo.perf_tracker) yield minute_msg risk_message = algo.perf_tracker.handle_simulation_end() yield risk_message
def run_test(in_file, test_spec, global_cfg): """Run a single tavern test Note that each tavern test can consist of multiple requests (log in, create, update, delete, etc). The global configuration is copied and used as an initial configuration for this test. Any values which are saved from any tests are saved into this test block and can be used for formatting in later stages in the test. Args: in_file (str): filename containing this test test_spec (dict): The specification for this test global_cfg (dict): Any global configuration for this test No Longer Raises: TavernException: If any of the tests failed """ # pylint: disable=too-many-locals # Initialise test config for this test with the global configuration before # starting test_block_config = dict(global_cfg) if "variables" not in test_block_config: test_block_config["variables"] = {} tavern_box = Box({"env_vars": dict(os.environ)}) if not test_spec: logger.warning("Empty test block in %s", in_file) return # Get included stages and resolve any into the test spec dictionary available_stages = test_block_config.get("stages", []) included_stages = _get_included_stages(tavern_box, test_block_config, test_spec, available_stages) all_stages = {s["id"]: s for s in available_stages + included_stages} test_spec["stages"] = _resolve_test_stages(test_spec, all_stages) test_block_config["variables"]["tavern"] = tavern_box test_block_name = test_spec["test_name"] # Strict on body by default default_strictness = test_block_config["strict"] logger.info("Running test : %s", test_block_name) with ExitStack() as stack: sessions = get_extra_sessions(test_spec, test_block_config) for name, session in sessions.items(): logger.debug("Entering context for %s", name) stack.enter_context(session) def getonly(stage): o = stage.get("only") if o is None: return False elif isinstance(o, bool): return o else: return strtobool(o) has_only = any(getonly(stage) for stage in test_spec["stages"]) # Run tests in a path in order for stage in test_spec["stages"]: if stage.get("skip"): continue elif has_only and not getonly(stage): continue test_block_config["strict"] = default_strictness # Can be overridden per stage # NOTE # this is hardcoded to check for the 'response' block. In the far # future there might not be a response block, but at the moment it # is the hardcoded value for any HTTP request. if stage.get("response", {}): if stage.get("response").get("strict", None) is not None: stage_strictness = stage.get("response").get( "strict", None) elif test_spec.get("strict", None) is not None: stage_strictness = test_spec.get("strict", None) else: stage_strictness = default_strictness logger.debug("Strict key checking for this stage is '%s'", stage_strictness) test_block_config["strict"] = stage_strictness elif default_strictness: logger.debug("Default strictness '%s' ignored for this stage", default_strictness) # Wrap run_stage with retry helper run_stage_with_retries = retry(stage, test_block_config)(run_stage) try: run_stage_with_retries(sessions, stage, tavern_box, test_block_config) except exceptions.TavernException as e: e.stage = stage e.test_block_config = test_block_config raise if getonly(stage): break
def context(self): stack = ExitStack() for (name, context) in self.__contexts__: setattr(stack, name, stack.enter_context(context)) return stack
class NosePlugin(Plugin): configSection = 'mailman' def __init__(self): super(NosePlugin, self).__init__() self.patterns = [] self.stderr = False self.record = False def set_stderr(ignore): self.stderr = True self.addArgument(self.patterns, 'P', 'pattern', 'Add a test matching pattern') self.addFlag(set_stderr, 'E', 'stderr', 'Enable stderr logging to sub-runners') def set_record(ignore): self.record = True self.addFlag(set_record, 'R', 'rerecord', """Force re-recording of test responses. Requires Mailman to be running.""") self._data_path = os.path.join(TOPDIR, 'tests', 'data', 'tape.yaml') self._resources = ExitStack() def startTestRun(self, event): # Check to see if we're running the test suite in record mode. If so, # delete any existing recording. if self.record: try: os.remove(self._data_path) except OSError as error: if error.errno != errno.ENOENT: raise # This will automatically create the recording file. self._resources.enter_context(vcr.use_cassette(self._data_path)) def stopTestRun(self, event): # Stop all recording. self._resources.close() def getTestCaseNames(self, event): if len(self.patterns) == 0: # No filter patterns, so everything should be tested. return # Does the pattern match the fully qualified class name? for pattern in self.patterns: full_class_name = '{}.{}'.format( event.testCase.__module__, event.testCase.__name__) if re.search(pattern, full_class_name): # Don't suppress this test class. return names = filter(event.isTestMethod, dir(event.testCase)) for name in names: full_test_name = '{}.{}.{}'.format( event.testCase.__module__, event.testCase.__name__, name) for pattern in self.patterns: if re.search(pattern, full_test_name): break else: event.excludedNames.append(name) def handleFile(self, event): path = event.path[len(TOPDIR)+1:] if len(self.patterns) > 0: for pattern in self.patterns: if re.search(pattern, path): break else: # Skip this doctest. return base, ext = os.path.splitext(path) if ext != '.rst': return test = doctest.DocFileTest( path, package=mailmanclient, optionflags=FLAGS, setUp=setup, tearDown=teardown) # Suppress the extra "Doctest: ..." line. test.shortDescription = lambda: None event.extraTests.append(test)
def transform(self): """ Main generator work loop. """ algo = self.algo metrics_tracker = algo.metrics_tracker emission_rate = metrics_tracker.emission_rate def every_bar(dt_to_use, current_data=self.current_data, handle_data=algo.event_manager.handle_data): for capital_change in calculate_minute_capital_changes(dt_to_use): yield capital_change self.simulation_dt = dt_to_use # called every tick (minute or day). algo.on_dt_changed(dt_to_use) blotter = algo.blotter # handle any transactions and commissions coming out new orders # placed in the last bar new_transactions, new_commissions, closed_orders = \ blotter.get_transactions(current_data) blotter.prune_orders(closed_orders) for transaction in new_transactions: metrics_tracker.process_transaction(transaction) # since this order was modified, record it order = blotter.orders[transaction.order_id] metrics_tracker.process_order(order) for commission in new_commissions: metrics_tracker.process_commission(commission) handle_data(algo, current_data, dt_to_use) # grab any new orders from the blotter, then clear the list. # this includes cancelled orders. new_orders = blotter.new_orders blotter.new_orders = [] # if we have any new orders, record them so that we know # in what perf period they were placed. for new_order in new_orders: metrics_tracker.process_order(new_order) def once_a_day(midnight_dt, current_data=self.current_data, data_portal=self.data_portal): # process any capital changes that came overnight for capital_change in algo.calculate_capital_changes( midnight_dt, emission_rate=emission_rate, is_interday=True): yield capital_change # set all the timestamps self.simulation_dt = midnight_dt algo.on_dt_changed(midnight_dt) metrics_tracker.handle_market_open( midnight_dt, algo.data_portal, ) # handle any splits that impact any positions or any open orders. assets_we_care_about = (viewkeys(metrics_tracker.positions) | viewkeys(algo.blotter.open_orders)) if assets_we_care_about: splits = data_portal.get_splits(assets_we_care_about, midnight_dt) if splits: algo.blotter.process_splits(splits) metrics_tracker.handle_splits(splits) def on_exit(): # Remove references to algo, data portal, et al to break cycles # and ensure deterministic cleanup of these objects when the # simulation finishes. self.algo = None self.benchmark_source = self.current_data = self.data_portal = None with ExitStack() as stack: stack.callback(on_exit) stack.enter_context(self.processor) stack.enter_context(ZiplineAPI(self.algo)) if algo.data_frequency == 'minute': def execute_order_cancellation_policy(): algo.blotter.execute_cancel_policy(SESSION_END) def calculate_minute_capital_changes(dt): # process any capital changes that came between the last # and current minutes return algo.calculate_capital_changes( dt, emission_rate=emission_rate, is_interday=False) else: def execute_order_cancellation_policy(): pass def calculate_minute_capital_changes(dt): return [] for dt, action in self.clock: if action == BAR: for capital_change_packet in every_bar(dt): yield capital_change_packet elif action == SESSION_START: for capital_change_packet in once_a_day(dt): yield capital_change_packet elif action == SESSION_END: # End of the session. positions = metrics_tracker.positions position_assets = algo.asset_finder.retrieve_all(positions) self._cleanup_expired_assets(dt, position_assets) execute_order_cancellation_policy() algo.validate_account_controls() yield self._get_daily_message(dt, algo, metrics_tracker) elif action == BEFORE_TRADING_START_BAR: self.simulation_dt = dt algo.on_dt_changed(dt) algo.before_trading_start(self.current_data) elif action == MINUTE_END: minute_msg = self._get_minute_message( dt, algo, metrics_tracker, ) yield minute_msg if not self.fast_backtest: risk_message = metrics_tracker.handle_simulation_end( self.data_portal, ) else: risk_message = None yield risk_message
def run_test(in_file, test_spec, global_cfg): """Run a single tavern test Note that each tavern test can consist of multiple requests (log in, create, update, delete, etc). The global configuration is copied and used as an initial configuration for this test. Any values which are saved from any tests are saved into this test block and can be used for formatting in later stages in the test. Args: in_file (str): filename containing this test test_spec (dict): The specification for this test global_cfg (dict): Any global configuration for this test Raises: TavernException: If any of the tests failed """ # pylint: disable=too-many-locals # Initialise test config for this test with the global configuration before # starting test_block_config = dict(global_cfg) if "variables" not in test_block_config: test_block_config["variables"] = {} tavern_box = Box({ "env_vars": dict(os.environ), }) test_block_config["variables"]["tavern"] = tavern_box if not test_spec: logger.warning("Empty test block in %s", in_file) return if test_spec.get("includes"): for included in test_spec["includes"]: if "variables" in included: formatted_include = format_keys(included["variables"], {"tavern": tavern_box}) test_block_config["variables"].update(formatted_include) test_block_name = test_spec["test_name"] logger.info("Running test : %s", test_block_name) with ExitStack() as stack: sessions = get_extra_sessions(test_spec) for name, session in sessions.items(): logger.debug("Entering context for %s", name) stack.enter_context(session) # Run tests in a path in order for stage in test_spec["stages"]: name = stage["name"] try: r = get_request_type(stage, test_block_config, sessions) except exceptions.MissingFormatError: log_fail(stage, None, None) raise tavern_box.update(request_vars=r.request_vars) try: expected = get_expected(stage, test_block_config, sessions) except exceptions.TavernException: log_fail(stage, None, None) raise delay(stage, "before") logger.info("Running stage : %s", name) try: response = r.run() except exceptions.TavernException: log_fail(stage, None, expected) raise verifiers = get_verifiers(stage, test_block_config, sessions, expected) for v in verifiers: try: saved = v.verify(response) except exceptions.TavernException: log_fail(stage, v, expected) raise else: test_block_config["variables"].update(saved) log_pass(stage, verifiers) tavern_box.pop("request_vars") delay(stage, "after")
def main(argv=None): if argv is None: argv = sys.argv usage = "usage: %prog [options] [workflow_file]" parser = optparse.OptionParser(usage=usage) parser.add_option("--no-discovery", action="store_true", help="Don't run widget discovery " "(use full cache instead)") parser.add_option("--force-discovery", action="store_true", help="Force full widget discovery " "(invalidate cache)") parser.add_option("--clear-widget-settings", action="store_true", help="Remove stored widget setting") parser.add_option("--no-welcome", action="store_true", help="Don't show welcome dialog.") parser.add_option("--no-splash", action="store_true", help="Don't show splash screen.") parser.add_option("-l", "--log-level", help="Logging level (0, 1, 2, 3, 4)", type="int", default=1) parser.add_option("--style", help="QStyle to use", type="str", default=None) parser.add_option("--stylesheet", help="Application level CSS style sheet to use", type="str", default="orange.qss") parser.add_option("--qt", help="Additional arguments for QApplication", type="str", default=None) parser.add_option("--config", help="Configuration namespace", type="str", default="orangecanvas.example") # -m canvas orange.widgets # -m canvas --config orange.widgets (options, args) = parser.parse_args(argv[1:]) levels = [ logging.CRITICAL, logging.ERROR, logging.WARN, logging.INFO, logging.DEBUG ] # Fix streams before configuring logging (otherwise it will store # and write to the old file descriptors) fix_win_pythonw_std_stream() # Try to fix fonts on OSX Mavericks/Yosemite, ... fix_osx_private_font() # File handler should always be at least INFO level so we need # the application root level to be at least at INFO. root_level = min(levels[options.log_level], logging.INFO) rootlogger = logging.getLogger(__package__) rootlogger.setLevel(root_level) # Standard output stream handler at the requested level stream_hander = logging.StreamHandler() stream_hander.setLevel(level=levels[options.log_level]) rootlogger.addHandler(stream_hander) if options.config is not None: try: cfg = utils.name_lookup(options.config) except (ImportError, AttributeError): pass else: # config.default = cfg config.set_default(cfg) log.info("activating %s", options.config) log.info("Starting 'Orange Canvas' application.") qt_argv = argv[:1] if options.style is not None: qt_argv += ["-style", options.style] if options.qt is not None: qt_argv += shlex.split(options.qt) qt_argv += args if options.clear_widget_settings: log.debug("Clearing widget settings") shutil.rmtree(config.widget_settings_dir(), ignore_errors=True) if QT_VERSION >= 0x50600: CanvasApplication.setAttribute(Qt.AA_UseHighDpiPixmaps) log.debug("Starting CanvasApplicaiton with argv = %r.", qt_argv) app = CanvasApplication(qt_argv) # NOTE: config.init() must be called after the QApplication constructor config.init() file_handler = logging.FileHandler(filename=os.path.join( config.log_dir(), "canvas.log"), mode="w") file_handler.setLevel(root_level) rootlogger.addHandler(file_handler) # intercept any QFileOpenEvent requests until the main window is # fully initialized. # NOTE: The QApplication must have the executable ($0) and filename # arguments passed in argv otherwise the FileOpen events are # triggered for them (this is done by Cocoa, but QApplicaiton filters # them out if passed in argv) open_requests = [] def onrequest(url): log.info("Received an file open request %s", url) open_requests.append(url) app.fileOpenRequest.connect(onrequest) settings = QSettings() stylesheet = options.stylesheet stylesheet_string = None if stylesheet != "none": if os.path.isfile(stylesheet): with io.open(stylesheet, "r") as f: stylesheet_string = f.read() else: if not os.path.splitext(stylesheet)[1]: # no extension stylesheet = os.path.extsep.join([stylesheet, "qss"]) pkg_name = __package__ resource = "styles/" + stylesheet if pkg_resources.resource_exists(pkg_name, resource): stylesheet_string = \ pkg_resources.resource_string(pkg_name, resource).decode("utf-8") base = pkg_resources.resource_filename(pkg_name, "styles") pattern = re.compile( r"^\s@([a-zA-Z0-9_]+?)\s*:\s*([a-zA-Z0-9_/]+?);\s*$", flags=re.MULTILINE) matches = pattern.findall(stylesheet_string) for prefix, search_path in matches: QDir.addSearchPath(prefix, os.path.join(base, search_path)) log.info("Adding search path %r for prefix, %r", search_path, prefix) stylesheet_string = pattern.sub("", stylesheet_string) else: log.info("%r style sheet not found.", stylesheet) # Add the default canvas_icons search path dirpath = os.path.abspath(os.path.dirname(__file__)) QDir.addSearchPath("canvas_icons", os.path.join(dirpath, "icons")) canvas_window = CanvasMainWindow() canvas_window.setWindowIcon(config.application_icon()) if stylesheet_string is not None: canvas_window.setStyleSheet(stylesheet_string) if not options.force_discovery: reg_cache = cache.registry_cache() else: reg_cache = None widget_registry = qt.QtWidgetRegistry() widget_discovery = config.widget_discovery(widget_registry, cached_descriptions=reg_cache) want_splash = \ settings.value("startup/show-splash-screen", True, type=bool) and \ not options.no_splash if want_splash: pm, rect = config.splash_screen() splash_screen = SplashScreen(pixmap=pm, textRect=rect) splash_screen.setAttribute(Qt.WA_DeleteOnClose) splash_screen.setFont(QFont("Helvetica", 12)) color = QColor("#FFD39F") def show_message(message): splash_screen.showMessage(message, color=color) widget_registry.category_added.connect(show_message) show_splash = splash_screen.show close_splash = splash_screen.close else: show_splash = close_splash = lambda: None log.info("Running widget discovery process.") cache_filename = os.path.join(config.cache_dir(), "widget-registry.pck") if options.no_discovery: with open(cache_filename, "rb") as f: widget_registry = pickle.load(f) widget_registry = qt.QtWidgetRegistry(widget_registry) else: show_splash() widget_discovery.run(config.widgets_entry_points()) close_splash() # Store cached descriptions cache.save_registry_cache(widget_discovery.cached_descriptions) with open(cache_filename, "wb") as f: pickle.dump(WidgetRegistry(widget_registry), f) set_global_registry(widget_registry) canvas_window.set_widget_registry(widget_registry) canvas_window.show() canvas_window.raise_() want_welcome = \ settings.value("startup/show-welcome-screen", True, type=bool) \ and not options.no_welcome # Process events to make sure the canvas_window layout has # a chance to activate (the welcome dialog is modal and will # block the event queue, plus we need a chance to receive open file # signals when running without a splash screen) app.processEvents() app.fileOpenRequest.connect(canvas_window.open_scheme_file) if want_welcome and not args and not open_requests: canvas_window.welcome_dialog() elif args: log.info("Loading a scheme from the command line argument %r", args[0]) canvas_window.load_scheme(args[0]) elif open_requests: log.info("Loading a scheme from an `QFileOpenEvent` for %r", open_requests[-1]) canvas_window.load_scheme(open_requests[-1].toLocalFile()) # Tee stdout and stderr into Output dock output_view = canvas_window.output_view() stdout = TextStream() stdout.stream.connect(output_view.write) if sys.stdout: stdout.stream.connect(sys.stdout.write) stdout.flushed.connect(sys.stdout.flush) stderr = TextStream() error_writer = output_view.formated(color=Qt.red) stderr.stream.connect(error_writer.write) if sys.stderr: stderr.stream.connect(sys.stderr.write) stderr.flushed.connect(sys.stderr.flush) sys.excepthook = ExceptHook(stream=stderr) with ExitStack() as stack: stack.enter_context(redirect_stdout(stdout)) stack.enter_context(redirect_stderr(stderr)) log.info("Entering main event loop.") try: status = app.exec_() except BaseException: log.error("Error in main event loop.", exc_info=True) canvas_window.deleteLater() app.processEvents() app.flush() del canvas_window # Collect any cycles before deleting the QApplication instance gc.collect() del app return status
def main(self, orings): for iseqs in izip( * [iring.read(guarantee=self.guarantee) for iring in self.irings]): if self.shutdown_event.is_set(): break for i, iseq in enumerate(iseqs): self.sequence_proclogs[i].update(iseq.header) oheaders, islices = self._on_sequence(iseqs) for ohdr in oheaders: if 'time_tag' not in ohdr: ohdr['time_tag'] = self._seq_count self._seq_count += 1 # Allow passing None to mean slice(gulp_nframe) if islices is None: islices = [None] * len(self.irings) default_igulp_nframes = [ self.gulp_nframe or iseq.header['gulp_nframe'] for iseq in iseqs ] islices = [ islice or slice(igulp_nframe) for (islice, igulp_nframe) in zip(islices, default_igulp_nframes) ] islices = [_span_slice(slice_) for slice_ in islices] for iseq, islice in zip(iseqs, islices): if self.buffer_factor is None: src_block = iseq.ring.owner if src_block is not None and self.is_fused_with(src_block): buffer_factor = 1 else: buffer_factor = None else: buffer_factor = self.buffer_factor iseq.resize(gulp_nframe=(islice.stop - islice.start), buf_nframe=self.buffer_nframe, buffer_factor=buffer_factor) igulp_nframes = [islice.stop - islice.start for islice in islices] with ExitStack() as oseq_stack: oseqs = self.begin_sequences(oseq_stack, orings, oheaders, igulp_nframes) prev_time = time.time() for ispans in izip(*[ iseq.read(islice.stop - islice.start, islice.step, islice.start) for (iseq, islice) in zip(iseqs, islices) ]): if self.shutdown_event.is_set(): break cur_time = time.time() acquire_time = cur_time - prev_time prev_time = cur_time with ExitStack() as ospan_stack: ospans = self.reserve_spans(ospan_stack, oseqs, ispans) cur_time = time.time() reserve_time = cur_time - prev_time prev_time = cur_time # *TODO: See if can fuse together multiple on_data calls here before # calling stream_synchronize(). # Consider passing .data instead of rings here ostrides = self._on_data(ispans, ospans) # TODO: // Default to not spinning the CPU: cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync); bf.device.stream_synchronize() # Allow returning None to indicate complete consumption if ostrides is None: ostrides = [ospan.nframe for ospan in ospans] ostrides = [ ostride if ostride is not None else ospan.nframe for (ostride, ospan) in zip(ostrides, ospans) ] for ospan, ostride in zip(ospans, ostrides): ospan.commit(ostride) cur_time = time.time() process_time = cur_time - prev_time prev_time = cur_time self.perf_proclog.update({ 'acquire_time': acquire_time, 'reserve_time': reserve_time, 'process_time': process_time }) self._on_sequence_end(iseqs)
def transform(self, stream_in): """ Main generator work loop. """ # Initialize the mkt_close mkt_open = self.algo.perf_tracker.market_open mkt_close = self.algo.perf_tracker.market_close # inject the current algo # snapshot time to any log record generated. with ExitStack() as stack: stack.enter_context(self.processor) stack.enter_context(ZiplineAPI(self.algo)) data_frequency = self.sim_params.data_frequency self._call_before_trading_start(mkt_open) for date, snapshot in stream_in: self.simulation_dt = date self.on_dt_changed(date) # If we're still in the warmup period. Use the event to # update our universe, but don't yield any perf messages, # and don't send a snapshot to handle_data. if date < self.algo_start: for event in snapshot: if event.type == DATASOURCE_TYPE.SPLIT: self.algo.blotter.process_split(event) elif event.type == DATASOURCE_TYPE.TRADE: self.update_universe(event) self.algo.perf_tracker.process_trade(event) elif event.type == DATASOURCE_TYPE.CUSTOM: self.update_universe(event) else: messages = self._process_snapshot( date, snapshot, self.algo.instant_fill, ) # Perf messages are only emitted if the snapshot contained # a benchmark event. for message in messages: yield message # When emitting minutely, we need to call # before_trading_start before the next trading day begins if date == mkt_close: if mkt_close <= self.algo.perf_tracker.last_close: before_last_close = \ mkt_close < self.algo.perf_tracker.last_close try: mkt_open, mkt_close = \ trading.environment \ .next_open_and_close(mkt_close) except trading.NoFurtherDataError: # If at the end of backtest history, # skip advancing market close. pass if before_last_close: self._call_before_trading_start(mkt_open) elif data_frequency == 'daily': next_day = trading.environment.next_trading_day(date) if next_day is not None and \ next_day < self.algo.perf_tracker.last_close: self._call_before_trading_start(next_day) self.algo.portfolio_needs_update = True self.algo.account_needs_update = True self.algo.performance_needs_update = True risk_message = self.algo.perf_tracker.handle_simulation_end() yield risk_message
def settings(*args, **kwargs): """ Nest context managers and/or override ``env`` variables. `settings` serves two purposes: * Most usefully, it allows temporary overriding/updating of ``env`` with any provided keyword arguments, e.g. ``with settings(user='******'):``. Original values, if any, will be restored once the ``with`` block closes. * The keyword argument ``clean_revert`` has special meaning for ``settings`` itself (see below) and will be stripped out before execution. * In addition, it will use `contextlib.nested`_ to nest any given non-keyword arguments, which should be other context managers, e.g. ``with settings(hide('stderr'), show('stdout')):``. .. _contextlib.nested: http://docs.python.org/library/contextlib.html#contextlib.nested These behaviors may be specified at the same time if desired. An example will hopefully illustrate why this is considered useful:: def my_task(): with settings( hide('warnings', 'running', 'stdout', 'stderr'), warn_only=True ): if run('ls /etc/lsb-release'): return 'Ubuntu' elif run('ls /etc/redhat-release'): return 'RedHat' The above task executes a `run` statement, but will warn instead of aborting if the ``ls`` fails, and all output -- including the warning itself -- is prevented from printing to the user. The end result, in this scenario, is a completely silent task that allows the caller to figure out what type of system the remote host is, without incurring the handful of output that would normally occur. Thus, `settings` may be used to set any combination of environment variables in tandem with hiding (or showing) specific levels of output, or in tandem with any other piece of Fabric functionality implemented as a context manager. If ``clean_revert`` is set to ``True``, ``settings`` will **not** revert keys which are altered within the nested block, instead only reverting keys whose values remain the same as those given. More examples will make this clear; below is how ``settings`` operates normally:: # Before the block, env.parallel defaults to False, host_string to None with settings(parallel=True, host_string='myhost'): # env.parallel is True # env.host_string is 'myhost' env.host_string = 'otherhost' # env.host_string is now 'otherhost' # Outside the block: # * env.parallel is False again # * env.host_string is None again The internal modification of ``env.host_string`` is nullified -- not always desirable. That's where ``clean_revert`` comes in:: # Before the block, env.parallel defaults to False, host_string to None with settings(parallel=True, host_string='myhost', clean_revert=True): # env.parallel is True # env.host_string is 'myhost' env.host_string = 'otherhost' # env.host_string is now 'otherhost' # Outside the block: # * env.parallel is False again # * env.host_string remains 'otherhost' Brand new keys which did not exist in ``env`` prior to using ``settings`` are also preserved if ``clean_revert`` is active. When ``False``, such keys are removed when the block exits. .. versionadded:: 1.4.1 The ``clean_revert`` kwarg. """ managers = list(args) if kwargs: managers.append(_setenv(kwargs)) with ExitStack() as stack: yield tuple(stack.enter_context(cm) for cm in managers)
def transform(self): """ Main generator work loop. """ algo = self.algo def every_bar(dt_to_use, current_data=self.current_data, handle_data=algo.event_manager.handle_data): # called every tick (minute or day). if dt_to_use in algo.capital_changes: process_minute_capital_changes(dt_to_use) self.simulation_dt = dt_to_use algo.on_dt_changed(dt_to_use) blotter = algo.blotter perf_tracker = algo.perf_tracker # handle any transactions and commissions coming out new orders # placed in the last bar new_transactions, new_commissions, closed_orders = \ blotter.get_transactions(current_data) blotter.prune_orders(closed_orders) for transaction in new_transactions: perf_tracker.process_transaction(transaction) # since this order was modified, record it order = blotter.orders[transaction.order_id] perf_tracker.process_order(order) if new_commissions: for commission in new_commissions: perf_tracker.process_commission(commission) handle_data(algo, current_data, dt_to_use) # grab any new orders from the blotter, then clear the list. # this includes cancelled orders. new_orders = blotter.new_orders blotter.new_orders = [] # if we have any new orders, record them so that we know # in what perf period they were placed. if new_orders: for new_order in new_orders: perf_tracker.process_order(new_order) algo.portfolio_needs_update = True algo.account_needs_update = True algo.performance_needs_update = True def once_a_day(midnight_dt, current_data=self.current_data, data_portal=self.data_portal): perf_tracker = algo.perf_tracker if midnight_dt in algo.capital_changes: # process any capital changes that came overnight change = algo.capital_changes[midnight_dt] log.info('Processing capital change of %s at %s' % (change, midnight_dt)) perf_tracker.process_capital_changes(change, is_interday=True) # Get the positions before updating the date so that prices are # fetched for trading close instead of midnight positions = algo.perf_tracker.position_tracker.positions position_assets = algo.asset_finder.retrieve_all(positions) # set all the timestamps self.simulation_dt = midnight_dt algo.on_dt_changed(midnight_dt) # we want to wait until the clock rolls over to the next day # before cleaning up expired assets. self._cleanup_expired_assets(midnight_dt, position_assets) # handle any splits that impact any positions or any open orders. assets_we_care_about = \ viewkeys(perf_tracker.position_tracker.positions) | \ viewkeys(algo.blotter.open_orders) if assets_we_care_about: splits = data_portal.get_splits(assets_we_care_about, midnight_dt) if splits: algo.blotter.process_splits(splits) perf_tracker.position_tracker.handle_splits(splits) # call before trading start algo.before_trading_start(current_data) def handle_benchmark(date, benchmark_source=self.benchmark_source): algo.perf_tracker.all_benchmark_returns[date] = \ benchmark_source.get_value(date) def on_exit(): # Remove references to algo, data portal, et al to break cycles # and ensure deterministic cleanup of these objects when the # simulation finishes. self.algo = None self.benchmark_source = self.current_data = self.data_portal = None with ExitStack() as stack: stack.callback(on_exit) stack.enter_context(self.processor) stack.enter_context(ZiplineAPI(self.algo)) if algo.data_frequency == 'minute': def execute_order_cancellation_policy(): algo.blotter.execute_cancel_policy(DAY_END) def process_minute_capital_changes(dt): # If we are running daily emission, prices won't # necessarily be synced at the end of every minute, and we # need the up-to-date prices for capital change # calculations. We want to sync the prices as of the # last market minute, and this is okay from a data portal # perspective as we have technically not "advanced" to the # current dt yet. algo.perf_tracker.position_tracker.sync_last_sale_prices( self.algo.trading_schedule.previous_execution_minute( dt), False, self.data_portal) # process any capital changes that came between the last # and current minutes change = algo.capital_changes[dt] log.info('Processing capital change of %s at %s' % (change, dt)) algo.perf_tracker.process_capital_changes( change, is_interday=False) else: def execute_order_cancellation_policy(): pass def process_minute_capital_changes(dt): pass for dt, action in self.clock: if action == BAR: every_bar(dt) elif action == DAY_START: once_a_day(dt) elif action == DAY_END: # End of the day. if algo.perf_tracker.emission_rate == 'daily': handle_benchmark(normalize_date(dt)) execute_order_cancellation_policy() yield self._get_daily_message(dt, algo, algo.perf_tracker) elif action == MINUTE_END: handle_benchmark(dt) minute_msg = \ self._get_minute_message(dt, algo, algo.perf_tracker) yield minute_msg risk_message = algo.perf_tracker.handle_simulation_end() yield risk_message
def __init__(self): super(CleanableCM, self).__init__() self.stack = ExitStack()
def open_files(files, **kwargs): """A plural form of :func:`open_file`.""" with ExitStack() as stack: yield [stack.enter_context(open_file(f, **kwargs)) for f in files]
class ZiplineTestCase(with_metaclass(FinalMeta, TestCase)): """ Shared extensions to core unittest.TestCase. Overrides the default unittest setUp/tearDown functions with versions that use ExitStack to correctly clean up resources, even in the face of exceptions that occur during setUp/setUpClass. Subclasses **should not override setUp or setUpClass**! Instead, they should implement `init_instance_fixtures` for per-test-method resources, and `init_class_fixtures` for per-class resources. Resources that need to be cleaned up should be registered using either `enter_{class,instance}_context` or `add_{class,instance}_callback}. """ _in_setup = False @final @classmethod def setUpClass(cls): cls._class_teardown_stack = ExitStack() try: cls._base_init_fixtures_was_called = False cls.init_class_fixtures() assert cls._base_init_fixtures_was_called, ( "ZiplineTestCase.init_class_fixtures() was not called.\n" "This probably means that you overrode init_class_fixtures" " without calling super()." ) except: cls.tearDownClass() raise @classmethod def init_class_fixtures(cls): """ Override and implement this classmethod to register resources that should be created and/or torn down on a per-class basis. Subclass implementations of this should always invoke this with super() to ensure that fixture mixins work properly. """ if cls._in_setup: raise ValueError( 'Called init_class_fixtures from init_instance_fixtures.' 'Did you write super(..., self).init_class_fixtures() instead' ' of super(..., self).init_instance_fixtures()?', ) cls._base_init_fixtures_was_called = True @final @classmethod def tearDownClass(cls): cls._class_teardown_stack.close() @final @classmethod def enter_class_context(cls, context_manager): """ Enter a context manager to be exited during the tearDownClass """ if cls._in_setup: raise ValueError( 'Attempted to enter a class context in init_instance_fixtures.' '\nDid you mean to call enter_instance_context?', ) return cls._class_teardown_stack.enter_context(context_manager) @final @classmethod def add_class_callback(cls, callback): """ Register a callback to be executed during tearDownClass. Parameters ---------- callback : callable The callback to invoke at the end of the test suite. """ if cls._in_setup: raise ValueError( 'Attempted to add a class callback in init_instance_fixtures.' '\nDid you mean to call add_instance_callback?', ) return cls._class_teardown_stack.callback(callback) @final def setUp(self): type(self)._in_setup = True self._instance_teardown_stack = ExitStack() try: self._init_instance_fixtures_was_called = False self.init_instance_fixtures() assert self._init_instance_fixtures_was_called, ( "ZiplineTestCase.init_instance_fixtures() was not" " called.\n" "This probably means that you overrode" " init_instance_fixtures without calling super()." ) except: self.tearDown() raise finally: type(self)._in_setup = False def init_instance_fixtures(self): self._init_instance_fixtures_was_called = True @final def tearDown(self): self._instance_teardown_stack.close() @final def enter_instance_context(self, context_manager): """ Enter a context manager that should be exited during tearDown. """ return self._instance_teardown_stack.enter_context(context_manager) @final def add_instance_callback(self, callback): """ Register a callback to be executed during tearDown. Parameters ---------- callback : callable The callback to invoke at the end of each test. """ return self._instance_teardown_stack.callback(callback)
def transform(self, stream_in): """ Main generator work loop. """ # Initialize the mkt_close mkt_open = self.algo.perf_tracker.market_open mkt_close = self.algo.perf_tracker.market_close # inject the current algo # snapshot time to any log record generated. # with 。。。as 是一种上下文管理器,打开与关闭。exitstack()是一个语法糖 with ExitStack() as stack: stack.enter_context(self.processor) stack.enter_context(ZiplineAPI(self.algo)) data_frequency = self.sim_params.data_frequency self._call_before_trading_start(mkt_open) for date, snapshot in stream_in: #print date,u'在主循环之内的date',self.algo_start #raw_input() #for i in snapshot: # print i #raw_input() # snapshot,为迭代的数据系统,包括时间,股票数据等 # 进入主循环,跟随日期进行循环 self.simulation_dt = date #模拟日期 self.on_dt_changed(date) # If we're still in the warmup period. Use the event to # update our universe, but don't yield any perf messages, # and don't send a snapshot to handle_data. # 如果在热身阶段 # 判断是否进入交易日期,若开始了则handle_data发送给 # 判断是否进行到达模拟开始的时间 if date < self.algo_start: for event in snapshot: if event.type == DATASOURCE_TYPE.SPLIT: self.algo.blotter.process_split(event) elif event.type == DATASOURCE_TYPE.TRADE: self.update_universe(event) self.algo.perf_tracker.process_trade(event) elif event.type == DATASOURCE_TYPE.CUSTOM: self.update_universe(event) if self.algo.history_container: #print self.current_data self.algo.history_container.update( self.current_data, date) else: # 进入每日信息的处理, messages = self._process_snapshot( date, snapshot, self.algo.instant_fill, ) # Perf messages are only emitted if the snapshot contained # a benchmark event. for message in messages: yield message # When emitting minutely, we need to call # before_trading_start before the next trading day begins if date == mkt_close: if mkt_close <= self.algo.perf_tracker.last_close: before_last_close = \ mkt_close < self.algo.perf_tracker.last_close try: mkt_open, mkt_close = \ self.env.next_open_and_close(mkt_close) except NoFurtherDataError: # If at the end of backtest history, # skip advancing market close. pass if before_last_close: self._call_before_trading_start(mkt_open) elif data_frequency == 'daily': next_day = self.env.next_trading_day(date) if next_day is not None and \ next_day < self.algo.perf_tracker.last_close: self._call_before_trading_start( next_day ) #如果下一天非空,并且next不是表现的最后一天。last_close,就 self.algo.portfolio_needs_update = True self.algo.account_needs_update = True self.algo.performance_needs_update = True risk_message = self.algo.perf_tracker.handle_simulation_end() yield risk_message
def dedup_tracked1(sess, tt, ofile_reserved, query, fs, skipped): space_gain1 = space_gain2 = space_gain3 = 0 ofile_soft, ofile_hard = resource.getrlimit(resource.RLIMIT_OFILE) # Hopefully close any files we left around gc.collect() # The log can cause frequent commits, we don't mind losing them in # a crash (no need for durability). SQLite is in WAL mode, so this pragma # should disable most commit-time fsync calls without compromising # consistency. sess.execute('PRAGMA synchronous=NORMAL;') for comm1 in query: if len(sess.identity_map) > 300: sess.flush() space_gain1 += comm1.size * (comm1.inode_count - 1) tt.update(comm1=comm1) for inode in comm1.inodes: # XXX Need to cope with deleted inodes. # We cannot find them in the search-new pass, not without doing # some tracking of directory modifications to poke updated # directories to find removed elements. # rehash everytime for now # I don't know enough about how inode transaction numbers are # updated (as opposed to extent updates) to be able to actually # cache the result try: path = lookup_ino_path_one(inode.vol.fd, inode.ino) except IOError as e: if e.errno != errno.ENOENT: raise # We have a stale record for a removed inode # XXX If an inode number is reused and the second instance # is below the size cutoff, we won't update the .size # attribute and we won't get an IOError to notify us # either. Inode reuse does happen (with and without # inode_cache), so this branch isn't enough to rid us of # all stale entries. We can also get into trouble with # regular file inodes being replaced by some other kind of # inode. sess.delete(inode) continue with closing(fopenat(inode.vol.fd, path)) as rfile: inode.mini_hash_from_file(rfile) for comm2 in comm1.comm2: space_gain2 += comm2.size * (comm2.inode_count - 1) tt.update(comm2=comm2) for inode in comm2.inodes: try: path = lookup_ino_path_one(inode.vol.fd, inode.ino) except IOError as e: if e.errno != errno.ENOENT: raise sess.delete(inode) continue with closing(fopenat(inode.vol.fd, path)) as rfile: inode.fiemap_hash_from_file(rfile) if not comm2.comm3: continue comm3, = comm2.comm3 count3 = comm3.inode_count space_gain3 += comm3.size * (count3 - 1) tt.update(comm3=comm3) files = [] fds = [] fd_names = {} fd_inodes = {} by_hash = collections.defaultdict(list) # XXX I have no justification for doubling count3 ofile_req = 2 * count3 + ofile_reserved if ofile_req > ofile_soft: if ofile_req <= ofile_hard: resource.setrlimit(resource.RLIMIT_OFILE, (ofile_req, ofile_hard)) ofile_soft = ofile_req else: tt.notify( 'Too many duplicates (%d at size %d), ' 'would bring us over the open files limit (%d, %d).' % (count3, comm3.size, ofile_soft, ofile_hard)) for inode in comm3.inodes: if inode.has_updates: skipped.append(inode) continue for inode in comm3.inodes: # Open everything rw, we can't pick one for the source side # yet because the crypto hash might eliminate it. # We may also want to defragment the source. try: path = lookup_ino_path_one(inode.vol.fd, inode.ino) except IOError as e: if e.errno == errno.ENOENT: sess.delete(inode) continue raise try: afile = fopenat_rw(inode.vol.fd, path) except IOError as e: if e.errno == errno.ETXTBSY: # The file contains the image of a running process, # we can't open it in write mode. tt.notify('File %r is busy, skipping' % path) skipped.append(inode) continue elif e.errno == errno.EACCES: # Could be SELinux or immutability tt.notify('Access denied on %r, skipping' % path) skipped.append(inode) continue elif e.errno == errno.ENOENT: # The file was moved or unlinked by a racing process tt.notify('File %r may have moved, skipping' % path) skipped.append(inode) continue raise # It's not completely guaranteed we have the right inode, # there may still be race conditions at this point. # Gets re-checked below (tell and fstat). fd = afile.fileno() fd_inodes[fd] = inode fd_names[fd] = path files.append(afile) fds.append(fd) with ExitStack() as stack: for afile in files: stack.enter_context(closing(afile)) # Enter this context last immutability = stack.enter_context(ImmutableFDs(fds)) for afile in files: fd = afile.fileno() inode = fd_inodes[fd] if fd in immutability.fds_in_write_use: tt.notify('File %r is in use, skipping' % fd_names[fd]) skipped.append(inode) continue hasher = hashlib.sha1() for buf in iter(lambda: afile.read(BUFSIZE), b''): hasher.update(buf) # Gets rid of a race condition st = os.fstat(fd) if st.st_ino != inode.ino: skipped.append(inode) continue if st.st_dev != inode.vol.st_dev: skipped.append(inode) continue size = afile.tell() if size != comm3.size: if size < inode.vol.size_cutoff: # if we didn't delete this inode, it would cause # spurious comm groups in all future invocations. sess.delete(inode) else: skipped.append(inode) continue by_hash[hasher.digest()].append(afile) for fileset in by_hash.itervalues(): if len(fileset) < 2: continue sfile = fileset[0] sfd = sfile.fileno() # Commented out, defragmentation can unshare extents. # It can also disable compression as a side-effect. if False: defragment(sfd) dfiles = fileset[1:] dfiles_successful = [] for dfile in dfiles: dfd = dfile.fileno() sname = fd_names[sfd] dname = fd_names[dfd] if not cmp_files(sfile, dfile): # Probably a bug since we just used a crypto hash tt.notify('Files differ: %r %r' % (sname, dname)) assert False, (sname, dname) continue if clone_data(dest=dfd, src=sfd, check_first=True): tt.notify('Deduplicated: %r %r' % (sname, dname)) dfiles_successful.append(dfile) else: tt.notify( 'Did not deduplicate (same extents): %r %r' % (sname, dname)) if dfiles_successful: evt = DedupEvent(fs=fs, item_size=comm3.size, created=system_now()) sess.add(evt) for afile in [sfile] + dfiles_successful: inode = fd_inodes[afile.fileno()] evti = DedupEventInode(event=evt, ino=inode.ino, vol=inode.vol) sess.add(evti) sess.commit() tt.format(None) tt.notify('Potential space gain: pass 1 %d, pass 2 %d pass 3 %d' % (space_gain1, space_gain2, space_gain3)) # Restore fsync so that the final commit (in dedup_tracked) # will be durable. sess.commit() sess.execute('PRAGMA synchronous=FULL;')
def run_test(in_file, test_spec, global_cfg): """Run a single tavern test Note that each tavern test can consist of multiple requests (log in, create, update, delete, etc). The global configuration is copied and used as an initial configuration for this test. Any values which are saved from any tests are saved into this test block and can be used for formatting in later stages in the test. Args: in_file (str): filename containing this test test_spec (dict): The specification for this test global_cfg (dict): Any global configuration for this test No Longer Raises: TavernException: If any of the tests failed """ # pylint: disable=too-many-locals # Initialise test config for this test with the global configuration before # starting test_block_config = dict(global_cfg) if "variables" not in test_block_config: test_block_config["variables"] = {} tavern_box = Box({ "env_vars": dict(os.environ), }) test_block_config["variables"]["tavern"] = tavern_box if not test_spec: logger.warning("Empty test block in %s", in_file) return available_stages = {} if test_spec.get("includes"): for included in test_spec["includes"]: if "variables" in included: formatted_include = format_keys(included["variables"], {"tavern": tavern_box}) test_block_config["variables"].update(formatted_include) if "stages" in included: for stage in included["stages"]: if stage["id"] in available_stages: raise exceptions.DuplicateStageDefinitionError( "Stage with specified id already defined: {}". format(stage["id"])) available_stages[stage["id"]] = stage test_block_name = test_spec["test_name"] # Strict on body by default default_strictness = test_block_config["strict"] logger.info("Running test : %s", test_block_name) with ExitStack() as stack: test_spec["stages"] = _resolve_test_stages(test_spec, available_stages) sessions = get_extra_sessions(test_spec, test_block_config) for name, session in sessions.items(): logger.debug("Entering context for %s", name) stack.enter_context(session) # Run tests in a path in order for stage in test_spec["stages"]: if stage.get('skip'): continue test_block_config["strict"] = default_strictness # Can be overridden per stage # NOTE # this is hardcoded to check for the 'response' block. In the far # future there might not be a response block, but at the moment it # is the hardcoded value for any HTTP request. if stage.get("response", {}): if stage.get("response").get("strict", None) is not None: stage_strictness = stage.get("response").get( "strict", None) elif test_spec.get("strict", None) is not None: stage_strictness = test_spec.get("strict", None) else: stage_strictness = default_strictness logger.debug("Strict key checking for this stage is '%s'", stage_strictness) test_block_config["strict"] = stage_strictness elif default_strictness: logger.debug("Default strictness '%s' ignored for this stage", default_strictness) try: run_stage(sessions, stage, tavern_box, test_block_config) except exceptions.TavernException as e: e.stage = stage e.test_block_config = test_block_config raise if stage.get('only'): break
def do_dedup(sess, tt, chunk): #print ">>> do dedup ", chunk[0].size, chunk[0].mini_hash, len(chunk) global ofile_soft global ofile_hard global ofile_reserved global fs files = [] fds = [] fd_names = {} fd_inodes = {} by_hash = collections.defaultdict(list) # XXX I have no justification for doubling count3 ofile_req = 2 * len(chunk) + ofile_reserved if ofile_req > ofile_soft: if ofile_req <= ofile_hard: resource.setrlimit(resource.RLIMIT_OFILE, (ofile_req, ofile_hard)) ofile_soft = ofile_req else: tt.notify('Too many duplicates (%d at size %d), ' 'would bring us over the open files limit (%d, %d).' % (count3, comm3.size, ofile_soft, ofile_hard)) for inode in comm3.inodes: if inode.has_updates: skipped.append(inode) continue for inode in chunk: # Open everything rw, we can't pick one for the source side # yet because the crypto hash might eliminate it. # We may also want to defragment the source. try: path = lookup_ino_path_one(inode.vol.fd, inode.ino) except IOError as e: if e.errno == errno.ENOENT: sess.delete(inode) continue raise try: afile = fopenat_rw(inode.vol.fd, path) except IOError as e: if e.errno == errno.ETXTBSY: # The file contains the image of a running process, # we can't open it in write mode. tt.notify('File %r is busy, skipping' % path) skipped.append(inode) continue elif e.errno == errno.EACCES: # Could be SELinux or immutability tt.notify('Access denied on %r, skipping' % path) skipped.append(inode) continue elif e.errno == errno.ENOENT: # The file was moved or unlinked by a racing process tt.notify('File %r may have moved, skipping' % path) skipped.append(inode) continue raise # It's not completely guaranteed we have the right inode, # there may still be race conditions at this point. # Gets re-checked below (tell and fstat). fd = afile.fileno() fd_inodes[fd] = inode fd_names[fd] = path files.append(afile) fds.append(fd) with ExitStack() as stack: for afile in files: stack.enter_context(closing(afile)) # Enter this context last immutability = stack.enter_context(ImmutableFDs(fds)) for afile in files: fd = afile.fileno() inode = fd_inodes[fd] if fd in immutability.fds_in_write_use: tt.notify('File %r is in use, skipping' % fd_names[fd]) skipped.append(inode) continue hasher = hashlib.sha1() for buf in iter(lambda: afile.read(BUFSIZE), b''): hasher.update(buf) # Gets rid of a race condition st = os.fstat(fd) if st.st_ino != inode.ino: skipped.append(inode) continue if st.st_dev != inode.vol.st_dev: skipped.append(inode) continue size = afile.tell() if size != inode.size: if size < inode.vol.size_cutoff: # if we didn't delete this inode, it would cause # spurious comm groups in all future invocations. sess.delete(inode) else: skipped.append(inode) continue by_hash[hasher.digest()].append(afile) for fileset in by_hash.itervalues(): if len(fileset) < 2: continue sfile = fileset[0] sfd = sfile.fileno() # Commented out, defragmentation can unshare extents. # It can also disable compression as a side-effect. if False: defragment(sfd) dfiles = fileset[1:] dfiles_successful = [] for dfile in dfiles: dfd = dfile.fileno() sname = fd_names[sfd] dname = fd_names[dfd] if not cmp_files(sfile, dfile): # Probably a bug since we just used a crypto hash tt.notify('Files differ: %r %r' % (sname, dname)) assert False, (sname, dname) continue if clone_data(dest=dfd, src=sfd, check_first=True): tt.notify('Deduplicated: %r %r' % (sname, dname)) dfiles_successful.append(dfile) else: tt.notify( 'Did not deduplicate (same extents): %r %i %r %i' % (sname, fd_inodes[sfd].ino, dname, fd_inodes[dfd].ino)) if dfiles_successful: evt = DedupEvent(fs=fs, item_size=inode.size, created=system_now()) sess.add(evt) for afile in [sfile] + dfiles_successful: inode = fd_inodes[afile.fileno()] evti = DedupEventInode(event=evt, ino=inode.ino, vol=inode.vol) sess.add(evti) sess.commit()
def ingest(name, environ=os.environ, timestamp=None, assets_versions=(), show_progress=False): """Ingest data for a given bundle. Parameters ---------- name : str The name of the bundle. environ : mapping, optional The environment variables. By default this is os.environ. timestamp : datetime, optional The timestamp to use for the load. By default this is the current time. assets_versions : Iterable[int], optional Versions of the assets db to which to downgrade. show_progress : bool, optional Tell the ingest function to display the progress where possible. """ try: bundle = bundles[name] except KeyError: raise UnknownBundle(name) if timestamp is None: timestamp = pd.Timestamp.utcnow() timestamp = timestamp.tz_convert('utc').tz_localize(None) timestr = to_bundle_ingest_dirname(timestamp) cachepath = cache_path(name, environ=environ) pth.ensure_directory(pth.data_path([name, timestr], environ=environ)) pth.ensure_directory(cachepath) with dataframe_cache(cachepath, clean_on_failure=False) as cache, \ ExitStack() as stack: # we use `cleanup_on_failure=False` so that we don't purge the # cache directory if the load fails in the middle if bundle.create_writers: wd = stack.enter_context( working_dir(pth.data_path([], environ=environ))) daily_bars_path = wd.ensure_dir(*daily_equity_relative( name, timestr, environ=environ, )) daily_bar_writer = BcolzDailyBarWriter( daily_bars_path, bundle.calendar, bundle.start_session, bundle.end_session, ) # Do an empty write to ensure that the daily ctables exist # when we create the SQLiteAdjustmentWriter below. The # SQLiteAdjustmentWriter needs to open the daily ctables so # that it can compute the adjustment ratios for the dividends. daily_bar_writer.write(()) minute_bar_writer = BcolzMinuteBarWriter( wd.ensure_dir(*minute_equity_relative( name, timestr, environ=environ)), bundle.calendar, bundle.start_session, bundle.end_session, minutes_per_day=bundle.minutes_per_day, ) assets_db_path = wd.getpath(*asset_db_relative( name, timestr, environ=environ, )) asset_db_writer = AssetDBWriter(assets_db_path) adjustment_db_writer = stack.enter_context( SQLiteAdjustmentWriter( wd.getpath(*adjustment_db_relative( name, timestr, environ=environ)), BcolzDailyBarReader(daily_bars_path), bundle.calendar.all_sessions, overwrite=True, )) else: daily_bar_writer = None minute_bar_writer = None asset_db_writer = None adjustment_db_writer = None if assets_versions: raise ValueError('Need to ingest a bundle that creates ' 'writers in order to downgrade the assets' ' db.') bundle.ingest( environ, asset_db_writer, minute_bar_writer, daily_bar_writer, adjustment_db_writer, bundle.calendar, bundle.start_session, bundle.end_session, cache, show_progress, pth.data_path([name, timestr], environ=environ), ) for version in sorted(set(assets_versions), reverse=True): version_path = wd.getpath(*asset_db_relative( name, timestr, environ=environ, db_version=version, )) with working_file(version_path) as wf: shutil.copy2(assets_db_path, wf.path) downgrade(wf.path, version)
def __call__(s, *args, **kwargs): stack = ExitStack() stack.enter_context(self.flask_app.app_context()) stack.enter_context(DBMgr.getInstance().global_connection()) if getattr(s, 'request_context', False): stack.enter_context(self.flask_app.test_request_context()) args = _CelerySAWrapper.unwrap_args(args) kwargs = _CelerySAWrapper.unwrap_kwargs(kwargs) plugin = getattr(s, 'plugin', kwargs.pop('__current_plugin__', None)) if isinstance(plugin, basestring): plugin_name = plugin plugin = plugin_engine.get_plugin(plugin) if plugin is None: stack.close() raise ValueError('Plugin not active: ' + plugin_name) stack.enter_context(plugin_context(plugin)) clearCache() with stack: return super(IndicoTask, s).__call__(*args, **kwargs)
def compare_with_cpu_command(args): configure_progress(os.path.join(args.outdir, 'progress.txt')) class TrainConfig: pass class OptConfig: pass class MonConfig: pass # Load config with current context files = [] files.append(args.config) with nn.parameter_scope('current'): info = load.load(files) parameters = get_parameters(grad_only=False) config = TrainConfig() config.global_config = info.global_config config.training_config = info.training_config config.optimizers = OrderedDict() for name, opt in info.optimizers.items(): o = OptConfig() o.optimizer = opt o.data_iterator = None config.optimizers[name] = o config.monitors = OrderedDict() for name, mon in info.monitors.items(): m = MonConfig() m.monitor = mon m.data_iterator = None config.monitors[name] = m # Load config with cpu context files = [] files.append(args.config2) with nn.parameter_scope('cpu'): info_cpu = load.load(files) cpu_parameters = get_parameters(grad_only=False) config_cpu = TrainConfig() config_cpu.global_config = info_cpu.global_config config_cpu.training_config = info_cpu.training_config config_cpu.optimizers = OrderedDict() for name, opt in info_cpu.optimizers.items(): o = OptConfig() o.optimizer = opt o.data_iterator = None config_cpu.optimizers[name] = o config_cpu.monitors = OrderedDict() for name, mon in info_cpu.monitors.items(): m = MonConfig() m.monitor = mon m.data_iterator = None config_cpu.monitors[name] = m result_array = [['1-Correl']] # Profile Optimizer with ExitStack() as stack: for name, o in config.optimizers.items(): o.data_iterator = stack.enter_context(o.optimizer.data_iterator()) for name, o in config_cpu.optimizers.items(): o.data_iterator = stack.enter_context(o.optimizer.data_iterator()) result_array = compare_optimizer(config, parameters, config_cpu, cpu_parameters, result_array) # Write profiling result import csv with open(args.outdir + os.sep + 'compare_with_cpu.csv', 'w') as f: writer = csv.writer(f, lineterminator='\n') writer.writerows(result_array) logger.log(99, 'Compare with CPU Completed.') progress(None) return True
class ZiplineTestCase(with_metaclass(FinalMeta, TestCase)): """ Shared extensions to core unittest.TestCase. Overrides the default unittest setUp/tearDown functions with versions that use ExitStack to correctly clean up resources, even in the face of exceptions that occur during setUp/setUpClass. Subclasses **should not override setUp or setUpClass**! Instead, they should implement `init_instance_fixtures` for per-test-method resources, and `init_class_fixtures` for per-class resources. Resources that need to be cleaned up should be registered using either `enter_{class,instance}_context` or `add_{class,instance}_callback}. """ _in_setup = False @final @classmethod def setUpClass(cls): # Hold a set of all the "static" attributes on the class. These are # things that are not populated after the class was created like # methods or other class level attributes. cls._static_class_attributes = set(vars(cls)) cls._class_teardown_stack = ExitStack() try: cls._base_init_fixtures_was_called = False cls.init_class_fixtures() assert cls._base_init_fixtures_was_called, ( "ZiplineTestCase.init_class_fixtures() was not called.\n" "This probably means that you overrode init_class_fixtures" " without calling super()." ) except: cls.tearDownClass() raise @classmethod def init_class_fixtures(cls): """ Override and implement this classmethod to register resources that should be created and/or torn down on a per-class basis. Subclass implementations of this should always invoke this with super() to ensure that fixture mixins work properly. """ if cls._in_setup: raise ValueError( 'Called init_class_fixtures from init_instance_fixtures.' 'Did you write super(..., self).init_class_fixtures() instead' ' of super(..., self).init_instance_fixtures()?', ) cls._base_init_fixtures_was_called = True @final @classmethod def tearDownClass(cls): # We need to get this before it's deleted by the loop. stack = cls._class_teardown_stack for name in set(vars(cls)) - cls._static_class_attributes: # Remove all of the attributes that were added after the class was # constructed. This cleans up any large test data that is class # scoped while still allowing subclasses to access class level # attributes. delattr(cls, name) stack.close() @final @classmethod def enter_class_context(cls, context_manager): """ Enter a context manager to be exited during the tearDownClass """ if cls._in_setup: raise ValueError( 'Attempted to enter a class context in init_instance_fixtures.' '\nDid you mean to call enter_instance_context?', ) return cls._class_teardown_stack.enter_context(context_manager) @final @classmethod def add_class_callback(cls, callback): """ Register a callback to be executed during tearDownClass. Parameters ---------- callback : callable The callback to invoke at the end of the test suite. """ if cls._in_setup: raise ValueError( 'Attempted to add a class callback in init_instance_fixtures.' '\nDid you mean to call add_instance_callback?', ) return cls._class_teardown_stack.callback(callback) @final def setUp(self): type(self)._in_setup = True self._pre_setup_attrs = set(vars(self)) self._instance_teardown_stack = ExitStack() try: self._init_instance_fixtures_was_called = False self.init_instance_fixtures() assert self._init_instance_fixtures_was_called, ( "ZiplineTestCase.init_instance_fixtures() was not" " called.\n" "This probably means that you overrode" " init_instance_fixtures without calling super()." ) except: self.tearDown() raise finally: type(self)._in_setup = False def init_instance_fixtures(self): self._init_instance_fixtures_was_called = True @final def tearDown(self): # We need to get this before it's deleted by the loop. stack = self._instance_teardown_stack for attr in set(vars(self)) - self._pre_setup_attrs: delattr(self, attr) stack.close() @final def enter_instance_context(self, context_manager): """ Enter a context manager that should be exited during tearDown. """ return self._instance_teardown_stack.enter_context(context_manager) @final def add_instance_callback(self, callback): """ Register a callback to be executed during tearDown. Parameters ---------- callback : callable The callback to invoke at the end of each test. """ return self._instance_teardown_stack.callback(callback)
import sys from .program import TestProgram from .is_standalone import is_standalone_use from .utils import PY2, drop_into_debugger if PY2: from contextlib2 import ExitStack else: from contextlib import ExitStack with ExitStack() as stack: if '--debug' in sys.argv: stack.enter_context(drop_into_debugger()) is_standalone_use(False) TestProgram(module=None)