def spatial_info_perm_test(SI,
                           C,
                           position,
                           tstart,
                           tstop,
                           nperms=10000,
                           shuffled_SI=None,
                           win_trial=True):
    '''run permutation test on spatial information calculations. returns empirical p-values for each cell'''
    if len(C.shape) > 2:
        C = np.expand_dims(C, 1)

    if shuffled_SI is None:
        shuffled_SI = np.zeros([nperms, C.shape[1]])

        for perm in range(nperms):

            if win_trial:
                C_tmat, occ_tmat, edes, centers = u.make_pos_bin_trial_matrices(
                    C, position, tstart, tstop, perm=True)
            else:
                C_perm = np.roll(C,
                                 randrange(30, position.shape[0], 30),
                                 axis=0)
                C_tmat, occ_tmat, edes, centers = u.make_pos_bin_trial_matrices(
                    C, position, tstart, tstop, perm=False)

            fr, occ = np.squeeze(np.nanmean(C_tmat,
                                            axis=0)), occ_tmat.sum(axis=0)
            occ /= occ.sum()

            si = spatial_info(fr, occ)
            shuffled_SI[perm, :] = si

    p = np.zeros([
        C.shape[1],
    ])
    for cell in range(C.shape[1]):
        #print(SI[cell],np.max(shuffled_SI[:,cell]))
        #p[cell] = np.where(SI[cell]>shuffled_SI[:,cell])[0].shape[0]/nperms
        p[cell] = np.sum(SI[cell] > shuffled_SI[:, cell]) / nperms

    return p, shuffled_SI
    def confusion_matrix(self, decode_dict, save=False):
        d_trial_mat, tr, edges, centers = u.make_pos_bin_trial_matrices(
            decode_dict['pop ix'], self.pos, self.tstarts, self.teleports)
        d_m_dict = u.trial_type_dict(d_trial_mat, self.trial_info['morphs'])

        keys = np.unique(self.trial_info['morphs'])
        c = np.zeros(
            [d_trial_mat.shape[-1], d_trial_mat.shape[1] * keys.shape[0]])
        for n, key in enumerate(keys.tolist()):
            c[:, n * d_trial_mat.shape[1]:(n + 1) *
              d_trial_mat.shape[1]] = np.nanmean(d_m_dict[key], axis=0).T

        f, ax = plt.subplots()
        ax.imshow(c, cmap='viridis', vmin=0, vmax=.5)
        ax.set_xlabel('True Label')
        ax.set_ylabel('Decoded Label')
        if save:
            try:
                os.makedirs(self.prefix)
            except:
                pass
            f.savefig(os.path.join(self.prefix, "confusion_matrix.png"),
                      format="png")
        return c, (f, ax)
    def plot_llr(self, LLR, save=False):

        keys = np.unique(self.trial_info['morphs'])
        nm = keys.shape[0]
        # pos binned data
        llr_pos, occ, edges, centers = u.make_pos_bin_trial_matrices(
            LLR, self.pos, self.tstarts, self.teleports)

        # by morph means
        d_pos = u.trial_type_dict(llr_pos, self.trial_info['morphs'])

        # time-binned data
        llr_time = u.make_time_bin_trial_matrices(LLR, self.tstarts,
                                                  self.teleports)

        # by morph means
        d_time = u.trial_type_dict(llr_time, self.trial_info['morphs'])

        mu_pos = np.zeros([keys.shape[0], llr_pos.shape[1]])
        sem_pos = np.zeros([keys.shape[0], llr_pos.shape[1]])
        mu_time = np.zeros([keys.shape[0], llr_time.shape[1]])
        sem_time = np.zeros([keys.shape[0], llr_time.shape[1]])
        for j, k in enumerate(keys):
            mu_pos[j, :] = np.nanmean(d_pos[k], axis=0)
            sem_pos[j, :] = np.nanstd(d_pos[k], axis=0)
            mu_time[j, :] = np.nanmean(d_time[k], axis=0)
            sem_time[j, :] = np.nanstd(d_time[k], axis=0)

        # actually plot stuff
        f_mp, ax_mp = plt.subplots()
        f_mt, ax_mt = plt.subplots()
        time = np.arange(0, mu_time.shape[1]) * 1 / 15.46
        for z in range(nm):
            # for zz in range(5):
            ax_mp.plot(centers, mu_pos[z, :], color=plt.cm.cool(keys[z]))
            ax_mt.plot(time, mu_time[z, :], color=plt.cm.cool(keys[z]))

            ax_mp.fill_between(centers,
                               mu_pos[z, :] + sem_pos[z, :],
                               y2=mu_pos[z, :] - sem_pos[z, :],
                               color=plt.cm.cool(keys[z]),
                               alpha=.4)
            ax_mt.fill_between(time,
                               mu_time[z, :] + sem_time[z, :],
                               y2=mu_time[z, :] - sem_time[z, :],
                               color=plt.cm.cool(keys[z]),
                               alpha=.4)

        ax_mp.set_xlabel('position')
        ax_mp.set_ylabel('LLR')
        ax_mt.set_xlabel('time')
        ax_mt.set_ylabel('LLR')

        # ff_pos,aax_pos = plt.subplots()
        f_pos, ax_pos = plt.subplots(1, nm, figsize=[20, 5])
        f_time, ax_time = plt.subplots(1, nm, figsize=[20, 5])
        # ff_time,aax_time = plt.subplots()
        for i, (start, stop, m, r) in enumerate(
                zip(self.tstarts.tolist(), self.teleports.tolist(),
                    self.trial_info['morphs'].tolist(),
                    self.trial_info['rewards'].tolist())):
            # print(start,stop,m,r,wj,bj)
            self._single_line_llr_multiax(self.pos[start:stop],
                                          LLR[start:stop], m, r, ax_pos)
            self._single_line_llr_multiax(np.arange(stop - start) * 1. / 15.46,
                                          LLR[start:stop],
                                          m,
                                          r,
                                          ax_time,
                                          xlim=[0, 250])

        ax_pos[0].set_xlabel('position')
        # aax_pos.set_xlabel('position')
        ax_time[0].set_xlabel('time')
        # aax_time.set_xlabel('time')
        for z in [0, -1]:
            for a in range(nm):

                ax_pos[a].fill_between(centers,
                                       mu_pos[z, :] + sem_pos[z, :],
                                       y2=mu_pos[z, :] - sem_pos[z, :],
                                       color=plt.cm.cool(keys[z]),
                                       alpha=.4)
                ax_time[a].fill_between(time,
                                        mu_time[z, :] + sem_time[z, :],
                                        y2=mu_time[z, :] - sem_time[z, :],
                                        color=plt.cm.cool(keys[z]),
                                        alpha=.4)

        if save:
            try:
                os.makedirs(self.prefix)
            except:
                pass
            f_mp.savefig(os.path.join(self.prefix, "LLR_position.png"),
                         format="png")
            f_mt.savefig(os.path.join(self.prefix, "LLR_time.png"),
                         format="png")
            f_pos.savefig(os.path.join(self.prefix, "LLR_pos_st.png"),
                          format="png")
            f_time.savefig(os.path.join(self.prefix, "LLR_time_st.png"),
                           format="png")

        return (f_mp, ax_mp), (f_mt, ax_mt), (f_pos, ax_pos), (
            f_time, ax_time
        )  #, (ff_pos,aax_pos), (f_time,ax_time), (ff_time,aax_time)
        np.save(os.path.join(path, "confusion_matrix.npy"), cm)

        mp, mt, pos, time = s.plot_llr(LLR_pop, save=True)

        # plot a random 200 cells
        llrdir = os.path.join(prefix, "cell_llr")
        postdir = os.path.join(prefix, "cell_post_i")
        try:
            os.makedirs(llrdir)
            os.makedirs(postdir)

        except:
            pass

        cell_trial_mat, o, edges, centers = u.make_pos_bin_trial_matrices(
            post_i, s.pos, s.tstarts, s.teleports)
        cell_d = u.trial_type_dict(cell_trial_mat, s.trial_info['morphs'])
        cell_l_mat, o, edges, centers = u.make_pos_bin_trial_matrices(
            LLR, s.pos, s.tstarts, s.teleports)
        cell_l_d = u.trial_type_dict(cell_l_mat, s.trial_info['morphs'])

        keys = np.unique(s.trial_info['morphs'])
        perm = np.random.permutation(s.C_z.shape[1])[:200]
        for c in perm.tolist():
            f, ax = plt.subplots(1,
                                 keys.shape[0],
                                 figsize=[4 * keys.shape[0], 4])
            ff, aax = plt.subplots(1,
                                   keys.shape[0],
                                   figsize=[4 * keys.shape[0], 4])
            for a, m in enumerate(keys.tolist()):
def place_cells_calc(C,
                     position,
                     trial_info,
                     tstart_inds,
                     teleport_inds,
                     method="all",
                     pthr=.95,
                     correct_only=False,
                     speed=None,
                     win_trial_perm=False,
                     morphlist=[0, 1]):
    '''get masks for significant place cells that have significant place info
    in both even and odd trials'''

    C_trial_mat, occ_trial_mat, edges, centers = u.make_pos_bin_trial_matrices(
        C, position, tstart_inds, teleport_inds, speed=speed)

    morphs = trial_info['morphs']
    if correct_only:
        mask = trial_info['rewards'] > 0
        morphs = morphs[mask]
        C_trial_mat = C_trial_mat[mask, :, :]
        occ_trial_mat = occ_trial_mat[mask, :]

    C_morph_dict = u.trial_type_dict(C_trial_mat, morphs)
    occ_morph_dict = u.trial_type_dict(occ_trial_mat, morphs)
    tstart_inds, teleport_inds = np.where(tstart_inds == 1)[0], np.where(
        teleport_inds == 1)[0]
    tstart_morph_dict = u.trial_type_dict(tstart_inds, morphs)
    teleport_morph_dict = u.trial_type_dict(teleport_inds, morphs)

    # for each morph value
    FR, masks, SI = {}, {}, {}
    for m in morphlist:

        FR[m] = {}
        SI[m] = {}

        # firing rate maps
        FR[m]['all'] = np.nanmean(C_morph_dict[m], axis=0)
        occ_all = occ_morph_dict[m].sum(axis=0)
        occ_all /= occ_all.sum()
        SI[m]['all'] = spatial_info(FR[m]['all'], occ_all)
        if method == 'split_halves':
            FR[m]['odd'] = np.nanmean(C_morph_dict[m][0::2, :, :], axis=0)
            FR[m]['even'] = np.nanmean(C_morph_dict[m][1::2, :, :], axis=0)

            # occupancy
            occ_o, occ_e = occ_morph_dict[m][0::2, :].sum(
                axis=0), occ_morph_dict[m][1::2, :].sum(axis=0)
            occ_o /= occ_o.sum()
            occ_e /= occ_e.sum()

            SI[m]['odd'] = spatial_info(FR[m]['odd'], occ_o)
            SI[m]['even'] = spatial_info(FR[m]['even'], occ_e)

            p_e, shuffled_SI = spatial_info_perm_test(
                SI[m]['even'],
                C,
                position,
                tstart_morph_dict[m][1::2],
                teleport_morph_dict[m][1::2],
                nperms=100,
                win_trial=win_trial_perm)
            p_o, shuffled_SI = spatial_info_perm_test(
                SI[m]['odd'],
                C,
                position,
                tstart_morph_dict[m][0::2],
                teleport_morph_dict[m][0::2],
                nperms=100,
                win_trial=win_trial_perm)  #,shuffled_SI=shuffled_SI)

            masks[m] = np.multiply(p_e > pthr, p_o > pthr)

        elif method == 'bootstrap':
            n_boots = 30
            # drop trial with highest firing rate
            tmat = C_morph_dict[m]
            # maxtrial = np.argmax(tmat.sum(axis=1),axis=0)
            # print(maxtrial.shape)
            # mask = np.ones([tmat.shape[0],])
            # mask[maxtrial]=0
            # mask = mask>0
            # tmat = tmat[mask,:,:]
            omat = occ_morph_dict[m]  #[mask,:,:]

            SI_bs = np.zeros([n_boots, C.shape[1]])
            print("start bootstrap")
            for b in range(n_boots):

                # pick a random subset of trials
                ntrials = tmat.shape[0]  #C_morph_dict[m].shape[0]
                bs_pcnt = .67  # proportion of trials to keep
                bs_thr = int(bs_pcnt * ntrials)  # number of trials to keep
                bs_inds = np.random.permutation(ntrials)[:bs_thr]
                FR_bs = np.nanmean(tmat[bs_inds, :, :], axis=0)
                #np.nanmean(C_morph_dict[m][bs_inds,:,:],axis=0)
                occ_bs = omat[bs_inds, :].sum(
                    axis=0)  #occ_morph_dict[m][bs_inds,:].sum(axis=0)
                occ_bs /= occ_bs.sum()
                SI_bs[b, :] = spatial_info(FR_bs, occ_bs)
            print("end bootstrap")
            SI[m]['bootstrap'] = np.median(SI_bs, axis=0).ravel()
            p_bs, shuffled_SI = spatial_info_perm_test(
                SI[m]['bootstrap'],
                C,
                position,
                tstart_morph_dict[m],
                teleport_morph_dict[m],
                nperms=100,
                win_trial=win_trial_perm)
            masks[m] = p_bs > pthr

        else:
            p_all, shuffled_SI = spatial_info_perm_test(
                SI[m]['all'],
                C,
                position,
                tstart_morph_dict[m],
                teleport_morph_dict[m],
                nperms=100,
                win_trial=win_trial_perm)
            masks[m] = p_all > pthr

    return masks, FR, SI
def single_session(sess,
                   savefigs=False,
                   fbase=None,
                   deconv=False,
                   correct_only=False,
                   speedThr=False,
                   method='bootstrap',
                   win_trial_perm=False,
                   cell_method='s2p',
                   morphlist=[0, 1]):

    # load calcium data and aligned vr
    VRDat, C, S, A = pp.load_scan_sess(sess,
                                       fneu_coeff=.7,
                                       analysis=cell_method)

    if deconv:
        C = S
    else:
        C = u.df(C)

    # get trial by trial info
    trial_info, tstart_inds, teleport_inds = u.by_trial_info(VRDat)
    C_trial_mat, occ_trial_mat, edges, centers = u.make_pos_bin_trial_matrices(
        C, VRDat['pos']._values, VRDat['tstart']._values,
        VRDat['teleport']._values)
    C_morph_dict = u.trial_type_dict(C_trial_mat, trial_info['morphs'])
    occ_morph_dict = u.trial_type_dict(occ_trial_mat, trial_info['morphs'])
    #mask = VRDat['pos']._values>0

    # find place cells individually on odd and even trials
    # keep only cells with significant spatial information on both
    if speedThr:
        masks, FR, SI = place_cells_calc(C,
                                         VRDat['pos']._values,
                                         trial_info,
                                         VRDat['tstart']._values,
                                         VRDat['teleport']._values,
                                         method=method,
                                         correct_only=correct_only,
                                         speed=VRDat.speed._values,
                                         win_trial_perm=win_trial_perm,
                                         morphlist=morphlist)
    else:
        masks, FR, SI = place_cells_calc(C,
                                         VRDat['pos']._values,
                                         trial_info,
                                         VRDat['tstart']._values,
                                         VRDat['teleport']._values,
                                         method=method,
                                         correct_only=correct_only,
                                         win_trial_perm=win_trial_perm,
                                         morphlist=morphlist)

    # plot place cells by morph
    f_pc, ax_pc = plot_placecells(C_morph_dict, masks)

    ########################################################
    # number in each environment
    print('morph 0 place cells = %g out of %g , %f ' %
          (masks[0].sum(), masks[0].shape[0],
           masks[0].sum() / masks[0].shape[0]))
    print('morph 1 place cells = %g out of %g, %f' %
          (masks[1].sum(), masks[1].shape[0],
           masks[1].sum() / masks[1].shape[0]))

    # number with place fields in both
    common_pc = np.multiply(masks[0], masks[1])
    print('common place cells = %g' % common_pc.sum())
    # including, excluding reward zones

    # ####### stability
    # # first vs second half correlation
    # sc_corr, pv_corr= {}, {}
    # sc_corr[0], pv_corr[0] = stability_split_halves(C_morph_dict[0])
    # sc_corr[1], pv_corr[1] = stability_split_halves(C_morph_dict[1])

    #   (fancier version, tortuosity of warping function over time)
    # not implemented yet

    # ####### tuning specificity
    # #   vector length of circularized tuning curve
    # mvl = {}
    # mvl[0] = meanvectorlength(FR[0]['all'])
    # mvl[1] = meanvectorlength(FR[1]['all'])

    # reward cell scatter plot
    FR_0_cpc = FR[0]['all'][:, common_pc]
    FR_1_cpc = FR[1]['all'][:, common_pc]
    f_rc, ax_rc = reward_cell_scatterplot(FR_0_cpc, FR_1_cpc)

    # cell's topography
    # # place cell in which morph
    # both = np.where((masks[0]>0) & (masks[1]>0) )[0]
    # none = np.where((masks[0]==0) & (masks[1]==0))[0]
    # m0 = np.where((masks[0]==1) & (masks[1]==0))[0]
    # m1 = np.where((masks[0]==0) & (masks[1]==1))[0]
    # #tvals = np.zeros([A.shape[1],])
    #tvals[both]=.01
    #tvals[m0]=-1
    #tvals[m1]=1

    # reward zone score

    # place field width

    # place cell reliability

    if savefigs:
        f_pc.savefig(fbase + "_pc.pdf", format='pdf')
        f_pc.savefig(fbase + "_pc.svg", format='svg')

        f_rc.savefig(fbase + "_rc.pdf", format='pdf')
        f_rc.savefig(fbase + "_rc.svg", format='svg')

    return FR, masks, SI