def test_add_1(self): # can add one strategy, and adding it again raises error. sm = StrategyManager() self.assertEquals(len(sm.strategies.keys()), 0) sm.add_strategy(self.saa) self.assertEquals(len(sm.strategies.keys()), 1) self.assertRaises(DuplicateStrategyError, sm.add_strategy, self.saa)
def test_get_strategy_by_name_2(self): # raises error if name isn't one of strategies under management sm = StrategyManager() sm.add_strategy(self.saa) sm.add_strategy(self.sab) sm.add_strategy(self.sba) self.assertRaises(MissingStrategyError, sm.get_strategy, self.sabnnn.name)
def test_add_7(self): # cannot resolve conflict on methods_used name. sm = StrategyManager() sm.add_strategy(self.sba) self.assertEquals(len(sm.strategies.keys()), 1) # because sabnba is has different methods_used # but the same methods_used_name as existing Strategy. self.assertRaises(MethodsUsedNameForbiddenError, sm.add_strategy, self.sabnba)
def test_add_6(self): # cannot resolve conflict on settings name if methods_used match. sm = StrategyManager() sm.add_strategy(self.saa) self.assertEquals(len(sm.strategies.keys()), 1) # because sabnaa is has the same methods_used/methods_used_name and # different settings AND the same settings_name. self.assertRaises(SettingsNameForbiddenError, sm.add_strategy, self.sabnaa)
def test_add_2(self): # can add two different strategies, adding either again raises error. # same methods used, different settings sm = StrategyManager() sm.add_strategy(self.saa) self.assertEquals(len(sm.strategies.keys()), 1) sm.add_strategy(self.sab) self.assertEquals(len(sm.strategies.keys()), 2) self.assertRaises(DuplicateStrategyError, sm.add_strategy, self.sab) self.assertRaises(DuplicateStrategyError, sm.add_strategy, self.saa)
def test_add_5(self): # when adding a strategy, it doesn't matter what it's name is, if it # is already under management, it raises an error sm = StrategyManager() sm.add_strategy(self.saa) self.assertEquals(len(sm.strategies.keys()), 1) sm.add_strategy(self.sab) self.assertEquals(len(sm.strategies.keys()), 2) # because abnba is already there, but known as A(b) not B(a) self.assertRaises(DuplicateStrategyError, sm.add_strategy, self.sabnba) self.assertRaises(DuplicateStrategyError, sm.add_strategy, self.sabnnn)
def __init__(self, module_suffix=None): path_utils.setup_user_directories(app_name="spikepy") self.config_manager = config_manager self.trial_manager = TrialManager() self.plugin_manager = plugin_manager self.strategy_manager = StrategyManager() self.strategy_manager.load_all_strategies() self._current_strategy = None self.current_strategy = self.get_default_strategy() self.process_manager = ProcessManager(self.trial_manager) # register callback for open_files self.process_manager.open_files.add_callback(self._files_opened, takes_target_results=True)
def test_get_strategy_by_name(self): # can get strategy under management given its name as well. sm = StrategyManager() sm.add_strategy(self.saa) sm.add_strategy(self.sab) sm.add_strategy(self.sba) sab_managed = sm.get_strategy(self.sab.name) self.assertTrue(sab_managed is not self.sab) self.assertEquals(sab_managed , self.sab)
def test_get_strategy_by_name_and_by_strategy(self): # getting strategy by name or by strategy yields same result. sm = StrategyManager() sm.add_strategy(self.saa) sm.add_strategy(self.sab) sm.add_strategy(self.sba) sab_by_name = sm.get_strategy(self.sab.name) sab_by_strategy = sm.get_strategy(self.sab) self.assertTrue(sab_by_name is not self.sab) self.assertTrue(sab_by_name is sab_by_strategy) self.assertRaises(ArgumentTypeError, sm.get_strategy, True)
def test_get_strategy_name(self): sm = StrategyManager() name = sm.get_strategy_name(self.saa) self.assertEquals(name, msn(pt.CUSTOM_SC, pt.CUSTOM_LC)) # after adding saa, name should be same as saa sm.add_strategy(self.saa) name = sm.get_strategy_name(self.saa) self.assertEquals(name, self.saa.name) # name should be same methods_used part, but custom settings for sab name = sm.get_strategy_name(self.sab) self.assertEquals(name, msn(self.sab.methods_used_name, pt.CUSTOM_LC)) name = sm.get_strategy_name(self.sba) self.assertEquals(name, msn(pt.CUSTOM_SC, pt.CUSTOM_LC))
def test_get_strategy_by_strategy(self): # can get the strategy from manager if we have equivalent strategy, # regardless of names. # managed strategies are copies of strategies, not references. sm = StrategyManager() sm.add_strategy(self.saa) sm.add_strategy(self.sab) sm.add_strategy(self.sba) sab_managed = sm.get_strategy(self.sab) sabnnn_managed = sm.get_strategy(self.sabnnn) self.assertTrue(sab_managed is not self.sab) self.assertEquals(sab_managed , self.sab) self.assertTrue(sabnnn_managed is not self.sabnnn) self.assertEquals(sabnnn_managed , self.sab)
def test_add_4(self): # adding strategy with a methods_used set that already has a name # added strategy will adopt existing name. sm = StrategyManager() sm.add_strategy(self.saa) self.assertEquals(len(sm.strategies.keys()), 1) sm.add_strategy(self.sabnnn) self.assertEquals(len(sm.strategies.keys()), 2) # stored strategy is called A(none) and is same as sab smsabnnn = sm.get_strategy(msn('a','none')) self.assertEquals(smsabnnn , self.sab) # because ab is already there, but known as A(none) self.assertRaises(DuplicateStrategyError, sm.add_strategy, self.sab) self.assertRaises(DuplicateStrategyError, sm.add_strategy, smsabnnn) self.assertRaises(DuplicateStrategyError, sm.add_strategy, self.sabnnn)
def test_remove_strategy(self): # can remove strategies by name or by strategy. sm = StrategyManager() sm.add_strategy(self.saa) sm.add_strategy(self.sab) sm.add_strategy(self.sba) self.assertEquals(len(sm.strategies.keys()), 3) sm.remove_strategy(self.sab) self.assertEquals(len(sm.strategies.keys()), 2) self.assertRaises(MissingStrategyError, sm.get_strategy, self.sab) sm.add_strategy(self.sab) self.assertEquals(len(sm.strategies.keys()), 3) sm.remove_strategy(self.sab.name) self.assertEquals(len(sm.strategies.keys()), 2) self.assertRaises(MissingStrategyError, sm.get_strategy, self.sab.name)
def test_get_strategy_by_strategy_2(self): # asking for strategy not under management raises error sm = StrategyManager() sm.add_strategy(self.saa) sm.add_strategy(self.sba) self.assertRaises(MissingStrategyError, sm.get_strategy, self.sab)
class Session(object): def __init__(self, module_suffix=None): path_utils.setup_user_directories(app_name="spikepy") self.config_manager = config_manager self.trial_manager = TrialManager() self.plugin_manager = plugin_manager self.strategy_manager = StrategyManager() self.strategy_manager.load_all_strategies() self._current_strategy = None self.current_strategy = self.get_default_strategy() self.process_manager = ProcessManager(self.trial_manager) # register callback for open_files self.process_manager.open_files.add_callback(self._files_opened, takes_target_results=True) # FILE RELATED def export(self, data_interpreter_name, base_path=None, **kwargs): if base_path is None: base_path = os.getcwd() di = self.plugin_manager.data_interpreters[data_interpreter_name] return di.write_data_file(self.marked_trials, base_path, **kwargs) def load(self, filename): """Load session from a file.""" return self.open_file(filename) def open_file(self, fullpath): """Open file located at fullpath.""" return self.process_manager.open_file(fullpath) @supports_callbacks def open_files(self, fullpaths): """Open the files located at fullpaths""" return self.process_manager.open_files(fullpaths) def save(self, filename, gzipped=True): """Save this session.""" if not filename.endswith(".ses"): filename = "%s.ses" % filename trial_dicts = [] for trial in self.trials: trial_dicts.append(trial.as_dict) strategy_dict = self.current_strategy.as_dict session_dict = {"trials": trial_dicts, "strategy": strategy_dict} if gzipped: ofile = gzip.open(filename, "wb") else: ofile = open(filename, "wb") cPickle.dump(session_dict, ofile, protocol=-1) ofile.close() return filename # TRIAL RELATED def get_trial(self, name_or_id): """Return the trial with the given name or id""" if isinstance(name_or_id, uuid.UUID): return self.trial_manager.get_trial_with_id(name_or_id) else: return self.trial_manager.get_trial_with_name(name_or_id) def get_trial_with_name(self, name): """ Find the trial with display_name=<name> and return it. Raises RuntimeError if trial cannot be found. """ return self.trial_manager.get_trial_with_name(name) def mark_all_trials(self, status=True): """Mark all trials according to <status>""" for trial in self.trials: try: self.mark_trial(trial.display_name, status) except CannotMarkTrialError: pass @supports_callbacks def mark_trial(self, name_or_id, status=True): """Mark trial with name_or_id according to <status>.""" trial = self.get_trial(name_or_id) return self.trial_manager.mark_trial(trial.display_name, status=status) @property def marked_trials(self): """Return all currently marked trials.""" return self.trial_manager.marked_trials def remove_marked_trials(self): """Remove all currently marked trials.""" results = [] for trial in self.marked_trials: results.append(self.remove_trial(self.marked_trials)) return results @supports_callbacks def remove_trial(self, name_or_id): """Remove the trial with name or id given.""" trial = self.get_trial(name_or_id) return self.trial_manager.remove_trial(trial) @supports_callbacks def rename_trial(self, old_name_or_id, proposed_name): """Find trial with <old_name_or_id> and rename it to <proposed_name>.""" trial = self.get_trial(old_name_or_id) return self.trial_manager.rename_trial(trial.display_name, proposed_name) @property def trials(self): """Return all currently marked and unmarked trials.""" return self.trial_manager.trials def visualize(self, trial_name, visualization_name, **kwargs): """ Generate and display the visualization with the given <visualization_name> (or name subset) using the trial with name <trial_name>. """ visualization = self.plugin_manager.visualizations[visualization_name] trial = self.get_trial(trial_name) return visualization.draw(trial, **kwargs) # STRATEGY RELATED @property def current_strategy(self): """The currently selected strategy.""" return self._current_strategy @current_strategy.setter def current_strategy(self, strategy_or_name): """Set the current strategy with either a name or a Strategy object.""" if isinstance(strategy_or_name, Strategy): if strategy_or_name is not self.current_strategy: try: strategy = self.strategy_manager.get_strategy(strategy_or_name) except MissingStrategyError: # its okay if not under management strategy = strategy_or_name strategy_or_name.name = self.strategy_manager.get_strategy_name(strategy_or_name) self._set_current_strategy(strategy_or_name) else: strategy = self.strategy_manager.get_strategy(strategy_or_name) if strategy is not self.current_strategy: self._set_current_strategy(strategy) @supports_callbacks def _set_current_strategy(self, strategy): validated_strategy = self.plugin_manager.validate_strategy(strategy) self._current_strategy = strategy return strategy def save_current_strategy(self, strategy_name): """Save the current strategy, giving it the name <strategy_name>""" self.strategy_manager.save_current_strategy(strategy_name) # RUN RELATED def join_run(self): """Join the run thread (if there is one).""" if hasattr(self, "_run_thread"): self._run_thread.join() @property def is_running(self): if hasattr(self, "_run_thread"): return self._run_thread.is_alive() else: return False def run(self, stage_name=None, strategy=None, message_queue=multiprocessing.Queue(), async=False): """ Run the given strategy (defaults to current_strategy), or a stage from that strategy. Results are placed into the appropriate trial's resources. Inputs: strategy: A Strategy object. If not passed, session.current_strategy will be used. stage_name: If passed, only that stage will be run. message_queue: If passed, will be populated with run messages. async: If True, processing will run in a separate thread. This thread can be joined with session.join_run() """ if strategy is None or not isinstance(strategy, Strategy): strategy = self.current_strategy # if still none, then abort run. if strategy is None: raise NoCurrentStrategyError("You must supply a strategy or set the session's current strategy.") self.process_manager.prepare_to_run_strategy(strategy, stage_name=stage_name) self._run_thread = threading.Thread( target=self.process_manager.run_tasks, kwargs={"message_queue": message_queue} ) self._run_thread.start() if not async: self._run_thread.join()