예제 #1
0
def load_with_parms(site, **kwargs):
    # defaults

    options = {'batch': 316,
               'cellid': site,
               'stimfmt': 'envelope',
               'rasterfs': 100,
               'runclass': 'CPN',
               'stim': False,
               'resp':True}

    options.update(**kwargs)

    manager = BAPHYExperiment(siteid=site, batch=options['batch'])

    loaded_rec = manager.get_recording(recache=True, **options)
    parameters = manager.get_baphy_exptparams()

    # load_URI, _ = nb.baphy_load_recording_uri(**options)
    # loaded_rec = recording.load_recording(load_URI)

    CPN_rec = cpe.set_recording_subepochs(loaded_rec)
    recordings  = split_recording(CPN_rec)

    return recordings, parameters
예제 #2
0
def get_pairs(parm, rasterfs=100):
    expt = BAPHYExperiment(parm)
    rec = expt.get_recording(rasterfs=rasterfs, resp=True, stim=False)
    resp = rec['resp'].rasterize()
    expt_params = expt.get_baphy_exptparams()  # Using Charlie's manager
    ref_handle = expt_params[0]['TrialObject'][1]['ReferenceHandle'][1]
    soundies = list(ref_handle['SoundPairs'].values())
    pairs = [tuple([j for j in (soundies[s]['bg_sound_name'].split('.')[0],
                                          soundies[s]['fg_sound_name'].split('.')[0])])
                       for s in range(len(soundies))]
    for c, p in enumerate(pairs):
        print(f"{c} - {p}")
    print(f"There are {len(resp.chans) - 1} units and {len(pairs)} sound pairs.")
    print("Returning one value less than channel and pair count.")

    return (len(pairs)-1), (len(resp.chans)-1)
예제 #3
0
VALID_TRIALS_EAR = np.full(len(uDate), np.nan)
VALID_TRIALS_LAT = np.full(len(uDate), np.nan)
for idx, ud in enumerate(uDate):
    print(f"Loading data from {ud}")
    parmfiles = d[d.date == ud].parmfile_path.values.tolist()
    # add catch to make sure "siteid" the same for all files
    sid = [p.split(os.path.sep)[-1][:7] for p in parmfiles]
    if np.any(np.array(sid) != sid[0]):
        bad_idx = (np.array(sid) != sid[0])
        parmfiles = np.array(parmfiles)[~bad_idx].tolist()
    manager = BAPHYExperiment(parmfiles)

    # make sure only loaded actives
    pf_mask = [
        True if k['BehaveObjectClass'] == 'RewardTargetLBHB' else False
        for k in manager.get_baphy_exptparams()
    ]
    if sum(pf_mask) == len(manager.parmfile):
        pass
    else:
        parmfiles = np.array(manager.parmfile)[pf_mask].tolist()
        manager = BAPHYExperiment(parmfiles)

    # get reaction times of targets, only for "correct" trials
    bev = manager.get_behavior_events(**options)
    bev = manager._stack_events(bev)
    bev = bev[bev.invalidTrial == False]
    bev, params = TBP(bev, manager.get_baphy_exptparams()[0])
    _rts = get_reaction_times(params, bev, **options)

    # get early / middle / late rts
예제 #4
0
        rp = rec.copy()
        rp = rp.create_mask(True)
        rp = rp.and_mask(['PASSIVE_EXPERIMENT'])

        # find / sort epoch names
        if batch in [324, 325]:
            targets = thelp.sort_targets([f for f in ra['resp'].epochs.name.unique() if 'TAR_' in f])
            targets = [t for t in targets if (ra['resp'].epochs.name==t).sum()>=5]
            on_center = thelp.get_tar_freqs([f.strip('REM_') for f in ra['resp'].epochs.name.unique() if 'REM_' in f])[0]
            targets = [t for t in targets if str(on_center) in t]
            catch = [f for f in ra['resp'].epochs.name.unique() if 'CAT_' in f]
            catch = [c for c in catch if str(on_center) in c]
            targets_str = targets
            catch_str = catch
        elif batch == 307:
            params = manager.get_baphy_exptparams()
            params = [p for p in params if p['BehaveObjectClass']!='Passive'][0]
            tf = params['TrialObject'][1]['TargetHandle'][1]['Names']
            targets = [f'TAR_{t}' for t in tf]
            if params['TrialObject'][1]['OverlapRefTar']=='Yes':
                snrs = params['TrialObject'][1]['RelativeTarRefdB'] 
            else:
                snrs = ['Inf']
            snrs = [s if (s!=np.inf) else 'Inf' for s in snrs]
            targets_str = [f'TAR_{t}+{snr}dB+Noise' for snr, t in zip(snrs, tf)]
            targets_str = targets_str[::-1]
            targets = targets[::-1]
            # only keep targets w/ at least 5 reps in active
            targets_str = [ts for t, ts in zip(targets, targets_str) if (ra['resp'].epochs.name==t).sum()>=5]
            targets = [t for t in targets if (ra['resp'].epochs.name==t).sum()>=5]
            catch = []
예제 #5
0
def plot_binaural_psths(df, cellid, bg, fg, batch, save=False, close=False):
    '''Takes input of a data fram from ohel.calc_psth_metrics along with a cellid and pair of
    sounds and will plot all the spatial psth combos with spectrogram of sound. Can save.'''
    manager = BAPHYExperiment(cellid=cellid, batch=batch)
    options = ohel.get_load_options(batch)
    rec = manager.get_recording(**options)

    rec['resp'] = rec['resp'].extract_channels([cellid])
    resp = copy.copy(rec['resp'].rasterize())

    expt_params = manager.get_baphy_exptparams()  # Using Charlie's manager
    if len(expt_params) == 1:
        ref_handle = expt_params[0]['TrialObject'][1]['ReferenceHandle'][1]
    if len(expt_params) > 1:
        ref_handle = expt_params[-1]['TrialObject'][1]['ReferenceHandle'][1]
    BG_folder, FG_folder = ref_handle['BG_Folder'], ref_handle['FG_Folder']

    ## I could make this not dependent on DF if I add some code that loads the epochs of the cell
    ## that you inputted and applies that type label function, that's really all I'm using from df

    df_filtered = df[(df.BG == bg) & (df.FG == fg) & (df.cellid == cellid)]
    if len(df_filtered) == 0:
        pairs = get_pair_names(cellid, df, show=False)
        raise ValueError(f"The inputted BG: {bg} and FG: {fg} are not in {cellid}.\n"
                         f"Maybe try one of these:\n{pairs}")

    epochs = []
    name = df_filtered.epoch.loc[df_filtered.kind == '11'].values[0]
    bb, ff = name.split('_')[1], name.split('_')[2]
    bb1, ff1 = name.replace(ff, 'null'), name.replace(bb, 'null')
    epochs.append(bb1), epochs.append(ff1)

    name = df_filtered.epoch.loc[df_filtered.kind == '22'].values[0]
    bb, ff = name.split('_')[1], name.split('_')[2]
    bb2, ff2 = name.replace(ff, 'null'), name.replace(bb, 'null')
    epochs.append(bb2), epochs.append(ff2)
    epochs.extend(df_filtered.epoch.values)

    r = resp.extract_epochs(epochs)
    SR = df_filtered['SR'].values[0]

    f = plt.figure(figsize=(18, 9))
    psth11 = plt.subplot2grid((9, 8), (0, 0), rowspan=3, colspan=3)
    psth12 = plt.subplot2grid((9, 8), (0, 3), rowspan=3, colspan=3, sharey=psth11)
    psth21 = plt.subplot2grid((9, 8), (3, 0), rowspan=3, colspan=3, sharey=psth11)
    psth22 = plt.subplot2grid((9, 8), (3, 3), rowspan=3, colspan=3, sharey=psth11)
    specA1 = plt.subplot2grid((9, 8), (7, 0), rowspan=1, colspan=3)
    specB1 = plt.subplot2grid((9, 8), (8, 0), rowspan=1, colspan=3)
    specA2 = plt.subplot2grid((9, 8), (7, 3), rowspan=1, colspan=3)
    specB2 = plt.subplot2grid((9, 8), (8, 3), rowspan=1, colspan=3)
    psthbb = plt.subplot2grid((9, 8), (0, 6), rowspan=3, colspan=2, sharey=psth11)
    psthff = plt.subplot2grid((9, 8), (3, 6), rowspan=3, colspan=2, sharey=psth11)
    ax = [psth11, psth12, psth21, psth22, specA1, specB1, specA2, specB2, psthbb, psthff]

    prestim = resp.epochs[resp.epochs['name'] == 'PreStimSilence'].copy().iloc[0]['end']
    time = (np.arange(0, r[epochs[0]].shape[-1]) / options['rasterfs']) - prestim

    # r_mean = {e: np.squeeze(np.nanmean(r[e], axis=0)) for e in epochs}
    r_mean = {e: np.squeeze(np.nanmean(r[e], axis=0)) - SR for e in epochs}

    epochs.extend(['lin11', 'lin12', 'lin21', 'lin22'])
    bg1, fg1, bg2, fg2 = epochs[0], epochs[1], epochs[2], epochs[3]
    r_mean['lin11'], r_mean['lin12'] = r_mean[bg1]+r_mean[fg1], r_mean[bg1]+r_mean[fg2]
    r_mean['lin21'], r_mean['lin22'] = r_mean[bg2]+r_mean[fg1], r_mean[bg2]+r_mean[fg2]

    colors = ['deepskyblue'] *3 + ['violet'] *3 + ['yellowgreen'] *3 + ['darksalmon'] *3 \
             + ['dimgray'] *4 + ['black'] *4
    styles = ['-'] *16 + [':'] *4
    ax_num = [0, 1, 8, 2, 3, 8, 0, 2, 9, 1, 3, 9, 0, 1, 2, 3, 0, 1, 2, 3]
    ep_num = [0, 0, 0, 2, 2, 2, 1, 1, 1, 3, 3, 3, 4, 5, 6, 7, 8, 9, 10, 11]
    labels = ['BG1'] *3 + ['BG2'] *3 + ['FG1'] *3 + ['FG2'] *3 \
             + ['BG1+FG1'] + ['BG1+FG2'] + ['BG2+FG1'] + ['BG2+FG2'] + ['LS'] *4

    for e, a, c, s, l in zip(ep_num, ax_num, colors, styles, labels):
        ax[a].plot(time, sf.gaussian_filter1d(r_mean[epochs[e]], sigma=2)
                   * options['rasterfs'], color=c, linestyle=s, label=l)

    ymin, ymax = ax[0].get_ylim()
    AXS = [0, 1, 2, 3, 8, 9]
    for AX, tt, aab, bab, ali, bli, prf in zip(range(4), df_filtered.kind, df_filtered.AcorAB,
                                          df_filtered.BcorAB, df_filtered.AcorLin,
                                               df_filtered.BcorLin, df_filtered.pref):
        ax[AX].legend((f'BG{tt[0]}, corr={np.around(aab, 3)}',
                       f'FG{tt[1]}, corr={np.around(bab, 3)}',
                       f'BG{tt[0]}+FG{tt[1]}',
                       f'LS, Acorr={np.around(ali, 3)}\nBcorr={np.around(bli, 3)}\npref={np.around(prf, 3)}'))
    for AX in AXS:
        ax[AX].vlines([0, 1.0], ymin, ymax, color='black', lw=0.75, ls='--')
        ax[AX].vlines(0.5, ymax * 0.9, ymax, color='black', lw=0.75, ls=':')
        ax[AX].spines['right'].set_visible(True), ax[AX].spines['top'].set_visible(True)
        if AX !=8 and AX !=9:
            ax[AX].set_xlim((-prestim * 0.5), (1 + (prestim * 0.75)))
        else:
            ax[AX].set_xlim((-prestim * 0.15), (1 + (prestim * 0.25)))

        if AX == 0 or AX == 1 or AX == 8:
            plt.setp(ax[AX].get_xticklabels(), visible=False)
        if AX == 1 or AX == 3 or AX == 8 or AX == 9:
            plt.setp(ax[AX].get_yticklabels(), visible=False)
        if AX == 2 or AX == 3 or AX == 9:
            ax[AX].set_xlabel('Time(s)', fontweight='bold', fontsize=10)
        if AX == 0 or AX == 2:
            ax[AX].set_ylabel('Spikes', fontweight='bold', fontsize=10)

    ax[0].set_title(f"{cellid} - BG: {bg} - FG: {fg}", fontweight='bold', fontsize=12)

    bbn, ffn = bb[:2], ff[:2]
    bg_path = glob.glob((f'/auto/users/hamersky/baphy/Config/lbhb/SoundObjects/@OverlappingPairs/'
                        f'{BG_folder}/{bbn}*.wav'))[0]
    fg_path = glob.glob((f'/auto/users/hamersky/baphy/Config/lbhb/SoundObjects/@OverlappingPairs/'
                        f'{FG_folder}/{ffn}*.wav'))[0]

    xf = 100
    low, high = ax[0].get_xlim()
    low, high = low * xf, high * xf

    for AX in range(4,8):
        if AX == 4 or AX == 6:
            sfs, W = wavfile.read(bg_path)
        elif AX == 5 or AX == 7:
            sfs, W = wavfile.read(fg_path)
        spec = gtgram(W, sfs, 0.02, 0.01, 48, 100, 24000)
        ax[AX].imshow(spec, aspect='auto', origin='lower', extent=[0, spec.shape[1], 0, spec.shape[0]])
        ax[AX].set_xlim(low, high)
        ax[AX].set_xticks([]), ax[AX].set_yticks([])
        ax[AX].set_xticklabels([]), ax[AX].set_yticklabels([])
        ax[AX].spines['top'].set_visible(False), ax[AX].spines['bottom'].set_visible(False)
        ax[AX].spines['left'].set_visible(False), ax[AX].spines['right'].set_visible(False)
    ax[4].set_ylabel(f"{bb.split('-')[0]}", fontweight='bold')
    ax[5].set_ylabel(f"{ff.split('-')[0]}", fontweight='bold')

    if save:
        site, animal, area, unit = cellid.split('-')[0], cellid[:3], df.area.loc[df.cellid == cellid].unique()[0], cellid[8:]
        path = f"/home/hamersky/OLP Binaural/{animal}/{area}/{site}/{unit}/"
        Path(path).mkdir(parents=True, exist_ok=True)
        print(f"Saving to {path + f'{cellid}-{bg}-{fg}.png'}")
        plt.savefig(path + f"{cellid}-{bg}-{fg}.png")
        if close:
            plt.close()
예제 #6
0
    # find / sort epoch names
    files = [f for f in rec['resp'].epochs.name.unique() if 'FILE_' in f]
    targets = [f for f in rec['resp'].epochs.name.unique() if 'TAR_' in f]
    catch = [f for f in rec['resp'].epochs.name.unique() if 'CAT_' in f]

    sounds = targets + catch
    ref_stims = [x for x in rec['resp'].epochs.name.unique() if 'STIM_' in x]
    idx = np.argsort([int(s.split('_')[-1]) for s in ref_stims])
    ref_stims = np.array(ref_stims)[idx].tolist()
    all_stims = ref_stims + sounds

    # ================================================================================================
    # Plot "raw" data -- tuning curves / psth's .
    # PSTHs for REFs
    ref_dur = int(manager.get_baphy_exptparams()[0]['TrialObject'][1]['ReferenceHandle'][1]['Duration'] * options['rasterfs'])
    pre_post = int(manager.get_baphy_exptparams()[0]['TrialObject'][1]['ReferenceHandle'][1]['PostStimSilence'] * options['rasterfs'])
    d1 = int(np.sqrt(rec['resp'].shape[0]))
    f, ax = plt.subplots(d1+1, d1+1, figsize=(16, 12))
    for i in range(rec['resp'].shape[0]):
        if i == 0:
            br.psth(rec, chan=rec['resp'].chans[i], epochs=ref_stims, ep_dur=ref_dur, cmap='viridis', prestim=pre_post, ax=ax.flatten()[i])
        else:
            br.psth(rec, chan=rec['resp'].chans[i], epochs=ref_stims, ep_dur=ref_dur, cmap='viridis', prestim=pre_post, supp_legend=True, ax=ax.flatten()[i])
    f.tight_layout()

    f.savefig(site_path + '/REF_psth.png')

    # PSTHs for TARs (and CATCH)
    d1 = int(np.sqrt(rec['resp'].shape[0]))
    f, ax = plt.subplots(d1+1, d1+1, figsize=(16, 12))
LI_ALL = np.full((4, len(uDate)), np.nan)
VALID_TRIALS = np.full(len(uDate), np.nan)

for idx, ud in enumerate(uDate):
    parmfiles = d[d.date == ud].parmfile_path.values.tolist()
    # add catch to make sure "siteid" the same for all files
    sid = [p.split(os.path.sep)[-1][:7] for p in parmfiles]
    if np.any(np.array(sid) != sid[0]):
        bad_idx = (np.array(sid) != sid[0])
        parmfiles = np.array(parmfiles)[~bad_idx].tolist()
    manager = BAPHYExperiment(parmfiles)

    # make sure only loaded actives
    pf_mask = [
        True if k['BehaveObjectClass'] == 'RewardTargetLBHB' else False
        for k in manager.get_baphy_exptparams()
    ]
    if sum(pf_mask) == len(manager.parmfile):
        pass
    else:
        parmfiles = np.array(manager.parmfile)[pf_mask].tolist()
        manager = BAPHYExperiment(parmfiles)

    rec = manager.get_recording(recache=True, **options)
    rec = rec.and_mask(['ACTIVE_EXPERIMENT'])

    # define overlapping trial windows for behavior metrics
    # divide trials into quantiles (25/50/75 percentile)
    nTrials = rec['fileidx'].extract_epoch('TRIAL', mask=rec['mask']).shape[0]
    edges = np.quantile(range(nTrials), [.25, .5, .75]).astype(int)
    # then define three overlapping ranges. Early, middle, late
예제 #8
0
파일: OLP_helpers.py 프로젝트: LBHB/nems_db
def get_sound_statistics(weight_df, plot=True):
    '''5/12/22 Takes a cellid and batch and figures out all the sounds that were played
    in that experiment and calculates some stastistics it plots side by side. Also outputs
    those numbers in a cumbersome dataframe'''
    lfreq, hfreq, bins = 100, 24000, 48
    cid, btch = weight_df.cellid.iloc[0], weight_df.batch.iloc[0]
    manager = BAPHYExperiment(cellid=cid, batch=btch)
    expt_params = manager.get_baphy_exptparams()  # Using Charlie's manager
    ref_handle = expt_params[-1]['TrialObject'][1]['ReferenceHandle'][1]
    BG_folder, FG_folder = ref_handle['BG_Folder'], ref_handle['FG_Folder']

    bbs = list(set([bb.split('_')[1][:2] for bb in weight_df.epoch]))
    ffs = list(set([ff.split('_')[2][:2] for ff in weight_df.epoch]))
    bbs.sort(key=int), ffs.sort(key=int)

    bg_paths = [glob.glob((f'/auto/users/hamersky/baphy/Config/lbhb/SoundObjects/@OverlappingPairs/'
                           f'{BG_folder}/{bb}*.wav'))[0] for bb in bbs]
    fg_paths = [glob.glob((f'/auto/users/hamersky/baphy/Config/lbhb/SoundObjects/@OverlappingPairs/'
                           f'{FG_folder}/{ff}*.wav'))[0] for ff in ffs]
    paths = bg_paths + fg_paths
    bgname = [bb.split('/')[-1].split('.')[0] for bb in bg_paths]
    fgname = [ff.split('/')[-1].split('.')[0] for ff in fg_paths]
    names = bgname + fgname

    Bs, Fs = ['BG'] * len(bgname), ['FG'] * len(fgname)
    labels = Bs + Fs

    sounds = []
    means = np.empty((bins, len(names)))
    means[:] = np.NaN
    for cnt, sn, pth, ll in zip(range(len(labels)), names, paths, labels):
        sfs, W = wavfile.read(pth)
        spec = gtgram(W, sfs, 0.02, 0.01, bins, lfreq, hfreq)

        dev = np.std(spec, axis=1)

        freq_mean = np.nanmean(spec, axis=1)
        x_freq = np.logspace(np.log2(lfreq), np.log2(hfreq), num=bins, base=2)
        csm = np.cumsum(freq_mean)
        big = np.max(csm)

        freq75 = x_freq[np.abs(csm - (big * 0.75)).argmin()]
        freq25 = x_freq[np.abs(csm - (big * 0.25)).argmin()]
        freq50 = x_freq[np.abs(csm - (big * 0.5)).argmin()]
        bandw = np.log2(freq75 / freq25)

        means[:, cnt] = freq_mean

        sounds.append({'name': sn,
                       'type': ll,
                       'std': dev,
                       'bandwidth': bandw,
                       '75th': freq75,
                       '25th': freq25,
                       'center': freq50,
                       'spec': spec,
                       'mean_freq': freq_mean,
                       'freq_stationary': np.std(freq_mean)})

    sound_df = pd.DataFrame(sounds)

    # allmean = np.nanmean(means, axis=1, keepdims=True)
    # norm_mean = [aa / allmean for aa in sound_df.mean_freq]
    # freq_stationarity = [np.std(aa) for aa in allmean]
    # sound_df['norm_mean'],  = norm_mean
    # sound_df['freq_stationary'] = freq_stationarity

    ss = sound_df.explode('std')
    # frs = sound_df.explode('norm_mean')
    # frs = sound_df.explode('mean_freq')
    snames = [dd[2:] for dd in sound_df.name]

    if plot:
        fig, ax = plt.subplots(1, 3, figsize=(18, 8))

        sb.barplot(x='name', y='std', palette=["lightskyblue" if x == 'BG' else 'yellowgreen' for x in sound_df.type],
                   data=ss, ci=68, ax=ax[0], errwidth=1)
        ax[0].set_xticklabels(snames, rotation=90, fontweight='bold', fontsize=7)
        ax[0].set_ylabel('Non-stationariness', fontweight='bold', fontsize=12)
        ax[0].spines['top'].set_visible(True), ax[0].spines['right'].set_visible(True)
        ax[0].set(xlabel=None)

        sb.barplot(x='name', y='bandwidth',
                   palette=["lightskyblue" if x == 'BG' else 'yellowgreen' for x in sound_df.type],
                   data=sound_df, ax=ax[1])
        ax[1].set_xticklabels(snames, rotation=90, fontweight='bold', fontsize=7)
        ax[1].set_ylabel('Bandwidth (octaves)', fontweight='bold', fontsize=12)
        ax[1].spines['top'].set_visible(True), ax[1].spines['right'].set_visible(True)
        ax[1].set(xlabel=None)

        sb.barplot(x='name', y='freq_stationary',
                   palette=["lightskyblue" if x == 'BG' else 'yellowgreen' for x in sound_df.type],
                   data=sound_df, ax=ax[2])
        ax[2].set_xticklabels(snames, rotation=90, fontweight='bold', fontsize=7)
        ax[2].set_ylabel('Frequency Non-stationariness', fontweight='bold', fontsize=12)
        ax[2].spines['top'].set_visible(True), ax[2].spines['right'].set_visible(True)
        ax[2].set(xlabel=None)

        fig.tight_layout()

    return sound_df
예제 #9
0
rts_middle = {k: [] for k in snrs}
rts_late = {k: [] for k in snrs}
for idx, ud in enumerate(uDate):
    print(f"Loading data from {ud}")
    parmfiles = d[d.date == ud].parmfile_path.values.tolist()
    # add catch to make sure "siteid" the same for all files
    sid = [p.split(os.path.sep)[-1][:7] for p in parmfiles]
    if np.any(np.array(sid) != sid[0]):
        bad_idx = (np.array(sid) != sid[0])
        parmfiles = np.array(parmfiles)[~bad_idx].tolist()
    manager = BAPHYExperiment(parmfiles)

    # make sure only loaded actives
    pf_mask = [
        True if k['BehaveObjectClass'] == 'RewardTargetLBHB' else False
        for k in manager.get_baphy_exptparams()
    ]
    if sum(pf_mask) == len(manager.parmfile):
        pass
    else:
        parmfiles = np.array(manager.parmfile)[pf_mask].tolist()
        manager = BAPHYExperiment(parmfiles)

    # get reaction times of targets, only for "correct" trials
    bev = manager.get_behavior_events(**options)
    bev = manager._stack_events(bev)
    bev = bev[bev.invalidTrial == False]
    _rts = get_reaction_times(manager.get_baphy_exptparams()[0], bev,
                              **options)

    # get early / middle / late rts
예제 #10
0
def psth_responses(parm, pair_idx, cell, sigma=2, save=False, rasterfs=100):
    expt = BAPHYExperiment(parm)
    rec = expt.get_recording(rasterfs=rasterfs, resp=True, stim=False)
    resp = rec['resp'].rasterize()

    site, unit = expt.siteid[:-1], resp.chans[cell]
    if len(resp.chans) <= cell:
        raise ValueError(f"Cell {cell} is out of range for site with {len(resp.chans) - 1} units")

    expt_params = expt.get_baphy_exptparams()  # Using Charlie's manager
    ref_handle = expt_params[0]['TrialObject'][1]['ReferenceHandle'][1]
    soundies = list(ref_handle['SoundPairs'].values())
    pairs = [tuple([j for j in (soundies[s]['bg_sound_name'].split('.')[0],
                                soundies[s]['fg_sound_name'].split('.')[0])])
             for s in range(len(soundies))]

    if len(pairs) <= pair_idx:
        raise ValueError(f"Pair_idx {pair_idx} is out of range for unit with {len(pairs) - 1} sound pairs")

    BG, FG = pairs[pair_idx]
    colors = ['deepskyblue', 'yellowgreen', 'grey', 'silver']
    bg_path = f'/auto/users/hamersky/baphy/Config/lbhb/SoundObjects/@OverlappingPairs/Background2/{BG}.wav'
    fg_path = f'/auto/users/hamersky/baphy/Config/lbhb/SoundObjects/@OverlappingPairs/Foreground3/{FG}.wav'

    #turn pairs into format for epochs, not file getting
    for c, t in enumerate(pairs):
        pairs[c] = tuple([ss.replace(' ', '') for ss in t])
    BG, FG = pairs[pair_idx]
    epochs = [f'STIM_{BG}-0-1_null', f'STIM_null_{FG}-0-1', f'STIM_{BG}-0-1_{FG}-0-1',
              f'STIM_{BG}-0.5-1_{FG}-0-1']

    prestim = ref_handle['PreStimSilence']

    f = plt.figure(figsize=(15,9))
    psth = plt.subplot2grid((4, 5), (0, 0), rowspan=2, colspan=5)
    specBG = plt.subplot2grid((4, 5), (2, 0), rowspan=1, colspan=5)
    specFG = plt.subplot2grid((4, 5), (3, 0), rowspan=1, colspan=5)

    ax = [psth, specBG, specFG]

    r = resp.extract_epochs(epochs)

    time = (np.arange(0, r[epochs[0]].shape[-1]) / rasterfs ) - prestim

    for e, c in zip(epochs, colors):
        ax[0].plot(time, sf.gaussian_filter1d(np.nanmean(r[e][:,cell,:], axis=0), sigma=sigma)
             * rasterfs, color=c, label=e)
    ax[0].legend()
    ax[0].set_title(f"{resp.chans[cell]} - Pair {pair_idx} - BG: {BG} - FG: {FG} - sigma={sigma}", weight='bold')
    # ax[0].set_xlim([0-(prestim/2), time[-1]])
    ymin, ymax = ax[0].get_ylim()
    ax[0].vlines([0,1.0], ymin, ymax, color='black', lw=0.75, ls='--')
    ax[0].vlines(0.5, ymin, ymax, color='black', lw=0.75, ls=':')

    xf = 100
    low, high = ax[0].get_xlim()
    low, high = low * xf, high * xf

    sfs, W = wavfile.read(bg_path)
    spec = gtgram(W, sfs, 0.02, 0.01, 48, 100, 24000)
    ax[1].imshow(spec, aspect='auto', origin='lower', extent=[0, spec.shape[1], 0, spec.shape[0]])
    ax[1].set_xlim(low, high)
    ax[1].set_xticks([0, 20, 40, 60, 80]), ax[1].set_yticks([])
    ax[1].set_xticklabels([0, 0.2, 0.4, 0.6, 0.8]), ax[1].set_yticklabels([])
    ax[1].spines['top'].set_visible(False), ax[1].spines['bottom'].set_visible(False)
    ax[1].spines['left'].set_visible(False), ax[1].spines['right'].set_visible(False)
    ymin2, ymax2 = ax[1].get_ylim()
    ax[1].vlines((spec.shape[-1]+1)/2, ymin2, ymax2, color='white', lw=0.75, ls=':')

    sfs, W = wavfile.read(fg_path)
    spec = gtgram(W, sfs, 0.02, 0.01, 48, 100, 24000)
    ax[2].imshow(spec, aspect='auto', origin='lower', extent=[0, spec.shape[1], 0, spec.shape[0]])
    ax[2].set_xlim(low, high)
    ax[2].set_xticks([0, 20, 40, 60, 80]), ax[2].set_yticks([])
    ax[2].set_xticklabels([0, 0.2, 0.4, 0.6, 0.8]), ax[2].set_yticklabels([])
    ax[2].spines['top'].set_visible(False), ax[2].spines['bottom'].set_visible(False)
    ax[2].spines['left'].set_visible(False), ax[2].spines['right'].set_visible(False)
    ymin3, ymax3 = ax[2].get_ylim()
    ax[2].vlines((spec.shape[-1]+1)/2, ymin3, ymax3, color='white', lw=0.75, ls=':')

    ax[2].set_xlabel('Seconds', weight='bold')
    ax[1].set_ylabel(f"Background:\n{BG}", weight='bold', labelpad=-80, rotation=0)
    ax[2].set_ylabel(f"Foreground:\n{FG}", weight='bold', labelpad=-80, rotation=0)

    if save:
        path = f"/home/hamersky/Tabor PSTHs/{site}/"
        # if os.path.isfile(path):
        Path(path).mkdir(parents=True, exist_ok=True)

        plt.savefig(path + f"{unit} - Pair {pair_idx} - {BG} - {FG} - sigma{sigma}.png")
        plt.close()



# parm = '/auto/data/daq/Tabor/TBR011/TBR011a17_p_OLP.m'
# # parm = '/auto/data/daq/Tabor/TBR007/TBR007a10_p_OLP.m'
# expt = BAPHYExperiment(parm)
# rec = expt.get_recording(rasterfs=rasterfs, resp=True, stim=False)
# resp = rec['resp'].rasterize()
#
# expt_params = expt.get_baphy_exptparams()  # Using Charlie's manager
# ref_handle = expt_params[0]['TrialObject'][1]['ReferenceHandle'][1]
# soundies = list(ref_handle['SoundPairs'].values())
# pairs = [tuple([j for j in (soundies[s]['bg_sound_name'].split('.')[0],
#                             soundies[s]['fg_sound_name'].split('.')[0])])
#          for s in range(len(soundies))]
# #TBR011a-58-1 = -3
# cell = -3  #TBR011a-58-1
# cell = 2  #
# cell = 3 #large effect
#
# epochs = ['STIM_17Tuning-0-1_13Tsik-0-1', 'STIM_17Tuning-0-1_null', 'STIM_null_13Tsik-0-1',
#           'STIM_17Tuning-0.5-1_13Tsik-0-1', 'STIM_17Tuning-0-1_13Tsik-0.5-1']
# epochs = ['STIM_17Tuning-0-1_13Tsik-0-1', 'STIM_null_13Tsik-0-1',
#           'STIM_17Tuning-0.5-1_13Tsik-0-1']
# epochs = ['STIM_17Tuning-0-1_13Tsik-0-1', 'STIM_17Tuning-0-1_null',
#           'STIM_17Tuning-0-1_13Tsik-0.5-1']
#
# epochs = ['STIM_17Tuning-0-1_13Tsik-0-1', 'STIM_null_13Tsik-0-1',
#           'STIM_17Tuning-0.5-1_13Tsik-0-1', 'STIM_17Tuning-0-1_null']
# colors = ['black', 'yellowgreen', 'lightgrey', 'blue']
#
# #final colors for better plot
# epochs = ['STIM_17Tuning-0-1_null', 'STIM_null_13Tsik-0-1', 'STIM_17Tuning-0-1_13Tsik-0-1',
#           'STIM_17Tuning-0.5-1_13Tsik-0-1']
# colors = ['deepskyblue', 'yellowgreen', 'grey', 'rosybrown']
#
#
# epochs = [f'STIM_17Tuning-0-1_null', 'STIM_null_13Tsik-0-1', 'STIM_17Tuning-0-1_13Tsik-0-1',
#           'STIM_17Tuning-0.5-1_13Tsik-0-1']
# colors = ['deepskyblue', 'yellowgreen', 'grey', 'rosybrown']
#
# bg_path = f'/auto/users/hamersky/baphy/Config/lbhb/SoundObjects/@OverlappingPairs/Background1/{AA}.wav'
# fg_path = f'/auto/users/hamersky/baphy/Config/lbhb/SoundObjects/@OverlappingPairs/Foreground2/{BB}.wav'
#
# smooval = 6
#
# f = plt.figure(figsize=(15, 9))
# psth = plt.subplot2grid((4, 6), (0, 0), rowspan=2, colspan=6)
# specBG = plt.subplot2grid((4, 6), (2, 0), rowspan=1, colspan=6)
# specFG = plt.subplot2grid((4, 6), (3, 0), rowspan=1, colspan=6)
#
# ax = [psth, specBG, specFG]
#
# time = np.arange(0, r[epochs[0]].shape[-1] / rasterfs)
#
# for e, c in zip(epochs, colors):
#     mean = r[e][:,cell,:].mean(axis=0)
#     mean = smooth(mean, smooval)
#     ax[0].plot(mean, label=e, color=c)
# ax[0].legend()
# ax[0].set_title(f"{resp.chans[cell]}")
#
#
#
# r = resp.extract_epochs(epochs)
# fig, ax = plt.subplots()
# for e, c in zip(epochs,colors):
#     mean = r[e][:,cell,:].mean(axis=0)
#     mean = smooth(mean, 7)
#     ax.plot(mean, label=e, color=c)
# ax.legend()
# ax.set_title(f"{resp.chans[cell]}")
# ymin,ymax = ax.get_ylim()
# ax.vlines([50, 100, 150], ymin, ymax, color = 'black', ls=':')