コード例 #1
0
def load_snr_and_cellids(a1, peg, a1_snr_path, peg_snr_path):
    a1_significant_cells = get_significant_cells(a1,
                                                 SIG_TEST_MODELS,
                                                 as_list=True)
    a1_results = nd.batch_comp(batch=a1,
                               modelnames=SIG_TEST_MODELS,
                               stat=PLOT_STAT,
                               cellids=a1_significant_cells)
    a1_index = a1_results.index.values
    peg_significant_cells = get_significant_cells(peg,
                                                  SIG_TEST_MODELS,
                                                  as_list=True)
    peg_results = nd.batch_comp(batch=peg,
                                modelnames=SIG_TEST_MODELS,
                                stat=PLOT_STAT,
                                cellids=peg_significant_cells)
    peg_index = peg_results.index.values

    a1_snr_df = snr_by_batch(a1, 'ozgf.fs100.ch18',
                             load_path=a1_snr_path).loc[a1_index]
    peg_snr_df = snr_by_batch(peg, 'ozgf.fs100.ch18',
                              load_path=peg_snr_path).loc[peg_index]
    a1_snr = a1_snr_df.values.flatten()
    a1_cellids = a1_snr_df.index
    peg_snr = peg_snr_df.values.flatten()
    peg_idx = np.argsort(peg_snr_df.values, axis=None)
    peg_snr_sample = peg_snr[peg_idx]
    peg_cellids = peg_snr_df.index[peg_idx].values

    return a1_snr, peg_snr_sample, a1_snr_df, peg_snr_df, a1_cellids, peg_cellids
コード例 #2
0
ファイル: svd_utils.py プロジェクト: LBHB/nems_db
def get_significant_cells(batch, models, as_list=False):

    df_r = nd.batch_comp(batch, models, stat='r_test')
    df_r.dropna(axis=0, how='any', inplace=True)
    df_r.sort_index(inplace=True)
    df_e = nd.batch_comp(batch, models, stat='se_test')
    df_e.dropna(axis=0, how='any', inplace=True)
    df_e.sort_index(inplace=True)
    df_f = nd.batch_comp(batch, models, stat='r_floor')
    df_f.dropna(axis=0, how='any', inplace=True)
    df_f.sort_index(inplace=True)

    masks = []
    for m in models:
        mask1 = df_r[m] > df_e[m] * 2
        mask2 = df_r[m] > df_f[m]
        mask = mask1 & mask2
        masks.append(mask)

    all_significant = masks[0]
    for m in masks[1:]:
        all_significant &= m

    if as_list:
        all_significant = all_significant[all_significant].index.values.tolist(
        )

    return all_significant
コード例 #3
0
def plot_heldout_a1_vs_peg(a1_snr_path, peg_snr_path, ax=None):
    if ax is None:
        _, ax = plt.subplots(1, 1, figsize=(single_column_short))

    a1_snr, peg_snr, a1_snr_df, peg_snr_df, a1_cellids, peg_cellids = load_snr_and_cellids(
        322, 323, a1_snr_path, peg_snr_path)
    cell_map_df = get_matched_snr_mapping(a1_snr, peg_snr, a1_snr_df,
                                          peg_snr_df, a1_cellids, peg_cellids)
    #a1_matched_cellids = cell_map_df.A1_cellid
    peg_matched_cellids = cell_map_df.PEG_cellid
    #a1_r = nd.batch_comp(322, HELDOUT[:1], stat=PLOT_STAT).loc[a1_matched_cellids]
    peg_r = nd.batch_comp(323, [HELDOUT_CROSSBATCH],
                          stat=PLOT_STAT).loc[peg_matched_cellids]
    peg_r2 = nd.batch_comp(323, HELDOUT[:1],
                           stat=PLOT_STAT).loc[peg_matched_cellids]
    #test_crossbatch = st.wilcoxon(a1_r.values.flatten(), peg_r.values.flatten(), alternative='two-sided')
    test_within_peg = st.wilcoxon(peg_r.values.flatten(),
                                  peg_r2.values.flatten(),
                                  alternative='two-sided')

    #a1_r = a1_r.rename(columns={HELDOUT[0]: 'A1'})
    peg_r = peg_r.rename(columns={HELDOUT_CROSSBATCH: 'cross-batch held'})
    peg_r2 = peg_r2.rename(columns={HELDOUT[0]: 'PEG held'})
    combined_r = peg_r.join(peg_r2, how='outer')
    #combined_df = combined_r.join(peg_snr_df, how='left')
    #combined_df = combined_df.fillna(peg_snr_df.loc[peg_matched_cellids])
    #combined_df[PLOT_STAT] = np.nan
    #combined_df.r_ceiling.loc[a1_matched_cellids] = combined_df.A1.loc[a1_matched_cellids]
    #combined_df.r_ceiling.loc[peg_matched_cellids] = combined_df.PEG.loc[peg_matched_cellids]
    #combined_df['variable'] = 'A1'
    #combined_df['variable'].loc[peg_matched_cellids] = 'PEG'
    #combined_df = combined_df.drop(columns='A1')
    #combined_df = combined_df.drop(columns='PEG')
    #combined_df = combined_df.sort_values(by='snr')
    #combined_df = combined_df.reset_index(0)

    #sns.stripplot(x='variable', y=PLOT_STAT, data=combined_df, color=DOT_COLORS['1D CNNx2'],
    #hue='snr', palette=f'dark:{DOT_COLORS["1D CNNx2"]}',
    #size=2, ax=ax, zorder=0)
    sns.stripplot(data=combined_r,
                  color=DOT_COLORS['1Dx2-CNN'],
                  size=2,
                  ax=ax,
                  zorder=0)
    plt.xticks(rotation=45, fontsize=6, ha='right')
    sns.boxplot(data=combined_r,
                boxprops={
                    'facecolor': 'None',
                    'linewidth': 1
                },
                showcaps=False,
                showfliers=False,
                whiskerprops={'linewidth': 0},
                ax=ax)
    #ax.legend_.remove()  # tried to color by snr but it's just really hard to see, and no pattern anyway
    plt.tight_layout()

    return peg_r, peg_r2, test_within_peg
コード例 #4
0
def plot_pred_scatter(batch, modelnames, labels=None, colors=None, ax=None):

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.figure
    if labels is None:
        labels = ['model 1', 'model 2']
    if colors is None:
        colors = ['black', 'black']

    significant_cells = get_significant_cells(batch,
                                              SIG_TEST_MODELS,
                                              as_list=True)

    sig_scores = nd.batch_comp(batch,
                               modelnames,
                               cellids=significant_cells,
                               stat='r_test')
    se_scores = nd.batch_comp(batch,
                              modelnames,
                              cellids=significant_cells,
                              stat='se_test')
    ceiling_scores = nd.batch_comp(batch,
                                   modelnames,
                                   cellids=significant_cells,
                                   stat=PLOT_STAT)
    nonsig_cells = list(set(sig_scores.index) - set(significant_cells))

    # figure out units with significant differences between models
    sig = (sig_scores[modelnames[1]] - se_scores[modelnames[1]] > sig_scores[modelnames[0]] + se_scores[modelnames[0]]) | \
          (sig_scores[modelnames[0]] - se_scores[modelnames[0]] > sig_scores[modelnames[1]] + se_scores[modelnames[1]])
    group1 = (ceiling_scores.loc[~sig, modelnames[0]].values,
              ceiling_scores.loc[~sig, modelnames[1]].values)
    group2 = (ceiling_scores.loc[sig, modelnames[0]].values,
              ceiling_scores.loc[sig, modelnames[1]].values)
    n_nonsig = group1[0].size
    n_sig = group2[0].size

    scatter_groups([group1, group2], ['lightgray', 'black'],
                   ax=ax,
                   labels=['N.S.', 'p < 0.5'])
    #ax.set_title(f'{batch_str[batch]} {PLOT_STAT}')
    ax.set_xlabel(
        f'{labels[0]}\n(median r={ceiling_scores[modelnames[0]].median():.3f})',
        color=colors[0])
    ax.set_ylabel(
        f'{labels[1]}\n(median r={ceiling_scores[modelnames[1]].median():.3f})',
        color=colors[1])

    return fig, n_sig, n_nonsig
コード例 #5
0
def plot_conv_scatters(batch):
    significant_cells = get_significant_cells(batch,
                                              SIG_TEST_MODELS,
                                              as_list=True)

    sig_scores = nd.batch_comp(batch,
                               ALL_FAMILY_MODELS,
                               cellids=significant_cells,
                               stat='r_test')
    se_scores = nd.batch_comp(batch,
                              ALL_FAMILY_MODELS,
                              cellids=significant_cells,
                              stat='se_test')
    ceiling_scores = nd.batch_comp(batch,
                                   ALL_FAMILY_MODELS,
                                   cellids=significant_cells,
                                   stat=PLOT_STAT)
    nonsig_cells = list(set(sig_scores.index) - set(significant_cells))

    fig, ax = plt.subplots(1, 3, figsize=(12, 4))

    # LN vs DNN-single
    plot_pred_scatter(batch, [ALL_FAMILY_MODELS[3], ALL_FAMILY_MODELS[4]],
                      labels=['1Dx2-CNN', 'pop-LN'],
                      ax=ax[0])

    # LN vs conv1dx2+d
    plot_pred_scatter(batch, [ALL_FAMILY_MODELS[3], ALL_FAMILY_MODELS[2]],
                      labels=['1Dx2-CNN', 'pop-LN'],
                      ax=ax[0])

    # conv2d vs conv1dx2+d
    sig = (sig_scores[ALL_FAMILY_MODELS[0]] - se_scores[ALL_FAMILY_MODELS[0]] > sig_scores[ALL_FAMILY_MODELS[2]] + se_scores[ALL_FAMILY_MODELS[2]]) | \
        (sig_scores[ALL_FAMILY_MODELS[2]] - se_scores[ALL_FAMILY_MODELS[2]] > sig_scores[ALL_FAMILY_MODELS[0]] + se_scores[ALL_FAMILY_MODELS[0]])
    group3 = (ceiling_scores.loc[~sig, ALL_FAMILY_MODELS[0]].values,
              ceiling_scores.loc[~sig, ALL_FAMILY_MODELS[2]].values)
    group4 = (ceiling_scores.loc[sig, ALL_FAMILY_MODELS[0]].values,
              ceiling_scores.loc[sig, ALL_FAMILY_MODELS[2]].values)
    scatter_groups([group3, group4], ['lightgray', 'black'], ax=ax[2])
    ax[2].set_title('batch %d, %s, conv2d vs conv1dx2+d' % (batch, PLOT_STAT))
    ax[2].set_xlabel(
        f'DNN (2D conv) pred. correlation ({ceiling_scores[ALL_FAMILY_MODELS[0]].mean():.3f})',
        color=colors[0])
    ax[2].set_ylabel(
        f'DNN (1D conv) pred. correlation ({ceiling_scores[ALL_FAMILY_MODELS[2]].mean():.3f})',
        color=colors[1])

    plt.tight_layout()

    return fig
コード例 #6
0
def pop_selector(recording_uri_list,
                 batch=None,
                 cellid=None,
                 rand_match=False,
                 cell_count=20,
                 best_cells=False,
                 **context):

    rec = load_recording(recording_uri_list[0])
    all_cells = rec.meta['cellid']
    this_site = cellid
    cellid = [c for c in all_cells if c.split("-")[0] == this_site]
    site_cellid = cellid.copy()

    pmodelname = "ozgf.fs100.ch18-ld-sev_dlog-wc.18x3.g-fir.3x15-lvl.1-dexp.1_init-basic"
    single_perf = nd.batch_comp(batch=batch,
                                modelnames=[pmodelname],
                                cellids=all_cells,
                                stat='r_test')
    this_perf = np.array([
        single_perf[single_perf.index == c][pmodelname].values[0]
        for c in cellid
    ])
    sidx = np.argsort(this_perf)

    if best_cells:
        keepidx = (this_perf >= this_perf[sidx[-cell_count]])
        cellid = list(np.array(cellid)[keepidx])
        this_perf = this_perf[keepidx]
    else:
        cellid = cellid[:cell_count]
        this_perf = this_perf[:cell_count]

    if rand_match:
        out_cellid = [c for c in all_cells if c.split("-")[0] != this_site]
        out_perf = np.array([
            single_perf[single_perf.index == c][pmodelname].values[0]
            for c in out_cellid
        ])

        alt_cellid = []
        alt_perf = []
        for i, c in enumerate(cellid):
            d = np.abs(out_perf - this_perf[i])
            w = np.argmin(d)
            alt_cellid.append(out_cellid[w])
            alt_perf.append(out_perf[w])
            out_perf[w] = 100  # prevent cell from getting picked again
        log.info("Rand matched cellids: %s", alt_cellid)
        log.info("Mean actual: %.3f", np.mean(this_perf))
        print(this_perf)
        log.info("Mean rand: %.3f", np.mean(np.array(alt_perf)))
        print(np.array(alt_perf))
        rec['resp'] = rec['resp'].extract_channels(alt_cellid)
        rec.meta['cellid'] = cellid
    else:
        rec['resp'] = rec['resp'].extract_channels(cellid)
        rec.meta['cellid'] = cellid

    return {'rec': rec}
コード例 #7
0
ファイル: lv.py プロジェクト: LBHB/nems_db
def get_population_cellids(batch, modelname=None):
    if modelname is None:
        modelname = ref_modelname

    def get_siteid(s):
        return s.split("-")[0]

    all_siteids = {}
    rep_cellids = {}

    d = nd.batch_comp(batch=batch, modelnames=[modelname])
    d['siteid'] = d.index.map(get_siteid)

    siteids = list(set(d['siteid'].tolist()))

    all_siteids[batch] = siteids

    site_cellids = [
        d.loc[d.index.str.startswith(s)].index.values[0] for s in siteids
    ]
    site_cellids.sort()

    rep_cellids[batch] = site_cellids

    return all_siteids, rep_cellids
コード例 #8
0
ファイル: utils.py プロジェクト: LBHB/nems_db
def get_dataframes(batch, gc, stp, LN, combined):
    df_r = nd.batch_comp(batch, [gc, stp, LN, combined], stat='r_test')
    df_c = nd.batch_comp(batch, [gc, stp, LN, combined], stat='r_ceiling')
    df_e = nd.batch_comp(batch, [gc, stp, LN, combined], stat='se_test')
    # Remove any cellids that have NaN for 1 or more models
    # and sort indexes, double check that all are equal
    df_r.dropna(axis=0, how='any', inplace=True)
    df_e.dropna(axis=0, how='any', inplace=True)
    df_c.dropna(axis=0, how='any', inplace=True)
    df_r.sort_index(inplace=True)
    df_e.sort_index(inplace=True)
    df_c.sort_index(inplace=True)
    if (not np.all(df_r.index.values == df_e.index.values)) \
            or (not np.all(df_r.index.values == df_c.index.values)):
        raise ValueError('index mismatch in dataframes')

    return df_r, df_c, df_e
コード例 #9
0
ファイル: svd_utils.py プロジェクト: LBHB/nems_db
def get_rceiling_correction(batch):
    LN_model = MODELGROUPS['LN'][3]

    rceiling_ratios = []
    significant_cells = get_significant_cells(batch,
                                              SIG_TEST_MODELS,
                                              as_list=True)
    rtest = nd.batch_comp(batch, [LN_model],
                          cellids=significant_cells,
                          stat='r_test')
    rceiling = nd.batch_comp(batch, [LN_model],
                             cellids=significant_cells,
                             stat='r_ceiling')

    rceiling_ratios = rceiling[LN_model] / rtest[LN_model]
    rceiling_ratios.loc[rceiling_ratios < 1] = 1

    return rceiling_ratios
コード例 #10
0
ファイル: utils.py プロジェクト: LBHB/nems_db
def get_filtered_cellids(batch,
                         gc,
                         stp,
                         LN,
                         combined,
                         se_filter=True,
                         LN_filter=False,
                         as_lists=True):

    df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined)
    df_f = nd.batch_comp(batch, [gc, stp, LN, combined], stat='r_floor')
    df_f.dropna(axis=0, how='any', inplace=True)
    df_f.sort_index(inplace=True)
    if not np.all(df_f.index.values == df_r.index.values):
        raise ValueError('index mismatch in dataframes')

    cellids = df_r.index.values.tolist()
    gc_test, gc_se, gc_floor = [d[gc] for d in [df_r, df_e, df_f]]
    stp_test, stp_se, stp_floor = [d[stp] for d in [df_r, df_e, df_f]]
    ln_test, ln_se, ln_floor = [d[LN] for d in [df_r, df_e, df_f]]
    gc_stp_test, gc_stp_se, gc_stp_floor = [
        d[combined] for d in [df_r, df_e, df_f]
    ]

    if se_filter:
        # Remove if performance not significant at all
        good_cells = ((gc_test > gc_se * 2) & (gc_test > gc_floor) &
                      (stp_test > stp_se * 2) & (stp_test > stp_floor) &
                      (ln_test > ln_se * 2) & (ln_test > ln_floor) &
                      (gc_stp_test > gc_stp_se * 2) &
                      (gc_stp_test > gc_stp_floor))
    else:
        # Set to series w/ all True, so none are skipped
        good_cells = (gc_test != np.nan)

    if LN_filter:
        # Remove if performance significantly worse than LN
        bad_cells = ((gc_test + gc_se < ln_test - ln_se) |
                     (stp_test + stp_se < ln_test - ln_se) |
                     (gc_stp_test + gc_stp_se < ln_test - ln_se))
    else:
        # Set to series w/ all False, so none are skipped
        bad_cells = (gc_test == np.nan)

    keep = good_cells & ~bad_cells

    if as_lists:
        cellids = df_r[keep].index.values.tolist()
        under_chance = df_r[~good_cells].index.values.tolist()
        less_LN = df_r[bad_cells].index.values.tolist()
        return cellids, under_chance, less_LN
    else:
        return keep, ~good_cells, bad_cells
コード例 #11
0
def scatter_titan(batch):
    PLOT_STAT = 'r_test'
    TITAN_MODEL = 'ozgf.fs100.ch18.pop-loadpop-norm.l1-popev_wc.18x70.g-fir.1x15x70-relu.70.f-wc.70x80-fir.1x10x80-relu.80.f-wc.80x100-relu.100-wc.100xR-lvl.R-dexp.R_tfinit.n.mc50.lr1e3.es20-newtf.n.mc100.lr1e4.exa'
    REFERENCE_MODEL = 'ozgf.fs100.ch18.pop-loadpop-norm.l1-popev_wc.18x70.g-fir.1x15x70-relu.70.f-wc.70x80-fir.1x10x80-relu.80.f-wc.80x100-relu.100-wc.100xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4'
    MODELS = [REFERENCE_MODEL, TITAN_MODEL]

    significant_cells = get_significant_cells(batch, MODELS, as_list=True)

    sig_scores = nd.batch_comp(batch,
                               MODELS,
                               cellids=significant_cells,
                               stat='r_test')
    se_scores = nd.batch_comp(batch,
                              MODELS,
                              cellids=significant_cells,
                              stat='se_test')
    ceiling_scores = nd.batch_comp(batch,
                                   MODELS,
                                   cellids=significant_cells,
                                   stat=PLOT_STAT)
    nonsig_cells = list(set(sig_scores.index) - set(significant_cells))

    fig, ax = plt.subplots(1, 1, figsize=(4, 4))

    # LN vs DNN-single
    sig = (sig_scores[MODELS[1]] - se_scores[MODELS[1]] > sig_scores[MODELS[0]] + se_scores[MODELS[0]]) | \
          (sig_scores[MODELS[0]] - se_scores[MODELS[0]] > sig_scores[MODELS[1]] + se_scores[MODELS[1]])
    group1 = (ceiling_scores.loc[~sig, MODELS[0]].values,
              ceiling_scores.loc[~sig, MODELS[1]].values)
    group2 = (ceiling_scores.loc[sig, MODELS[0]].values,
              ceiling_scores.loc[sig, MODELS[1]].values)

    scatter_groups([group1, group2], ['lightgray', 'black'], ax=ax)
    ax.set_title('batch %d, %s, Conv1d vs Titan' % (batch, PLOT_STAT))
    ax.set_xlabel(f'Single stim ({ceiling_scores[MODELS[0]].mean():.3f})',
                  color='orange')
    ax.set_ylabel(f'Titan ({ceiling_scores[MODELS[1]].mean():.3f})',
                  color='lightgreen')

    return fig
コード例 #12
0
def get_heldout_results(batch, significant_cells, short_names):
    modelgroups = {}
    for i, name in enumerate(short_names):
        modelgroups[name + ' held'] = HELDOUT[i]
        modelgroups[name + ' match'] = MATCHED[i]

    r_ceilings = {}
    for n in short_names:
        # sum is just to collapse cols
        heldout_r_ceiling = nd.batch_comp(batch, [modelgroups[n + ' held']],
                                          cellids=significant_cells,
                                          stat=PLOT_STAT)[modelgroups[n +
                                                                      ' held']]
        matched_r_ceiling = nd.batch_comp(
            batch, [modelgroups[n + ' match']],
            cellids=significant_cells,
            stat=PLOT_STAT)[modelgroups[n + ' match']]

        r_ceilings[n + ' held'] = heldout_r_ceiling
        r_ceilings[n + ' match'] = matched_r_ceiling

    return r_ceilings
コード例 #13
0
def equivalence_effect_size(batch, models, performance_stat='r_ceiling',
                            manual_cellids=None):

    df = nd.batch_comp(batch=batch, modelnames=models, stat=performance_stat,
                       cellids=manual_cellids)
    stats = [df[model].values for model in models]
    rel_1 = stats[0] - stats[2]
    rel_2 = stats[1] - stats[2]
    effect_sizes = 0.5*(rel_1 + rel_2)
    results = {'performance_effect': effect_sizes, 'cellid': df.index.values}
    df = pd.DataFrame.from_dict(results)
    df.set_index('cellid', inplace=True)

    return df
コード例 #14
0
ファイル: pop_model_utils.py プロジェクト: LBHB/nems_db
def count_fits(models, batch=None):
    if batch is None:
        batches = [322, 323]
    else:
        batches = [batch]

    for batch in batches:
        df_r = nd.batch_comp(batch, models, stat='r_test')
        #print(f"BATCH {batch}:")
        for i, c in enumerate(df_r.columns):
            parts = c.split("_")
            if i == 0:
                print(f"BATCH {batch} -- {parts[0]}_XX_{parts[2]}")
            print(f'{parts[1]}: {df_r[c].count()}')
コード例 #15
0
ファイル: summary_plots.py プロジェクト: LBHB/nems_db
def scatter_bar(batches, modelnames, stest=SIG_TEST_MODELS, axes=None):
    if axes is None:
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(3.5, 6))
    else:
        ax1, ax2 = axes

    cellids = [
        get_significant_cells(batch, stest, as_list=True) for batch in batches
    ]
    r_values = [
        nd.batch_comp(batch, modelnames, cellids=cells, stat=PLOT_STAT)
        for batch, cells in zip(batches, cellids)
    ]
    all_r_values = pd.concat(r_values)

    # NOTE: if ALL_FAMILY_MODELS changes, the 3, 2 indices will need to be updated
    # Scatter Plot -- LN vs c1dx2+d
    ax1.scatter(all_r_values[modelnames[3]],
                all_r_values[modelnames[2]],
                s=2,
                c='black')
    ax1.plot([[0, 0], [1, 1]], c='black', linestyle='dashed', linewidth=1)
    ax_remove_box(ax1)
    #set_equal_axes(ax1)
    ax1.set_ylim(0, 1)
    ax1.set_xlim(0, 1)
    ax1.set_xlabel('LN_pop prediction accuracy')
    ax1.set_ylabel('conv1Dx2 prediction accuracy')

    # Bar Plot -- Median for each model
    # NOTE: ordering of names is assuming ALL_FAMILY_MODELS is being used and has not changed.
    short_names = ['conv2d', 'conv1d', 'conv1dx2+d', 'LN_pop', 'dnn1_single']
    bar_colors = [DOT_COLORS[k] for k in short_names]
    ax2.bar(np.arange(0, len(modelnames)),
            all_r_values.median(axis=0).values,
            color=bar_colors,
            tick_label=short_names)
    ax_remove_box(ax2)
    ax2.set_ylabel('Median Prediction Accuracy')
    ax2.set_xticklabels(ax2.get_xticklabels(), rotation='45', ha='right')

    fig = plt.gcf()
    fig.tight_layout()

    return ax1, ax2
コード例 #16
0
def bar_mean(batch, modelnames, stest=SIG_TEST_MODELS, ax=None):

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.figure

    cellids = get_significant_cells(batch, stest, as_list=True)
    r_values = nd.batch_comp(batch,
                             modelnames,
                             cellids=cellids,
                             stat=PLOT_STAT)

    # Bar Plot -- Median for each model
    # NOTE: ordering of names is assuming ALL_FAMILY_MODELS is being used and has not changed.
    bar_colors = [DOT_COLORS[k] for k in shortnames]
    medians = r_values.median(axis=0).values
    ax.bar(np.arange(0, len(modelnames)),
           medians,
           color=bar_colors,
           edgecolor='black',
           linewidth=1,
           tick_label=shortnames)
    ax.set_ylabel('Median prediction\ncorrelation')
    ax.set_xticklabels(ax.get_xticklabels(), rotation='45', ha='right')

    # Test significance for all comparisons
    stats_results = {}
    reduced_modelnames = modelnames.copy()
    reduced_shortnames = shortnames.copy()
    for m1, s1 in zip(modelnames, shortnames):
        i = 0
        reduced_modelnames.pop(i)
        reduced_shortnames.pop(i)
        i += 1
        # compare each model to every other model
        for m2, s2 in zip(reduced_modelnames, reduced_shortnames):
            stats_test = st.wilcoxon(r_values[m1],
                                     r_values[m2],
                                     alternative='two-sided')
            #stats_test = arrays_to_p(r_values[m1], r_values[m2], cellids, twosided=True)
            key = f'{s1} vs {s2}'
            stats_results[key] = stats_test

    return ax, medians, stats_results
コード例 #17
0
ファイル: xform_wrappers.py プロジェクト: nadoss/nems_db
def _matching_cells(batch=289, siteid=None, alt_cells_available=None,
                    cell_count=None, best_cells=False):

    pmodelname = "ozgf.fs100.ch18-ld-sev_dlog-wc.18x3.g-fir.3x15-lvl.1-dexp.1_init-basic"
    single_perf = nd.batch_comp(batch=batch, modelnames=[pmodelname], stat='r_test')
    if alt_cells_available is not None:
        all_cells = alt_cells_available
    else:
        all_cells = list(single_perf.index)

    cellid = [c for c in all_cells if c.split("-")[0]==siteid]
    this_perf=np.array([single_perf[single_perf.index==c][pmodelname].values[0] for c in cellid])

    if cell_count is None:
        pass
    elif best_cells:
        sidx = np.argsort(this_perf)
        keepidx=(this_perf >= this_perf[sidx[-cell_count]])
        cellid=list(np.array(cellid)[keepidx])
        this_perf = this_perf[keepidx]
    else:
        cellid=cellid[:cell_count]
        this_perf = this_perf[:cell_count]

    out_cellid = [c for c in all_cells if c.split("-")[0]!=siteid]
    out_perf=np.array([single_perf[single_perf.index==c][pmodelname].values[0]
                       if c in single_perf.index else 0.0
                       for c in out_cellid])

    alt_cellid=[]
    alt_perf=[]
    for i, c in enumerate(cellid):
        d = np.abs(out_perf-this_perf[i])
        w = np.argmin(d)
        alt_cellid.append(out_cellid[w])
        alt_perf.append(out_perf[w])
        out_perf[w]=100 # prevent cell from getting picked again
    log.info("Rand matched cellids: %s", alt_cellid)
    log.info("Mean actual: %.3f", np.mean(this_perf))
    log.info(this_perf)
    log.info("Mean rand: %.3f", np.mean(np.array(alt_perf)))
    log.info(np.array(alt_perf))

    return cellid, this_perf, alt_cellid, alt_perf
コード例 #18
0
    def run_fun(self):
        cellids = self._selected_cells()
        modelnames = self._selected_modelnames()
        batch = self.batch

        d = nd.batch_comp(batch=batch, modelnames=modelnames, cellids=cellids, stat='r_test')

        cellids = d.index
        siteids = [c.split("-")[0] for c in cellids]

        d['siteids'] = siteids
        mean_r_test = d[modelnames].mean().values

        shortened, prefix, suffix = find_common(modelnames)
        modelcount=len(modelnames)
        plt.figure()
        ax = plt.subplot(1,1,1)
        site_r_test = d.groupby(['siteids']).mean().values.T

        usiteids=d.groupby(['siteids']).groups.keys()

        ax.bar(np.arange(len(mean_r_test)), mean_r_test, color='lightgray')
        ax.plot(site_r_test)
        print('{} -- {}'.format(prefix, suffix))
        for i,m in enumerate(modelnames):

            print("{} (n={}): {:.3f}".format(shortened[i], d[m].count(), d[m].mean()))
            r = d[m].values
            #plt.plot(np.random.uniform(low=-0.25, high=0.25, size=r.shape)+i,
            #         r, '.', color='gray')
            s = shortened[i].replace("_","\n") + "\n{:.3f}".format(mean_r_test[i])
            ax.text(i, 0, s, rotation=90, color='black',
                    ha='left', va='bottom', fontsize=7)

        for i, s in enumerate(usiteids):
            ax.text(modelcount-0.2, site_r_test[-1,i], s)

        plt.title('{} -- {}'.format(prefix, suffix))
        ax.set_xticks(np.arange(modelcount))
        ax.set_xticklabels([])

        nplt.ax_remove_box(ax)
コード例 #19
0
def load_models(cell, batch, models, check_db=True, site=None, site_cache=None):
    '''Load standardized psth from each model, error if not all fits exist.'''

    if check_db:
        # Before trying to load, check database to see if a result exists.
        # Should set False if you know model results are not stored in DB,
        # but exist in file storage.
        df = nd.batch_comp(batch=batch, modelnames=models, cellids=[cell])
        if np.sum(df.isna().values) > 0:
            # at least one cell wasn't fit (or at least not stored in db)
            # so skip trying to load any of them since all are required.
            raise ValueError('Not all results exist for: %s, %d' % (cell, batch))

    # Load all models
    ctxs = []
    for model in models:
        if site_cache is None:
            xf, ctx = xhelp.load_model_xform(cell, batch, model)
        elif model in site_cache[site]:
            log.info("Site %s is cached, skipping load...", site)
            ctx = site_cache[site][model]
        else:
            xf, ctx = xhelp.load_model_xform(cell, batch, model)
            site_cache[site][model] = ctx
        ctxs.append(ctx)

    for ctx in ctxs:
        if ctx['val']['pred'].chans is None:
            ctx['val']['pred'].chans = copy(ctx['val']['resp'].chans)

    # Pull out model predictions and remove times with nan for at least 1 model
    preds = [ctx['val'].apply_mask()['pred'].extract_channels([cell]).as_continuous() for ctx in ctxs]
    ff = np.isfinite(preds[0])
    for pred in preds[1:]:
        ff &= np.isfinite(pred)
    no_nans = [pred[ff] for pred in preds]

    return no_nans
コード例 #20
0
def stp_v_beh():

    batch1 = 274
    batch2 = 275
    modelnames=["env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-fir.1x15-lvl.1-dexp.1_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-fir.1x15-lvl.1-rep.2-dexp.2-mrg_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-rep.2-fir.1x15x2-lvl.2-dexp.2-mrg_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-stp.1-fir.1x15-lvl.1-dexp.1_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-stp.1-fir.1x15-lvl.1-rep.2-dexp.2-mrg_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-stp.1-rep.2-fir.1x15x2-lvl.2-dexp.2-mrg_jk.nf5-init.st-basic",
                "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-rep.2-stp.2-fir.1x15x2-lvl.2-dexp.2-mrg_jk.nf5-init.st-basic"]
    fileprefix="fig8.stp_v_beh"
    n1=modelnames[0]
    n2=modelnames[-3]

    xc_range = [-0.05, 0.6]

    df1 = nd.batch_comp(batch1,modelnames,stat='r_test').reset_index()
    df1_e = nd.batch_comp(batch1,modelnames,stat='se_test').reset_index()

    df2 = nd.batch_comp(batch2,modelnames,stat='r_test').reset_index()
    df2_e = nd.batch_comp(batch2,modelnames,stat='se_test').reset_index()

    df = df1.append(df2)
    df_e = df1_e.append(df2_e)

    cellcount = len(df)

    beta1 = df[n1]
    beta2 = df[n2]
    beta1_test = df[n1]
    beta2_test = df[n2]
    se1 = df_e[n1]
    se2 = df_e[n2]

    beta1[beta1>1]=1
    beta2[beta2>1]=1

    # test for significant improvement
    improvedcells = (beta2_test-se2 > beta1_test+se1)

    # test for signficant prediction at all
    goodcells = ((beta2_test > se2*3) | (beta1_test > se1*3))

    fh = plt.figure()
    ax = plt.subplot(2,2,1)
    stateplots.beta_comp(beta1[goodcells], beta2[goodcells],
                         n1='LN STRF', n2='STP+BEH LN STRF',
                         hist_range=xc_range, ax=ax,
                         highlight=improvedcells[goodcells])


    # LN vs. STP:
    beta1b = df[modelnames[3]]
    beta1a = df[modelnames[0]]
    beta1 = beta1b - beta1a
    se1a = df_e[modelnames[3]]
    se1b= df_e[modelnames[0]]

    b1=4
    b0=3
    beta2b = df[modelnames[b1]]
    beta2a = df[modelnames[b0]]
    beta2 = beta2b - beta2a
    se2a = df_e[modelnames[b1]]
    se2b= df_e[modelnames[b0]]

    stpgood = (beta1 > se1a+se1b)
    behgood = (beta2 > se2a+se2b)
    neither_good = np.logical_not(stpgood) & np.logical_not(behgood)
    both_good = stpgood & behgood
    stp_only_good = stpgood & np.logical_not(behgood)
    beh_only_good = np.logical_not(stpgood) & behgood

    xc_range = np.array([-0.05, 0.15])
    beta1[beta1<xc_range[0]]=xc_range[0]
    beta2[beta2<xc_range[0]]=xc_range[0]

    zz = np.zeros(2)
    ax=plt.subplot(2,2,2)
    ax.plot(xc_range,zz,'k--',linewidth=0.5)
    ax.plot(zz,xc_range,'k--',linewidth=0.5)
    ax.plot(xc_range, xc_range, 'k--', linewidth=0.5)
    l = ax.plot(beta1[neither_good], beta2[neither_good], '.', color='lightgray') +\
        ax.plot(beta1[beh_only_good], beta2[beh_only_good], '.', color='purple') +\
        ax.plot(beta1[stp_only_good], beta2[stp_only_good], '.', color='orange') +\
        ax.plot(beta1[both_good], beta2[both_good], '.', color='black')
    ax_remove_box(ax)
    ax.set_aspect('equal', 'box')
    #plt.axis('equal')
    ax.set_xlim(xc_range)
    ax.set_ylim(xc_range)
    ax.set_xlabel('delta(stp)')
    ax.set_ylabel('delta(beh)')

    olap=np.zeros(100)
    a = stpgood.values.copy()
    b = behgood.values.copy()
    for i in range(100):
        np.random.shuffle(a)
        olap[i] = np.sum(a & b)

    ll=[np.sum(neither_good), np.sum(beh_only_good),
        np.sum(stp_only_good), np.sum(both_good)]
    ax.legend(l, ll)

    ax=plt.subplot(2,2,3)
    m = np.array(df.loc[goodcells].mean()[modelnames])
    xc_range = [-0.02, 0.2]
    plt.bar(np.arange(len(modelnames)), m, color='black')
    plt.plot(np.array([-1, len(modelnames)]), np.array([0, 0]), 'k--',
             linewidth=0.5)
    plt.ylim(xc_range)
    plt.title("batch {}, n={}/{} good cells".format(
            batch, np.sum(goodcells), len(goodcells)))
    plt.ylabel('median pred corr')
    plt.xlabel('model architecture')
    ax_remove_box(ax)

    for i in range(len(modelnames)-1):

        d1 = np.array(df[modelnames[i]])
        d2 = np.array(df[modelnames[i+1]])
        s, p = ss.wilcoxon(d1, d2)
        plt.text(i+0.5, m[i+1], "{:.1e}".format(p), ha='center', fontsize=6)

    plt.xticks(np.arange(len(m)),np.round(m,3))

    return fh, df[stpgood]['cellid'].tolist()
コード例 #21
0
ファイル: encoding_models_tag.py プロジェクト: LBHB/nems_db
pca_file = '/auto/users/svd/python/scripts/NAT_pop_models/dpc322.csv'
waveform_labels = pd.read_csv('phototag_waveform_labels.csv', index_col=0)
# waveform_labels = pd.read_csv(pl.Path('/auto/users/mateo/nems_db/nems_lbhb/projects/phototag/phototag_waveform_labels.csv'), index_col=0)

runclass = "NAT"
sql="select sCellFile.*,gSingleCell.siteid,gSingleCell.phototag from gSingleCell INNER JOIN sCellFile ON gSingleCell.id=sCellFile.singleid" +\
   " INNER JOIN gRunClass on gRunClass.id=sCellFile.runclassid" +\
   f" WHERE gRunClass.name='{runclass}' AND not(isnull(phototag))"
d = nd.pd_query(sql)
d['parmfile'] = d['stimpath'] + d['stimfile']
print(f'cell/file combos with phototag labels={len(d)}')

dtag = d[['cellid', 'phototag']].groupby('cellid').max()

dpred = nd.batch_comp(batch=batch, modelnames=modelnames, stat="r_ceiling")
dpred.columns = shortnames
dpred['siteid'] = dpred.index
dpred['siteid'] = dpred['siteid'].apply(nd.get_siteid)
dpred['diff'] = dpred[shortnames[1]] - dpred[shortnames[0]]

dpred = dpred.merge(dtag, how='inner', left_index=True, right_index=True)
dpred = dpred.merge(waveform_labels[['cellid', 'wshape']],
                    how='left',
                    left_index=True,
                    right_on='cellid').set_index('cellid')
#dpred['wshape']=dpred['wshape'].fillna("?")
dpred['pw'] = dpred['phototag'] + " " + dpred['wshape']

dpc = pd.read_csv(pca_file, index_col=0)
dpc = dpc.loc[dpc['type'] == 'avg']
コード例 #22
0
ファイル: fig6.stp_properties.py プロジェクト: nadoss/nems_db
d = nems_db.params.fitted_params_per_batch(
    batch,
    modelname,
    meta=['r_test', 'r_fit', 'se_test', 'r_ceiling'],
    stats_keys=[],
    multi='first')
if modelname0 is not None:
    d0 = nems_db.params.fitted_params_per_batch(
        batch,
        modelname0,
        meta=['r_test', 'r_fit', 'se_test', 'r_ceiling'],
        stats_keys=[],
        multi='first')

df_r = nd.batch_comp(batch, [modelname0, modelname], stat='r_test')
df_e = nd.batch_comp(batch, [modelname0, modelname], stat='se_test')

u_bounds = np.array([-0.6, 2.1])
tau_bounds = np.array([-0.1, 1.5])
str_bounds = np.array([-0.25, 0.55])
amp_bounds = np.array([-1, 1.5])
amp0_bounds = np.array([-0.5, 0.75])

indices = list(d.index)

for ind in indices:
    if '--u' in ind:
        u_index = ind
    elif '--tau' in ind:
        tau_index = ind
コード例 #23
0
import numpy as np
import matplotlib.pyplot as plt
import svdplots

batch = 303
modelname1 = "psth20pup0prebeh_stategain4_basic-nf"
modelname2 = "psth20pupprebeh_stategain4_basic-nf"

bins = 20
range = [-1, 1]
n1 = 'pre gain'
n2 = 'active gain'
i1 = 2
i2 = 3

res = nd.batch_comp(batch, [modelname0, modelname])

d1 = nems_db.params.fitted_params_per_batch(batch, modelname, stats_keys=[])
d2 = nems_db.params.fitted_params_per_batch(batch, modelname2, stats_keys=[])

# parse modelname
kws = modelname.split("_")
modname = kws[1]
statecount = int(modname[-1])

g1 = np.zeros([0, statecount])
b1 = np.zeros([0, statecount])
g2 = np.zeros([0, statecount])
b2 = np.zeros([0, statecount])
sig = np.zeros(len(d1.columns))
i = 0
コード例 #24
0
ファイル: generate_pop_figs.py プロジェクト: LBHB/nems_db
relative_changes_per_model = [
    means[0][i] / means[1][i] for i, _ in enumerate(labels)
]
relative_change_pareto = '\n'.join([
    f'{n}: {changes.mean()}'
    for n, changes in zip(labels, relative_changes_per_model)
])
stats_tests.append('\n\nPareto plot, relative change a1/peg:')
stats_tests.append(relative_change_pareto)

LN_single = MODELGROUPS['LN'][4]
pop_LN = ALL_FAMILY_MODELS[3]
sig_cells_A1 = get_significant_cells(322, SIG_TEST_MODELS, as_list=True)
r1 = nd.batch_comp(322, [LN_single, pop_LN],
                   cellids=sig_cells_A1,
                   stat=PLOT_STAT)
LN_test1 = st.wilcoxon(getattr(r1, LN_single),
                       getattr(r1, pop_LN),
                       alternative='two-sided')

sig_cells_PEG = get_significant_cells(323, SIG_TEST_MODELS, as_list=True)
r2 = nd.batch_comp(323, [LN_single, pop_LN],
                   cellids=sig_cells_PEG,
                   stat=PLOT_STAT)
LN_test2 = st.wilcoxon(getattr(r2, LN_single),
                       getattr(r2, pop_LN),
                       alternative='two-sided')

stats_tests.append('\nLN single vs pop-LN:')
stats_tests.append(f'A1: {LN_test1}')
コード例 #25
0
ファイル: queue_pop_models.py プロジェクト: LBHB/nems_db
modelname_filter = POP_MODELS[1]
mfb = {322: modelname_filter,
       323: modelname_filter.replace('.ver2','')}

# ROUND 1, all families pop
if 0:
    modelnames = ALL_FAMILY_POP[:-1]
    useGPU = True

    for batch in batches:
        if useGPU and (batch==322):
            c = ['NAT4v2']
        elif useGPU:
            c = ['NAT4']
        else:
            c = nd.batch_comp(modelnames=[modelname_filter], batch=batch).index.to_list()
        enqueue_exacloud_models(
            cellist=c, batch=batch, modellist=modelnames,
            user=lbhb_user, linux_user=user, force_rerun=force_rerun, priority=1,
            executable_path=executable_path_exa, script_path=script_path_exa, useGPU=useGPU)

if 0:
    # dnn single, round 1
    modelnames = DNN_SINGLE_MODELS[:5]

    # ln single, only 1 round
    modelnames = LN_SINGLE_MODELS[:10]

    # dnn single, round 2
    modelnames = DNN_SINGLE_STAGE2[:5]
コード例 #26
0
ファイル: fig2.example_resps.py プロジェクト: nadoss/nems_db
    'pdf.fonttype': 42,
    'ps.fonttype': 42
}
plt.rcParams.update(params)

batch = 259

# shrinkage
modelname1 = "env.fs100-ld-sev_dlog.f-fir.2x15-lvl.1-dexp.1_init-mt.shr-basic"
modelname2 = "env.fs100-ld-sev_dlog.f-wc.2x3.c.n-stp.3-fir.3x15-lvl.1-dexp.1_init-mt.shr-basic"
# regular
modelname1 = "env.fs100-ld-sev_dlog.f-fir.2x15-lvl.1-dexp.1_init-basic"
modelname2 = "env.fs100-ld-sev_dlog.f-wc.2x3.c.n-stp.3-fir.3x15-lvl.1-dexp.1_init-basic"

modelnames = [modelname1, modelname2]
df = nd.batch_comp(batch, modelnames)
df['diff'] = df[modelname2] - df[modelname1]
df['cellid'] = df.index
df.sort_values('cellid', inplace=True, ascending=True)
m = df['cellid'].str.startswith('por07') & (df[modelname2] > 0.3)
for index, c in df[m].iterrows():
    print("{}  {:.3f} - {:.3f} = {:.3f}".format(index, c[modelname2],
                                                c[modelname1], c['diff']))

plt.close('all')
outpath = "/auto/users/svd/docs/current/two_band_spn/eps_rev/"

if 0:
    #cellid="por077a-c1"
    cellid = "por074b-d2"
    cellid = "por020a-c1"
コード例 #27
0
def gd_scatter(batch,
               model1,
               model2,
               se_filter=True,
               gd_threshold=0,
               param='kappa',
               log_gd=False):

    df_r = nd.batch_comp(batch, [model1, model2], stat='r_ceiling')
    df_e = nd.batch_comp(batch, [model1, model2], stat='se_test')
    # Remove any cellids that have NaN for 1 or more models
    df_r.dropna(axis=0, how='any', inplace=True)
    df_e.dropna(axis=0, how='any', inplace=True)

    cellids = df_r.index.values.tolist()

    gc_test = df_r[model1]
    gc_se = df_e[model1]
    ln_test = df_r[model2]
    ln_se = df_e[model2]

    if se_filter:
        # Remove if performance not significant at all
        good_cells = ((gc_test > gc_se * 2) & (ln_test > ln_se * 2))
    else:
        # Set to series w/ all True, so none are skipped
        good_cells = (gc_test != np.nan)

    df1 = fitted_params_per_batch(batch, model1, stats_keys=[])
    df2 = fitted_params_per_batch(batch, model2, stats_keys=[])

    # fill in missing cellids w/ nan
    celldata = nd.get_batch_cells(batch=batch)
    cellids = celldata['cellid'].tolist()
    cellids = [c for c in cellids if c in good_cells]
    nrows = len(df1.index.values.tolist())

    df1_cells = df1.loc['meta--r_test'].index.values.tolist()[5:]
    df2_cells = df2.loc['meta--r_test'].index.values.tolist()[5:]

    nan_series = pd.Series(np.full((nrows), np.nan))

    df1_nans = 0
    df2_nans = 0

    for c in cellids:
        if c not in df1_cells:
            df1[c] = nan_series
            df1_nans += 1
        if c not in df2_cells:
            df2[c] = nan_series
            df2_nans += 1

    print("# missing cells: %d, %d" % (df1_nans, df2_nans))

    # Force same cellid order now that missing cols are filled in
    df1 = df1[cellids]
    df2 = df2[cellids]

    gc_vs_ln = df1.loc['meta--r_test'].values / df2.loc['meta--r_test'].values
    gc_vs_ln = gc_vs_ln.astype('float32')

    kappa_mod = df1[df1.index.str.contains('%s_mod' % param)]
    kappa = df1[df1.index.str.contains('%s$' % param)]
    gd_ratio = (np.abs(kappa_mod.values /
                       kappa.values)).astype('float32').flatten()

    ff = np.isfinite(gc_vs_ln) & np.isfinite(gd_ratio)
    gc_vs_ln = gc_vs_ln[ff]
    gd_ratio = gd_ratio[ff]
    if log_gd:
        gd_ratio = np.log(gd_ratio)

    # drop cells with excessively large/small gd_ratio or gc_vs_ln
    gcd_big = gd_ratio > 10
    gc_vs_ln_big = gc_vs_ln > 10
    gc_vs_ln_small = gc_vs_ln < 0.1
    keep = ~gcd_big & ~gc_vs_ln_big & ~gc_vs_ln_small
    gd_ratio = gd_ratio[keep]
    gc_vs_ln = gc_vs_ln[keep]

    r = np.corrcoef(gc_vs_ln, gd_ratio)[0, 1]
    n = gc_vs_ln.size

    # Separately do the same comparison but only with cells that had a
    # Gd ratio at least a little greater than 1 (i.e. had *some* GC effect)
    gd_ratio2 = copy.deepcopy(gd_ratio)
    gc_vs_ln2 = copy.deepcopy(gc_vs_ln)
    if log_gd:
        gd_threshold = np.log(gd_threshold)
    thresholded = (gd_ratio2 > gd_threshold)
    gd_ratio2 = gd_ratio2[thresholded]
    gc_vs_ln2 = gc_vs_ln2[thresholded]

    r2 = np.corrcoef(gc_vs_ln2, gd_ratio2)[0, 1]
    n2 = gc_vs_ln2.size

    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 9))

    ax1.scatter(gd_ratio, gc_vs_ln, c='black', s=1)
    ax1.set_ylabel("GC/LN R")
    ax1.set_xlabel("Gd ratio")
    ax1.set_title("Performance Improvement vs Gd ratio\nr: %.02f, n: %d" %
                  (r, n))

    ax2.hist(gd_ratio, bins=30, histtype='bar', color=['gray'])
    ax2.set_title('Gd ratio distribution')
    ax2.set_xlabel('Gd ratio')
    ax2.set_ylabel('Count')

    ax3.scatter(gd_ratio2, gc_vs_ln2, c='black', s=1)
    ax3.set_ylabel("GC/LN R")
    ax3.set_xlabel("Gd ratio")
    ax3.set_title("Same, only cells w/ Gd > %.02f\nr: %.02f, n: %d" %
                  (gd_threshold, r2, n2))

    ax4.hist(gd_ratio2, bins=30, histtype='bar', color=['gray'])
    ax4.set_title('Gd ratio distribution, only Gd > %.02f' % gd_threshold)
    ax4.set_xlabel('Gd ratio')
    ax4.set_ylabel('Count')

    fig.suptitle('param: %s' % param)
    fig.tight_layout()
コード例 #28
0
ファイル: pareto_pop_plot.py プロジェクト: LBHB/nems_db
def model_comp_pareto(batch,
                      modelgroups,
                      ax,
                      cellids,
                      nparms_modelgroups=None,
                      dot_colors=None,
                      dot_markers=None,
                      fill_styles=None,
                      plot_stat='r_test',
                      plot_medians=False,
                      labeled_models=None,
                      show_legend=True):

    if labeled_models is None:
        labeled_models = []
    if nparms_modelgroups is None:
        nparms_modelgroups = copy.copy(modelgroups)
    if fill_styles is None:
        fill_styles = {k: 'full' for (k, v) in dot_colors.items()}
    mean_cells_per_site = len(cellids)  # NAT4 dataset, so all cellids are used
    overall_min = 100
    overall_max = -100

    all_model_means = []
    labeled_data = []
    for k, modelnames in modelgroups.items():
        np_modelnames = nparms_modelgroups[k]
        b_ceiling = nd.batch_comp(batch,
                                  modelnames,
                                  cellids=cellids,
                                  stat=plot_stat)
        b_n = nd.batch_comp(batch,
                            np_modelnames,
                            cellids=cellids,
                            stat='n_parms')
        if not plot_medians:
            model_mean = b_ceiling.mean()
        else:
            model_mean = b_ceiling.median()
        b_m = np.array(model_mean)
        print(
            f"{k} modelcount {len(modelnames)} fits per model: {b_n.count().values}"
        )
        n_parms = np.array([np.mean(b_n[m]) for m in np_modelnames])

        # don't divide by cells per site if only one cell was fit
        if ('single' not in k) and (k != 'LN') and (k != 'stp'):
            n_parms = n_parms / mean_cells_per_site

        y_max = b_m.max() * 1.05
        y_min = b_m.min() * 0.9
        overall_max = max(overall_max, y_max)
        overall_min = min(overall_min, y_min)

        ax.plot(n_parms,
                b_m,
                color=dot_colors[k],
                marker=dot_markers[k],
                label=k,
                markersize=4.5,
                fillstyle=fill_styles[k])
        for m in labeled_models:
            if m in modelnames:
                i = modelnames.index(m)
                labeled_data.append([n_parms[i], b_m[i]])

        all_model_means.append(model_mean)

    ax.plot(*list(zip(*labeled_data)),
            's',
            color='black',
            marker='o',
            fillstyle='none',
            markersize=10)

    handles, labels = ax.get_legend_handles_labels()
    # reverse the order
    if show_legend:
        ax.legend(handles,
                  labels,
                  loc='lower right',
                  fontsize=7,
                  frameon=False)
    ax.set_xlabel('Free parameters per neuron')
    ax.set_ylabel('Median prediction correlation')
    ax.set_ylim((overall_min, overall_max))
    ax_remove_box(ax)
    plt.tight_layout()

    return ax, b_ceiling, all_model_means, labels
コード例 #29
0
    n2 = modelnames[-2]
elif 1:
    batch = 289
    modelnames = [
        "ozgf.fs100.ch18-ld-sev_dlog-wc.18x3-fir.3x15-lvl.1-dexp.1_init-basic",
        "ozgf.fs100.ch18-ld-sev_dlog-wc.18x3-stp.3-fir.3x15-lvl.1-dexp.1_init-basic"
    ]
    #modelnames = ["ozgf.fs100.ch18-ld-sev_dlog-wc.18x3.g-fir.3x15-lvl.1-dexp.1_init-basic",
    #              "ozgf.fs100.ch18-ld-sev_dlog-wc.18x3.g-stp.3-fir.3x15-lvl.1-dexp.1_init-basic"]
    n1 = modelnames[0]
    n2 = modelnames[1]
    fileprefix = "fig9.NAT"

xc_range = [-0.05, 1.1]

df = nd.batch_comp(batch, modelnames, stat='r_ceiling')
df_r = nd.batch_comp(batch, modelnames, stat='r_test')
df_e = nd.batch_comp(batch, modelnames, stat='se_test')

cellcount = len(df)

beta1 = df[n1]
beta2 = df[n2]
beta1_test = df_r[n1]
beta2_test = df_r[n2]
se1 = df_e[n1]
se2 = df_e[n2]

beta1[beta1 > 1] = 1
beta2[beta2 > 1] = 1
コード例 #30
0
def sparseness_by_batch(batch,
                        modelnames=None,
                        pop_reference_model=ALL_FAMILY_POP[2],
                        save_path=None,
                        force_regenerate=False,
                        rec=None):

    if (save_path is not None) & (not force_regenerate):
        if os.path.exists(save_path):
            sparseness_data = pd.read_csv(save_path, index_col=0)
            return sparseness_data

    if modelnames is None:
        modelnames = [
            ALL_FAMILY_MODELS[0], ALL_FAMILY_MODELS[2], ALL_FAMILY_MODELS[3]
        ]
        #modelnames=[ALL_FAMILY_MODELS[2],ALL_FAMILY_MODELS[3] ]

    d = nd.batch_comp(batch, modelnames)
    #cellids = d.index
    cellids = get_significant_cells(batch, SIG_TEST_MODELS, as_list=True)

    if rec is None:
        xf0, ctx0 = load_model_xform(cellids[0],
                                     batch,
                                     pop_reference_model,
                                     eval_model=True)
        val = ctx0['val']
        val = val.apply_mask()
        del ctx0
    else:
        val = rec

    sparseness_data = pd.DataFrame()

    for j, cellid in enumerate(cellids):
        r_test_all = d.loc[cellid].values.max()
        if r_test_all > 0:
            for i, m in enumerate(modelnames):

                xf, ctx = load_model_xform(cellid, batch, m, eval_model=False)
                val_ = ctx['modelspec'].evaluate(val)

                fs = val['resp'].fs
                this_resp = val['resp'].extract_channels(
                    chans=[cellid])._data[0, :] * fs
                this_pred = val_['pred']._data[0, :] * fs

                c = np.corrcoef(this_resp, this_pred)[0, 1]
                #c2 = np.corrcoef(this_resp._data[0,:], original_pred._data[0,:])[0,1]
                #print(c2, d.loc[cellid].values)

                S_r, S_p, c = sparseness(this_resp, this_pred, verbose=False)
                print(
                    f"{j} {cellid} r_test={c:.3f} orig r_test={d.loc[cellid].values[i]:.3f} S_r={S_r:.3f} S_p={S_p:.3f}"
                )

                sparseness_data = sparseness_data.append(
                    {
                        'cellid': cellid,
                        'S_r': S_r,
                        'S_p': S_p,
                        'r_test': c,
                        'r_test_all': r_test_all,
                        'model': i
                    },
                    ignore_index=True)

    if save_path is not None:
        print(f"Saving sparseness_data to {save_path}")
        sparseness_data.to_csv(save_path)

    return sparseness_data