コード例 #1
0
    def runtask(self,
                base_class=experiment.Experiment,
                feats=[],
                cli=False,
                **kwargs):
        '''
        Begin running of task
        '''
        now = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
        log_str("\n{}\n-------------------\nRunning new task: {}\n".format(
            now, base_class))

        if not cli:
            self.start_websock()

        if None in feats:
            raise ValueError("Features not found properly in database!")

        # initialize task status
        self.status.value = b"running" if 'saveid' in kwargs else b"testing"

        if 'seq' in kwargs:
            kwargs['seq_params'] = kwargs['seq'].params
            kwargs['seq'] = kwargs['seq'].get(
            )  ## retreive the database data on this end of the pipe

        use_websock = not (self.websock is None)
        if use_websock:
            feats.insert(0, websocket.NotifyFeat)
            kwargs['params']['websock'] = self.websock
            kwargs['params']['tracker_status'] = self.status

        task_class = experiment.make(base_class, feats=feats)

        # process parameters
        params = kwargs['params']
        if isinstance(params, str):
            params = Parameters(params)
        elif isinstance(params, dict):
            params = Parameters.from_dict(params)
        params.trait_norm(task_class.class_traits())
        params = params.params  # dict
        kwargs.pop('params', None)

        log_str("Spawning process...")
        log_str(str(kwargs))

        # Spawn the process
        self.proc = experiment.task_wrapper.TaskWrapper(
            log_filename=log_filename,
            params=params,
            target_class=task_class,
            websock=self.websock,
            status=self.status,
            **kwargs)

        # create a proxy for interacting with attributes/functions of the task.
        # The task runs in a separate process and we cannot directly access python
        # attributes of objects in other processes
        self.task_proxy, _ = self.proc.start()
コード例 #2
0
    def test_save_hdf(self):
        TestFeat = experiment.make(TestExp, feats=[SaveHDF])
        feat = TestFeat()
        feat.add_dtype("dummy_feat_for_test", "f8", (1, ))

        # start the feature
        feat.start()

        feat.join()

        mock_db = mocks.MockDatabase()
        feat.cleanup(mock_db, "saveHDF_test_output")

        hdf = tables.open_file("saveHDF_test_output.hdf")

        saved_msgs = [x.decode('utf-8') for x in hdf.root.task_msgs[:]["msg"]]
        self.assertEqual(saved_msgs,
                         ['wait', 'trial', 'reward', 'wait', 'None'])
        self.assertEqual(hdf.root.task_msgs[:]["time"].tolist(),
                         [0, 100, 101, 201, 601])

        self.assertTrue(
            np.all(hdf.root.task[:]["dummy_feat_for_test"][251:300] == -1))
        self.assertTrue(
            np.all(hdf.root.task[:]["dummy_feat_for_test"][:251] == 0))
        self.assertTrue(
            np.all(hdf.root.task[:]["dummy_feat_for_test"][300:] == 0))
コード例 #3
0
def consolerun(base_class='',
               feats=[],
               exp_params=dict(),
               gen_fn=None,
               gen_params=dict()):
    if isinstance(base_class, str):
        # assume that it's the name of a task as stored in the database
        base_class = models.Task.objects.get(name=base_class).get()

    for k, feat in enumerate(feats):
        # assume that the feature is input as the name of a feature already known to the database
        if isinstance(feat, str):
            feats[k] = models.Feature.objects.get(name=feat).get()

    # Run the pseudo-metaclass constructor
    Exp = experiment.make(base_class, feats=feats)

    # create the sequence of targets
    if gen_fn is None: gen_fn = Exp.get_default_seq_generator()
    targ_seq = gen_fn(**gen_params)

    # instantiate the experiment FSM
    exp = Exp(targ_seq, **exp_params)

    # run!
    exp.run_sync()
コード例 #4
0
def ismore_sim_bmi(baseline_data,
                   decoder,
                   targets_matrix=None,
                   session_length=0.):
    import ismore.invasive.bmi_ismoretasks as bmi_ismoretasks
    from riglib import experiment
    from features.hdf_features import SaveHDF
    from ismore.brainamp_features import SimBrainAmpData
    import datetime
    import numpy as np
    import matplotlib.pyplot as plt
    import multiprocessing as mp
    from features.blackrock_features import BlackrockBMI
    from ismore.exo_3D_visualization import Exo3DVisualizationInvasive

    targets = bmi_ismoretasks.SimBMIControlReplayFile.sleep_gen(length=100)
    plant_type = 'IsMore'
    kwargs = dict(session_length=session_length,
                  replay_neural_features=baseline_data,
                  decoder=decoder)

    if targets_matrix is not None:
        kwargs['targets_matrix'] = targets_matrix

    Task = experiment.make(bmi_ismoretasks.SimBMIControlReplayFile,
                           [SaveHDF])  #, Exo3DVisualizationInvasive])
    task = Task(targets, plant_type=plant_type, **kwargs)
    task.run_sync()
    pnm = save_dec_enc(task)
    return pnm
コード例 #5
0
    def test_mock_seq_with_features(self):
        from riglib.experiment.mocks import MockSequenceWithGenerators
        from features.hdf_features import SaveHDF
        import h5py

        task_cls = experiment.make(MockSequenceWithGenerators,
                                   feats=(SaveHDF, ))
        exp = task_cls(MockSequenceWithGenerators.gen_fn1())
        exp.run_sync()

        # optional delay if the OS is not ready to let you open the HDF file
        # created just now
        open_count = 0
        while open_count < 3:
            try:
                hdf = h5py.File(exp.h5file.name)
                break
            except OSError:
                open_count += 1
                if open_count >= 3:
                    raise Exception("Unable to open HDF file!")
                time.sleep(1)

        ref_current_state = np.array([
            0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1, 0, 0,
            0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0,
            0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0
        ])
        self.assertTrue(
            np.array_equal(hdf["/task"]["current_state"].ravel(),
                           ref_current_state))
コード例 #6
0
def run_task(session_length):
    Task = experiment.make(SimObsCLDA, [SaveHDF])
    Task.pre_init()

    targets = bmimultitasks.BMIResettingObstacles.centerout_2D_discrete_w_obstacle()

    kwargs=dict(assist_level_time=session_length, assist_level=(0.2, 0.),session_length=session_length,
        half_life=(300., 300.), half_life_time = session_length, timeout_time=60.)

    task = Task(targets, plant_type='cursor_14x14', **kwargs)
    task.init()
    task.run()

    task.decoder.save()

    ct = datetime.datetime.now()
    pnm = os.path.expandvars('$BMI3D/tests/sim_clda/'+ct.strftime("%m%d%y_%H%M") + '.hdf')

    f = open(task.h5file.name)
    f.close()
    time.sleep(1.)

    #Wait after HDF cleaned up
    task.cleanup_hdf()
    time.sleep(1.)

    #Copy temp file to actual desired location
    shutil.copy(task.h5file.name, pnm)
    f = open(pnm)
    f.close()
    return pnm
コード例 #7
0
def run_task(session_length):
    Task = experiment.make(SimObsCLDA, [SaveHDF])
    Task.pre_init()

    targets = bmimultitasks.BMIResettingObstacles.centerout_2D_discrete_w_obstacle()

    kwargs=dict(assist_level_time=session_length, assist_level=(0.2, 0.),session_length=session_length,
        half_life=(300., 300.), half_life_time = session_length, timeout_time=60.)

    task = Task(targets, plant_type='cursor_14x14', **kwargs)
    task.init()
    task.run()

    task.decoder.save()

    ct = datetime.datetime.now()
    pnm = os.path.expandvars('$BMI3D/tests/sim_clda/'+ct.strftime("%m%d%y_%H%M") + '.hdf')

    f = open(task.h5file.name)
    f.close()
    time.sleep(1.)

    #Wait after HDF cleaned up
    task.cleanup_hdf()
    time.sleep(1.)

    #Copy temp file to actual desired location
    shutil.copy(task.h5file.name, pnm)
    f = open(pnm)
    f.close()
    return pnm
コード例 #8
0
def init_exp(base_class, feats):
    blocks = 1
    targets = 3
    seq = ManualControl.centerout_2D_discrete(blocks, targets)
    Exp = experiment.make(base_class, feats=feats)
    exp = Exp(seq)
    exp.init()
    return exp
コード例 #9
0
def main_Y(session_length):
    ssm_y = StateSpaceEndptVelY()
    Task = experiment.make(SimVFB, [SaveHDF])
    targets = SimVFB.centerout_Y_discrete()
    #targets = manualcontrolmultitasks.ManualControlMulti.centerout_2D_discrete()
    task = Task(ssm_y, SimVFB.y_sim, targets, plant_type='cursor_14x14', session_length=session_length)
    task.run_sync()
    return task
コード例 #10
0
    def test_start_experiment_python(self):
        import json
        from built_in_tasks.passivetasks import TargetCaptureVFB2DWindow
        from riglib import experiment
        from features import Autostart
        from db.tracker import json_param

        try:
            import pygame
        except ImportError:
            print("Skipping test due to pygame missing")
            return

        # Create all the needed database entries
        subj = models.Subject(name="test_subject")
        subj.save()

        task = models.Task(
            name="test_vfb",
            import_path="built_in_tasks.passivetasks.TargetCaptureVFB2DWindow")
        task.save()

        models.Generator.populate()
        gen = models.Generator.objects.get(name='centerout_2D_discrete')

        seq_params = dict(nblocks=1, ntargets=1)
        seq_rec = models.Sequence(generator=gen,
                                  params=json.dumps(seq_params),
                                  task=task)
        seq_rec.save()
        print(seq_rec)

        task_rec = models.Task.objects.get(name='test_vfb')
        te = models.TaskEntry(task=task_rec, subject=subj)
        te.save()

        seq, seq_params = seq_rec.get()

        # Start the task
        base_class = task.get_base_class()
        Task = experiment.make(base_class, feats=[])

        params = json_param.Parameters.from_dict(dict(window_size=(480, 240)))
        params.trait_norm(Task.class_traits())

        saveid = te.id
        task_start_data = dict(subj=subj.id,
                               base_class=base_class,
                               feats=[Autostart],
                               params=dict(window_size=(480, 240)),
                               seq=seq_rec,
                               seq_params=seq_params,
                               saveid=saveid)

        tracker = tasktrack.Track.get_instance()
        tracker.runtask(cli=True, **task_start_data)
コード例 #11
0
def main_xz_CL_obstacles(session_length, task_kwargs=None):
    ssm_xz = StateSpaceEndptVel2D()
    Task = experiment.make(Sim_FA_Obs_BMI, [SaveHDF])
    targets = BMIResettingObstacles.centerout_2D_discrete_w_obstacle()
    task = Task(ssm_xz,
                targets,
                plant_type='cursor_14x14',
                session_length=session_length,
                **task_kwargs)
    task.run_sync()
    return task
コード例 #12
0
def main_xz_CL(session_length, task_kwargs=None):
    ssm_xz = StateSpaceEndptVel2D()
    Task = experiment.make(Sim_FA_BMI, [SaveHDF])
    targets = manualcontrolmultitasks.ManualControlMulti.centerout_2D_discrete(
    )
    task = Task(ssm_xz,
                targets,
                plant_type='cursor_14x14',
                session_length=session_length,
                **task_kwargs)
    task.run_sync()
    return task
コード例 #13
0
def main_Y(session_length):
    ssm_y = StateSpaceEndptPos1D()
    Task = experiment.make(SimVFB, [SaveHDF])
    targets = SimVFB.centerout_Y_discrete()
    #targets = manualcontrolmultitasks.ManualControlMulti.centerout_2D_discrete()
    task = Task(ssm_y,
                SimVFB.y_sim,
                targets,
                plant_type='cursor_14x14',
                session_length=session_length)
    task.run_sync()
    return task
コード例 #14
0
def main_xz_obs(session_length, task_kwargs=None):
    ssm_xz = StateSpaceEndptVel2D()
    Task = experiment.make(SimVFB_obs, [SaveHDF])
    targets = BMIResettingObstacles.centerout_2D_discrete_w_obstacle()
    C = np.random.normal(0, 2, (20, 7))
    task = Task(ssm_xz,
                C,
                targets,
                plant_type='cursor_14x14',
                session_length=session_length,
                **task_kwargs)

    task.run_sync()
    return task
コード例 #15
0
 def get(self, feats=()):
     from json_param import Parameters
     from riglib import experiment
     Exp = experiment.make(self.task.get(), tuple(f.get() for f in self.feats.all())+feats)
     params = Parameters(self.params)
     params.trait_norm(Exp.class_traits())
     if issubclass(Exp, experiment.Sequence):
         gen, gp = self.sequence.get()
         seq = gen(Exp, **gp)
         exp = Exp(seq, **params.params)
     else:
         exp = Exp(**params.params)
     exp.event_log = json.loads(self.report)
     return exp
コード例 #16
0
    def test_n_successful_targets(self):
        TestFeat = experiment.make(TargetCaptureVFB2DWindow, feats=[SaveHDF])

        n_targets = 8
        seq = TargetCaptureVFB2DWindow.centerout_2D_discrete(
            nblocks=1, ntargets=n_targets)
        feat = TestFeat(seq, window_size=(480, 240))

        feat.run_sync()

        time.sleep(1)  # small delay to allow HDF5 file to be written
        hdf = tables.open_file(feat.h5file.name)

        self.assertEqual(n_targets,
                         np.sum(hdf.root.task_msgs[:]['msg'] == b'reward'))
コード例 #17
0
def main_xz(session_length, task_kwargs=None):
    ssm_xz = StateSpaceEndptVel2D()
    Task = experiment.make(SimVFB, [SaveHDF])
    targets = manualcontrolmultitasks.ManualControlMulti.centerout_2D_discrete(
    )
    C = np.random.normal(0, 2, (20, 7))
    task = Task(ssm_xz,
                C,
                targets,
                plant_type='cursor_14x14',
                session_length=session_length,
                **task_kwargs)

    task.run_sync()
    return task
コード例 #18
0
    def test_experiment_unfixed(self):
        for cls in [SimBMICosEncLinDec]:
            N_TARGETS = 8
            N_TRIALS = 16
            seq = cls.sim_target_no_center(N_TARGETS, N_TRIALS)
            base_class = cls
            feats = []
            Exp = experiment.make(base_class, feats=feats)
            exp = Exp(seq)
            exp.init()

            exp.run()

            rewards, time_penalties, hold_penalties = calculate_rewards(exp)
            self.assertTrue(
                rewards <= rewards + time_penalties + hold_penalties)
            self.assertTrue(rewards > 0)
コード例 #19
0
    def test_experiment(self):
        for cls in [SimBMICosEncLinDec, SimBMIVelocityLinDec]:
            N_TARGETS = 8
            N_TRIALS = 16
            seq = cls.sim_target_seq_generator_multi(N_TARGETS, N_TRIALS)
            base_class = cls
            feats = []
            Exp = experiment.make(base_class, feats=feats)
            exp = Exp(seq)
            exp.init()
            exp.decoder.filt.fix_norm_attr()

            exp.run()

            rewards, time_penalties, hold_penalties = calculate_rewards(exp)
            self.assertTrue(
                rewards <= rewards + time_penalties + hold_penalties)
            self.assertTrue(rewards > 0)
コード例 #20
0
    def test_mock_seq_with_features(self):
        task_cls = experiment.make(MockSequenceWithGenerators, feats=(SaveHDF,))
        exp = task_cls(MockSequenceWithGenerators.gen_fn1())
        exp.run_sync()

        time.sleep(2)
        hdf = h5py.File(exp.h5file_name)

        # test that the annotation appears in the messages
        self.assertTrue(b'annotation: test annotation' in hdf["/task_msgs"]["msg"])

        ref_current_state = np.array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1, 0, 0,
            0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0,
            0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0])
        saved_current_state = hdf["/task"]["current_state"].ravel()

        # TODO this length chopping should be needed, but the vector appears to be short sometimes
        L = min(len(ref_current_state), len(saved_current_state))
        self.assertTrue(np.array_equal(ref_current_state[:L], saved_current_state[:L]))
コード例 #21
0
def test_make_params():
    from riglib import experiment

    from features.generator_features import Autostart
    from features.phasespace_features import MotionSimulate
    from features.hdf_features import SaveHDF

    from tasks import manualcontrol

    jsdesc = dict()
    Exp = experiment.make(manualcontrol.TargetDirection, (features.Autostart, features.MotionSimulate, features.SaveHDF))
    ctraits = Exp.class_traits()
    for trait in Exp.class_editable_traits():
        varname = dict()
        varname['type'] = ctraits[trait].trait_type.__class__.__name__
        varname['default'] = ctraits[trait].default
        varname['desc'] = ctraits[trait].desc
        jsdesc[trait] = varname

    return json.dumps(jsdesc)
コード例 #22
0
    def get(self, feats=()):
        print "models.Task.get()"
        from namelist import tasks
        if len(tasks) == 0: 
            print 'Import error in tracker.models.Task.get: from namelist import task returning empty -- likely error in task'
        
        feature_classes = Feature.getall(feats)

        if self.name in tasks and not None in feature_classes:
            try:
                # reload the module which contains the base task class
                task_cls = tasks[self.name]

                module_name = task_cls.__module__
                if '.' in module_name:
                    module_names = module_name.split('.')
                    mod = __import__(module_names[0])
                    for submod in module_names[1:]:
                        mod = getattr(mod, submod)
                else:
                    mod = __import__(module_name)
                
                task_cls_module = mod
                task_cls_module = imp.reload(task_cls_module)
                task_cls = getattr(task_cls_module, task_cls.__name__)

                # run the metaclass constructor
                Exp = experiment.make(task_cls, feature_classes)
                return Exp
            except:
                print "Problem making the task class!"
                traceback.print_exc()
                print self.name
                print feats
                print Feature.getall(feats)
                print "*******"
                return experiment.Experiment
        elif self.name in tasks:
            return tasks[self.name]
        else:
            return experiment.Experiment
コード例 #23
0
def consolerun(base_class='', feats=[], exp_params=dict(), gen_fn=None, gen_params=dict()):
    if isinstance(base_class, (str, unicode)):
        # assume that it's the name of a task as stored in the database
        base_class = models.Task.objects.get(name=base_class).get()
    
    for k, feat in enumerate(feats):
        # assume that the feature is input as the name of a feature already known to the database
        if isinstance(feat, (str, unicode)):
            feats[k] = models.Feature.objects.get(name=feat).get()

    # Run the pseudo-metaclass constructor
    Exp = experiment.make(base_class, feats=feats)

    # create the sequence of targets
    if gen_fn is None: gen_fn = Exp.get_default_seq_generator()
    targ_seq = gen_fn(**gen_params)

    # instantiate the experiment FSM
    exp = Exp(targ_seq, **exp_params)

    # run!
    exp.run_sync()
コード例 #24
0
    def test_save_hdf(self):
        TestFeat = experiment.make(TestExp, feats=[SaveHDF])
        feat = TestFeat()
        feat.add_dtype("dummy_feat_for_test", "f8", (1,))

        # start the feature
        feat.start()

        feat.join()

        mock_db = mocks.MockDatabase()
        feat.cleanup(mock_db, "saveHDF_test_output")
        
        hdf = tables.open_file("saveHDF_test_output.hdf")

        self.assertEqual(hdf.root.task_msgs[:]["msg"].tolist(), 
            ['wait', 'trial', 'reward', 'wait', 'None'])
        self.assertEqual(hdf.root.task_msgs[:]["time"].tolist(), [0, 100, 101, 201, 601])

        self.assertTrue(np.all(hdf.root.task[:]["dummy_feat_for_test"][251:300] == -1))
        self.assertTrue(np.all(hdf.root.task[:]["dummy_feat_for_test"][:251] == 0))
        self.assertTrue(np.all(hdf.root.task[:]["dummy_feat_for_test"][300:] == 0))
コード例 #25
0
def ismore_sim_bmi(baseline_data, decoder, targets_matrix=None, session_length=0.):
    import ismore.invasive.bmi_ismoretasks as bmi_ismoretasks
    from riglib import experiment
    from features.hdf_features import SaveHDF
    from ismore.brainamp_features import SimBrainAmpData
    import datetime
    import numpy as np
    import matplotlib.pyplot as plt
    import multiprocessing as mp
    from features.blackrock_features import BlackrockBMI
    from ismore.exo_3D_visualization import Exo3DVisualizationInvasive

    targets = bmi_ismoretasks.SimBMIControlReplayFile.sleep_gen(length=100)
    plant_type = 'IsMore'
    kwargs=dict(session_length=session_length, replay_neural_features=baseline_data, decoder=decoder)
    
    if targets_matrix is not None:
        kwargs['targets_matrix']=targets_matrix
    
    Task = experiment.make(bmi_ismoretasks.SimBMIControlReplayFile, [SaveHDF])#, Exo3DVisualizationInvasive])
    task = Task(targets, plant_type=plant_type, **kwargs)
    task.run_sync()
    pnm = save_dec_enc(task)
    return pnm
コード例 #26
0
    def setUp(self):
        sink_manager = sink.SinkManager.get_instance()
        sink_manager.reset()

        n_targets = 1
        nblocks = 1
        seq = TargetCaptureVFB2DWindow.centerout_2D_discrete(
            nblocks=nblocks, ntargets=n_targets)

        params = dict(window_size=(480, 240))

        task_class = experiment.make(TargetCaptureVFB2DWindow, feats=[SaveHDF])

        task_wrapper = experiment.task_wrapper.TaskWrapper(
            subj=None,
            params=params,
            target_class=task_class,
            seq=seq,
            log_filename='tasktrack_log')

        task_proxy, data_proxy = task_wrapper.start()
        self.task_proxy = task_proxy
        self.task_wrapper = task_wrapper
        self.n_trials = n_targets * nblocks
コード例 #27
0
 def test_metaclass_constructor(self):
     from features.hdf_features import SaveHDF
     exp = experiment.make(experiment.LogExperiment, feats=(SaveHDF, ))
     exp()
コード例 #28
0
        samples = samples.astype(np.float32)

        # play the audio
        self.audio_stream.write(samples)
        # self.audio_t += self.audio_duration
        super()._cycle()

    def run(self):
        super().run()
        self.audio_stream.stop_stream()
        self.audio_stream.close()
        self.audio_p.terminate() 


TestFeat = experiment.make(TargetCaptureVFB2DWindow, feats=[SaveHDF, SimLFPSensorFeature, SimLFPOutput])
# TestFeat.fps = 5
seq = TargetCaptureVFB2DWindow.centerout_2D_discrete(nblocks=2, ntargets=8) 
feat = TestFeat(seq, window_size=(480, 240))

feat.run_sync()

time.sleep(1)
hdf = tables.open_file(feat.h5file.name)
os.rename(feat.h5file.name, "test_vfb_real_time_audio_feedback.hdf")

saved_msgs = [x.decode('utf-8') for x in hdf.root.task_msgs[:]["msg"]]

lfp = hdf.root.sim_lfp[:]['lfp'][:]
ts = hdf.root.sim_lfp[:]['ts']
コード例 #29
0
from riglib import experiment
import pickle
from features.plexon_features import PlexonBMI
from features.hdf_features import SaveHDF
from riglib.bmi import train

# decoder = pickle.load(open('/Users/preeyakhanna/Dropbox/Carmena_Lab/FA_exp/grom_data/grom20151201_01_RMLC12011916.pkl'))

# Task = experiment.make(factor_analysis_tasks.FactorBMIBase, [CorticalBMI, SaveHDF])
# targets = factor_analysis_tasks.FactorBMIBase.generate_catch_trials()
# kwargs=dict(session_length=20.)
# task = Task(targets, plant_type="cursor_25x14", **kwargs)
# task.decoder = decoder

# import riglib.plexon
# task.sys_module = riglib.plexon

# task.init()
# task.run()

from tasks import choice_fa_tasks
decoder = pickle.load(
    open(
        '/storage/decoders/grom20160201_01_RMLC02011515_w_fa_dict_from_4048.pkl'
    ))

Task = experiment.make(choice_fa_tasks.FreeChoiceFA, [PlexonBMI, SaveHDF])
targets = choice_fa_tasks.FreeChoiceFA.centerout_2D_discrete_w_free_choice()
task = Task(targets, plant_type="cursor_25x14")
task.decoder = decoder
コード例 #30
0
def run_iter_feat_addition(
        total_exp_time=60,
        n_neurons=128,
        fraction_snr=0.25,
        percent_high_SNR_noises=np.arange(0.7, 0.6, -0.2),
        data_dump_folder='/home/sijia-aw/BMi3D_my/operation_funny_chicken/sim_data/neurons_128/run_3/',
        random_seed=0):

    #percent_high_SNR_noises[-1] = 0
    num_noises = len(percent_high_SNR_noises)

    percent_high_SNR_noises_labels = [
        f'{s:.2f}' for s in percent_high_SNR_noises
    ]

    import numpy as np
    np.set_printoptions(precision=2, suppress=True)

    mean_firing_rate_low = 50
    mean_firing_rate_high = 100
    noise_mode = 'fixed_gaussian'
    fixed_noise_level = 5  #Hz

    neuron_types = ['noisy', 'non_noisy']

    n_neurons_no_noise_group = int(n_neurons * fraction_snr)
    n_neurons_noisy_group = n_neurons - n_neurons_no_noise_group

    no_noise_neuron_ind = np.arange(n_neurons_no_noise_group)
    noise_neuron_ind = np.arange(
        n_neurons_no_noise_group,
        n_neurons_noisy_group + n_neurons_no_noise_group)

    neuron_type_indices_in_a_list = [noise_neuron_ind, no_noise_neuron_ind]

    noise_neuron_list = np.full(n_neurons, False, dtype=bool)
    no_noise_neuron_list = np.full(n_neurons, False, dtype=bool)

    noise_neuron_list[noise_neuron_ind] = True
    no_noise_neuron_list[no_noise_neuron_ind] = True

    N_TYPES_OF_NEURONS = 2

    print('We have two types of indices: ')
    for t, l in enumerate(neuron_type_indices_in_a_list):
        print(f'{neuron_types[t]}:{l}')

    # make percent of count into a list
    percent_of_count_in_a_list = list()

    for i in range(num_noises):
        percent_of_count = np.ones(n_neurons)[:, np.newaxis]

        percent_of_count[noise_neuron_ind] = 1
        percent_of_count[no_noise_neuron_ind] = percent_high_SNR_noises[i]

        percent_of_count_in_a_list.append(percent_of_count)

    #for comparision
    #for comparision
    exp_conds_add = [
        f'iter_{s}_{random_seed}_{n_neurons}' for s in percent_high_SNR_noises
    ]
    exp_conds_keep = [
        f'same_{s}_{random_seed}_{n_neurons}' for s in percent_high_SNR_noises
    ]
    exp_conds = [
        f'wo_FS_{s}_{random_seed}_{n_neurons}' for s in percent_high_SNR_noises
    ]

    exp_conds.extend(exp_conds_add)
    exp_conds.extend(exp_conds_keep)
    print(f'we have experimental conditions {exp_conds}')

    # In[7]:

    # CHANGE: game mechanics: generate task params
    N_TARGETS = 8
    N_TRIALS = 2000

    NUM_EXP = len(exp_conds)  # how many experiments we are running.

    # # Config the experiments
    #
    # this section largely copyied and pasted from
    # bmi3d-sijia(branch)-bulti_in_experiemnts
    # https://github.com/sijia66/brain-python-interface/blob/master/built_in_tasks/sim_task_KF.py

    # import libraries
    # make sure these directories are in the python path.,
    from bmimultitasks import SimBMIControlMulti, SimBMICosEncKFDec, BMIControlMultiNoWindow, SimpleTargetCapture, SimpleTargetCaptureWithHold
    from features import SaveHDF
    from features.simulation_features import get_enc_setup, SimKFDecoderRandom, SimIntentionLQRController, SimClockTick
    from features.simulation_features import SimHDF, SimTime

    from riglib import experiment

    from riglib.stereo_opengl.window import FakeWindow
    from riglib.bmi import train

    import weights

    import time
    import copy
    import numpy as np
    import matplotlib.pyplot as plt
    import sympy as sp
    import itertools  #for identical sequences

    np.set_printoptions(precision=2, suppress=True)

    # ##  behaviour and task setup

    # In[10]:

    #seq = SimBMIControlMulti.sim_target_seq_generator_multi(
    #N_TARGETS, N_TRIALS)

    from target_capture_task import ConcreteTargetCapture
    seq = ConcreteTargetCapture.out_2D()

    #create a second version of the tasks
    seqs = itertools.tee(seq, NUM_EXP + 1)
    target_seq = list(seqs[NUM_EXP])

    seqs = seqs[:NUM_EXP]

    SAVE_HDF = True
    SAVE_SIM_HDF = True  #this makes the task data available as exp.task_data_hist
    DEBUG_FEATURE = False

    #base_class = SimBMIControlMulti
    base_class = SimpleTargetCaptureWithHold

    #for adding experimental features such as encoder, decoder
    feats = []
    feats_2 = []
    feats_set = []  # this is a going to be a list of lists

    # ## Additional task setup

    # In[11]:

    from simulation_features import TimeCountDown
    from features.sync_features import HDFSync

    feats.append(HDFSync)
    feats_2.append(HDFSync)

    feats.append(TimeCountDown)
    feats_2.append(TimeCountDown)

    # ## encoder
    #
    # the cosine tuned encoder uses a poisson process, right
    # https://en.wikipedia.org/wiki/Poisson_distribution
    # so if the lambda is 1, then it's very likely

    # In[12]:

    from features.simulation_features import get_enc_setup

    ENCODER_TYPE = 'cosine_tuned_encoder_with_poisson_noise'

    #neuron set up : 'std (20 neurons)' or 'toy (4 neurons)'
    N_NEURONS, N_STATES, sim_C = get_enc_setup(sim_mode='rot_90',
                                               n_neurons=n_neurons)

    #multiply our the neurons
    sim_C[noise_neuron_list] = sim_C[noise_neuron_list] * mean_firing_rate_low
    sim_C[no_noise_neuron_list] = sim_C[
        no_noise_neuron_list] * mean_firing_rate_high

    #set up the encoder
    from features.simulation_features import SimCosineTunedEncWithNoise
    #set up intention feedbackcontroller
    #this ideally set before the encoder
    feats.append(SimIntentionLQRController)
    feats.append(SimCosineTunedEncWithNoise)

    feats_2.append(SimIntentionLQRController)
    feats_2.append(SimCosineTunedEncWithNoise)

    # ## decoder setup

    # In[13]:

    #clda on random
    DECODER_MODE = 'random'  # random

    #take care the decoder setup
    if DECODER_MODE == 'random':
        feats.append(SimKFDecoderRandom)
        feats_2.append(SimKFDecoderRandom)
        print(f'{__name__}: set base class ')
        print(f'{__name__}: selected SimKFDecoderRandom \n')
    else:  #defaul to a cosEnc and a pre-traind KF DEC
        from features.simulation_features import SimKFDecoderSup
        feats.append(SimKFDecoderSup)
        feats_2.append(SimKFDecoderSup)
        print(f'{__name__}: set decoder to SimKFDecoderSup\n')

    # ##  clda: learner and updater

    # In[14]:

    #setting clda parameters
    ##learner: collects paired data at batch_sizes
    RHO = 0.5
    batch_size = 100

    #learner and updater: actualy set up rho
    UPDATER_BATCH_TIME = 1
    UPDATER_HALF_LIFE = np.log(RHO) * UPDATER_BATCH_TIME / np.log(0.5)

    LEARNER_TYPE = 'feedback'  # to dumb or not dumb it is a question 'feedback'
    UPDATER_TYPE = 'smooth_batch'  #none or "smooth_batch"

    #you know what?
    #learner only collects firing rates labeled with estimated estimates
    #we would also need to use the labeled data
    #now, we can set up a dumb/or not-dumb learner
    if LEARNER_TYPE == 'feedback':
        from features.simulation_features import SimFeedbackLearner
        feats.append(SimFeedbackLearner)
        feats_2.append(SimFeedbackLearner)
    else:
        from features.simulation_features import SimDumbLearner
        feats.append(SimDumbLearner)
        feats_2.append(SimDumbLearner)

    #to update the decoder.
    if UPDATER_TYPE == 'smooth_batch':
        from features.simulation_features import SimSmoothBatch
        feats.append(SimSmoothBatch)
        feats_2.append(SimSmoothBatch)
    else:  #defaut to none
        print(f'{__name__}: need to specify an updater')

    # ## feature selector setup

    # In[15]:

    from feature_selection_feature import FeatureTransformer, TransformerBatchToFit
    from feature_selection_feature import FeatureSelector, LassoFeatureSelector, SNRFeatureSelector, IterativeFeatureSelector
    from feature_selection_feature import ReliabilityFeatureSelector

    #pass the real time limit on clock
    feats.append(FeatureSelector)
    feats_2.append(IterativeFeatureSelector)

    feature_x_meth_arg = [
        ('transpose', None),
    ]

    kwargs_feature = dict()
    kwargs_feature = {
        'transform_x_flag': False,
        'transform_y_flag': False,
        'n_starting_feats': n_neurons,
        'n_states': 7,
        "train_high_SNR_time": 60
    }

    print('kwargs will be updated in a later time')
    print(
        f'the feature adaptation project is tracking {kwargs_feature.keys()} ')

    #assistor set up assist level
    assist_level = (0.0, 0.0)

    exp_feats = [feats] * num_noises

    e_f_2 = [feats_2] * num_noises

    e_f_3 = [feats] * num_noises

    exp_feats.extend(e_f_2)
    exp_feats.extend(e_f_3)

    if DEBUG_FEATURE:
        from features.simulation_features import DebugFeature
        feats.append(DebugFeature)

    if SAVE_HDF:
        feats.append(SaveHDF)
        feats_2.append(SaveHDF)
    if SAVE_SIM_HDF:
        feats.append(SimHDF)
        feats_2.append(SimHDF)

    #pass the real time limit on clock
    feats.append(SimClockTick)
    feats.append(SimTime)

    feats_2.append(SimClockTick)
    feats_2.append(SimTime)

    # In[19]:

    kwargs_exps = list()

    for i in range(num_noises):
        d = dict()

        d['total_exp_time'] = total_exp_time

        d['assist_level'] = assist_level
        d['sim_C'] = sim_C

        d['noise_mode'] = noise_mode
        d['percent_noise'] = percent_of_count_in_a_list[i]
        d['fixed_noise_level'] = fixed_noise_level

        d['batch_size'] = batch_size

        d['batch_time'] = UPDATER_BATCH_TIME
        d['half_life'] = UPDATER_HALF_LIFE
        d['no_noise_neuron_ind'] = no_noise_neuron_ind
        d['noise_neuron_ind'] = noise_neuron_ind

        d.update(kwargs_feature)

        kwargs_exps.append(d)

    kwargs_exps_add = copy.deepcopy(kwargs_exps)
    kwargs_exps_start = copy.deepcopy(kwargs_exps)

    for k in kwargs_exps_add:

        k['init_feat_set'] = np.full(N_NEURONS, False, dtype=bool)
        k['init_feat_set'][no_noise_neuron_list] = True

    for k in kwargs_exps_start:

        k['init_feat_set'] = np.full(N_NEURONS, False, dtype=bool)
        k['init_feat_set'][no_noise_neuron_list] = True

    kwargs_exps.extend(kwargs_exps_add)
    kwargs_exps.extend(kwargs_exps_start)

    print(f'we have got {len(kwargs_exps)} exps')
    kwargs_exps[1]['init_feat_set']

    # ## make and initalize experiment instances

    # In[20]:

    #seed the experiment
    np.random.seed(0)

    exps = list()  #create a list of experiment

    for i, s in enumerate(seqs):
        #spawn the task
        f = exp_feats[i]
        Exp = experiment.make(base_class, feats=f)

        e = Exp(s, **kwargs_exps[i])
        exps.append(e)

    exps_np = np.array(exps, dtype='object')

    def get_KF_C_Q_from_decoder(first_decoder):
        """
        get the decoder matrices C, Q from the decoder instance
        
        Args:
            first_decoder: riglib.bmi.decoder.
        Returns:
            target_C, target_Q: np.ndarray instances
        """
        target_C = first_decoder.filt.C
        target_Q = np.copy(first_decoder.filt.Q)
        diag_val = 10000
        np.fill_diagonal(target_Q, diag_val)

        return (target_C, target_Q)

    from feature_selection_feature import run_exp_loop

    WAIT_FOR_HDF_FILE_TO_STOP = 10

    for i, e in enumerate(exps):
        np.random.seed(random_seed)

        e.init()

        # save the decoder if it is the first one.
        if i == 0:
            (target_C, target_Q) = get_KF_C_Q_from_decoder(e.decoder)

            weights.change_target_kalman_filter_with_a_C_mat(e.decoder.filt,
                                                             target_C,
                                                             Q=target_Q,
                                                             debug=False)

        else:  # otherwise, just replace it.
            weights.change_target_kalman_filter_with_a_C_mat(e.decoder.filt,
                                                             target_C,
                                                             Q=target_Q,
                                                             debug=False)

        e.select_decoder_features(e.decoder)
        e.record_feature_active_set(e.decoder)

        #################################################################
        # actual experiment begins
        run_exp_loop(e, **kwargs_exps[i])

        e.hdf.stop()
        print(f'wait for {WAIT_FOR_HDF_FILE_TO_STOP}s for hdf file to save')
        time.sleep(WAIT_FOR_HDF_FILE_TO_STOP)

        e.save_feature_params()

        time.sleep(WAIT_FOR_HDF_FILE_TO_STOP)

        e.cleanup_hdf()

        e.sinks.reset()

        print(f'Finished running  {exp_conds[i]}')

        print()

    import shutil

    import os
    import subprocess

    for i, e in enumerate(exps):

        import subprocess
        old = e.h5file.name
        new = data_dump_folder + exp_conds[i] + '.h5'
        process = "cp {} {}".format(old, new)

        print(process)
        subprocess.run(
            process,
            shell=True)  # do not remember, assign shell value to True.

    import os
    import aopy
    import tables

    exp_data_all = list()
    exp_data_metadata_all = list()

    for i, e in enumerate(exp_conds):
        files = {'hdf': e + '.h5'}

        file_name = os.path.join(data_dump_folder, files['hdf'])

        # write in the exp processing files

        aopy.data.save_hdf(data_dump_folder,
                           file_name,
                           kwargs_exps[i],
                           data_group="/feature_selection",
                           append=True)

        with tables.open_file(file_name, mode='r') as f:
            print(f)

        try:
            d, m = aopy.preproc.parse_bmi3d(data_dump_folder, files)
        except:
            print(f'cannot parse {e}')
コード例 #31
0


import os
os.environ['DISPLAY'] = ':0'

save = True

#task = models.Task.objects.get(name='lfp_mod')
task = models.Task.objects.get(name='lfp_mod_mc_reach_out')
#task = models.Task.objects.get(name='manual_control_multi')

base_class = task.get()

feats = [SaveHDF, Autostart, PlexonBMI, MotionData]
Exp = experiment.make(base_class, feats=feats)

#params.trait_norm(Exp.class_traits())
params = dict(session_length=10, plant_visible=True, lfp_plant_type='cursor_onedimLFP', mc_plant_type='cursor_14x14'
        rand_start=(0.,0.), max_tries=1)

gen = SimBMIControlMulti.sim_target_seq_generator_multi(8, 1000)
exp = Exp(gen, **params)

import pickle
#decoder = pickle.load(open('/storage/decoders/cart20141206_06_test_lfp1d2.pkl'))
d#ecoder = pickle.load(open('/storage/decoders/cart20141208_12_test_PK.pkl'))
decoder = pickle.load(open('/storage/decoders/cart20141209_08_cart_2015_pilot_2.pkl'))
exp.decoder = decoder

exp.init()
コード例 #32
0
 def test_metaclass_constructor(self):
     exp = experiment.make(experiment.LogExperiment, feats=(SaveHDF, ))
     exp()
コード例 #33
0
from ismore.invasive import bmi_ismoretasks
from riglib import experiment
from features.hdf_features import SaveHDF
from features.arduino_features import BlackrockSerialDIORowByte
import numpy as np

targets = bmi_ismoretasks.SimBMIControl.rehand_simple(length=100)
Task = experiment.make(bmi_ismoretasks.VisualFeedback, [SaveHDF, BlackrockSerialDIORowByte])
kwargs=dict(session_length=15., assist_level = (1., 1.), assist_level_time=60.,
    timeout_time=60.,)
task = Task(targets, plant_type='ReHand', **kwargs)
task.init()
コード例 #34
0
ファイル: tmp.py プロジェクト: pkhanna104/fa_analysis
from tasks import manualcontrolmultitasks
from riglib.bmi.state_space_models import StateSpaceEndptVel2D
from features.hdf_features import SaveHDF
from tasks import sim_fa_decoding
from riglib import experiment

ssm = StateSpaceEndptVel2D()
targ = manualcontrolmultitasks.ManualControlMulti.centerout_2D_discrete()

import os
kw = dict(encoder_name = os.path.expandvars('$FA_GROM_DATA/sims/enc030816_2325vfb_red_spk.pkl'))

from sim_neurons import parse_hdf
kf_decoder, dim_red_dict, hdf = parse_hdf.train_xz_KF(kw['encoder_name'])
kw['fa_dict'] = dim_red_dict

Task = experiment.make(sim_fa_decoding.Sim_FA_BMI, [SaveHDF])
task = Task(ssm, targ,**kw)
コード例 #35
0
    def __init__(self, subj, task_rec, feats, params, seq=None, seq_params=None, saveid=None):
        '''
        Parameters
        ----------
        subj : tracker.models.Subject instance
            Database record for subject performing the task
        task_rec : tracker.models.Task instance
            Database record for base task being run (without features)
        feats : list 
            List of features to enable for the task
        params : json_param.Parameters, or string representation of JSON object
            user input on configurable task parameters
        seq : models.Sequence instance, or tuple
            Database record of Sequence parameters/static target sequence
            If passed in as a tuple, then it's the result of calling 'seq.get' on the models.Sequence instance
        seq_params: params from seq (see above)

        saveid : int, optional
            ID number of db.tracker.models.TaskEntry associated with this task
            if None specified, then the data saved will not be linked to the
            database entry and will be lost after the program exits
        '''
        self.saveid = saveid
        self.taskname = task_rec.name
        self.subj = subj
        if isinstance(params, Parameters):
            self.params = params
        elif isinstance(params, (string, unicode)):
            self.params = Parameters(params)
        
        base_class = task_rec.get()

        if None in feats:
            raise Exception("Features not found properly in database!")
        else:
            Task = experiment.make(base_class, feats=feats)

        # Run commands which must be executed before the experiment class can be instantiated (e.g., starting neural recording)
        Task.pre_init(saveid=saveid)

        self.params.trait_norm(Task.class_traits())
        if issubclass(Task, experiment.Sequence):
            # retreive the sequence data from the db, or from the input argument if the input arg was a tuple
            if isinstance(seq, tuple):
                gen_constructor, gen_params = seq
            elif isinstance(seq, models.Sequence):
                gen_constructor, gen_params = seq.get()
                # Typically, 'gen_constructor' is the experiment.generate.runseq function (not an element of namelist.generators)
            else:
                raise ValueError("Unrecognized type for seq")

            gen = gen_constructor(Task, **gen_params)
            self.params.params['seq_params'] = seq_params
            
            # 'gen' is now a true python generator usable by experiment.Sequence
            self.task = Task(gen, **self.params.params)

            with open(log_filename, 'a') as f:
                f.write("instantiating task with a generator\n")

        else:
            self.task = Task(**self.params.params)
        self.task.start()
コード例 #36
0
                    #print tmp-bound
                elif tmp[i] < 0:
                    tmp[i] = 1
                    #print tmp-bound
            self.current_pt = tmp - bound

            #Update screen
            self._update(np.array(self.current_pt))

            #Update 'last_pt' for next iteration
            self.last_pt = self.current_pt.copy()

        #write to screen
        self.draw_world()

        #Save Cursor position to HDF
        self.task_data['cursor'] = np.array(self.current_pt)

        #Save Target position to HDF (super form )
        self.update_target_location()


if __name__ == "__main__":
    from riglib.experiment import make
    from riglib.experiment.generate import runseq
    from riglib.experiment.features import SimulatedEyeData, MotionData, RewardSystem
    seq = rand_target_sequence()
    Exp = make(ManualControl, (SimulatedEyeData, MotionData, RewardSystem))
    exp = Exp(runseq(Exp, seq), fixations=[(0, 0)])
    exp.run()
コード例 #37
0
from ismore.invasive import bmi_ismoretasks
from riglib import experiment
from features.hdf_features import SaveHDF
from features.arduino_features import BlackrockSerialDIORowByte
import numpy as np

targets = bmi_ismoretasks.SimBMIControl.rehand_simple(length=100)
Task = experiment.make(bmi_ismoretasks.VisualFeedback,
                       [SaveHDF, BlackrockSerialDIORowByte])
kwargs = dict(
    session_length=15.,
    assist_level=(1., 1.),
    assist_level_time=60.,
    timeout_time=60.,
)
task = Task(targets, plant_type='ReHand', **kwargs)
task.init()
コード例 #38
0
    def __init__(self, subj, task_rec, feats, params, seq=None, seq_params=None, saveid=None):
        '''
        Parameters
        ----------
        subj : tracker.models.Subject instance
            Database record for subject performing the task
        task_rec : tracker.models.Task instance
            Database record for base task being run (without features)
        feats : list 
            List of features to enable for the task
        params : json_param.Parameters, or string representation of JSON object
            user input on configurable task parameters
        seq : models.Sequence instance, or tuple
            Database record of Sequence parameters/static target sequence
            If passed in as a tuple, then it's the result of calling 'seq.get' on the models.Sequence instance
        seq_params: params from seq (see above)

        saveid : int, optional
            ID number of db.tracker.models.TaskEntry associated with this task
            if None specified, then the data saved will not be linked to the
            database entry and will be lost after the program exits
        '''
        self.saveid = saveid
        self.taskname = task_rec.name
        self.subj = subj
        if isinstance(params, Parameters):
            self.params = params
        elif isinstance(params, (string, str)):
            self.params = Parameters(params)
        
        base_class = task_rec.get()

        if None in feats:
            raise Exception("Features not found properly in database!")
        else:
            Task = experiment.make(base_class, feats=feats)

        # Run commands which must be executed before the experiment class can be instantiated (e.g., starting neural recording)
        Task.pre_init(saveid=saveid)

        self.params.trait_norm(Task.class_traits())
        if issubclass(Task, experiment.Sequence):
            # retreive the sequence data from the db, or from the input argument if the input arg was a tuple
            if isinstance(seq, tuple):
                gen_constructor, gen_params = seq
            elif isinstance(seq, models.Sequence):
                gen_constructor, gen_params = seq.get()
                # Typically, 'gen_constructor' is the experiment.generate.runseq function (not an element of namelist.generators)
            else:
                raise ValueError("Unrecognized type for seq")

            gen = gen_constructor(Task, **gen_params)
            self.params.params['seq_params'] = seq_params
            
            # 'gen' is now a true python generator usable by experiment.Sequence
            self.task = Task(gen, **self.params.params)

            with open(log_filename, 'a') as f:
                f.write("instantiating task with a generator\n")

        else:
            self.task = Task(**self.params.params)
        self.task.start()
コード例 #39
0
decoder_list = models.Decoder.objects.filter(entry=decoder_trained_id)
Decoder = decoder_list[decoder_list_ix]
decoder = pickle.load(open('/storage/decoders/'+Decoder.path))

kw=dict(decoder=decoder)
s = models.Subject.objects.filter(name='Gromit')
t = models.Task.objects.filter(name='rat_bmi')
entry = models.TaskEntry(subject_id=s[0].id, task=t[0])
entry.sequence_id = -1

params = Parameters.from_dict(dict(decoder=Decoder.pk, decoder_path=Decoder.path))
entry.params = params.to_json()
entry.save()

saveid = entry.id

Task = experiment.make(rat_bmi_tasks.RatBMI, [Autostart, PlexonBMI, PlexonSerialDIORowByte, RewardSystem, SaveHDF])
Task.pre_init(saveid=saveid)
task = Task(plant_type='aud_cursor', session_length=session_length, reward_time=reward_time, **kw)
task.subj=s[0]

task.init()
task.run()

#Cleanup
from db.tracker import dbq
cleanup_successful = task.cleanup(dbq, saveid, subject=task.subj)
task.decoder.save()
task.cleanup_hdf()
task.terminate()
kwargs_exps

# ## seed the exp

# In[7]:

np.random.seed(0)

#

# ## make our experiment class

# In[8]:

#spawn the task
Exp = experiment.make(base_class, feats=feats)

# # creat experiments with different C batch lengths

# ## create exps

# In[9]:

exps = list()  #create a list of experiment

for i, s in enumerate(seqs):
    e = Exp(s, **kwargs_exps[i])
    exps.append(e)

# In[10]:
コード例 #41
0
from tasks import factor_analysis_tasks
from riglib import experiment
import pickle
from features.plexon_features import PlexonBMI
from features.hdf_features import SaveHDF
from riglib.bmi import train

# decoder = pickle.load(open('/Users/preeyakhanna/Dropbox/Carmena_Lab/FA_exp/grom_data/grom20151201_01_RMLC12011916.pkl'))

# Task = experiment.make(factor_analysis_tasks.FactorBMIBase, [CorticalBMI, SaveHDF])
# targets = factor_analysis_tasks.FactorBMIBase.generate_catch_trials()
# kwargs=dict(session_length=20.)
# task = Task(targets, plant_type="cursor_25x14", **kwargs)
# task.decoder = decoder

# import riglib.plexon
# task.sys_module = riglib.plexon

# task.init()
# task.run()


from tasks import choice_fa_tasks
decoder = pickle.load(open('/storage/decoders/grom20160201_01_RMLC02011515_w_fa_dict_from_4048.pkl'))

Task = experiment.make(choice_fa_tasks.FreeChoiceFA, [PlexonBMI, SaveHDF])
targets = choice_fa_tasks.FreeChoiceFA.centerout_2D_discrete_w_free_choice()
task = Task(targets, plant_type="cursor_25x14")
task.decoder =decoder