Beispiel #1
0
def dPrime_from_NIT_site(site, duration, source, position, meta):
    options = {
        'batch': batch,
        'siteid': site,
        'stimfmt': 'envelope',
        'rasterfs': load_fs,
        'recache': False,
        'runclass': 'NTI',
        'stim': False
    }
    load_URI = nb.baphy_load_recording_uri(**options)
    rec = recording.load_recording(load_URI)

    rec = set_recording_subepochs(rec)
    sig = rec['resp']

    # calculates response realiability and select only good cells to improve analysis
    r_vals, goodcells = signal_reliability(sig,
                                           r'\ASTIM_*',
                                           threshold=meta['reliability'])
    goodcells = goodcells.tolist()

    probe_regex = NTI_epoch_name(duration, source, position)
    cp_regex = fr'\AC(({NTI_epoch_name()})|(PreStimSilence))_P{probe_regex}\Z'

    full_rast, transitions, contexts = raster_from_sig(
        sig, cp_regex, goodcells, raster_fs=meta['raster_fs'])

    if len(contexts) < 2:
        real = shuffled = simulated = None
        print(f'only one context for {probe_regex}, skiping analysis')
    else:
        real, shuffled, simulated = nway_analysis(full_rast, meta)

    return real, shuffled, simulated, transitions, contexts
Beispiel #2
0
def fit_transform(site, probe, meta, part):
    recs = load(site)

    if len(recs) > 2:
        print(f'\n\n{recs.keys()}\n\n')

    rec = recs['trip0']
    sig = rec['resp']

    # calculates response realiability and select only good cells to improve analysis
    r_vals, goodcells = signal_reliability(sig,
                                           r'\ASTIM_*',
                                           threshold=meta['reliability'])
    goodcells = goodcells.tolist()

    # get the full data raster Context x Probe x Rep x Neuron x Time
    raster = src.data.rasters.raster_from_sig(
        sig,
        probe,
        channels=goodcells,
        contexts=meta['transitions'],
        smooth_window=meta['smoothing_window'],
        raster_fs=meta['raster_fs'],
        part=part,
        zscore=meta['zscore'])

    # trialR shape: Trial x Cell x Context x Probe x Time; R shape: Cell x Context x Probe x Time
    trialR, _, _ = cdPCA.format_raster(raster)
    trialR = trialR.squeeze()  # squeezes out probe

    # calculates full LDA. i.e. considering all 4 categories
    LDA_projection, LDA_weights = cLDA.fit_transform_over_time(trialR, 1)
    dprime = cDP.pairwise_dprimes(LDA_projection.squeeze())

    return dprime, LDA_projection, LDA_weights
Beispiel #3
0
def dPCA_fourway_analysis(site, probe, meta):
    # recs = load(site, remote=True, rasterfs=meta['raster_fs'], recache=False)
    recs = load(site, rasterfs=meta['raster_fs'], recache=rec_recache)

    if len(recs) > 2:
        print(f'\n\n{recs.keys()}\n\n')

    rec = recs['trip0']
    sig = rec['resp']

    # calculates response realiability and select only good cells to improve analysis
    r_vals, goodcells = signal_reliability(sig,
                                           r'\ASTIM_*',
                                           threshold=meta['reliability'])
    goodcells = goodcells.tolist()

    # get the full data raster Context x Probe x Rep x Neuron x Time
    raster = src.data.rasters.raster_from_sig(
        sig,
        probe,
        channels=goodcells,
        contexts=meta['transitions'],
        smooth_window=meta['smoothing_window'],
        raster_fs=meta['raster_fs'],
        zscore=meta['zscore'])

    # trialR shape: Trial x Cell x Context x Probe x Time; R shape: Cell x Context x Probe x Time
    trialR, R, _ = cdPCA.format_raster(raster)
    trialR, R = trialR.squeeze(axis=3), R.squeeze(axis=2)  # squeezes out probe
    Re, C, S, T = trialR.shape

    # calculates full dPCA. i.e. considering all 4 categories
    dPCA_projection, dPCA_transformation = cdPCA.fit_transform(R, trialR)
    dprime = cDP.pairwise_dprimes(dPCA_projection,
                                  observation_axis=0,
                                  condition_axis=1)

    # calculates floor (ctx shuffle) and ceiling (simulated data)
    sim_dprime = np.empty([meta['montecarlo']] + list(dprime.shape))
    shuf_dprime = np.empty([meta['montecarlo']] + list(dprime.shape))

    ctx_shuffle = trialR.copy()
    # pbar = ProgressBar()
    for rr in range(meta['montecarlo']):
        # ceiling: simulates data, calculates dprimes
        sim_trial = np.random.normal(np.mean(trialR, axis=0),
                                     np.std(trialR, axis=0),
                                     size=[Re, C, S, T])
        sim_projection = cdPCA.transform(sim_trial, dPCA_transformation)
        sim_dprime[rr, ...] = cDP.pairwise_dprimes(sim_projection,
                                                   observation_axis=0,
                                                   condition_axis=1)

        ctx_shuffle = shuffle(ctx_shuffle, shuffle_axis=2, indie_axis=0)
        shuf_projection = cdPCA.transform(ctx_shuffle, dPCA_transformation)
        shuf_dprime[rr, ...] = cDP.pairwise_dprimes(shuf_projection,
                                                    observation_axis=0,
                                                    condition_axis=1)

    return dprime, shuf_dprime, sim_dprime, goodcells
Beispiel #4
0
def _load_site_formated_raste(site, contexts, probes, meta, recache_rec=False):
    """
    wrapper of wrappers. Load a recording, selects the subset of data (triplets, or permutations), generates raster using
    selected  probes and transitions
    :param site:
    :param meta:
    :param recache_rec:
    :return:
    """

    recs = load(site, rasterfs=meta['raster_fs'], recache=recache_rec)
    if len(recs) > 2:
        print(f'\n\n{recs.keys()}\n\n')

    # pulls the right recording depending on stimulus type and pulls the signal from it.
    if meta['stim_type'] == 'triplets':
        type_key = 'trip0'
    elif meta['stim_type'] == 'permutations':
        type_key = 'perm0'
    else:
        raise ValueError(
            f"unknown stim type, use 'triplets' or 'permutations'")

    sig = recs[type_key]['resp']

    # calculates response realiability and select only good cells to improve analysis
    r_vals, goodcells = signal_reliability(sig,
                                           r'\ASTIM_*',
                                           threshold=meta['reliability'])
    goodcells = goodcells.tolist()

    # get the full data raster Context x Probe x Rep x Neuron x Time
    raster = src.data.rasters.raster_from_sig(
        sig,
        probes=probes,
        channels=goodcells,
        contexts=contexts,
        smooth_window=meta['smoothing_window'],
        raster_fs=meta['raster_fs'],
        zscore=meta['zscore'],
        part='probe')

    return raster, goodcells
Beispiel #5
0
def fit_transform(site, probe, meta, part):
    recs = load(site)

    if len(recs) > 2:
        print(f'\n\n{recs.keys()}\n\n')

    rec = recs['trip0']
    sig = rec['resp']

    # calculates response realiability and select only good cells to improve analysis
    r_vals, goodcells = signal_reliability(sig,
                                           r'\ASTIM_*',
                                           threshold=meta['reliability'])
    goodcells = goodcells.tolist()

    raster = src.data.rasters.raster_from_sig(
        sig,
        probe,
        channels=goodcells,
        contexts=meta['transitions'],
        smooth_window=meta['smoothing_window'],
        raster_fs=meta['raster_fs'],
        part=part,
        zscore=meta['zscore'])

    # trialR shape: Trial x Cell x Context x Probe x Time; R shape: Cell x Context x Probe x Time
    trialR, R, _ = cdPCA.format_raster(raster)
    trialR, R = trialR.squeeze(), R.squeeze()  # squeezes out probe

    _, dPCA_projection, _, dpca = cdPCA._cpp_dPCA(R,
                                                  trialR,
                                                  significance=False,
                                                  dPCA_parms={})
    dPCA_projection = dPCA_projection['ct'][:, 0, ...]
    dPCA_weights = np.tile(dpca.D['ct'][:, 0][:, None, None],
                           [1, 1, R.shape[-1]])

    dprime = cDP.pairwise_dprimes(dPCA_projection)

    return dprime, dPCA_projection, dPCA_weights, dpca
Beispiel #6
0
def dPCA_fourway_analysis(site, probe, meta):
    recs = load(site)

    if len(recs) > 2:
        print(f'\n\n{recs.keys()}\n\n')

    rec = recs['trip0']
    sig = rec['resp']

    # calculates response realiability and select only good cells to improve analysis
    r_vals, goodcells = signal_reliability(sig,
                                           r'\ASTIM_*',
                                           threshold=meta['reliability'])
    goodcells = goodcells.tolist()

    # get the full data raster Context x Probe x Rep x Neuron x Time
    raster = src.data.rasters.raster_from_sig(
        sig,
        probe,
        channels=goodcells,
        contexts=meta['transitions'],
        smooth_window=meta['smoothing_window'],
        raster_fs=meta['raster_fs'],
        zscore=meta['zscore'])

    # trialR shape: Trial x Cell x Context x Probe x Time; R shape: Cell x Context x Probe x Time
    trialR, R, _ = cdPCA.format_raster(raster)
    trialR, R = trialR.squeeze(), R.squeeze()  # squeezes out probe
    Re, C, S, T = trialR.shape

    # calculates full dPCA. i.e. considering all 4 categories
    def fit_transformt(R, trialR):
        _, dPCA_projection, _, dpca = cdPCA._cpp_dPCA(R,
                                                      trialR,
                                                      significance=False,
                                                      dPCA_parms={})
        dPCA_projection = dPCA_projection['ct'][:, 0, ]
        dPCA_transformation = np.tile(dpca.D['ct'][:, 0][:, None, None],
                                      [1, 1, T])
        return dPCA_projection, dPCA_transformation

    dPCA_projection, dPCA_transformation = fit_transformt(R, trialR)
    dprime = cDP.pairwise_dprimes(dPCA_projection)

    # calculates floor (ctx shuffle) and ceiling (simulated data)
    sim_dprime = np.empty([meta['montecarlo']] + list(dprime.shape))
    shuf_dprime = np.empty([meta['montecarlo']] + list(dprime.shape))

    ctx_shuffle = trialR.copy()

    pbar = ProgressBar()
    for rr in pbar(range(meta['montecarlo'])):
        # ceiling: simulates data, calculates dprimes
        sim_trial = np.random.normal(np.mean(trialR, axis=0),
                                     np.std(trialR, axis=0),
                                     size=[Re, C, S, T])
        sim_projection = cLDA.transform_over_time(
            cLDA._reorder_dims(sim_trial), dPCA_transformation)
        sim_dprime[rr, ...] = cDP.pairwise_dprimes(
            cLDA._recover_dims(sim_projection).squeeze())

        ctx_shuffle = shuffle(ctx_shuffle, shuffle_axis=2, indie_axis=0)
        shuf_projection = cLDA.transform_over_time(
            cLDA._reorder_dims(ctx_shuffle), dPCA_transformation)
        shuf_dprime[rr, ...] = cDP.pairwise_dprimes(
            cLDA._recover_dims(shuf_projection).squeeze())

    return dprime, shuf_dprime, sim_dprime
Beispiel #7
0
def dPCA_site_summary(site, probe):
    # loads the raw data
    recs = load(site, rasterfs=meta['raster_fs'], recache=rec_recache)
    sig = recs['trip0']['resp']

    # calculates response realiability and select only good cells to improve analysis
    r_vals, goodcells = signal_reliability(sig,
                                           r'\ASTIM_*',
                                           threshold=meta['reliability'])
    goodcells = goodcells.tolist()

    # get the full data raster Context x Probe x Rep x Neuron x Time
    raster = src.data.rasters.raster_from_sig(
        sig,
        probe,
        channels=goodcells,
        contexts=meta['transitions'],
        smooth_window=meta['smoothing_window'],
        raster_fs=meta['raster_fs'],
        zscore=meta['zscore'],
        part='probe')

    # trialR shape: Trial x Cell x Context x Probe x Time; R shape: Cell x Context x Probe x Time
    trialR, R, _ = cdPCA.format_raster(raster)
    trialR, R = trialR.squeeze(axis=3), R.squeeze(axis=2)  # squeezes out probe
    Z, trialZ, dpca = cdPCA._cpp_dPCA(R, trialR)

    fig, axes = plt.subplots(2, 3, sharex='all', sharey='row')
    for vv, (marginalization, arr) in enumerate(Z.items()):
        means = Z[marginalization]
        trials = trialZ[marginalization]

        if marginalization == 't':
            marginalization = 'probe'
        elif marginalization == 'ct':
            marginalization = 'context'
        for pc in range(3):  # first 3 principal components

            ax = axes[vv, pc]
            for tt, trans in enumerate(
                    meta['transitions']):  # for each context
                ax.plot(times,
                        means[pc, tt, :],
                        label=trans,
                        color=trans_color_map[trans],
                        linewidth=2)
                # _ = fplt._cint(times, trials[:,pc,tt,:],  confidence=0.95, ax=ax,
                #                fillkwargs={'color': trans_color_map[trans], 'alpha': 0.5})
                ax.tick_params(labelsize=ax_val_size)

            # formats axes labels and ticks
            if pc == 0:  # y labels
                ax.set_ylabel(
                    f'{marginalization} dependent\nfiring rate (z-score)',
                    fontsize=ax_lab_size)
            else:
                ax.axes.get_yaxis().set_visible(True)
                pass

            if vv == 0:
                ax.set_title(f'dPC #{pc + 1}', fontsize=sub_title_size)
                ax.axes.get_xaxis().set_visible(True)
            else:
                ax.set_xlabel('time (ms)', fontsize=ax_lab_size)

            ## Hide the right and top spines
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            ax.tick_params(labelsize=ax_val_size)
    # legend in last axis
    axes[-1, -1].legend(loc='upper right',
                        fontsize='x-large',
                        markerscale=10,
                        frameon=False)

    return fig, ax, dpca
Beispiel #8
0
def analysis_steps_plot(id, probe, source):
    site = id[:7] if source == 'SC' else id

    # loads the raw data
    recs = load(site, rasterfs=meta['raster_fs'], recache=False)
    sig = recs['trip0']['resp']
    # calculates response realiability and select only good cells to improve analysis
    r_vals, goodcells = signal_reliability(sig,
                                           r'\ASTIM_*',
                                           threshold=meta['reliability'])
    goodcells = goodcells.tolist()
    # get the full data raster Context x Probe x Rep x Neuron x Time
    raster = src.data.rasters.raster_from_sig(
        sig,
        probe,
        channels=goodcells,
        contexts=meta['transitions'],
        smooth_window=meta['smoothing_window'],
        raster_fs=meta['raster_fs'],
        zscore=meta['zscore'],
        part='probe')
    # trialR shape: Trial x Cell x Context x Probe x Time; R shape: Cell x Context x Probe x Time
    trialR, R, _ = cdPCA.format_raster(raster)
    trialR, R = trialR.squeeze(axis=3), R.squeeze(axis=2)  # squeezes out probe

    if source == 'dPCA':
        projection, _ = cdPCA.fit_transform(R, trialR)
    elif source == 'LDA':
        projection, _ = cLDA.fit_transform_over_time(trialR)
        projection = projection.squeeze(axis=1)

    if meta['zscore'] is False:
        trialR = trialR * meta['raster_fs']
        if source == 'dPCA':
            projection = projection * meta['raster_fs']

    # flips signs of dprimes and montecarlos as needed
    dprimes, shuffleds = cDP.flip_dprimes(
        batch_dprimes[source]['dprime'][id],
        batch_dprimes[source]['shuffled_dprime'][id],
        flip='max')
    if source in ['dPCA', 'LDA']:
        _, simulations = cDP.flip_dprimes(
            batch_dprimes[source]['dprime'][id],
            batch_dprimes[source]['simulated_dprime'][id],
            flip='max')

    t = times[:trialR.shape[-1]]
    nrows = 2 if source == 'SC' else 3
    fig, axes = plt.subplots(nrows, 6, sharex='all', sharey='row')

    #  PSTH
    for tt, trans in enumerate(itt.combinations(meta['transitions'], 2)):
        t0_idx = meta['transitions'].index(trans[0])
        t1_idx = meta['transitions'].index(trans[1])

        if source == 'SC':
            cell_idx = goodcells.index(id)
            axes[0, tt].plot(t,
                             trialR[:, cell_idx, t0_idx, :].mean(axis=0),
                             color=trans_color_map[trans[0]],
                             linewidth=3)
            axes[0, tt].plot(t,
                             trialR[:, cell_idx, t1_idx, :].mean(axis=0),
                             color=trans_color_map[trans[1]],
                             linewidth=3)
        else:
            axes[0, tt].plot(t,
                             projection[:, t0_idx, :].mean(axis=0),
                             color=trans_color_map[trans[0]],
                             linewidth=3)
            axes[0, tt].plot(t,
                             projection[:, t1_idx, :].mean(axis=0),
                             color=trans_color_map[trans[1]],
                             linewidth=3)

    # Raster, dprime, CI
    bottom, top = axes[0, 0].get_ylim()
    half = ((top - bottom) / 2) + bottom
    for tt, trans in enumerate(itt.combinations(meta['transitions'], 2)):
        prb_idx = all_probes.index(probe)
        pair_idx = tt

        if source == 'SC':
            # raster
            cell_idx = goodcells.index(id)
            t0_idx = meta['transitions'].index(trans[0])
            t1_idx = meta['transitions'].index(trans[1])

            _ = fplt._raster(t,
                             trialR[:, cell_idx, t0_idx, :],
                             y_offset=0,
                             y_range=(bottom, half),
                             ax=axes[0, tt],
                             scatter_kws={
                                 'color': trans_color_map[trans[0]],
                                 'alpha': 0.4,
                                 's': 10
                             })
            _ = fplt._raster(t,
                             trialR[:, cell_idx, t1_idx, :],
                             y_offset=0,
                             y_range=(half, top),
                             ax=axes[0, tt],
                             scatter_kws={
                                 'color': trans_color_map[trans[1]],
                                 'alpha': 0.4,
                                 's': 10
                             })

        # plots the real dprime and the shuffled dprime ci
        axes[1, tt].plot(t, dprimes[prb_idx, pair_idx, :], color='black')
        _ = fplt._cint(t,
                       shuffleds[:, prb_idx, pair_idx, :],
                       confidence=0.95,
                       ax=axes[1, tt],
                       fillkwargs={
                           'color': 'black',
                           'alpha': 0.5
                       })

        if source in ['dPCA', 'LDA']:
            # plots the real dprime and simulated dprime ci
            axes[2, tt].plot(t, dprimes[prb_idx, pair_idx, :], color='black')
            _ = fplt._cint(t,
                           simulations[:, prb_idx, pair_idx, :],
                           confidence=0.95,
                           ax=axes[2, tt],
                           fillkwargs={
                               'color': 'black',
                               'alpha': 0.5
                           })

    # significance bars
    ax1_bottom = axes[1, 0].get_ylim()[0]
    if source == 'dPCA':
        ax2_bottom = axes[2, 0].get_ylim()[0]
    for tt, trans in enumerate(itt.combinations(meta['transitions'], 2)):
        prb_idx = all_probes.index(probe)
        pair_idx = tt
        # histogram of context discrimination
        axes[1, tt].bar(
            t,
            batch_dprimes[source]['shuffled_significance'][id][prb_idx,
                                                               pair_idx, :],
            width=bar_width,
            align='center',
            edgecolor='white',
            bottom=ax1_bottom)
        if source in ['dPCA', 'LDA']:
            # histogram of population effects
            axes[2, tt].bar(t,
                            batch_dprimes[source]['simulated_significance'][id]
                            [prb_idx, pair_idx, :],
                            width=bar_width,
                            align='center',
                            edgecolor='white',
                            bottom=ax2_bottom)

        # formats legend
        if tt == 0:
            axes[0, tt].set_ylabel(f'dPC', fontsize=ax_lab_size)
            axes[0, tt].tick_params(labelsize=ax_val_size)
            axes[1, tt].set_ylabel(f'dprime', fontsize=ax_lab_size)
            axes[1, tt].tick_params(labelsize=ax_val_size)
            if source in ['dPCA', 'LDA']:
                axes[2, tt].set_ylabel(f'dprime', fontsize=ax_lab_size)
                axes[2, tt].tick_params(labelsize=ax_val_size)

        axes[-1, tt].set_xlabel('time (ms)', fontsize=ax_lab_size)
        axes[-1, tt].tick_params(labelsize=ax_val_size)
        axes[0, tt].set_title(f'{trans[0]}_{trans[1]}',
                              fontsize=sub_title_size)

        for ax in np.ravel(axes):
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)

    return fig, axes
Beispiel #9
0
meta = {
    'reliability': 0.1,  # r value
    'smoothing_window': 20
}  # ms

for site in all_sites:
    # load and format triplets from a site
    recs = load(site)
    rec = recs['perm0']
    sig = rec['resp'].rasterize()

    # calculates response realiability and select only good cells to improve analysis

    r_vals, goodcells = signal_reliability(sig,
                                           r'\ASTIM_*',
                                           threshold=meta['reliability'])
    goodcells = goodcells.tolist()

    # plots PSTHs of all probes after silence
    # fig, axes = cplot.hybrid(sig, epoch_names=r'\AC0_P\d\Z', channels=goodcells)

    # plots PSHTs of individual best probe after all contexts
    # fig, axes = cplot.hybrid(sig, epoch_names=r'\AC\d_P3\Z', channels=goodcells)

    # takes an example probe
    full_array, invalid_cp, valid_cp, all_contexts, all_probes = \
        tp.make_full_array(sig, channels=goodcells, smooth_window=meta['smoothing_window'])

    # get a specific probe after a set of different transitions
Beispiel #10
0
def twoway_analysis(site, probe, meta):
    recs = load(site)

    if len(recs) > 2:
        print(f'\n\n{recs.keys()}\n\n')

    rec = recs['trip0']
    sig = rec['resp']

    # calculates response realiability and select only good cells to improve analysis
    r_vals, goodcells = signal_reliability(sig,
                                           r'\ASTIM_*',
                                           threshold=meta['reliability'])
    goodcells = goodcells.tolist()

    # outer lists to save the dprimes foe each pair of ctxs
    dprime = list()
    shuf_dprime = list()
    sim_dprime = list()

    for transitions in itt.combinations(meta['transitions'], 2):

        # get the full data raster Context x Probe x Rep x Neuron x Time
        raster = src.data.rasters.raster_from_sig(
            sig,
            probe,
            channels=goodcells,
            contexts=transitions,
            smooth_window=meta['smoothing_window'],
            raster_fs=meta['raster_fs'],
            zscore=meta['zscore'])

        # trialR shape: Trial x Cell x Context x Probe x Time; R shape: Cell x Context x Probe x Time
        trialR, _, _ = cdPCA.format_raster(raster)
        trialR = trialR.squeeze()  # squeezes out probe
        R, C, S, T = trialR.shape

        # calculates LDA across the two selected transitions categories
        LDA_projection, LDA_transformation = cLDA.fit_transform_over_time(
            trialR, 1)
        dp = cDP.pairwise_dprimes(LDA_projection.squeeze())
        dprime.append(dp)

        # calculates floor (ctx shuffle) and ceiling (simulated data)
        sim_dp = np.empty([meta['montecarlo']] + list(dp.shape))
        shuf_dp = np.empty([meta['montecarlo']] + list(dp.shape))

        ctx_shuffle = trialR.copy()

        pbar = ProgressBar()
        for rr in pbar(range(meta['montecarlo'])):
            # ceiling: simulates data, calculates dprimes
            sim_trial = np.random.normal(np.mean(trialR, axis=0),
                                         np.std(trialR, axis=0),
                                         size=[R, C, S, T])
            sim_projection = cLDA.transform_over_time(
                cLDA._reorder_dims(sim_trial), LDA_transformation)
            sim_dp[rr, ...] = cDP.pairwise_dprimes(
                cLDA._recover_dims(sim_projection).squeeze())

            ctx_shuffle = shuffle(ctx_shuffle, shuffle_axis=2, indie_axis=0)
            shuf_projection, _ = cLDA.fit_transform_over_time(ctx_shuffle)
            shuf_dp[rr, ...] = cDP.pairwise_dprimes(shuf_projection.squeeze())

        shuf_dprime.append(shuf_dp)
        sim_dprime.append(sim_dp)

    # orders the list into arrays of the same shape as the fourwise analysis: MonteCarlo x Pair x Time

    dprime = np.concatenate(dprime, axis=0)
    shuf_dprime = np.concatenate(shuf_dprime, axis=1)
    sim_dprime = np.concatenate(sim_dprime, axis=1)

    return dprime, shuf_dprime, sim_dprime
Beispiel #11
0
def cell_dprime(site, probe, meta):
    # recs = load(site, remote=True, rasterfs=meta['raster_fs'], recache=False)
    recs = load(site, rasterfs=meta['raster_fs'], recache=rec_recache)
    if len(recs) > 2:
        print(f'\n\n{recs.keys()}\n\n')

    rec = recs['trip0']
    sig = rec['resp']

    # calculates response realiability and select only good cells to improve analysis
    r_vals, goodcells = signal_reliability(sig,
                                           r'\ASTIM_*',
                                           threshold=meta['reliability'])
    goodcells = goodcells.tolist()

    # get the full data raster Context x Probe x Rep x Neuron x Time
    raster = src.data.rasters.raster_from_sig(
        sig,
        probe,
        channels=goodcells,
        contexts=meta['transitions'],
        smooth_window=meta['smoothing_window'],
        raster_fs=meta['raster_fs'],
        zscore=meta['zscore'],
        part='probe')

    # trialR shape: Trial x Cell x Context x Probe x Time; R shape: Cell x Context x Probe x Time
    trialR, R, _ = cdPCA.format_raster(raster)
    trialR, R = trialR.squeeze(axis=3), R.squeeze(axis=2)  # squeezes out probe

    rep, chn, ctx, tme = trialR.shape

    trans_pairs = [
        f'{x}_{y}' for x, y in itt.combinations(meta['transitions'], 2)
    ]

    dprime = cDP.pairwise_dprimes(
        trialR, observation_axis=0,
        condition_axis=2)  # shape CellPair x Cell x Time

    # Shuffles the rasters n times and organizes in an array with the same shape the raster plus one dimension
    # with size n containing each shuffle

    shuffled = list()
    # pbar = ProgressBar()
    print(f"\nshuffling {meta['montecarlo']} times")
    for tp in trans_pairs:
        shuf_trialR = np.empty([meta['montecarlo'], rep, chn, 2, tme])
        shuf_trialR[:] = np.nan

        tran_idx = np.array(
            [meta['transitions'].index(t) for t in tp.split('_')])
        ctx_shuffle = trialR[:, :, tran_idx, :].copy()

        for rr in range(meta['montecarlo']):
            shuf_trialR[rr, ...] = shuffle(ctx_shuffle,
                                           shuffle_axis=2,
                                           indie_axis=0)

        shuffled.append(
            cDP.pairwise_dprimes(shuf_trialR,
                                 observation_axis=1,
                                 condition_axis=3))

    shuffled = np.stack(shuffled, axis=1).squeeze(axis=0).swapaxes(
        0, 1)  # shape Montecarlo x ContextPair x Cell x Time

    return dprime, shuffled, goodcells, trans_pairs