예제 #1
0
 def test_add_1(self):
     # can add one strategy, and adding it again raises error.
     sm = StrategyManager()
     self.assertEquals(len(sm.strategies.keys()), 0)
     sm.add_strategy(self.saa)
     self.assertEquals(len(sm.strategies.keys()), 1)
     self.assertRaises(DuplicateStrategyError, sm.add_strategy, self.saa)
예제 #2
0
 def test_get_strategy_by_name_2(self):
     # raises error if name isn't one of strategies under management
     sm = StrategyManager()
     sm.add_strategy(self.saa)
     sm.add_strategy(self.sab)
     sm.add_strategy(self.sba)
     self.assertRaises(MissingStrategyError, sm.get_strategy, 
             self.sabnnn.name)
예제 #3
0
 def test_add_7(self):
     # cannot resolve conflict on methods_used name.
     sm = StrategyManager()
     sm.add_strategy(self.sba)
     self.assertEquals(len(sm.strategies.keys()), 1)
     # because sabnba is has different methods_used
     # but the same methods_used_name as existing Strategy.
     self.assertRaises(MethodsUsedNameForbiddenError, sm.add_strategy, 
             self.sabnba)
예제 #4
0
 def test_add_6(self):
     # cannot resolve conflict on settings name if methods_used match.
     sm = StrategyManager()
     sm.add_strategy(self.saa)
     self.assertEquals(len(sm.strategies.keys()), 1)
     # because sabnaa is has the same methods_used/methods_used_name and
     #   different settings AND the same settings_name.
     self.assertRaises(SettingsNameForbiddenError, sm.add_strategy, 
             self.sabnaa)
예제 #5
0
 def test_add_2(self):
     # can add two different strategies, adding either again raises error.
     # same methods used, different settings
     sm = StrategyManager()
     sm.add_strategy(self.saa)
     self.assertEquals(len(sm.strategies.keys()), 1)
     sm.add_strategy(self.sab)
     self.assertEquals(len(sm.strategies.keys()), 2)
     self.assertRaises(DuplicateStrategyError, sm.add_strategy, self.sab)
     self.assertRaises(DuplicateStrategyError, sm.add_strategy, self.saa)
예제 #6
0
 def test_add_5(self):
     # when adding a strategy, it doesn't matter what it's name is, if it
     # is already under management, it raises an error
     sm = StrategyManager()
     sm.add_strategy(self.saa)
     self.assertEquals(len(sm.strategies.keys()), 1)
     sm.add_strategy(self.sab)
     self.assertEquals(len(sm.strategies.keys()), 2)
     # because abnba is already there, but known as A(b) not B(a)
     self.assertRaises(DuplicateStrategyError, sm.add_strategy, self.sabnba)
     self.assertRaises(DuplicateStrategyError, sm.add_strategy, self.sabnnn)
예제 #7
0
    def __init__(self, module_suffix=None):
        path_utils.setup_user_directories(app_name="spikepy")

        self.config_manager = config_manager
        self.trial_manager = TrialManager()
        self.plugin_manager = plugin_manager
        self.strategy_manager = StrategyManager()
        self.strategy_manager.load_all_strategies()
        self._current_strategy = None
        self.current_strategy = self.get_default_strategy()
        self.process_manager = ProcessManager(self.trial_manager)

        # register callback for open_files
        self.process_manager.open_files.add_callback(self._files_opened, takes_target_results=True)
예제 #8
0
    def test_get_strategy_by_name(self):
        # can get strategy under management given its name as well.
        sm = StrategyManager()
        sm.add_strategy(self.saa)
        sm.add_strategy(self.sab)
        sm.add_strategy(self.sba)
        sab_managed = sm.get_strategy(self.sab.name)

        self.assertTrue(sab_managed is not self.sab)
        self.assertEquals(sab_managed ,  self.sab)
예제 #9
0
    def test_get_strategy_by_name_and_by_strategy(self):
        # getting strategy by name or by strategy yields same result.
        sm = StrategyManager()
        sm.add_strategy(self.saa)
        sm.add_strategy(self.sab)
        sm.add_strategy(self.sba)

        sab_by_name = sm.get_strategy(self.sab.name)
        sab_by_strategy = sm.get_strategy(self.sab)

        self.assertTrue(sab_by_name is not self.sab)
        self.assertTrue(sab_by_name is sab_by_strategy)

        self.assertRaises(ArgumentTypeError, sm.get_strategy, True)
예제 #10
0
    def test_get_strategy_name(self):
        sm = StrategyManager()
        name = sm.get_strategy_name(self.saa)
        self.assertEquals(name, msn(pt.CUSTOM_SC, pt.CUSTOM_LC))

        # after adding saa, name should be same as saa
        sm.add_strategy(self.saa)
        name = sm.get_strategy_name(self.saa)
        self.assertEquals(name, self.saa.name)
        # name should be same methods_used part, but custom settings for sab
        name = sm.get_strategy_name(self.sab)
        self.assertEquals(name, msn(self.sab.methods_used_name, pt.CUSTOM_LC))

        name = sm.get_strategy_name(self.sba)
        self.assertEquals(name, msn(pt.CUSTOM_SC, pt.CUSTOM_LC))
예제 #11
0
    def test_get_strategy_by_strategy(self):
        # can get the strategy from manager if we have equivalent strategy,
        # regardless of names.
        # managed strategies are copies of strategies, not references.
        sm = StrategyManager()
        sm.add_strategy(self.saa)
        sm.add_strategy(self.sab)
        sm.add_strategy(self.sba)
        sab_managed = sm.get_strategy(self.sab)
        sabnnn_managed = sm.get_strategy(self.sabnnn)

        self.assertTrue(sab_managed is not self.sab)
        self.assertEquals(sab_managed ,  self.sab)

        self.assertTrue(sabnnn_managed is not self.sabnnn)
        self.assertEquals(sabnnn_managed ,  self.sab)
예제 #12
0
 def test_add_4(self):
     # adding strategy with a methods_used set that already has a name
     #   added strategy will adopt existing name.
     sm = StrategyManager()
     sm.add_strategy(self.saa)
     self.assertEquals(len(sm.strategies.keys()), 1)
     sm.add_strategy(self.sabnnn)
     self.assertEquals(len(sm.strategies.keys()), 2)
     # stored strategy is called A(none) and is same as sab
     smsabnnn = sm.get_strategy(msn('a','none'))
     self.assertEquals(smsabnnn ,  self.sab)
     # because ab is already there, but known as A(none)
     self.assertRaises(DuplicateStrategyError, sm.add_strategy, self.sab)
     self.assertRaises(DuplicateStrategyError, sm.add_strategy, smsabnnn)
     self.assertRaises(DuplicateStrategyError, sm.add_strategy, self.sabnnn)
예제 #13
0
    def test_remove_strategy(self):
        # can remove strategies by name or by strategy.
        sm = StrategyManager()
        sm.add_strategy(self.saa)
        sm.add_strategy(self.sab)
        sm.add_strategy(self.sba)
        self.assertEquals(len(sm.strategies.keys()), 3)
        sm.remove_strategy(self.sab)
        self.assertEquals(len(sm.strategies.keys()), 2)
        self.assertRaises(MissingStrategyError, sm.get_strategy, self.sab)

        sm.add_strategy(self.sab)
        self.assertEquals(len(sm.strategies.keys()), 3)
        sm.remove_strategy(self.sab.name)
        self.assertEquals(len(sm.strategies.keys()), 2)
        self.assertRaises(MissingStrategyError, sm.get_strategy, self.sab.name)
예제 #14
0
 def test_get_strategy_by_strategy_2(self):
     # asking for strategy not under management raises error
     sm = StrategyManager()
     sm.add_strategy(self.saa)
     sm.add_strategy(self.sba)
     self.assertRaises(MissingStrategyError, sm.get_strategy, self.sab)
예제 #15
0
class Session(object):
    def __init__(self, module_suffix=None):
        path_utils.setup_user_directories(app_name="spikepy")

        self.config_manager = config_manager
        self.trial_manager = TrialManager()
        self.plugin_manager = plugin_manager
        self.strategy_manager = StrategyManager()
        self.strategy_manager.load_all_strategies()
        self._current_strategy = None
        self.current_strategy = self.get_default_strategy()
        self.process_manager = ProcessManager(self.trial_manager)

        # register callback for open_files
        self.process_manager.open_files.add_callback(self._files_opened, takes_target_results=True)

    # FILE RELATED
    def export(self, data_interpreter_name, base_path=None, **kwargs):
        if base_path is None:
            base_path = os.getcwd()

        di = self.plugin_manager.data_interpreters[data_interpreter_name]
        return di.write_data_file(self.marked_trials, base_path, **kwargs)

    def load(self, filename):
        """Load session from a file."""
        return self.open_file(filename)

    def open_file(self, fullpath):
        """Open file located at fullpath."""
        return self.process_manager.open_file(fullpath)

    @supports_callbacks
    def open_files(self, fullpaths):
        """Open the files located at fullpaths"""
        return self.process_manager.open_files(fullpaths)

    def save(self, filename, gzipped=True):
        """Save this session."""
        if not filename.endswith(".ses"):
            filename = "%s.ses" % filename

        trial_dicts = []
        for trial in self.trials:
            trial_dicts.append(trial.as_dict)
        strategy_dict = self.current_strategy.as_dict
        session_dict = {"trials": trial_dicts, "strategy": strategy_dict}
        if gzipped:
            ofile = gzip.open(filename, "wb")
        else:
            ofile = open(filename, "wb")
        cPickle.dump(session_dict, ofile, protocol=-1)
        ofile.close()
        return filename

    # TRIAL RELATED
    def get_trial(self, name_or_id):
        """Return the trial with the given name or id"""
        if isinstance(name_or_id, uuid.UUID):
            return self.trial_manager.get_trial_with_id(name_or_id)
        else:
            return self.trial_manager.get_trial_with_name(name_or_id)

    def get_trial_with_name(self, name):
        """
        Find the trial with display_name=<name> and return it.
        Raises RuntimeError if trial cannot be found.
        """
        return self.trial_manager.get_trial_with_name(name)

    def mark_all_trials(self, status=True):
        """Mark all trials according to <status>"""
        for trial in self.trials:
            try:
                self.mark_trial(trial.display_name, status)
            except CannotMarkTrialError:
                pass

    @supports_callbacks
    def mark_trial(self, name_or_id, status=True):
        """Mark trial with name_or_id according to <status>."""
        trial = self.get_trial(name_or_id)
        return self.trial_manager.mark_trial(trial.display_name, status=status)

    @property
    def marked_trials(self):
        """Return all currently marked trials."""
        return self.trial_manager.marked_trials

    def remove_marked_trials(self):
        """Remove all currently marked trials."""
        results = []
        for trial in self.marked_trials:
            results.append(self.remove_trial(self.marked_trials))
        return results

    @supports_callbacks
    def remove_trial(self, name_or_id):
        """Remove the trial with name or id given."""
        trial = self.get_trial(name_or_id)
        return self.trial_manager.remove_trial(trial)

    @supports_callbacks
    def rename_trial(self, old_name_or_id, proposed_name):
        """Find trial with <old_name_or_id> and rename it to <proposed_name>."""
        trial = self.get_trial(old_name_or_id)
        return self.trial_manager.rename_trial(trial.display_name, proposed_name)

    @property
    def trials(self):
        """Return all currently marked and unmarked trials."""
        return self.trial_manager.trials

    def visualize(self, trial_name, visualization_name, **kwargs):
        """
            Generate and display the visualization with the given 
        <visualization_name> (or name subset) using the trial with 
        name <trial_name>.
        """
        visualization = self.plugin_manager.visualizations[visualization_name]
        trial = self.get_trial(trial_name)

        return visualization.draw(trial, **kwargs)

    # STRATEGY RELATED
    @property
    def current_strategy(self):
        """The currently selected strategy."""
        return self._current_strategy

    @current_strategy.setter
    def current_strategy(self, strategy_or_name):
        """Set the current strategy with either a name or a Strategy object."""
        if isinstance(strategy_or_name, Strategy):
            if strategy_or_name is not self.current_strategy:
                try:
                    strategy = self.strategy_manager.get_strategy(strategy_or_name)
                except MissingStrategyError:  # its okay if not under management
                    strategy = strategy_or_name
                strategy_or_name.name = self.strategy_manager.get_strategy_name(strategy_or_name)
                self._set_current_strategy(strategy_or_name)
        else:
            strategy = self.strategy_manager.get_strategy(strategy_or_name)
            if strategy is not self.current_strategy:
                self._set_current_strategy(strategy)

    @supports_callbacks
    def _set_current_strategy(self, strategy):
        validated_strategy = self.plugin_manager.validate_strategy(strategy)
        self._current_strategy = strategy
        return strategy

    def save_current_strategy(self, strategy_name):
        """Save the current strategy, giving it the name <strategy_name>"""
        self.strategy_manager.save_current_strategy(strategy_name)

    # RUN RELATED
    def join_run(self):
        """Join the run thread (if there is one)."""
        if hasattr(self, "_run_thread"):
            self._run_thread.join()

    @property
    def is_running(self):
        if hasattr(self, "_run_thread"):
            return self._run_thread.is_alive()
        else:
            return False

    def run(self, stage_name=None, strategy=None, message_queue=multiprocessing.Queue(), async=False):
        """
            Run the given strategy (defaults to current_strategy), or a stage 
        from that strategy.  Results are placed into the appropriate 
        trial's resources.
        Inputs:
            strategy: A Strategy object.  If not passed, 
                    session.current_strategy will be used.
            stage_name: If passed, only that stage will be run.
            message_queue: If passed, will be populated with run messages.
            async: If True, processing will run in a separate thread.  This 
                    thread can be joined with session.join_run()
        """
        if strategy is None or not isinstance(strategy, Strategy):
            strategy = self.current_strategy

        # if still none, then abort run.
        if strategy is None:
            raise NoCurrentStrategyError("You must supply a strategy or set the session's current strategy.")

        self.process_manager.prepare_to_run_strategy(strategy, stage_name=stage_name)

        self._run_thread = threading.Thread(
            target=self.process_manager.run_tasks, kwargs={"message_queue": message_queue}
        )
        self._run_thread.start()
        if not async:
            self._run_thread.join()