def save_new_dec(task_entry_id, dec_obj, suffix):
    '''
    Summary: Method to save decoder to DB -- saves to TE that original decoder came from
    Input param: task_entry_id: original task to save decoder to
    Input param: dec_obj: KF decoder new
    Input param: suffix:
    Output param: 
    '''

    te = dbfn.TaskEntry(task_entry_id)
    try:
        te_id = te.te_id
    except:
        dec_nm = te.name
        te_ix = re.search('te[0-9]',dec_nm)
        ix = te_ix.start() + 2
        sub_dec_nm = dec_nm[ix:]
        
        te_ix_end = sub_dec_nm.find('_')
        if te_ix_end == -1:
            te_ix_end = len(sub_dec_nm)
        te_id = int(sub_dec_nm[:te_ix_end])

    old_dec_obj = te.decoder_record
    if old_dec_obj is None:
        old_dec_obj = faux_decoder_obj(task_entry_id)
    trainbmi.save_new_decoder_from_existing(dec_obj, old_dec_obj, suffix=suffix)
def save_new_dec(task_entry_id, dec_obj, suffix):
    '''
    Summary: Method to save decoder to DB -- saves to TE that original decoder came from
    Input param: task_entry_id: original task to save decoder to
    Input param: dec_obj: KF decoder new
    Input param: suffix:
    Output param: 
    '''

    te = dbfn.TaskEntry(task_entry_id)
    try:
        te_id = te.te_id
    except:
        dec_nm = te.name
        te_ix = re.search('te[0-9]',dec_nm)
        ix = te_ix.start() + 2
        sub_dec_nm = dec_nm[ix:]
        
        te_ix_end = sub_dec_nm.find('_')
        if te_ix_end == -1:
            te_ix_end = len(sub_dec_nm)
        te_id = int(sub_dec_nm[:te_ix_end])

    old_dec_obj = te.decoder_record
    if old_dec_obj is None:
        old_dec_obj = faux_decoder_obj(task_entry_id)
    trainbmi.save_new_decoder_from_existing(dec_obj, old_dec_obj, suffix=suffix)
예제 #3
0
def add_fa_dict_to_decoder(decoder_training_te, dec_ix, fa_te, return_dec=False):
    # First make sure we're training from the correct task entry: spike counts n_units == BMI units
    from db import dbfunctions as dbfn

    te = dbfn.TaskEntry(fa_te)
    hdf = te.hdf
    sc_n_units = hdf.root.task[0]["spike_counts"].shape[0]

    from db.tracker import models

    te_arr = models.Decoder.objects.filter(entry=decoder_training_te)
    search_flag = 1
    for te in te_arr:
        ix = te.path.find("_")
        if search_flag:
            if int(te.path[ix + 1 : ix + 3]) == dec_ix:
                decoder_old = te
                search_flag = 0

    if search_flag:
        raise Exception("No decoder from ", str(decoder_training_te), " and matching index: ", str(dec_ix))

    from tasks.factor_analysis_tasks import FactorBMIBase

    FA_dict = FactorBMIBase.generate_FA_matrices(fa_te)

    import pickle

    dec = pickle.load(open(decoder_old.filename))
    dec.trained_fa_dict = FA_dict
    dec_n_units = dec.n_units

    if dec_n_units != sc_n_units:
        raise Exception("Cant use TE for BMI training and FA training -- n_units mismatch")
    if return_dec:
        return dec
    else:
        from db import trainbmi

        trainbmi.save_new_decoder_from_existing(dec, decoder_old, suffix="_w_fa_dict_from_" + str(fa_te))
예제 #4
0
def conv_KF_to_splitFA_dec(decoder_training_te, dec_ix, fa_te, search_suffix="w_fa_dict_from_"):

    from db import dbfunctions as dbfn

    te = dbfn.TaskEntry(fa_te)
    hdf = te.hdf
    sc_n_units = hdf.root.task[0]["spike_counts"].shape[0]

    from db.tracker import models

    te_arr = models.Decoder.objects.filter(entry=decoder_training_te)
    search_flag = 1
    for te in te_arr:
        ix = te.path.find("_")
        if search_flag:
            if int(te.path[ix + 1 : ix + 3]) == dec_ix:
                decoder = pickle.load(open(te.filename))
                if hasattr(decoder, "trained_fa_dict"):
                    ix = te.path.find("w_fa_dict_from_")
                    if ix > 1:
                        fa_te_train = te.path[ix + len(search_suffix) : ix + len(search_suffix) + 4]
                        if int(fa_te_train) == fa_te:
                            decoder_old = te
                            # search_flag = 0

    # if search_flag:
    #     raise Exception('No decoder from ', str(decoder_training_te), ' and matching index: ', str(dec_ix), ' with FA training from: ',str(fa_te))
    # else:
    print "Using old decoder: ", decoder_old.path

    decoder = pickle.load(open(decoder_old.filename))
    if hasattr(decoder, "trained_fa_dict"):
        FA_dict = decoder.trained_fa_dict
    else:
        raise Exception("Make an FA dict decoder first, then re-train that")

    from db import dbfunctions as dbfn

    te_id = dbfn.TaskEntry(fa_te)

    files = dict(plexon=te_id.plx_filename, hdf=te_id.hdf_filename)
    extractor_cls = decoder.extractor_cls
    extractor_kwargs = decoder.extractor_kwargs
    extractor_kwargs["discard_zero_units"] = False
    kin_extractor = get_plant_pos_vel
    ssm = decoder.ssm
    update_rate = binlen = decoder.binlen
    units = decoder.units
    tslice = (0.0, te_id.length)

    ## get kinematic data
    kin_source = "task"
    tmask, rows = _get_tmask(files, tslice, sys_name=kin_source)
    kin = kin_extractor(files, binlen, tmask, pos_key="cursor", vel_key=None)

    ## get neural features
    neural_features, units, extractor_kwargs = get_neural_features(
        files, binlen, extractor_cls.extract_from_file, extractor_kwargs, tslice=tslice, units=units, source=kin_source
    )

    # Get main shared input:
    T = neural_features.shape[0]
    demean = neural_features.T - np.tile(FA_dict["fa_mu"], [1, T])

    # Neural features in time x spikes:
    z = FA_dict["u_svd"].T * FA_dict["uut_psi_inv"] * demean

    shar_z = FA_dict["fa_main_shared"] * demean
    priv = demean - shar_z

    # Time by features:
    neural_features2 = np.vstack((z, priv))
    decoder_split = train_KFDecoder_abstract(ssm, kin.T, neural_features2, units, update_rate, tslice=tslice)
    decoder_split.n_features = len(units)
    decoder_split.trained_fa_dict = FA_dict

    decoder_split.extractor_cls = extractor_cls
    decoder_split.extractor_kwargs = extractor_kwargs

    from db import trainbmi

    trainbmi.save_new_decoder_from_existing(decoder_split, decoder_old, suffix="_split")