Exemple #1
0
def load_snr_and_cellids(a1, peg, a1_snr_path, peg_snr_path):
    a1_significant_cells = get_significant_cells(a1,
                                                 SIG_TEST_MODELS,
                                                 as_list=True)
    a1_results = nd.batch_comp(batch=a1,
                               modelnames=SIG_TEST_MODELS,
                               stat=PLOT_STAT,
                               cellids=a1_significant_cells)
    a1_index = a1_results.index.values
    peg_significant_cells = get_significant_cells(peg,
                                                  SIG_TEST_MODELS,
                                                  as_list=True)
    peg_results = nd.batch_comp(batch=peg,
                                modelnames=SIG_TEST_MODELS,
                                stat=PLOT_STAT,
                                cellids=peg_significant_cells)
    peg_index = peg_results.index.values

    a1_snr_df = snr_by_batch(a1, 'ozgf.fs100.ch18',
                             load_path=a1_snr_path).loc[a1_index]
    peg_snr_df = snr_by_batch(peg, 'ozgf.fs100.ch18',
                              load_path=peg_snr_path).loc[peg_index]
    a1_snr = a1_snr_df.values.flatten()
    a1_cellids = a1_snr_df.index
    peg_snr = peg_snr_df.values.flatten()
    peg_idx = np.argsort(peg_snr_df.values, axis=None)
    peg_snr_sample = peg_snr[peg_idx]
    peg_cellids = peg_snr_df.index[peg_idx].values

    return a1_snr, peg_snr_sample, a1_snr_df, peg_snr_df, a1_cellids, peg_cellids
Exemple #2
0
def sanity_check_LN(batch, modelnames, save_path=None, load_path=None):
    if load_path is not None:
        corrs = pd.read_pickle(load_path)
        return corrs
    else:
        cellids = []
        pop_vs_LN = []
    significant_cells = get_significant_cells(batch,
                                              SIG_TEST_MODELS,
                                              as_list=True)

    for cellid in significant_cells:
        # Load and evaluate each model, pull out validation pred signal for each one.
        contexts = [
            xhelp.load_model_xform(cellid, batch, m, eval_model=True)[1]
            for m in modelnames
        ]
        preds = [
            c['val'].apply_mask()['pred'].as_continuous() for c in contexts
        ]

        # Compute correlation between eaceh pair of models, append to running list.
        # 0: conv2d, 1: conv1dx2+d, 2: LN_pop,  # TODO: if EQUIVALENCE_MODELS changes, this needs to change as well
        pop_vs_LN.append(np.corrcoef(
            preds[0], preds[1])[0, 1])  # correlate conv1dx2+d with LN_pop
        cellids.append(cellid)

        # Convert to dataframe and save after each cell, in case there's a crash.
        corrs = {'cellid': cellids, 'pop_vs_LN': pop_vs_LN}
        corrs = pd.DataFrame.from_dict(corrs)
        corrs.set_index('cellid', inplace=True)
        if save_path is not None:
            corrs.to_pickle(save_path)

    return corrs
Exemple #3
0
def generate_psth_correlations_pop(batch,
                                   modelnames,
                                   save_path=None,
                                   load_path=None):
    if load_path is not None:
        corrs = pd.read_pickle(load_path)
        return corrs
    else:
        cellids = []
        c2d_c1d = []
        c2d_LN = []
        c1d_LN = []
    significant_cells = get_significant_cells(batch,
                                              SIG_TEST_MODELS,
                                              as_list=True)

    # Load and evaluate each model, pull out validation pred signal for each one.
    contexts = [
        xhelp.load_model_xform(significant_cells[0], batch, m,
                               eval_model=True)[1] for m in modelnames
    ]
    preds = [c['val'].apply_mask()['pred'] for c in contexts]
    chans = preds[0].chans  # not all the models load chans for some reason
    for i, _ in enumerate(preds[1:]):
        preds[i + 1].chans = chans
    preds = [
        p.extract_channels(significant_cells).as_continuous() for p in preds
    ]

    for i, cellid in enumerate(significant_cells):
        # Compute correlation between eaceh pair of models, append to running list.
        # 0: conv2d, 1: conv1dx2+d, 2: LN_pop,  # TODO: if EQUIVALENCE_MODELS_POP changes, this needs to change as well
        c2d_c1d.append(np.corrcoef(
            preds[0][i], preds[1][i])[0,
                                      1])  # correlate conv2d with conv1dx2+d
        c2d_LN.append(np.corrcoef(
            preds[0][i], preds[2][i])[0, 1])  # correlate conv2d with LN_pop
        c1d_LN.append(np.corrcoef(
            preds[1][i], preds[2][i])[0,
                                      1])  # correlate conv1dx2+d with LN_pop
        cellids.append(cellid)

    # Convert to dataframe and save after each cell, in case there's a crash.
    corrs = {
        'cellid': cellids,
        'c2d_c1d': c2d_c1d,
        'c2d_LN': c2d_LN,
        'c1d_LN': c1d_LN
    }
    corrs = pd.DataFrame.from_dict(corrs)
    corrs.set_index('cellid', inplace=True)
    if save_path is not None:
        corrs.to_pickle(save_path)

    return corrs
Exemple #4
0
def plot_pred_scatter(batch, modelnames, labels=None, colors=None, ax=None):

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.figure
    if labels is None:
        labels = ['model 1', 'model 2']
    if colors is None:
        colors = ['black', 'black']

    significant_cells = get_significant_cells(batch,
                                              SIG_TEST_MODELS,
                                              as_list=True)

    sig_scores = nd.batch_comp(batch,
                               modelnames,
                               cellids=significant_cells,
                               stat='r_test')
    se_scores = nd.batch_comp(batch,
                              modelnames,
                              cellids=significant_cells,
                              stat='se_test')
    ceiling_scores = nd.batch_comp(batch,
                                   modelnames,
                                   cellids=significant_cells,
                                   stat=PLOT_STAT)
    nonsig_cells = list(set(sig_scores.index) - set(significant_cells))

    # figure out units with significant differences between models
    sig = (sig_scores[modelnames[1]] - se_scores[modelnames[1]] > sig_scores[modelnames[0]] + se_scores[modelnames[0]]) | \
          (sig_scores[modelnames[0]] - se_scores[modelnames[0]] > sig_scores[modelnames[1]] + se_scores[modelnames[1]])
    group1 = (ceiling_scores.loc[~sig, modelnames[0]].values,
              ceiling_scores.loc[~sig, modelnames[1]].values)
    group2 = (ceiling_scores.loc[sig, modelnames[0]].values,
              ceiling_scores.loc[sig, modelnames[1]].values)
    n_nonsig = group1[0].size
    n_sig = group2[0].size

    scatter_groups([group1, group2], ['lightgray', 'black'],
                   ax=ax,
                   labels=['N.S.', 'p < 0.5'])
    #ax.set_title(f'{batch_str[batch]} {PLOT_STAT}')
    ax.set_xlabel(
        f'{labels[0]}\n(median r={ceiling_scores[modelnames[0]].median():.3f})',
        color=colors[0])
    ax.set_ylabel(
        f'{labels[1]}\n(median r={ceiling_scores[modelnames[1]].median():.3f})',
        color=colors[1])

    return fig, n_sig, n_nonsig
Exemple #5
0
def plot_conv_scatters(batch):
    significant_cells = get_significant_cells(batch,
                                              SIG_TEST_MODELS,
                                              as_list=True)

    sig_scores = nd.batch_comp(batch,
                               ALL_FAMILY_MODELS,
                               cellids=significant_cells,
                               stat='r_test')
    se_scores = nd.batch_comp(batch,
                              ALL_FAMILY_MODELS,
                              cellids=significant_cells,
                              stat='se_test')
    ceiling_scores = nd.batch_comp(batch,
                                   ALL_FAMILY_MODELS,
                                   cellids=significant_cells,
                                   stat=PLOT_STAT)
    nonsig_cells = list(set(sig_scores.index) - set(significant_cells))

    fig, ax = plt.subplots(1, 3, figsize=(12, 4))

    # LN vs DNN-single
    plot_pred_scatter(batch, [ALL_FAMILY_MODELS[3], ALL_FAMILY_MODELS[4]],
                      labels=['1Dx2-CNN', 'pop-LN'],
                      ax=ax[0])

    # LN vs conv1dx2+d
    plot_pred_scatter(batch, [ALL_FAMILY_MODELS[3], ALL_FAMILY_MODELS[2]],
                      labels=['1Dx2-CNN', 'pop-LN'],
                      ax=ax[0])

    # conv2d vs conv1dx2+d
    sig = (sig_scores[ALL_FAMILY_MODELS[0]] - se_scores[ALL_FAMILY_MODELS[0]] > sig_scores[ALL_FAMILY_MODELS[2]] + se_scores[ALL_FAMILY_MODELS[2]]) | \
        (sig_scores[ALL_FAMILY_MODELS[2]] - se_scores[ALL_FAMILY_MODELS[2]] > sig_scores[ALL_FAMILY_MODELS[0]] + se_scores[ALL_FAMILY_MODELS[0]])
    group3 = (ceiling_scores.loc[~sig, ALL_FAMILY_MODELS[0]].values,
              ceiling_scores.loc[~sig, ALL_FAMILY_MODELS[2]].values)
    group4 = (ceiling_scores.loc[sig, ALL_FAMILY_MODELS[0]].values,
              ceiling_scores.loc[sig, ALL_FAMILY_MODELS[2]].values)
    scatter_groups([group3, group4], ['lightgray', 'black'], ax=ax[2])
    ax[2].set_title('batch %d, %s, conv2d vs conv1dx2+d' % (batch, PLOT_STAT))
    ax[2].set_xlabel(
        f'DNN (2D conv) pred. correlation ({ceiling_scores[ALL_FAMILY_MODELS[0]].mean():.3f})',
        color=colors[0])
    ax[2].set_ylabel(
        f'DNN (1D conv) pred. correlation ({ceiling_scores[ALL_FAMILY_MODELS[2]].mean():.3f})',
        color=colors[1])

    plt.tight_layout()

    return fig
Exemple #6
0
def bar_mean(batch, modelnames, stest=SIG_TEST_MODELS, ax=None):

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.figure

    cellids = get_significant_cells(batch, stest, as_list=True)
    r_values = nd.batch_comp(batch,
                             modelnames,
                             cellids=cellids,
                             stat=PLOT_STAT)

    # Bar Plot -- Median for each model
    # NOTE: ordering of names is assuming ALL_FAMILY_MODELS is being used and has not changed.
    bar_colors = [DOT_COLORS[k] for k in shortnames]
    medians = r_values.median(axis=0).values
    ax.bar(np.arange(0, len(modelnames)),
           medians,
           color=bar_colors,
           edgecolor='black',
           linewidth=1,
           tick_label=shortnames)
    ax.set_ylabel('Median prediction\ncorrelation')
    ax.set_xticklabels(ax.get_xticklabels(), rotation='45', ha='right')

    # Test significance for all comparisons
    stats_results = {}
    reduced_modelnames = modelnames.copy()
    reduced_shortnames = shortnames.copy()
    for m1, s1 in zip(modelnames, shortnames):
        i = 0
        reduced_modelnames.pop(i)
        reduced_shortnames.pop(i)
        i += 1
        # compare each model to every other model
        for m2, s2 in zip(reduced_modelnames, reduced_shortnames):
            stats_test = st.wilcoxon(r_values[m1],
                                     r_values[m2],
                                     alternative='two-sided')
            #stats_test = arrays_to_p(r_values[m1], r_values[m2], cellids, twosided=True)
            key = f'{s1} vs {s2}'
            stats_results[key] = stats_test

    return ax, medians, stats_results
Exemple #7
0
def scatter_titan(batch):
    PLOT_STAT = 'r_test'
    TITAN_MODEL = 'ozgf.fs100.ch18.pop-loadpop-norm.l1-popev_wc.18x70.g-fir.1x15x70-relu.70.f-wc.70x80-fir.1x10x80-relu.80.f-wc.80x100-relu.100-wc.100xR-lvl.R-dexp.R_tfinit.n.mc50.lr1e3.es20-newtf.n.mc100.lr1e4.exa'
    REFERENCE_MODEL = 'ozgf.fs100.ch18.pop-loadpop-norm.l1-popev_wc.18x70.g-fir.1x15x70-relu.70.f-wc.70x80-fir.1x10x80-relu.80.f-wc.80x100-relu.100-wc.100xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4'
    MODELS = [REFERENCE_MODEL, TITAN_MODEL]

    significant_cells = get_significant_cells(batch, MODELS, as_list=True)

    sig_scores = nd.batch_comp(batch,
                               MODELS,
                               cellids=significant_cells,
                               stat='r_test')
    se_scores = nd.batch_comp(batch,
                              MODELS,
                              cellids=significant_cells,
                              stat='se_test')
    ceiling_scores = nd.batch_comp(batch,
                                   MODELS,
                                   cellids=significant_cells,
                                   stat=PLOT_STAT)
    nonsig_cells = list(set(sig_scores.index) - set(significant_cells))

    fig, ax = plt.subplots(1, 1, figsize=(4, 4))

    # LN vs DNN-single
    sig = (sig_scores[MODELS[1]] - se_scores[MODELS[1]] > sig_scores[MODELS[0]] + se_scores[MODELS[0]]) | \
          (sig_scores[MODELS[0]] - se_scores[MODELS[0]] > sig_scores[MODELS[1]] + se_scores[MODELS[1]])
    group1 = (ceiling_scores.loc[~sig, MODELS[0]].values,
              ceiling_scores.loc[~sig, MODELS[1]].values)
    group2 = (ceiling_scores.loc[sig, MODELS[0]].values,
              ceiling_scores.loc[sig, MODELS[1]].values)

    scatter_groups([group1, group2], ['lightgray', 'black'], ax=ax)
    ax.set_title('batch %d, %s, Conv1d vs Titan' % (batch, PLOT_STAT))
    ax.set_xlabel(f'Single stim ({ceiling_scores[MODELS[0]].mean():.3f})',
                  color='orange')
    ax.set_ylabel(f'Titan ({ceiling_scores[MODELS[1]].mean():.3f})',
                  color='lightgreen')

    return fig
Exemple #8
0
def generate_heldout_plots(batch,
                           batch_name,
                           sig_test_models=SIG_TEST_MODELS,
                           ax=None,
                           hide_xaxis=False):

    significant_cells = get_significant_cells(batch,
                                              sig_test_models,
                                              as_list=True)
    #short_names = ['conv2d', 'conv1d', 'conv1dx2+d', 'LN_pop', 'dnn1']
    short_names = ['1Dx2-CNN', 'pop-LN', 'single-CNN']
    if len(short_names) != len(HELDOUT):
        raise ValueError(
            'length of short_names must equal number of models in HELDOUT / MATCHED'
        )
    r_ceilings = get_heldout_results(batch, significant_cells, short_names)

    r_ceilings['cellid'] = significant_cells
    reference_results = nd.batch_comp(batch,
                                      sig_test_models,
                                      cellids=significant_cells,
                                      stat=PLOT_STAT)
    reference_medians = [
        reference_results[m].median() for m in sig_test_models
    ]

    heldout_names = [n + ' held' for n in short_names]
    matched_names = [n + ' match' for n in short_names]
    tests = [
        st.mannwhitneyu(r_ceilings[x], r_ceilings[y], alternative='two-sided')
        for x, y in zip(heldout_names, matched_names)
    ]
    median_diffs = [
        r_ceilings[y].median() - r_ceilings[x].median()
        for x, y in zip(heldout_names, matched_names)
    ]

    df = pd.DataFrame.from_dict(r_ceilings)
    value_vars = [
        col for sublist in zip(heldout_names, matched_names) for col in sublist
    ]
    results = pd.melt(df,
                      id_vars='cellid',
                      value_vars=value_vars,
                      value_name=PLOT_STAT)
    results = results.rename(columns={'variable': 'model'})
    results['hue_tag'] = np.zeros(results.shape[0], )
    for i, n in enumerate(short_names):
        results.loc[results['model'].str.contains(n), 'hue_tag'] = i

    if ax is None:
        _, ax = plt.subplots()
    else:
        plt.sca(ax)
    tres = results.loc[(results[PLOT_STAT] < 1) & results[PLOT_STAT] > -0.05]
    #                  palette=[DOT_COLORS['conv1dx2+d'],DOT_COLORS['conv1dx2+d'],
    #                          DOT_COLORS['LN_pop'],DOT_COLORS['LN_pop'],
    #                          DOT_COLORS['dnn1_single'],DOT_COLORS['dnn1_single']],
    sns.stripplot(x='model',
                  y=PLOT_STAT,
                  hue='hue_tag',
                  data=tres,
                  zorder=0,
                  order=value_vars,
                  jitter=0.2,
                  ax=ax,
                  palette=[
                      DOT_COLORS['1Dx2-CNN'], DOT_COLORS['pop-LN'],
                      DOT_COLORS['single-CNN']
                  ],
                  size=2)
    ax.legend_.remove()
    sns.boxplot(x='model',
                y=PLOT_STAT,
                data=tres,
                boxprops={
                    'facecolor': 'None',
                    'linewidth': 1
                },
                showcaps=False,
                showfliers=False,
                whiskerprops={'linewidth': 0},
                order=value_vars,
                ax=ax)
    #plt.title('%s' % batch_name)

    labels = [e.get_text() for e in ax.get_xticklabels()]
    ticks = ax.get_xticks()
    w = 0.1
    for idx, model in enumerate(labels):
        idx = labels.index(model)
        j = int(idx * 0.5)
        plt.hlines(reference_medians[j],
                   ticks[idx] - w,
                   ticks[idx] + w,
                   color='black',
                   linewidth=2)

    ax.set_ylim(-0.05, 1)
    ax.set_xlabel('')
    plt.xticks(rotation=45, fontsize=6, ha='right')
    if hide_xaxis:
        ax.xaxis.set_visible(False)
    plt.tight_layout()

    return [p for p in tests
            ], significant_cells, r_ceilings, reference_medians, median_diffs
Exemple #9
0
def plot_matched_snr(a1,
                     peg,
                     a1_snr_path,
                     peg_snr_path,
                     plot_sanity_check=True,
                     ax=None,
                     inset_ax=None):
    modelnames = SIG_TEST_MODELS

    # Load model performance results for a1 and peg
    a1_significant_cells = get_significant_cells(a1, modelnames, as_list=True)
    a1_results = nd.batch_comp(batch=a1,
                               modelnames=modelnames,
                               stat=PLOT_STAT,
                               cellids=a1_significant_cells)
    a1_index = a1_results.index.values
    a1_medians = [a1_results[m].median() for m in modelnames]

    peg_significant_cells = get_significant_cells(peg,
                                                  modelnames,
                                                  as_list=True)
    peg_results = nd.batch_comp(batch=peg,
                                modelnames=modelnames,
                                stat=PLOT_STAT,
                                cellids=peg_significant_cells)
    peg_index = peg_results.index.values
    peg_medians = [peg_results[m].median() for m in modelnames]

    a1_snr_df = snr_by_batch(a1, 'ozgf.fs100.ch18',
                             load_path=a1_snr_path).loc[a1_index]
    peg_snr_df = snr_by_batch(peg, 'ozgf.fs100.ch18',
                              load_path=peg_snr_path).loc[peg_index]

    # put peg snr in increasing order
    a1_snr = a1_snr_df.values.flatten()
    a1_median_snr = np.median(a1_snr)
    a1_cellids = a1_snr_df.index.values
    peg_snr = peg_snr_df.values.flatten()
    peg_median_snr = np.median(peg_snr)
    peg_idx = np.argsort(peg_snr_df.values, axis=None)
    peg_cellids = peg_snr_df.index[peg_idx].values
    peg_snr_sample = peg_snr[peg_idx]
    test_snr = st.mannwhitneyu(a1_snr, peg_snr, alternative='two-sided')

    # force "exact" distribution match for given histogram bins
    a1_matched_cellids, peg_matched_cellids, bins = get_matched_snr_cells(
        a1_snr, peg_snr_sample, a1_cellids, peg_cellids)
    a1_matched_snr_df = a1_snr_df.loc[a1_matched_cellids]
    peg_matched_snr_df = peg_snr_df.loc[peg_matched_cellids]

    a1_snr_df2 = a1_snr_df.rename(columns={'snr': 'A1'})
    peg_snr_df2 = peg_snr_df.rename(columns={'snr': 'PEG'})
    combined_df = a1_snr_df2.join(peg_snr_df2, how='outer')

    combined_df = combined_df.reset_index(0)
    combined_df = pd.melt(combined_df,
                          value_vars=['A1', 'PEG'],
                          value_name='SNR',
                          id_vars='cellid')
    combined_df['matched'] = 0
    combined_df['matched'].loc[combined_df.cellid.isin(
        a1_matched_cellids + peg_matched_cellids)] = 1
    combined_df = combined_df.sort_values(by='matched')

    a1_matched_snr = a1_matched_snr_df.values
    peg_matched_snr = peg_matched_snr_df.values
    a1_median_snr_matched = np.median(a1_matched_snr)
    peg_median_snr_matched = np.median(peg_matched_snr)

    if plot_sanity_check:
        # Sanity check: these should definitely be the same
        fig = plt.figure()
        plt.hist(peg_matched_snr, bins=bins, alpha=0.5, label='peg')
        plt.hist(a1_matched_snr, bins=bins, alpha=0.3, label='a1')
        plt.title('sanity check: these should completely overlap')
        plt.legend()

    # Filter by matched cellids,
    # then combine into single dataframe with columns for cellid, a1/peg, modelname, PLOT_STAT
    short_names = ['1Dx2 CNN', 'pop-LN', 'single-CNN']
    a1_short = [s + ' A1' for s in short_names]
    a1_rename = {k: v for k, v in zip(modelnames, a1_short)}
    a1_results = a1_results.rename(columns=a1_rename)
    a1_matched_results = a1_results.loc[a1_matched_cellids].reset_index(
        level=0)
    a1_removed_results = a1_results.loc[
        ~a1_results.index.isin(a1_matched_cellids)].reset_index(level=0)
    #a1_full_results = a1_results.reset_index(level=0)

    peg_short = [s + ' PEG' for s in short_names]
    peg_rename = {k: v for k, v in zip(modelnames, peg_short)}
    peg_results = peg_results.rename(columns=peg_rename)
    peg_matched_results = peg_results.loc[peg_matched_cellids].reset_index(
        level=0)
    peg_removed_results = peg_results[
        ~peg_results.index.isin(peg_matched_cellids)].reset_index(level=0)

    # Test significance after matching distributions
    test_c1 = st.mannwhitneyu(a1_matched_results[a1_short[0]],
                              peg_matched_results[peg_short[0]],
                              alternative='two-sided')
    test_LN = st.mannwhitneyu(a1_matched_results[a1_short[1]],
                              peg_matched_results[peg_short[1]],
                              alternative='two-sided')
    test_dnn = st.mannwhitneyu(a1_matched_results[a1_short[2]],
                               peg_matched_results[peg_short[2]],
                               alternative='two-sided')

    results_matched = pd.concat([a1_matched_results, peg_matched_results],
                                axis=0)
    results_removed = pd.concat([a1_removed_results, peg_removed_results],
                                axis=0)
    alternating_columns = [
        col for sublist in zip(a1_short, peg_short) for col in sublist
    ]

    results_matched = pd.melt(results_matched,
                              id_vars='cellid',
                              value_vars=a1_short + peg_short,
                              value_name=PLOT_STAT)
    results_matched = results_matched.rename(columns={'variable': 'model'})
    results_matched['hue_tag'] = np.zeros(results_matched.shape[0], )
    for i, n in enumerate(short_names):
        results_matched.loc[results_matched['model'].str.contains(n),
                            'hue_tag'] = i

    results_removed = pd.melt(results_removed,
                              id_vars='cellid',
                              value_vars=a1_short + peg_short,
                              value_name=PLOT_STAT)
    results_removed = results_removed.rename(columns={'variable': 'model'})
    results_removed['hue_tag'] = np.zeros(results_removed.shape[0], )
    for i, n in enumerate(short_names):
        results_removed.loc[results_removed['model'].str.contains(n),
                            'hue_tag'] = i

    if ax is None:
        _, ax = plt.subplots()
    else:
        plt.sca(ax)
    jitter = 0.2
    palette = {
        0: DOT_COLORS['1Dx2-CNN'],
        1: DOT_COLORS['pop-LN'],
        2: DOT_COLORS['single-CNN']
    }

    # plot removed cells "under" the remaining ones
    #tres = results_removed.loc[(results_removed[PLOT_STAT]<1) & results_removed[PLOT_STAT]>-0.05]
    # sns.stripplot(x='model', y=PLOT_STAT, data=tres, zorder=0, order=alternating_columns,
    #                    color='gray', alpha=0.5, size=2, jitter=jitter, hue='hue_tag', palette=palette, ax=ax)
    #ax.legend_.remove()

    tres = results_matched.loc[(results_matched[PLOT_STAT] < 1)
                               & results_matched[PLOT_STAT] > -0.05]
    sns.stripplot(x='model',
                  y=PLOT_STAT,
                  data=tres,
                  zorder=0,
                  order=alternating_columns,
                  jitter=jitter,
                  hue='hue_tag',
                  palette=palette,
                  ax=ax,
                  size=2)
    ax.legend_.remove()
    sns.boxplot(x='model',
                y=PLOT_STAT,
                data=results_matched,
                boxprops={
                    'facecolor': 'None',
                    'linewidth': 1
                },
                showcaps=False,
                showfliers=False,
                whiskerprops={'linewidth': 0},
                order=alternating_columns,
                ax=ax)

    labels = [e.get_text() for e in ax.get_xticklabels()]
    ticks = ax.get_xticks()
    w = 0.1
    for idx, model in enumerate(labels):
        idx = labels.index(model)
        j = int(idx * 0.5)
        if idx % 2 == 0:
            plt.hlines(a1_medians[j],
                       ticks[idx] - w,
                       ticks[idx] + w,
                       color='black',
                       linewidth=2)
        else:
            plt.hlines(peg_medians[j],
                       ticks[idx] - w,
                       ticks[idx] + w,
                       color='black',
                       linewidth=2)

    ax.set(ylim=(None, 1))
    plt.xticks(rotation=45, fontsize=6, ha='right')
    plt.tight_layout()

    # Plot the original distributions and the matched distribution, to visualize what was removed.
    if inset_ax is None:
        _, inset_ax = plt.subplots()
    else:
        plt.sca(inset_ax)

    # Histogram version
    # bins=bins -- this would be the actual bins used for matching the distributions, but it's a bit too fine
    # for the figure size we ended up at.
    plt.hist(a1_snr,
             bins=bins,
             label='a1',
             histtype='stepfilled',
             edgecolor='black',
             color='white')  #DOT_COLORS['pop LN'])
    plt.hist(peg_snr,
             bins=bins,
             label='peg',
             histtype='stepfilled',
             edgecolor='black',
             color='lightgray')  #DOT_COLORS['2D CNN'],)
    plt.hist(a1_matched_snr,
             bins=bins,
             label='matched',
             histtype='stepfilled',
             edgecolor='black',
             fill=False,
             hatch='\\\\\\')
    plt.legend()
    plt.ylabel('Number of neurons')
    plt.xlabel('SNR')

    # Strip plot version
    # sns.stripplot(x='variable', y='SNR', data=combined_df, hue='matched', palette={0: 'lightgray', 1: 'black'},
    #               size=2, ax=inset_ax, jitter=jitter)
    # inset_ax.legend_.remove()
    #sns.stripplot(data=combined_df_matched, palette={'A1': 'black', 'PEG': 'black'}, size=2, ax=inset_ax, jitter=jitter)

    plt.tight_layout()

    return (test_c1, test_LN, test_dnn, test_snr, a1_results.median(),
            a1_matched_results.median(), peg_results.median(),
            peg_matched_results.median(), a1_median_snr, a1_median_snr_matched,
            peg_median_snr, peg_median_snr_matched)
Exemple #10
0
def generate_psth_correlations_single(batch,
                                      modelnames,
                                      save_path=None,
                                      load_path=None,
                                      test_limit=None,
                                      force_rerun=False,
                                      skip_new_cells=True):
    if load_path is not None:
        corrs = pd.read_pickle(load_path)
        cellids = corrs.index.values.tolist()
        c2d_c1d = corrs['c2d_c1d'].values.tolist()
        c2d_LN = corrs['c2d_LN'].values.tolist()
        c1d_LN = corrs['c1d_LN'].values.tolist()
    else:
        cellids = []
        c2d_c1d = []
        c2d_LN = []
        c1d_LN = []

    significant_cells = get_significant_cells(batch,
                                              SIG_TEST_MODELS,
                                              as_list=True)
    for cellid in significant_cells[:test_limit]:
        if (cellid in cellids) and (not force_rerun):
            #print(f'skipping cellid: {cellid}')
            continue
        if skip_new_cells:
            # Don't stop to add new correlations for cells that weren't included in
            # a previous analysis (like if new recordings have been done for the same batch).
            continue

        # Load and evaluate each model, pull out validation pred signal for each one.
        contexts = [
            xhelp.load_model_xform(cellid, batch, m, eval_model=True)[1]
            for m in modelnames
        ]
        preds = [
            c['val'].apply_mask()['pred'].as_continuous() for c in contexts
        ]

        # Compute correlation between eaceh pair of models, append to running list.
        # 0: conv2d, 1: conv1dx2+d, 2: LN_pop,  # TODO: if EQUIVALENCE_MODELS changes, this needs to change as well
        c2d_c1d.append(np.corrcoef(
            preds[0], preds[1])[0, 1])  # correlate conv2d with conv1dx2+d
        c2d_LN.append(np.corrcoef(preds[0],
                                  preds[2])[0,
                                            1])  # correlate conv2d with LN_pop
        c1d_LN.append(np.corrcoef(
            preds[1], preds[2])[0, 1])  # correlate conv1dx2+d with LN_pop
        cellids.append(cellid)

        # Convert to dataframe and save after each cell, in case there's a crash.
        corrs = {
            'cellid': cellids,
            'c2d_c1d': c2d_c1d,
            'c2d_LN': c2d_LN,
            'c1d_LN': c1d_LN
        }
        corrs = pd.DataFrame.from_dict(corrs)
        corrs.set_index('cellid', inplace=True)
        if save_path is not None:
            corrs.to_pickle(save_path)

    return corrs
Exemple #11
0
    f"ozgf.fs100.ch18-ld-norm.l1-sev.k25_{half_test_modelspec}_prefit.hs-tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4",
    f"ozgf.fs100.ch18-ld-norm.l1-sev.k50_{half_test_modelspec}_prefit.hs-tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4",
    f"ozgf.fs100.ch18-ld-norm.l1-sev_{half_test_modelspec}_prefit.hm-tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4"
]

# then fit last layer on heldout cell with half the data (same est data as for modelname_half_prefit), run per cell
modelname_half_fullfit = [  #
    f"ozgf.fs100.ch18-ld-norm.l1-sev.k10_{half_test_modelspec}_prefit.htm-tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4",
    f"ozgf.fs100.ch18-ld-norm.l1-sev.k15_{half_test_modelspec}_prefit.hfm-tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4",
    f"ozgf.fs100.ch18-ld-norm.l1-sev.k25_{half_test_modelspec}_prefit.hqm-tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4",
    f"ozgf.fs100.ch18-ld-norm.l1-sev.k50_{half_test_modelspec}_prefit.hhm-tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4",
    f"ozgf.fs100.ch18-ld-norm.l1-sev_{half_test_modelspec}_prefit.hm-tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4"
]

batch = 322
sig_cells = get_significant_cells(batch, SIG_TEST_MODELS, as_list=True)

d = []
pcts = ['10', '15', '25', '50', '100']
mdls = ['LN', 'dnns', 'std', 'prefit']

# don't used heldout model for final fit
_modelname_half_prefit = modelname_half_prefit
_modelname_half_prefit[-1] = modelname_half_fullfit[-1]

pre = nd.batch_comp(batch,
                    _modelname_half_prefit,
                    cellids=sig_cells,
                    stat=PLOT_STAT)
full = nd.batch_comp(batch,
                     modelname_half_fullfit,