Exemple #1
0
def load_model_xform(cellid,
                     batch=271,
                     modelname="ozgf100ch18_wcg18x2_fir15x2_lvl1_dexp1_fit01",
                     eval_model=True,
                     only=None):
    '''
    Load a model that was previously fit via fit_model_xforms

    Parameters
    ----------
    cellid : str
        cellid in celldb database
    batch : int
        batch number in celldb database
    modelname : str
        modelname in celldb database
    eval_model : boolean
        If true, the entire xfspec will be re-evaluated after loading.
    only : int
        Index of single xfspec step to evaluate if eval_model is False.
        For example, only=0 will typically just load the recording.

    Returns
    -------
    xfspec, ctx : nested list, dictionary

    '''

    kws = escaped_split(modelname, '_')
    old = False
    if (len(kws) > 3) or ((len(kws) == 3) and kws[1].startswith('stategain')
                          and not kws[1].startswith('stategain.')):
        # Check if modelname uses old format.
        log.info("Using old modelname format ... ")
        old = True

    d = nd.get_results_file(batch, [modelname], [cellid])
    filepath = d['modelpath'][0]
    # TODO add BAPHY_API support . Not implemented on nems_baphy yet?
    #if get_setting('USE_NEMS_BAPHY_API'):
    #    prefix = '/auto/data/nems_db' # get_setting('NEMS_RESULTS_DIR')
    #    uri = filepath.replace(prefix,
    #                           'http://' + get_setting('NEMS_BAPHY_API_HOST') + ":" + str(get_setting('NEMS_BAPHY_API_PORT')))
    #else:
    #    uri = filepath.replace('/auto/data/nems_db/results', get_setting('NEMS_RESULTS_DIR'))

    # hack: hard-coded assumption that server will use this data root
    uri = filepath.replace('/auto/data/nems_db/results',
                           get_setting('NEMS_RESULTS_DIR'))
    if old:
        raise NotImplementedError("need to use oxf library.")
        xfspec, ctx = oxf.load_analysis(uri, eval_model=eval_model)
    else:
        xfspec, ctx = xforms.load_analysis(uri,
                                           eval_model=eval_model,
                                           only=only)
    return xfspec, ctx
Exemple #2
0
def load_model_baphy_xform(
        cellid,
        batch=271,
        modelname="ozgf100ch18_wcg18x2_fir15x2_lvl1_dexp1_fit01",
        eval_model=True,
        only=None):
    '''
    DEPRECATED. Migrated to xhelp.load_model_xform()

    Load a model that was previously fit via fit_model_xforms_baphy.

    Parameters
    ----------
    cellid : str
        cellid in celldb database
    batch : int
        batch number in celldb database
    modelname : str
        modelname in celldb database
    eval_model : boolean
        If true, the entire xfspec will be re-evaluated after loading.
    only : int
        Index of single xfspec step to evaluate if eval_model is False.
        For example, only=0 will typically just load the recording.

    Returns
    -------
    xfspec, ctx : nested list, dictionary

    '''

    kws = nems.utils.escaped_split(modelname, '_')
    old = False
    if (len(kws) > 3) or ((len(kws) == 3) and kws[1].startswith('stategain')
                          and not kws[1].startswith('stategain.')):
        # Check if modelname uses old format.
        log.info("Using old modelname format ... ")
        old = True

    d = nd.get_results_file(batch, [modelname], [cellid])
    filepath = d['modelpath'][0]

    if old:
        xfspec, ctx = oxf.load_analysis(filepath, eval_model=eval_model)
    else:
        xfspec, ctx = xforms.load_analysis(filepath,
                                           eval_model=eval_model,
                                           only=only)
    return xfspec, ctx
Exemple #3
0
def _get_result_paths(batch, cellid_list, modelname):

    results_file = nd.get_results_file(batch)
    ff_cellid = results_file.cellid.isin(cellid_list)
    ff_modelname = results_file.modelname == modelname
    result_paths = results_file.loc[ff_cellid & ff_modelname,
                                    'modelpath'].tolist()

    if not result_paths:
        raise ValueError(
            'no cells in cellid_list fit with this model'.format(modelname))
    elif len(result_paths) != len(cellid_list):
        raise ValueError(
            'inconsitent cells in cell_idlist, and in loading paths\n'
            'cells')

    return result_paths
Exemple #4
0
def load_batch_modelpaths(batch, modelnames, cellids=None, eval_model=True):
    d = nd.get_results_file(batch, [modelnames], cellids=cellids)
    return d['modelpath'].tolist()
Exemple #5
0
    a = 'af0:4.as0:4.sc.rb10'
    best_alpha = pd.read_csv(
        '/auto/users/hellerc/code/projects/nat_pupil_ms_final/dprime/best_alpha.csv',
        index_col=0)
    alpha = best_alpha.loc[site][0]
    alpha = (float(alpha.split(',')[0].replace('(', '')),
             float(alpha.split(',')[1].replace(')', '')))
    a = 'af{0}.as{1}.sc.rb10'.format(
        str(alpha[0]).replace('.', ':'),
        str(alpha[1]).replace('.', ':'))
    modelname = 'ns.fs4.pup-ld-hrc-apm-pbal-psthfr-ev-residual-addmeta_lv.2xR.f.s-lvlogsig.3xR.ipsth_jk.nf5.p-pupLVbasic.constrLVonly.{}'.format(
        a)

    cellid = [c for c in nd.get_batch_cells(batch).cellid if site in c][0]
    mp = nd.get_results_file(batch, [modelname], [cellid]).modelpath[0]

    xfspec, ctx = xforms.load_analysis(mp)

    r = ctx['val'].apply_mask()
    fs = r['resp'].fs
    fast = r['lv'].extract_channels(['lv_fast'])._data.squeeze()
    slow = r['lv'].extract_channels(['lv_slow'])._data.squeeze()
    pupil = r['pupil']._data.squeeze()

    o = ss.periodogram(fast, fs=fs)
    F.append(o[1].squeeze())
    Fm.append(o[0][np.argmax(o[1].squeeze())])

    o = ss.periodogram(slow, fs=fs)
    S.append(o[1].squeeze())
Exemple #6
0
import joblib as jl
import nems.db as nd
import numpy as np
import pandas as pd

import nems.xforms as xforms
import matplotlib.pyplot as plt
import itertools as itt

batch = 310
results_file = nd.get_results_file(batch)

all_models = results_file.modelname.unique().tolist()
result_paths = results_file.modelpath.tolist()
mod_modelnames = [ss.replace('-', '_') for ss in all_models]

models_shortname = {
    'wc.2x2.c-fir.2x15-lvl.1-dexp.1': 'LN',
    'wc.2x2.c-stp.2-fir.2x15-lvl.1-dexp.1': 'STP',
    'wc.2x2.c-fir.2x15-lvl.1-stategain.18-dexp.1': 'pop',
    'wc.2x2.c-stp.2-fir.2x15-lvl.1-stategain.18-dexp.1': 'STP_pop'
}

all_cells = nd.get_batch_cells(batch=310).cellid.tolist()

goodcell = 'BRT037b-39-1'
best_model = 'wc.2x2.c-stp.2-fir.2x15-lvl.1-stategain.18-dexp.1'

test_path = '/auto/data/nems_db/results/310/BRT037b-39-1/BRT037b-39-1.wc.2x2.c_stp.2_fir.2x15_lvl.1_stategain.18_dexp.1.fit_basic.2018-11-14T093820/'

rerun = False
Exemple #7
0
# name of script that you'd like to run
script_path = '/auto/users/mateo/context_probe_analysis/cluster_script.py'

# parameters that will be passed to script.
force_rerun = True
batch = 310

# define modelspec_name
modelnames = ['wc.2x2.c-fir.2x15-lvl.1-dexp.1',
              'wc.2x2.c-stp.2-fir.2x15-lvl.1-dexp.1',
              'wc.2x2.c-fir.2x15-lvl.1-stategain.S-dexp.1',
              'wc.2x2.c-stp.2-fir.2x15-lvl.1-stategain.S-dexp.1']

# define cellids
batch_cells = set(nd.get_batch_cells(batch=batch).cellid)
full_analysis = nd.get_results_file(batch)
already_analyzed = full_analysis.cellid.unique().tolist()
# batch_cells = ['BRT037b-39-1'] # best cell


# iterates over every mode, checks what cells have not been fitted with it and runs the fit command.
for model in modelnames:
    ff_model = full_analysis.modelname == model
    already_fitted_cells = set(full_analysis.loc[ff_model, 'cellid'])

    cells_to_fit = list(batch_cells.difference(already_fitted_cells))

    print('model {}, cells to fit:\n{}'.format(model, cells_to_fit))

    out = nd.enqueue_models(celllist=cells_to_fit, batch=batch, modellist=[model], user='******', force_rerun=force_rerun,
                            executable_path=executable_path, script_path=script_path)
Exemple #8
0
    def __init__(self, batch, cellids, modelnames, parent=None):
        '''
        contexts should be a nested dictionary with the format:
            contexts = {
                    cellid1: {'model1': ctx_a, 'model2': ctx_b},
                    cellid2: {'model1': ctx_c, 'model2': ctx_d},
                    ...
                    }
        '''
        super(qw.QWidget, self).__init__()
        d = nd.get_results_file(batch=batch,
                                cellids=cellids,
                                modelnames=modelnames)
        contexts = {}
        for c in cellids:
            cell_contexts = {}
            for m in modelnames:
                try:
                    filepath = d[d.cellid == c][d.modelname ==
                                                m]['modelpath'].values[0] + '/'
                    xfspec, ctx = xforms.load_analysis(filepath,
                                                       eval_model=True)
                    cell_contexts[m] = ctx
                except IndexError:
                    print("Coudln't find modelpath for cell: %s model: %s" %
                          (c, m))
                    pass
            contexts[c] = cell_contexts
        self.contexts = contexts
        self.batch = batch
        self.cellids = cellids
        self.modelnames = modelnames

        self.time_scroller = TimeScroller(self)

        self.layout = qw.QVBoxLayout()
        self.tabs = qw.QTabWidget()
        self.comparison_tabs = []
        for k, v in self.contexts.items():
            names = list(v.keys())
            names.insert(0, 'Response')
            signals = []
            for i, m in enumerate(v):
                if i == 0:
                    resp = resp = v[list(v.keys())[0]]['val']['resp']
                    times = np.linspace(0, resp.shape[-1] / resp.fs,
                                        resp.shape[-1])
                    signals.append(resp.as_continuous().T)
                signals.append(v[m]['val']['pred'].as_continuous().T)
            if signals:
                tab = ComparisonFrame(signals, names, times, self)
                self.comparison_tabs.append(tab)
                self.tabs.addTab(tab, k)
            else:
                pass

        self.time_scroller._update_max_time()

        self.layout.addWidget(self.tabs)
        self.layout.addWidget(self.time_scroller)
        self.setLayout(self.layout)
else:
    # No regression. Load any model.
    xforms_models = [xforms_models[0]]

for xforms_modelname in xforms_models:
    log.info("Load recording from xforms model {}".format(xforms_modelname))
    if chan_nums is not None:
        cellid = [[
            c for c in nd.get_batch_cells(batch).cellid
            if (site in c) & (c.split('-')[1] in chan_nums)
        ][0]]
    else:
        cellid = [[c for c in nd.get_batch_cells(batch).cellid
                   if site in c][0]]

    mp = nd.get_results_file(batch, [xforms_modelname], cellid).modelpath[0]
    xfspec, ctx = xforms.load_analysis(mp)
    # apply hrc (and balanced, if exists) and evoked masks from xforms fit
    rec = ctx['val'].apply_mask(reset_epochs=True)

    # only necessary if using correction method 1
    if lv_regress:
        if rec['lv']._data.shape[0] > 1:
            rec['lv'] = rec['pupil']._modified_copy(rec['lv']._data[1:, :])

    rec = rec.create_mask(True)

    # filtering / pupil regression must always go first!
    if pupil_regressed & lv_regress:

        if regression_method1:
Exemple #10
0
def ddr_pred_site_sim(site,
                      batch=None,
                      modelname_base=None,
                      just_return_mean_dp=False,
                      save_fig=False,
                      skip_plot=False):
    if batch is None:
        batch = 331
        cellid = [c for c in db.get_batch_cells(batch).cellid if site in c][0]

        if len(cellid) == 0:
            batch = 322

    cellid = [c for c in db.get_batch_cells(batch).cellid if site in c][0]
    if len(cellid) == 0:
        raise ValueError(f"No match for site {site} batch {batch}")

    if modelname_base is None:
        if batch == 331:
            modelname_base = "psth.fs4.pup-ld-epcpn-hrc-psthfr.z-pca.cc1.no.p-{0}-plgsm.p2-aev-rd" + \
                             "_stategain.2xR.x1,3-spred-lvnorm.4xR.so.x2-inoise.4xR.x3" + \
                             "_tfinit.xx0.n.lr1e4.cont.et4.i50000-lvnoise.r8-aev-ccnorm.t4.f0.ss3"
        else:
            modelname_base = "psth.fs4.pup-ld-hrc-psthfr.z-pca.cc1.no.p-{0}-plgsm.p2-aev-rd" + \
                             "_stategain.2xR.x1,3-spred-lvnorm.4xR.so.x2-inoise.4xR.x3" + \
                             "_tfinit.xx0.n.lr1e4.cont.et4.i50000-lvnoise.r8-aev-ccnorm.md.t5.f0.ss3"
    log.info(f"site {site} modelname_base: {modelname_base}")

    modelnames, states = parse_modelname_base(modelname_base)
    labels = ['actual'] + states

    mse = np.zeros((len(modelnames) - 1, 3))
    cc = np.zeros((len(modelnames) - 1, 3))

    if just_return_mean_dp:
        dp = np.zeros((len(modelnames), 2))
        for i, m in enumerate(modelnames):
            modelpath = db.get_results_file(batch=batch,
                                            modelnames=[m],
                                            cellids=[cellid
                                                     ]).iloc[0]["modelpath"]

            loader = decoding.DecodingResults()
            raw_res = loader.load_results(
                os.path.join(modelpath, "decoding_TDR.pickle"))

            raw_df = raw_res.numeric_results
            raw_df = raw_df.loc[pd.IndexSlice[raw_res.evoked_stimulus_pairs,
                                              2], :].copy()

            if 'mask_bins' in raw_res.meta.keys():
                mask_bins = raw_res.meta['mask_bins']
                fit_combos = [k for k, v in raw_res.mapping.items() if (('_'.join(v[0].split('_')[:-1]), int(v[0].split('_')[-1])) in mask_bins) & \
                                                                    (('_'.join(v[1].split('_')[:-1]), int(v[1].split('_')[-1])) in mask_bins)]
                all_combos = raw_res.evoked_stimulus_pairs
                val_combos = [c for c in all_combos if c not in fit_combos]
            s = raw_df["sp_dp"] / (raw_df["sp_dp"] + raw_df["bp_dp"])
            l = raw_df["bp_dp"] / (raw_df["sp_dp"] + raw_df["bp_dp"])
            dp[i, :] = [s.mean(), l.mean]
        return labels, dp

    f, ax = plt.subplots(4,
                         len(modelnames),
                         figsize=(8, 6),
                         sharex='row',
                         sharey='row')
    for i, m in enumerate(modelnames):
        modelpath = db.get_results_file(batch=batch,
                                        modelnames=[m],
                                        cellids=[cellid]).iloc[0]["modelpath"]

        loader = decoding.DecodingResults()
        raw_res = loader.load_results(
            os.path.join(modelpath, "decoding_TDR.pickle"))

        raw_df = raw_res.numeric_results
        raw_df = raw_df.loc[pd.IndexSlice[raw_res.evoked_stimulus_pairs,
                                          2], :].copy()

        if 'mask_bins' in raw_res.meta.keys():
            mask_bins = raw_res.meta['mask_bins']
            fit_combos = [k for k, v in raw_res.mapping.items() if (('_'.join(v[0].split('_')[:-1]), int(v[0].split('_')[-1])) in mask_bins) & \
                                                                (('_'.join(v[1].split('_')[:-1]), int(v[1].split('_')[-1])) in mask_bins)]
            all_combos = raw_res.evoked_stimulus_pairs
            val_combos = [c for c in all_combos if c not in fit_combos]
            import pbd
            pdb.set_trace()

        if i == 0:
            mmraw0 = raw_df[["sp_dp", "bp_dp"]].values.max()
        ax[0, i].plot([0, mmraw0], [0, mmraw0], 'k--', lw=0.5)
        ax[0, i].scatter(raw_df["sp_dp"], raw_df["bp_dp"], s=3)
        #a, b = 'delta_pred', 'delta_act'
        a, b = 'delta_pred_raw', 'delta_act_raw'

        if i == 0:
            #raw_df = raw_df.loc[(raw_df["bp_dp"]>10) & (raw_df["sp_dp"]>10)]
            #raw_df = raw_df.loc[(raw_df["bp_dp"]<60) & (raw_df["sp_dp"]<60)]
            raw_df.loc[:,
                       'delta_act_raw'] = (raw_df["bp_dp"] - raw_df["sp_dp"])
            raw_df.loc[:,
                       'delta_act'] = (raw_df["bp_dp"] - raw_df["sp_dp"]) / (
                           raw_df["bp_dp"] + raw_df["sp_dp"])
            raw_df.loc[:, 'bp_dp_act'] = raw_df["bp_dp"]
            raw_df.loc[:, 'sp_dp_act'] = raw_df["sp_dp"]
            resp_df = raw_df[[
                'sp_dp_act', 'bp_dp_act', 'delta_act_raw', 'delta_act'
            ]]
            ax[1, 0].set_axis_off()
            ax[2, 0].set_axis_off()
            ax[3, 0].set_axis_off()
            mmraw = np.max(np.abs(raw_df[['delta_act_raw']].values))
            mmnorm = np.max(np.abs(raw_df[['delta_act']].values))
            ax[0, i].set_title(f"{labels[i]} n={len(raw_df)}")
        else:

            raw_df.loc[:,
                       'delta_pred_raw'] = (raw_df["bp_dp"] - raw_df["sp_dp"])
            raw_df.loc[:,
                       'delta_pred'] = (raw_df["bp_dp"] - raw_df["sp_dp"]) / (
                           raw_df["bp_dp"] + raw_df["sp_dp"])
            raw_df = raw_df.merge(resp_df,
                                  how='inner',
                                  left_index=True,
                                  right_index=True)

            x = np.concatenate((raw_df['bp_dp'], raw_df['sp_dp']))
            y = np.concatenate((raw_df['bp_dp_act'], raw_df['sp_dp_act']))
            cc[i - 1, 0] = np.corrcoef(x, y)[0, 1]
            mse[i - 1, 0] = np.std(x - y)

            ax[1, i].scatter(raw_df['bp_dp'],
                             raw_df['bp_dp_act'],
                             s=3,
                             alpha=0.4)
            ax[1, i].set_title(f"{cc[i-1, 0]:.3f}")

            # normed dp
            a, b = 'delta_pred', 'delta_act'
            ax[2, i].plot([-mmnorm, mmnorm], [-mmnorm, mmnorm], 'k--', lw=0.5)
            ax[2, i].scatter(raw_df[a], raw_df[b], s=3, alpha=0.4)
            cc[i - 1, 1] = np.corrcoef(raw_df[a], raw_df[b])[0, 1]
            mse[i - 1, 1] = np.std(raw_df[a] - raw_df[b])
            ax[2, i].set_title(f"e={mse[i-1,1]:.1f} cc={cc[i-1,1]:.3f}")

            # raw dp
            a, b = 'delta_pred_raw', 'delta_act_raw'
            ax[3, i].plot([-mmraw, mmraw], [-mmraw, mmraw], 'k--', lw=0.5)
            ax[3, i].scatter(raw_df[a], raw_df[b], s=3, alpha=0.4)
            cc[i - 1, 2] = np.corrcoef(raw_df[a], raw_df[b])[0, 1]
            mse[i - 1, 2] = np.std(raw_df[a] - raw_df[b])
            ax[3, i].set_title(f"e={mse[i-1,2]:.3f} cc={cc[i-1,2]:.3f}")

            ax[0, i].set_title(f"{labels[i]}")

    ax[0, 0].set_ylabel('big pupil dp')
    ax[0, 0].set_xlabel('small pupil dp')
    ax[1, 1].set_ylabel('actual big dp')
    ax[1, 1].set_xlabel('pred big dp')
    ax[2, 1].set_xlabel('pred delta raw')
    ax[2, 1].set_ylabel('act delta raw')
    ax[3, 1].set_xlabel('pred delta norm')
    ax[3, 1].set_ylabel('act delta norm')
    pupil_range = raw_res.pupil_range['range'].mean()
    f.suptitle(f"{site} - {batch} - puprange {pupil_range:.3f}")
    plt.tight_layout()

    if save_fig:
        f.savefig(
            f'/auto/users/svd/projects/pop_state/ddr_pred_{site}_{batch}.jpg')
    if skip_plot:
        plt.close(f)

    return labels[1:], cc, mse, pupil_range
Exemple #11
0
def generate_state_corrected_psth(batch=None, modelname=None, cellids=None, siteid=None, movement_mask=False,
                                        gain_only=False, dc_only=False, cache_path=None, recache=False):
    """
    Modifies the exisiting recording so that psth signal is the prediction specified
    by the modelname. Designed with stategain models in mind. CRH.

    If the model doesn't exist already in /auto/users/hellerc/results/, this
    will go ahead and fit the model and save it in /auto/users/hellerc/results.

    If the fit dir (from xforms) exists, simply reload the result and call this psth.
    """
    if siteid is None:
        raise ValueError("must specify siteid!")
    if cache_path is not None:
        fn = cache_path + siteid + '_{}.tgz'.format(modelname.split('.')[1])
        if gain_only:
            fn = fn.replace('.tgz', '_gonly.tgz')
        if 'mvm' in modelname:
            fn = fn.replace('.tgz', '_mvm.tgz')
        if (os.path.isfile(fn)) & (recache == False):
            rec = Recording.load(fn)
            return rec
        else:
            # do the rest of the code
            pass

    if batch is None or modelname is None:
        raise ValueError('Must specify batch and modelname!')
    results_table = nd.get_results_file(batch, modelnames=[modelname])
    preds = []
    ms = []
    for cell in cellids:
        log.info(cell)
        try:
            p = results_table[results_table['cellid']==cell]['modelpath'].values[0]
            if os.path.isdir(p):
                xfspec, ctx = xforms.load_analysis(p)
                preds.append(ctx['val'])
                ms.append(ctx['modelspec'])
            else:
                sys.exit('Fit for {0} does not exist'.format(cell))
        except:
            log.info("WARNING: fit doesn't exist for cell {0}".format(cell))

    # Need to add a check to make sure that the preds are the same length (if
    # multiple cellids). This could be violated if one cell for example existed
    # in a prepassive run but the other didn't and so they were fit differently
    file_epochs = []

    for pr in preds:
        file_epochs += [ep for ep in pr.epochs.name if ep.startswith('FILE')]

    unique_files = np.unique(file_epochs)
    shared_files = []
    for f in unique_files:
        if np.sum([1 for file in file_epochs if file == f]) == len(preds):
            shared_files.append(str(f))
        else:
            # this rawid didn't span all cells at the requested site
            pass

    # mask all file epochs for all preds with the shared file epochs
    # and adjust epochs
    if (int(batch) == 307) | (int(batch) == 294):
        for i, p in enumerate(preds):
            preds[i] = p.and_mask(shared_files)
            preds[i] = preds[i].apply_mask(reset_epochs=True)

    sigs = {}
    for i, p in enumerate(preds):
        if gain_only:
            # update phi
            mspec = ms[i]
            not_gain_keys = [k for k in mspec[0]['phi'].keys() if '_g' not in k]
            for k in not_gain_keys:
                mspec[0]['phi'][k] = np.append(mspec[0]['phi'][k][0, 0], np.zeros(mspec[0]['phi'][k].shape[-1]-1))[np.newaxis, :]
            pred = mspec.evaluate(p)['pred']
        elif dc_only:
            mspec = ms[i]
            not_dc_keys = [k for k in mspec[0]['phi'].keys() if '_d' not in k]
            for k in not_dc_keys:
                mspec[0]['phi'][k] = np.append(mspec[0]['phi'][k][0, 0], np.zeros(mspec[0]['phi'][k].shape[-1]-1))[np.newaxis, :]
            pred = mspec.evaluate(p)['pred']
        else:
            pred = p['pred'] 
        if i == 0:           
            new_psth_sp = p['psth_sp']
            new_psth = pred
            new_resp = p['resp'].rasterize()

        else:
            try:
                new_psth_sp = new_psth_sp.concatenate_channels([new_psth_sp, p['psth_sp']])
                new_psth = new_psth.concatenate_channels([new_psth, pred])
                new_resp = new_resp.concatenate_channels([new_resp, p['resp'].rasterize()])
            except ValueError:
                import pdb; pdb.set_trace()

    new_pup = preds[0]['pupil']
    sigs['pupil'] = new_pup

    if 'pupil_raw' in preds[0].signals.keys():
        sigs['pupil_raw'] = preds[0]['pupil_raw']

    if 'mask' in preds[0].signals:
        new_mask = preds[0]['mask']
        sigs['mask'] = new_mask
    else:
        mask_rec = preds[0].create_mask(True)
        new_mask = mask_rec['mask']
        sigs['mask'] = new_mask

    if 'rem' in preds[0].signals.keys():
        rem = preds[0]['rem']
        sigs['rem'] = rem

    if 'pupil_eyespeed' in preds[0].signals.keys():
        new_eyespeed = preds[0]['pupil_eyespeed']
        sigs['pupil_eyespeed'] = new_eyespeed

    new_psth_sp.name = 'psth_sp'
    new_psth.name = 'psth'
    new_resp.name = 'resp'
    sigs['psth_sp'] = new_psth_sp
    sigs['psth'] = new_psth
    sigs['resp'] = new_resp

    new_rec = Recording(sigs, meta=preds[0].meta)

    # make sure mask is cast to bool
    new_rec['mask'] = new_rec['mask']._modified_copy(new_rec['mask']._data.astype(bool))

    if cache_path is not None:
        log.info('caching {}'.format(fn))
        new_rec.save_targz(fn)

    return new_rec
Exemple #12
0
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 23 10:11:40 2018

@author: svd
"""

import nems.plots.api as nplt
import nems.db as nd
import nems.xforms as xforms
from nems.gui.recording_browser import browse_recording, browse_context

cellid = 'TAR010c-18-1'
batch = 271
#modelname = 'wc.18x1.g-fir.1x15-lvl.1'
modelname = 'dlog-wc.18x1.g-fir.1x15-lvl.1'
#modelname = 'dlog-wc.18x1.g-stp.1-fir.1x15-lvl.1-dexp.1'

d = nd.get_results_file(batch=batch, cellids=[cellid], modelnames=[modelname])

filepath = d['modelpath'][0] + '/'
xfspec, ctx = xforms.load_analysis(filepath, eval_model=False)

ctx, log_xf = xforms.evaluate(xfspec, ctx)

#nplt.quickplot(ctx)
ctx['modelspec'].quickplot(ctx['val'])

aw = browse_context(ctx, signals=['stim', 'pred', 'resp'])
Exemple #13
0
load_dprime = False

# fast LV with pupil (gain model with sigmoid nonlinearity)
modelname = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta_slogsig.SxR-lv.1xR.f.pred-lvlogsig.2xR_jk.nf5.p-pupLVbasic.constrLVonly.af0:2.sc'

# single module for LV (testing LV modeling architectures)
#modelname = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta_puplvmodel.pred.step.dc.R_jk.nf5.p-pupLVbasic.constrLVonly.af0:3.sc.rb10'
#modelname0 = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta_puplvmodel.dc.pupOnly.R_jk.nf5.p-pupLVbasic.constrLVonly.af0:0.sc'

# without jackknifing (or cross validation)
modelname = 'ns.fs4.pup-ld-st.pup-epsig-hrc-apm-pbal-psthfr-ev-addmeta-aev_puplvmodel.pred.step.g.dc.R_pupLVbasic.constrNC.af0:1.sc.rb2'
modelname0 = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta-aev_puplvmodel.g.dc.pupOnly.R_pupLVbasic.constrLVonly.af0:0.sc.rb2'
xforms_model = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-addmeta-aev_puplvmodel.pred.step.pfix.dc.R_pupLVbasic.constrLVonly.af0:0.sc.rb10'
if load:
    c = [c for c in nd.get_batch_cells(batch).cellid if cellid in c][0]
    mp = nd.get_results_file(batch, [modelname], [c]).modelpath[0]
    _, ctx = xforms.load_analysis(mp)
    mp = nd.get_results_file(batch, [modelname0], [c]).modelpath[0]
    _, ctx2 = xforms.load_analysis(mp)
else:
    ctx = xfit.fit_xforms_model(batch, cellid, modelname, save_analysis=False)
    ctx2 = xfit.fit_xforms_model(batch,
                                 cellid,
                                 modelname0,
                                 save_analysis=False)

# plot lv, pupil, PC1 timecourses
if '.g.' in modelname:
    key1 = 'pg'
    key2 = 'lvg'
if '.dc.' in modelname:
Exemple #14
0
# mns = [
#     'env.fs100-SPOld-stSPO.nb-SPOsev_dlog-stategain.2x2.g.o1.b0d001:5-fir.2x15.z-lvl.1_SDB-init.rb5-basic.t7-SPOpf.Exa',
#     'env.fs100-SPOld-stSPO.nb-SPOsev_dlog-stategain.2x2.g.o1.b0d001:5-fir.2x15.z-lvl.1_SDB-init.rb10-basic.t7-SPOpf.Exa'
#     ]
# mns=['env.fs200-SPOld-stSPO.nb-SPOsev-shuf.st_dlog-stategain.2x2.g.o1.b0d001:5-fir.2x30.z-lvl.1-dexp.1_SDB-init.t5.rb5-basic.t6-SPOpf.Exa',
#      'env.fs200-SPOld-stSPO.nb-SPOsev_dlog-stategain.2x2.g.o1.b0d001:5-fir.2x30.z-lvl.1-dexp.1_SDB-init.t5.rb5-basic.t6-SPOpf.Exa',
#      'env.fs200-SPOld-stSPO.nb-SPOsev-shuf.st_dlog-stategain.2x2.g.o1.b0d001:5-wc.2x2.c-stp.2-fir.2x30.z-lvl.1-dexp.1_SDB-init.t5.rb5-basic.t6-SPOpf.Exa',
#      'env.fs200-SPOld-stSPO.nb-SPOsev_dlog-stategain.2x2.g.o1.b0d001:5-wc.2x2.c-stp.2-fir.2x30.z-lvl.1-dexp.1_SDB-init.t5.rb5-basic.t6-SPOpf.Exa'
#      ]
# mns = [mn.replace('fs200','fs100').replace('2x30','2x15') for mn in mns]
# comparisons=((0,1),(2,3),(0,2))

## Get df of modefits
#cells = sp.get_significant_cells(batch,mns,as_list=True) #Sig cells across all models
#cells = sp.get_significant_cells(batch,mns[:1],as_list=True) #Sig cells in the first model
dfc = nd.get_results_file(batch,mns[:1]); cells = list(dfc['cellid'].values) # All cells fit in the first model
#cells = [cell for cell in cells if 'fre' not in cell]; print('Keeping only fred cells')
cells = list(np.load('/auto/users/luke/Projects/SPS/NEMS/fre_oldfit_newfit_common_subset.npy', allow_pickle=True))
print(f'{len(cells)} cells')
df = nd.get_results_file(batch,mns,cells)

cells_fit_in_all = set(df.loc[df['modelname']==mns[0],'cellid'])
for mn_ in mns[1:]:
    cells_fit_in_all = cells_fit_in_all.intersection(set(df.loc[df['modelname']==mn_,'cellid']))
df = df[df['cellid'].isin(cells_fit_in_all)]
print(f'Dropped not fit, down to {len(df)/len(mns)} cells per model')


## Define fnargs, arguments tha will be passed to a function called when you click on a point
fnargs = [{'ax': imageax, 'ft': 5, 'data_series_dict': 'dsx'},
          {'ax': imageax2, 'ft': 5, 'data_series_dict': 'dsy'}]
Exemple #15
0
modelname = [
    'ns.fs4.pup-ld-st.pup-hrc-pbal-psthfr-ev_slogsig.SxR-lv.1xR-lvlogsig.2xR_jk.nf5.p-pupLVbasic.a0:009'
]
# NC constraint
modelname = [
    'ns.fs4.pup-ld-st.pup-hrc-apm-psthfr-ev_slogsig.SxR-lv.1xR-lvlogsig.2xR_jk.nf5.p-pupLVbasic.constrNC.a0:35'
]
modelname = [
    'ns.fs4.pup-ld-st.pup-hrc-apm-psthfr-ev-residual_slogsig.SxR-lv.1xR-lvlogsig.2xR_jk.nf2.p-pupLVbasic.constrNC.a0:05'
]
modelname2 = ['ns.fs4.pup-ld-st.pup-hrc-psthfr-ev_slogsig.SxR_jk.nf5.p-basic']
batch = 289
cellids = nd.get_batch_cells(batch).cellid
cellid = [[c for c in cellids if site in c][0]]

mp = nd.get_results_file(batch, modelname, cellid).modelpath[0]
mp2 = nd.get_results_file(batch, modelname2, cellid).modelpath[0]

xfspec, ctx = xforms.load_analysis(mp)
xfspec2, ctx2 = xforms.load_analysis(mp2)

# plot summary of fit
ctx2['modelspec'].quickplot()
ctx['modelspec'].quickplot()

rec = ctx['val'].apply_mask(reset_epochs=True).copy()  # raw recording
rec['lv'] = rec['lv']._modified_copy(rec['lv']._data[1, :][np.newaxis, :])
rec1 = ctx2['val'].apply_mask(
    reset_epochs=True).copy()  # first order regression
rec12 = rec.copy()  # first / second order regression