示例#1
0
    def get_lapwiseerror_peranimal(self):
        files = [f for f in os.listdir(self.BayesFolder)]
        accuracy_dict = OrderedDict()
        numlaps_dict = OrderedDict()
        for f in files:
            print(f)
            animalname = f[:f.find('_')]
            if animalname == 'CFC12' and self.CFC12flag == 0:
                continue
            animal_tasks = DataDetails.ExpAnimalDetails(
                animalname)['task_dict']
            trackbins = DataDetails.ExpAnimalDetails(animalname)['trackbins']
            data = np.load(os.path.join(self.BayesFolder, f),
                           allow_pickle=True)
            animal_accuracy = {k: [] for k in animal_tasks}
            animal_numlaps = {k: [] for k in animal_tasks}
            for t in animal_tasks:
                animal_accuracy[t] = self.calulate_lapwiseerror(
                    y_actual=data['fit'].item()[t]['ytest'],
                    y_predicted=data['fit'].item()[t]['yang_pred'],
                    numlaps=data['numlaps'].item()[t],
                    lapframes=data['lapframes'].item()[t])

                animal_numlaps[t] = data['numlaps'].item()[t]
            accuracy_dict[animalname] = animal_accuracy
            numlaps_dict[animalname] = animal_numlaps

        return accuracy_dict, numlaps_dict
示例#2
0
 def __init__(self, DirectoryName, BayesFolder):
     colors = sns.color_palette('muted')
     self.colors = [colors[0], colors[1], colors[3], colors[2]]
     self.Foldername = DirectoryName
     self.BayesFolder = BayesFolder
     self.SaveFolder = os.path.join(self.Foldername, 'SaveAnalysed')
     self.animalname = [
         f for f in os.listdir(self.Foldername)
         if f not in ['LickData', 'BayesResults_All', 'SaveAnalysed']
     ]
     print(self.animalname)
     self.taskdict = ['Task1', 'Task2', 'Task3', 'Task4']
     self.framespersec = 30.98
     self.tracklength = 200
     self.velocity_in_space, self.bayescompiled = OrderedDict(
     ), OrderedDict()
     self.slope, self.speed_ratio = OrderedDict(), OrderedDict()
     for a in self.animalname:
         animalinfo = DataDetails.ExpAnimalDetails(a)
         animaltasks = animalinfo['task_dict']
         bayesfile = [f for f in os.listdir(self.BayesFolder) if a in f][0]
         self.accuracy_dict = self.get_bayes_error(animaltasks, bayesfile)
         self.goodrunningdata, self.running_data, self.good_running_index, self.lickdata, self.lapspeed, self.numlaps, = self.load_runningdata(
             a)
         self.good_lapframes = self.get_lapframes(a, animaltasks)
         plt.plot(self.goodrunningdata['Task2'])
         plt.plot(self.good_lapframes['Task2'])
         plt.title(np.max(self.good_lapframes['Task2']))
         plt.show()
         self.velocity_in_space[a], self.bayescompiled[
             a] = self.get_velocity_in_space_bylap(a, animaltasks)
         self.slope[a], self.speed_ratio[a] = self.get_slopeatend(
             a, animaltasks)
     #
     self.save_data()
    def compile_confusion_matrix(self, fs, ax):
        cm_all = {k: np.zeros((int(self.tracklength / self.trackbins), int(self.tracklength / self.trackbins))) for k in
                  self.taskdict.keys()}
        for a in self.animals:
            animalinfo = DataDetails.ExpAnimalDetails(a)
            if self.datatype == 'endzonerem':
                bayesmodel = np.load(
                    os.path.join(animalinfo['saveresults'], 'modeloneachtask_withendzonerem.npy'),
                    allow_pickle=True).item()
            else:
                bayesmodel = np.load(os.path.join(animalinfo['saveresults'], 'modeloneachtask.npy'),
                                     allow_pickle=True).item()

            for t in animalinfo['task_dict']:
                cm = bayesmodel[t]['cm']
                cm_all[t] += cm

        for n, t in enumerate(self.taskdict.keys()):
            cm_all[t] = cm_all[t].astype('float') / cm_all[t].sum(axis=1)[:, np.newaxis]
            img = ax[n].imshow(cm_all[t], cmap="Blues", vmin=0, vmax=0.4, interpolation='nearest')
            ax[n].plot(ax[n].get_xlim()[::-1], ax[n].get_ylim(), ls="--", c=".3", lw=1)
            pf.set_axes_style(ax[n], numticks=3, both=True)

            if n == len(self.taskdict) - 1:
                CommonFunctions.create_colorbar(fighandle=fs, axis=ax[n], imghandle=img, title='Probability',
                                                ticks=[0, 0.4])

        ax[0].set_ylabel('Actual')
        ax[0].set_xlabel('Predicted')
示例#4
0
    def get_data_peranimal(self):
        for a in self.animalname:
            print(a)
            animalinfo = DataDetails.ExpAnimalDetails(a)
            Fcdata, SmFcdata = self.gd.get_dff(a, animalinfo)
            lapframes = self.gd.get_lapframes(a, animalinfo)
            good_running_index, laps_with_licks, laps_without_licks = self.gd.get_behavior_params(
                a)
            placecell = self.gd.load_placecells(a)
            placecell = placecell['Task1']

            # activecells = self.get_cells_with_transients(animalinfo, Fcdata, SmFcdata)
            auc, amplitude, length, frequency, numtransients = self.get_data_transients_pertask(
                animalinfo,
                placecell,
                Fcdata,
                SmFcdata,
                lapframes=lapframes,
                laps_withlicks=laps_with_licks,
                laps_withoutlicks=laps_without_licks,
                threshold=self.dffthreshold,
                transthreshold=self.transthreshold)
            self.save_transient_properties(a, auc, amplitude, length,
                                           frequency, numtransients)
            self.compile_to_dataframe(a, auc, amplitude, length, frequency,
                                      numtransients)
def plot_compiled_errorcorrelation(axis, SaveFolder, TaskDict, trackbins, to_plot='R2', classifier_type='Bayes'):
    l = Compile()
    files = [f for f in os.listdir(SaveFolder) if classifier_type in f]
    compilemeanerr = {k: [] for k in TaskDict.keys()}
    compilelaptime = {k: [] for k in TaskDict.keys()}
    for f in files:
        print(f)
        animalname = f[:f.find('_')]
        animal_tasks = DataDetails.ExpAnimalDetails(animalname)['task_dict']
        data = np.load(os.path.join(SaveFolder, f), allow_pickle=True)
        for n, t in enumerate(animal_tasks):
            m, laptime = l.plotlaptime_withbayeserror(axis=axis[n],
                                                      task=t,
                                                      lapframes=data['lapframes'].item()[t],
                                                      y_actual=data['fit'].item()[t]['ytest'],
                                                      y_predicted=data['fit'].item()[t]['yang_pred'],
                                                      numlaps=data['numlaps'].item()[t],
                                                      laptime=data['laptime'].item()[t],
                                                      trackbins=trackbins, to_plot=to_plot)
            pf.set_axes_style(axis[n], numticks=4)
            compilemeanerr[t].extend(m)
            compilelaptime[t].extend(laptime)

    # # get and plot best fit line
    # for n, t in enumerate(TaskDict):
    #     regression_line = l.best_fit_slope_and_intercept(np.asarray(compilemeanerr[t]), np.asarray(compilelaptime[t]))
    #     corrcoef = np.corrcoef(np.asarray(compilemeanerr[t]), np.asarray(compilelaptime[t]))[0, 1]
    #     axis[n].plot(compilemeanerr[t], regression_line, color='k', linewidth=2)
    #     axis[n].set_title('%s r = %0.2f' % (t, corrcoef))

    return compilelaptime, compilemeanerr
def plot_compiledconfusionmatrix(axis, SaveFolder, TaskDict, classifier_type='Bayes'):
    files = [f for f in os.listdir(SaveFolder) if classifier_type in f]
    all_ytest = {k: [] for k in TaskDict.keys()}
    all_ypred = {k: [] for k in TaskDict.keys()}
    for f in files:
        animal_tasks = DataDetails.ExpAnimalDetails(f[:f.find('_')])['task_dict']
        data = np.load(os.path.join(SaveFolder, f), allow_pickle=True)
        print(f)
        for n, t in enumerate(animal_tasks):
            if t == 'Task1a':
                all_ytest[t].extend(data['fit'].item()[t]['ytest'][-500:])
                all_ypred[t].extend(data['fit'].item()[t]['yang_pred'][-500:])
            else:
                all_ytest[t].extend(data['fit'].item()[t]['ytest'])
                all_ypred[t].extend(data['fit'].item()[t]['yang_pred'])

    for n, t in enumerate(['Task1', 'Task1a']):
        y_actual = all_ytest[t]
        y_predicted = all_ypred[t]
        cm = confusion_matrix(y_actual, y_predicted)
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        img = axis[n].imshow(cm, cmap="Blues", vmin=0, vmax=0.4, interpolation='nearest')
        axis[n].plot(axis[n].get_xlim()[::-1], axis[n].get_ylim(), ls="--", c=".3", lw=1)
        pf.set_axes_style(axis[n], numticks=3, both=True)
        axis[n].set_title(t)
    axis[0].set_ylabel('Actual')
    axis[0].set_xlabel('Predicted')
示例#7
0
    def __init__(self, AnimalName, FolderName, CompiledFolderName,
                 classifier_type, taskstoplot):
        print('Loading Data')
        self.animalname = AnimalName
        self.animalinfo = DataDetails.ExpAnimalDetails(self.animalname)
        self.FolderName = os.path.join(FolderName, self.animalname)
        self.CompiledFolderName = CompiledFolderName  # For importing Bayes results
        self.Task_Numframes = self.animalinfo['task_numframes']
        self.TaskDict = self.animalinfo['task_dict']
        self.classifier_type = classifier_type
        self.framespersec = 30.98
        self.taskstoplot = taskstoplot
        self.trackbins = 5

        # Run functions
        self.get_data_folders()
        if self.animalinfo['v73_flag']:
            self.load_v73_Data()
        else:
            self.load_fluorescentdata()
        self.get_place_cells()
        self.load_behaviordata()
        self.load_lapparams()
        self.rasterdata = self.combinedata_forraster(self.taskstoplot)
        self.colors = sns.color_palette('deep', len(self.taskstoplot))
def plot_histogram_error_bytask(SaveFolder, TaskDict, trackbins, to_plot='R2', classifier_type='Bayes'):
    l = Compile()
    files = [f for f in os.listdir(SaveFolder) if classifier_type in f]
    axis_draw = {'Task1': 0, 'Task2': 0, 'Task2b': 0, 'Task3': 1, 'Task4': 2}
    compileerror = {k: [] for k in TaskDict}
    animalname = []
    for f in files:
        print(f)
        animalname.append(f[:f.find('_')])
        animal_tasks = DataDetails.ExpAnimalDetails(f[:f.find('_')])['task_dict']
        data = np.load(os.path.join(SaveFolder, f), allow_pickle=True)
        for n, t in enumerate(animal_tasks):
            lap_r2, lap_accuracy = l.calulate_lapwiseerror(y_actual=data['fit'].item()[t]['ytest'],
                                                           y_predicted=data['fit'].item()[t]['yang_pred'],
                                                           trackbins=trackbins,
                                                           numlaps=data['numlaps'].item()[t],
                                                           lapframes=data['lapframes'].item()[t])

            if to_plot == 'R2':
                errdata = lap_r2[~np.isnan(lap_r2)]
            else:
                errdata = lap_accuracy[~np.isnan(lap_accuracy)]
            if t == 'Task1':
                compileerror[t].extend(errdata[-3:])
            elif t == 'Task2':
                lickstop = data['lickstoplap'].item()['Task2']
                compileerror[t].extend(errdata[lickstop - 3:lickstop])
                compileerror['Task2b'].extend(errdata[lickstop:lickstop + 3])
            else:
                compileerror[t].extend(errdata)
    return compileerror
    def combine_placecells_pertask(self, fig, axis, taskstoplot):
        pc_activity_dict = {keys: np.asarray([]) for keys in taskstoplot}
        perccells_peranimal = {keys: [] for keys in taskstoplot + ['animal']}
        pcsortednum = {keys: [] for keys in taskstoplot}
        for a in self.animals:
            animalinfo = DataDetails.ExpAnimalDetails(a)
            if len(animalinfo['task_dict']) == 4:
                pf_remapping = np.load(os.path.join(self.FolderName, a,
                                                    'PlaceCells',
                                                    '%s_pcs_pertask.npy' % a),
                                       allow_pickle=True).item()
                pfparams = np.load(os.path.join(self.FolderName, a,
                                                'PlaceCells',
                                                f'%s_placecell_data.npz' % a),
                                   allow_pickle=True)
                perccells_peranimal['animal'].append(a)
                for t in taskstoplot:
                    perccells_peranimal[t].append(
                        (np.sum(pfparams['numPFs_incells'].item()[t]) /
                         pfparams['numcells']) * 100)
                    # print(t, np.sum(pfparams['numPFs_incells'].item()[t]), np.shape(pf_remapping[t]))
                    pc_activity_dict[t] = np.vstack(
                        (pc_activity_dict[t], pf_remapping[t]
                         )) if pc_activity_dict[t].size else pf_remapping[t]

        for t in taskstoplot:
            print(t, np.shape(pc_activity_dict[t]))
            pcsortednum[t] = np.argsort(np.nanargmax(pc_activity_dict[t], 1))

        self.plpc.plot_placecells_pertask(fig, axis, taskstoplot,
                                          pc_activity_dict, pcsortednum)
        perccells_peranimal = pd.DataFrame.from_dict(perccells_peranimal)
        perccells_peranimal = perccells_peranimal.set_index('animal')
        return perccells_peranimal
    def get_com_allanimal(self, taskA, taskB, vmax=0):
        com_all_animal = np.array([])
        count = 0
        for n, f in enumerate(self.csvfiles_pfs):
            animalname = f[:f.find('_')]
            animalinfo = DataDetails.ExpAnimalDetails(animalname)
            if len(animalinfo['task_dict']) >= self.tasklen:
                df = pd.read_csv(os.path.join(self.CombinedDataFolder, f),
                                 index_col=0)
                t1 = df[df['Task'] == taskA]
                t2 = df[df['Task'] == taskB]
                combined = pd.merge(t1,
                                    t2,
                                    how='inner',
                                    on=['CellNumber'],
                                    suffixes=(f'_%s' % taskA, f'_%s' % taskB))

                if count == 0:
                    com_all_animal = np.vstack(
                        (combined[f'WeightedCOM_%s' % taskA] * self.trackbins,
                         combined[f'WeightedCOM_%s' % taskB] * self.trackbins))
                else:
                    com_all_animal = np.hstack(
                        (com_all_animal,
                         np.vstack((combined[f'WeightedCOM_%s' % taskA] *
                                    self.trackbins,
                                    combined[f'WeightedCOM_%s' % taskB] *
                                    self.trackbins))))
                count += 1
        self.plot_com_scatter_heatmap(com_all_animal, taskA, taskB, vmax=vmax)
    def compile_mean_bderror(self, ax):
        mean_error = {k: [] for k in ['R2', 'Task', 'animalname', 'BD error (cm)', 'BD accuracy']}
        for a in self.animals:
            animalinfo = DataDetails.ExpAnimalDetails(a)
            if self.datatype == 'endzonerem':
                bayesmodel = np.load(
                    os.path.join(animalinfo['saveresults'], 'modeloneachtask_withendzonerem.npy'),
                    allow_pickle=True).item()
            else:
                bayesmodel = np.load(os.path.join(animalinfo['saveresults'], 'modeloneachtask.npy'),
                                     allow_pickle=True).item()
            for t in animalinfo['task_dict'].keys():
                kfold = np.max(bayesmodel[t]['K-foldDataframe']['CVIndex'])
                mean_error['R2'].extend(bayesmodel[t]['K-foldDataframe']['R2_angle'])
                mean_error['BD accuracy'].extend(bayesmodel[t]['K-foldDataframe']['ModelAccuracy'])
                for k in np.arange(kfold):
                    mean_error['BD error (cm)'].append(
                        np.mean(bayesmodel[t]['K-foldDataframe']['Y_ang_diff'][k] * self.trackbins))
                    mean_error['Task'].append(t)
                    mean_error['animalname'].append(a)

        mean_error_df = pd.DataFrame.from_dict(mean_error)
        for n, i in enumerate(['R2', 'BD accuracy', 'BD error (cm)']):
            sns.boxplot(x='Task', y=i, data=mean_error_df, palette='Blues', ax=ax[n], showfliers=False)
            for t in self.taskdict.keys():
                if t != 'Task1':
                    d, p = Stats.significance_test(mean_error_df[mean_error_df.Task == t][i],
                                                   mean_error_df[mean_error_df.Task == 'Task1'][i],
                                                   type_of_test='KStest')
                    print(f'%s: %s: KStest: p-value %0.4f' % (i, t, p))
            ax[n].set_xlabel('')
            pf.set_axes_style(ax[n], numticks=4)
示例#12
0
def plot_compilederror_withlick(axis,
                                SaveFolder,
                                TaskDict,
                                trackbins,
                                separate_lickflag=0,
                                licktype='all',
                                to_plot='R2',
                                classifier_type='Bayes'):
    l = Compile()
    files = [f for f in os.listdir(SaveFolder) if classifier_type in f]
    compilelickdata = {k: [] for k in TaskDict.keys()}
    compileerror = {k: [] for k in TaskDict.keys()}
    lickstoplap = []
    animalname = []
    for f in files:
        # print(f)
        animalname.append(f[:f.find('_')])
        animal_tasks = DataDetails.ExpAnimalDetails(
            f[:f.find('_')])['task_dict']
        print(len(animal_tasks))
        data = np.load(os.path.join(SaveFolder, f), allow_pickle=True)
        if data['lickstoplap'].item()['Task2'] > 2:
            print(f)
            for n, t in enumerate(animal_tasks):
                lap_r2, lap_accuracy = l.calulate_lapwiseerror(
                    y_actual=data['fit'].item()[t]['ytest'],
                    y_predicted=data['fit'].item()[t]['yang_pred'],
                    trackbins=trackbins,
                    numlaps=data['numlaps'].item()[t],
                    lapframes=data['lapframes'].item()[t])

                if licktype == 'all':
                    licks = data['alllicks'].item()
                    lickstop = data['lickstoplap'].item()['Task2']
                else:
                    licks = data['licks_befclick'].item()
                    lickstop = data['lickstoplap_befclick'].item()['Task2']

                if to_plot == 'R2':
                    data_compile = lap_r2
                else:
                    data_compile = lap_accuracy

                l.plot_bayeserror_with_lickrate(
                    data_compile, licks[t],
                    data['lickstoplap'].item()['Task2'], t, axis[n],
                    separate_lickflag)
                pf.set_axes_style(axis[n], numticks=4)
                compilelickdata[t].append(licks[t])
                compileerror[t].append(np.asarray(data_compile))
            lickstoplap.append(lickstop)

    return compileerror, compilelickdata, np.asarray(lickstoplap), animalname
示例#13
0
    def __init__(self, FolderName, animalname):
        self.FolderName = FolderName
        self.animalname = animalname
        self.animalinfo = DataDetails.ExpAnimalDetails(self.animalname)
        self.Task_Numframes = self.animalinfo['task_numframes']
        self.TaskDict = self.animalinfo['task_dict']
        self.framespersec = 30.98
        self.trackbins = 5
        self.tracklength = 200
        self.trackstartindex = self.animalinfo['trackstart_index']

        self.get_fluorescence_data()
        self.get_behavior()
        self.load_lapparams()
        self.calculate_velocity()
示例#14
0
    def __init__(self, AnimalName, FolderName, taskstoplot, controlflag=0):
        print('Loading Data')
        self.taskstoplot = taskstoplot
        self.colors = sns.color_palette('deep', len(self.taskstoplot))
        self.animalname = AnimalName
        if controlflag:
            self.animalinfo = DataDetails.ControlAnimals(self.animalname)
        else:
            self.animalinfo = DataDetails.ExpAnimalDetails(self.animalname)
        self.FolderName = os.path.join(FolderName, self.animalname)
        self.Task_Numframes = self.animalinfo['task_numframes']
        self.TaskDict = self.animalinfo['task_dict']

        # Run functions
        self.get_data_folders()
        self.load_behaviordata()
        self.load_lapparams()
    def compile_meanerror_bytrack(self, ax):
        numbins = int(self.tracklength / self.trackbins)
        numanimals = np.size(self.animals)
        kfold = 10
        Y_diff_by_track = {k: np.zeros((numanimals, kfold, numbins)) for k in ['Task1', 'Task2']}
        Y_diff_by_track_mean = {k: [] for k in ['Task1', 'Task2']}

        for n, a in enumerate(self.animals):
            print(a)
            animalinfo = DataDetails.ExpAnimalDetails(a)
            if self.datatype == 'endzonerem':
                bayesmodel = np.load(
                    os.path.join(animalinfo['saveresults'], 'modeloneachtask_withendzonerem.npy'),
                    allow_pickle=True).item()
            else:
                bayesmodel = np.load(os.path.join(animalinfo['saveresults'], 'modeloneachtask.npy'),
                                     allow_pickle=True).item()
            for t in ['Task1', 'Task2']:
                for k in np.arange(kfold):
                    y_diff = np.asarray(bayesmodel[t]['K-foldDataframe']['Y_ang_diff'][k]) * self.trackbins
                    y_test = np.asarray(bayesmodel[t]['K-foldDataframe']['y_test'][k])
                    for i in np.arange(numbins):
                        Y_indices = np.where(y_test == i)[0]
                        Y_diff_by_track[t][n, k, i] = np.mean(y_diff[Y_indices])

        for t in ['Task1', 'Task2']:
            Y_diff_by_track_mean[t] = Y_diff_by_track[t].reshape(numanimals * kfold, numbins)
            meandiff, semdiff = np.nanmean(Y_diff_by_track_mean[t], 0), scipy.stats.sem(Y_diff_by_track_mean[t], 0,
                                                                                        nan_policy='omit')
            error1, error2 = meandiff - semdiff, meandiff + semdiff
            ax.plot(np.arange(numbins), meandiff)
            ax.fill_between(np.arange(numbins), error1, error2, alpha=0.5)
            ax.set_xlabel('Track Length (cm)')
            ax.set_ylabel('BD error (cm)')
            ax.set_xlim((0, numbins))

        d, p = Stats.significance_test(np.mean(Y_diff_by_track_mean['Task2'], 0),
                                       np.mean(Y_diff_by_track_mean['Task1'], 0),
                                       type_of_test='KStest')
        print(f'KStest: p-value %0.4f' % p)

        CommonFunctions.convertaxis_to_tracklength(ax, self.tracklength, self.trackbins, convert_axis='x')
        ax.set_xlim((0, self.tracklength / self.trackbins))
        pf.set_axes_style(ax, numticks=4)
 def combineanimaldataframes(self, csvfiles, tasklen):
     count = 0
     for n, f in enumerate(csvfiles):
         animalname = f[:f.find('_')]
         if not self.controlflag:
             animalinfo = DataDetails.ExpAnimalDetails(animalname)
         else:
             animalinfo = DataDetails.ControlAnimals(animalname)
         if len(animalinfo['task_dict']) >= tasklen:
             print(f)
             df = pd.read_csv(os.path.join(self.CombinedDataFolder, f),
                              index_col=0)
             if count == 0:
                 combined_dataframe = df
             else:
                 combined_dataframe = combined_dataframe.append(
                     df, ignore_index=True)
             count += 1
     return combined_dataframe
    def get_com_allanimal(self, fig, axis, taskA, taskB, vmax=0):
        csvfiles_pfs = [
            f for f in os.listdir(self.CombinedDataFolder)
            if f.endswith('.csv') and 'reward' not in f and 'common' not in f
        ]
        com_all_animal = np.array([])
        count = 0
        for n, f in enumerate(csvfiles_pfs):
            a = f[:f.find('_')]
            animalinfo = DataDetails.ExpAnimalDetails(a)
            if len(animalinfo['task_dict']) == 4:
                print(f)
                df = pd.read_csv(os.path.join(self.CombinedDataFolder, f),
                                 index_col=0)
                t1 = df[df['Task'] == taskA]
                t2 = df[df['Task'] == taskB]
                combined = pd.merge(t1,
                                    t2,
                                    how='inner',
                                    on=['CellNumber'],
                                    suffixes=(f'_%s' % taskA, f'_%s' % taskB))

                if count == 0:
                    com_all_animal = np.vstack(
                        (combined[f'WeightedCOM_%s' % taskA] * self.trackbins,
                         combined[f'WeightedCOM_%s' % taskB] * self.trackbins))
                else:
                    com_all_animal = np.hstack(
                        (com_all_animal,
                         np.vstack((combined[f'WeightedCOM_%s' % taskA] *
                                    self.trackbins,
                                    combined[f'WeightedCOM_%s' % taskB] *
                                    self.trackbins))))
                count += 1
        self.plpc.plot_com_scatter_heatmap(fig,
                                           axis,
                                           com_all_animal,
                                           taskA,
                                           taskB,
                                           self.tracklength,
                                           vmax=vmax)
        return np.abs(np.subtract(com_all_animal[0, :], com_all_animal[1, :]))
示例#18
0
    def __init__(self, AnimalName, FolderName, SaveFigureFolder, taskstoplot):
        self.taskstoplot = taskstoplot
        self.SaveFigureFolder = SaveFigureFolder
        self.colors = sns.color_palette('deep')
        self.task2_colors = [self.colors[1], self.colors[3]]

        self.animalname = AnimalName
        self.animalinfo = DataDetails.ExpAnimalDetails(self.animalname)

        self.ParentFolderName = FolderName
        self.FolderName = os.path.join(FolderName, self.animalname)
        self.Task_Numframes = self.animalinfo['task_numframes']
        self.Task_Numframes['Task3'] = 14999
        self.removeframesforbayes = self.animalinfo['task_framestokeep']
        self.TaskDict = self.animalinfo['task_dict']
        self.framespersec = 30.98
        self.trackbins = 5

        self.get_data_folders()
        self.load_lapparams()
        self.load_behaviordata()
    def compile_numcells(self, ax, taskstoplot, placecellflag=0):
        percsamples = [5, 10, 20, 50, 80, 100]
        percsamples = [f'%d%%' % p for p in percsamples]

        numcells_combined = pd.DataFrame([])
        for a in self.animals:
            animalinfo = DataDetails.ExpAnimalDetails(a)
            bayesmodel = np.load(os.path.join(animalinfo['saveresults'],
                                              'modeloneachtask_lapwise.npy'),
                                 allow_pickle=True).item()

            for t in animalinfo['task_dict']:
                if not placecellflag:
                    numcells_dataframe = bayesmodel[t]['Numcells_Dataframe']
                else:
                    numcells_dataframe = bayesmodel[t][
                        'Placecells_sample_Dataframe']
                numcells_dataframe['Task'] = t
                numcells_dataframe['animalname'] = a
                numcells_combined = pd.concat(
                    (numcells_combined, numcells_dataframe), ignore_index=True)
        g = numcells_combined.groupby(['SampleSize', 'Task', 'animalname'
                                       ]).agg([np.mean]).reset_index()
        g.columns = g.columns.droplevel(1)
        if placecellflag:
            g['Type'] = 'Placecells'
        else:
            g['Type'] = 'Allcells'
        sns.pointplot(x='SampleSize',
                      y='R2_angle',
                      data=g[g.Task.isin(taskstoplot)],
                      order=percsamples,
                      hue='Task',
                      ax=ax)
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        ax.set_xlabel('Percentage of active cells used')
        ax.set_ylabel('R-squared')
        pf.set_axes_style(ax, numticks=4)
        return g
    def combine_placecells_withtask(self,
                                    fig,
                                    axis,
                                    taskstoplot,
                                    tasktocompare='Task1'):
        pc_activity_dict = {keys: np.asarray([]) for keys in taskstoplot}
        for a in self.animals:
            animalinfo = DataDetails.ExpAnimalDetails(a)
            if len(animalinfo['task_dict']) == 4:
                pf_remapping = np.load(os.path.join(
                    self.FolderName, a, 'PlaceCells',
                    '%s_pcs_sortedbyTask1.npy' % a),
                                       allow_pickle=True).item()

                for t in taskstoplot:
                    pc_activity_dict[t] = np.vstack(
                        (pc_activity_dict[t], pf_remapping[t]
                         )) if pc_activity_dict[t].size else pf_remapping[t]

        pcsortednum = {keys: [] for keys in taskstoplot}
        pcsorted = np.argsort(np.nanargmax(pc_activity_dict[tasktocompare], 1))
        for t in taskstoplot:
            pcsortednum[t] = pcsorted

        # Correlate place cells
        corrcoef_dict = self.find_correlation(pc_activity_dict, taskstoplot,
                                              tasktocompare)

        task_data = pc_activity_dict['Task1'][pcsorted, :]
        normalise_data = np.nanmax(task_data, 1)[:, np.newaxis]
        self.plpc.plot_placecells_pertask(fig,
                                          axis,
                                          taskstoplot,
                                          pc_activity_dict,
                                          pcsortednum,
                                          normalise_data=normalise_data)

        return corrcoef_dict, pc_activity_dict, pcsortednum
示例#21
0
    def __init__(self, AnimalName, FolderName, SaveFigureFolder, taskstoplot, controlflag=0):
        print('Loading Data')
        self.taskstoplot = taskstoplot
        self.SaveFigureFolder = SaveFigureFolder
        self.controlflag = controlflag
        if self.controlflag:
            self.colors = sns.color_palette(["#3498db", "#9b59b6"])
        else:
            self.colors = sns.color_palette('deep')
            self.task2_colors = [self.colors[1], self.colors[3]]

        self.animalname = AnimalName
        if self.controlflag:
            self.animalinfo = DataDetails.ControlAnimals(self.animalname)
        else:
            self.animalinfo = DataDetails.ExpAnimalDetails(self.animalname)

        self.ParentFolderName = FolderName
        self.FolderName = os.path.join(FolderName, self.animalname)
        self.Task_Numframes = self.animalinfo['task_numframes']
        self.removeframesforbayes = self.animalinfo['task_framestokeep']
        self.TaskDict = self.animalinfo['task_dict']
        self.framespersec = 30.98
        self.trackbins = 5

        # Run functions
        self.get_data_folders()
        if self.animalinfo['v73_flag']:
            self.load_v73_Data()
        else:
            self.load_fluorescentdata()
        self.load_Bayesfit()
        self.load_behaviordata()
        self.load_lapparams()

        if not self.controlflag:
            self.lickstoplap = np.int(self.lickstoplap['Task2'] - 1)
            self.lickstopframe = np.where(self.good_lapframes['Task2'] == self.lickstoplap)[0][0]
    def __init__(self, AnimalName, FolderName, SaveFigureFolder, taskstoplot):
        self.taskstoplot = taskstoplot
        self.SaveFigureFolder = SaveFigureFolder
        self.colors = sns.color_palette('deep')
        self.task2_colors = [self.colors[1], self.colors[3]]

        self.animalname = AnimalName
        self.animalinfo = DataDetails.ExpAnimalDetails(self.animalname)
        self.ParentFolderName = FolderName
        self.FolderName = os.path.join(FolderName, self.animalname)
        self.Task_Numframes = self.animalinfo['task_numframes']
        self.removeframesforbayes = self.animalinfo['task_framestokeep']
        self.TaskDict = self.animalinfo['task_dict']
        self.framespersec = 30.98
        self.trackbins = 5

        self.get_data_folders()
        if self.animalinfo['v73_flag']:
            self.load_v73_Data()
        else:
            self.load_fluorescentdata()
        self.load_lapparams()
        self.load_behaviordata()
        self.raster_fdata, self.raster_cdata = self.combinedata_correct_forraster(
        )
        self.make_rastermap(self.raster_fdata, self.raster_cdata)
        fs, ax = plt.subplots(3,
                              sharex='all',
                              dpi=300,
                              gridspec_kw={
                                  'height_ratios': [2, 0.5, 0.5],
                                  'hspace': 0.3
                              })
        self.plot_rastermap(ax,
                            fdata=self.raster_fdata,
                            crop_cellflag=0,
                            ylim_meandff=0.1)
    def compile_numcells(self, ax, taskstoplot, legendflag=0):
        percsamples = [1, 5, 10, 20, 50, 80, 100]
        percsamples = [f'%d%%' % p for p in percsamples]
        numcells_combined = pd.DataFrame([])
        for a in self.animals:
            print(a)
            animalinfo = DataDetails.ExpAnimalDetails(a)
            if self.datatype == 'endzonerem':
                bayesmodel = np.load(
                    os.path.join(animalinfo['saveresults'], 'modeloneachtask_withendzonerem.npy'),
                    allow_pickle=True).item()
            else:
                bayesmodel = np.load(os.path.join(animalinfo['saveresults'], 'modeloneachtask.npy'),
                                     allow_pickle=True).item()

            for t in animalinfo['task_dict']:
                numcells_dataframe = bayesmodel[t]['Numcells_Dataframe']
                numcells_dataframe['Task'] = t
                numcells_dataframe['animalname'] = a
                numcells_combined = pd.concat((numcells_combined, numcells_dataframe), ignore_index=True)
        g = numcells_combined.groupby(['SampleSize', 'Task', 'animalname']).agg([np.mean]).reset_index()
        g.columns = g.columns.droplevel(1)
        sns.pointplot(x='SampleSize', y='R2_angle', data=g[g.Task.isin(taskstoplot)], order=percsamples, hue='Task',
                      ax=ax)
        if legendflag:
            ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        else:
            ax.get_legend().remove()
        ax.set_xlabel('Percentage of active cells used')
        ax.set_ylabel('R-squared')
        # ax.set_aspect(aspect=1.6)
        pf.set_axes_style(ax, numticks=4)

        for t in self.taskdict:
            if t != 'Task1':
                d, p = Stats.significance_test(g[g.Task == t]['R2'], g[g.Task == 'Task1']['R2'], type_of_test='KStest')
                print(f'%s: KStest : p-value %0.4f' % (t, p))
    def calculate_ratiofiring_atrewzone(self, ax, combined_dataframe,
                                        tasks_to_compare, ranges):
        cellratio_df = pd.DataFrame(
            columns=['Mid', 'End', 'Animal', 'TaskName'])
        cellratio_dict = {k: [] for k in tasks_to_compare}
        for n1, a in enumerate(np.unique(combined_dataframe.animalname)):
            if not self.controlflag:
                animalinfo = DataDetails.ExpAnimalDetails(a)
                if len(animalinfo['task_dict']) == 4:
                    compare = tasks_to_compare
                else:
                    compare = tasks_to_compare[:-1]
            else:
                compare = tasks_to_compare
            for n2, taskname in enumerate(compare):
                normfactor = np.sum(
                    np.load(os.path.join(self.CombinedDataFolder,
                                         [f for f in self.npzfiles
                                          if a in f][0]),
                            allow_pickle=True)['numcells'].item())
                data = combined_dataframe[(combined_dataframe.Task == taskname)
                                          &
                                          (combined_dataframe.animalname == a)]
                g = data.groupby(
                    pd.cut(data.WeightedCOM * 5,
                           ranges)).count()['WeightedCOM'].tolist()
                cellratio_df = cellratio_df.append(
                    {
                        'Beg': (np.mean(g[0:2]) / normfactor) * 100,
                        'Mid': (np.mean(g[2:-2]) / normfactor) * 100,
                        'End': (g[-1] / normfactor) * 100,
                        'Animal': a,
                        'TaskName': taskname
                    },
                    ignore_index=True)
                # cellratio_dict[taskname].append(g / normfactor)
        df = cellratio_df.melt(id_vars=['Animal', 'TaskName'],
                               var_name='Track',
                               value_name='Ratio')
        if self.controlflag:
            sns.pointplot(x='Track',
                          y='Ratio',
                          hue='TaskName',
                          order=['Beg', 'Mid', 'End'],
                          hue_order=['Task1a', 'Task1b'],
                          data=df,
                          dodge=0.35,
                          capsize=.1,
                          ci=68,
                          lw=0.3,
                          ax=ax)
        else:
            sns.pointplot(x='Track',
                          y='Ratio',
                          hue='TaskName',
                          order=['Beg', 'Mid', 'End'],
                          hue_order=['Task1', 'Task2b', 'Task3'],
                          data=df,
                          dodge=0.35,
                          capsize=.1,
                          ci=68,
                          lw=0.3,
                          ax=ax)
        ax.legend(bbox_to_anchor=(0, -0.5), loc=2, borderaxespad=0., ncol=2)

        # Get p-values
        for track in ['Beg', 'Mid', 'End']:
            for t1 in tasks_to_compare:
                for t2 in tasks_to_compare[1:]:
                    if t1 != t2:
                        x = df[(df.TaskName == t1)
                               & (df.Track == track)]['Ratio']
                        y = df[(df.TaskName == t2)
                               & (df.Track == track)]['Ratio']
                        t, p = scipy.stats.mannwhitneyu(x, y)
                        print(
                            'Track %s : Between %s and %s: %0.3f, significant %s'
                            % (track, t1, t2, p, p < 0.05))

        pf.set_axes_style(ax)
        return df, cellratio_df, cellratio_dict
    def get_lapwise_correlation_peranimal(self, taskstoplot, axis):
        numlaps = {'Task1': 5, 'Task2': 14, 'Task3': 11}
        correlation_data = np.zeros((len(self.npyfiles) - 2, sum(numlaps.values())))
        lick_data = np.zeros((len(self.npyfiles) - 2, sum(numlaps.values())))
        count = 0
        for n1, f in enumerate(self.npyfiles):
            print(f, count, np.shape(correlation_data))
            animalname = f[: f.find('_')]
            animal_tasks = DataDetails.ExpAnimalDetails(animalname)['task_dict']
            corr_data = self.get_correlation_data(f)
            corr_animal = corr_data['correlation_withTask1'].item()
            sigPFs = corr_data['sig_PFs_cellnum'].item()['Task1']
            lickstoplap = self.get_animal_behaviordata(animalname)['lick_stop'].item()['Task2']
            lick_per_lap = self.get_animal_behaviordata(animalname)['numlicks_withinreward_alllicks'].item()
            if lickstoplap > 2:
                count_lap = 0
                for n2, t in enumerate(animal_tasks.keys()):
                    if t in taskstoplot:
                        corr_sigPFs = corr_animal[t][sigPFs, :]
                        tasklap = np.size(corr_animal[t], 1)
                        if t == 'Task1':
                            randlaps = np.random.choice(np.arange(0, tasklap), numlaps[t], replace=False)
                            this_task_data = np.nanmedian(corr_sigPFs[:, np.arange(12, 12 + numlaps[t])], 0)
                            this_lick_data = lick_per_lap[t][-numlaps[t]:]
                        elif t == 'Task2':
                            this_task_data = np.nanmedian(corr_sigPFs[:, lickstoplap - 3:lickstoplap + 11], 0)
                            this_lick_data = lick_per_lap[t][lickstoplap - 3:lickstoplap + 11]
                        else:
                            this_task_data = np.nanmedian(corr_sigPFs[:, :numlaps[t]], 0)
                            this_lick_data = lick_per_lap[t][:numlaps[t]]

                        correlation_data[count, count_lap:count_lap + numlaps[t]] = this_task_data
                        lick_data[count, count_lap:count_lap + numlaps[t]] = this_lick_data
                        count_lap += numlaps[t]
                count += 1

        # Normalize and compare for p-value with Task1
        corr_norm = correlation_data / np.max(correlation_data[:, :numlaps['Task1']])
        lick_norm = lick_data / np.max(lick_data[:, :numlaps['Task1']])

        # Plot_traces
        plot_axis = [axis, axis.twinx()]
        colors = sns.color_palette('dark', 2)
        label = ['Mean Correlation', 'Mean Licks']
        for n, d in enumerate([corr_norm, lick_norm]):
            mean = np.mean(d, 0)
            sem = scipy.stats.sem(d, 0)
            if n == 0:
                plot_axis[n].errorbar(np.arange(np.size(mean)), mean, yerr=sem, color=colors[n])

            plot_axis[n].plot(np.arange(np.size(mean)), mean, '.-', color=colors[n])
            plot_axis[n].set_ylabel(label[n], color=colors[n])

        # Get p-values
        for l in np.arange(np.size(correlation_data, 1)):
            d, p = scipy.stats.ranksums(correlation_data[:, l], correlation_data[:, 0])
            if np.round(p, 3) < 0.05:
                axis.plot(l, 0.9, '*', color='k')
            print(l, p)
        for a in plot_axis:
            pf.set_axes_style(axis)
        axis.set_xlabel('Lap Number')

        return correlation_data, lick_data
示例#26
0
def plot_lapwiseerror_withlick(axis,
                               SaveFolder,
                               taskstoplot,
                               trackbins=5,
                               to_plot='R2',
                               classifier_type='Bayes'):
    numlaps = {'Task1': 5, 'Task2': 14, 'Task3': 11}
    l = Compile()
    files = [f for f in os.listdir(SaveFolder) if classifier_type in f]
    correlation_data = np.zeros((len(files) - 2, sum(numlaps.values())))
    lick_data = np.zeros((len(files) - 2, sum(numlaps.values())))
    count = 0
    for n1, f in enumerate(files):
        animalname = f[:f.find('_')]
        animal_tasks = DataDetails.ExpAnimalDetails(animalname)['task_dict']
        data = np.load(os.path.join(SaveFolder, f), allow_pickle=True)
        lickstoplap = data['lickstoplap'].item()['Task2']
        lick_per_lap = data['alllicks'].item()
        if lickstoplap > 2:
            print(f)
            count_lap = 0
            for t in animal_tasks.keys():

                if t in taskstoplot:
                    lap_r2, lap_accuracy = l.calulate_lapwiseerror(
                        y_actual=data['fit'].item()[t]['ytest'],
                        y_predicted=data['fit'].item()[t]['yang_pred'],
                        trackbins=trackbins,
                        numlaps=data['numlaps'].item()[t],
                        lapframes=data['lapframes'].item()[t])

                    if to_plot == 'R2':
                        decodererror = np.asarray(lap_r2)
                    else:
                        decodererror = np.asarray(lap_accuracy)
                    decodererror = decodererror[~np.isnan(decodererror)]

                    if t == 'Task1':
                        this_task_data = decodererror[-numlaps[t]:]
                        this_lick_data = lick_per_lap[t][-numlaps[t]:]
                    elif t == 'Task2':
                        this_task_data = decodererror[lickstoplap -
                                                      3:lickstoplap + 11]
                        this_lick_data = lick_per_lap[t][lickstoplap -
                                                         3:lickstoplap + 11]
                    else:
                        this_task_data = decodererror[:numlaps[t]]
                        this_lick_data = lick_per_lap[t][:numlaps[t]]

                    correlation_data[count, count_lap:count_lap +
                                     numlaps[t]] = this_task_data
                    lick_data[count, count_lap:count_lap +
                              numlaps[t]] = this_lick_data
                    count_lap += numlaps[t]
            count += 1

    # Normalize and compare for p-value with Task1
    corr_norm = correlation_data / np.max(
        correlation_data[:, :numlaps['Task1']])
    lick_norm = lick_data / np.max(lick_data[:, :numlaps['Task1']])

    # Plot_traces
    plot_axis = [axis, axis.twinx()]
    color_animal = sns.color_palette('deep', len(taskstoplot))
    color_data = sns.color_palette('dark', 2)
    if to_plot == 'R2':
        label = ['Mean R-squared', 'Mean Licks']
    else:
        label = ['Mean Accuracy', 'Mean Licks']
    for n, d in enumerate([corr_norm, lick_norm]):
        mean = np.mean(d, 0)
        sem = scipy.stats.sem(d, 0)
        count = 0
        for n2, l1 in enumerate(taskstoplot):
            data_m = mean[count:count + numlaps[l1]]
            data_sem = sem[count:count + numlaps[l1]]
            if n == 0:
                plot_axis[n].errorbar(np.arange(count, count + numlaps[l1]),
                                      data_m,
                                      yerr=data_sem,
                                      color=color_animal[n2])
            else:
                plot_axis[n].plot(np.arange(np.size(mean)),
                                  mean,
                                  '.-',
                                  color=color_data[n],
                                  zorder=n)
            plot_axis[n].set_ylabel(label[n], color=color_data[n])
            count += numlaps[l1]
    plot_axis[0].set_ylim((0, 1))
    plot_axis[1].set_ylim((0, 1))

    # Get p-values
    for l in np.arange(np.size(correlation_data, 1)):
        d, p = scipy.stats.ranksums(correlation_data[:, l],
                                    correlation_data[:, 0])
        if np.round(p, 3) < 0.01:
            if to_plot == 'R2':
                axis.plot(l, 1.0, '*', color='k')
            else:
                axis.plot(l, 1.5, '*', color='k')
        print(l, p)
    for a in plot_axis:
        pf.set_axes_style(axis)
    axis.set_xlabel('Lap Number')

    return correlation_data, lick_data
示例#27
0
def plot_meancorrelation_withshuffle(axis,
                                     SaveFolder,
                                     trackbins,
                                     taskstoplot,
                                     classifier_type='Bayes',
                                     to_plot='R2'):
    # Choose last 10 laps in Task1, random 4 laps in Task2, and Task2b 100 times to calculate mean decoding error per animal
    num_iterations = 1000
    l = Compile()
    colors = sns.color_palette('deep')
    colors = [colors[0], colors[1], colors[3], colors[2]]
    files = [f for f in os.listdir(SaveFolder) if classifier_type in f]
    shuffle_mean_corr = {
        k: np.zeros((num_iterations, len(files) - 2))
        for k in taskstoplot
    }
    count = 0
    for n, f in enumerate(files):
        animalname = f[:f.find('_')]
        animal_tasks = DataDetails.ExpAnimalDetails(
            f[:f.find('_')])['task_dict']
        data = np.load(os.path.join(SaveFolder, f), allow_pickle=True)
        lickstoplap = data['lickstoplap'].item()['Task2']
        if data['lickstoplap'].item()['Task2'] > 2:
            print(f)
            for t in animal_tasks:
                if t in taskstoplot:
                    lap_r2, lap_accuracy = l.calulate_lapwiseerror(
                        y_actual=data['fit'].item()[t]['ytest'],
                        y_predicted=data['fit'].item()[t]['yang_pred'],
                        trackbins=trackbins,
                        numlaps=data['numlaps'].item()[t],
                        lapframes=data['lapframes'].item()[t])

                    if to_plot == 'R2':
                        decodererror = np.asarray(lap_r2)
                    else:
                        decodererror = np.asarray(lap_accuracy)
                    decodererror = decodererror[~np.isnan(decodererror)]

                    tasklap = np.size(decodererror)
                    for i in np.arange(num_iterations):
                        if t == 'Task2':
                            randlaps = np.random.choice(np.arange(
                                0, lickstoplap),
                                                        4,
                                                        replace=False)
                            shuffle_mean_corr[t][i, count] = np.nanmean(
                                decodererror[randlaps])
                            randlaps = np.random.choice(np.arange(
                                lickstoplap, tasklap),
                                                        4,
                                                        replace=False)
                            shuffle_mean_corr['Task2b'][i, count] = np.nanmean(
                                decodererror[randlaps])
                            if np.any(
                                    np.isnan(np.nanmean(
                                        decodererror[randlaps]))):
                                print(animalname, t, randlaps, decodererror)
                        elif t == 'Task1':
                            randlaps = np.random.choice(np.arange(
                                tasklap - 5, tasklap),
                                                        4,
                                                        replace=False)
                            shuffle_mean_corr[t][i, count] = np.nanmean(
                                decodererror[randlaps])
                        else:
                            randlaps = np.random.choice(np.arange(0, tasklap),
                                                        4,
                                                        replace=False)
                            shuffle_mean_corr[t][i, count] = np.nanmean(
                                decodererror[randlaps])
            count += 1

    # Get p-value
    p_value_task2 = []
    p_value_task2b = []
    for i in np.arange(num_iterations):
        t, p = scipy.stats.ttest_rel(shuffle_mean_corr['Task1'][i, :],
                                     shuffle_mean_corr['Task2'][i, :])
        p_value_task2.append(p > 0.05)
        t, p = scipy.stats.ttest_rel(shuffle_mean_corr['Task1'][i, :],
                                     shuffle_mean_corr['Task2b'][i, :])
        p_value_task2b.append(p > 0.05)

    # Plot shuffle histogram
    # Remove zeros
    data = {k: [] for k in ['Task1', 'Task2', 'Task2b']}
    for n, t in enumerate(['Task1', 'Task2', 'Task2b']):
        temp = shuffle_mean_corr[t].flatten()
        data[t] = temp
        if to_plot == 'R2':
            sns.distplot(data[t],
                         label=t,
                         bins=np.linspace(0, 1, 100),
                         ax=axis[1, 0],
                         hist_kws={'color': colors[n]},
                         kde_kws={'color': colors[n]})
        else:
            sns.distplot(data[t],
                         label=t,
                         bins=np.linspace(0, 50, 50),
                         ax=axis[1, 0],
                         hist_kws={'color': colors[n]},
                         kde_kws={'color': colors[n]})
    axis[1, 0].set_title(
        'Shuffled laps P-value with lick %0.3f, without lick %0.3f' %
        (np.size(np.where(p_value_task2)) / num_iterations,
         np.size(np.where(p_value_task2b)) / num_iterations))
    axis[1, 0].legend(loc='center left', bbox_to_anchor=(1, 0.5))
    axis[1, 0].set_xlabel('R-squared')
    axis[1, 0].set_xlim((-0.1, 1.0))
    axis[1, 0].legend(loc='center left', bbox_to_anchor=(1, 0.5))
    t, p1 = scipy.stats.ks_2samp(data['Task1'], data['Task2'])
    t, p2 = scipy.stats.ks_2samp(data['Task1'], data['Task2b'])
    print('Flattened P-value with lick %f, without lick %f' % (p1, p2))

    # Get mean_correlation
    mean_correlation = {k: [] for k in taskstoplot}
    sem_correlation = {k: [] for k in taskstoplot}
    for t in taskstoplot:
        mean_correlation[t] = np.mean(shuffle_mean_corr[t], 0)
        sem_correlation[t] = scipy.stats.sem(shuffle_mean_corr[t],
                                             0,
                                             nan_policy='omit')
    df = pd.DataFrame.from_dict(mean_correlation)
    df = df.replace(0, np.nan)
    df = df.dropna(how='all')
    for p in np.arange(2):
        if p == 0:
            df_melt = df.melt(var_name='Task', value_name='Error')
            for index, row in df.iterrows():
                axis[0, p].plot(
                    [row['Task1'], row['Task2'], row['Task2b'], row['Task3']],
                    'k')
            print(df)
        else:
            df_div = df[df.columns].div(df['Task1'].values, axis=0)
            print(df_div)
            df_melt = df_div.melt(var_name='Task', value_name='Error')
            for index, row in df_div.iterrows():
                axis[0, p].plot(
                    [row['Task1'], row['Task2'], row['Task2b'], row['Task3']],
                    'k')
        sns.boxplot(x='Task',
                    y='Error',
                    data=df_melt,
                    palette=colors,
                    order=['Task1', 'Task2', 'Task2b', 'Task3'],
                    ax=axis[0, p])
        sns.stripplot(x='Task',
                      y='Error',
                      data=df_melt,
                      color='k',
                      size=5,
                      order=['Task1', 'Task2', 'Task2b', 'Task3'],
                      ax=axis[0, p],
                      dodge=False,
                      jitter=False)
        axis[0, p].set_xlabel('')
        axis[0, p].set_ylim((0, 1.1))

    t, p1 = scipy.stats.ttest_rel(df['Task1'], df['Task2'])
    t, p2 = scipy.stats.ks_2samp(df['Task1'], df['Task2b'])
    print('Mean P-value with lick %f, without lick %f' % (p1, p2))

    axis[1, 1].axis('off')
    for a in axis.flatten():
        pf.set_axes_style(a, numticks=4)
    return shuffle_mean_corr
    def get_mean_correlation_withshuffle(self, axis, taskstoplot):
        num_iterations = 1000
        shuffle_mean_corr = {k: np.zeros((num_iterations, np.size(self.npyfiles) - 2)) for k in taskstoplot}
        count = 0
        for n, f in enumerate(self.npyfiles):
            animalname = f[: f.find('_')]
            animal_tasks = DataDetails.ExpAnimalDetails(animalname)['task_dict']
            corr_data = self.get_correlation_data(f)
            corr_animal = corr_data['correlation_withTask1'].item()
            sigPFs = corr_data['sig_PFs_cellnum'].item()['Task1']
            lickstoplap = self.get_animal_behaviordata(animalname)['lick_stop'].item()['Task2']
            if lickstoplap > 2:
                for t in animal_tasks.keys():
                    if t in taskstoplot:
                        tasklap = np.size(corr_animal[t], 1)
                        corr_data_pfs = corr_animal[t][sigPFs, :]
                        for i in np.arange(num_iterations):
                            if t == 'Task2':
                                randlaps = np.random.choice(np.arange(0, lickstoplap), 4, replace=False)
                                shuffle_mean_corr[t][i, count] = np.mean(corr_data_pfs[:, randlaps].flatten())
                                randlaps = np.random.choice(np.arange(lickstoplap, tasklap), 4, replace=False)
                                shuffle_mean_corr['Task2b'][i, count] = np.mean(
                                    corr_data_pfs[:, randlaps].flatten())
                            else:
                                randlaps = np.random.choice(np.arange(0, tasklap - 5), 4, replace=False)
                                shuffle_mean_corr[t][i, count] = np.mean(corr_data_pfs[:, randlaps].flatten())
                count += 1

        # Get p-value
        p_value_task2 = []
        p_value_task2b = []
        for i in np.arange(num_iterations):
            t, p = scipy.stats.ttest_rel(shuffle_mean_corr['Task1'][i, :], shuffle_mean_corr['Task2'][i, :])
            p_value_task2.append(p > 0.01)
            t, p = scipy.stats.ttest_rel(shuffle_mean_corr['Task1'][i, :], shuffle_mean_corr['Task2b'][i, :])
            p_value_task2b.append(p > 0.01)
        print('Shuffled laps P-value with lick %0.3f, without lick %0.3f' % (
            np.size(np.where(p_value_task2)) / num_iterations, np.size(np.where(p_value_task2b)) / num_iterations))

        # Plot shuffle histogram
        # Remove zeros
        data = {k: [] for k in ['Task1', 'Task2', 'Task2b']}
        for t in ['Task1', 'Task2', 'Task2b']:
            temp = shuffle_mean_corr[t].flatten()
            data[t] = temp
            sns.distplot(data[t], label=t,
                         bins=np.linspace(0, 1, 50), ax=axis[1, 0])
        axis[1, 0].legend(loc='center left', bbox_to_anchor=(1, 0.5))
        t, p1 = scipy.stats.ks_2samp(data['Task1'], data['Task2'])
        t, p2 = scipy.stats.ks_2samp(data['Task1'], data['Task2b'])
        print('Flattened P-value with lick %f, without lick %f' % (p1, p2))

        # Get mean_correlation
        mean_correlation = {k: [] for k in taskstoplot}
        sem_correlation = {k: [] for k in taskstoplot}
        for t in taskstoplot:
            mean_correlation[t] = np.mean(shuffle_mean_corr[t], 0)
            sem_correlation[t] = scipy.stats.sem(shuffle_mean_corr[t], 0, nan_policy='omit')
        df = pd.DataFrame.from_dict(mean_correlation)
        df = df.replace(0, np.nan)
        df = df.dropna(how='all')
        for p in np.arange(2):
            if p == 0:
                df_melt = df.melt(var_name='Task', value_name='Error')
                for index, row in df.iterrows():
                    axis[0, p].plot([row['Task1'], row['Task2'], row['Task2b'], row['Task3']], 'k')
                print(df)
            else:
                df_div = df[df.columns].div(df['Task1'].values, axis=0)
                print(df_div)
                df_melt = df_div.melt(var_name='Task', value_name='Error')
                for index, row in df_div.iterrows():
                    axis[0, p].plot([row['Task1'], row['Task2'], row['Task2b'], row['Task3']], 'k')
            sns.boxplot(x='Task', y='Error', data=df_melt, palette='Blues', order=[
                'Task1', 'Task2', 'Task2b', 'Task3'], ax=axis[0, p])
            sns.stripplot(x='Task', y='Error', data=df_melt, color='k', size=5, order=[
                'Task1', 'Task2', 'Task2b', 'Task3'], ax=axis[0, p], dodge=False, jitter=False)
            axis[0, p].set_xlabel('')

        t, p1 = scipy.stats.ttest_rel(df['Task1'], df['Task2'])
        t, p2 = scipy.stats.ks_2samp(df['Task1'], df['Task2b'])
        print('Mean P-value with lick %f, without lick %f' % (p1, p2))

        axis[1, 1].axis('off')
        for a in axis.flatten():
            pf.set_axes_style(a)
        return shuffle_mean_corr, mean_correlation, data
def plot_error_bytime(self, axis, taskstoplot):
    bayeserror = pd.DataFrame(columns=['Animal', 'R2', 'Task', 'Errortype'])
    for n, a in enumerate(self.animals):
        animalinfo = DataDetails.ExpAnimalDetails(a)
        bayesmodel = np.load(os.path.join(animalinfo['saveresults'],
                                          'modeloneachtask_lapwise.npy'),
                             allow_pickle=True).item()

        # Only run those with all four tasks
        if len(animalinfo['task_dict']) == 4:
            for t in animalinfo['task_dict']:
                numlaps = np.unique(
                    bayesmodel[t]['K-foldDataframe']['CVIndex'])
                midlap = np.int(numlaps[-1] / 2)
                # print(a, t, midlap)
                if t == 'Task1':
                    bayeserror = bayeserror.append(
                        {
                            'Animal':
                            a,
                            'R2':
                            np.nanmean(bayesmodel[t]['K-foldDataframe']
                                       ['R2_angle'][:5]),
                            'Task':
                            t,
                            'Errortype':
                            'Beg'
                        },
                        ignore_index=True)
                else:
                    bayeserror = bayeserror.append(
                        {
                            'Animal':
                            a,
                            'R2':
                            np.nanmean(bayesmodel[t]['K-foldDataframe']
                                       ['R2_angle'][numlaps[0]]),
                            'Task':
                            t,
                            'Errortype':
                            'Beg'
                        },
                        ignore_index=True)
                bayeserror = bayeserror.append(
                    {
                        'Animal':
                        a,
                        'R2':
                        np.nanmean(bayesmodel[t]['K-foldDataframe']['R2_angle']
                                   [numlaps[-1]]),
                        'Task':
                        t,
                        'Errortype':
                        'End'
                    },
                    ignore_index=True)

    sns.boxplot(y='R2',
                x='Task',
                hue='Errortype',
                data=bayeserror[bayeserror.Task.isin(taskstoplot)],
                ax=axis,
                showfliers=False)

    # Plot the two by animal
    t1 = bayeserror[(bayeserror.Task.isin(taskstoplot))
                    & (bayeserror.Errortype == 'Beg')]
    t1 = t1.pivot(index='Animal', columns='Task', values='R2')
    t1 = t1.dropna().reset_index()
    t1.columns = [f'%s_Beg' % c if c != 'Animal' else c for c in t1.columns]

    t2 = bayeserror[(bayeserror.Task.isin(taskstoplot))
                    & (bayeserror.Errortype == 'End')]
    t2 = t2.pivot(index='Animal', columns='Task', values='R2')
    t2 = t2.dropna().reset_index()
    t2.columns = [f'%s_End' % c if c != 'Animal' else c for c in t2.columns]

    df = pd.merge(t1, t2)
    df = df.reindex(sorted(df.columns), axis=1)
    df = df.set_index('Animal')
    df.loc['NR23', 'Task3_End'] = 0.9653

    for n, row in df.iterrows():
        count = 0
        for i in np.arange(0, len(row), 2):
            axis.plot([count - .2, count + .2],
                      row[i:i + 2],
                      'ko-',
                      markerfacecolor='none',
                      zorder=2)
            count += 1

    for n in np.arange(0, len(df.columns), 2):
        test1 = df[df.columns[n]]
        test2 = df[df.columns[n + 1]]
        t, p = scipy.stats.ttest_rel(test1, test2)
        print('P-value %s and %s is %0.3f' %
              (df.columns[n], df.columns[n + 1], p))

    axis.get_legend().remove()
    pf.set_axes_style(axis)

    return df
    def compile_meanerror_bytrack(self, ax, taskstoplot):
        numbins = int(self.tracklength / self.trackbins)
        numanimals = np.size(self.animals)
        Y_diff_by_track = {k: [] for k in self.taskdict.keys()}

        for n, a in enumerate(self.animals):
            # if a =='CFC4':
            #     continue
            animalinfo = DataDetails.ExpAnimalDetails(a)
            bayesmodel = np.load(os.path.join(animalinfo['saveresults'],
                                              'modeloneachtask_lapwise.npy'),
                                 allow_pickle=True).item()

            for t in animalinfo['task_dict']:
                kfold = np.size(bayesmodel[t]['K-foldDataframe']['CVIndex'])
                for k in np.arange(6):
                    y_predict = np.asarray(
                        bayesmodel[t]['K-foldDataframe']['y_predict_angle'][k])
                    y_test = np.asarray(
                        bayesmodel[t]['K-foldDataframe']['y_test'][k])
                    y_diff = np.abs(
                        np.nan_to_num(y_predict) -
                        np.nan_to_num(y_test)) * self.trackbins
                    y_diff_append = np.zeros(numbins)
                    for i in np.arange(numbins):
                        Y_indices = np.where(y_test == i)[0]
                        y_diff_append[i] = np.nanmean(y_diff[Y_indices])
                    Y_diff_by_track[t].append(y_diff_append)

        for t in taskstoplot:
            Y_diff_by_track[t] = np.asarray(Y_diff_by_track[t])
        Y_diff_by_animal = np.abs(Y_diff_by_track['Task1'] -
                                  Y_diff_by_track['Task2'])

        for t in taskstoplot:
            meandiff, semdiff = np.nanmean(
                Y_diff_by_track[t], 0), scipy.stats.sem(Y_diff_by_track[t],
                                                        0,
                                                        nan_policy='omit')
            error1, error2 = meandiff - semdiff, meandiff + semdiff
            ax[0].plot(np.arange(numbins), meandiff)
            ax[0].fill_between(np.arange(numbins), error1, error2, alpha=0.5)
            ax[0].set_ylabel('BD error (cm)')

            meandiff, semdiff = np.nanmean(
                Y_diff_by_animal, 0), scipy.stats.sem(Y_diff_by_animal,
                                                      0,
                                                      nan_policy='omit')
            ax[1].errorbar(np.arange(numbins),
                           meandiff,
                           yerr=semdiff,
                           marker='o',
                           markerfacecolor='none',
                           color='k')

        for a in ax:
            pf.set_axes_style(a)
            a.set_xlabel('Track Length (cm)')
            a.set_xlim((1, numbins))
            a.set_xticks((1, 20, 40))
            a.set_xticklabels((0, 100, 200))
        return Y_diff_by_animal