Exemplo n.º 1
0
def main():

    #################################################
    ## SETUP

    ## Get list of subject files
    subj_files = listdir(DAT_PATH)
    subj_files = [file for file in subj_files if EXT.lower() in file.lower()]

    ## Set up FOOOF Objects
    # Initialize FOOOF settings & objects objects
    fooof_settings = FOOOFSettings(peak_width_limits=PEAK_WIDTH_LIMITS, max_n_peaks=MAX_N_PEAKS,
                                   min_peak_amplitude=MIN_PEAK_AMP, peak_threshold=PEAK_THRESHOLD,
                                   aperiodic_mode=APERIODIC_MODE)
    fm = FOOOF(*fooof_settings, verbose=False)
    fg = FOOOFGroup(*fooof_settings, verbose=False)

    # Save out a settings file
    fg.save('0-FOOOF_Settings', pjoin(RES_PATH, 'FOOOF'), save_settings=True)

    # Set up the dictionary to store all the FOOOF results
    fg_dict = dict()
    for load_label in LOAD_LABELS:
        fg_dict[load_label] = dict()
        for side_label in SIDE_LABELS:
            fg_dict[load_label][side_label] = dict()
            for seg_label in SEG_LABELS:
                fg_dict[load_label][side_label][seg_label] = []

    ## Initialize group level data stores
    n_subjs, n_conds, n_times = len(subj_files), 3, N_TIMES
    group_fooofed_alpha_freqs = np.zeros(shape=[n_subjs])
    dropped_components = np.ones(shape=[n_subjs, 50]) * 999
    dropped_trials = np.ones(shape=[n_subjs, 1500]) * 999
    canonical_group_avg_dat = np.zeros(shape=[n_subjs, n_conds, n_times])
    fooofed_group_avg_dat = np.zeros(shape=[n_subjs, n_conds, n_times])

    # Set channel types
    ch_types = {'LHor' : 'eog', 'RHor' : 'eog', 'IVer' : 'eog', 'SVer' : 'eog',
                'LMas' : 'misc', 'RMas' : 'misc', 'Nose' : 'misc', 'EXG8' : 'misc'}

    #################################################
    ## RUN ACROSS ALL SUBJECTS

    # Run analysis across each subject
    for s_ind, subj_file in enumerate(subj_files):

        # Get subject label and print status
        subj_label = subj_file.split('.')[0]
        print('\nCURRENTLY RUNNING SUBJECT: ', subj_label, '\n')

        #################################################
        ## LOAD / ORGANIZE / SET-UP DATA

        # Load subject of data, apply apply fixes for channels, etc
        eeg_dat = mne.io.read_raw_edf(pjoin(DAT_PATH, subj_file),
                                      preload=True, verbose=False)

        # Fix channel name labels
        eeg_dat.info['ch_names'] = [chl[2:] for chl in \
            eeg_dat.ch_names[:-1]] + [eeg_dat.ch_names[-1]]
        for ind, chi in enumerate(eeg_dat.info['chs']):
            eeg_dat.info['chs'][ind]['ch_name'] = eeg_dat.info['ch_names'][ind]

        # Update channel types
        eeg_dat.set_channel_types(ch_types)

        # Set reference - average reference
        eeg_dat = eeg_dat.set_eeg_reference(ref_channels='average',
                                            projection=False, verbose=False)

        # Set channel montage
        chs = mne.channels.read_montage('standard_1020', eeg_dat.ch_names)
        eeg_dat.set_montage(chs)

        # Get event information & check all used event codes
        evs = mne.find_events(eeg_dat, shortest_event=1, verbose=False)

        # Pull out sampling rate
        srate = eeg_dat.info['sfreq']

        #################################################
        ## Pre-Processing: ICA

        # High-pass filter data for running ICA
        eeg_dat.filter(l_freq=1., h_freq=None, fir_design='firwin')

        if RUN_ICA:

            print("\nICA: CALCULATING SOLUTION\n")

            # ICA settings
            method = 'fastica'
            n_components = 0.99
            random_state = 47
            reject = {'eeg': 20e-4}

            # Initialize ICA object
            ica = ICA(n_components=n_components, method=method,
                      random_state=random_state)

            # Fit ICA
            ica.fit(eeg_dat, reject=reject)

            # Save out ICA solution
            ica.save(pjoin(RES_PATH, 'ICA', subj_label + '-ica.fif'))

        # Otherwise: load previously saved ICA to apply
        else:
            print("\nICA: USING PRECOMPUTED\n")
            ica = read_ica(pjoin(RES_PATH, 'ICA', subj_label + '-ica.fif'))

        # Find components to drop, based on correlation with EOG channels
        drop_inds = []
        for chi in EOG_CHS:
            inds, _ = ica.find_bads_eog(eeg_dat, ch_name=chi, threshold=2.5,
                                             l_freq=1, h_freq=10, verbose=False)
            drop_inds.extend(inds)
        drop_inds = list(set(drop_inds))

        # Set which components to drop, and collect record of this
        ica.exclude = drop_inds
        dropped_components[s_ind, 0:len(drop_inds)] = drop_inds

        # Apply ICA to data
        eeg_dat = ica.apply(eeg_dat)

        #################################################
        ## SORT OUT EVENT CODES

        # Extract a list of all the event labels
        all_trials = [it for it2 in EV_DICT.values() for it in it2]

        # Create list of new event codes to be used to label correct trials (300s)
        all_trials_new = [it + 100 for it in all_trials]
        # This is an annoying way to collapse across the doubled event markers from above
        all_trials_new = [it - 1 if not ind%2 == 0 else it for ind, it in enumerate(all_trials_new)]
        # Get labelled dictionary of new event names
        ev_dict2 = {k:v for k, v in zip(EV_DICT.keys(), set(all_trials_new))}

        # Initialize variables to store new event definitions
        evs2 = np.empty(shape=[0, 3], dtype='int64')
        lags = np.array([])

        # Loop through, creating new events for all correct trials
        t_min, t_max = -0.4, 3.0
        for ref_id, targ_id, new_id in zip(all_trials, CORR_CODES * 6, all_trials_new):

            t_evs, t_lags = mne.event.define_target_events(evs, ref_id, targ_id, srate,
                                                           t_min, t_max, new_id)

            if len(t_evs) > 0:
                evs2 = np.vstack([evs2, t_evs])
                lags = np.concatenate([lags, t_lags])

        #################################################
        ## FOOOF

        # Set channel of interest
        ch_ind = eeg_dat.ch_names.index(CHL)

        # Calculate PSDs over ~ first 2 minutes of data, for specified channel
        fmin, fmax = 1, 50
        tmin, tmax = 5, 125
        psds, freqs = mne.time_frequency.psd_welch(eeg_dat, fmin=fmin, fmax=fmax,
                                                   tmin=tmin, tmax=tmax,
                                                   n_fft=int(2*srate), n_overlap=int(srate),
                                                   n_per_seg=int(2*srate),
                                                   verbose=False)

        # Fit FOOOF across all channels
        fg.fit(freqs, psds, FREQ_RANGE, n_jobs=-1)

        # Save out FOOOF results
        fg.save(subj_label + '_fooof', pjoin(RES_PATH, 'FOOOF'), save_results=True)

        # Extract individualized CF from specified channel, add to group collection
        fm = fg.get_fooof(ch_ind, False)
        fooof_freq, _, _ = get_band_peak(fm.peak_params_, [7, 14])
        group_fooofed_alpha_freqs[s_ind] = fooof_freq

        # If not FOOOF alpha extracted, reset to 10
        if np.isnan(fooof_freq):
            fooof_freq = 10

        #################################################
        ## ALPHA FILTERING

        # CANONICAL: Filter data to canonical alpha band: 8-12 Hz
        alpha_dat = eeg_dat.copy()
        alpha_dat.filter(8, 12, fir_design='firwin', verbose=False)
        alpha_dat.apply_hilbert(envelope=True, verbose=False)

        # FOOOF: Filter data to FOOOF derived alpha band
        fooof_dat = eeg_dat.copy()
        fooof_dat.filter(fooof_freq-2, fooof_freq+2, fir_design='firwin')
        fooof_dat.apply_hilbert(envelope=True)

        #################################################
        ## EPOCH TRIALS

        # Set epoch timings
        tmin, tmax = -0.85, 1.1

        # Epoch trials - raw data for trial rejection
        epochs = mne.Epochs(eeg_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax,
                            baseline=None, preload=True, verbose=False)

        # Epoch trials - filtered version
        epochs_alpha = mne.Epochs(alpha_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax,
                                  baseline=(-0.5, -0.35), preload=True, verbose=False)
        epochs_fooof = mne.Epochs(fooof_dat, evs2, ev_dict2, tmin=tmin, tmax=tmax,
                                  baseline=(-0.5, -0.35), preload=True, verbose=False)

        #################################################
        ## PRE-PROCESSING: AUTO-REJECT
        if RUN_AUTOREJECT:

            print('\nAUTOREJECT: CALCULATING SOLUTION\n')

            # Initialize and run autoreject across epochs
            ar = AutoReject(n_jobs=4, verbose=False)
            ar.fit(epochs)

            # Save out AR solution
            ar.save(pjoin(RES_PATH, 'AR', subj_label + '-ar.hdf5'), overwrite=True)

        # Otherwise: load & apply previously saved AR solution
        else:
            print('\nAUTOREJECT: USING PRECOMPUTED\n')
            ar = read_auto_reject(pjoin(RES_PATH, 'AR', subj_label + '-ar.hdf5'))
            ar.verbose = 'tqdm'

        # Apply autoreject to the original epochs object it was learnt on
        epochs, rej_log = ar.transform(epochs, return_log=True)

        # Apply autoreject to the copies of the data - apply interpolation, then drop same epochs
        _apply_interp(rej_log, epochs_alpha, ar.threshes_, ar.picks_, ar.verbose)
        epochs_alpha.drop(rej_log.bad_epochs)
        _apply_interp(rej_log, epochs_fooof, ar.threshes_, ar.picks_, ar.verbose)
        epochs_fooof.drop(rej_log.bad_epochs)

        # Collect which epochs were dropped
        dropped_trials[s_ind, 0:sum(rej_log.bad_epochs)] = np.where(rej_log.bad_epochs)[0]

        #################################################
        ## SET UP CHANNEL CLUSTERS

        # Set channel clusters - take channels contralateral to stimulus presentation
        #  Note: channels will be used to extract data contralateral to stimulus presentation
        le_chs = ['P3', 'P5', 'P7', 'P9', 'O1', 'PO3', 'PO7']       # Left Side Channels
        le_inds = [epochs.ch_names.index(chn) for chn in le_chs]
        ri_chs = ['P4', 'P6', 'P8', 'P10', 'O2', 'PO4', 'PO8']      # Right Side Channels
        ri_inds = [epochs.ch_names.index(chn) for chn in ri_chs]

        #################################################
        ## TRIAL-RELATED ANALYSIS: CANONICAL vs. FOOOF

        ## Pull out channels of interest for each load level
        #  Channels extracted are those contralateral to stimulus presentation

        # Canonical Data
        lo1_a = np.concatenate([epochs_alpha['LeLo1']._data[:, ri_inds, :],
                                epochs_alpha['RiLo1']._data[:, le_inds, :]], 0)
        lo2_a = np.concatenate([epochs_alpha['LeLo2']._data[:, ri_inds, :],
                                epochs_alpha['RiLo2']._data[:, le_inds, :]], 0)
        lo3_a = np.concatenate([epochs_alpha['LeLo3']._data[:, ri_inds, :],
                                epochs_alpha['RiLo3']._data[:, le_inds, :]], 0)

        # FOOOFed data
        lo1_f = np.concatenate([epochs_fooof['LeLo1']._data[:, ri_inds, :],
                                epochs_fooof['RiLo1']._data[:, le_inds, :]], 0)
        lo2_f = np.concatenate([epochs_fooof['LeLo2']._data[:, ri_inds, :],
                                epochs_fooof['RiLo2']._data[:, le_inds, :]], 0)
        lo3_f = np.concatenate([epochs_fooof['LeLo3']._data[:, ri_inds, :],
                                epochs_fooof['RiLo3']._data[:, le_inds, :]], 0)

        ## Calculate average across trials and channels - add to group data collection

        # Canonical data
        canonical_group_avg_dat[s_ind, 0, :] = np.mean(lo1_a, 1).mean(0)
        canonical_group_avg_dat[s_ind, 1, :] = np.mean(lo2_a, 1).mean(0)
        canonical_group_avg_dat[s_ind, 2, :] = np.mean(lo3_a, 1).mean(0)

        # FOOOFed data
        fooofed_group_avg_dat[s_ind, 0, :] = np.mean(lo1_f, 1).mean(0)
        fooofed_group_avg_dat[s_ind, 1, :] = np.mean(lo2_f, 1).mean(0)
        fooofed_group_avg_dat[s_ind, 2, :] = np.mean(lo3_f, 1).mean(0)

        #################################################
        ## FOOOFING TRIAL AVERAGED DATA

        # Loop loop loads & trials segments
        for seg_label, seg_time in zip(SEG_LABELS, SEG_TIMES):
            tmin, tmax = seg_time[0], seg_time[1]

            # Calculate PSDs across trials, fit FOOOF models to averages
            for le_label, ri_label, load_label in zip(['LeLo1', 'LeLo2', 'LeLo3'],
                                                      ['RiLo1', 'RiLo2', 'RiLo3'],
                                                      LOAD_LABELS):

                ## Calculate trial wise PSDs for left & right side trials
                trial_freqs, le_trial_psds = periodogram(
                    epochs[le_label]._data[:, :, _time_mask(epochs.times, tmin, tmax, srate)],
                    srate, window='hann', nfft=4*srate)
                trial_freqs, ri_trial_psds = periodogram(
                    epochs[ri_label]._data[:, :, _time_mask(epochs.times, tmin, tmax, srate)],
                    srate, window='hann', nfft=4*srate)

                ## FIT ALL CHANNELS VERSION
                if FIT_ALL_CHANNELS:

                    ## Average spectra across trials within a given load & side
                    le_avg_psd_contra = avg_func(le_trial_psds[:, ri_inds, :], 0)
                    le_avg_psd_ipsi = avg_func(le_trial_psds[:, le_inds, :], 0)
                    ri_avg_psd_contra = avg_func(ri_trial_psds[:, le_inds, :], 0)
                    ri_avg_psd_ipsi = avg_func(ri_trial_psds[:, ri_inds, :], 0)

                    ## Combine spectra across left & right trials for given load
                    ch_psd_contra = np.vstack([le_avg_psd_contra, ri_avg_psd_contra])
                    ch_psd_ipsi = np.vstack([le_avg_psd_ipsi, ri_avg_psd_ipsi])

                    ## Fit FOOOFGroup to all channels, average & and collect results
                    fg.fit(trial_freqs, ch_psd_contra, FREQ_RANGE)
                    fm = avg_fg(fg)
                    fg_dict[load_label]['Contra'][seg_label].append(fm.copy())
                    fg.fit(trial_freqs, ch_psd_ipsi, FREQ_RANGE)
                    fm = avg_fg(fg)
                    fg_dict[load_label]['Ipsi'][seg_label].append(fm.copy())

                ## COLLAPSE ACROSS CHANNELS VERSION
                else:

                    ## Average spectra across trials and channels within a given load & side
                    le_avg_psd_contra = avg_func(avg_func(le_trial_psds[:, ri_inds, :], 0), 0)
                    le_avg_psd_ipsi = avg_func(avg_func(le_trial_psds[:, le_inds, :], 0), 0)
                    ri_avg_psd_contra = avg_func(avg_func(ri_trial_psds[:, le_inds, :], 0), 0)
                    ri_avg_psd_ipsi = avg_func(avg_func(ri_trial_psds[:, ri_inds, :], 0), 0)

                    ## Collapse spectra across left & right trials for given load
                    avg_psd_contra = avg_func(np.vstack([le_avg_psd_contra, ri_avg_psd_contra]), 0)
                    avg_psd_ipsi = avg_func(np.vstack([le_avg_psd_ipsi, ri_avg_psd_ipsi]), 0)

                    ## Fit FOOOF, and collect results
                    fm.fit(trial_freqs, avg_psd_contra, FREQ_RANGE)
                    fg_dict[load_label]['Contra'][seg_label].append(fm.copy())
                    fm.fit(trial_freqs, avg_psd_ipsi, FREQ_RANGE)
                    fg_dict[load_label]['Ipsi'][seg_label].append(fm.copy())

    #################################################
    ## SAVE OUT RESULTS

    # Save out group data
    np.save(pjoin(RES_PATH, 'Group', 'alpha_freqs_group'), group_fooofed_alpha_freqs)
    np.save(pjoin(RES_PATH, 'Group', 'canonical_group'), canonical_group_avg_dat)
    np.save(pjoin(RES_PATH, 'Group', 'fooofed_group'), fooofed_group_avg_dat)
    np.save(pjoin(RES_PATH, 'Group', 'dropped_trials'), dropped_trials)
    np.save(pjoin(RES_PATH, 'Group', 'dropped_components'), dropped_components)

    # Save out second round of FOOOFing
    for load_label in LOAD_LABELS:
        for side_label in SIDE_LABELS:
            for seg_label in SEG_LABELS:
                fg = combine_fooofs(fg_dict[load_label][side_label][seg_label])
                fg.save('Group_' + load_label + '_' + side_label + '_' + seg_label,
                        pjoin(RES_PATH, 'FOOOF'), save_results=True)
Exemplo n.º 2
0
print(fgs)

###################################################################################################

# Compare the aperiodic exponent results across conditions
for ind, fg in enumerate(fgs):
    print("Aperiodic exponent for condition {} is {:1.4f}".format(
        ind, np.mean(fg.get_params('aperiodic_params', 'exponent'))))

###################################################################################################
# combine_fooofs
# --------------
#
# Depending what the organization of the data is, you might also want to collapse
# FOOOF models dimensions that have been fit.
#
#
# To do so, you can use the :func:`combine_fooofs` function, which takes
# a list of FOOOF or FOOOFGroup objects, and combines them together into
# a single FOOOFGroup object (assuming the settings and data definitions
# are consistent to do so).

###################################################################################################

# You can also combine a list of FOOOF objects into a single FOOOF object
all_fg = combine_fooofs(fgs)

# Explore the results from across all FOOOF fits
all_fg.print_results()
all_fg.plot()
Exemplo n.º 3
0
def main(argv):
    # defining basepaths
    basepath = '/Users/rdgao/Documents/data/CRCNS/fcx1/'
    rec_dirs = [f for f in np.sort(os.listdir(basepath)) if os.path.isdir(basepath+f)]
    result_basepath = '/Users/rdgao/Documents/code/research/field-echos/results/fcx1/wakesleep/'

    if 'do_psds' in argv:
        print('Computing PSDs...')

        for cur_rec in range(len(rec_dirs))[21:]:
            print(rec_dirs[cur_rec])
            # compute PSDs
            psd_path = result_basepath + rec_dirs[cur_rec] + '/psd/'

            # load data
            ephys_data = io.loadmat(basepath+rec_dirs[cur_rec]+'/'+rec_dirs[cur_rec]+'_ephys.mat', squeeze_me=True)
            behav_data = pd.read_csv(basepath+rec_dirs[cur_rec]+'/'+rec_dirs[cur_rec]+'_wakesleep.csv', index_col=0)

            # get some params
            nchan,nsamp = ephys_data['lfp'].shape
            fs = ephys_data['fs']
            ephys_data['t_lfp'] = np.arange(0,nsamp)/fs
            elec_region = np.unique(ephys_data['elec_regions'])[0]

            # get subset of behavior that marks wake and sleep
            behav_sub = behav_data[behav_data['Label'].isin(['Wake', 'Sleep'])]

            # name, nperseg, noverlap, f_range, outlier_pct
            p_configs = [['1sec', int(fs), int(fs/2), [0., 200.], 5],
                            ['5sec', int(fs*5), int(fs*4), [0., 200.], 5]]

            for p_cfg in p_configs:
                # parameter def
                print(p_cfg)
                saveout_path = psd_path+ p_cfg[0]
                nperseg, noverlap, f_range, outlier_pct = p_cfg[1:]

                psd_mean, psd_med,  = [], []
                for ind, cur_eps in behav_sub.iterrows():
                    # find indices of LFP that correspond to behavior
                    lfp_inds = np.where(np.logical_and(ephys_data['t_lfp']>=cur_eps['Start'],ephys_data['t_lfp']<cur_eps['End']))[0]

                    # compute mean and median welchPSD
                    p_squished = spectral.compute_spectrum(ephys_data['lfp'][:,lfp_inds], ephys_data['fs'], method='welch',avg_type='mean', nperseg=nperseg, noverlap=noverlap, f_range=f_range, outlier_pct=outlier_pct)
                    f_axis, cur_psd_mean = p_squished[0,:], p_squished[1::2,:] # work-around for ndsp currently squishing together the outputs
                    p_squished = spectral.compute_spectrum(ephys_data['lfp'][:,lfp_inds], ephys_data['fs'], method='welch',avg_type='median', nperseg=nperseg, noverlap=noverlap, f_range=f_range, outlier_pct=outlier_pct)
                    f_axis, cur_psd_med = p_squished[0,:], p_squished[1::2,:]

                    # append to list
                    psd_mean.append(cur_psd_mean)
                    psd_med.append(cur_psd_med)

                # collect, stack, and save out
                psd_mean, psd_med, behav_info = np.array(psd_mean), np.array(psd_med), np.array(behav_sub)
                save_dict = {}
                for name in ['psd_mean', 'psd_med','nperseg','noverlap','fs','outlier_pct', 'behav_info', 'elec_region', 'f_axis']:
                    save_dict[name] = eval(name)
                utils.makedir(saveout_path, timestamp=False)
                np.savez(file=saveout_path+'/psds.npz', **save_dict)

    if 'do_fooof' in argv:
        fooof_settings = [['knee', 4, (0.1,200)],
                            ['fixed', 4, (0.1,200)],
                            ['fixed', 2, (0.1,10)],
                            ['fixed', 2, (30,55)]]
        for cur_rec in range(len(rec_dirs)):
            print(rec_dirs[cur_rec])
            psd_path = result_basepath + rec_dirs[cur_rec] + '/psd/'
            for psd_win in ['1sec/', '5sec/']:
                psd_folder = psd_path+psd_win
                psd_data = np.load(psd_folder+'psds.npz')
                for psd_mode in ['psd_mean', 'psd_med']:
                    for f_s in fooof_settings:
                        fg = FOOOFGroup(aperiodic_mode=f_s[0], max_n_peaks=f_s[1])
                        fgs = fit_fooof_group_3d(fg, psd_data['f_axis'], psd_data[psd_mode], freq_range=f_s[2])
                        fg_all = combine_fooofs(fgs)
                        fooof_savepath = utils.makedir(psd_folder, '/fooof/'+psd_mode+'/', timestamp=False)
                        fg_all.save('fg_%s_%ipks_%i-%iHz'%(f_s[0],f_s[1],f_s[2][0],f_s[2][1]), fooof_savepath, save_results=True, save_settings=True)
Exemplo n.º 4
0
def main(argv):
    # defining basepaths
    basepath = '/Users/rdgao/Documents/data/CRCNS/fcx1/'
    rec_dirs = [
        f for f in np.sort(os.listdir(basepath)) if os.path.isdir(basepath + f)
    ]
    result_basepath = '/Users/rdgao/Documents/code/research/field-echos/results/fcx1/wakesleep/'

    if 'do_psds' in argv:
        print('Computing PSDs...')

        for cur_rec in range(len(rec_dirs)):
            print(rec_dirs[cur_rec])
            # compute PSDs
            psd_path = result_basepath + rec_dirs[cur_rec] + '/psd_spikes/'

            # load data
            ephys_data = io.loadmat(basepath + rec_dirs[cur_rec] + '/' +
                                    rec_dirs[cur_rec] + '_ephys.mat',
                                    squeeze_me=True)
            behav_data = pd.read_csv(basepath + rec_dirs[cur_rec] + '/' +
                                     rec_dirs[cur_rec] + '_wakesleep.csv',
                                     index_col=0)
            elec_region = np.unique(ephys_data['elec_regions'])[0]
            elec_shank_map = ephys_data['elec_shank_map']

            # some organization of spike meta datafile
            # NOTE that all this had to be done because I was an idiot and
            # organized the spikeinfo table and spikes in some dumb way
            # make spike info into df and access based on cell, and add end time
            df_spkinfo = pd.DataFrame(ephys_data['spike_info'],
                                      columns=ephys_data['spike_info_cols'])
            df_spkinfo.insert(
                len(df_spkinfo.columns) - 1, 'spike_start_ind',
                np.concatenate(
                    ([0], df_spkinfo['num_spikes_cumsum'].iloc[:-1].values)))
            df_spkinfo.rename(columns={'num_spikes_cumsum': 'spike_end_ind'},
                              inplace=True)

            # this is now a list of N arrays, where N is the number of cells
            #    now we can aggregate arbitrarily based on cell index
            spikes_list = utils.spikes_as_list(ephys_data['spiketrain'],
                                               df_spkinfo)

            # pooling across populations from the same shanks
            df_spkinfo_pooled = df_spkinfo.copy()
            for g_i, g in df_spkinfo.groupby(['shank', 'cell_EI_type']):
                # super python magic that collapses all the spikes of the same pop on the same shank into one array
                spikes_list.append(
                    np.sort(
                        np.hstack(
                            [spikes_list[c_i] for c_i, cell in g.iterrows()])))
                # update spike info dataframe
                df_pop = pd.DataFrame({
                    'shank': g['shank'].head(1),
                    'cell_EI_type': g['cell_EI_type'].head(1),
                    'num_spikes': g['num_spikes'].sum(),
                    'cell_id': 0
                })
                df_spkinfo_pooled = df_spkinfo_pooled.append(df_pop,
                                                             ignore_index=True)

            # pooling across entire recording
            for g_i, g in df_spkinfo.groupby(['cell_EI_type']):
                spikes_list.append(
                    np.sort(
                        np.hstack(
                            [spikes_list[c_i] for c_i, cell in g.iterrows()])))
                df_pop = pd.DataFrame({
                    'shank': 0,
                    'cell_id': 0,
                    'cell_EI_type': g['cell_EI_type'].head(1),
                    'num_spikes': g['num_spikes'].sum()
                })
                df_spkinfo_pooled = df_spkinfo_pooled.append(df_pop,
                                                             ignore_index=True)

            # save spikeinfo table to recording folder
            utils.makedir(psd_path, timestamp=False)
            df_spkinfo_pooled.to_csv(psd_path + '/spike_info.csv')

            ##### ------------- #####
            # compute PSDs across conditions and populations
            # individual cells
            dt = 0.005
            fs = 1 / dt

            # name, nperseg, noverlap, f_range, outlier_pct
            p_configs = [['2sec', int(2 * fs),
                          int(2 * fs * 4 / 5)],
                         ['5sec', int(5 * fs),
                          int(5 * fs * 4 / 5)]]

            behav_sub = behav_data[behav_data['Label'].isin(['Wake', 'Sleep'
                                                             ])].reset_index()
            behav_info = np.array(behav_sub)
            num_block, num_cell = len(behav_sub), len(spikes_list)
            for p_cfg in p_configs:
                print(p_cfg)
                saveout_path = psd_path + p_cfg[0]
                nperseg, noverlap = p_cfg[1:]

                psd_mean = np.zeros(
                    (num_block, num_cell, int(p_cfg[1] / 2 + 1)))
                psd_med = np.zeros(
                    (num_block, num_cell, int(p_cfg[1] / 2 + 1)))
                for cell, spikes in enumerate(spikes_list):
                    print(cell, end='|')
                    for block, cur_eps in behav_sub.iterrows():
                        spikes_eps = spikes[np.logical_and(
                            spikes >= cur_eps['Start'],
                            spikes < cur_eps['End'])]
                        t_spk, spikes_binned = utils.bin_spiketrain(
                            spikes_eps, dt, cur_eps[['Start', 'End']])
                        f_axis, psd_mean[block,
                                         cell, :] = spectral.compute_spectrum(
                                             spikes_binned,
                                             fs,
                                             method='welch',
                                             avg_type='mean',
                                             nperseg=nperseg,
                                             noverlap=noverlap)
                        f_axis, psd_med[block,
                                        cell, :] = spectral.compute_spectrum(
                                            spikes_binned,
                                            fs,
                                            method='welch',
                                            avg_type='median',
                                            nperseg=nperseg,
                                            noverlap=noverlap)

                # save PSDs and spike_info dataframe
                save_dict = {}
                for name in [
                        'psd_mean', 'psd_med', 'nperseg', 'noverlap', 'fs',
                        'behav_info', 'elec_region', 'elec_shank_map', 'f_axis'
                ]:
                    save_dict[name] = eval(name)
                utils.makedir(saveout_path, timestamp=False)
                np.savez(file=saveout_path + '/psds.npz', **save_dict)

    if 'do_fooof' in argv:
        fooof_settings = [['fixed', 2, (.5, 80)], ['fixed', 1, (.5, 5)],
                          ['fixed', 1, (10, 20)], ['fixed', 1, (30, 80)]]

        for cur_rec in range(len(rec_dirs)):
            print(rec_dirs[cur_rec])
            psd_path = result_basepath + rec_dirs[cur_rec] + '/psd_spikes/'
            df_spkinfo_pooled = pd.read_csv(psd_path + '/spike_info.csv',
                                            index_col=0)

            # grab only the aggregate cells
            df_pops = df_spkinfo_pooled[df_spkinfo_pooled['cell_id'] == 0]
            df_pops.to_csv(psd_path + '/pop_spike_info.csv')

            for psd_win in ['2sec/', '5sec/']:
                psd_folder = psd_path + psd_win
                psd_data = np.load(psd_folder + 'psds.npz')
                for psd_mode in ['psd_mean']:
                    psd_spikes = psd_data[psd_mode][:, df_pops.index.values, :]
                    if np.any(np.isinf(np.log(psd_spikes[:, :, 0]))):
                        # if any PSDs are 0s, set it to ones
                        print('Null PSDs found.')
                        zero_inds = np.where(
                            np.isinf(np.log(psd_spikes[:, :, 0])))
                        psd_spikes[zero_inds] = 1.

                    fg_all = []
                    for f_s in fooof_settings:
                        fg = FOOOFGroup(aperiodic_mode=f_s[0],
                                        max_n_peaks=f_s[1],
                                        peak_width_limits=(5, 20))
                        fgs = fit_fooof_group_3d(fg,
                                                 psd_data['f_axis'],
                                                 psd_spikes,
                                                 freq_range=f_s[2])
                        fg_all = combine_fooofs(fgs)
                        fooof_savepath = utils.makedir(psd_folder,
                                                       '/fooof/' + psd_mode +
                                                       '/',
                                                       timestamp=False)
                        fg_all.save('fg_%s_%ipks_%i-%iHz' %
                                    (f_s[0], f_s[1], f_s[2][0], f_s[2][1]),
                                    fooof_savepath,
                                    save_results=True,
                                    save_settings=True)

    print('Done.')