Ejemplo n.º 1
0
def corr_evaluation(dataDB, mc, estimator, exclQueryLst=None, minTrials=50, **kwargs):
    resultsDict = {'corr' : {}, 'pval' : {}}

    dps = DataParameterSweep(dataDB, exclQueryLst, mousename='auto', **kwargs)

    for idx, row in dps.sweepDF.iterrows():
        kwargsThis = pd_row_to_kwargs(row, parseNone=True, dropKeys=['mousename'])

        results = []
        for session in dataDB.get_sessions(row['mousename']):
            dataRSP = dataDB.get_neuro_data({'session' : session}, **kwargsThis)[0]

            nTrials, nTime, nChannel = dataRSP.shape

            if nTrials < minTrials:
                print('Too few trials =', nTrials, ' for', session, kwargs, ': skipping')
            else:
                mc.set_data(dataRSP, 'rsp')
                metricSettings={'timeAvg' : True, 'havePVal' : True, 'estimator' : estimator}
                rez2D = mc.metric3D('corr', '', metricSettings=metricSettings)
                rez1D = np.array([tril_1D(rez2D[..., 0]), tril_1D(rez2D[..., 1])])
                results += [rez1D]

        if results != []:
            dictKey = '_'.join([row['mousename'], *kwargs.values()])
            results = np.hstack(results)
            resultsDict['corr'][dictKey] = results[0]
            resultsDict['pval'][dictKey] = results[1]
    return resultsDict
Ejemplo n.º 2
0
def compute_mean_interval(dataDB,
                          ds,
                          trialTypeTrg,
                          skipExisting=False,
                          exclQueryLst=None,
                          **kwargs):  # intervName=None,
    dataName = 'mean'

    dps = DataParameterSweep(dataDB,
                             exclQueryLst,
                             mousename='auto',
                             trialType=trialTypeTrg,
                             **kwargs)
    for idx, row in dps.sweepDF.iterrows():
        print(list(row))

        for session in dataDB.get_sessions(row['mousename'],
                                           datatype=row['datatype']):
            attrsDict = {**{'session': session}, **dict(row)}

            dsDataLabels = ds.ping_data(dataName, attrsDict)
            if not skipExisting and len(dsDataLabels) > 0:
                dsuffix = dataName + '_' + '_'.join(attrsDict.values())
                print('Skipping existing', dsuffix)
            else:
                dataRSP = dataDB.get_neuro_data({'session': session},
                                                datatype=row['datatype'],
                                                intervName=row['intervName'],
                                                trialType=row['trialType'])[0]

                dataRP = np.mean(dataRSP, axis=1)

                ds.delete_rows(dsDataLabels, verbose=False)
                ds.save_data(dataName, dataRP, attrsDict)
def calc_metric_session(dataDB, mc, ds, metricName, dimOrdTrg, nameSuffix, minTrials=1, skipExisting=False,
                        metricSettings=None, sweepSettings=None, dropChannels=None,
                        verbose=True, exclQueryLst=None, **kwargs):  # dataTypes='auto', trialTypeNames=None, perfNames=None, intervNames=None,

    autoAppendDict = {'trialType' : ['None'], 'performance': ['None']}
    dps = DataParameterSweep(dataDB, exclQueryLst, autoAppendDict=autoAppendDict, mousename='auto', **kwargs)

    progBar = IntProgress(min=0, max=len(dps.sweepDF), description=nameSuffix)
    display(progBar)  # display the bar

    for idx, row in dps.sweepDF.iterrows():
        kwargs = dict(row)
        del kwargs['mousename']

        zscoreDim = 'rs' if kwargs['datatype'] == 'raw' else None

        if verbose:
            print(metricName, nameSuffix, kwargs)

        metric_by_session(dataDB, mc, ds, row['mousename'], metricName, dimOrdTrg,
                          skipExisting=skipExisting,
                          dropChannels=dropChannels,
                          dataName=nameSuffix,
                          minTrials=minTrials,
                          zscoreDim=zscoreDim,
                          metricSettings=metricSettings,
                          sweepSettings=sweepSettings, timeAvg=True,
                          **kwargs)

        progBar.value += 1
Ejemplo n.º 4
0
def classification_accuracy_brainplot_mousephase(dataDB,
                                                 exclQueryLst,
                                                 fontsize=20,
                                                 trialType='auto',
                                                 **kwargs):
    assert 'intervName' in kwargs.keys(), 'Requires phases'
    dps = DataParameterSweep(dataDB, exclQueryLst, mousename='auto', **kwargs)
    nMice = dps.param_size('mousename')
    nInterv = dps.param_size('intervName')

    trialType = trialType if trialType != 'auto' else dataDB.get_trial_type_names(
    )

    for paramVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', 'intervName'])):
        plotSuffix = param_vals_to_suffix(paramVals)
        print(plotSuffix)

        fig, ax = plt.subplots(nrows=nMice,
                               ncols=nInterv,
                               figsize=(4 * nInterv, 4 * nMice))

        for mousename, dfMouse in dfTmp.groupby(['mousename']):
            iMouse = dps.param_index('mousename', mousename)

            ax[iMouse][0].set_ylabel(mousename, fontsize=fontsize)
            for idx, row in dfMouse.iterrows():
                intervName = row['intervName']
                iInterv = dps.param_index('intervName', intervName)
                ax[0][iInterv].set_title(intervName, fontsize=fontsize)

                kwargsThis = pd_row_to_kwargs(row,
                                              parseNone=True,
                                              dropKeys=['mousename'])
                dataRPLst = [
                    get_data_avg(dataDB, mousename, avgAxes=1, **kwargsThis)
                    for tt in trialType
                ]

                # Split two textures
                dataT1 = np.concatenate([dataRPLst[0], dataRPLst[1]])
                dataT2 = np.concatenate([dataRPLst[2], dataRPLst[3]])

                svcAcc = [
                    classification_accuracy_weighted(x[:, None], y[:, None])
                    for x, y in zip(dataT1.T, dataT2.T)
                ]

                dataDB.plot_area_values(fig,
                                        ax[iMouse][iInterv],
                                        svcAcc,
                                        vmin=0.5,
                                        vmax=1.0,
                                        cmap='jet')

        prefixPath = 'pics/classification_accuracy/brainplot_mousephase/'
        make_path(prefixPath)
        plt.savefig(prefixPath + plotSuffix + '.svg')
        plt.close()
Ejemplo n.º 5
0
def activity_brainplot_mousephase_submouse(dataDB,
                                           exclQueryLst=None,
                                           vmin=None,
                                           vmax=None,
                                           fontsize=20,
                                           dpi=200,
                                           **kwargs):
    assert 'intervName' in kwargs.keys(), 'Requires phases'
    dps = DataParameterSweep(dataDB, exclQueryLst, mousename='auto', **kwargs)
    nMice = dps.param_size('mousename')
    nInterv = dps.param_size('intervName')

    for paramVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', 'intervName'])):
        plotSuffix = param_vals_to_suffix(paramVals)
        print(plotSuffix)

        fig, ax = plt.subplots(nrows=nMice,
                               ncols=nInterv,
                               figsize=(4 * nInterv, 4 * nMice),
                               tight_layout=True)

        for intervName, dfInterv in dfTmp.groupby(['intervName']):
            iInterv = dps.param_index('intervName', intervName)

            ax[0][iInterv].set_title(intervName, fontsize=fontsize)

            rezDict = {}
            for idx, row in dfInterv.iterrows():
                mousename = row['mousename']
                kwargsThis = pd_row_to_kwargs(row,
                                              parseNone=True,
                                              dropKeys=['mousename'])
                rezDict[mousename] = get_data_avg(dataDB,
                                                  mousename,
                                                  avgAxes=(0, 1),
                                                  **kwargsThis)

            dataPsub = np.mean(list(rezDict.values()), axis=0)
            for idx, row in dfInterv.iterrows():
                mousename = row['mousename']
                iMouse = dps.param_index('mousename', mousename)
                ax[iMouse][0].set_ylabel(mousename, fontsize=fontsize)
                dataPDelta = rezDict[mousename] - dataPsub

                haveColorBar = iInterv == nInterv - 1
                dataDB.plot_area_values(fig,
                                        ax[iMouse][iInterv],
                                        dataPDelta,
                                        vmin=vmin,
                                        vmax=vmax,
                                        cmap='jet',
                                        haveColorBar=haveColorBar)

        prefixPath = 'pics/activity/brainplot_mousephase/submouse/'
        make_path(prefixPath)
        plt.savefig(prefixPath + plotSuffix + '.svg')
        plt.close()
Ejemplo n.º 6
0
def activity_brainplot_mouse(dataDB,
                             xParamName,
                             exclQueryLst=None,
                             vmin=None,
                             vmax=None,
                             fontsize=20,
                             dpi=200,
                             **kwargs):
    assert xParamName in kwargs.keys(), 'Requires ' + xParamName
    dps = DataParameterSweep(dataDB, exclQueryLst, mousename='auto', **kwargs)
    nMice = dps.param_size('mousename')
    nXParam = dps.param_size(xParamName)

    for paramVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', xParamName])):
        plotSuffix = param_vals_to_suffix(paramVals)
        print(plotSuffix)

        fig, ax = plt.subplots(nrows=nMice,
                               ncols=nXParam,
                               figsize=(4 * nXParam, 4 * nMice),
                               tight_layout=True)

        for mousename, dfMouse in dfTmp.groupby(['mousename']):
            iMouse = dps.param_index('mousename', mousename)

            ax[iMouse][0].set_ylabel(mousename, fontsize=fontsize)
            for idx, row in dfMouse.iterrows():
                xParamVal = row[xParamName]
                iXParam = dps.param_index(xParamName, xParamVal)

                ax[0][iXParam].set_title(xParamVal, fontsize=fontsize)

                kwargsThis = pd_row_to_kwargs(row,
                                              parseNone=True,
                                              dropKeys=['mousename'])
                dataLst = dataDB.get_neuro_data({'mousename': mousename},
                                                **kwargsThis)
                dataRSP = np.concatenate(dataLst, axis=0)
                dataP = np.mean(dataRSP, axis=(0, 1))

                haveColorBar = iXParam == nXParam - 1
                dataDB.plot_area_values(fig,
                                        ax[iMouse][iXParam],
                                        dataP,
                                        vmin=vmin,
                                        vmax=vmax,
                                        cmap='jet',
                                        haveColorBar=haveColorBar)

        prefixPath = 'pics/activity/brainplot_' + xParamName + '/'
        make_path(prefixPath)
        plt.savefig(prefixPath + plotSuffix + '.svg')
        plt.close()
Ejemplo n.º 7
0
def multiprocess_session(dataDB, mc, h5outname, argSweepDict, exclQueryLst,
                         metricType, **kwargsMetric):
    # dim=3, nBin=4, metric='BivariatePID', permuteTarget=False, dropChannels=None, timeSweep=False

    h5lock.touch_file(h5outname)  # If output file does not exist, create it
    h5lock.touch_group(
        h5outname, 'lock')  # If lock group does not exist, create lock group

    dps = DataParameterSweep(dataDB, exclQueryLst, **argSweepDict)

    for idx, row in dps.sweepDF.iterrows():
        # channelNames = dataDB.get_channel_labels(row['mousename'])
        # nChannels = len(channelNames)

        keyDataMouse = metricType + '_' + '_'.join(
            [str(key) for key in row.values])
        keyLabelMouse = 'Label_' + '_'.join([str(key) for key in row.values])
        for session in dataDB.get_sessions(row['mousename'],
                                           datatype=row['datatype']):
            keyDataSession = keyDataMouse + '_' + session
            keyLabelSession = keyLabelMouse + '_' + session
            print(keyDataSession)

            # Test if this parameter combination not yet calculated
            if h5lock.lock_test_available(h5outname, keyDataSession):
                kwargsData = dict(row)
                del kwargsData['mousename']
                kwargsData = {
                    k: v if v != 'None' else None
                    for k, v in kwargsData.items()
                }

                # Get data
                dataLst = dataDB.get_neuro_data({'session': session},
                                                zscoreDim=None,
                                                **kwargsData)

                # Calculate Metric
                rezIdxs, rezVals = _metric_results(metricType, dataLst, mc,
                                                   **kwargsMetric)

                # Save to file
                h5lock.unlock_write(h5outname, keyLabelSession, keyDataSession,
                                    np.array(rezIdxs), rezVals)
def compute_store_corr_mouse(dataDB,
                             ds,
                             trialTypeTrg,
                             skipExisting=False,
                             exclQueryLst=None,
                             **kwargs):  # intervName=None,
    dataName = 'corr_mouse'

    dps = DataParameterSweep(dataDB,
                             exclQueryLst,
                             mousename='auto',
                             trialType=trialTypeTrg,
                             **kwargs)
    for idx, row in dps.sweepDF.iterrows():
        print(list(row))

        mousename = row['mousename']
        queryDict = dict(row)
        del queryDict['mousename']
        attrsDict = {**{'mousename': mousename}, **queryDict}

        dsDataLabels = ds.ping_data(dataName, attrsDict)
        if not skipExisting and len(dsDataLabels) > 0:
            dsuffix = dataName + '_' + '_'.join(attrsDict.values())
            print('Skipping existing', dsuffix)
        else:
            dataRSPLst = dataDB.get_neuro_data({'mousename': mousename},
                                               datatype=row['datatype'],
                                               intervName=row['intervName'],
                                               trialType=row['trialType'])

            dataRSP = np.concatenate(dataRSPLst, axis=0)
            dataRP = np.mean(dataRSP, axis=1)
            cc = np.corrcoef(dataRP.T)

            ds.delete_rows(dsDataLabels, verbose=False)
            ds.save_data(dataName, cc, attrsDict)
Ejemplo n.º 9
0
def multiprocess_mouse_trgsweep(dataDB,
                                mc,
                                h5outname,
                                argSweepDict,
                                exclQueryLst,
                                metricType,
                                dropChannels=None,
                                **kwargsMetric):
    # dim=3, nBin=4, metric='BivariatePID', permuteTarget=False, timeSweep=False

    h5lock.touch_file(h5outname)  # If output file does not exist, create it
    h5lock.touch_group(
        h5outname, 'lock')  # If lock group does not exist, create lock group

    dps = DataParameterSweep(dataDB, exclQueryLst, **argSweepDict)

    channelLabels = dataDB.get_channel_labels()
    haveDelay = 'DEL' in dataDB.get_interval_names()

    for idx, row in dps.sweepDF.iterrows():
        for iTrg, trgLabel in enumerate(channelLabels):
            # Ensure target is not dropped
            if (dropChannels is None) or (iTrg not in dropChannels):
                keyDataMouse = metricType + '_' + '_'.join(
                    [str(key) for key in row.values] + [str(iTrg)])
                keyLabelMouse = 'Label_' + '_'.join(
                    [str(key) for key in row.values] + [str(iTrg)])

                # Test if this parameter combination not yet calculated
                if h5lock.lock_test_available(h5outname, keyDataMouse):

                    # Sources are all channels - target - dropped
                    exclChannels = [iTrg]
                    if dropChannels is not None:
                        exclChannels += dropChannels
                    srcLabels = [
                        ch for iCh, ch in enumerate(channelLabels)
                        if iCh not in exclChannels
                    ]

                    kwargsData = dict(row)
                    del kwargsData['mousename']
                    kwargsData = {
                        k: v if v != 'None' else None
                        for k, v in kwargsData.items()
                    }

                    # Get data
                    dataLst = get_data_list(dataDB,
                                            haveDelay,
                                            row['mousename'],
                                            zscoreDim=None,
                                            **kwargsData)

                    print(len(dataLst), dataLst[0].shape, haveDelay,
                          row['mousename'], kwargsData)

                    # dataLst = dataDB.get_neuro_data({'mousename': row['mousename']}, zscoreDim=None, **kwargs)

                    # Calculate metric
                    rezIdxs, rezVals = _metric_results(metricType,
                                                       dataLst,
                                                       mc,
                                                       labelsAll=channelLabels,
                                                       labelsSrc=srcLabels,
                                                       labelsTrg=[trgLabel],
                                                       **kwargsMetric)

                    # Save to file
                    h5lock.unlock_write(h5outname, keyLabelMouse, keyDataMouse,
                                        np.array(rezIdxs), rezVals)
Ejemplo n.º 10
0
def activity_brainplot_mousephase_subpre(dataDB,
                                         exclQueryLst=None,
                                         vmin=None,
                                         vmax=None,
                                         fontsize=20,
                                         dpi=200,
                                         **kwargs):
    assert 'intervName' in kwargs.keys(), 'Requires phases'
    dps = DataParameterSweep(dataDB,
                             exclQueryLst,
                             mousename='auto',
                             datatype=['bn_session'],
                             **kwargs)
    nMice = dps.param_size('mousename')
    nInterv = dps.param_size('intervName')

    for paramVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', 'intervName'])):
        plotSuffix = param_vals_to_suffix(paramVals)
        print(plotSuffix)

        fig, ax = plt.subplots(nrows=nMice,
                               ncols=nInterv,
                               figsize=(4 * nInterv, 4 * nMice),
                               tight_layout=True)
        for mousename, dfMouse in dfTmp.groupby(['mousename']):
            iMouse = dps.param_index('mousename', mousename)

            ax[iMouse][0].set_ylabel(mousename, fontsize=fontsize)

            kwargsPre = pd_row_to_kwargs(pd_first_row(dfMouse)[1],
                                         parseNone=True,
                                         dropKeys=['mousename', 'intervName'])
            kwargsPre['intervName'] = 'PRE'
            dataPPre = get_data_avg(dataDB,
                                    mousename,
                                    avgAxes=(0, 1),
                                    **kwargsPre)

            for idx, row in dfMouse.iterrows():
                intervName = row['intervName']
                iInterv = dps.param_index('intervName', intervName)

                if intervName != 'PRE':
                    ax[0][iInterv].set_title(intervName, fontsize=fontsize)

                    kwargsThis = pd_row_to_kwargs(row,
                                                  parseNone=True,
                                                  dropKeys=['mousename'])
                    dataP = get_data_avg(dataDB,
                                         mousename,
                                         avgAxes=(0, 1),
                                         **kwargsThis)

                    dataPDelta = dataP - dataPPre

                    haveColorBar = iInterv == nInterv - 1
                    dataDB.plot_area_values(fig,
                                            ax[iMouse][iInterv],
                                            dataPDelta,
                                            vmin=vmin,
                                            vmax=vmax,
                                            cmap='jet',
                                            haveColorBar=haveColorBar)

        prefixPath = 'pics/activity/brainplot_mousephase/subpre/'
        make_path(prefixPath)
        plt.savefig(prefixPath + plotSuffix + '.svg')
        plt.close()
def plot_corr_consistency_l1_mouse(
        dataDB,
        nDropPCA=None,
        dropChannels=None,
        exclQueryLst=None,
        **kwargs):  # performances=None, trialTypes=None,

    assert 'intervName' in kwargs.keys(), 'Requires phases'
    dps = DataParameterSweep(dataDB,
                             exclQueryLst,
                             mousename='auto',
                             intervName='auto',
                             datatype='auto',
                             **kwargs)
    mice = sorted(dataDB.mice)
    nMice = len(mice)

    for paramExtraVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', 'datatype', 'intervName'])):
        plotExtraSuffix = param_vals_to_suffix(paramExtraVals)

        dfColumns = ['datatype', 'phase', 'consistency']
        dfConsistency = pd.DataFrame(columns=dfColumns)

        for paramVals, dfMouse in dfTmp.groupby(dps.invert_param(['mousename'
                                                                  ])):
            plotSuffix = param_vals_to_suffix(paramVals)
            print(plotSuffix)

            corrLst = []
            for idx, row in dfMouse.iterrows():
                plotSuffix = '_'.join([str(s) for s in row.values])
                print(plotSuffix)

                kwargsThis = pd_row_to_kwargs(row,
                                              parseNone=True,
                                              dropKeys=['mousename'])
                dataRSPLst = dataDB.get_neuro_data(
                    {
                        'mousename': row['mousename']
                    },  # NOTE: zscore channels for each session to avoid session-wise effects
                    zscoreDim='rs',
                    **kwargsThis)

                dataRSP = np.concatenate(dataRSPLst, axis=0)
                dataRP = np.mean(dataRSP, axis=1)
                # dataRP = zscore(dataRP, axis=0)

                if dropChannels is not None:
                    channelMask = np.ones(dataRP.shape[1]).astype(bool)
                    channelMask[dropChannels] = 0
                    dataRP = dataRP[:, channelMask]

                if nDropPCA is not None:
                    dataRP = drop_PCA(dataRP, nDropPCA)

                corrLst += [tril_1D(np.corrcoef(dataRP.T))]

            # fig, ax = plt.subplots(nrows=nMice, ncols=nMice, figsize=(4 * nMice, 4 * nMice))

            pairDict = {}
            rezMat = np.zeros((nMice, nMice))
            for iMouse, mousename in enumerate(mice):
                # ax[iMouse][0].set_ylabel(mice[iMouse])
                # ax[-1][iMouse].set_xlabel(mice[iMouse])
                pairDict[mousename] = corrLst[iMouse]

                for jMouse in range(nMice):
                    # rezMat[iMouse][jMouse] = 1 - rmae(corrLst[iMouse], corrLst[jMouse])
                    rezMat[iMouse][jMouse] = np.corrcoef(
                        corrLst[iMouse], corrLst[jMouse])[0, 1]

                    # cci = offdiag_1D(corrLst[iMouse])
                    # ccj = offdiag_1D(corrLst[jMouse])
                    # ax[iMouse][jMouse].plot(cci, ccj, '.')

            pPlot = sns.pairplot(data=pd.DataFrame(pairDict),
                                 vars=mice,
                                 kind='kde')

            prefixPath = 'pics/consistency/corr/mouse/dropPCA_' + str(
                nDropPCA) + '/scatter/'
            make_path(prefixPath)
            plt.savefig(prefixPath + 'scatter_' + plotSuffix + '.svg')
            plt.close()

            fig, ax = plt.subplots()
            imshow(fig,
                   ax,
                   rezMat,
                   haveColorBar=True,
                   limits=[0, 1],
                   xTicks=mice,
                   yTicks=mice,
                   cmap='jet')

            prefixPath = 'pics/consistency/corr/mouse/dropPCA_' + str(
                nDropPCA) + '/metric/'
            make_path(prefixPath)
            plt.savefig(prefixPath + 'metric_' + plotSuffix + '.svg')
            plt.close()

            avgConsistency = np.round(np.mean(offdiag_1D(rezMat)), 2)
            dfConsistency = pd_append_row(
                dfConsistency,
                [row['datatype'], row['intervName'], avgConsistency])

        fig, ax = plt.subplots()
        dfPivot = pd_pivot(dfConsistency, *dfColumns)
        sns.heatmap(data=dfPivot,
                    ax=ax,
                    annot=True,
                    vmin=0,
                    vmax=1,
                    cmap='jet')

        prefixPath = 'pics/consistency/corr/mouse/dropPCA_' + str(
            nDropPCA) + '/'
        make_path(prefixPath)
        fig.savefig(prefixPath + plotExtraSuffix + '.svg')
        plt.close()
def plot_corr_mousephase_submouse(dataDB,
                                  mc,
                                  estimator,
                                  nDropPCA=None,
                                  dropChannels=None,
                                  exclQueryLst=None,
                                  corrStrategy='mean',
                                  fontsize=20,
                                  **kwargs):

    assert 'intervName' in kwargs.keys(), 'Requires phases'
    dps = DataParameterSweep(dataDB, exclQueryLst, mousename='auto', **kwargs)
    nMice = dps.param_size('mousename')
    nInterv = dps.param_size('intervName')

    for paramVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', 'intervName'])):
        plotSuffix = param_vals_to_suffix(paramVals)
        print(plotSuffix)

        figCorr, axCorr = plt.subplots(nrows=nMice,
                                       ncols=nInterv,
                                       figsize=(4 * nInterv, 4 * nMice))

        for intervName, dfInterv in dfTmp.groupby(['intervName']):
            iInterv = dps.param_index('intervName', intervName)

            axCorr[0][iInterv].set_title(intervName, fontsize=fontsize)

            rezDict = {}
            for idx, row in dfInterv.iterrows():
                mousename = row['mousename']

                kwargsThis = pd_row_to_kwargs(row,
                                              parseNone=True,
                                              dropKeys=['mousename'])
                channelLabels, rez2D = calc_corr_mouse(
                    dataDB,
                    mc,
                    mousename,
                    strategy=corrStrategy,
                    nDropPCA=nDropPCA,
                    dropChannels=dropChannels,
                    estimator=estimator,
                    **kwargsThis)

                rezDict[mousename] = rez2D

            # Plot correlations
            rezMean = np.mean(list(rezDict.values()), axis=0)

            for mousename, rezMouse in rezDict.items():
                iMouse = dps.param_index('mousename', mousename)
                axCorr[iMouse][0].set_ylabel(mousename, fontsize=fontsize)

                if mousename in rezDict.keys():
                    haveColorBar = iInterv == nInterv - 1
                    imshow(figCorr,
                           axCorr[iMouse][iInterv],
                           rezMouse - rezMean,
                           title='corr',
                           haveColorBar=haveColorBar,
                           limits=[-1, 1],
                           cmap='RdBu_r')

        # Save image
        prefixPath = 'pics/corr/mousephase/dropPCA_' + str(
            nDropPCA) + '/submouse/'
        make_path(prefixPath)
        figCorr.savefig(prefixPath + 'corr_' + plotSuffix + '.svg')
        plt.close()
def plot_corr_mouse(dataDB,
                    mc,
                    estimator,
                    xParamName,
                    nDropPCA=None,
                    dropChannels=None,
                    haveBrain=False,
                    haveMono=True,
                    corrStrategy='mean',
                    exclQueryLst=None,
                    thrMono=0.4,
                    clusterParam=0.5,
                    fontsize=20,
                    **kwargs):

    assert xParamName in ['intervName', 'trialType'], 'Unexpected parameter'
    assert xParamName in kwargs.keys(), 'Requires ' + xParamName
    dps = DataParameterSweep(dataDB, exclQueryLst, mousename='auto', **kwargs)
    nMice = dps.param_size('mousename')
    nXParam = dps.param_size(xParamName)

    for paramVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', xParamName])):
        plotSuffix = param_vals_to_suffix(paramVals)
        print(plotSuffix)

        figCorr, axCorr = plt.subplots(nrows=nMice,
                                       ncols=nXParam,
                                       figsize=(4 * nXParam, 4 * nMice),
                                       tight_layout=True)
        figClust, axClust = plt.subplots(nrows=nMice,
                                         ncols=nXParam,
                                         figsize=(4 * nXParam, 4 * nMice),
                                         tight_layout=True)
        if haveBrain:
            figBrain, axBrain = plt.subplots(nrows=nMice,
                                             ncols=nXParam,
                                             figsize=(4 * nXParam, 4 * nMice),
                                             tight_layout=True)
        if haveMono:
            figMono, axMono = plt.subplots(nrows=nMice,
                                           ncols=nXParam,
                                           figsize=(4 * nXParam, 4 * nMice),
                                           tight_layout=True)

        for mousename, dfMouse in dfTmp.groupby(['mousename']):
            iMouse = dps.param_index('mousename', mousename)

            axCorr[iMouse][0].set_ylabel(mousename, fontsize=fontsize)
            axClust[iMouse][0].set_ylabel(mousename, fontsize=fontsize)

            if haveBrain:
                axBrain[iMouse][0].set_ylabel(mousename, fontsize=fontsize)
            if haveMono:
                axMono[iMouse][0].set_ylabel(mousename, fontsize=fontsize)

            for idx, row in dfMouse.iterrows():
                xParamVal = row[xParamName]
                iXParam = dps.param_index(xParamName, xParamVal)

                axCorr[0][iXParam].set_title(xParamVal, fontsize=fontsize)
                axClust[0][iXParam].set_title(xParamVal, fontsize=fontsize)

                if haveBrain:
                    axBrain[0][iXParam].set_title(xParamVal, fontsize=fontsize)
                if haveMono:
                    axMono[0][iXParam].set_title(xParamVal, fontsize=fontsize)

                kwargsThis = pd_row_to_kwargs(row,
                                              parseNone=True,
                                              dropKeys=['mousename'])
                channelLabels, rez2D = calc_corr_mouse(
                    dataDB,
                    mc,
                    mousename,
                    strategy=corrStrategy,
                    nDropPCA=nDropPCA,
                    dropChannels=dropChannels,
                    estimator=estimator,
                    **kwargsThis)

                haveColorBar = iXParam == nXParam - 1

                # Plot correlations
                imshow(figCorr,
                       axCorr[iMouse][iXParam],
                       rez2D,
                       limits=[-1, 1],
                       cmap='jet',
                       haveColorBar=haveColorBar)

                # Plot clustering
                clusters = cluster_dist_matrix_max(rez2D,
                                                   clusterParam,
                                                   method='Affinity')
                cluster_plot(figClust,
                             axClust[iMouse][iXParam],
                             rez2D,
                             clusters,
                             channelLabels,
                             limits=[-1, 1],
                             cmap='jet',
                             haveColorBar=haveColorBar)

                if haveBrain:
                    cluster_brain_plot(figBrain,
                                       axBrain[iMouse][iXParam],
                                       dataDB,
                                       clusters,
                                       dropChannels=dropChannels)

                if haveMono:
                    _plot_corr_1D(figMono, axMono[iMouse][iXParam],
                                  channelLabels, rez2D, thrMono)

        # Save image
        prefixPrefixPath = 'pics/corr/mouse' + xParamName + '/dropPCA_' + str(
            nDropPCA) + '/'

        prefixPath = prefixPrefixPath + 'corr/'
        make_path(prefixPath)
        figCorr.savefig(prefixPath + 'corr_' + plotSuffix + '.svg')
        plt.close(figCorr)

        prefixPath = prefixPrefixPath + 'clust/'
        make_path(prefixPath)
        figClust.savefig(prefixPath + 'clust_' + plotSuffix + '.svg')
        plt.close(figClust)
        if haveBrain:
            prefixPath = prefixPrefixPath + 'clust_brainplot/'
            make_path(prefixPath)
            figBrain.savefig(prefixPath + 'clust_brainplot_' + plotSuffix +
                             '.svg')
            plt.close(figBrain)
        if haveMono:
            prefixPath = prefixPrefixPath + '1D/'
            make_path(prefixPath)
            figMono.savefig(prefixPath + '1Dplot_' + plotSuffix + '.svg')
            plt.close(figMono)
def movie_mouse_trialtype(dataDB,
                          dataKWArgs,
                          calcKWArgs,
                          plotKWArgs,
                          calc_func,
                          plot_func,
                          prefixPath='',
                          exclQueryLst=None,
                          haveDelay=False,
                          fontsize=20,
                          tTrgDelay=2.0,
                          tTrgRew=2.0):
    assert 'trialType' in dataKWArgs.keys(), 'Requires trial types'
    assert 'intervName' not in dataKWArgs.keys(
    ), 'Movie intended for full range'
    dps = DataParameterSweep(dataDB,
                             exclQueryLst,
                             mousename='auto',
                             **dataKWArgs)
    nMice = dps.param_size('mousename')
    nTrialType = dps.param_size('trialType')

    for paramVals, dfTmp in dps.sweepDF.groupby(
            dps.invert_param(['mousename', 'trialType'])):
        plotSuffix = param_vals_to_suffix(paramVals)

        # Store all preprocessed data first
        dataDict = {}
        for mousename, dfMouse in dfTmp.groupby(['mousename']):
            for idx, row in dfMouse.iterrows():
                trialType = row['trialType']
                print('Reading data, ', plotSuffix, mousename, trialType)

                kwargsThis = pd_row_to_kwargs(row,
                                              parseNone=True,
                                              dropKeys=['mousename'])

                dataDict[(mousename,
                          trialType)] = calc_func(dataDB,
                                                  mousename,
                                                  calcKWArgs,
                                                  haveDelay=haveDelay,
                                                  tTrgDelay=tTrgDelay,
                                                  tTrgRew=tTrgRew,
                                                  **kwargsThis)

        # Test that all datasets have the same duration
        shapeSet = set([v.shape for v in dataDict.values()])
        assert len(shapeSet) == 1
        nTimes = shapeSet.pop()[0]

        progBar = IntProgress(min=0, max=nTimes, description=plotSuffix)
        display(progBar)  # display the bar
        for iTime in range(nTimes):
            make_path(prefixPath)
            outfname = prefixPath + plotSuffix + '_' + str(iTime) + '.png'

            if os.path.isfile(outfname):
                print('Already calculated', iTime, 'skipping')
                progBar.value += 1
                continue

            fig, ax = plt.subplots(nrows=nMice,
                                   ncols=nTrialType,
                                   figsize=(4 * nTrialType, 4 * nMice),
                                   tight_layout=True)

            for iMouse, mousename in enumerate(dps.param('mousename')):
                ax[iMouse][0].set_ylabel(mousename, fontsize=fontsize)
                for iTT, trialType in enumerate(dps.param('trialType')):
                    ax[0][iTT].set_title(trialType, fontsize=fontsize)
                    # print(datatype, mousename)

                    dataP = dataDict[(mousename, trialType)][iTime]

                    rightMost = iTT == nTrialType - 1
                    plot_func(dataDB,
                              fig,
                              ax[iMouse][iTT],
                              dataP,
                              haveColorBar=rightMost,
                              **plotKWArgs)

            # Add a timescale bar to the figure
            timestamps = dataDB.get_timestamps(mousename, session=None)
            if 'delay' not in timestamps.keys():
                tsKeys = ['PRE'] + list(timestamps.keys())
                tsVals = list(
                    timestamps.values()) + [nTimes / dataDB.targetFPS]
            else:
                tsKeys = ['PRE'] + list(timestamps.keys()) + ['reward']
                tsVals = list(timestamps.values()) + [
                    timestamps['delay'] + tTrgDelay, nTimes / dataDB.targetFPS
                ]

            print(tsVals, iTime / dataDB.targetFPS)
            add_timescale_bar(fig, tsKeys, tsVals, iTime / dataDB.targetFPS)

            fig.savefig(outfname, bbox_inches='tight')
            # plt.close()
            plt.cla()
            plt.clf()
            plt.close('all')
            progBar.value += 1
    return prefixPath