Example #1
0
def reconsitute_rec(batch, cellid_list, modelname):
    '''
    Takes a group of single cell recordings (from cells of a population recording) including their model predictions,
    and builds a recording withe signals containing the responses and predictions of all the cells in the population
    This is to make the recordigs compatible with downstream dispersion analisis or any analysis working with signals
    of neuronal populations
    :param batch: int batch number
    :param cellid_list: [str, str ...] list of cell IDs
    :param modelname: str. modelaname
    :return: NEMS Recording object
    '''

    result_paths = _get_result_paths(batch, cellid_list, modelname)

    cell_resp_dict = dict()
    cell_pred_dict = col.defaultdict()

    for ff, filepath in enumerate(result_paths):
        # use modelsepcs to predict the response of resp
        xfspec, ctx = xforms.load_analysis(filepath=filepath,
                                           eval_model=False,
                                           only=slice(0, 2, 1))
        modelspecs = ctx['modelspecs'][0]
        cellid = modelspecs[0]['meta']['cellid']
        real_modelname = modelspecs[0]['meta']['modelname']
        rec = ctx['rec'].copy()
        rec = ms.evaluate(
            rec, modelspecs)  # recording containing signal for resp and pred

        # holds and organizes the raw data, keeping track of the cell for later concatenations.
        cell_resp_dict.update(
            rec['resp']._data
        )  # in PointProcess signals _data is already a dict, thus the use of update
        cell_pred_dict[cellid] = rec[
            'pred']._data  # in Rasterized signals _data is a matrix, thus the requirement to asign key.

    # create a new population recording. pull stim from last single cell, create signal from meta form last resp signal and
    # stacked data for all cells. modify signal metadata to be consistent with new data and cells contained
    pop_resp = rec['resp']._modified_copy(data=cell_resp_dict,
                                          chans=list(cell_resp_dict.keys()),
                                          nchans=len(
                                              list(cell_resp_dict.keys())))

    stack_data = np.concatenate(list(cell_pred_dict.values()), axis=0)
    pop_pred = rec['pred']._modified_copy(data=stack_data,
                                          chans=list(cell_pred_dict.keys()),
                                          nchans=len(
                                              list(cell_pred_dict.keys())))

    reconstituted_recording = rec.copy()

    reconstituted_recording['resp'] = pop_resp
    reconstituted_recording['pred'] = pop_pred
    del reconstituted_recording.signals['state']
    del reconstituted_recording.signals['state_raw']

    return reconstituted_recording
Example #2
0
def load_model_xform(cellid,
                     batch=271,
                     modelname="ozgf100ch18_wcg18x2_fir15x2_lvl1_dexp1_fit01",
                     eval_model=True,
                     only=None):
    '''
    Load a model that was previously fit via fit_model_xforms

    Parameters
    ----------
    cellid : str
        cellid in celldb database
    batch : int
        batch number in celldb database
    modelname : str
        modelname in celldb database
    eval_model : boolean
        If true, the entire xfspec will be re-evaluated after loading.
    only : int
        Index of single xfspec step to evaluate if eval_model is False.
        For example, only=0 will typically just load the recording.

    Returns
    -------
    xfspec, ctx : nested list, dictionary

    '''

    kws = escaped_split(modelname, '_')
    old = False
    if (len(kws) > 3) or ((len(kws) == 3) and kws[1].startswith('stategain')
                          and not kws[1].startswith('stategain.')):
        # Check if modelname uses old format.
        log.info("Using old modelname format ... ")
        old = True

    d = nd.get_results_file(batch, [modelname], [cellid])
    filepath = d['modelpath'][0]
    # TODO add BAPHY_API support . Not implemented on nems_baphy yet?
    #if get_setting('USE_NEMS_BAPHY_API'):
    #    prefix = '/auto/data/nems_db' # get_setting('NEMS_RESULTS_DIR')
    #    uri = filepath.replace(prefix,
    #                           'http://' + get_setting('NEMS_BAPHY_API_HOST') + ":" + str(get_setting('NEMS_BAPHY_API_PORT')))
    #else:
    #    uri = filepath.replace('/auto/data/nems_db/results', get_setting('NEMS_RESULTS_DIR'))

    # hack: hard-coded assumption that server will use this data root
    uri = filepath.replace('/auto/data/nems_db/results',
                           get_setting('NEMS_RESULTS_DIR'))
    if old:
        raise NotImplementedError("need to use oxf library.")
        xfspec, ctx = oxf.load_analysis(uri, eval_model=eval_model)
    else:
        xfspec, ctx = xforms.load_analysis(uri,
                                           eval_model=eval_model,
                                           only=only)
    return xfspec, ctx
Example #3
0
def load_model_baphy_xform(
        cellid,
        batch=271,
        modelname="ozgf100ch18_wcg18x2_fir15x2_lvl1_dexp1_fit01",
        eval_model=True,
        only=None):
    '''
    DEPRECATED. Migrated to xhelp.load_model_xform()

    Load a model that was previously fit via fit_model_xforms_baphy.

    Parameters
    ----------
    cellid : str
        cellid in celldb database
    batch : int
        batch number in celldb database
    modelname : str
        modelname in celldb database
    eval_model : boolean
        If true, the entire xfspec will be re-evaluated after loading.
    only : int
        Index of single xfspec step to evaluate if eval_model is False.
        For example, only=0 will typically just load the recording.

    Returns
    -------
    xfspec, ctx : nested list, dictionary

    '''

    kws = nems.utils.escaped_split(modelname, '_')
    old = False
    if (len(kws) > 3) or ((len(kws) == 3) and kws[1].startswith('stategain')
                          and not kws[1].startswith('stategain.')):
        # Check if modelname uses old format.
        log.info("Using old modelname format ... ")
        old = True

    d = nd.get_results_file(batch, [modelname], [cellid])
    filepath = d['modelpath'][0]

    if old:
        xfspec, ctx = oxf.load_analysis(filepath, eval_model=eval_model)
    else:
        xfspec, ctx = xforms.load_analysis(filepath,
                                           eval_model=eval_model,
                                           only=only)
    return xfspec, ctx
Example #4
0
def dstrf_sample(ctx=None, cellid='TAR010c-18-2', savepath=None):

    if ctx is None:
        xf, ctx = load_analysis(savepath)

    rec = ctx['val']
    modelspec = ctx['modelspec']
    cellids = rec['resp'].chans
    match = [c == cellid for c in cellids]
    c = np.where(match)[0][0]
    maxbins = 1000
    stepbins = 3
    memory = 15

    # analyze all output channels
    out_channel = [c]
    channel_count = len(out_channel)

    stim_mag = rec['stim'].as_continuous()[:, :maxbins].sum(axis=0)
    index_range = np.arange(0, len(stim_mag), 1)
    log.info(
        'Calculating dstrf for %d channels, %d timepoints (%d steps), memory=%d',
        channel_count, len(index_range), stepbins, memory)
    dstrf = compute_dstrf(modelspec,
                          rec.copy(),
                          out_channel=out_channel,
                          memory=memory,
                          index_range=index_range)

    rr = (150, 550)
    dindex = [52, 54, 56, 65, 133, 165, 215, 220, 233, 240, 250]

    f = dstrf_details(modelspec,
                      rec,
                      cellid,
                      rr,
                      dindex,
                      dstrf=dstrf,
                      dpcs=None,
                      maxbins=maxbins)
    f.suptitle(f"{cellid} r_test={modelspec.meta['r_test'][c][0]:.3f}")

    return f, dstrf
Example #5
0
def kamiak_to_database(cellids,
                       batch,
                       modelnames,
                       source_path,
                       executable_path=None,
                       script_path=None):

    user = '******'
    linux_user = '******'
    allowqueuemaster = 1
    waitid = 0
    parmstring = ''
    rundataid = 0
    priority = 1
    reserve_gb = 0
    codeHash = 'kamiak'

    if executable_path in [None, 'None', 'NONE', '']:
        executable_path = get_setting('DEFAULT_EXEC_PATH')
    if script_path in [None, 'None', 'NONE', '']:
        script_path = get_setting('DEFAULT_SCRIPT_PATH')

    combined = [(c, b, m)
                for c, b, m in itertools.product(cellids, [batch], modelnames)]
    notes = ['%s/%s/%s' % (c, b, m) for c, b, m in combined]
    commandPrompts = [
        "%s %s %s %s %s" % (executable_path, script_path, c, b, m)
        for c, b, m in combined
    ]

    engine = nd.Engine()
    for (c, b, m), note, commandPrompt in zip(combined, notes, commandPrompts):
        path = os.path.join(source_path, batch, c, m)
        if not os.path.exists(path):
            log.warning("missing fit for: \n%s\n%s\n%s\n"
                        "using path: %s\n", batch, c, m, path)
            continue
        else:
            xfspec, ctx = xforms.load_analysis(path, eval_model=False)
            preview = ctx['modelspec'].meta.get('figurefile', None)
            if 'log' not in ctx:
                ctx['log'] = 'missing log'
            figures_to_load = ctx['figures_to_load']
            figures = [xforms.load_resource(f) for f in figures_to_load]
            ctx['figures'] = figures
            xforms.save_analysis(None, None, ctx['modelspec'], xfspec,
                                 ctx['figures'], ctx['log'])
            nd.update_results_table(ctx['modelspec'], preview=preview)

        conn = engine.connect()
        sql = 'SELECT * FROM tQueue WHERE note="' + note + '"'
        r = conn.execute(sql)
        if r.rowcount > 0:
            # existing job, figure out what to do with it
            x = r.fetchone()
            queueid = x['id']
            complete = x['complete']

            if complete == 1:
                # Do nothing - the queue already shows a complete job
                pass

            elif complete == 2:
                # Change dead to complete
                sql = "UPDATE tQueue SET complete=1, killnow=0 WHERE id={}".format(
                    queueid)
                r = conn.execute(sql)

            else:
                # complete in [-1, 0] -- already running or queued
                # Do nothing
                pass

        else:
            # New job
            sql = "INSERT INTO tQueue (rundataid,progname,priority," +\
                   "reserve_gb,parmstring,allowqueuemaster,user," +\
                   "linux_user,note,waitid,codehash,queuedate,complete) VALUES"+\
                   " ({},'{}',{}," +\
                   "{},'{}',{},'{}'," +\
                   "'{}','{}',{},'{}',NOW(),1)"

            sql = sql.format(rundataid, commandPrompt, priority, reserve_gb,
                             parmstring, allowqueuemaster, user, linux_user,
                             note, waitid, codeHash)
            r = conn.execute(sql)

        conn.close()
Example #6
0
all_cells = nd.get_batch_cells(batch=310).cellid.tolist()

goodcell = 'BRT037b-39-1'
best_model = 'wc.2x2.c-stp.2-fir.2x15-lvl.1-stategain.18-dexp.1'

test_path = '/auto/data/nems_db/results/310/BRT037b-39-1/BRT037b-39-1.wc.2x2.c_stp.2_fir.2x15_lvl.1_stategain.18_dexp.1.fit_basic.2018-11-14T093820/'

rerun = False

# compare goodness of fit between models
#   iteratively go trough file
if rerun == True:
    population_metas = list()
    for filepath in result_paths:
        _, ctx = xforms.load_analysis(filepath=filepath,
                                      eval_model=True,
                                      only=None)
        meta = ctx['modelspecs'][0][0]['meta']
        #   extract important values into a dictionary
        subset_keys = ['cellid', 'modelname', 'r_test', 'r_fit']
        d = {k: v for k, v in meta.items() if k in subset_keys}
        d['r_test'] = d['r_test'][0]
        d['r_fit'] = d['r_fit'][0]
        population_metas.append(d)
    meta_DF = pd.DataFrame(population_metas)
    jl.dump(meta_DF,
            '/home/mateo/code/context_probe_analysis/pickles/summary_metrics')
else:
    meta_DF = jl.load(
        '/home/mateo/code/context_probe_analysis/pickles/summary_metrics')
Example #7
0
 def load_xfrom_from_folder(self, directory, eval_model=True):
     """Loads an xform/context from a directory."""
     xfspec, ctx = xforms.load_analysis(str(directory),
                                        eval_model=eval_model)
     return xfspec, ctx
Example #8
0
    def __init__(self, batch, cellids, modelnames, parent=None):
        '''
        contexts should be a nested dictionary with the format:
            contexts = {
                    cellid1: {'model1': ctx_a, 'model2': ctx_b},
                    cellid2: {'model1': ctx_c, 'model2': ctx_d},
                    ...
                    }
        '''
        super(qw.QWidget, self).__init__()
        d = nd.get_results_file(batch=batch,
                                cellids=cellids,
                                modelnames=modelnames)
        contexts = {}
        for c in cellids:
            cell_contexts = {}
            for m in modelnames:
                try:
                    filepath = d[d.cellid == c][d.modelname ==
                                                m]['modelpath'].values[0] + '/'
                    xfspec, ctx = xforms.load_analysis(filepath,
                                                       eval_model=True)
                    cell_contexts[m] = ctx
                except IndexError:
                    print("Coudln't find modelpath for cell: %s model: %s" %
                          (c, m))
                    pass
            contexts[c] = cell_contexts
        self.contexts = contexts
        self.batch = batch
        self.cellids = cellids
        self.modelnames = modelnames

        self.time_scroller = TimeScroller(self)

        self.layout = qw.QVBoxLayout()
        self.tabs = qw.QTabWidget()
        self.comparison_tabs = []
        for k, v in self.contexts.items():
            names = list(v.keys())
            names.insert(0, 'Response')
            signals = []
            for i, m in enumerate(v):
                if i == 0:
                    resp = resp = v[list(v.keys())[0]]['val']['resp']
                    times = np.linspace(0, resp.shape[-1] / resp.fs,
                                        resp.shape[-1])
                    signals.append(resp.as_continuous().T)
                signals.append(v[m]['val']['pred'].as_continuous().T)
            if signals:
                tab = ComparisonFrame(signals, names, times, self)
                self.comparison_tabs.append(tab)
                self.tabs.addTab(tab, k)
            else:
                pass

        self.time_scroller._update_max_time()

        self.layout.addWidget(self.tabs)
        self.layout.addWidget(self.time_scroller)
        self.setLayout(self.layout)
Example #9
0
#batch, cellid = 269, 'chn019a-a1'
#batch, cellid = 269, 'oys042c-d1'
#batch, cellid = 273, 'chn041d-b1'
#batch, cellid = 273, 'zee027b-c1'
batch, cellid = 269, 'btn144a-a1'
#modelspec = 'RDTwcg18x2-RDTfir2x15_RDTstreamgain_lvl1_dexp1'
#keywordstring = 'dlog-wc.18x1.g-fir.1x15-lvl.1'
#keywordstring = 'rdtwc.18x1.g-rdtfir.1x15-rdtgain.relative.NTARGETS-lvl.1'
keywordstring = 'rdtgain.gen.NTARGETS-rdtmerge.stim-wc.18x1.g-fir.1x15-lvl.1'

modelname = 'rdtld-rdtshf.rep-rdtsev-rdtfmt_' + keywordstring + '_init-basic'

savefile = nw.fit_model_xforms_baphy(cellid, batch, modelname, saveInDB=False)
#xf,ctx = nw.load_model_baphy_xform(cellid,batch,modelname)
xf, ctx = xforms.load_analysis(savefile)
# browse_context(ctx, 'val', signals=['stim', 'resp', 'fg_sf', 'bg_sf', 'state'])
"""
# database-free version
recording_uri = '/Users/svd/python/nems/recordings/chn019a_e3a6a2e25b582125a7a6ee98d8f8461557ae0cf7.tgz'
#recording_uri = '/Users/svd/python/nems/recordings/chn019a_16e888cad7fef05b2f51c58874bd07040ae80903.tgz'
shuff_streams=False
shuff_rep=False
xfspec = [
    ('nems.xforms.init_context', {'batch': batch, 'cellid': cellid, 'keywordstring': keywordstring,
                                  'recording_uri': recording_uri}),
    ('nems_lbhb.rdt.io.load_recording', {}),
    ('nems_lbhb.rdt.preprocessing.rdt_shuffle', {'shuff_streams': shuff_streams, 'shuff_rep': shuff_rep}),
    ('nems_lbhb.rdt.preprocessing.split_est_val', {}),
    ('nems_lbhb.rdt.xforms.format_keywordstring', {}),
    ('nems.xforms.init_from_keywords', {}),
Example #10
0
        self.plot_container[id(layer_area)] = layer_area
        self.add_collapsible_dock(layer_area,
                                  window_title=f'{recording}:{signal}')

        # layer_area.plotWidget.update_plot(y_data=signal_data, y_data_name=signal)
        layer_area.parent().parent().set_toggle(True)

        # add the plot types to the combo box
        layer_area.comboBox.blockSignals(True)
        layer_area.comboBox.addItems(PG_PLOTS.keys())
        layer_area.comboBox.blockSignals(False)

        # only link if shapes match
        ## TODO: make this a try in the plot widget
        # if self.ctx['val']['stim']._data.shape[-1] == signal_data.shape[-1]:
        #     self.link_together(layer_area.plotWidget)

        layer_area.update_plot()


if __name__ == '__main__':
    from nems import xforms
    demo_model = '/auto/data/nems_db/results/322/ARM030a-28-2/ozgf.fs100.ch18-ld-sev.dlog-wc.18x3.g-fir.3x15-lvl.1-dexp.1.tfinit.n.lr1e3.rb5.es20-newtf.n.lr1e4.es20.2021-06-15T212246'
    xfspec, ctx = xforms.load_analysis(demo_model)

    #xfspec, ctx = xforms.load_context(r'C:\Users\Alex\PycharmProjects\NEMS\results\temp_xform')

    app = QApplication(sys.argv)
    window = MainWindow(ctx=ctx, xfspec=xfspec)
    app.exec_()
Example #11
0
def generate_state_corrected_psth(batch=None, modelname=None, cellids=None, siteid=None, movement_mask=False,
                                        gain_only=False, dc_only=False, cache_path=None, recache=False):
    """
    Modifies the exisiting recording so that psth signal is the prediction specified
    by the modelname. Designed with stategain models in mind. CRH.

    If the model doesn't exist already in /auto/users/hellerc/results/, this
    will go ahead and fit the model and save it in /auto/users/hellerc/results.

    If the fit dir (from xforms) exists, simply reload the result and call this psth.
    """
    if siteid is None:
        raise ValueError("must specify siteid!")
    if cache_path is not None:
        fn = cache_path + siteid + '_{}.tgz'.format(modelname.split('.')[1])
        if gain_only:
            fn = fn.replace('.tgz', '_gonly.tgz')
        if 'mvm' in modelname:
            fn = fn.replace('.tgz', '_mvm.tgz')
        if (os.path.isfile(fn)) & (recache == False):
            rec = Recording.load(fn)
            return rec
        else:
            # do the rest of the code
            pass

    if batch is None or modelname is None:
        raise ValueError('Must specify batch and modelname!')
    results_table = nd.get_results_file(batch, modelnames=[modelname])
    preds = []
    ms = []
    for cell in cellids:
        log.info(cell)
        try:
            p = results_table[results_table['cellid']==cell]['modelpath'].values[0]
            if os.path.isdir(p):
                xfspec, ctx = xforms.load_analysis(p)
                preds.append(ctx['val'])
                ms.append(ctx['modelspec'])
            else:
                sys.exit('Fit for {0} does not exist'.format(cell))
        except:
            log.info("WARNING: fit doesn't exist for cell {0}".format(cell))

    # Need to add a check to make sure that the preds are the same length (if
    # multiple cellids). This could be violated if one cell for example existed
    # in a prepassive run but the other didn't and so they were fit differently
    file_epochs = []

    for pr in preds:
        file_epochs += [ep for ep in pr.epochs.name if ep.startswith('FILE')]

    unique_files = np.unique(file_epochs)
    shared_files = []
    for f in unique_files:
        if np.sum([1 for file in file_epochs if file == f]) == len(preds):
            shared_files.append(str(f))
        else:
            # this rawid didn't span all cells at the requested site
            pass

    # mask all file epochs for all preds with the shared file epochs
    # and adjust epochs
    if (int(batch) == 307) | (int(batch) == 294):
        for i, p in enumerate(preds):
            preds[i] = p.and_mask(shared_files)
            preds[i] = preds[i].apply_mask(reset_epochs=True)

    sigs = {}
    for i, p in enumerate(preds):
        if gain_only:
            # update phi
            mspec = ms[i]
            not_gain_keys = [k for k in mspec[0]['phi'].keys() if '_g' not in k]
            for k in not_gain_keys:
                mspec[0]['phi'][k] = np.append(mspec[0]['phi'][k][0, 0], np.zeros(mspec[0]['phi'][k].shape[-1]-1))[np.newaxis, :]
            pred = mspec.evaluate(p)['pred']
        elif dc_only:
            mspec = ms[i]
            not_dc_keys = [k for k in mspec[0]['phi'].keys() if '_d' not in k]
            for k in not_dc_keys:
                mspec[0]['phi'][k] = np.append(mspec[0]['phi'][k][0, 0], np.zeros(mspec[0]['phi'][k].shape[-1]-1))[np.newaxis, :]
            pred = mspec.evaluate(p)['pred']
        else:
            pred = p['pred'] 
        if i == 0:           
            new_psth_sp = p['psth_sp']
            new_psth = pred
            new_resp = p['resp'].rasterize()

        else:
            try:
                new_psth_sp = new_psth_sp.concatenate_channels([new_psth_sp, p['psth_sp']])
                new_psth = new_psth.concatenate_channels([new_psth, pred])
                new_resp = new_resp.concatenate_channels([new_resp, p['resp'].rasterize()])
            except ValueError:
                import pdb; pdb.set_trace()

    new_pup = preds[0]['pupil']
    sigs['pupil'] = new_pup

    if 'pupil_raw' in preds[0].signals.keys():
        sigs['pupil_raw'] = preds[0]['pupil_raw']

    if 'mask' in preds[0].signals:
        new_mask = preds[0]['mask']
        sigs['mask'] = new_mask
    else:
        mask_rec = preds[0].create_mask(True)
        new_mask = mask_rec['mask']
        sigs['mask'] = new_mask

    if 'rem' in preds[0].signals.keys():
        rem = preds[0]['rem']
        sigs['rem'] = rem

    if 'pupil_eyespeed' in preds[0].signals.keys():
        new_eyespeed = preds[0]['pupil_eyespeed']
        sigs['pupil_eyespeed'] = new_eyespeed

    new_psth_sp.name = 'psth_sp'
    new_psth.name = 'psth'
    new_resp.name = 'resp'
    sigs['psth_sp'] = new_psth_sp
    sigs['psth'] = new_psth
    sigs['resp'] = new_resp

    new_rec = Recording(sigs, meta=preds[0].meta)

    # make sure mask is cast to bool
    new_rec['mask'] = new_rec['mask']._modified_copy(new_rec['mask']._data.astype(bool))

    if cache_path is not None:
        log.info('caching {}'.format(fn))
        new_rec.save_targz(fn)

    return new_rec
Example #12
0
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 23 10:11:40 2018

@author: svd
"""

import nems.plots.api as nplt
import nems.db as nd
import nems.xforms as xforms
from nems.gui.recording_browser import browse_recording, browse_context

cellid = 'TAR010c-18-1'
batch = 271
#modelname = 'wc.18x1.g-fir.1x15-lvl.1'
modelname = 'dlog-wc.18x1.g-fir.1x15-lvl.1'
#modelname = 'dlog-wc.18x1.g-stp.1-fir.1x15-lvl.1-dexp.1'

d = nd.get_results_file(batch=batch, cellids=[cellid], modelnames=[modelname])

filepath = d['modelpath'][0] + '/'
xfspec, ctx = xforms.load_analysis(filepath, eval_model=False)

ctx, log_xf = xforms.evaluate(xfspec, ctx)

#nplt.quickplot(ctx)
ctx['modelspec'].quickplot(ctx['val'])

aw = browse_context(ctx, signals=['stim', 'pred', 'resp'])
Example #13
0
ax.set_xlabel(model, fontsize=6)
ax.set_ylabel(lp_model, fontsize=6)
ax.plot([-1, 1], [-1, 1], 'k--', zorder=-1)
ax.axhline(0, linestyle='--', color='k')
ax.axvline(0, linestyle='--', color='k')
ax.axis('equal')
ax.set_title("gain weights for LV")
f.tight_layout()

plt.show()

import nems.xforms as xforms
from nems.xform_helper import load_model_xform

# load an old model
modelpath = '/auto/data/nems_db/results/331/AMT020a/psth.fs4.pup-loadpred.cpnmvm-st.pup.pvp0-plgsm.e10.sp-lvnoise.r8-aev.lvnorm.SxR.d.so-inoise.2xR.ccnorm.t5.ss1.2021-06-21T174425'
xf, ctx = xforms.load_analysis(modelpath)

# load a new model
site = 'AMT020a'
batch = 331
modelname = "psth.fs4.pup-ld-st.pup.pvp0-epcpn.old-mvm.t25.w1-hrc-psthfr-plgsm.e10.sp-aev_sdexp2.2xR-lvnorm.SxR.d.so-inoise.2xR_init.xx1.it50000-lvnoise.r8-aev-ccnorm.f0.ss1"
xfn, ctxn = load_model_xform(
    cellid=[c for c in nd.get_batch_cells(batch).cellid if site in c][0],
    batch=batch,
    modelname=modelname)

# old/new loadpred models
modelname = "psth.fs4.pup-loadpred.cpnOldmvm,t25,w1-st.pup.pvp0-plgsm.e10.sp-lvnoise.r8-aev_lvnorm.SxR.d.so-inoise.2xR_ccnorm.t5.ss1"
xfnn, ctxnn = load_model_xform(cellid=site, batch=batch, modelname=modelname)
Example #14
0
# fast LV with pupil (gain model with sigmoid nonlinearity)
modelname = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta_slogsig.SxR-lv.1xR.f.pred-lvlogsig.2xR_jk.nf5.p-pupLVbasic.constrLVonly.af0:2.sc'

# single module for LV (testing LV modeling architectures)
#modelname = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta_puplvmodel.pred.step.dc.R_jk.nf5.p-pupLVbasic.constrLVonly.af0:3.sc.rb10'
#modelname0 = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta_puplvmodel.dc.pupOnly.R_jk.nf5.p-pupLVbasic.constrLVonly.af0:0.sc'

# without jackknifing (or cross validation)
modelname = 'ns.fs4.pup-ld-st.pup-epsig-hrc-apm-pbal-psthfr-ev-addmeta-aev_puplvmodel.pred.step.g.dc.R_pupLVbasic.constrNC.af0:1.sc.rb2'
modelname0 = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta-aev_puplvmodel.g.dc.pupOnly.R_pupLVbasic.constrLVonly.af0:0.sc.rb2'
xforms_model = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-addmeta-aev_puplvmodel.pred.step.pfix.dc.R_pupLVbasic.constrLVonly.af0:0.sc.rb10'
if load:
    c = [c for c in nd.get_batch_cells(batch).cellid if cellid in c][0]
    mp = nd.get_results_file(batch, [modelname], [c]).modelpath[0]
    _, ctx = xforms.load_analysis(mp)
    mp = nd.get_results_file(batch, [modelname0], [c]).modelpath[0]
    _, ctx2 = xforms.load_analysis(mp)
else:
    ctx = xfit.fit_xforms_model(batch, cellid, modelname, save_analysis=False)
    ctx2 = xfit.fit_xforms_model(batch,
                                 cellid,
                                 modelname0,
                                 save_analysis=False)

# plot lv, pupil, PC1 timecourses
if '.g.' in modelname:
    key1 = 'pg'
    key2 = 'lvg'
if '.dc.' in modelname:
    key1 = 'pd'
Example #15
0
# xfspec.append(['nems.xforms.average_away_stim_occurrences', {}])
xfspec.append([
    'nems.xforms.init_from_keywords', {
        'keywordstring': modelspec_name,
        'meta': meta
    }
])
xfspec.append(['nems.xforms.fit_basic_init', {}])
xfspec.append(['nems.xforms.fit_basic', {}])
xfspec.append(['nems.xforms.predict', {}])
xfspec.append(['nems.xforms.add_summary_statistics', {}])
xfspec.append(['nems.xforms.plot_summary', {}])

ctx, log_xf = xforms.evaluate(xfspec)
modelspecs = ctx['modelspecs']
destination = '/auto/data/nems_db/results/{0}/{1}/{2}/'.format(
    batch, cellid, ms.get_modelspec_longname(modelspecs[0]))
modelspecs[0][0]['meta']['modelpath'] = destination
modelspecs[0][0]['meta']['figurefile'] = destination + 'figure.0000.png'
modelspecs[0][0]['meta'].update(meta)

save_data = xforms.save_analysis(destination,
                                 recording=ctx['rec'],
                                 modelspecs=modelspecs,
                                 xfspec=xfspec,
                                 figures=ctx['figures'],
                                 log=log_xf)

savepath = save_data['savepath']
loaded = xforms.load_analysis(filepath=savepath, eval_model=True, only=None)
Example #16
0
datapath = '/auto/users/svd/projects/reward_training/nems_export/torcs/'

# find all models, knowing that the folder names should contain "2020"
d = glob.glob(datapath + '*2020*')

# get unique list of cells and modelnames
cellids = list(set([f.split("_")[0] for f in d]))
modelnames = list(set([f.split("_")[1].split('.2020')[0] for f in d]))

# load an example model
cellid = 'NMK020c-29-1'
modelinfo = 'st.pup.fil-'

i, = np.where([(cellid in f) and (modelinfo in f) for f in d])
i = i[0]  # is an array, just take first entry (should just be 1)

xf, ctx = load_analysis(d[i], eval_model=False)

modelspec = ctx['modelspec']
print('Loaded model {}'.format(os.path.basename(d[i])))
print('Cellid: {}\nModel name: {}'.format(modelspec.meta['cellid'],
                                          modelspec.meta['modelname']))
print('r_test: {:.3f}'.format(modelspec.meta['r_test'][0][0]))

state_channels = modelspec.meta['state_chans']
for i, s in enumerate(state_channels):
    # find name of current file
    print("{}: offset={:.3f} gain={:.3f} MI={:.3f}".format(
        s, modelspec.phi[0]['d'][0, i], modelspec.phi[0]['g'][0, i],
        modelspec.meta['state_mod'][i]))
Example #17
0
    best_alpha = pd.read_csv(
        '/auto/users/hellerc/code/projects/nat_pupil_ms_final/dprime/best_alpha.csv',
        index_col=0)
    alpha = best_alpha.loc[site][0]
    alpha = (float(alpha.split(',')[0].replace('(', '')),
             float(alpha.split(',')[1].replace(')', '')))
    a = 'af{0}.as{1}.sc.rb10'.format(
        str(alpha[0]).replace('.', ':'),
        str(alpha[1]).replace('.', ':'))
    modelname = 'ns.fs4.pup-ld-hrc-apm-pbal-psthfr-ev-residual-addmeta_lv.2xR.f.s-lvlogsig.3xR.ipsth_jk.nf5.p-pupLVbasic.constrLVonly.{}'.format(
        a)

    cellid = [c for c in nd.get_batch_cells(batch).cellid if site in c][0]
    mp = nd.get_results_file(batch, [modelname], [cellid]).modelpath[0]

    xfspec, ctx = xforms.load_analysis(mp)

    r = ctx['val'].apply_mask()
    fs = r['resp'].fs
    fast = r['lv'].extract_channels(['lv_fast'])._data.squeeze()
    slow = r['lv'].extract_channels(['lv_slow'])._data.squeeze()
    pupil = r['pupil']._data.squeeze()

    o = ss.periodogram(fast, fs=fs)
    F.append(o[1].squeeze())
    Fm.append(o[0][np.argmax(o[1].squeeze())])

    o = ss.periodogram(slow, fs=fs)
    S.append(o[1].squeeze())
    Sm.append(o[0][np.argmax(o[1].squeeze())])
Example #18
0
# NC constraint
modelname = [
    'ns.fs4.pup-ld-st.pup-hrc-apm-psthfr-ev_slogsig.SxR-lv.1xR-lvlogsig.2xR_jk.nf5.p-pupLVbasic.constrNC.a0:35'
]
modelname = [
    'ns.fs4.pup-ld-st.pup-hrc-apm-psthfr-ev-residual_slogsig.SxR-lv.1xR-lvlogsig.2xR_jk.nf2.p-pupLVbasic.constrNC.a0:05'
]
modelname2 = ['ns.fs4.pup-ld-st.pup-hrc-psthfr-ev_slogsig.SxR_jk.nf5.p-basic']
batch = 289
cellids = nd.get_batch_cells(batch).cellid
cellid = [[c for c in cellids if site in c][0]]

mp = nd.get_results_file(batch, modelname, cellid).modelpath[0]
mp2 = nd.get_results_file(batch, modelname2, cellid).modelpath[0]

xfspec, ctx = xforms.load_analysis(mp)
xfspec2, ctx2 = xforms.load_analysis(mp2)

# plot summary of fit
ctx2['modelspec'].quickplot()
ctx['modelspec'].quickplot()

rec = ctx['val'].apply_mask(reset_epochs=True).copy()  # raw recording
rec['lv'] = rec['lv']._modified_copy(rec['lv']._data[1, :][np.newaxis, :])
rec1 = ctx2['val'].apply_mask(
    reset_epochs=True).copy()  # first order regression
rec12 = rec.copy()  # first / second order regression

#rec1 = preproc.regress_state(rec1, state_sigs=['pupil'], regress=['pupil'])
#rec12 = preproc.regress_state(rec12, state_sigs=['pupil', 'lv'], regress=['pupil', 'lv'])
psth = preproc.generate_psth(rec12)