示例#1
0
def get_psd_stats(bird, block, seg, hemi, data_dir='/auto/tdrive/mschachter/data'):

    transforms_dir = os.path.join(data_dir, bird, 'transforms')
    cf_file = os.path.join(transforms_dir, 'PairwiseCF_%s_%s_%s_%s_raw.h5' % (bird, block, seg, hemi))
    cft = PairwiseCFTransform.load(cf_file)

    electrodes = cft.df.electrode1.unique()

    estats = dict()
    for e in electrodes:
        i = (cft.df.electrode1 == e) & (cft.df.electrode1 == cft.df.electrode2) & (cft.df.decomp == 'locked')
        indices = cft.df['index'][i].values
        psds = cft.psds[indices]
        log_transform(psds)
        estats[e] = (psds.mean(axis=0), psds.std(axis=0, ddof=1))
    return estats
示例#2
0
def draw_figures(data_dir='/auto/tdrive/mschachter/data'):

    bird = 'GreBlu9508M'
    block = 'Site4'
    seg = 'Call1'
    hemi = 'L'
    file_ext = '_'.join([bird, block, seg, hemi])

    pcf_file = os.path.join(data_dir, bird, 'transforms', 'PairwiseCF_%s_%s_new.h5' % (file_ext, 'raw'))
    pcf = PairwiseCFTransform.load(pcf_file)

    g = pcf.df.groupby(['stim_id', 'order', 'stim_type'])

    plist = list()

    i = (pcf.df.decomp == 'full') & (pcf.df.electrode1 == pcf.df.electrode2)
    assert i.sum() > 0
    xfull_indices = list(pcf.df['index'][i].values)
    print 'len(pcf.df)=%d' % len(pcf.df)
    print 'pcf.df[index].max()=%d' % pcf.df['index'].max()
    print 'pcf.psds.shape=',pcf.psds.shape
    print 'xfull_indices.max()=%d' % max(xfull_indices)
    print 'len(xfull_indices)=%d' % len(xfull_indices)
    Xfull = pcf.psds[xfull_indices, :]
    Xfull /= Xfull.max()
    pcf.log_transform(Xfull)
    Xfull -= Xfull.mean(axis=0)
    Xfull /= Xfull.std(axis=0, ddof=1)

    # take log transform of power spectra
    i = (pcf.df.decomp == 'onewin') & (pcf.df.electrode1 == pcf.df.electrode2)
    assert i.sum() > 0
    xone_indices = list(pcf.df['index'][i].values)
    Xonewin = pcf.psds[xone_indices, :]
    Xonewin /= Xonewin.max()
    pcf.log_transform(Xonewin)
    Xonewin -= Xonewin.mean(axis=0)
    Xonewin /= Xonewin.std(axis=0, ddof=1)

    for (stim_id,order,stim_type),gdf in g:

        electrodes = gdf.electrode1.unique()
        stim_dur = gdf.stim_duration.values[0]
        if stim_dur < 0.050 or stim_dur > 0.400:
            continue

        for e in electrodes:
            i = (gdf.decomp == 'full') & (gdf.electrode1 == e) & (gdf.electrode1 == gdf.electrode2)
            assert i.sum() == 1

            xi = gdf['index'][i].values[0]
            ii = xfull_indices.index(xi)
            full_psd = Xfull[ii, :]

            i = (gdf.decomp == 'onewin') & (gdf.electrode1 == e) & (gdf.electrode1 == gdf.electrode2)
            assert i.sum() == 1
            xi = gdf['index'][i].values[0]
            ii = xone_indices.index(xi)
            onewin_psd = Xonewin[ii, :]

            plist.append({'full_psd':full_psd, 'onewin_psd':onewin_psd, 'stim_id':stim_id, 'stim_order':order,
                          'stim_type':stim_type, 'electrode':e, 'stim_dur':stim_dur})

    np.random.shuffle(plist)
    plist.sort(key=operator.itemgetter('stim_dur'))
    short_plist = [x for k,x in enumerate(plist) if k % 20 == 0]
    print 'len(short_plist)=%d' % len(short_plist)

    def _plot_psds(_pdata, _ax):
        absmax = max(np.abs(_pdata['full_psd']).max(), np.abs(_pdata['onewin_psd']).max())
        plt.axhline(0, c='k')
        plt.plot(pcf.freqs, _pdata['full_psd'], 'k-', linewidth=3.0, alpha=0.7)
        plt.plot(pcf.freqs, _pdata['onewin_psd'], 'g-', linewidth=3.0, alpha=0.7)
        plt.title('e%d: %d_%d (%s) %0.3fs' % (_pdata['electrode'], _pdata['stim_id'], _pdata['stim_order'], _pdata['stim_type'], _pdata['stim_dur']))
        plt.axis('tight')
        plt.ylim(-absmax, absmax)

    multi_plot(short_plist, _plot_psds, nrows=5, ncols=9)
    plt.show()
示例#3
0
def get_full_data(bird, block, segment, hemi, stim_id, data_dir='/auto/tdrive/mschachter/data'):

    bdir = os.path.join(data_dir, bird)
    tdir = os.path.join(bdir, 'transforms')

    aprops = USED_ACOUSTIC_PROPS

    # load the BioSound
    bs_file = os.path.join(tdir, 'BiosoundTransform_%s.h5' % bird)
    bs = BiosoundTransform.load(bs_file)

    # load the StimEvent transform
    se_file = os.path.join(tdir, 'StimEvent_%s_%s_%s_%s.h5' % (bird,block,segment,hemi))
    print 'Loading %s...' % se_file
    se = StimEventTransform.load(se_file, rep_types_to_load=['raw'])
    se.zscore('raw')
    se.segment_stims_from_biosound(bs_file)

    # load the pairwise CF transform
    pcf_file = os.path.join(tdir, 'PairwiseCF_%s_%s_%s_%s_raw.h5' % (bird,block,segment,hemi))
    print 'Loading %s...' % pcf_file
    pcf = PairwiseCFTransform.load(pcf_file)

    def log_transform(x, dbnoise=100.):
        x /= x.max()
        zi = x > 0
        x[zi] = 20*np.log10(x[zi]) + dbnoise
        x[x < 0] = 0
        x /= x.max()

    all_lfp_psds = deepcopy(pcf.psds)
    log_transform(all_lfp_psds)
    all_lfp_psds -= all_lfp_psds.mean(axis=0)
    all_lfp_psds /= all_lfp_psds.std(axis=0, ddof=1)

    # get overall biosound stats
    bs_stats = dict()
    for aprop in aprops:
        amean = bs.stim_df[aprop].mean()
        astd = bs.stim_df[aprop].std(ddof=1)
        bs_stats[aprop] = (amean, astd)

    for (stim_id2,stim_type2),gdf in se.segment_df.groupby(['stim_id', 'stim_type']):
        print '%d: %s' % (stim_id2, stim_type2)

    # get the spectrogram
    i = se.segment_df.stim_id == stim_id
    last_end_time = se.segment_df.end_time[i].max()

    spec_freq = se.spec_freq
    stim_spec = se.spec_by_stim[stim_id]
    spec_t = np.arange(stim_spec.shape[1]) / se.lfp_sample_rate
    speci = np.min(np.where(spec_t > last_end_time)[0])
    spec_t = spec_t[:speci]
    stim_spec = stim_spec[:, :speci]
    stim_dur = spec_t.max() - spec_t.min()

    # get the raw LFP
    si = int(se.pre_stim_time*se.lfp_sample_rate)
    ei = int(stim_dur*se.lfp_sample_rate) + si
    lfp = se.lfp_reps_by_stim['raw'][stim_id][:, :, si:ei]
    ntrials,nelectrodes,nt = lfp.shape

    # get the raw spikes, spike_mat is ragged array of shape (num_trials, num_cells, num_spikes)
    spike_mat = se.spikes_by_stim[stim_id]
    assert ntrials == len(spike_mat)

    ncells = len(se.cell_df)
    print 'ncells=%d' % ncells
    ntrials = len(spike_mat)

    # compute the PSTH
    psth = list()
    for n in range(ncells):
        # get the spikes across all trials for neuron n
        spikes = [spike_mat[k][n] for k in range(ntrials)]
        # make a PSTH
        _psth_t,_psth = compute_psth(spikes, stim_dur, bin_size=1.0/se.lfp_sample_rate)
        psth.append(_psth)
    psth = np.array(psth)

    if hemi == 'L':
        electrode_order = ROSTRAL_CAUDAL_ELECTRODES_LEFT
    else:
        electrode_order = ROSTRAL_CAUDAL_ELECTRODES_RIGHT

    # get acoustic props and LFP/spike power spectra for each syllable
    syllable_props = list()

    i = bs.stim_df.stim_id == stim_id
    orders = sorted(bs.stim_df.order[i].values)
    cell_index2electrode = None
    for o in orders:
        i = (bs.stim_df.stim_id == stim_id) & (bs.stim_df.order == o)
        assert i.sum() == 1

        d = dict()
        d['start_time'] = bs.stim_df.start_time[i].values[0]
        d['end_time'] = bs.stim_df.end_time[i].values[0]
        d['order'] = o

        for aprop in aprops:
            amean,astd = bs_stats[aprop]
            d[aprop] = (bs.stim_df[aprop][i].values[0] - amean) / astd

        # get the LFP power spectra
        lfp_psd = list()
        for k,e in enumerate(electrode_order):
            i = (pcf.df.stim_id == stim_id) & (pcf.df.order == o) & (pcf.df.decomp == 'full') & \
                (pcf.df.electrode1 == e) & (pcf.df.electrode2 == e)

            assert i.sum() == 1, "i.sum()=%d" % i.sum()

            index = pcf.df[i]['index'].values[0]
            lfp_psd.append(all_lfp_psds[index, :])
        d['lfp_psd'] = np.array(lfp_psd)

        syllable_props.append(d)

    return {'stim_id':stim_id, 'spec_t':spec_t, 'spec_freq':spec_freq, 'spec':stim_spec,
            'lfp':lfp, 'spikes':spike_mat, 'lfp_sample_rate':se.lfp_sample_rate, 'psth':psth,
            'syllable_props':syllable_props, 'electrode_order':electrode_order, 'psd_freq':pcf.freqs,
            'cell_index2electrode':cell_index2electrode, 'aprops':aprops}
示例#4
0
def draw_figures(data_dir='/auto/tdrive/mschachter/data'):

    pcf_file = os.path.join(data_dir, 'GreBlu9508M', 'transforms', 'PairwiseCF_GreBlu9508M_Site4_Call1_L_raw.h5')
    pcf = PairwiseCFTransform.load(pcf_file)

    lags_index = np.abs(pcf.lags) == 0.

    i = pcf.df.stim_type != 'mlnoise'
    df = pcf.df[i]
    
    # transform psds
    psds = deepcopy(pcf.psds)
    pcf.log_transform(psds)
    # psds -= psds.mean(axis=0)
    # psds /= psds.std(axis=0, ddof=1)

    # transform pairwise cfs
    cross_cfs = pcf.cross_cfs[:, lags_index]
    
    # transform synchrony
    syncs = deepcopy(pcf.spike_synchrony)
    # sync -= sync.mean(axis=0)
    # sync /= sync.std(axis=0, ddof=1)

    X = list()

    electrodes = pcf.df.electrode1.unique()

    g = df.groupby(['stim_id', 'order'])
    for (stim_id,syllable_order),gdf in g:

        for k,e1 in enumerate(electrodes):
            for j in range(k):
                e2 = electrodes[j]

                i = (gdf.decomp == 'locked') & (gdf.electrode1 == e1) & (gdf.electrode2 == e2)
                if i.sum() == 0:
                    i = (gdf.decomp == 'locked') & (gdf.electrode1 == e2) & (gdf.electrode2 == e1)
                assert i.sum() == 1
                indices = gdf['index'][i]
                cf = cross_cfs[indices, :]
                cf_sum = np.abs(cf).sum()

                # compute average spike synchrony for this stim and electrode
                i = (gdf.decomp == 'spike_sync') & ((gdf.electrode1 == e1) | (gdf.electrode2 == e2))
                if i.sum() == 0:
                    i = (gdf.decomp == 'spike_sync') & ((gdf.electrode1 == e2) | (gdf.electrode2 == e1))
                assert i.sum() >= 1
                indices = gdf['index'][i]
                sync12 = syncs[indices].max()

                if cf_sum > 0 and sync12 > 0:
                    X.append((cf_sum, sync12))

    X = np.array(X)
    cc = np.corrcoef(X[:, 0], X[:, 1])[0, 1]

    plt.figure()
    ax = plt.subplot(1, 1, 1)
    plt.plot(X[:, 1], X[:, 0], 'go')
    plt.xlabel('Spike Synchrony')
    plt.ylabel('LFP Synchrony (-20ms to 20ms)')
    plt.title('cc=%0.2f' % cc)
    plt.axis('tight')

    plt.show()