Example #1
0
def classification_accuracy_brainplot_mousephase(dataDB,
                                                 exclQueryLst,
                                                 fontsize=20,
                                                 trialType='auto',
                                                 **kwargs):
    assert 'intervName' in kwargs.keys(), 'Requires phases'
    dps = DataParameterSweep(dataDB, exclQueryLst, mousename='auto', **kwargs)
    nMice = dps.param_size('mousename')
    nInterv = dps.param_size('intervName')

    trialType = trialType if trialType != 'auto' else dataDB.get_trial_type_names(
    )

    for paramVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', 'intervName'])):
        plotSuffix = param_vals_to_suffix(paramVals)
        print(plotSuffix)

        fig, ax = plt.subplots(nrows=nMice,
                               ncols=nInterv,
                               figsize=(4 * nInterv, 4 * nMice))

        for mousename, dfMouse in dfTmp.groupby(['mousename']):
            iMouse = dps.param_index('mousename', mousename)

            ax[iMouse][0].set_ylabel(mousename, fontsize=fontsize)
            for idx, row in dfMouse.iterrows():
                intervName = row['intervName']
                iInterv = dps.param_index('intervName', intervName)
                ax[0][iInterv].set_title(intervName, fontsize=fontsize)

                kwargsThis = pd_row_to_kwargs(row,
                                              parseNone=True,
                                              dropKeys=['mousename'])
                dataRPLst = [
                    get_data_avg(dataDB, mousename, avgAxes=1, **kwargsThis)
                    for tt in trialType
                ]

                # Split two textures
                dataT1 = np.concatenate([dataRPLst[0], dataRPLst[1]])
                dataT2 = np.concatenate([dataRPLst[2], dataRPLst[3]])

                svcAcc = [
                    classification_accuracy_weighted(x[:, None], y[:, None])
                    for x, y in zip(dataT1.T, dataT2.T)
                ]

                dataDB.plot_area_values(fig,
                                        ax[iMouse][iInterv],
                                        svcAcc,
                                        vmin=0.5,
                                        vmax=1.0,
                                        cmap='jet')

        prefixPath = 'pics/classification_accuracy/brainplot_mousephase/'
        make_path(prefixPath)
        plt.savefig(prefixPath + plotSuffix + '.svg')
        plt.close()
Example #2
0
def activity_brainplot_mousephase_submouse(dataDB,
                                           exclQueryLst=None,
                                           vmin=None,
                                           vmax=None,
                                           fontsize=20,
                                           dpi=200,
                                           **kwargs):
    assert 'intervName' in kwargs.keys(), 'Requires phases'
    dps = DataParameterSweep(dataDB, exclQueryLst, mousename='auto', **kwargs)
    nMice = dps.param_size('mousename')
    nInterv = dps.param_size('intervName')

    for paramVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', 'intervName'])):
        plotSuffix = param_vals_to_suffix(paramVals)
        print(plotSuffix)

        fig, ax = plt.subplots(nrows=nMice,
                               ncols=nInterv,
                               figsize=(4 * nInterv, 4 * nMice),
                               tight_layout=True)

        for intervName, dfInterv in dfTmp.groupby(['intervName']):
            iInterv = dps.param_index('intervName', intervName)

            ax[0][iInterv].set_title(intervName, fontsize=fontsize)

            rezDict = {}
            for idx, row in dfInterv.iterrows():
                mousename = row['mousename']
                kwargsThis = pd_row_to_kwargs(row,
                                              parseNone=True,
                                              dropKeys=['mousename'])
                rezDict[mousename] = get_data_avg(dataDB,
                                                  mousename,
                                                  avgAxes=(0, 1),
                                                  **kwargsThis)

            dataPsub = np.mean(list(rezDict.values()), axis=0)
            for idx, row in dfInterv.iterrows():
                mousename = row['mousename']
                iMouse = dps.param_index('mousename', mousename)
                ax[iMouse][0].set_ylabel(mousename, fontsize=fontsize)
                dataPDelta = rezDict[mousename] - dataPsub

                haveColorBar = iInterv == nInterv - 1
                dataDB.plot_area_values(fig,
                                        ax[iMouse][iInterv],
                                        dataPDelta,
                                        vmin=vmin,
                                        vmax=vmax,
                                        cmap='jet',
                                        haveColorBar=haveColorBar)

        prefixPath = 'pics/activity/brainplot_mousephase/submouse/'
        make_path(prefixPath)
        plt.savefig(prefixPath + plotSuffix + '.svg')
        plt.close()
Example #3
0
def plot_metric_bulk_1D(dataDB, ds, metricName, nameSuffix, prepFunc=None, xlim=None, ylim=None, yscale=None,
                     verbose=True, xFunc=None, haveTimeLabels=False):#, dropCols=None):
    # 1. Extract all results for this test
    dfAll = ds.list_dsets_pd().fillna('None')
    # if dropCols is not None:
    #     dfAll = dfAll.drop(dropCols, axis=1)

    dfAnalysis = pd_query(dfAll, {'metric' : metricName, "name" : nameSuffix})
    dfAnalysis = pd_move_cols_front(dfAnalysis, ['metric', 'name', 'mousename'])  # Move leading columns forwards for more informative printing/saving
    dfAnalysis = dfAnalysis.drop(['target_dim', 'datetime', 'shape'], axis=1)

    # Loop over all other columns except mousename
    colsExcl = list(set(dfAnalysis.columns) - {'mousename', 'dset'})

    for colVals, dfSub in dfAnalysis.groupby(colsExcl):
        fig, ax = plt.subplots(figsize=(4, 4))

        if verbose:
            print(list(colVals))

        for idxMouse, rowMouse in dfSub.sort_values(by='mousename').iterrows():
            print(list(rowMouse.values))

            dataThis = ds.get_data(rowMouse['dset'])
            assert dataThis.ndim == 1, 'Only using 1D data for this plot function'

            if prepFunc is not None:
                dataThis = prepFunc(dataThis)

            #                     if datatype == 'raw':
            #                         nTrialThis = dataDB.get_ntrial_bytype({'mousename' : row['mousename']}, trialType=trialType, performance=performance)
            #                         dataThis *= np.sqrt(48*nTrialThis)
            #                         print('--', row['mousename'], nTrialThis)

            x = np.arange(len(dataThis)) if xFunc is None else np.array(xFunc(rowMouse['mousename'], len(dataThis)))
            x, dataThis = drop_nan_rows([x, dataThis])

            ax.plot(x, dataThis, label=rowMouse['mousename'])

        if yscale is not None:
            ax.set_yscale(yscale)

        if haveTimeLabels:
            dataDB.label_plot_timestamps(ax, linecolor='y', textcolor='k', shX=-0.5, shY=0.05)

        dataName = rowMouse.drop(['dset', 'mousename'])
        dataName = '_'.join([str(el) for el in dataName])

        prefixPath = 'pics/bulk/' + metricName + '/'
        make_path(prefixPath)

        ax.legend()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.set_xlabel(nameSuffix)
        ax.set_ylabel(metricName)
        plt.savefig(prefixPath + dataName + '.png', dpi=200)
        plt.close()
Example #4
0
def activity_brainplot_mouse(dataDB,
                             xParamName,
                             exclQueryLst=None,
                             vmin=None,
                             vmax=None,
                             fontsize=20,
                             dpi=200,
                             **kwargs):
    assert xParamName in kwargs.keys(), 'Requires ' + xParamName
    dps = DataParameterSweep(dataDB, exclQueryLst, mousename='auto', **kwargs)
    nMice = dps.param_size('mousename')
    nXParam = dps.param_size(xParamName)

    for paramVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', xParamName])):
        plotSuffix = param_vals_to_suffix(paramVals)
        print(plotSuffix)

        fig, ax = plt.subplots(nrows=nMice,
                               ncols=nXParam,
                               figsize=(4 * nXParam, 4 * nMice),
                               tight_layout=True)

        for mousename, dfMouse in dfTmp.groupby(['mousename']):
            iMouse = dps.param_index('mousename', mousename)

            ax[iMouse][0].set_ylabel(mousename, fontsize=fontsize)
            for idx, row in dfMouse.iterrows():
                xParamVal = row[xParamName]
                iXParam = dps.param_index(xParamName, xParamVal)

                ax[0][iXParam].set_title(xParamVal, fontsize=fontsize)

                kwargsThis = pd_row_to_kwargs(row,
                                              parseNone=True,
                                              dropKeys=['mousename'])
                dataLst = dataDB.get_neuro_data({'mousename': mousename},
                                                **kwargsThis)
                dataRSP = np.concatenate(dataLst, axis=0)
                dataP = np.mean(dataRSP, axis=(0, 1))

                haveColorBar = iXParam == nXParam - 1
                dataDB.plot_area_values(fig,
                                        ax[iMouse][iXParam],
                                        dataP,
                                        vmin=vmin,
                                        vmax=vmax,
                                        cmap='jet',
                                        haveColorBar=haveColorBar)

        prefixPath = 'pics/activity/brainplot_' + xParamName + '/'
        make_path(prefixPath)
        plt.savefig(prefixPath + plotSuffix + '.svg')
        plt.close()
Example #5
0
def plot_consistency_significant_activity_byaction(dataDB,
                                                   ds,
                                                   minTrials=10,
                                                   performance=None,
                                                   dropChannels=None,
                                                   metric='accuracy',
                                                   limits=None):
    testFunc = test_metric_by_name(metric)

    rows = ds.list_dsets_pd()
    rows['mousename'] = [
        dataDB.find_mouse_by_session(session) for session in rows['session']
    ]

    dfColumns = ['datatype', 'phase', 'consistency']
    dfConsistency = pd.DataFrame(columns=dfColumns)

    for (datatype,
         intervName), rowsMouse in rows.groupby(['datatype', 'intervName']):
        pSigDict = {}
        for mousename, rowsSession in rowsMouse.groupby(['mousename']):
            pSig = []
            for session, rowsTrial in rowsSession.groupby(['session']):
                if (performance is None) or dataDB.is_matching_performance(
                        session, performance, mousename=mousename):
                    if len(rowsTrial) != 2:
                        print(mousename, session, rowsTrial)
                        raise ValueError('Expected exactly 2 rows')

                    dsetLabels = list(rowsTrial['dset'])
                    data1 = ds.get_data(dsetLabels[0])
                    data2 = ds.get_data(dsetLabels[1])
                    nTrials1 = data1.shape[0]
                    nTrials2 = data2.shape[1]

                    if (nTrials1 < minTrials) or (nTrials2 < minTrials):
                        print(session, datatype, intervName, 'too few trials',
                              nTrials1, nTrials2, ';; skipping')
                    else:
                        nChannels = data1.shape[1]

                        if dropChannels is not None:
                            channelMask = np.ones(nChannels).astype(bool)
                            channelMask[dropChannels] = 0
                            data1 = data1[:, channelMask]
                            data2 = data2[:, channelMask]
                            nChannels = nChannels - len(dropChannels)

                        pvals = [
                            testFunc(data1[:, iCh], data2[:, iCh])
                            for iCh in range(nChannels)
                        ]

                        # pSig += [(np.array(pvals) < 0.01).astype(int)]
                        pSig += [-np.log10(np.array(pvals))]
            # pSigDict[mousename] = np.sum(pSig, axis=0)
            pSigDict[mousename] = np.mean(pSig, axis=0)

        mice = sorted(pSigDict.keys())
        nMice = len(mice)
        corrCoef = np.zeros((nMice, nMice))
        for iMouse, iName in enumerate(mice):
            for jMouse, jName in enumerate(mice):
                corrCoef[iMouse, jMouse] = np.corrcoef(pSigDict[iName],
                                                       pSigDict[jName])[0, 1]

        plotSuffix = '_'.join([datatype, str(performance), intervName])

        sns.pairplot(data=pd.DataFrame(pSigDict), vars=mice)

        prefixPath = 'pics/consistency/significant_activity/byaction/bymouse/'
        make_path(prefixPath)
        plt.savefig(prefixPath + plotSuffix + '.svg')
        plt.close()

        fig2, ax2 = plt.subplots()
        ax2.imshow(corrCoef, vmin=0, vmax=1)
        imshow(fig2,
               ax2,
               corrCoef,
               title='Significance Correlation',
               haveColorBar=True,
               limits=[0, 1],
               xTicks=mice,
               yTicks=mice)

        prefixPath = 'pics/consistency/significant_activity/byaction/bymouse_corr/'
        make_path(prefixPath)
        plt.savefig(prefixPath + plotSuffix + '.svg')
        plt.close()

        avgConsistency = np.round(np.mean(offdiag_1D(corrCoef)), 2)
        dfConsistency = pd_append_row(dfConsistency,
                                      [datatype, intervName, avgConsistency])

    fig, ax = plt.subplots()
    dfPivot = pd_pivot(dfConsistency, *dfColumns)
    sns.heatmap(data=dfPivot, ax=ax, annot=True, vmax=1, cmap='jet')

    prefixPath = 'pics/consistency/significant_activity/byaction/'
    make_path(prefixPath)
    fig.savefig(prefixPath + 'consistency_' + str(performance) + '.svg')
    plt.close()
Example #6
0
def significance_brainplot_mousephase_byaction(
        dataDB,
        ds,
        performance=None,  #exclQueryLst=None,
        metric='accuracy',
        minTrials=10,
        limits=(0.5, 1.0),
        fontsize=20):
    testFunc = test_metric_by_name(metric)

    rows = ds.list_dsets_pd()
    rows['mousename'] = [
        dataDB.find_mouse_by_session(session) for session in rows['session']
    ]

    intervNames = dataDB.get_interval_names()
    mice = sorted(dataDB.mice)
    nInterv = len(intervNames)
    nMice = len(mice)

    for datatype, dfDataType in rows.groupby(['datatype']):
        fig, ax = plt.subplots(nrows=nMice,
                               ncols=nInterv,
                               figsize=(4 * nInterv, 4 * nMice),
                               tight_layout=True)

        for iInterv, intervName in enumerate(intervNames):
            ax[0][iInterv].set_title(intervName, fontsize=fontsize)
            for iMouse, mousename in enumerate(mice):
                ax[iMouse][0].set_ylabel(mousename, fontsize=fontsize)

                pSig = []
                queryDict = {'mousename': mousename, 'intervName': intervName}

                # if (exclQueryLst is None) or all([not subset_dict(queryDict, d) for d in exclQueryLst]) :
                rowsSession = pd_query(dfDataType, queryDict)

                if len(rowsSession) > 0:
                    for session, rowsTrial in rowsSession.groupby(['session']):

                        if (performance is
                                None) or dataDB.is_matching_performance(
                                    session, performance, mousename=mousename):
                            dataThis = []
                            for idx, row in rowsTrial.iterrows():
                                dataThis += [ds.get_data(row['dset'])]

                            nChannels = dataThis[0].shape[1]
                            nTrials1 = dataThis[0].shape[0]
                            nTrials2 = dataThis[1].shape[0]

                            if (nTrials1 < minTrials) or (nTrials2 <
                                                          minTrials):
                                print(session, datatype, intervName,
                                      'too few trials', nTrials1, nTrials2,
                                      ';; skipping')
                            else:
                                pSig += [[
                                    testFunc(dataThis[0][:, iCh],
                                             dataThis[1][:, iCh])
                                    for iCh in range(nChannels)
                                ]]

                    # pSigDict[mousename] = np.sum(pSig, axis=0)
                    print(intervName, mousename, np.array(pSig).shape)

                    pSigAvg = np.mean(pSig, axis=0)

                    dataDB.plot_area_values(fig,
                                            ax[iMouse][iInterv],
                                            pSigAvg,
                                            vmin=limits[0],
                                            vmax=limits[1],
                                            cmap='jet',
                                            haveColorBar=iInterv == nInterv -
                                            1)

        plotSuffix = '_'.join([datatype, str(performance), metric])
        prefixPath = 'pics/significance/brainplot_mousephase/byaction/'
        make_path(prefixPath)
        fig.savefig(prefixPath + plotSuffix + '.svg')
        plt.close()
Example #7
0
def activity_brainplot_mouse_2DF(dbDict,
                                 intervNameMap,
                                 intervOrdMap,
                                 trialTypes,
                                 vmin,
                                 vmax,
                                 drop6=False,
                                 dpi=200,
                                 fontsize=20):
    dbTmp = list(dbDict.values())[0]

    mice = sorted(dbTmp.mice)
    intervals = dbTmp.get_interval_names()

    for datatype in ['bn_trial', 'bn_session']:
        for trialType in trialTypes:
            for intervName in intervals:
                intervLabel = intervName if intervName not in intervNameMap else intervNameMap[
                    intervName]

                fig, ax = plt.subplots(nrows=2,
                                       ncols=len(mice),
                                       figsize=(4 * len(mice), 4 * 2),
                                       tight_layout=True)

                for iDB, (dbName, dataDB) in enumerate(dbDict.items()):
                    ax[iDB][0].set_ylabel(dbName, fontsize=fontsize)
                    intervEffName = intervName if (
                        dbName,
                        intervName) not in intervOrdMap else intervOrdMap[(
                            dbName, intervName)]

                    for iMouse, mousename in enumerate(mice):
                        ax[0][iMouse].set_title(mousename, fontsize=fontsize)
                        if (not drop6) or (intervEffName !=
                                           'REW') or (mousename != 'mou_6'):
                            print(datatype, intervEffName, dbName, mousename,
                                  drop6)
                            dataLst = dataDB.get_neuro_data(
                                {'mousename': mousename},
                                datatype=datatype,
                                intervName=intervEffName,
                                trialType=trialType)
                            dataRSP = np.concatenate(dataLst, axis=0)
                            dataP = np.mean(dataRSP, axis=(0, 1))

                            haveColorBar = iMouse == len(mice) - 1
                            dataDB.plot_area_values(fig,
                                                    ax[iDB][iMouse],
                                                    dataP,
                                                    vmin=vmin,
                                                    vmax=vmax,
                                                    cmap='jet',
                                                    haveColorBar=haveColorBar)

                prefixPath = 'pics/activity/brainplot_mousephase/2df/'
                make_path(prefixPath)
                plt.savefig(prefixPath +
                            '_'.join([datatype, trialType, intervLabel]) +
                            '.svg')
                plt.close()
Example #8
0
def activity_brainplot_mousephase_subpre(dataDB,
                                         exclQueryLst=None,
                                         vmin=None,
                                         vmax=None,
                                         fontsize=20,
                                         dpi=200,
                                         **kwargs):
    assert 'intervName' in kwargs.keys(), 'Requires phases'
    dps = DataParameterSweep(dataDB,
                             exclQueryLst,
                             mousename='auto',
                             datatype=['bn_session'],
                             **kwargs)
    nMice = dps.param_size('mousename')
    nInterv = dps.param_size('intervName')

    for paramVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', 'intervName'])):
        plotSuffix = param_vals_to_suffix(paramVals)
        print(plotSuffix)

        fig, ax = plt.subplots(nrows=nMice,
                               ncols=nInterv,
                               figsize=(4 * nInterv, 4 * nMice),
                               tight_layout=True)
        for mousename, dfMouse in dfTmp.groupby(['mousename']):
            iMouse = dps.param_index('mousename', mousename)

            ax[iMouse][0].set_ylabel(mousename, fontsize=fontsize)

            kwargsPre = pd_row_to_kwargs(pd_first_row(dfMouse)[1],
                                         parseNone=True,
                                         dropKeys=['mousename', 'intervName'])
            kwargsPre['intervName'] = 'PRE'
            dataPPre = get_data_avg(dataDB,
                                    mousename,
                                    avgAxes=(0, 1),
                                    **kwargsPre)

            for idx, row in dfMouse.iterrows():
                intervName = row['intervName']
                iInterv = dps.param_index('intervName', intervName)

                if intervName != 'PRE':
                    ax[0][iInterv].set_title(intervName, fontsize=fontsize)

                    kwargsThis = pd_row_to_kwargs(row,
                                                  parseNone=True,
                                                  dropKeys=['mousename'])
                    dataP = get_data_avg(dataDB,
                                         mousename,
                                         avgAxes=(0, 1),
                                         **kwargsThis)

                    dataPDelta = dataP - dataPPre

                    haveColorBar = iInterv == nInterv - 1
                    dataDB.plot_area_values(fig,
                                            ax[iMouse][iInterv],
                                            dataPDelta,
                                            vmin=vmin,
                                            vmax=vmax,
                                            cmap='jet',
                                            haveColorBar=haveColorBar)

        prefixPath = 'pics/activity/brainplot_mousephase/subpre/'
        make_path(prefixPath)
        plt.savefig(prefixPath + plotSuffix + '.svg')
        plt.close()
def plot_corr_mouse(dataDB,
                    mc,
                    estimator,
                    xParamName,
                    nDropPCA=None,
                    dropChannels=None,
                    haveBrain=False,
                    haveMono=True,
                    corrStrategy='mean',
                    exclQueryLst=None,
                    thrMono=0.4,
                    clusterParam=0.5,
                    fontsize=20,
                    **kwargs):

    assert xParamName in ['intervName', 'trialType'], 'Unexpected parameter'
    assert xParamName in kwargs.keys(), 'Requires ' + xParamName
    dps = DataParameterSweep(dataDB, exclQueryLst, mousename='auto', **kwargs)
    nMice = dps.param_size('mousename')
    nXParam = dps.param_size(xParamName)

    for paramVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', xParamName])):
        plotSuffix = param_vals_to_suffix(paramVals)
        print(plotSuffix)

        figCorr, axCorr = plt.subplots(nrows=nMice,
                                       ncols=nXParam,
                                       figsize=(4 * nXParam, 4 * nMice),
                                       tight_layout=True)
        figClust, axClust = plt.subplots(nrows=nMice,
                                         ncols=nXParam,
                                         figsize=(4 * nXParam, 4 * nMice),
                                         tight_layout=True)
        if haveBrain:
            figBrain, axBrain = plt.subplots(nrows=nMice,
                                             ncols=nXParam,
                                             figsize=(4 * nXParam, 4 * nMice),
                                             tight_layout=True)
        if haveMono:
            figMono, axMono = plt.subplots(nrows=nMice,
                                           ncols=nXParam,
                                           figsize=(4 * nXParam, 4 * nMice),
                                           tight_layout=True)

        for mousename, dfMouse in dfTmp.groupby(['mousename']):
            iMouse = dps.param_index('mousename', mousename)

            axCorr[iMouse][0].set_ylabel(mousename, fontsize=fontsize)
            axClust[iMouse][0].set_ylabel(mousename, fontsize=fontsize)

            if haveBrain:
                axBrain[iMouse][0].set_ylabel(mousename, fontsize=fontsize)
            if haveMono:
                axMono[iMouse][0].set_ylabel(mousename, fontsize=fontsize)

            for idx, row in dfMouse.iterrows():
                xParamVal = row[xParamName]
                iXParam = dps.param_index(xParamName, xParamVal)

                axCorr[0][iXParam].set_title(xParamVal, fontsize=fontsize)
                axClust[0][iXParam].set_title(xParamVal, fontsize=fontsize)

                if haveBrain:
                    axBrain[0][iXParam].set_title(xParamVal, fontsize=fontsize)
                if haveMono:
                    axMono[0][iXParam].set_title(xParamVal, fontsize=fontsize)

                kwargsThis = pd_row_to_kwargs(row,
                                              parseNone=True,
                                              dropKeys=['mousename'])
                channelLabels, rez2D = calc_corr_mouse(
                    dataDB,
                    mc,
                    mousename,
                    strategy=corrStrategy,
                    nDropPCA=nDropPCA,
                    dropChannels=dropChannels,
                    estimator=estimator,
                    **kwargsThis)

                haveColorBar = iXParam == nXParam - 1

                # Plot correlations
                imshow(figCorr,
                       axCorr[iMouse][iXParam],
                       rez2D,
                       limits=[-1, 1],
                       cmap='jet',
                       haveColorBar=haveColorBar)

                # Plot clustering
                clusters = cluster_dist_matrix_max(rez2D,
                                                   clusterParam,
                                                   method='Affinity')
                cluster_plot(figClust,
                             axClust[iMouse][iXParam],
                             rez2D,
                             clusters,
                             channelLabels,
                             limits=[-1, 1],
                             cmap='jet',
                             haveColorBar=haveColorBar)

                if haveBrain:
                    cluster_brain_plot(figBrain,
                                       axBrain[iMouse][iXParam],
                                       dataDB,
                                       clusters,
                                       dropChannels=dropChannels)

                if haveMono:
                    _plot_corr_1D(figMono, axMono[iMouse][iXParam],
                                  channelLabels, rez2D, thrMono)

        # Save image
        prefixPrefixPath = 'pics/corr/mouse' + xParamName + '/dropPCA_' + str(
            nDropPCA) + '/'

        prefixPath = prefixPrefixPath + 'corr/'
        make_path(prefixPath)
        figCorr.savefig(prefixPath + 'corr_' + plotSuffix + '.svg')
        plt.close(figCorr)

        prefixPath = prefixPrefixPath + 'clust/'
        make_path(prefixPath)
        figClust.savefig(prefixPath + 'clust_' + plotSuffix + '.svg')
        plt.close(figClust)
        if haveBrain:
            prefixPath = prefixPrefixPath + 'clust_brainplot/'
            make_path(prefixPath)
            figBrain.savefig(prefixPath + 'clust_brainplot_' + plotSuffix +
                             '.svg')
            plt.close(figBrain)
        if haveMono:
            prefixPath = prefixPrefixPath + '1D/'
            make_path(prefixPath)
            figMono.savefig(prefixPath + '1Dplot_' + plotSuffix + '.svg')
            plt.close(figMono)
Example #10
0
def scatter_metric_bulk(ds, metricName, nameSuffix, prepFunc=None, xlim=None, ylim=None, yscale=None,
                        verbose=True, xFunc=None, haveRegression=False):#, dropCols=None):
    # 1. Extract all results for this test
    dfAll = ds.list_dsets_pd().fillna('None')
    # if dropCols is not None:
    #     dfAll = dfAll.drop(dropCols, axis=1)

    dfAnalysis = pd_query(dfAll, {'metric' : metricName, "name" : nameSuffix})
    dfAnalysis = pd_move_cols_front(dfAnalysis, ['metric', 'name', 'mousename'])  # Move leading columns forwards for more informative printing/saving
    dfAnalysis = dfAnalysis.drop(['target_dim', 'datetime', 'shape'], axis=1)

    if 'performance' in dfAnalysis.columns:
        dfAnalysis = dfAnalysis[dfAnalysis['performance'] == 'None'].drop(['performance'], axis=1)

    # Loop over all other columns except mousename
    colsExcl = list(set(dfAnalysis.columns) - {'mousename', 'dset'})

    for colVals, dfSub in dfAnalysis.groupby(colsExcl):
        fig, ax = plt.subplots()

        if verbose:
            print(list(colVals))

        xLst = []
        yLst = []
        for idxMouse, rowMouse in dfSub.sort_values(by='mousename').iterrows():
            print(list(rowMouse.values))

            dataThis = ds.get_data(rowMouse['dset'])

            if prepFunc is not None:
                dataThis = prepFunc(dataThis)

            #                     if datatype == 'raw':
            #                         nTrialThis = dataDB.get_ntrial_bytype({'mousename' : row['mousename']}, trialType=trialType, performance=performance)
            #                         dataThis *= np.sqrt(48*nTrialThis)
            #                         print('--', row['mousename'], nTrialThis)

            x = np.arange(len(dataThis)) if xFunc is None else np.array(xFunc(rowMouse['mousename'], len(dataThis)))
            print(dataThis.shape)

            x, dataThis = drop_nan_rows([x, dataThis])
            print(dataThis.shape)

            ax.plot(x, dataThis, '.', label=rowMouse['mousename'])
            xLst += [x]
            yLst += [dataThis]

        if yscale is not None:
            plt.yscale(yscale)

        dataName = rowMouse.drop(['dset', 'mousename'])
        dataName = '_'.join([str(el) for el in dataName])

        ax.legend()
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

        if haveRegression:
            sns.regplot(ax=ax, x=np.hstack(xLst), y=np.hstack(yLst), scatter=False)

        prefixPath = 'pics/bulk/' + metricName + '/'
        make_path(prefixPath)

        fig.savefig(prefixPath + dataName + '.png')
        plt.close()
Example #11
0
def barplot_conditions(ds, metricName, nameSuffix, verbose=True, trialTypes=None, intervNames=None):
    '''
    Sweep over datatypes
    1. (Mouse * [iGO, iNOGO]) @ {interv='AVG'}
    2. (Mouse * interv / AVG) @ {trialType=None}
    '''

    # 1. Extract all results for this test
    dfAll = ds.list_dsets_pd().fillna('None')

    dfAnalysis = pd_query(dfAll, {'metric' : metricName, "name" : nameSuffix})
    dfAnalysis = pd_move_cols_front(dfAnalysis, ['metric', 'name', 'mousename'])  # Move leading columns forwards for more informative printing/saving
    dfAnalysis = dfAnalysis.drop(['target_dim', 'datetime', 'shape'], axis=1)

    if 'performance' in dfAnalysis.columns:
        sweepLst = ['datatype', 'performance']
    else:
        sweepLst = ['datatype']

    for key, dfDataType in dfAnalysis.groupby(sweepLst):
        plotSuffix = '_'.join(key) if isinstance(key, list) else '_'.join([key])

        if verbose:
            print(plotSuffix)

        intervNamesData = list(set(dfDataType['intervName']))
        trialTypesData =  list(set(dfDataType['trialType']))

        intervNames = intervNamesData if intervNames is None else [i for i in intervNames if i in intervNamesData]
        trialTypes = trialTypesData if trialTypes is None else [i for i in trialTypes if i in trialTypesData]

        #################################
        # Plot 1 ::: Mouse * TrialType
        #################################

        for intervName in intervNames:
            df1 = pd_query(dfDataType, {'intervName' : intervName})
            if trialTypes is not None:
                df1 = df1[df1['trialType'].isin(trialTypes)]

            dfData1 = pd.DataFrame(columns=['mousename', 'trialType', metricName])

            for idx, row in df1.iterrows():
                data = ds.get_data(row['dset'])
                for d in data:
                    dfData1 = pd_append_row(dfData1, [row['mousename'], row['trialType'], d])

            mice = sorted(set(dfData1['mousename']))
            fig, ax = plt.subplots()
            sns_barplot(ax, dfData1, "mousename", metricName, 'trialType', annotHue=True, xOrd=mice, hOrd=trialTypes)
            # sns.barplot(ax=ax, x="mousename", y=metricName, hue='trialType', data=dfData1)

            prefixPath = 'pics/bulk/' + metricName + '/barplot_conditions/'
            make_path(prefixPath)

            fig.savefig(prefixPath + 'barplot_trialtype_' + plotSuffix + '_' + intervName + '.png', dpi=300)
            plt.close()

        #################################
        # Plot 2 ::: Mouse * Phase
        #################################

        for trialType in ['None'] + trialTypes:
            df2 = pd_query(dfDataType, {'trialType' : trialType})

            # display(df2.head())

            df2 = df2[df2['intervName'] != 'AVG']
            if key[0] == 'bn_trial':
                df2 = df2[df2['intervName'] != 'PRE']

            dfData2 = pd.DataFrame(columns=['mousename', 'phase', metricName])

            for idx, row in df2.iterrows():
                data = ds.get_data(row['dset'])
                for d in data:
                    dfData2 = pd_append_row(dfData2, [row['mousename'], row['intervName'], d])

            dfData2 = dfData2.sort_values('mousename')

            mice = sorted(set(dfData2['mousename']))
            fig, ax = plt.subplots()
            sns_barplot(ax, dfData2, "mousename", metricName, 'phase', annotHue=False, xOrd=mice, hOrd=intervNames)
            # sns.barplot(ax=ax, x="mousename", y=metricName, hue='phase', data=dfData2)

            prefixPath = 'pics/bulk/' + metricName + '/barplot_conditions/'
            make_path(prefixPath)

            fig.savefig(prefixPath + 'barplot_phase_' + plotSuffix + '_' + trialType + '.png', dpi=300)
            plt.close()
def plot_corr_consistency_l1_phase(dataDB,
                                   nDropPCA=None,
                                   dropChannels=None,
                                   performance=None,
                                   datatype=None):
    mice = sorted(dataDB.mice)
    phases = dataDB.get_interval_names()
    nPhases = len(phases)

    dfColumns = ['mousename', 'trialtype', 'consistency']
    dfConsistency = pd.DataFrame(columns=dfColumns)

    for iMouse, mousename in enumerate(mice):
        for trialType in dataDB.get_trial_type_names():
            fnameSuffix = '_'.join(
                [datatype, mousename, trialType,
                 str(performance)])
            print(fnameSuffix)

            corrLst = []
            for intervName in phases:
                kwargs = {
                    'datatype': datatype,
                    'intervName': intervName,
                    'trialType': trialType
                }
                if performance is not None:
                    kwargs['performance'] = performance

                dataRSPLst = dataDB.get_neuro_data({'mousename': mousename},
                                                   zscoreDim='rs',
                                                   **kwargs)

                dataRSP = np.concatenate(dataRSPLst, axis=0)
                dataRP = np.mean(dataRSP, axis=1)
                # dataRP = zscore(dataRP, axis=0)

                if dropChannels is not None:
                    channelMask = np.ones(dataRP.shape[1]).astype(bool)
                    channelMask[dropChannels] = 0
                    dataRP = dataRP[:, channelMask]

                if nDropPCA is not None:
                    dataRP = drop_PCA(dataRP, nDropPCA)

                corrLst += [tril_1D(np.corrcoef(dataRP.T))]

            # fig, ax = plt.subplots(nrows=nMice, ncols=nMice, figsize=(4 * nMice, 4 * nMice))

            pairDict = {}
            rezMat = np.zeros((nPhases, nPhases))
            for idxNamei, iName in enumerate(phases):
                pairDict[iName] = corrLst[idxNamei]

                for idxNamej, jName in enumerate(phases):
                    rezMat[idxNamei][idxNamej] = np.corrcoef(
                        corrLst[idxNamei], corrLst[idxNamej])[0, 1]

            pPlot = sns.pairplot(data=pd.DataFrame(pairDict),
                                 vars=phases,
                                 kind='kde')

            prefixPath = 'pics/consistency/corr/phase/dropPCA_' + str(
                nDropPCA) + '/scatter/'
            make_path(prefixPath)
            plt.savefig(prefixPath + 'scatter_' + fnameSuffix + '.svg')
            plt.close()

            fig, ax = plt.subplots()
            imshow(fig,
                   ax,
                   rezMat,
                   haveColorBar=True,
                   limits=[0, 1],
                   xTicks=phases,
                   yTicks=phases,
                   cmap='jet')

            prefixPath = 'pics/consistency/corr/phase/dropPCA_' + str(
                nDropPCA) + '/metric/'
            make_path(prefixPath)
            plt.savefig(prefixPath + 'metric_' + fnameSuffix + '.svg')
            plt.close()

            avgConsistency = np.round(np.mean(offdiag_1D(rezMat)), 2)
            dfConsistency = pd_append_row(
                dfConsistency, [mousename, trialType, avgConsistency])

    fig, ax = plt.subplots()
    dfPivot = pd_pivot(dfConsistency, *dfColumns)
    sns.heatmap(data=dfPivot, ax=ax, annot=True, vmin=0, vmax=1, cmap='jet')

    prefixPath = 'pics/consistency/corr/phase/dropPCA_' + str(nDropPCA) + '/'
    make_path(prefixPath)
    fig.savefig(prefixPath + datatype + '_' + str(performance) + '.svg')
    plt.close()
def plot_corr_consistency_l1_mouse(
        dataDB,
        nDropPCA=None,
        dropChannels=None,
        exclQueryLst=None,
        **kwargs):  # performances=None, trialTypes=None,

    assert 'intervName' in kwargs.keys(), 'Requires phases'
    dps = DataParameterSweep(dataDB,
                             exclQueryLst,
                             mousename='auto',
                             intervName='auto',
                             datatype='auto',
                             **kwargs)
    mice = sorted(dataDB.mice)
    nMice = len(mice)

    for paramExtraVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', 'datatype', 'intervName'])):
        plotExtraSuffix = param_vals_to_suffix(paramExtraVals)

        dfColumns = ['datatype', 'phase', 'consistency']
        dfConsistency = pd.DataFrame(columns=dfColumns)

        for paramVals, dfMouse in dfTmp.groupby(dps.invert_param(['mousename'
                                                                  ])):
            plotSuffix = param_vals_to_suffix(paramVals)
            print(plotSuffix)

            corrLst = []
            for idx, row in dfMouse.iterrows():
                plotSuffix = '_'.join([str(s) for s in row.values])
                print(plotSuffix)

                kwargsThis = pd_row_to_kwargs(row,
                                              parseNone=True,
                                              dropKeys=['mousename'])
                dataRSPLst = dataDB.get_neuro_data(
                    {
                        'mousename': row['mousename']
                    },  # NOTE: zscore channels for each session to avoid session-wise effects
                    zscoreDim='rs',
                    **kwargsThis)

                dataRSP = np.concatenate(dataRSPLst, axis=0)
                dataRP = np.mean(dataRSP, axis=1)
                # dataRP = zscore(dataRP, axis=0)

                if dropChannels is not None:
                    channelMask = np.ones(dataRP.shape[1]).astype(bool)
                    channelMask[dropChannels] = 0
                    dataRP = dataRP[:, channelMask]

                if nDropPCA is not None:
                    dataRP = drop_PCA(dataRP, nDropPCA)

                corrLst += [tril_1D(np.corrcoef(dataRP.T))]

            # fig, ax = plt.subplots(nrows=nMice, ncols=nMice, figsize=(4 * nMice, 4 * nMice))

            pairDict = {}
            rezMat = np.zeros((nMice, nMice))
            for iMouse, mousename in enumerate(mice):
                # ax[iMouse][0].set_ylabel(mice[iMouse])
                # ax[-1][iMouse].set_xlabel(mice[iMouse])
                pairDict[mousename] = corrLst[iMouse]

                for jMouse in range(nMice):
                    # rezMat[iMouse][jMouse] = 1 - rmae(corrLst[iMouse], corrLst[jMouse])
                    rezMat[iMouse][jMouse] = np.corrcoef(
                        corrLst[iMouse], corrLst[jMouse])[0, 1]

                    # cci = offdiag_1D(corrLst[iMouse])
                    # ccj = offdiag_1D(corrLst[jMouse])
                    # ax[iMouse][jMouse].plot(cci, ccj, '.')

            pPlot = sns.pairplot(data=pd.DataFrame(pairDict),
                                 vars=mice,
                                 kind='kde')

            prefixPath = 'pics/consistency/corr/mouse/dropPCA_' + str(
                nDropPCA) + '/scatter/'
            make_path(prefixPath)
            plt.savefig(prefixPath + 'scatter_' + plotSuffix + '.svg')
            plt.close()

            fig, ax = plt.subplots()
            imshow(fig,
                   ax,
                   rezMat,
                   haveColorBar=True,
                   limits=[0, 1],
                   xTicks=mice,
                   yTicks=mice,
                   cmap='jet')

            prefixPath = 'pics/consistency/corr/mouse/dropPCA_' + str(
                nDropPCA) + '/metric/'
            make_path(prefixPath)
            plt.savefig(prefixPath + 'metric_' + plotSuffix + '.svg')
            plt.close()

            avgConsistency = np.round(np.mean(offdiag_1D(rezMat)), 2)
            dfConsistency = pd_append_row(
                dfConsistency,
                [row['datatype'], row['intervName'], avgConsistency])

        fig, ax = plt.subplots()
        dfPivot = pd_pivot(dfConsistency, *dfColumns)
        sns.heatmap(data=dfPivot,
                    ax=ax,
                    annot=True,
                    vmin=0,
                    vmax=1,
                    cmap='jet')

        prefixPath = 'pics/consistency/corr/mouse/dropPCA_' + str(
            nDropPCA) + '/'
        make_path(prefixPath)
        fig.savefig(prefixPath + plotExtraSuffix + '.svg')
        plt.close()
def plot_corr_mouse_2DF(dfDict,
                        mc,
                        estimator,
                        intervNameMap,
                        intervOrdMap,
                        corrStrategy='mean',
                        nDropPCA=None,
                        dropChannels=None,
                        exclQueryLst=None):
    dataDBTmp = list(dfDict.values())[0]

    mice = sorted(dataDBTmp.mice)
    nMice = len(mice)
    intervNames = dataDBTmp.get_interval_names()
    trialTypes = dataDBTmp.get_trial_type_names()

    for trialType in trialTypes:
        for intervName in intervNames:
            intervLabel = intervName if intervName not in intervNameMap else intervNameMap[
                intervName]
            plotSuffix = trialType + '_' + intervLabel
            print(plotSuffix)

            fig, ax = plt.subplots(nrows=2,
                                   ncols=nMice,
                                   figsize=(4 * nMice, 4 * 2),
                                   tight_layout=True)

            for iDB, (dbName, dataDB) in enumerate(dfDict.items()):
                ax[iDB][0].set_ylabel(dbName)

                intervEffName = intervName if (
                    dbName, intervName) not in intervOrdMap else intervOrdMap[(
                        dbName, intervName)]

                for iMouse, mousename in enumerate(mice):
                    ax[0][iMouse].set_title(mousename)

                    kwargs = {
                        'mousename': mousename,
                        'intervName': intervEffName,
                        'trialType': trialType,
                        'datatype': 'bn_session'
                    }

                    if np.all([
                            not subset_dict(excl, kwargs)
                            for excl in exclQueryLst
                    ]):
                        del kwargs['mousename']
                        kwargs = {
                            k: v if v != 'None' else None
                            for k, v in kwargs.items()
                        }

                        channelLabels, rez2D = calc_corr_mouse(
                            dataDB,
                            mc,
                            mousename,
                            strategy=corrStrategy,
                            nDropPCA=nDropPCA,
                            dropChannels=dropChannels,
                            estimator=estimator,
                            **kwargs)

                        imshow(fig,
                               ax[iDB][iMouse],
                               rez2D,
                               limits=[-1, 1],
                               cmap='jet',
                               haveColorBar=iMouse == nMice - 1)

            # Save image
            prefixPath = 'pics/corr/bystim/dropPCA_' + str(nDropPCA) + '/'
            make_path(prefixPath)
            plt.savefig(prefixPath + 'corr_bn_session_' + plotSuffix + '.svg')
            plt.close()
def plot_corr_mousephase_submouse(dataDB,
                                  mc,
                                  estimator,
                                  nDropPCA=None,
                                  dropChannels=None,
                                  exclQueryLst=None,
                                  corrStrategy='mean',
                                  fontsize=20,
                                  **kwargs):

    assert 'intervName' in kwargs.keys(), 'Requires phases'
    dps = DataParameterSweep(dataDB, exclQueryLst, mousename='auto', **kwargs)
    nMice = dps.param_size('mousename')
    nInterv = dps.param_size('intervName')

    for paramVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', 'intervName'])):
        plotSuffix = param_vals_to_suffix(paramVals)
        print(plotSuffix)

        figCorr, axCorr = plt.subplots(nrows=nMice,
                                       ncols=nInterv,
                                       figsize=(4 * nInterv, 4 * nMice))

        for intervName, dfInterv in dfTmp.groupby(['intervName']):
            iInterv = dps.param_index('intervName', intervName)

            axCorr[0][iInterv].set_title(intervName, fontsize=fontsize)

            rezDict = {}
            for idx, row in dfInterv.iterrows():
                mousename = row['mousename']

                kwargsThis = pd_row_to_kwargs(row,
                                              parseNone=True,
                                              dropKeys=['mousename'])
                channelLabels, rez2D = calc_corr_mouse(
                    dataDB,
                    mc,
                    mousename,
                    strategy=corrStrategy,
                    nDropPCA=nDropPCA,
                    dropChannels=dropChannels,
                    estimator=estimator,
                    **kwargsThis)

                rezDict[mousename] = rez2D

            # Plot correlations
            rezMean = np.mean(list(rezDict.values()), axis=0)

            for mousename, rezMouse in rezDict.items():
                iMouse = dps.param_index('mousename', mousename)
                axCorr[iMouse][0].set_ylabel(mousename, fontsize=fontsize)

                if mousename in rezDict.keys():
                    haveColorBar = iInterv == nInterv - 1
                    imshow(figCorr,
                           axCorr[iMouse][iInterv],
                           rezMouse - rezMean,
                           title='corr',
                           haveColorBar=haveColorBar,
                           limits=[-1, 1],
                           cmap='RdBu_r')

        # Save image
        prefixPath = 'pics/corr/mousephase/dropPCA_' + str(
            nDropPCA) + '/submouse/'
        make_path(prefixPath)
        figCorr.savefig(prefixPath + 'corr_' + plotSuffix + '.svg')
        plt.close()
Example #16
0
def plot_consistency_significant_activity_byphase(dataDB,
                                                  ds,
                                                  intervals,
                                                  minTrials=10,
                                                  performance=None,
                                                  dropChannels=None):
    rows = ds.list_dsets_pd()
    rows['mousename'] = [
        dataDB.find_mouse_by_session(session) for session in rows['session']
    ]

    dfColumns = ['datatype', 'trialType', 'consistency']
    dfConsistency = pd.DataFrame(columns=dfColumns)

    for (datatype,
         trialType), rowsMouse in rows.groupby(['datatype', 'trialType']):
        pSigDict = {}
        for mousename, rowsSession in rowsMouse.groupby(['mousename']):
            pSig = []
            for session, rowsTrial in rowsSession.groupby(['session']):
                if (performance is None) or dataDB.is_matching_performance(
                        session, performance, mousename=mousename):
                    assert intervals[0] in list(rowsTrial['intervName'])
                    assert intervals[1] in list(rowsTrial['intervName'])
                    dsetLabel1 = pd_is_one_row(
                        pd_query(rowsTrial,
                                 {'intervName': intervals[0]}))[1]['dset']
                    dsetLabel2 = pd_is_one_row(
                        pd_query(rowsTrial,
                                 {'intervName': intervals[1]}))[1]['dset']
                    data1 = ds.get_data(dsetLabel1)
                    data2 = ds.get_data(dsetLabel2)
                    nTrials1 = data1.shape[0]
                    nTrials2 = data2.shape[1]

                    if (nTrials1 < minTrials) or (nTrials2 < minTrials):
                        print(session, datatype, trialType, 'too few trials',
                              nTrials1, nTrials2, ';; skipping')
                    else:
                        nChannels = data1.shape[1]
                        if dropChannels is not None:
                            channelMask = np.ones(nChannels).astype(bool)
                            channelMask[dropChannels] = 0
                            data1 = data1[:, channelMask]
                            data2 = data2[:, channelMask]
                            nChannels = nChannels - len(dropChannels)

                        pvals = [
                            wilcoxon(data1[:, iCh],
                                     data2[:, iCh],
                                     alternative='two-sided')[1]
                            for iCh in range(nChannels)
                        ]
                        # pSig += [(np.array(pvals) < 0.01).astype(int)]
                        pSig += [-np.log10(np.array(pvals))]
            # pSigDict[mousename] = np.sum(pSig, axis=0)
            pSigDict[mousename] = np.mean(pSig, axis=0)

        mice = sorted(dataDB.mice)
        nMice = len(mice)
        corrCoef = np.zeros((nMice, nMice))
        for iMouse, iName in enumerate(mice):
            for jMouse, jName in enumerate(mice):
                corrCoef[iMouse, jMouse] = np.corrcoef(pSigDict[iName],
                                                       pSigDict[jName])[0, 1]

        sns.pairplot(data=pd.DataFrame(pSigDict), vars=mice)

        prefixPath = 'pics/consistency/significant_activity/byphase/bymouse/'
        make_path(prefixPath)
        plt.savefig(prefixPath + datatype + '_' + trialType + '.svg')
        plt.close()

        fig2, ax2 = plt.subplots()
        ax2.imshow(corrCoef, vmin=0, vmax=1)
        imshow(fig2,
               ax2,
               corrCoef,
               title='Significance Correlation',
               haveColorBar=True,
               limits=[0, 1],
               xTicks=mice,
               yTicks=mice)

        prefixPath = 'pics/consistency/significant_activity/byphase/bymouse_corr/'
        make_path(prefixPath)
        plt.savefig(prefixPath + datatype + '_' + trialType + '.svg')
        plt.close()

        avgConsistency = np.round(np.mean(offdiag_1D(corrCoef)), 2)
        dfConsistency = pd_append_row(dfConsistency,
                                      [datatype, trialType, avgConsistency])

    fig, ax = plt.subplots()
    dfPivot = pd_pivot(dfConsistency, *dfColumns)
    sns.heatmap(data=dfPivot, ax=ax, annot=True, vmax=1, cmap='jet')

    prefixPath = 'pics/consistency/significant_activity/byphase/'
    make_path(prefixPath)
    fig.savefig(prefixPath + 'consistency_' + str(performance) + '.svg')
    plt.close()
def plot_pca1_session(dataDB,
                      mousename,
                      session,
                      trialTypesSelected=('Hit', 'CR')):
    plotColors = pylab.cm.gist_heat([0.2, 0.4, 0.6, 0.8])
    plotColorMap = dict(
        zip(trialTypesSelected, plotColors[:len(trialTypesSelected)]))

    fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(12, 8))

    for iDataType, datatype in enumerate(['bn_trial', 'bn_session']):
        timesRS = dataDB.get_absolute_times(mousename, session)
        dataRSP = dataDB.get_neuro_data({'session': session},
                                        datatype=datatype)[0]

        # Train PCA on whole session, but only trial based timesteps
        dataSP = numpy_merge_dimensions(dataRSP, 0, 2)
        timesS = numpy_merge_dimensions(timesRS, 0, 2)
        pca = PCA(n_components=1)
        dataPCA1 = pca.fit_transform(dataSP)[:, 0]
        pcaSig = np.sign(np.mean(pca.components_))
        pcaTransform = lambda x: pca.transform(x)[:, 0] * pcaSig

        # Compute 1st PCA during trial-time
        # Note: it is irrelevant whether averaging or PCA-transform comes first
        trialTypes = dataDB.get_trial_types(session, mousename)
        timesTrial = dataDB.get_times(dataRSP.shape[1])

        for tt in trialTypesSelected:
            trialIdxs = trialTypes == tt
            if np.sum(trialIdxs) > 0:
                dataAvgTTSP = np.mean(dataRSP[trialIdxs], axis=0)
                dataAvgTTPCA = pcaTransform(dataAvgTTSP)
                dataAvgTTAvg = np.mean(dataAvgTTSP, axis=1)

                ax[iDataType, 1].plot(timesTrial,
                                      dataAvgTTAvg,
                                      label=tt,
                                      color=plotColorMap[tt])
                ax[iDataType, 2].plot(timesTrial,
                                      dataAvgTTPCA,
                                      label=tt,
                                      color=plotColorMap[tt])

        ax[iDataType, 0].set_ylabel(datatype)
        ax[iDataType, 0].plot(timesS, dataPCA1)
        ax[iDataType, 0].set_title('1st PCA during session')

        ax[iDataType, 1].set_title('Trial-average activity')
        ax[iDataType, 1].legend()

        ax[iDataType, 2].set_title('1st PCA trial-average')
        ax[iDataType, 2].legend()

        dataDB.label_plot_timestamps(ax[iDataType, 1], mousename, session)
        dataDB.label_plot_timestamps(ax[iDataType, 2], mousename, session)
        dataDB.label_plot_intervals(ax[iDataType, 1], mousename, session)
        dataDB.label_plot_intervals(ax[iDataType, 2], mousename, session)

    prefixPath = 'pics/bulk/traces/bymouse/'
    suffixPath = '_'.join([datatype, mousename, session])
    make_path(prefixPath)
    fig.savefig(prefixPath + 'traces_' + suffixPath + '.svg')
    plt.close()
def movie_mouse_trialtype(dataDB,
                          dataKWArgs,
                          calcKWArgs,
                          plotKWArgs,
                          calc_func,
                          plot_func,
                          prefixPath='',
                          exclQueryLst=None,
                          haveDelay=False,
                          fontsize=20,
                          tTrgDelay=2.0,
                          tTrgRew=2.0):
    assert 'trialType' in dataKWArgs.keys(), 'Requires trial types'
    assert 'intervName' not in dataKWArgs.keys(
    ), 'Movie intended for full range'
    dps = DataParameterSweep(dataDB,
                             exclQueryLst,
                             mousename='auto',
                             **dataKWArgs)
    nMice = dps.param_size('mousename')
    nTrialType = dps.param_size('trialType')

    for paramVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', 'trialType'])):
        plotSuffix = param_vals_to_suffix(paramVals)

        # Store all preprocessed data first
        dataDict = {}
        for mousename, dfMouse in dfTmp.groupby(['mousename']):
            for idx, row in dfMouse.iterrows():
                trialType = row['trialType']
                print('Reading data, ', plotSuffix, mousename, trialType)

                kwargsThis = pd_row_to_kwargs(row,
                                              parseNone=True,
                                              dropKeys=['mousename'])

                dataDict[(mousename,
                          trialType)] = calc_func(dataDB,
                                                  mousename,
                                                  calcKWArgs,
                                                  haveDelay=haveDelay,
                                                  tTrgDelay=tTrgDelay,
                                                  tTrgRew=tTrgRew,
                                                  **kwargsThis)

        # Test that all datasets have the same duration
        shapeSet = set([v.shape for v in dataDict.values()])
        assert len(shapeSet) == 1
        nTimes = shapeSet.pop()[0]

        progBar = IntProgress(min=0, max=nTimes, description=plotSuffix)
        display(progBar)  # display the bar
        for iTime in range(nTimes):
            make_path(prefixPath)
            outfname = prefixPath + plotSuffix + '_' + str(iTime) + '.png'

            if os.path.isfile(outfname):
                print('Already calculated', iTime, 'skipping')
                progBar.value += 1
                continue

            fig, ax = plt.subplots(nrows=nMice,
                                   ncols=nTrialType,
                                   figsize=(4 * nTrialType, 4 * nMice),
                                   tight_layout=True)

            for iMouse, mousename in enumerate(dps.param('mousename')):
                ax[iMouse][0].set_ylabel(mousename, fontsize=fontsize)
                for iTT, trialType in enumerate(dps.param('trialType')):
                    ax[0][iTT].set_title(trialType, fontsize=fontsize)
                    # print(datatype, mousename)

                    dataP = dataDict[(mousename, trialType)][iTime]

                    rightMost = iTT == nTrialType - 1
                    plot_func(dataDB,
                              fig,
                              ax[iMouse][iTT],
                              dataP,
                              haveColorBar=rightMost,
                              **plotKWArgs)

            # Add a timescale bar to the figure
            timestamps = dataDB.get_timestamps(mousename, session=None)
            if 'delay' not in timestamps.keys():
                tsKeys = ['PRE'] + list(timestamps.keys())
                tsVals = list(
                    timestamps.values()) + [nTimes / dataDB.targetFPS]
            else:
                tsKeys = ['PRE'] + list(timestamps.keys()) + ['reward']
                tsVals = list(timestamps.values()) + [
                    timestamps['delay'] + tTrgDelay, nTimes / dataDB.targetFPS
                ]

            print(tsVals, iTime / dataDB.targetFPS)
            add_timescale_bar(fig, tsKeys, tsVals, iTime / dataDB.targetFPS)

            fig.savefig(outfname, bbox_inches='tight')
            # plt.close()
            plt.cla()
            plt.clf()
            plt.close('all')
            progBar.value += 1
    return prefixPath