Example #1
0
def draw_single_cell(agg, df_cell):

    aprops_to_display = ['category', 'maxAmp', 'sal', 'meanspect', 'q1', 'q2', 'q3',
                         'entropyspect', 'meantime', 'entropytime']

    plist = list()
    for aprop in aprops_to_display:
        perfs = df_cell['perf_%s' % aprop].values
        lkrats = df_cell['lkrat_%s' % aprop].values
        plist.append({'perfs':perfs, 'lkrats':lkrats, 'aprop':aprop})

    def _plot_perf_hist(pdata, ax):
        plt.sca(ax)
        plt.hist(pdata['perfs'], bins=10, color=COLOR_YELLOW_SPIKE)
        plt.xlabel('R2')
        plt.axis('tight')
        plt.xlim(0, 0.4)
        tks = [0.0, 0.1, 0.2, 0.3, 0.4]
        plt.xticks(tks, ['%0.1f' % x for x in tks])
        plt.title(pdata['aprop'])

    def _plot_lkrat_hist(pdata, ax):
        lkr = pdata['lkrats']
        lkr = lkr[lkr >= 0]
        plt.sca(ax)
        plt.hist(lkr, bins=10, color=COLOR_YELLOW_SPIKE)
        plt.xlabel('Likelihood Ratio')
        plt.axis('tight')
        plt.xlim(0, 60)
        plt.title(pdata['aprop'])

    multi_plot(plist, _plot_perf_hist, nrows=2, ncols=5)
    multi_plot(plist, _plot_lkrat_hist, nrows=2, ncols=5)
Example #2
0
def draw_neural_joint_relationships(preproc_file):

    hf = h5py.File(preproc_file)
    index2prop = list(hf.attrs['integer2prop'])
    index2type = list(hf.attrs['integer2type'])
    S = np.array(hf['S'])
    X = np.array(hf['X'])
    Y = np.array(hf['Y'])
    hf.close()

    stim_type = [index2type[k] for k in Y[:, 0]]

    sub_i = [index2prop.index('maxAmp'), index2prop.index('sal'), index2prop.index('meanspect')]
    S = S[:, sub_i]

    ncells = X.shape[1]
    nsamps = S.shape[0]

    csv_file = '/auto/tdrive/mschachter/data/aggregate/big3_spike_rate.csv'
    if not os.path.exists(csv_file):
        # write out a csv file for analysis
        cdata = {'maxAmp':list(), 'sal':list(), 'meanspect':list(), 'call_type':list(), 'cell':list(), 'spike_rate':list()}
        for n in range(ncells):
            for k in range(nsamps):
                cdata['maxAmp'].append(S[k, 0])
                cdata['sal'].append(S[k, 1])
                cdata['meanspect'].append(S[k, 2])
                cdata['call_type'].append(stim_type[k])
                cdata['cell'].append(n)
                cdata['spike_rate'].append(X[k, n])
        cdf = pd.DataFrame(cdata)
        cdf.to_csv(csv_file, header=True, index=False)

    plist = list()
    for n in range(ncells):
        plist.append({'n':n, 'x':S[:, 0], 'y':S[:, 1], 'r':X[:, n], 'xlabel':'maxAmp', 'ylabel':'saliency', 'logx':True})
        plist.append({'n': n, 'x': S[:, 0], 'y': S[:, 2], 'r': X[:, n], 'xlabel': 'maxAmp', 'ylabel': 'meanspect', 'logx':True})
        plist.append({'n': n, 'x': S[:, 1], 'y': S[:, 2], 'r': X[:, n], 'xlabel': 'saliency', 'ylabel': 'meanspect', 'logx':False})

    def _plot_scatter(_pdata, _ax):
        if _pdata['logx']:
            _ax.set_xscale('log')
        """
        _clrs = [CALL_TYPE_COLORS[ct] for ct in stim_type]
        _r = _pdata['r']
        _r -= _r.min()
        _r /= _r.max()
        for _x,_y,_r,_c in zip(_pdata['x'], _pdata['y'], _r, _clrs):
            plt.plot(_x, _y, 'o', c=_c, alpha=_r)
        """
        plt.scatter(_pdata['x'], _pdata['y'], marker='o', c=_pdata['r'], cmap=plt.cm.afmhot_r, s=49, alpha=0.7)

        plt.xlabel(_pdata['xlabel'])
        plt.ylabel(_pdata['ylabel'])

        plt.title('cell %d' % _pdata['n'])
        plt.colorbar(label='Spike Rate')

    multi_plot(plist, _plot_scatter, nrows=3, ncols=3, figsize=(23,13), wspace=0.25, hspace=0.25)
    plt.show()
Example #3
0
def draw_figure(data_dir='/auto/tdrive/mschachter/data', bird='GreBlu9508M'):

    pp_file = os.path.join(data_dir, bird, 'preprocess', 'preproc_raw_coherence_band0_npc0_self_locked_all_Site4_Call1_L.h5')
    print 'pp_file=%s' % pp_file

    hf = h5py.File(pp_file, 'r')
    X = np.array(hf['X'])
    S = np.array(hf['S'])
    index2electrode = list(hf.attrs['index2electrode'])
    index2aprop = list(hf.attrs['integer2prop'])
    freqs = list(hf.attrs['freqs'])
    hf.close()

    reduced_aprops = ['fund',
                       'sal',
                       'voice2percent',
                       'meanspect',
                       'skewspect',
                       'entropyspect',
                       'q2',
                       'meantime',
                       'skewtime',
                       'entropytime',
                       'maxAmp']

    nelectrodes = len(index2electrode)
    nfreqs = len(freqs)

    electrode_order = ROSTRAL_CAUDAL_ELECTRODES_LEFT

    X -= X.mean()
    X /= X.std(ddof=1)
    S -= S.mean()
    S /= S.std(ddof=1)

    cc_mats = list()
    for k,aprop in enumerate(reduced_aprops):
        s = S[:, k]
        CC = np.zeros([nelectrodes, nfreqs])
        for j,e in enumerate(index2electrode):
            for m,f in enumerate(freqs):

                i = j*nfreqs + m
                x = X[:, i]
                jj = electrode_order.index(e)
                CC[jj, m] = np.corrcoef(x, s)[0, 1]
        cc_mats.append({'CC':CC, 'aprop':aprop})

    def _plot_cc_mat(pdata, ax):
        plt.sca(ax)
        plt.imshow(pdata['CC'], interpolation='nearest', aspect='auto', vmin=-0.5, vmax=0.5, cmap=plt.cm.seismic)
        plt.yticks(range(nelectrodes), ['%d' % _e for _e in electrode_order])
        plt.xticks(range(nfreqs), ['%d' % _f for _f in freqs])
        plt.title(pdata['aprop'])
        plt.colorbar()

    multi_plot(cc_mats, _plot_cc_mat, nrows=4, ncols=3)
    plt.show()
Example #4
0
def draw_perf_hists(agg, df_me):

    assert isinstance(agg, AggregateLFPAndSpikePSDDecoder)
    freqs = agg.freqs

    aprops_to_display = ['category', 'maxAmp', 'meanspect', 'stdspect', 'q1', 'q2', 'q3', 'skewspect', 'kurtosisspect',
                         'sal', 'entropyspect', 'meantime', 'stdtime', 'entropytime']

    # make histograms of performances across sites for each acoustic property
    perf_list = list()
    i = (df_me.band == 0)
    for aprop in aprops_to_display:
        lfp_perf = df_me[i]['perf_%s_%s' % (aprop, 'lfp')].values
        spike_perf = df_me[i]['perf_%s_%s' % (aprop, 'spike')].values

        perf_list.append({'lfp_perf':lfp_perf, 'spike_perf':spike_perf,
                          'lfp_mean':lfp_perf.mean(), 'spike_mean':spike_perf.mean(),
                          'aprop':aprop})

    # make plots
    print 'len(perf_list)=%d' % len(perf_list)

    def _plot_hist(pdata, ax):
        plt.sca(ax)
        plt.hist(pdata['lfp_perf'], bins=10, color='#00639E', alpha=0.7)
        plt.hist(pdata['spike_perf'], bins=10, color='#F1DE00', alpha=0.7)
        plt.legend(['LFP', 'Spike'], fontsize='x-small')
        plt.title(pdata['aprop'])
        if pdata['aprop'] == 'category':
            plt.xlabel('PCC')
        else:
            plt.xlabel('R2')
        plt.axis('tight')

    multi_plot(perf_list, _plot_hist, nrows=3, ncols=5, hspace=0.30, wspace=0.30)

    def _plot_scatter(pdata, ax):
        # pmax = max(pdata['lfp_perf'].max(), pdata['spike_perf'].max())
        pmax = 0.8
        plt.sca(ax)
        plt.plot(np.linspace(0, pmax, 20), np.linspace(0, pmax, 20), 'k-')
        plt.plot(pdata['lfp_perf'], pdata['spike_perf'], 'ko', alpha=0.7, markersize=10.)
        plt.title(pdata['aprop'])
        pstr = 'R2'
        if pdata['aprop'] == 'category':
            pstr = 'PCC'
        plt.xlabel('LFP %s' % pstr)
        plt.ylabel('Spike %s' % pstr)
        plt.axis('tight')
        plt.xlim(0, pmax)
        plt.ylim(0, pmax)

    multi_plot(perf_list, _plot_scatter, nrows=3, ncols=5, hspace=0.30, wspace=0.30)
Example #5
0
def draw_all_encoder_perfs_and_decoder_weights(agg, aprops=('sal', 'q2', 'maxAmp', 'meantime', 'entropytime')):

    freqs,lags = get_freqs_and_lags()

    font = {'family':'normal', 'weight':'bold', 'size':10}
    plt.matplotlib.rc('font', **font)

    plist = list()

    for (bird,block,segment,hemi),gdf in agg.df.groupby(['bird', 'block', 'segment', 'hemi']):

        bstr = '%s_%s_%s_%s' % (bird,hemi,block,segment)
        ii = (gdf.decomp == 'self_locked')
        assert ii.sum() == 1
        wkey = gdf[ii]['wkey'].values[0]

        lfp_eperf = agg.encoder_perfs[wkey]

        plist.append({'type':'encoder', 'X':lfp_eperf, 'title':bstr})

        lfp_decoder_weights = agg.decoder_weights[wkey]
        lfp_decoder_perfs = agg.decoder_perfs[wkey]

        for k,aprop in enumerate(aprops):
            ai = USED_ACOUSTIC_PROPS.index(aprop)
            lfp_weights = lfp_decoder_weights[:, :, ai]
            dstr = '%0.2f (%s)' % (lfp_decoder_perfs[ai], aprop)
            plist.append({'type':'decoder', 'X':lfp_weights, 'title':dstr})

    def _plot_X(_pdata, _ax):
        plt.sca(_ax)
        _X = _pdata['X']
        if _pdata['type'] == 'decoder':
            _absmax = np.abs(_X).max()
            _vmin = -_absmax
            _vmax = _absmax
            _cmap = plt.cm.seismic
        else:
            _vmin = 0.
            _vmax = 0.35
            _cmap = magma

        plt.imshow(_X, interpolation='nearest', aspect='auto', cmap=_cmap, vmin=_vmin, vmax=_vmax)
        plt.title(_pdata['title'])
        plt.xticks([])
        plt.yticks([])

    multi_plot(plist, _plot_X, nrows=5, ncols=6, figsize=(23, 13))
    plt.show()
Example #6
0
def draw_lags_vs_perf(data_dir='/auto/tdrive/mschachter/data'):

    pfile = os.path.join(data_dir, 'GreBlu9508M', 'transforms', 'PairwiseCF_GreBlu9508M_Site4_Call1_L_raw.h5')
    hf = h5py.File(pfile, 'r')
    full_lags_ms = hf.attrs['lags']
    hf.close()

    max_lag = full_lags_ms.max()

    agg_file = os.path.join(data_dir, 'aggregate', 'pard.h5')
    agg = AcousticEncoderDecoderAggregator.load(agg_file)

    print 'keys=',agg.df.keys()

    lag_bnds = [ (-1., 1.), (-6., 6.), (-13., 13.), (-18., 18), (-23., 23), (-28., 28.), (-34., 34), (None,None)]
    perfs_by_bound = list()
    for lb,ub in lag_bnds:
        if lb is None:
            decomp = 'self+cross_locked'
        else:
            decomp = 'self+cross_locked_lim_%d_%d' % (lb, ub)

        i = agg.df.decomp == decomp
        print 'decomp=%s, i.sum()=%d' % (decomp, i.sum())
        wkeys = agg.df.wkey[i].values

        dperfs = np.array([agg.decoder_perfs[wkey] for wkey in wkeys])
        perfs_by_bound.append(dperfs)

    perfs_by_bound = np.array(perfs_by_bound)

    pbb_mean = perfs_by_bound.mean(axis=1)
    pbb_std = perfs_by_bound.std(axis=1)

    plist = list()
    for k,aprop in enumerate(REDUCED_ACOUSTIC_PROPS):
        plist.append({'aprop':aprop, 'mean':pbb_mean[:, k], 'std':pbb_std[:, k]})

    def _plot_pbb(pdata, ax):
        _lagw = [2*x[1] for x in lag_bnds if x[1] is not None]
        _lagw.append(max_lag*2)
        plt.plot(_lagw, pdata['mean'], 'k-', linewidth=3.0, alpha=0.7)
        plt.xlabel('Lag Width (ms)')
        plt.ylabel('R2')
        plt.axis('tight')
        plt.title(pdata['aprop'])

    multi_plot(plist, _plot_pbb, nrows=3, ncols=4)
    plt.show()
Example #7
0
def draw_figures(data_dir='/auto/tdrive/mschachter/data'):

    bird = 'GreBlu9508M'
    block = 'Site4'
    seg = 'Call1'
    hemi = 'L'
    file_ext = '_'.join([bird, block, seg, hemi])

    pcf_file = os.path.join(data_dir, bird, 'transforms', 'PairwiseCF_%s_%s_new.h5' % (file_ext, 'raw'))
    pcf = PairwiseCFTransform.load(pcf_file)

    g = pcf.df.groupby(['stim_id', 'order', 'stim_type'])

    plist = list()

    i = (pcf.df.decomp == 'full') & (pcf.df.electrode1 == pcf.df.electrode2)
    assert i.sum() > 0
    xfull_indices = list(pcf.df['index'][i].values)
    print 'len(pcf.df)=%d' % len(pcf.df)
    print 'pcf.df[index].max()=%d' % pcf.df['index'].max()
    print 'pcf.psds.shape=',pcf.psds.shape
    print 'xfull_indices.max()=%d' % max(xfull_indices)
    print 'len(xfull_indices)=%d' % len(xfull_indices)
    Xfull = pcf.psds[xfull_indices, :]
    Xfull /= Xfull.max()
    pcf.log_transform(Xfull)
    Xfull -= Xfull.mean(axis=0)
    Xfull /= Xfull.std(axis=0, ddof=1)

    # take log transform of power spectra
    i = (pcf.df.decomp == 'onewin') & (pcf.df.electrode1 == pcf.df.electrode2)
    assert i.sum() > 0
    xone_indices = list(pcf.df['index'][i].values)
    Xonewin = pcf.psds[xone_indices, :]
    Xonewin /= Xonewin.max()
    pcf.log_transform(Xonewin)
    Xonewin -= Xonewin.mean(axis=0)
    Xonewin /= Xonewin.std(axis=0, ddof=1)

    for (stim_id,order,stim_type),gdf in g:

        electrodes = gdf.electrode1.unique()
        stim_dur = gdf.stim_duration.values[0]
        if stim_dur < 0.050 or stim_dur > 0.400:
            continue

        for e in electrodes:
            i = (gdf.decomp == 'full') & (gdf.electrode1 == e) & (gdf.electrode1 == gdf.electrode2)
            assert i.sum() == 1

            xi = gdf['index'][i].values[0]
            ii = xfull_indices.index(xi)
            full_psd = Xfull[ii, :]

            i = (gdf.decomp == 'onewin') & (gdf.electrode1 == e) & (gdf.electrode1 == gdf.electrode2)
            assert i.sum() == 1
            xi = gdf['index'][i].values[0]
            ii = xone_indices.index(xi)
            onewin_psd = Xonewin[ii, :]

            plist.append({'full_psd':full_psd, 'onewin_psd':onewin_psd, 'stim_id':stim_id, 'stim_order':order,
                          'stim_type':stim_type, 'electrode':e, 'stim_dur':stim_dur})

    np.random.shuffle(plist)
    plist.sort(key=operator.itemgetter('stim_dur'))
    short_plist = [x for k,x in enumerate(plist) if k % 20 == 0]
    print 'len(short_plist)=%d' % len(short_plist)

    def _plot_psds(_pdata, _ax):
        absmax = max(np.abs(_pdata['full_psd']).max(), np.abs(_pdata['onewin_psd']).max())
        plt.axhline(0, c='k')
        plt.plot(pcf.freqs, _pdata['full_psd'], 'k-', linewidth=3.0, alpha=0.7)
        plt.plot(pcf.freqs, _pdata['onewin_psd'], 'g-', linewidth=3.0, alpha=0.7)
        plt.title('e%d: %d_%d (%s) %0.3fs' % (_pdata['electrode'], _pdata['stim_id'], _pdata['stim_order'], _pdata['stim_type'], _pdata['stim_dur']))
        plt.axis('tight')
        plt.ylim(-absmax, absmax)

    multi_plot(short_plist, _plot_psds, nrows=5, ncols=9)
    plt.show()
Example #8
0
def draw_freq_lkrats(agg, df_me):

    # aprops_to_display = ['category', 'maxAmp', 'meanspect', 'stdspect', 'q1', 'q2', 'q3', 'skewspect', 'kurtosisspect',
    #                      'sal', 'entropyspect', 'meantime', 'stdtime', 'entropytime']

    #aprops_to_display = ['category', 'maxAmp', 'sal',
    #                     'entropytime', 'meanspect', 'entropyspect',
    #                     'q1', 'q2', 'q3']
    aprops_to_display = ['category', 'maxAmp', 'sal',
                         'q1', 'q2', 'q3']

    assert isinstance(agg, AggregateLFPAndSpikePSDDecoder)
    freqs = agg.freqs
    nbands = len(freqs)

    # compute the significance threshold for each site, use it to normalize the likelihood ratio
    i = agg.df.bird != 'BlaBro09xxF'
    g = agg.df[i].groupby(['bird', 'block', 'segment', 'hemi'])

    num_sites = len(g)

    normed_lkrat_lfp = dict()
    normed_lkrat_spike = dict()
    for aprop in aprops_to_display:
        normed_lkrat_lfp[aprop] = np.zeros([num_sites, nbands])
        normed_lkrat_spike[aprop] = np.zeros([num_sites, nbands])

    for k,((bird,block,segment,hemi),gdf) in enumerate(g):

        i = (gdf.e1 != -1) & (gdf.e1 == gdf.e2) & (gdf.cell_index != -1) & (gdf.decomp == 'spike_psd') & \
            (gdf.exel == False) & (gdf.exfreq == False) & (gdf.aprop == 'q2')
        ncells = i.sum()
        # print '%s,%s,%s,%s # of cells: %d' % (bird, block, segment, hemi, ncells)

        # get the likelihood ratios and normalize them by threshold
        for aprop in aprops_to_display:

            for b in range(1, nbands+1):
                i = (df_me.bird == bird) & (df_me.block == block) & (df_me.segment == segment) & (df_me.hemi == hemi) & \
                    (df_me.band == b)
                if i.sum() != 1:
                    print 'i.sum()=%d, b=%d, (%s,%s,%s,%s)' % (i.sum(), b, bird, block, segment, hemi)
                assert i.sum() == 1

                lkrats_lfp = df_me[i]['lkrat_%s_%s' % (aprop, 'lfp')].values[0]
                lkrats_spike = df_me[i]['lkrat_%s_%s' % (aprop, 'spike')].values[0]

                normed_lkrat_lfp[aprop][k, b-1] = lkrats_lfp
                normed_lkrat_spike[aprop][k, b-1] = lkrats_spike

    # make a list of data for multi plot
    plist = list()
    for aprop in aprops_to_display:
        plist.append({'lfp':normed_lkrat_lfp[aprop], 'spike':normed_lkrat_spike[aprop], 'freqs':freqs, 'aprop':aprop})

    def _plot_freqs(pdata, ax):
        plt.sca(ax)

        nsamps_lfp = len(pdata['lfp'])
        lkrat_lfp_mean = pdata['lfp'].mean(axis=0)
        lkrat_lfp_std = pdata['lfp'].std(axis=0, ddof=1) / np.sqrt(nsamps_lfp)

        nsamps_spike = len(pdata['spike'])
        lkrat_spike_mean = pdata['spike'].mean(axis=0)
        lkrat_spike_std = pdata['spike'].std(axis=0, ddof=1) / np.sqrt(nsamps_spike)

        if pdata['aprop'] != 'category':
            plt.axhline(1.0, c='k', linestyle='dashed', alpha=0.7, linewidth=2.0)
        plt.errorbar(pdata['freqs'], lkrat_lfp_mean, yerr=lkrat_lfp_std, c=COLOR_BLUE_LFP, linewidth=7.0, alpha=0.9, ecolor='k', elinewidth=2.0)
        plt.errorbar(pdata['freqs']+2., lkrat_spike_mean, yerr=lkrat_spike_std, c=COLOR_YELLOW_SPIKE, linewidth=7.0, alpha=0.9, ecolor='k', elinewidth=2.0)

        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Normalized LR')
        leg = custom_legend([COLOR_BLUE_LFP, COLOR_YELLOW_SPIKE], ['LFP', 'Spike'])
        plt.legend(handles=leg, fontsize='x-small')
        plt.title(pdata['aprop'])
        plt.axis('tight')
        if pdata['aprop'] != 'category':
            plt.ylim(0, 5)
        else:
            plt.axis('tight')

    figsize = (24, 9)
    multi_plot(plist, _plot_freqs, nrows=2, ncols=3, hspace=0.45, wspace=0.45, facecolor='w', figsize=figsize)
    fname = os.path.join(get_this_dir(), 'perf_by_freq.svg')
    plt.savefig(fname, facecolor='w', edgecolor='none')
Example #9
0
            t_c = 0.5 * (1.0 / t_freq) - 0.010
            strf = onset_strf(t,
                              f,
                              t_freq=t_freq,
                              t_phase=np.pi,
                              f_c=f_c,
                              f_sigma=1000.0,
                              t_sigma=t_sigma,
                              t_c=t_c)
            title = '$f_c$=%dHz, $\sigma_t$=%dms, $f_t$=%dHz' % (f_c, t_sigma *
                                                                 1e3, t_freq)
            onset_plist.append({'strf': strf, 'title': title})

    multi_plot(onset_plist,
               plot_strf,
               nrows=len(onset_f_c),
               ncols=len(onset_t_sigmas))

    #build harmonic stack STRFs
    stack_t_sigma = 0.005
    stack_f_sigma = 1500
    stack_f_c = np.linspace(300.0, 8000.0, 10)
    stack_f_freq = np.linspace(1e-4, 7e-4, 5)

    stack_t_freqs = np.array([20.0, 15.0, 10.0, 5.0])

    stack_plist = list()
    for f_c in stack_f_c:
        for f_freq in stack_f_freq:
            strf = checkerboard_strf(t,
                                     f,
Example #10
0
    #build onset STRFs of varying center frequency and temporal bandwidths
    onset_f_sigma = 500
    onset_f_c = np.linspace(300.0, 8000.0, 10)
    onset_t_sigmas = np.array([0.005, 0.010, 0.025, 0.050])
    onset_t_freqs = np.array([20.0, 15.0, 10.0, 5.0])

    onset_plist = list()
    for f_c in onset_f_c:
        for t_sigma,t_freq in zip(onset_t_sigmas, onset_t_freqs):

            t_c = 0.5*(1.0 / t_freq) - 0.010
            strf = onset_strf(t, f, t_freq=t_freq, t_phase=np.pi, f_c=f_c, f_sigma=1000.0, t_sigma=t_sigma, t_c=t_c)
            title = '$f_c$=%dHz, $\sigma_t$=%dms, $f_t$=%dHz' % (f_c, t_sigma*1e3, t_freq)
            onset_plist.append({'strf':strf, 'title':title})

    multi_plot(onset_plist, plot_strf, nrows=len(onset_f_c), ncols=len(onset_t_sigmas))


    #build harmonic stack STRFs
    stack_t_sigma = 0.005
    stack_f_sigma = 1500
    stack_f_c = np.linspace(300.0, 8000.0, 10)
    stack_f_freq = np.linspace(1e-4, 7e-4, 5)

    stack_t_freqs = np.array([20.0, 15.0, 10.0, 5.0])

    stack_plist = list()
    for f_c in stack_f_c:
        for f_freq in stack_f_freq:
            strf = checkerboard_strf(t, f,
                                     t_freq=10.0, t_phase=0.0,