Пример #1
0
 def on_view_model(self):
     """Event handler for view model button."""
     self.statusbar.showMessage('Loading model...', 5000)
     xfspec, ctx = xform_helper.load_model_xform(self.cellid,
                                                 self.batch,
                                                 self.modelname,
                                                 eval_model=True)
     self.launch_model_browser(ctx, xfspec)
Пример #2
0
def stp_sigmoid_pred_matched(cellid, batch, modelname, LN, include_phi=True):
    xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)
    ln_spec, ln_ctx = xhelp.load_model_xform(cellid, batch, LN)
    modelspec = ctx['modelspec']
    modelspec.recording = ctx['val']
    val = ctx['val'].apply_mask()
    ln_modelspec = ln_ctx['modelspec']
    ln_modelspec.recording = ln_ctx['val']
    ln_val = ln_ctx['val'].apply_mask()

    pred_after_NL = val['pred'].as_continuous().flatten()  # with stp
    val_before_NL = ms.evaluate(ln_val, ln_modelspec, stop=-1)
    pred_before_NL = val_before_NL['pred'].as_continuous().flatten()  # no stp

    stp_idx = find_module('stp', modelspec)
    val_before_stp = ms.evaluate(val, modelspec, stop=stp_idx)
    val_after_stp = ms.evaluate(val, modelspec, stop=stp_idx + 1)
    pred_before_stp = val_before_stp['pred'].as_continuous().mean(
        axis=0).flatten()
    pred_after_stp = val_after_stp['pred'].as_continuous().mean(
        axis=0).flatten()
    stp_effect = (pred_after_stp - pred_before_stp) / (pred_after_stp +
                                                       pred_before_stp)

    fig = plt.figure()
    plasma = plt.get_cmap('plasma')
    plt.scatter(pred_before_NL,
                pred_after_NL,
                c=stp_effect,
                s=2,
                alpha=0.75,
                cmap=plasma)
    plt.title(cellid)
    plt.xlabel('pred in (no stp)')
    plt.ylabel('pred out (with stp)')

    if include_phi:
        stp_phi = modelspec.phi[stp_idx]
        phi_string = '\n'.join(
            ['%s:  %.4E' % (k, v) for k, v in stp_phi.items()])
        fig.text(0.775, 0.9, phi_string, va='top', ha='left')
        plt.subplots_adjust(right=0.775, left=0.075)
    plt.colorbar()

    return fig
Пример #3
0
    def load(self):
        batch = self.batchLE.text()
        cellid = self.cellLE.text()
        modelname = self.modelLE.text()
        print('Loading {}/{}/{}'.format(cellid,batch,modelname))

        xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)
        self.close()
        return xfspec, ctx
Пример #4
0
def get_model_preds(cellid, batch, modelname):
    xf, ctx = xhelp.load_model_xform(cellid,
                                     batch,
                                     modelname,
                                     eval_model=False)
    ctx, l = xforms.evaluate(xf, ctx, stop=-1)
    #ctx, l = oxf.evaluate(xf, ctx, stop=-1)

    return xf, ctx
Пример #5
0
def dynamic_sigmoid_differences(batch,
                                modelname,
                                hist_bins=60,
                                test_limit=None,
                                save_path=None,
                                load_path=None,
                                use_quartiles=False,
                                avg_bin_count=20):

    if load_path is None:
        cellids = nd.get_batch_cells(batch, as_list=True)
        ratios = []
        for cellid in cellids[:test_limit]:
            xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)
            val = ctx['val'].apply_mask()
            ctpred = val['ctpred'].as_continuous().flatten()
            pred_after = val['pred'].as_continuous().flatten()
            val_before = ms.evaluate(val, ctx['modelspec'], stop=-1)
            pred_before = val_before['pred'].as_continuous().flatten()
            median_ct = np.nanmedian(ctpred)
            if use_quartiles:
                low = np.percentile(ctpred, 25)
                high = np.percentile(ctpred, 75)
                low_mask = (ctpred >= low) & (ctpred < median_ct)
                high_mask = ctpred >= high
            else:
                low_mask = ctpred < median_ct
                high_mask = ctpred >= median_ct

            # TODO: do some kind of binning here since the two vectors
            # don't actually overlap in x axis
            mean_before, bin_masks = _binned_xvar(pred_before, avg_bin_count)
            low = _binned_yavg(pred_after, low_mask, bin_masks)
            high = _binned_yavg(pred_after, high_mask, bin_masks)

            ratio = np.nanmean((low - high) / (np.abs(low) + np.abs(high)))
            ratios.append(ratio)

        ratios = np.array(ratios)
        if save_path is not None:
            np.save(save_path, ratios)
    else:
        ratios = np.load(load_path)

    plt.figure()
    plt.hist(ratios,
             bins=hist_bins,
             color=[wsu_gray_light],
             edgecolor='black',
             linewidth=1)
    #plt.rc('text', usetex=True)
    #plt.xlabel(r'\texit{\frac{low-high}{\left|high\right|+\left|low\right|}}')
    plt.xlabel('(low - high)/(|low| + |high|)')
    plt.ylabel('cell count')
    plt.title("difference of low-contrast output and high-contrast output\n"
              "positive means low-contrast has higher firing rate on average")
Пример #6
0
def dynamic_sigmoid_range(cellid,
                          batch,
                          modelname,
                          plot=True,
                          show_min_max=True):

    xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)
    modelspec = ctx['modelspec']
    modelspec.recording = ctx['val']
    val = ctx['val'].apply_mask()
    ctpred = val['ctpred'].as_continuous().flatten()
    val_before_dsig = ms.evaluate(val, modelspec, stop=-1)
    pred_before_dsig = val_before_dsig['pred'].as_continuous().flatten()
    lows = {k: v for k, v in modelspec[-1]['phi'].items() if '_mod' not in k}
    highs = {k[:-4]: v for k, v in modelspec[-1]['phi'].items() if '_mod' in k}
    for k in lows:
        if k not in highs:
            highs[k] = lows[k]
    # re-sort keys to make sure they're in the same order
    lows = {k: lows[k] for k in sorted(lows)}
    highs = {k: highs[k] for k in sorted(highs)}

    ctmax_val = np.max(ctpred)
    ctmin_val = np.min(ctpred)
    ctmax_idx = np.argmax(ctpred)
    ctmin_idx = np.argmin(ctpred)

    thetas = list(lows.values())
    theta_mods = list(highs.values())
    for t, t_m, k in zip(thetas, theta_mods, list(lows.keys())):
        lows[k] = t + (t_m - t) * ctmin_val
        highs[k] = t + (t_m - t) * ctmax_val

    low_out = _double_exponential(pred_before_dsig, **lows).flatten()
    high_out = _double_exponential(pred_before_dsig, **highs).flatten()

    if plot:
        fig = plt.figure()
        plt.scatter(pred_before_dsig, low_out, color='blue', s=0.7, alpha=0.6)
        plt.scatter(pred_before_dsig, high_out, color='red', s=0.7, alpha=0.6)

        if show_min_max:
            max_pred = pred_before_dsig[ctmax_idx]
            min_pred = pred_before_dsig[ctmin_idx]
            plt.scatter(min_pred,
                        low_out[ctmin_idx],
                        facecolors=None,
                        edgecolors='blue',
                        s=60)
            plt.scatter(max_pred,
                        high_out[ctmax_idx],
                        facecolors=None,
                        edgecolors='red',
                        s=60)

    return pred_before_dsig, low_out, high_out
Пример #7
0
def _get_plot_contexts(cellid, batch, gc, stp, LN, combined):
    xfspec1, gc_ctx = xhelp.load_model_xform(cellid,
                                             batch,
                                             gc,
                                             eval_model=True)
    xfspec2, stp_ctx = xhelp.load_model_xform(cellid,
                                              batch,
                                              stp,
                                              eval_model=True)
    xfspec3, LN_ctx = xhelp.load_model_xform(cellid,
                                             batch,
                                             LN,
                                             eval_model=True)
    xfspec4, combined_ctx = xhelp.load_model_xform(cellid,
                                                   batch,
                                                   combined,
                                                   eval_model=True)

    return gc_ctx, stp_ctx, LN_ctx, combined_ctx
Пример #8
0
 def preview(self):
     batch = self.batchLE.text()
     cellid = self.cellLE.text()
     modelname = self.modelLE.text()
     print("Viewing {},{},{}".format(batch,cellid,modelname))
     xf, ctx = xhelp.load_model_xform(cellid, batch, modelname, eval_model=False)
     figurefile = ctx['modelspec'].meta['figurefile']
     print('Figure file: ' + figurefile)
     self.im_window.update_imagepath(imagepath=figurefile)
     self.im_window.show()
Пример #9
0
def generate_psth_correlations_pop(batch,
                                   modelnames,
                                   save_path=None,
                                   load_path=None):
    if load_path is not None:
        corrs = pd.read_pickle(load_path)
        return corrs
    else:
        cellids = []
        c2d_c1d = []
        c2d_LN = []
        c1d_LN = []
    significant_cells = get_significant_cells(batch,
                                              SIG_TEST_MODELS,
                                              as_list=True)

    # Load and evaluate each model, pull out validation pred signal for each one.
    contexts = [
        xhelp.load_model_xform(significant_cells[0], batch, m,
                               eval_model=True)[1] for m in modelnames
    ]
    preds = [c['val'].apply_mask()['pred'] for c in contexts]
    chans = preds[0].chans  # not all the models load chans for some reason
    for i, _ in enumerate(preds[1:]):
        preds[i + 1].chans = chans
    preds = [
        p.extract_channels(significant_cells).as_continuous() for p in preds
    ]

    for i, cellid in enumerate(significant_cells):
        # Compute correlation between eaceh pair of models, append to running list.
        # 0: conv2d, 1: conv1dx2+d, 2: LN_pop,  # TODO: if EQUIVALENCE_MODELS_POP changes, this needs to change as well
        c2d_c1d.append(np.corrcoef(
            preds[0][i], preds[1][i])[0,
                                      1])  # correlate conv2d with conv1dx2+d
        c2d_LN.append(np.corrcoef(
            preds[0][i], preds[2][i])[0, 1])  # correlate conv2d with LN_pop
        c1d_LN.append(np.corrcoef(
            preds[1][i], preds[2][i])[0,
                                      1])  # correlate conv1dx2+d with LN_pop
        cellids.append(cellid)

    # Convert to dataframe and save after each cell, in case there's a crash.
    corrs = {
        'cellid': cellids,
        'c2d_c1d': c2d_c1d,
        'c2d_LN': c2d_LN,
        'c1d_LN': c1d_LN
    }
    corrs = pd.DataFrame.from_dict(corrs)
    corrs.set_index('cellid', inplace=True)
    if save_path is not None:
        corrs.to_pickle(save_path)

    return corrs
Пример #10
0
def pred_comparison(cellid, batch, models, subtract=None):
    '''
    Convention for model order: gc, stp, ln, gc+stp
    subtract: None, or string modelname to subtract preds (like LN)
    '''
    if subtract is not None:
        xf, ctx = xhelp.load_model_xform(cellid, batch, subtract)
        subtract_pred = ctx['pred'].as_continuous().T

    preds = []
    for m in models:
        xf, ctx = xhelp.load_model_xform(cellid, batch, m)
        p = ctx['val']['pred'].as_continuous().T
        if subtract is not None:
            p = p - subtract_pred
        preds.append(p)

    for p in preds:
        plt.plot(p)
    plt.legend(*preds)
Пример #11
0
def get_strf(cellid, batch, modelname):
    xfspec, ctx = xhelp.load_model_xform(cellid,
                                         batch,
                                         modelname,
                                         eval_model=False)
    modelspec = ctx['modelspec']
    wc_coefs = _get_wc_coefficients(modelspec, idx=0)
    fir_coefs = _get_fir_coefficients(modelspec, idx=0)
    strf = wc_coefs.T @ fir_coefs

    return strf
Пример #12
0
def test_comparison(
        cellid="TAR010c-15-5",
        batch=289,
        modelname1="ozgf.fs100.ch18.pop-loadpop.cc20.bth-norm.l1-popev_wc.18x30.g-fir.1x12x30-relu.30-wc.30x30.z-relu.30-wc.30xR.z-lvl.R-dexp.R_tfinit.n.lr1e3.et3-newtf.n.lr1e4-popspc",
        modelname2="ozgf.fs100.ch18.pop-loadpop.cc20.bth-norm.l1-popev_wc.18x40.g-stp.40.q.s-fir.1x12x40-relu.40-wc.40x30.z-relu.30-wc.30xR.z-lvl.R-dexp.R_tfinit.n.lr1e3.et3-newtf.n.lr1e4-popspc",
        modelname_ref="ozgf.fs100.ch18-ld-sev_dlog-wc.18x3.g-fir.3x15-lvl.1-dexp.1_init-basic",
        recname="val"):

    xf1, ctx1 = load_model_xform(cellid, batch, modelname1)
    xf2, ctx2 = load_model_xform(cellid, batch, modelname2)
    xfr, ctxr = load_model_xform(cellid, batch, modelname_ref)

    modelspec1 = ctx1['modelspec']
    modelspec2 = ctx2['modelspec']
    modelspec_ref = ctxr['modelspec']
    rec1 = ctx1[recname]
    rec2 = ctx2[recname]
    rec_ref = ctxr[recname]

    return modelspec1, modelspec2, modelspec_ref, rec1, rec2, rec_ref
Пример #13
0
def gd_ratio(cellid, batch, modelname):

    xfspec, ctx = xhelp.load_model_xform(cellid,
                                         batch,
                                         modelname,
                                         eval_model=False)
    mspec = ctx['modelspec']
    dsig_idx = find_module('dynamic_sigmoid', mspec)
    phi = mspec[dsig_idx]['phi']

    return phi['kappa_mod'] / phi['kappa']
Пример #14
0
def sparseness_example(batch,
                       cellid,
                       modelname,
                       rec=None,
                       ax=None,
                       tag="",
                       colors=None):

    if rec is None:
        xf, ctx = load_model_xform(cellid, batch, modelname, eval_model=True)
        val = ctx['val']
        val = val.apply_mask()
    else:
        xf, ctx = load_model_xform(cellid, batch, modelname, eval_model=False)
        val = rec

    val_ = ctx['modelspec'].evaluate(val)

    fs = val['resp'].fs
    this_resp = val['resp'].extract_channels(chans=[cellid])._data[0, :] * fs
    this_pred = val_['pred']._data[0, :] * fs

    c = np.corrcoef(this_resp, this_pred)[0, 1]

    if ax is None:
        f, ax = plt.subplots()
    S_r, S_p, c = sparseness(this_resp,
                             this_pred,
                             cellid=cellid,
                             verbose=True,
                             ax=ax,
                             colors=colors,
                             tag=tag)
    #ax.set_title(f"{cellid} {tag}")
    print(
        f"{cellid} r_test={c:.3f} orig r_test={ctx['modelspec'].meta['r_test'][0,0]:.3f} S_r={S_r:.3f} S_p={S_p:.3f}"
    )

    return S_r, S_p
Пример #15
0
def load_models(cell, batch, models, check_db=True, site=None, site_cache=None):
    '''Load standardized psth from each model, error if not all fits exist.'''

    if check_db:
        # Before trying to load, check database to see if a result exists.
        # Should set False if you know model results are not stored in DB,
        # but exist in file storage.
        df = nd.batch_comp(batch=batch, modelnames=models, cellids=[cell])
        if np.sum(df.isna().values) > 0:
            # at least one cell wasn't fit (or at least not stored in db)
            # so skip trying to load any of them since all are required.
            raise ValueError('Not all results exist for: %s, %d' % (cell, batch))

    # Load all models
    ctxs = []
    for model in models:
        if site_cache is None:
            xf, ctx = xhelp.load_model_xform(cell, batch, model)
        elif model in site_cache[site]:
            log.info("Site %s is cached, skipping load...", site)
            ctx = site_cache[site][model]
        else:
            xf, ctx = xhelp.load_model_xform(cell, batch, model)
            site_cache[site][model] = ctx
        ctxs.append(ctx)

    for ctx in ctxs:
        if ctx['val']['pred'].chans is None:
            ctx['val']['pred'].chans = copy(ctx['val']['resp'].chans)

    # Pull out model predictions and remove times with nan for at least 1 model
    preds = [ctx['val'].apply_mask()['pred'].extract_channels([cell]).as_continuous() for ctx in ctxs]
    ff = np.isfinite(preds[0])
    for pred in preds[1:]:
        ff &= np.isfinite(pred)
    no_nans = [pred[ff] for pred in preds]

    return no_nans
Пример #16
0
def cf_batch_rank1(batch,
                   modelname,
                   save_path=None,
                   load_path=None,
                   f_low=0.2,
                   f_high=20,
                   nf=18,
                   test_limit=None):

    if load_path is not None:
        df = pd.read_pickle(load_path)
        return df

    cells = nd.get_batch_cells(batch, as_list=True)
    cfs = []
    cf_bins = []
    skipped = []
    for cellid in cells[:test_limit]:
        try:
            xfspec, ctx = xhelp.load_model_xform(cellid,
                                                 batch,
                                                 modelname,
                                                 eval_model=False)
        except:
            # cell probably not fit for this model
            skipped.append(cellid)
            continue

        modelspec = ctx['modelspec']
        # mult by nf b/c x vals in gaussian coeffs module are divided by
        # number of channels
        # max and min bounds b/c means outside of bin range are allowed
        mean = min(max(0, np.asscalar(modelspec.phi[1]['mean']) * nf), nf - 1)
        khz_freqs = np.logspace(np.log(f_low),
                                np.log(f_high),
                                num=nf,
                                base=np.e)
        cf_bin = int(round(mean))
        cf = khz_freqs[cf_bin]
        cfs.append(cf)
        cf_bins.append(cf_bin)

    cellid_index = [c for c in cells[:test_limit] if c not in skipped]
    results = {'cellid': cellid_index, 'cf': cfs, 'cf_bin': cf_bins}
    df = pd.DataFrame.from_dict(results)
    df.set_index('cellid', inplace=True)
    if save_path is not None:
        df.to_pickle(save_path)

    return df
Пример #17
0
def dynamic_sigmoid_distribution(cellid,
                                 batch,
                                 modelname,
                                 sample_every=10,
                                 alpha=0.1):

    xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)
    modelspec = ctx['modelspec']
    val = ctx['val'].apply_mask()
    modelspec.recording = val
    val_before_dsig = ms.evaluate(val, modelspec, stop=-1)
    pred_before_dsig = val_before_dsig['pred'].as_continuous().T
    ctpred = val_before_dsig['ctpred'].as_continuous().T

    lows = {k: v for k, v in modelspec[-1]['phi'].items() if '_mod' not in k}
    highs = {k[:-4]: v for k, v in modelspec[-1]['phi'].items() if '_mod' in k}
    for k in lows:
        if k not in highs:
            highs[k] = lows[k]
    # re-sort keys to make sure they're in the same order
    lows = {k: lows[k] for k in sorted(lows)}
    highs = {k: highs[k] for k in sorted(highs)}
    thetas = list(lows.values())
    theta_mods = list(highs.values())

    fig = plt.figure()
    for i in range(int(len(pred_before_dsig) / sample_every)):
        try:
            ts = {}
            for t, t_m, k in zip(thetas, theta_mods, list(lows.keys())):
                ts[k] = t + (t_m - t) * ctpred[i * sample_every]
            y = _double_exponential(pred_before_dsig, **ts)
            plt.scatter(pred_before_dsig,
                        y,
                        color='black',
                        alpha=alpha,
                        s=0.01)
        except IndexError:
            # Will happen on last attempt if array wasn't evenly divisible
            # by sample_every
            pass
    t_max = {}
    t_min = {}
    for t, t_m, k in zip(thetas, theta_mods, list(lows.keys())):
        t_max[k] = t + (t_m - t) * np.nanmax(ctpred)
        t_min[k] = t + (t_m - t) * np.nanmin(ctpred)
    max_out = _double_exponential(pred_before_dsig, **t_max)
    min_out = _double_exponential(pred_before_dsig, **t_min)
    plt.scatter(pred_before_dsig, max_out, color='red', s=0.1)
    plt.scatter(pred_before_dsig, min_out, color='blue', s=0.1)
Пример #18
0
 def on_selection_changed(self, event=None):
     print('on_selection_changed')
     try:
         cellid = self.cells.currentItem().text()
         modelname = self.models.currentItem().text()
         batch = self.batch
         print('Selected cell(s): ' + cellid)
         print('Selected model(s): ' + modelname)
         print('Selected batch: ' + str(batch))
         xf, ctx = xhelp.load_model_xform(cellid, batch, modelname, eval_model=False)
         figurefile = ctx['modelspec'].meta['figurefile']
         print('Figure file: ' + figurefile)
         self.im_window.update_imagepath(imagepath=figurefile)
     except:
         print('error?')
Пример #19
0
def load_existing_pred(cellid=None,
                       siteid=None,
                       batch=None,
                       modelname_existing=None,
                       **kwargs):
    """
    designed to be called by xforms keyword loadpred 
    cellid/siteid - one or the other required
    batch - required
    default modelname_existing = "psth.fs4.pup-ld-st.pup-hrc-psthfr-aev_sdexp2.SxR_newtf.n.lr1e4.cont.et5.i50000"
    
    makes new signal 'pred0' from evaluated 'pred', returns in updated rec
    returns ctx-compatible dict {'rec': nems.Recording, 'input_name': 'pred0'}
    """
    if (batch is None):
        raise ValueError("must specify cellid/siteid and batch")

    if cellid is None:
        if siteid is None:
            raise ValueError("must specify cellid/siteid and batch")
        d = nd.pd_query(
            "SELECT batch,cellid FROM Batches WHERE batch=%s AND cellid like %s",
            (
                batch,
                siteid + "%",
            ))
        cellid = d['cellid'].values[0]
    elif type(cellid) is list:
        cellid = cellid[0]

    if modelname_existing is None:
        #modelname_existing = "psth.fs4.pup-ld-st.pup-hrc-psthfr-aev_sdexp2.SxR_newtf.n.lr1e4.cont"
        modelname_existing = "psth.fs4.pup-ld-st.pup-hrc-psthfr-aev_sdexp2.SxR_newtf.n.lr1e4.cont.et5.i50000"

    xf, ctx = xhelp.load_model_xform(cellid, batch, modelname_existing)
    for k in ctx['val'].signals.keys():
        if k not in ctx['rec'].signals.keys():
            ctx['rec'].signals[k] = ctx['val'].signals[k].copy()
    s = ctx['rec']['pred'].copy()
    s.name = 'pred0'
    ctx['rec'].add_signal(s)

    #return {'rec': ctx['rec'],'val': ctx['val'],'est': ctx['est']}
    return {'rec': ctx['rec'], 'input_name': 'pred0'}
Пример #20
0
def dynamic_sigmoid_pred_matched(cellid, batch, modelname, include_phi=True):
    xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)
    modelspec = ctx['modelspec']
    modelspec.recording = ctx['val']
    val = ctx['val'].apply_mask()
    ctpred = val['ctpred'].as_continuous().flatten()
    # HACK
    # this really shouldn't happen.. but for some reason some  of the
    # batch 263 cells are getting nans, so temporary fix.
    ctpred[np.isnan(ctpred)] = 0
    pred_after_dsig = val['pred'].as_continuous().flatten()
    val_before_dsig = ms.evaluate(val, modelspec, stop=-1)
    pred_before_dsig = val_before_dsig['pred'].as_continuous().flatten()

    fig = plt.figure(figsize=(12, 7))
    plasma = plt.get_cmap('plasma')
    plt.scatter(pred_before_dsig,
                pred_after_dsig,
                c=ctpred,
                s=2,
                alpha=0.75,
                cmap=plasma)
    plt.title(cellid)
    plt.xlabel('pred in')
    plt.ylabel('pred out')

    if include_phi:
        dsig_phi = modelspec.phi[-1]
        phi_string = '\n'.join(
            ['%s:  %.4E' % (k, v) for k, v in dsig_phi.items()])
        thetas = list(dsig_phi.keys())[0:-1:2]
        mods = list(dsig_phi.keys())[1::2]
        weights = {
            k: (dsig_phi[mods[i]] - dsig_phi[thetas[i]])
            for i, k in enumerate(thetas)
        }
        weights_string = 'weights:\n' + '\n'.join(
            ['%s:  %.4E' % (k, v) for k, v in weights.items()])
        fig.text(0.775, 0.9, phi_string, va='top', ha='left')
        fig.text(0.775, 0.1, weights_string, va='bottom', ha='left')
        plt.subplots_adjust(right=0.775, left=0.075)

    plt.colorbar()
    return fig
Пример #21
0
def random_condition_convergence(cellid,
                                 batch,
                                 modelname,
                                 separate_figures=True):
    xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)
    meta = ctx['modelspec'].meta
    rcs = meta['random_conditions']
    best_idx = meta['best_random_idx']
    keys = list(rcs[0][0].keys())

    if not separate_figures:
        plt.figure()
        colors = [np.random.rand(3, ) for k in keys]
        for initial, final in rcs:
            starts = list(initial.values())
            ends = list(final.values())
            for i, k in enumerate(keys):
                plt.plot([0, 1],
                         [np.asscalar(starts[i]),
                          np.asscalar(ends[i])],
                         c=colors[i])
        plt.legend(keys)

    else:
        for k in keys:
            plt.figure()
            for i, (initial, final) in enumerate(rcs):
                start = initial[k]
                end = final[k]
                if i == 0:
                    color = 'blue'
                    label = 'mean'
                elif i == best_idx:
                    color = 'red'
                    label = 'best'
                else:
                    color = 'black'
                    label = None
                plt.plot([0, 1],
                         np.concatenate((start, end)),
                         color=color,
                         label=label)
            plt.legend()
            plt.title("%s, best_idx: %d" % (k, best_idx))
Пример #22
0
def compare_sims(start=0, end=None):
    # TODO: set up to compare on synthetic stimuli
    xfspec, ctx = xhelp.load_model_xform(_DEFAULT_CELL, _DEFAULT_BATCH,
                                         _DEFAULT_MODEL)
    val = ctx['val']
    gc_sim = build_toy_gc_cell(0, 0, 0, -0.5) #base, amp, shift, kappa
    gc_sim[-2]['fn_kwargs']['compute_contrast'] = True
    stp_sim = build_toy_stp_cell([0, 0.1], [0.08, 0.08]) #u, tau
    LN_sim = build_toy_LN_cell()

    stim = val['stim'].as_continuous()
    gc_val = gc_sim.evaluate(val)
    gc_sim.recording = gc_val
    gc_psth = gc_val['pred'].as_continuous().flatten()
    stp_val = stp_sim.evaluate(val)
    stp_sim.recording = stp_val
    stp_psth = stp_val['pred'].as_continuous().flatten()
    LN_val = LN_sim.evaluate(val)
    LN_sim.recording = LN_val
    LN_psth = LN_val['pred'].as_continuous().flatten()

    fig = plt.figure(figsize=wide_fig)
    if end is None:
        end = stim.shape[-1]
    plt.imshow(stim, aspect='auto', cmap=spectrogram_cmap,
               origin='lower', extent=(0, stim.shape[-1], 2.1, 3.4))
    lw = 0.75
    plt.plot(LN_psth, color=model_colors['LN'], linewidth=lw)
    plt.plot(gc_psth, color=model_colors['gc'], linewidth=lw*1.25)
    plt.plot(stp_psth, color=model_colors['stp'], alpha=0.75,
             linewidth=lw*1.25)
    plt.ylim(-0.1, 3.4)
    plt.xlim(start, end)
    ax = plt.gca()
    ax_remove_box(ax)

    return fig
Пример #23
0
def pop_model_example(figsize=None):

    tctx = load_high_res_stim()

    batch = 322
    cellid = "DRX006b-128-2"
    modelname = ALL_FAMILY_POP[2]
    xf, ctx = load_model_xform(cellid, batch, modelname)

    modelspec = ctx['modelspec'].copy()
    val = ctx['val'].apply_mask()

    # extract a subset of channels, since 9xx is too many
    N = 50
    rr = slice(N, N * 2, 1)
    modelspec.phi[8]['coefficients'] = modelspec.phi[8]['coefficients'][rr, :]
    modelspec.phi[9]['level'] = modelspec.phi[9]['level'][rr]
    modelspec.phi[10]['base'] = modelspec.phi[10]['base'][rr]
    modelspec.phi[10]['amplitude'] = modelspec.phi[10]['amplitude'][rr]
    modelspec.phi[10]['shift'] = modelspec.phi[10]['shift'][rr]
    modelspec.phi[10]['kappa'] = modelspec.phi[10]['kappa'][rr]
    val['resp'] = val['resp']._modified_copy(data=val['resp'][rr, :])
    val['pred'] = val['pred']._modified_copy(data=val['pred'][rr, :])
    modelspec.meta['cellid'] = modelspec.meta['cellids'][N]
    modelspec.meta['cellids'] = modelspec.meta['cellids'][rr]
    modelspec.meta['r_ceiling'] = modelspec.meta['r_test'][rr] * 1.1
    print(val['stim'].shape, val['resp'].shape, tctx['val']['stim'].shape)

    f = pop_models.plot_layer_outputs(modelspec,
                                      val,
                                      index_range=np.arange(150, 600),
                                      example_idx=15,
                                      figsize=figsize,
                                      altstim=tctx['val']['stim'])

    return f
Пример #24
0
def mean_prior_used(batch, modelname):
    choices = []
    cells = nd.get_batch_cells(batch, as_list=True)
    for i, c in enumerate(cells[400:500]):
        if 25 % (i + 1) == 0:
            print('cell %d/%d\n' % (i, len(cells)))
        try:
            xfspec, ctx = xhelp.load_model_xform(c,
                                                 batch,
                                                 modelname,
                                                 eval_model=False)
            modelspec = ctx['modelspec']
            choices.append(modelspec.meta.get('best_random_idx', 0))
        except ValueError:
            # no result
            continue

    if choices:
        choices = np.array(choices).flatten()
        mean_count = np.sum(choices == 0)
        proportion = mean_count / len(choices)
        print('proportion mean prior used: %.4f' % proportion)
    else:
        print('no results found')
Пример #25
0
import nems.gui.editors as gui

log = logging.getLogger(__name__)
# NAT A1 SINGLE NEURON + PUPIL
batch = 289
cellid = 'TAR009d-42-1'
modelname = "ozgf.fs100.ch18-ld-sev_dlog.f-wc.18x3.g-stp.3-fir.3x15-lvl.1-dexp.1_init-basic"

batch, cellid = 308, 'AMT018a-09-1'
modelname = 'ozgf.fs100.ch18-ld-sev_dlog-wc.18x4.g-fir.2x15x2-relu.2-wc.2x1-lvl.1-dexp.1_tf.n.rb10'
modelname2 = 'ozgf.fs100.ch18-ld-sev_dlog-wc.18x4.g-fir.2x15x2-relu.2-wc.2x1-lvl.1-dexp.1_tf.n.rb5'
modelname2 = None

GUI = True

xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)

if GUI:
    # interactive model browser (matplotlib embedded Qt)
    ex = gui.browse_xform_fit(ctx, xfspec)

    if modelname2 is not None:
        xfspec2, ctx2 = xhelp.load_model_xform(cellid, batch, modelname2)
        ex2 = gui.browse_xform_fit(ctx2,
                                   xfspec2,
                                   control_widget=ex.editor.global_controls)

    #aw = browse_context(ctx, rec='val', signals=['stim', 'pred', 'resp'])
    #aw = browse_context(ctx, signals=['state', 'psth', 'pred', 'resp'])

else:
Пример #26
0
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 22 17:59:33 2018

@author: daniela
"""
#import nems_lbhb.xform_wrappers as nw
import nems.xform_helper as xhelp

cellid = "BRT036b-06-1"
modelname = "psth.fs20.pup-ld-st.pup.beh-evs.tar.lic0_fir.Nx40-lvl.1-stategain.3_jk.nf10-init.st-basic"
batch = 301
xf, ctx = xhelp.load_model_xform(cellid, batch, modelname)
Пример #27
0
def sparseness_figs():

    batch = 322
    cellid = 'ARM030a-40-2'
    modelname = ALL_FAMILY_MODELS[2]
    modelname_ln = ALL_FAMILY_MODELS[3]
    xf, ctx = load_model_xform(cellid, batch, modelname, eval_model=True)
    val = ctx['val']
    val = val.apply_mask()

    f1, ax = plt.subplots(
        1, 4, figsize=double_column_shorter)  #, sharex=True, sharey=True)
    sparseness_example(batch,
                       cellid,
                       modelname,
                       rec=val,
                       ax=ax[0],
                       tag="1D CNNx2",
                       colors=['lightgray', DOT_COLORS['1D CNNx2']])
    sparseness_example(batch,
                       cellid,
                       modelname_ln,
                       rec=val,
                       ax=ax[1],
                       tag="pop LN",
                       colors=['lightgray', DOT_COLORS['pop LN']])
    ax[1].set_yticks([])
    ax[0].set_box_aspect(1)
    ax[1].set_box_aspect(1)

    a1_sparseness_path = int_path / str(a1) / 'sparseness_data.csv'
    peg_sparseness_path = int_path / str(peg) / 'sparseness_data.csv'
    modelnames = [
        ALL_FAMILY_MODELS[0], ALL_FAMILY_MODELS[2], ALL_FAMILY_MODELS[3]
    ]
    pop_reference_model = ALL_FAMILY_POP[2]

    batch = a1
    sparseness_data_a1 = sparseness_by_batch(
        batch,
        modelnames=modelnames,
        pop_reference_model=pop_reference_model,
        save_path=a1_sparseness_path,
        force_regenerate=False,
        rec=None)
    batch = peg
    sparseness_data_peg = sparseness_by_batch(
        batch,
        modelnames=modelnames,
        pop_reference_model=pop_reference_model,
        save_path=peg_sparseness_path,
        force_regenerate=False,
        rec=None)

    sparseness_data_a1['area'] = 'A1'
    sparseness_data_peg['area'] = 'PEG'
    sparseness_data = pd.concat([sparseness_data_a1, sparseness_data_peg],
                                ignore_index=True)
    # TODO: (maybe) filter out cellids that aren't in sig cell list, in addition to the r>0.2 check in sparseness_by_batch
    a1_cellids = get_significant_cells(322, SIG_TEST_MODELS, as_list=True)
    peg_cellids = get_significant_cells(323, SIG_TEST_MODELS, as_list=True)

    d = sparseness_data_a1.loc[sparseness_data_a1.model == 1].merge(
        sparseness_data_a1.loc[sparseness_data_a1.model == 2],
        how='inner',
        on='cellid',
        suffixes=('_dnn', '_ln'))
    ax[2].plot([0, 1], [0, 1], 'k--')
    d.plot.scatter(
        'S_p_ln', 'S_p_dnn', s=1, c='black',
        ax=ax[2])  #, c='r_test_all_dnn', ax=ax[1, 0], vmin=0.1, vmax=0.9)
    #ax[2].set_title('Predicted sparseness')
    ax[2].set_aspect('equal')
    ax[2].set_xlabel('pop LN')
    ax[2].set_ylabel('1D CNNx2')

    sd_r = sparseness_data[['area', 'cellid', 'model', 'S_r',
                            'r_test_all']].copy()
    sd_r['model'] = "act"
    sd_r.columns = ['area', 'cellid', 'model', 'S', 'r_test']
    sd_r = sd_r.drop_duplicates()
    sd_p = sparseness_data[['area', 'cellid', 'model', 'S_p',
                            'r_test_all']].copy()
    sd_p = sd_p.loc[sd_p['model'] > 0]
    sd_p['model'] = sd_p['model'].astype(str)
    sd_p.columns = ['area', 'cellid', 'model', 'S', 'r_test']
    sd = pd.concat([sd_p, sd_r], ignore_index=True)
    sd.loc[sd['model'] == '1', 'model'] = '1D CNN'
    sd.loc[sd['model'] == '2', 'model'] = 'pop LN'
    sd['label'] = sd['area'] + " " + sd['model']
    #tres=results.loc[(results[PLOT_STAT]<1) & results[PLOT_STAT]>-0.05]

    r_test_min = 0.0
    print(f"r_test_min={r_test_min}")
    sd_thr = sd.loc[sd['r_test'] > r_test_min]
    #f,ax=plt.subplots()
    sns.stripplot(
        x='label',
        y='S',
        hue='label',
        data=sd_thr,
        zorder=0,
        palette=['gray', DOT_COLORS['1D CNNx2'], DOT_COLORS['pop LN']] * 2,
        hue_order=[
            'A1 act', 'A1 1D CNN', 'A1 pop LN', 'PEG act', 'PEG 1D CNN',
            'PEG pop LN'
        ],
        order=[
            'A1 act', 'A1 1D CNN', 'A1 pop LN', 'PEG act', 'PEG 1D CNN',
            'PEG pop LN'
        ],
        jitter=0.2,
        size=2,
        ax=ax[3])  #[1,1]
    sns.boxplot(x='label',
                y='S',
                data=sd_thr,
                boxprops={
                    'facecolor': 'None',
                    'linewidth': 1
                },
                showcaps=False,
                showfliers=False,
                whiskerprops={'linewidth': 0},
                order=[
                    'A1 act', 'A1 1D CNN', 'A1 pop LN', 'PEG act',
                    'PEG 1D CNN', 'PEG pop LN'
                ],
                ax=ax[3])  #[1,1]
    plt.xticks(rotation=45, fontsize=6, ha='right')
    ax[3].legend_.remove()
    ax[3].set_xlabel('')
    ax[3].set_box_aspect(1)
    #ax[3].set_title(f"r_test_min={r_test_min}")

    ref_models = ['A1 act', 'A1 1D CNN', 'PEG act', 'PEG 1D CNN']
    test_models = ['A1 1D CNN', 'A1 pop LN', 'PEG 1D CNN', 'PEG pop LN']

    tests = [[
        m1, m2,
        st.wilcoxon(sd.loc[sd['label'] == m1, 'S'],
                    sd.loc[sd['label'] == m2, 'S'],
                    alternative='two-sided')
    ] for m1, m2 in zip(ref_models, test_models)]
    print(pd.DataFrame(tests, columns=['ref', 'test', 'Wilcoxon u,p']))
    print(sd.groupby('label').median())

    f1.tight_layout()
    sd['r_approx'] = np.round(sd['r_test'], 1)
    ms = sd.groupby(['label', 'r_approx']).mean()
    mp = ms.reset_index().pivot(index='r_approx', columns='label', values='S')
    es = sd.groupby(['label', 'r_approx']).sem()
    ep = es.reset_index().pivot(index='r_approx', columns='label', values='S')

    f2, ax = plt.subplots(1,
                          2,
                          figsize=column_and_half_short,
                          sharex=True,
                          sharey=True)
    palette = ['gray', DOT_COLORS['1D CNNx2'], DOT_COLORS['pop LN']]
    for c, p in zip(['A1 act', 'A1 pop LN', 'A1 1D CNN'], palette):
        ax[0].errorbar(mp.index, mp[c].values, ep[c].values, color=p, label=c)
    for c, p in zip(['PEG act', 'PEG pop LN', 'PEG 1D CNN'], palette):
        ax[1].errorbar(mp.index, mp[c].values, ep[c].values, color=p, label=c)
    ax[0].legend()
    ax[1].legend()
    ax[0].set_xlabel('r_test')
    ax[0].set_ylabel('Sparseness')
    return f1, f2, tests, sd
Пример #28
0
else:
    modelname_base = "psth.fs4.pup-ld-hrc-psthfr.z-pca.cc1.no.p-{0}-plgsm.p2-aev-rd"+\
                "_stategain.2xR.x1,3-spred-lvnorm.4xR.so.x2-inoise.4xR.x3"+\
                "_tfinit.xx0.n.lr1e4.cont.et4.i20000-lvnoise.r4-aev-ccnorm.md.t5.f0.ss3"
    states = [
        'st.pca0.pup+r1+s0,1', 'st.pca.pup+r1+s0,1', 'st.pca.pup+r1+s1',
        'st.pca.pup+r1'
    ]
    resp_modelname = f"psth.fs4.pup-ld-hrc-psthfr.z-pca.cc1.no.p-{states[-1]}-plgsm.p2-aev-rd.resp"+\
                "_stategain.2xR.x1,3-spred-lvnorm.4xR.so.x2-inoise.4xR.x3"+\
                "_tfinit.xx0.n.lr1e4.cont.et4.i20-lvnoise.r4-aev-ccnorm.md.t1.f0.ss3"

modelnames = [resp_modelname] + [modelname_base.format(s) for s in states]

xf, ctx = load_model_xform(cellid=cellid,
                           batch=batch,
                           modelname=modelnames[-1])

val = ctx['val'].copy()
resp = val['resp'].rasterize()
epoch_regex = "^STIM_"
epochs = ep.epoch_names_matching(resp.epochs, regex_str=epoch_regex)

input_name = 'pred0'
pred0 = val[input_name].extract_epochs(epochs, mask=val['mask'])
pred = val['pred'].extract_epochs(epochs, mask=val['mask'])
resp = val['resp'].extract_epochs(epochs, mask=val['mask'])
pupil = val['pupil'].extract_epochs(epochs, mask=val['mask'])
pmedian = np.nanmedian(val['pupil'].as_continuous())

epochs = list(resp.keys())
Пример #29
0
plt.figure()
ax = plt.subplot(3, 1, 1)
nplt.spectrogram_from_epoch(val['stim'], epoch, ax=ax, time_offset=2)

ax = plt.subplot(3, 1, 2)
nplt.timeseries_from_epoch([val['resp']], epoch, ax=ax)

raster = rec['resp'].extract_epoch(epoch)
ax = plt.subplot(3, 1, 3)
plt.imshow(raster[:, 0, :])

plt.tight_layout()

# see what a "traditional" NEMS model looks like
nems_modelname = "ozgf.fs100.ch18-ld-sev_dlog-wc.18x2.g-do.2x15-lvl.1-dexp.1_init-basic"
xfspec, ctx = load_model_xform(cellid, batch=batch, modelname=nems_modelname)
nplt.quickplot(ctx)

ex = gui.browse_xform_fit(ctx, xfspec)

##
batch, cellid = 308, 'AMT018a-09-1'
modelname = 'ozgf.fs100.ch18-ld-sev_dlog-wc.18x4.g-fir.2x15x2-relu.2-wc.2x1-lvl.1-dexp.1_init.tf.rb5-tf.n'
xfspec, ctx = fit_model_xform(cellid, batch=batch, modelname=modelname)
nplt.quickplot(ctx)
ex = gui.browse_xform_fit(ctx, xfspec)

###Plot complexity of model versus how effective it was
batch = 308
metric = 'r_test'
metric2 = 'n_parms'
Пример #30
0
site = 'AMT020a'
batch = 331

# LOAD RAW DATA / MODEL PREDICTIONS
indep = "psth.fs4.pup-loadpred.cpnmvm-st.pup0.pvp-plgsm.e10.sp-lvnoise.r8-aev_lvnorm.2xR.d.so-inoise.3xR_ccnorm.t5.ss1"
rlv = "psth.fs4.pup-loadpred.cpnmvm-st.pup0.pvp0-plgsm.e10.sp-lvnoise.r8-aev_lvnorm.SxR.d.so-inoise.2xR_ccnorm.t5.ss1"
plv = "psth.fs4.pup-loadpred.cpnmvm-st.pup.pvp0-plgsm.e10.sp-lvnoise.r8-aev_lvnorm.SxR.d.so-inoise.2xR_ccnorm.t5.ss1"

reverything = 'psth.fs4.pup-ld-st.pup0.pvp0-epcpn-mvm.25.2-hrc-psthfr-plgsm.e10.sp-lvnoise.r8-aev_sdexp2.SxR-lvnorm.SxR.d.so-inoise.2xR_ccnorm.r.t5.ss1'
indep = 'psth.fs4.pup-ld-st.pup0.pvp-epcpn-mvm.25.2-hrc-psthfr-plgsm.e10.sp-lvnoise.r8-aev_sdexp2.SxR-lvnorm.2xR.d.so-inoise.3xR_ccnorm.r.t5.ss1'
rlv = 'psth.fs4.pup-ld-st.pup0.pvp-epcpn-mvm.25.2-hrc-psthfr-plgsm.e10.sp-lvnoise.r8-aev_sdexp2.SxR-lvnorm.2xR.d.so-inoise.2xR_ccnorm.r.t5.ss1'
plv = 'psth.fs4.pup-ld-st.pup.pvp0-epcpn-mvm.25.2-hrc-psthfr-plgsm.e10.sp-lvnoise.r8-aev_sdexp2.SxR-lvnorm.SxR.d.so-inoise.2xR_ccnorm.r.t5.ss1'

try:
    cellid = site
    xf_indep, ctx_indep = load_model_xform(modelname=indep, batch=batch, cellid=cellid)
    xf_rlv, ctx_rlv = load_model_xform(modelname=rlv, batch=batch, cellid=cellid)
    xf_plv, ctx_plv = load_model_xform(modelname=plv, batch=batch, cellid=cellid)
except:
    cellid = [c for c in nd.get_batch_cells(batch).cellid if site in c][0]
    xf_indep, ctx_indep = load_model_xform(modelname=indep, batch=batch, cellid=cellid)
    xf_rlv, ctx_rlv = load_model_xform(modelname=rlv, batch=batch, cellid=cellid)
    xf_plv, ctx_plv = load_model_xform(modelname=plv, batch=batch, cellid=cellid)

# GET COV MATRICES
stim = np.arange(10)
ibg, ism = fhelp.get_cov_matrices(ctx_indep['val'].copy(), sig='pred', sub='psth_sp', stims=stim, ss=0)
rbg, rsm = fhelp.get_cov_matrices(ctx_rlv['val'].copy(), sig='pred', sub='psth_sp', stims=stim, ss=0)
pbg, psm = fhelp.get_cov_matrices(ctx_plv['val'].copy(), sig='pred', sub='psth_sp', stims=stim, ss=0)
bg, sm = fhelp.get_cov_matrices(ctx_plv['val'].copy(), sig='resp', sub='psth_sp', stims=stim, ss=0)
mm = np.abs(np.max(np.concatenate((ibg, ism, rbg, rsm, pbg, psm, bg, sm))))