コード例 #1
0
    def check_reward_in_data(self, pwd):
        for mouseName, dfMouse in self.dataPaths.groupby(['mouse']):
            h5fname = os.path.join(pwd, mouseName + '.h5')

            for idx, row in dfMouse.iterrows():
                session = row['day'] + '_' + row['session']
                with h5py.File(h5fname, 'a') as h5f:
                    if session not in h5f['metadata'].keys():
                        print(mouseName, session, 'has no metadata, skipping')
                        continue

                    dataRAW = np.copy(h5f['data'][session])

                delay = pd_is_one_row(
                    pd_query(
                        self.dfSession, {
                            'mousename': mouseName,
                            'dateKey': row['day'],
                            'sessionKey': row['session']
                        }))[1]['delay']

                nTimestepVid = dataRAW.shape[1]
                rewStartIdx = int((5 + delay) * 20)
                overlap = max(0, nTimestepVid - rewStartIdx)

                print(mouseName, session, delay, nTimestepVid, rewStartIdx,
                      overlap)
コード例 #2
0
def plot_metric_bulk_1D(dataDB, ds, metricName, nameSuffix, prepFunc=None, xlim=None, ylim=None, yscale=None,
                     verbose=True, xFunc=None, haveTimeLabels=False):#, dropCols=None):
    # 1. Extract all results for this test
    dfAll = ds.list_dsets_pd().fillna('None')
    # if dropCols is not None:
    #     dfAll = dfAll.drop(dropCols, axis=1)

    dfAnalysis = pd_query(dfAll, {'metric' : metricName, "name" : nameSuffix})
    dfAnalysis = pd_move_cols_front(dfAnalysis, ['metric', 'name', 'mousename'])  # Move leading columns forwards for more informative printing/saving
    dfAnalysis = dfAnalysis.drop(['target_dim', 'datetime', 'shape'], axis=1)

    # Loop over all other columns except mousename
    colsExcl = list(set(dfAnalysis.columns) - {'mousename', 'dset'})

    for colVals, dfSub in dfAnalysis.groupby(colsExcl):
        fig, ax = plt.subplots(figsize=(4, 4))

        if verbose:
            print(list(colVals))

        for idxMouse, rowMouse in dfSub.sort_values(by='mousename').iterrows():
            print(list(rowMouse.values))

            dataThis = ds.get_data(rowMouse['dset'])
            assert dataThis.ndim == 1, 'Only using 1D data for this plot function'

            if prepFunc is not None:
                dataThis = prepFunc(dataThis)

            #                     if datatype == 'raw':
            #                         nTrialThis = dataDB.get_ntrial_bytype({'mousename' : row['mousename']}, trialType=trialType, performance=performance)
            #                         dataThis *= np.sqrt(48*nTrialThis)
            #                         print('--', row['mousename'], nTrialThis)

            x = np.arange(len(dataThis)) if xFunc is None else np.array(xFunc(rowMouse['mousename'], len(dataThis)))
            x, dataThis = drop_nan_rows([x, dataThis])

            ax.plot(x, dataThis, label=rowMouse['mousename'])

        if yscale is not None:
            ax.set_yscale(yscale)

        if haveTimeLabels:
            dataDB.label_plot_timestamps(ax, linecolor='y', textcolor='k', shX=-0.5, shY=0.05)

        dataName = rowMouse.drop(['dset', 'mousename'])
        dataName = '_'.join([str(el) for el in dataName])

        prefixPath = 'pics/bulk/' + metricName + '/'
        make_path(prefixPath)

        ax.legend()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.set_xlabel(nameSuffix)
        ax.set_ylabel(metricName)
        plt.savefig(prefixPath + dataName + '.png', dpi=200)
        plt.close()
コード例 #3
0
    def behaviour_tune_resample_kernel(self, mousename, session, sig2,
                                       trialType='Hit', trialIdx=0, srcFreq=30.0, trgFreq=20.0):
        dayKey = '_'.join(session.split('_')[:3])
        sessionKey = session.split('_')[3]

        idx, row = pd_is_one_row(pd_query(self.dataPaths, {'mouse': mousename, 'day': dayKey, 'session':sessionKey}))

        prepcommon.behaviour_tune_resample_kernel(row['pathMovementVectors'], sig2,
                                                  trialType=trialType, trialIdx=trialIdx, srcFreq=srcFreq, trgFreq=trgFreq)
コード例 #4
0
    def get_data_recent_by_query(self, queryDict, listDF=None):
        if listDF is None:
            listDF = self.list_dsets_pd()

        rows = pd_query(listDF, queryDict)

        # Find index of the latest result
        maxRowIdx = rows['datetime'].idxmax()
        attrs = rows.loc[maxRowIdx]
        data = self.get_data(attrs['dset'])

        return data, attrs
コード例 #5
0
    def delete_by_query(self, queryDict=None, timestr=None):
        if os.path.isfile(self.fname):
            if queryDict is None and timestr is None:
                raise ValueError('Must specify what to delete')

            rows = self.list_dsets_pd()
            if queryDict is not None:
                rows = pd_query(rows, queryDict)
            if timestr is not None:
                timeObj = datetime.strptime(timestr, self.pandasTimeFormat)
                rows = rows[rows['datetime'] >= timeObj]

            self.delete_rows(rows)
コード例 #6
0
def barplot_stacked(ax, df, xKey, yKey):
    sweepSet = set(df.columns) - {xKey, yKey}
    sweepVals = {key: list(set(df[key])) for key in sweepSet}
    sweepDF = pandas_helper.outer_product_df(sweepVals)

    bottom = np.zeros(len(set(df[xKey])))
    for idx, row in sweepDF.iterrows():
        queryDict = dict(row)
        dfThis = pandas_helper.pd_query(df, queryDict)
        ax.bar(dfThis[xKey],
               dfThis[yKey],
               bottom=bottom,
               label=str(list(queryDict.values())))
        bottom += np.array(dfThis[yKey])

    ax.set_xlabel(xKey)
    ax.set_ylabel(yKey)
    ax.legend()
コード例 #7
0
def plot_singlets(dataDB, h5fname, dfSummary, nTop=20, dropChannels=None):
    lmap = dataDB.map_channel_labels_canon()
    dataLabels = dataDB.get_channel_labels()
    dataLabelsCanon = [lmap[l] for l in dataLabels]

    pidTypes = ['unique', 'syn', 'red']

    groupLst = sorted(list(set(dfSummary.columns) - {'key', 'mousename'}))
    for key, dataMouse in dfSummary.groupby(groupLst):
        mice = list(sorted(set(dataMouse['mousename'])))
        dfJointDict = read_parse_joint_dataframe(dataMouse,
                                                 h5fname,
                                                 mice,
                                                 pidTypes,
                                                 dropChannels=dropChannels)

        fig, ax = plt.subplots(nrows=len(pidTypes),
                               figsize=(len(pidTypes) * 6, 12),
                               tight_layout=True)
        for iPid, pidType in enumerate(pidTypes):
            rezDict = {}
            for mousename in mice:
                rezTmp = []
                for label in dataLabels:
                    rezTmp += [
                        np.mean(
                            pd_query(dfJointDict[pidType],
                                     {'T': label})['muTrue_' + mousename])
                    ]
                rezDict[mousename] = rezTmp

            barplot_stacked_indexed(ax[iPid],
                                    rezDict,
                                    xTickLabels=dataLabelsCanon,
                                    xLabel='singlet',
                                    yLabel='bits',
                                    title=pidType,
                                    iMax=None,
                                    rotation=90)

        fig.suptitle('_'.join(key))
        plt.show()
コード例 #8
0
def plot_2D_outer_bymouse(dataDB, ds, metricName, nameSuffix='bymouse'):
    # Read dataframe, filter out desired metric
    df = ds.list_dsets_pd()
    dfEff = pd_query(df, {'metric': metricName, 'name': nameSuffix}).copy()

    # Read and append scalar to the dataframe
    rezLst = []
    for idx, row in dfEff.iterrows():
        rezLst += [ds.get_data(row['dset'])]
    dfEff[metricName] = np.array(rezLst)

    # Drop useless columns
    dfEff.drop(['dset', 'shape', 'datetime', 'target_dim', 'zscoreDim'], axis=1)

    # Drop averages for this plot
    dfEff = dfEff.loc[dfEff['intervName'] != 'AVG']

    fig, ax = plt.subplots(figsize=(4, 4))
    plot_df_2D_outer_product(ax, dfEff, ['datatype', 'mousename'], ['intervName', 'trialType'],
                             'rank_effective', orderDict={'intervName': dataDB.get_interval_names()}, vmin=1)
    plt.savefig(metricName + '_' + nameSuffix + '.svg')
    plt.show()
コード例 #9
0
def plot_consistency_bytrialtype(h5fname,
                                 dfSummary,
                                 dropChannels=None,
                                 performance=None,
                                 datatype=None,
                                 trialTypes=None,
                                 kind='point',
                                 fisherThr=0.1,
                                 limits=None):

    pidTypes = ['unique', 'syn', 'red']
    limitKWargs = {
        'vmin': limits[0],
        'vmax': limits[1]
    } if limits is not None else {}

    if performance is None:
        dfSummaryEff = pd_query(dfSummary, {'datatype': datatype})
    else:
        dfSummaryEff = pd_query(dfSummary, {
            'datatype': datatype,
            'performance': performance
        })

    dfColumns = ['mousename', 'phase', 'consistency']
    dfConsistencyDict = {
        pidType: pd.DataFrame(columns=dfColumns)
        for pidType in pidTypes
    }
    for (mousename, phase), df1 in dfSummaryEff.groupby(['mousename',
                                                         'phase']):
        fnameSuffix = '_'.join([mousename, datatype, phase, str(performance)])
        trialTypes = trialTypes if trialTypes is not None else sorted(
            list(set(df1['trialType'])))
        nTrialTypes = len(trialTypes)

        dfTrialTypeDict = {}
        for iPid, pidType in enumerate(pidTypes):
            dfTrialTypeDict[pidType] = pd.DataFrame()

        for trialType, dfTrialType in df1.groupby(['trialType']):
            if trialType in trialTypes:
                dfTmp = read_parse_joint_dataframe(dfTrialType,
                                                   h5fname, [mousename],
                                                   pidTypes,
                                                   dropChannels=dropChannels)
                for iPid, pidType in enumerate(pidTypes):
                    dfTrialTypeDict[pidType][trialType] = dfTmp[pidType][
                        'muTrue_' + mousename]

        for iPid, pidType in enumerate(pidTypes):
            maxRange = 0.35 if pidType == 'syn' else 1.0

            # As consistency metric perform Fischer's exact test for significant vs unsignificant links
            # As pairplot show contingency tables
            if kind == 'fisher':
                fischerLabels = ['low', 'high']
                rezMat = np.full((nTrialTypes, nTrialTypes), np.nan)
                fig, ax = plt.subplots(nrows=nTrialTypes,
                                       ncols=nTrialTypes,
                                       figsize=(4 * nTrialTypes,
                                                4 * nTrialTypes))
                for idxTTi, iTT in enumerate(trialTypes):
                    ax[idxTTi, 0].set_ylabel(iTT)
                    ax[-1, idxTTi].set_xlabel(iTT)
                    for idxTTj, jTT in enumerate(trialTypes):
                        if idxTTi == idxTTj:
                            ax[idxTTi][idxTTj].hist(
                                dfTrialTypeDict[pidType][iTT],
                                range=[0, maxRange],
                                bins=50)
                            ax[idxTTi][idxTTj].axvline(x=fisherThr,
                                                       linestyle='--',
                                                       color='pink')
                        else:
                            iBin = dfTrialTypeDict[pidType][iTT] >= fisherThr
                            jBin = dfTrialTypeDict[pidType][jTT] >= fisherThr
                            M = confusion_matrix(iBin, jBin)
                            M = M.astype(float) / np.sum(M)

                            # consistency = fisher_exact(M, alternative='two_sided')[0]
                            # consistency = -np.log10(fisher_exact(M, alternative='two_sided')[1])
                            consistency = cohen_kappa_score(iBin, jBin)
                            rezMat[idxTTi][idxTTj] = consistency

                            sns.heatmap(ax=ax[idxTTi][idxTTj],
                                        data=M,
                                        annot=True,
                                        cmap='jet',
                                        xticklabels=fischerLabels,
                                        yticklabels=fischerLabels)
            else:
                # As consistency metric use correlation coefficient between values
                rezMat = np.zeros((nTrialTypes, nTrialTypes))
                for idxTTi, iTT in enumerate(trialTypes):
                    for idxTTj, jTT in enumerate(trialTypes):
                        rezMat[idxTTi][idxTTj] = np.corrcoef(
                            dfTrialTypeDict[pidType][iTT],
                            dfTrialTypeDict[pidType][jTT])[0, 1]

                if kind == 'point':
                    # As pairplot use scatter
                    pPlot = sns.pairplot(data=dfTrialTypeDict[pidType],
                                         vars=trialTypes)  #, kind='kde')
                elif kind == 'heatmap':
                    # As pairplot use heatmap of binned scatter points
                    fig, ax = plt.subplots(nrows=nTrialTypes,
                                           ncols=nTrialTypes,
                                           figsize=(4 * nTrialTypes,
                                                    4 * nTrialTypes))

                    for idxTTi, iTT in enumerate(trialTypes):
                        ax[idxTTi, 0].set_ylabel(iTT)
                        ax[-1, idxTTi].set_xlabel(iTT)
                        for idxTTj, jTT in enumerate(trialTypes):
                            if idxTTi == idxTTj:
                                ax[idxTTi][idxTTj].hist(
                                    dfTrialTypeDict[pidType][iTT],
                                    range=[0, maxRange],
                                    bins=50)
                            else:
                                ax[idxTTi][idxTTj].hist2d(
                                    dfTrialTypeDict[pidType][iTT],
                                    dfTrialTypeDict[pidType][jTT],
                                    range=[[0, maxRange], [0, maxRange]],
                                    bins=[50, 50],
                                    cmap='jet')

            plt.savefig('pics/' + pidType + '_consistency_bymouse_scatter_' +
                        fnameSuffix + '.png')
            plt.close()

            fig, ax = plt.subplots()
            sns.heatmap(ax=ax,
                        data=rezMat,
                        annot=True,
                        cmap='jet',
                        xticklabels=trialTypes,
                        yticklabels=trialTypes,
                        **limitKWargs)
            # imshow(fig, ax, rezMat, haveColorBar=True, limits=[0,1], xTicks=mice, yTicks=mice, cmap='jet')
            plt.savefig('pics/' + pidType + '_consistency_bymouse_metric_' +
                        fnameSuffix + '.png')
            plt.close()

            avgConsistency = np.round(np.mean(offdiag_1D(rezMat)), 2)
            dfConsistencyDict[pidType] = pd_append_row(
                dfConsistencyDict[pidType], [mousename, phase, avgConsistency])

    for iPid, pidType in enumerate(pidTypes):
        fig, ax = plt.subplots()
        dfPivot = pd_pivot(dfConsistencyDict[pidType], *dfColumns)
        sns.heatmap(data=dfPivot, ax=ax, annot=True, cmap='jet', **limitKWargs)
        fig.savefig('pics/' + 'summary_' + pidType + '_consistency_metric_' +
                    datatype + '_' + str(performance) + '.png')
        plt.close()
コード例 #10
0
def test_avg_bits(dataDB, mc, h5fname, h5fnameRand, dfSummary, dfSummaryRand):
    channelLabels = dataDB.get_channel_labels()
    nChannel = len(channelLabels)
    pidTypes = ['unique', 'syn', 'red']
    '''
    Plan:
    1. Loop over (col / key, mousename) in dfSummaryRand
    2. Query dfSummary using rand key
    3. Loop over (col / key, mousename) in dfSummaryQueried
    4. Loop over mousename in dfSummaryRand
    5. Query dfSummaryQueried using mousename
    6. Extract datasets
    7. Barplot+Statannot
    '''

    groupLstRand = sorted(
        list(set(dfSummaryRand.columns) - {'key', 'mousename'}))
    print(set(dfSummaryRand['performance']))

    print(groupLstRand)
    for keyRand, dfRandMouse in dfSummaryRand.groupby(groupLstRand):
        print(keyRand)
        if isinstance(keyRand, str):
            keyRand = [keyRand]

        selectorLstRand = dict(zip(groupLstRand, keyRand))
        dfSummQueried = pd_query(dfSummary, selectorLstRand)

        groupLstQueried = sorted(
            list(
                set(dfSummQueried.columns) - {'key', 'mousename'} -
                set(groupLstRand)))
        for key, dfMouse in dfSummQueried.groupby(groupLstQueried):
            print('--', key)

            dfTot = pd.DataFrame()
            for idx, row in dfMouse.iterrows():
                # Read and preprocess true data
                dfRezTrue = pd.read_hdf(h5fname, row['key'])
                dfRezTrue = preprocess_unique(dfRezTrue)
                dfRezTrue = preprocess_drop_negative(dfRezTrue)
                dfRezTrue['type'] = 'Measured'
                dfRezTrue['mousename'] = row['mousename']
                dfTot = dfTot.append(dfRezTrue)

                # Read and preprocess random data
                rowRand = pd_is_one_row(
                    pd_query(dfRandMouse, {'mousename': row['mousename']}))[1]
                dfRezRand = pd.read_hdf(h5fnameRand, rowRand['key'])
                dfRezRand = preprocess_unique(dfRezRand)
                dfRezRand = preprocess_drop_negative(dfRezRand)
                dfRezRand['type'] = 'Shuffle'
                dfRezRand['mousename'] = rowRand['mousename']
                dfTot = dfTot.append(dfRezRand)

            # Barplot differences
            fig, ax = plt.subplots(ncols=3, figsize=(12, 4))
            fig.suptitle('_'.join(list(key) + list(keyRand)))
            for iPid, pidType in enumerate(pidTypes):
                dfPID = dfTot[dfTot['PID'] == pidType]
                sns.violinplot(ax=ax[iPid],
                               x="mousename",
                               y="muTrue",
                               hue="type",
                               data=dfPID,
                               scale='width',
                               cut=0)

                for mousename in sorted(set(dfPID['mousename'])):
                    dataTrue = pd_query(dfPID, {
                        'mousename': mousename,
                        'type': 'Measured'
                    })['muTrue']
                    dataRand = pd_query(dfPID, {
                        'mousename': mousename,
                        'type': 'Shuffle'
                    })['muTrue']
                    print(
                        'Test:', pidType, mousename, 'pval =',
                        mannwhitneyu(dataTrue, dataRand,
                                     alternative='greater')[1])

                ax[iPid].set_yscale('log')
                ax[iPid].set_ylabel('Bits')
                ax[iPid].set_title(pidType)
            plt.show()
コード例 #11
0
def get_sessions(dfRawH5, mousename):
    row = pd_is_one_row(pd_query(dfRawH5, {'mousename' : mousename}))[1]
    with h5py.File(row['path'], 'r') as h5file:
        return list(h5file['data'].keys())
コード例 #12
0
 def get_delay_length(self, mousename, session):
     row = pd_is_one_row(pd_query(self.dfSessions, {'mousename': mousename, 'session': session}))[1]
     return row['delay']
コード例 #13
0
def plot_consistency_significant_activity_byphase(dataDB,
                                                  ds,
                                                  intervals,
                                                  minTrials=10,
                                                  performance=None,
                                                  dropChannels=None):
    rows = ds.list_dsets_pd()
    rows['mousename'] = [
        dataDB.find_mouse_by_session(session) for session in rows['session']
    ]

    dfColumns = ['datatype', 'trialType', 'consistency']
    dfConsistency = pd.DataFrame(columns=dfColumns)

    for (datatype,
         trialType), rowsMouse in rows.groupby(['datatype', 'trialType']):
        pSigDict = {}
        for mousename, rowsSession in rowsMouse.groupby(['mousename']):
            pSig = []
            for session, rowsTrial in rowsSession.groupby(['session']):
                if (performance is None) or dataDB.is_matching_performance(
                        session, performance, mousename=mousename):
                    assert intervals[0] in list(rowsTrial['intervName'])
                    assert intervals[1] in list(rowsTrial['intervName'])
                    dsetLabel1 = pd_is_one_row(
                        pd_query(rowsTrial,
                                 {'intervName': intervals[0]}))[1]['dset']
                    dsetLabel2 = pd_is_one_row(
                        pd_query(rowsTrial,
                                 {'intervName': intervals[1]}))[1]['dset']
                    data1 = ds.get_data(dsetLabel1)
                    data2 = ds.get_data(dsetLabel2)
                    nTrials1 = data1.shape[0]
                    nTrials2 = data2.shape[1]

                    if (nTrials1 < minTrials) or (nTrials2 < minTrials):
                        print(session, datatype, trialType, 'too few trials',
                              nTrials1, nTrials2, ';; skipping')
                    else:
                        nChannels = data1.shape[1]
                        if dropChannels is not None:
                            channelMask = np.ones(nChannels).astype(bool)
                            channelMask[dropChannels] = 0
                            data1 = data1[:, channelMask]
                            data2 = data2[:, channelMask]
                            nChannels = nChannels - len(dropChannels)

                        pvals = [
                            wilcoxon(data1[:, iCh],
                                     data2[:, iCh],
                                     alternative='two-sided')[1]
                            for iCh in range(nChannels)
                        ]
                        # pSig += [(np.array(pvals) < 0.01).astype(int)]
                        pSig += [-np.log10(np.array(pvals))]
            # pSigDict[mousename] = np.sum(pSig, axis=0)
            pSigDict[mousename] = np.mean(pSig, axis=0)

        mice = sorted(dataDB.mice)
        nMice = len(mice)
        corrCoef = np.zeros((nMice, nMice))
        for iMouse, iName in enumerate(mice):
            for jMouse, jName in enumerate(mice):
                corrCoef[iMouse, jMouse] = np.corrcoef(pSigDict[iName],
                                                       pSigDict[jName])[0, 1]

        sns.pairplot(data=pd.DataFrame(pSigDict), vars=mice)

        prefixPath = 'pics/consistency/significant_activity/byphase/bymouse/'
        make_path(prefixPath)
        plt.savefig(prefixPath + datatype + '_' + trialType + '.svg')
        plt.close()

        fig2, ax2 = plt.subplots()
        ax2.imshow(corrCoef, vmin=0, vmax=1)
        imshow(fig2,
               ax2,
               corrCoef,
               title='Significance Correlation',
               haveColorBar=True,
               limits=[0, 1],
               xTicks=mice,
               yTicks=mice)

        prefixPath = 'pics/consistency/significant_activity/byphase/bymouse_corr/'
        make_path(prefixPath)
        plt.savefig(prefixPath + datatype + '_' + trialType + '.svg')
        plt.close()

        avgConsistency = np.round(np.mean(offdiag_1D(corrCoef)), 2)
        dfConsistency = pd_append_row(dfConsistency,
                                      [datatype, trialType, avgConsistency])

    fig, ax = plt.subplots()
    dfPivot = pd_pivot(dfConsistency, *dfColumns)
    sns.heatmap(data=dfPivot, ax=ax, annot=True, vmax=1, cmap='jet')

    prefixPath = 'pics/consistency/significant_activity/byphase/'
    make_path(prefixPath)
    fig.savefig(prefixPath + 'consistency_' + str(performance) + '.svg')
    plt.close()
コード例 #14
0
def significance_brainplot_mousephase_byaction(
        dataDB,
        ds,
        performance=None,  #exclQueryLst=None,
        metric='accuracy',
        minTrials=10,
        limits=(0.5, 1.0),
        fontsize=20):
    testFunc = test_metric_by_name(metric)

    rows = ds.list_dsets_pd()
    rows['mousename'] = [
        dataDB.find_mouse_by_session(session) for session in rows['session']
    ]

    intervNames = dataDB.get_interval_names()
    mice = sorted(dataDB.mice)
    nInterv = len(intervNames)
    nMice = len(mice)

    for datatype, dfDataType in rows.groupby(['datatype']):
        fig, ax = plt.subplots(nrows=nMice,
                               ncols=nInterv,
                               figsize=(4 * nInterv, 4 * nMice),
                               tight_layout=True)

        for iInterv, intervName in enumerate(intervNames):
            ax[0][iInterv].set_title(intervName, fontsize=fontsize)
            for iMouse, mousename in enumerate(mice):
                ax[iMouse][0].set_ylabel(mousename, fontsize=fontsize)

                pSig = []
                queryDict = {'mousename': mousename, 'intervName': intervName}

                # if (exclQueryLst is None) or all([not subset_dict(queryDict, d) for d in exclQueryLst]) :
                rowsSession = pd_query(dfDataType, queryDict)

                if len(rowsSession) > 0:
                    for session, rowsTrial in rowsSession.groupby(['session']):

                        if (performance is
                                None) or dataDB.is_matching_performance(
                                    session, performance, mousename=mousename):
                            dataThis = []
                            for idx, row in rowsTrial.iterrows():
                                dataThis += [ds.get_data(row['dset'])]

                            nChannels = dataThis[0].shape[1]
                            nTrials1 = dataThis[0].shape[0]
                            nTrials2 = dataThis[1].shape[0]

                            if (nTrials1 < minTrials) or (nTrials2 <
                                                          minTrials):
                                print(session, datatype, intervName,
                                      'too few trials', nTrials1, nTrials2,
                                      ';; skipping')
                            else:
                                pSig += [[
                                    testFunc(dataThis[0][:, iCh],
                                             dataThis[1][:, iCh])
                                    for iCh in range(nChannels)
                                ]]

                    # pSigDict[mousename] = np.sum(pSig, axis=0)
                    print(intervName, mousename, np.array(pSig).shape)

                    pSigAvg = np.mean(pSig, axis=0)

                    dataDB.plot_area_values(fig,
                                            ax[iMouse][iInterv],
                                            pSigAvg,
                                            vmin=limits[0],
                                            vmax=limits[1],
                                            cmap='jet',
                                            haveColorBar=iInterv == nInterv -
                                            1)

        plotSuffix = '_'.join([datatype, str(performance), metric])
        prefixPath = 'pics/significance/brainplot_mousephase/byaction/'
        make_path(prefixPath)
        fig.savefig(prefixPath + plotSuffix + '.svg')
        plt.close()
コード例 #15
0
 def get_rows(self, frameName, coldict):
     return pd_query(self.metaDataFrames[frameName], coldict)
コード例 #16
0
 def ping_data(self, name, attrDict):
     attrDict2 = attrDict.copy()
     attrDict2['name'] = name
     return pd_query(self.list_dsets_pd(), attrDict2)
コード例 #17
0
def scatter_metric_bulk(ds, metricName, nameSuffix, prepFunc=None, xlim=None, ylim=None, yscale=None,
                        verbose=True, xFunc=None, haveRegression=False):#, dropCols=None):
    # 1. Extract all results for this test
    dfAll = ds.list_dsets_pd().fillna('None')
    # if dropCols is not None:
    #     dfAll = dfAll.drop(dropCols, axis=1)

    dfAnalysis = pd_query(dfAll, {'metric' : metricName, "name" : nameSuffix})
    dfAnalysis = pd_move_cols_front(dfAnalysis, ['metric', 'name', 'mousename'])  # Move leading columns forwards for more informative printing/saving
    dfAnalysis = dfAnalysis.drop(['target_dim', 'datetime', 'shape'], axis=1)

    if 'performance' in dfAnalysis.columns:
        dfAnalysis = dfAnalysis[dfAnalysis['performance'] == 'None'].drop(['performance'], axis=1)

    # Loop over all other columns except mousename
    colsExcl = list(set(dfAnalysis.columns) - {'mousename', 'dset'})

    for colVals, dfSub in dfAnalysis.groupby(colsExcl):
        fig, ax = plt.subplots()

        if verbose:
            print(list(colVals))

        xLst = []
        yLst = []
        for idxMouse, rowMouse in dfSub.sort_values(by='mousename').iterrows():
            print(list(rowMouse.values))

            dataThis = ds.get_data(rowMouse['dset'])

            if prepFunc is not None:
                dataThis = prepFunc(dataThis)

            #                     if datatype == 'raw':
            #                         nTrialThis = dataDB.get_ntrial_bytype({'mousename' : row['mousename']}, trialType=trialType, performance=performance)
            #                         dataThis *= np.sqrt(48*nTrialThis)
            #                         print('--', row['mousename'], nTrialThis)

            x = np.arange(len(dataThis)) if xFunc is None else np.array(xFunc(rowMouse['mousename'], len(dataThis)))
            print(dataThis.shape)

            x, dataThis = drop_nan_rows([x, dataThis])
            print(dataThis.shape)

            ax.plot(x, dataThis, '.', label=rowMouse['mousename'])
            xLst += [x]
            yLst += [dataThis]

        if yscale is not None:
            plt.yscale(yscale)

        dataName = rowMouse.drop(['dset', 'mousename'])
        dataName = '_'.join([str(el) for el in dataName])

        ax.legend()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

        if haveRegression:
            sns.regplot(ax=ax, x=np.hstack(xLst), y=np.hstack(yLst), scatter=False)

        prefixPath = 'pics/bulk/' + metricName + '/'
        make_path(prefixPath)

        fig.savefig(prefixPath + dataName + '.png')
        plt.close()
コード例 #18
0
def test_prediction(dataDB, prepData, prepDF, intervNames=None):
    # classifier = LogisticRegression(max_iter=10000, C=1.0E-2, solver='lbfgs')
    classifier = RidgeClassifier(max_iter=10000, alpha=1.0E-2)

    for mousename in sorted(dataDB.mice):
        sessions = dataDB.get_sessions(mousename)

        nSessions = len(sessions)
        if intervNames is None:
            intervNames = dataDB.get_interval_names()

        figTest, axTest = plt.subplots(ncols=3, figsize=(10, 5))
        figClass, axClass = plt.subplots(ncols=3, figsize=(10, 5))
        figTest.suptitle(mousename)
        figClass.suptitle(mousename)

        for iInterv, intervName in enumerate(intervNames):
            testMat = np.zeros((48, nSessions))
            accLst = []

            for iSession, session in enumerate(sessions):
                print(intervName, session)

                queryDict = {
                    'mousename': mousename,
                    'session': session,
                    'interval': intervName
                }
                rowGo = pd_query(prepDF, {**queryDict, **{'trialType': 'iGO'}})
                rowNogo = pd_query(prepDF, {
                    **queryDict,
                    **{
                        'trialType': 'iNOGO'
                    }
                })

                if (len(rowGo) == 0) or (len(rowNogo) == 0):
                    print('Skipping session', session,
                          'because too few trials')
                    testMat[:, iSession] = np.nan
                    accLst += [{'accTrain': np.nan, 'accTest': np.nan}]
                else:
                    idxRowGO, _ = pd_is_one_row(rowGo)
                    idxRowNOGO, _ = pd_is_one_row(rowNogo)
                    dataGO = prepData[idxRowGO]
                    dataNOGO = prepData[idxRowNOGO]

                    # Doing pairwise testing on individual channels
                    for iCh in range(48):
                        p = mannwhitneyu(dataGO[:, iCh],
                                         dataNOGO[:, iCh],
                                         alternative='two-sided')[1]
                        testMat[iCh, iSession] = -np.log10(p)

                    # Doing classification
                    accLst += [
                        binary_classifier(dataGO,
                                          dataNOGO,
                                          classifier,
                                          method="looc",
                                          balancing=False)
                    ]

            # Plot test
            axTest[iInterv].set_title(intervName)
            img = axTest[iInterv].imshow(testMat, vmin=0, vmax=10)
            imshow_add_color_bar(figTest, axTest[iInterv], img)

            # Plot classification
            axClass[iInterv].set_title(intervName)
            axClass[iInterv].plot([l['accTrain'] for l in accLst],
                                  label='train')
            axClass[iInterv].plot([l['accTest'] for l in accLst], label='test')
            axClass[iInterv].axhline(y=0.5, linestyle='--', color='pink')
            axClass[iInterv].set_xlim(0, len(sessions))
            axClass[iInterv].set_ylim(0, 1)
            axClass[iInterv].legend()

        plt.show()
コード例 #19
0
def barplot_conditions(ds, metricName, nameSuffix, verbose=True, trialTypes=None, intervNames=None):
    '''
    Sweep over datatypes
    1. (Mouse * [iGO, iNOGO]) @ {interv='AVG'}
    2. (Mouse * interv / AVG) @ {trialType=None}
    '''

    # 1. Extract all results for this test
    dfAll = ds.list_dsets_pd().fillna('None')

    dfAnalysis = pd_query(dfAll, {'metric' : metricName, "name" : nameSuffix})
    dfAnalysis = pd_move_cols_front(dfAnalysis, ['metric', 'name', 'mousename'])  # Move leading columns forwards for more informative printing/saving
    dfAnalysis = dfAnalysis.drop(['target_dim', 'datetime', 'shape'], axis=1)

    if 'performance' in dfAnalysis.columns:
        sweepLst = ['datatype', 'performance']
    else:
        sweepLst = ['datatype']

    for key, dfDataType in dfAnalysis.groupby(sweepLst):
        plotSuffix = '_'.join(key) if isinstance(key, list) else '_'.join([key])

        if verbose:
            print(plotSuffix)

        intervNamesData = list(set(dfDataType['intervName']))
        trialTypesData =  list(set(dfDataType['trialType']))

        intervNames = intervNamesData if intervNames is None else [i for i in intervNames if i in intervNamesData]
        trialTypes = trialTypesData if trialTypes is None else [i for i in trialTypes if i in trialTypesData]

        #################################
        # Plot 1 ::: Mouse * TrialType
        #################################

        for intervName in intervNames:
            df1 = pd_query(dfDataType, {'intervName' : intervName})
            if trialTypes is not None:
                df1 = df1[df1['trialType'].isin(trialTypes)]

            dfData1 = pd.DataFrame(columns=['mousename', 'trialType', metricName])

            for idx, row in df1.iterrows():
                data = ds.get_data(row['dset'])
                for d in data:
                    dfData1 = pd_append_row(dfData1, [row['mousename'], row['trialType'], d])

            mice = sorted(set(dfData1['mousename']))
            fig, ax = plt.subplots()
            sns_barplot(ax, dfData1, "mousename", metricName, 'trialType', annotHue=True, xOrd=mice, hOrd=trialTypes)
            # sns.barplot(ax=ax, x="mousename", y=metricName, hue='trialType', data=dfData1)

            prefixPath = 'pics/bulk/' + metricName + '/barplot_conditions/'
            make_path(prefixPath)

            fig.savefig(prefixPath + 'barplot_trialtype_' + plotSuffix + '_' + intervName + '.png', dpi=300)
            plt.close()

        #################################
        # Plot 2 ::: Mouse * Phase
        #################################

        for trialType in ['None'] + trialTypes:
            df2 = pd_query(dfDataType, {'trialType' : trialType})

            # display(df2.head())

            df2 = df2[df2['intervName'] != 'AVG']
            if key[0] == 'bn_trial':
                df2 = df2[df2['intervName'] != 'PRE']

            dfData2 = pd.DataFrame(columns=['mousename', 'phase', metricName])

            for idx, row in df2.iterrows():
                data = ds.get_data(row['dset'])
                for d in data:
                    dfData2 = pd_append_row(dfData2, [row['mousename'], row['intervName'], d])

            dfData2 = dfData2.sort_values('mousename')

            mice = sorted(set(dfData2['mousename']))
            fig, ax = plt.subplots()
            sns_barplot(ax, dfData2, "mousename", metricName, 'phase', annotHue=False, xOrd=mice, hOrd=intervNames)
            # sns.barplot(ax=ax, x="mousename", y=metricName, hue='phase', data=dfData2)

            prefixPath = 'pics/bulk/' + metricName + '/barplot_conditions/'
            make_path(prefixPath)

            fig.savefig(prefixPath + 'barplot_phase_' + plotSuffix + '_' + trialType + '.png', dpi=300)
            plt.close()