Ejemplo n.º 1
0
def overlay_markers(c1='w',c2='k',FS=1000.,nevents=3,fontsize=14,npad=None,labels=None,clip_on=False):
    '''
    labels: default none. Can be
        "markers" for symbols or
        "names" for ['Object presented','Grip cued','Go cue']
        "short" for ['Object','Grip','Go cue']
    '''
    a,b = xlim()
    dx,dy = get_ax_pixel()
    debug( 'dx,dy=%s,%s'%(dx,dy))
    if labels=='markers':
        names = ['$\circ$','$\diamond$','$\star$']
    elif labels=='names':
        names = ['Object presented','Grip cued','Go cue']
    elif labels=='short':
        names = ['Object','Grip','Go cue']
    else:
        names = (None,)*3
    locations = [1000,2000,4000]
    for time,label in zip(locations,names)[:nevents]:
        time = float(time)/FS
        if time<=a: continue
        if time>=b: continue
        plot([time,time],ylim(),color=c2,lw=3,zorder=Inf,clip_on=clip_on)
        plot([time,time],ylim(),color=c1,lw=1,zorder=Inf,clip_on=clip_on)
        if not labels is None:
            text(time,ylim()[1]+dy*4,label,
                rotation=0,color='k',fontsize=fontsize,
                horizontalalignment='center',verticalalignment='bottom')
    xlim(a,b)
Ejemplo n.º 2
0
def plotISIhistHz(session,area,unit,start,stop,
    FS=1000,style='bar',color='k',nbins=30,label=None,
    fmin=0,FMAX=100,
    linestyle='-',lw=2,w=0.8):
    ByTrialSpikesMS = metaloadvariable(session,area,'ByTrialSpikesMS')
    NUNITS, NTRIALS = shape(ByTrialSpikesMS)
    isi = []
    for it in range(NTRIALS):
        isi.extend(diff(getTrial(session,area,unit,start,stop,it+1)))
    isi = FS/array(isi)
    if style=='bar':
        hist(isi,linspace(fmin,FMAX,nbins+1),facecolor='k',label=label,rwidth=w)
        a,b = ylim()
        plot((modefind(isi),)*2,ylim(),color='r',lw=lw)
        ylim(a,b)
        nicelimits()
        isi = isi[~isnan(isi)]
        isi = isi[~isinf(isi)]
        mf = modefind(isi)
        if (not isnan(mf)) and (not isinf(mf)):
            xticks(int32([xlim()[0],int(round(modefind(isi))),xlim()[1]]))
    elif style=='line':
        h,_ = histogram(isi,bins=nbins,range=(0,100))
        x = linspace(fmin,FMAX,1+nbins)
        x = 0.5*(x[1:]+x[:-1])
        plot(x,h,linestyle,lw=1,color=color,label=label)
    else:
        raise ValueError("Plot style must be line or bar")
    bareaxis(gca())
    xlabel('Frequency (Hz)')
    ylabel('N')
    title('ISI histogram\n(frequency)')
    gca().yaxis.labelpad = -20
Ejemplo n.º 3
0
def plotLFP(lfp):
    cla()
    plot(arange(stop-start)/FS,lfp,'k')
    ylim(-150,150)
    yticks([-150,0,150])
    ylabel('Microvolts')
    xlabel('Time (s)')
    simpleraxis(gca())
Ejemplo n.º 4
0
def allsession_summary_plot(statistic,monkey,area,fa,fb,color1=None,color2=None,colors=None,drawCues=False,smoothat=None,dolegend=False):
    '''
    allsession_summary_plot(array_average_ampltiude)
    '''
    filter_function = box_filter
    epoch = (6,-1000,6000)
    debug('!>>> applying',statistic,monkey,area,epoch,fa,fb)
    all_res = []
    for _s,_a in sessions_areas():
        if _s[0]!=monkey[0]: continue
        if _a!= area: continue
        times,res=onsession(statistic,_s,area,epoch,fa,fb)
        if not smoothat is None:
            res = arr([filter_function(x,smoothat) for x in res])
        all_res.extend(res)

    offset = (7000-len(times))/2
    times = arange(len(times))+offset
    ss = sort(res,axis=0)
    N = shape(res)[0]
    p0 = int(N*10/100+.4)
    p9 = int(N*90/100+.4)
    q1 = N/4
    q3 = N*3/4
    Q1 = ss[q1,:]
    Q2 = ss[q3,:]
    P0 = ss[p0,:]
    P9 = ss[p9,:]
    m = mean(res,0)
    M = median(res,0)
    s = std(res,0)
    sem = 1.96*s/sqrt(N)

    if not smoothat is None:
        m   = filter_function(m,smoothat)
        M   = filter_function(M,smoothat)
        s   = filter_function(s,smoothat)
        sem = filter_function(sem,smoothat)
        Q1 = filter_function(Q1,smoothat)
        Q2 = filter_function(Q2,smoothat)
        P0 = filter_function(P0,smoothat)
        P9 = filter_function(P9,smoothat)

    #plot(times,M,color=color1,lw=5,zorder=0,label='Interquartile range')
    fill_between(times,Q1,Q2,color=color1,lw=0,zorder=2)
    plot(times,M,color=color2,lw=2,zorder=2,label='Median')

    if dolegend: nicelegend()
    name = ' '.join(statistic.__name__.replace('_',' ').split()).title()
    title('%s, %s-%sHz, %s %s'%(name,fa,fb,monkey,area))
    xlabel('Time (ms)')
    ylabel(name)
    xlim(times[0],times[-1])
    if drawCues:
        overlayEvents('k','w',FS=1)
    draw()

    return m,s,sem,M,Q1,Q2,P0,P9
Ejemplo n.º 5
0
def ppc(session,area,unit,start,stop,
    window=100,FMAX=250,color='k',label=None,nTapers=None,lw=1.5,linestyle='-'):
    if not nTapers is None:
        warn('WARNING: no longer using multitaper, nTapers will be ignored! using Hann window')
    # depricating original code due to inconsistency with matlab PPC code
    '''
    snippits = getSTLFP(session,area,unit,start,stop,window)
    M        = shape(snippits)[0]
    fs       = fft(snippits)
    raw      = abs(sum(fs/abs(fs),0))**2
    unbiased = (raw-M)/(M**2-M)
    freqs = fftfreq(window*2+1,1./FS)
    use = (freqs>0.)&(freqs<=250.)
    plot(freqs[use],unbiased [use],color=color,label=label)
    ylim(0,0.5)
    ylabel('PPC',labelpad=-10)
    xlim(0,FMAX)
    nicelimits()
    xticks(linspace(0,FMAX,11),['%d'%x for x in linspace(0,FMAX,11)])
    xlabel('Frequency (Hz)')
    title('Pairwise phase consistency')
    simpleaxis(gca())
    '''
    channel = get_unit_channel(session,area,unit)
    '''
    # formerly getting signals from each block. we need to pull the data
    # from the raw LFP though, so that we can grab some LFP outside the
    # blocks in order to analyze spikes close to the edge of block. We
    # need those spikes for statistical power
    signal  = get_good_trial_lfp_data(session,area,channel)
    times   = get_all_good_trial_spike_times(session,area,unit,(6,start,stop))
    '''
    signal = get_raw_lfp_session(session,area,channel)
    times  = get_spikes_session_filtered_by_epoch(session,area,unit,(6,start,stop))

    (freqs,unbiased,phases),snippits = pairwise_phase_consistancy(signal,times,
        window=window,
        Fs=1000,
        multitaper=False,
        biased=False,
        delta=100,
        taper=hann)
    # function does not return anything,
    # just plots (below)
    use = (freqs>0.)&(freqs<=250.)
    plot(freqs[use],unbiased [use],linestyle,color=color,label=label,lw=lw)
    #cl = ppc_chance_level(nSamples,10000,.999,nTapers)
    #plot(xlim(),[cl,cl],color=color,label=label+' 99.9% chance level')
    #print 'chance level is %s'%cl
    ylim(0,0.5)
    ylabel('PPC',labelpad=-10)
    xlim(0,FMAX)
    nicelimits()
    xticks(linspace(0,FMAX,6),['%d'%x for x in linspace(0,FMAX,11)])
    xlabel('Frequency (Hz)')
    title('Pairwise phase consistency')
    simpleaxis(gca())
Ejemplo n.º 6
0
def plot_phase_gradient(dz):
    cla()
    imshow(angle(dz),interpolation='nearest')
    hsv()
    for i,row in list(enumerate(dz))[::1]:
        for j,z in list(enumerate(row))[::1]:
            z *=5
            plot([j,j+real(z)],[i,i+imag(z)],'w',lw=1)
    h,w = shape(dz)
    xlim(0-0.5,w-0.5)
    ylim(h-0.5,0-0.5)
Ejemplo n.º 7
0
def do_unit_ISI_plot(session, area, unit,
    INFOXPOS = 70, LABELSIZE=8, NBINS=20, FMAX = 250):
    cla()
    spikes = []
    for trial in get_good_trials(session):
        spikes.append(cgid.spikes.get_spikes_event(
            session,area,unit,trial,6,-1000,0))
        spikes.append(cgid.spikes.get_spikes_event(
            session,area,unit,trial,8,-1000,0))
    ISI_events = array(list(flatten(map(diff,spikes))))
    SNR        = get_unit_SNR(session,area,unit)
    histc,edges = histogram(ISI_events, bins = linspace(0,FMAX,NBINS+1))
    dx         = diff(edges)[0]
    bar(edges[:-1]+dx*0.1,histc,width=dx*0.8,color=GATHER[-1],edgecolor=(0,)*4)
    allisi = array(ISI_events)
    K   = 20
    x,y = kdepeak(log(K+allisi[allisi>0]))
    x   = exp(x)-K
    y   = y/(K+x)
    y   = y*len(allisi)*dx
    plot(x,y,color=RUST,lw=1.5)
    mean_rate           = sum(map(len,spikes))/float(len(get_good_trials(session))*2)
    ISI_cv              = std(allisi)/mean(allisi)
    burstiness          = sum(allisi<5)/float(len(allisi))*100
    ll                  = 1./mean(allisi)
    expected_short_isi  = (1.0-exp(-ll*10))*100
    mode                = modefind(allisi)
    residual_burstiness = burstiness-expected_short_isi
    LH = LABELSIZE+4
    text(INFOXPOS,ylim()[1]-pixels_to_yunits(20   ),'Mean rate = %d Hz'%mean_rate,
        horizontalalignment='left',
        verticalalignment  ='bottom',fontsize=LABELSIZE)
    text(INFOXPOS,ylim()[1]-pixels_to_yunits(20+LH*1),'ISI CV = %0.2f'%ISI_cv,
        horizontalalignment='left',
        verticalalignment  ='bottom',fontsize=LABELSIZE)
    text(INFOXPOS,ylim()[1]-pixels_to_yunits(20+LH*2),'SNR = %0.1f'%SNR,
        horizontalalignment='left',
        verticalalignment  ='bottom',fontsize=LABELSIZE)
    text(INFOXPOS,ylim()[1]-pixels_to_yunits(20+LH*3),'Mode freq. = %0.1f'%mode,
        horizontalalignment='left',
        verticalalignment  ='bottom',fontsize=LABELSIZE)
    axvline(mode,lw=2,color=TURQUOISE)
    xlabel('ms',fontsize=LABELSIZE)
    ylabel('No. Events',fontsize=LABELSIZE)
    fudgey(10)
    fudgex(5)
    xlim(0,FMAX)
    nicex()
    nicey()
    title('Monkey %s area %s\nsession %s unit %s'%(session[0],area,session[-2:],unit),loc='center',fontsize=7)
    return mean_rate,ISI_cv,SNR,mode
Ejemplo n.º 8
0
def unit_ISI_plot(session,area,unit,epoch=((6,-1000,0),(8,-1000,0)),INFOXPOS=70,LABELSIZE=8,NBINS=20,TMAX=300,INFOYSTART=0,BURST=10):
    cla()
    spikes = []
    for trial in get_good_trials(session):
        try:
            e,a,b = epoch
            spikes.append(cgid.spikes.get_spikes_event(session,area,unit,trial,e,a,b))
        except:
            for e,a,b in epoch:
                spikes.append(cgid.spikes.get_spikes_event(session,area,unit,trial,e,a,b))
    ISI_events = array(list(flatten(map(diff,spikes))))
    SNR        = cgid.spikes.get_unit_SNR(session,area,unit)
    histc,edges = histogram(ISI_events, bins = linspace(0,TMAX,NBINS+1))
    dx         = diff(edges)[0]
    bar(edges[:-1]+dx*0.1,histc,width=dx*0.8,color=GATHER[-1],edgecolor=(0,)*4)
    allisi = array(ISI_events)
    K   = 20
    x,y = kdepeak(log(K+allisi[allisi>0]))
    x   = exp(x)-K
    y   = y/(K+x)
    y   = y*len(allisi)*dx
    plot(x,y,color=RUST,lw=1.5)
    mean_rate           = sum(map(len,spikes))/float(len(get_good_trials(session))*2)
    noburst = allisi[allisi>BURST]
    ISI_cv              = std(noburst)/mean(noburst)
    burstiness          = sum(allisi<BURST)/float(len(allisi))*100
    ll                  = 1./mean(allisi)
    expected_short_isi  = (1.0-exp(-ll*10))*100
    residual_burstiness = burstiness-expected_short_isi
    LH = LABELSIZE+4
    text(INFOXPOS,ylim()[1]-pixels_to_yunits(INFOYSTART   ),'Mean rate = %d Hz'%mean_rate,
        horizontalalignment='left',
        verticalalignment  ='bottom',fontsize=LABELSIZE)
    text(INFOXPOS,ylim()[1]-pixels_to_yunits(INFOYSTART+LH*1),'ISI CV = %0.2f'%ISI_cv,
        horizontalalignment='left',
        verticalalignment  ='bottom',fontsize=LABELSIZE)
    text(INFOXPOS,ylim()[1]-pixels_to_yunits(INFOYSTART+LH*2),'SNR = %0.1f'%SNR,
        horizontalalignment='left',
        verticalalignment  ='bottom',fontsize=LABELSIZE)
    xlabel('ms',fontsize=LABELSIZE)
    ylabel('No. Events',fontsize=LABELSIZE)
    fudgey(10)
    fudgex(5)
    xlim(0,TMAX)
    nicex()
    nicey()
    simpleaxis()
    title('Monkey %s area %s\nsession %s unit %s'%(session[0],area,session[-2:],unit),loc='center',fontsize=7)
Ejemplo n.º 9
0
def coherence(session,area,unit,window=100,FMAX=250):
    assert False
    # this is untested, don't use it
    snippits  = getSTLFP(session,area,unit,window)
    M         = shape(snippits)[0]
    fs        = fft(snippits)
    raw2  = (abs(sum(fs,0))/sum(abs(fs),0))**2
    freqs = fftfreq(window*2+1,1./FS)
    use   = (freqs>0.)&(freqs<=250.)
    plot(freqs[use],raw2[use],color='r')
    ylim(0,0.5)
    nicelimits()
    ylabel('Coherence',labelpad=-10)
    xlim(0,FMAX)
    xticks(linspace(0,FMAX,11),['%d'%x for x in linspace(0,FMAX,11)])
    xlabel('Frequency (Hz)')
    title('Coherence')
Ejemplo n.º 10
0
def plotISIhist(session,area,unit,start,stop):
    cla()
    ByTrialSpikesMS = metaloadvariable(session,area,'ByTrialSpikesMS')
    NUNITS,NTRIALS = shape(ByTrialSpikesMS)
    isi = []
    for it in range(NTRIALS):
        isi.extend(diff(getTrial(session,area,unit,start,stop,it+1)))
    hist(isi,linspace(0,120,31),facecolor='k')
    a,b = ylim()
    plot((modefind(isi),)*2,ylim(),color='r',lw=2)
    ylim(a,b)
    nicelimits()
    xticks(int32([xlim()[0],int(round(modefind(isi))),xlim()[1]]))
    bareaxis(gca())
    xlabel('Time (ms)')
    ylabel('N')
    title('ISI histogram (time)')
    gca().yaxis.labelpad = -20
Ejemplo n.º 11
0
def plotPPC(unit):
    # PPC is 1 indexed by unit
    fftopt    = 100,10    #(window,bandWidth)
    ch        = 0,
    perminfo  = 0,0,50    #(perm,permstrategy,jitter)
    condition = 3,3       #(obj,grp)
    epoch     = 6,-1000,0 #(event,start,stop)
    basekey   = condition+epoch+ch+perminfo+fftopt
    keys,ppcs = allresults[session,area]
    match1 = find(all(keys==int32((unit,)+basekey+(0,)),1))
    x = squeeze(ppcs[match1])
    #cla()
    plot(ppcfreqs,x,color='k')
    ylim(0,0.5)
    xlim(0,400)
    nicelimits()
    xlabel('Frequency (Hz)')
    ylabel('PPC')
Ejemplo n.º 12
0
def plot_phase_direction(dz,skip=1,lw=1,zorder=None):
    '''
    Parameters:
        dz (complex123): phase gradient
        skip (int): only plot every skip
        lw (numeric): line width
    '''
    cla()
    imshow(angle(dz),interpolation='nearest')
    hsv()
    for i,row in list(enumerate(dz))[skip/2::skip]:
        for j,z in list(enumerate(row))[skip/2::skip]:
            z = 0.25*skip*z/abs(z)
            plot([j,j+real(z)],[i,i+imag(z)],'w',lw=lw,zorder=zorder)
            z = -z
            plot([j,j+real(z)],[i,i+imag(z)],'k',lw=lw,zorder=zorder)
    h,w = shape(dz)
    xlim(0-0.5,w-0.5)
    ylim(h-0.5,0-0.5)
Ejemplo n.º 13
0
def ontrial_summary_plot(statistic,session,area,trial,epoch,fa,fb):
    '''
    ontrial_summary_plot(array_average_ampltiude)
    '''
    name = ' '.join(statistic.__name__.replace('_',' ').split()).title()
    assert epoch is None
    times,res=overdata(statistic,session,area,trial,epoch,fa,fb)
    cla()
    N = randint(100)
    c1 = lighthues(100)[N]
    c2 = darkhues(100)[N]
    plot(times,res,lw=2,color=c2)
    title('%s, %s-%sHz, %s %s'%(name,fa,fb,session,area))
    xlabel('Time (ms)')
    ylabel(name)
    xlim(times[0],times[-1])
    overlayEvents('k','w',FS=1)
    #tight_layout()
    draw()
Ejemplo n.º 14
0
def plotSTA(session,area,unit,start,stop,window=100,texttop=False,color='k',label=None,lw=1.5,linestyle='-'):
    snippits = getSTLFP(session,area,unit,start,stop,window)
    time = arange(-window,window+1)
    sta  = mean(snippits,0)
    sts  = std(snippits,0)
    N    = shape(snippits)[0]
    sem  = 1.96*sts/sqrt(N)
    plot(time,sta,linestyle,zorder=5,color=color,label=label,lw=lw)
    title('Spike-triggered LFP average')
    xlabel('Time (ms)')
    nicelimits()
    simpleraxis(gca())
    xlim(-window,window)
    yextreme = max(*abs(array(ylim())))
    ylim(-yextreme,yextreme)
    yticks([-yextreme,0,yextreme])
    ylabel('STA (µV)')
    gca().yaxis.labelpad = -20
    return snippits
Ejemplo n.º 15
0
def plotWaveforms(session,area,unit):
    cla()
    waveForms = get_waveforms(session,area)
    wfs = waveForms[0,unit-1]
    times = arange(48)/30000.*1000*1000 # in uS
    toshow = wfs[:,:100:]
    nshow = shape(toshow)[1]
    for i in range(nshow):
        wf = toshow[:,i]
        wf = time_upsample(wf)
        t  = arange(48*4)/4./30000.*1000000
        dt = t[argmin(wf)]-400
        t -= dt
        plot(t,wf,color=(0.1,0.1,0.1),lw=0.5)
    xlim(times[0],times[-1])
    xlabel(u'μs')
    ylabel(u'μV')
    nicelimits()
    bareaxis(gca())
    title('Waveforms')
    gca().yaxis.labelpad = -20
Ejemplo n.º 16
0
def plotLFPSpikes(lfp,start,stop,spikes,
    zoom=1000,vrange=100,FS=1000.0,SPIKECOLOR=(1,0.5,0),lw=2):
    print 'caution assumes starting at beginning of trial'
    cla()
    plot((arange(stop-start)/FS)[:zoom],lfp[:zoom],
        'k')#,label='%s-%sHz LFP'%(lowf,highf))
    a,b = ylim()
    ylim(-100,100)
    a,b = ylim()
    spikes = float32(spikes)/FS
    mm = (a+b)/2
    c = (a+mm)/2
    d = (b+mm)/2
    for spt in spikes:
        if spt>=xlim()[1]:continue
        plot([spt,spt],(c,d),color=SPIKECOLOR,lw=lw)
    if len(spikes):
        plot([spt,spt],(c,d),color=SPIKECOLOR,lw=lw,label='Spike times')
    yticks([-vrange,0,vrange])
    ylabel('Microvolts')
    xlabel('Time (s)')
    xlim((start+1000)/float(FS),(zoom)/float(FS))
    simpleraxis()
    gca().yaxis.labelpad = -20
    xlabel('Time (s)')
    ylabel('µV')
    # nice_legend()
    legend(frameon=0,borderpad=-1,ncol=2,fontsize=12)
Ejemplo n.º 17
0
def estimate_beta_band(session,area,bw=8,epoch=None,doplot=False):
    '''
    return betapeak-0.5*bw,betapeak+0.5*bw
    '''
    print 'THIS IS NOT THE ONE YOU WANT TO USE'
    print 'IT IS EXPERIMENTAL COHERENCE BASED IDENTIFICATION OF BETA'
    assert 0
    if epoch is None: epoch = (6,-1000,3000)
    allco = []
    if not area is None:
        chs = get_good_channels(session,area)[:2]
        for a in chs:
            for b in chs:
                if a==b: continue
                for tr in get_good_trials(session):
                    x = get_filtered_lfp(session,area,a,tr,epoch,None,300)
                    y = get_filtered_lfp(session,area,b,tr,epoch,None,300)
                    co,fr = cohere(x,y,Fs=1000,NFFT=256)
                    allco.append(co)
    else:
        for area in areas:
            chs = get_good_channels(session,area)[:2]
            for a in chs:
                for b in chs:
                    if a==b: continue
                    for tr in get_good_trials(session):
                        x = get_filtered_lfp(session,area,a,tr,epoch,None,300)
                        y = get_filtered_lfp(session,area,b,tr,epoch,None,300)
                        co,fr = cohere(x,y,Fs=1000,NFFT=256)
                        allco.append(co)
    allco = array(allco)
    m = mean(allco,0)
    sem = std(allco,0)/sqrt(shape(allco)[0])
    # temporary in lieu of multitaper
    smooth = ceil(float(bw)/(diff(fr)[0]))
    smoothed = convolve(m,ones(smooth)/smooth,'same')
    use    = (fr<=56)&(fr>=5)
    betafr = (fr<=30-0.5*bw)&(fr>=15+0.5*bw)
    betapeak = fr[betafr][argmax(smoothed[betafr])]
    if doplot:
        clf()
        plot(fr[use],m[use],lw=2,color='k')
        plot(fr[use],smoothed[use],lw=1,color='r')
        plot(fr[use],(m+sem)[use],lw=1,color='k')
        plot(fr[use],(m-sem)[use],lw=1,color='k')
        positivey()
        xlim(*rangeover(fr[use]))
        shade([[betapeak-0.5*bw],[betapeak+0.5*bw]])
        draw()
    return betapeak-0.5*bw,betapeak+0.5*bw
Ejemplo n.º 18
0
def logpolar_gaussian(frame,doplot=False):
    # set to zero mean phase
    theta    = angle(mean(frame))
    rephased = frame*exp(1j*-theta)
    weights = abs(rephased)
    weights = weights/sum(weights)
    x = log(abs(rephased))
    y = angle(rephased)/4
    # use 2D gaussian approximation
    mx = dot(weights,x)
    my = dot(weights,y)
    cx = x - mx
    cy = y - my
    #cm = cov(cx,cy)
    correction = sum(weights)/(sum(weights)**2-sum(weights**2))
    cxx = dot(weights,cx*cx)*correction
    cxy = dot(weights,cx*cy)*correction
    cyy = dot(weights,cy*cy)*correction
    cm = arr([[cxx,cxy],[cxy,cyy]])
    sm = cholesky(cm)
    w,v = eig(cm)
    v = v[0,:]+1j*v[1,:]
    origin = mx + 1j*my
    w = sqrt(w)
    axis1 = origin + v[0]*w[0]*linspace(-1,1,100)
    axis2 = origin + v[1]*w[1]*linspace(-1,1,100)
    circle = exp(1j*linspace(0,2*pi,100))
    circle = p2c(dot(sm,[real(circle),imag(circle)]))+origin
    phase = exp(1j*theta)
    if doplot:
        plot(*c2p(exp(axis1)*phase),color='r',lw=2,zorder=Inf)
        plot(*c2p(exp(axis2)*phase),color='r',lw=2,zorder=Inf)
        plot(*c2p(exp(circle)*phase),color='r',lw=2,zorder=Inf)
    return exp(axis1)*phase,exp(axis2)*phase,exp(circle)*phase
Ejemplo n.º 19
0
def complex_gaussian(frame,doplot=False):
    # set to zero mean phase
    rephased = frame#*exp(1j*-theta)
    weights = ones(shape(rephased))
    weights = weights/sum(weights)
    # convert to log-polar
    x = real(rephased)
    y = imag(rephased)
    # use 2D gaussian approximation
    mx = dot(weights,x)
    my = dot(weights,y)
    cx = x - mx
    cy = y - my
    #cm = cov(cx,cy)
    correction = sum(weights)/(sum(weights)**2-sum(weights**2))
    cxx = dot(weights,cx*cx)*correction
    cxy = dot(weights,cx*cy)*correction
    cyy = dot(weights,cy*cy)*correction
    cm = arr([[cxx,cxy],[cxy,cyy]])
    sm = cholesky(cm)
    w,v = eig(cm)
    v = v[0,:]+1j*v[1,:]
    origin = mx + 1j*my
    w = sqrt(w)
    axis1 = origin + v[0]*w[0]*linspace(-1,1,100)
    axis2 = origin + v[1]*w[1]*linspace(-1,1,100)
    circle = exp(1j*linspace(0,2*pi,100))
    circle = p2c(dot(sm,[real(circle),imag(circle)]))+origin
    if doplot:
        plot(*c2p(axis1),color='r',lw=2,zorder=Inf)
        plot(*c2p(axis2),color='r',lw=2,zorder=Inf)
        plot(*c2p(circle),color='r',lw=2,zorder=Inf)
    return axis1,axis2,circle
Ejemplo n.º 20
0
def logpolar_stats(frame,doplot=False):
    z = mean(frame)
    r = mean(abs(frame))
    rl = mean(log(abs(frame)))
    rs = std(abs(frame))
    rsl = std(log(abs(frame)))
    w = frame / abs(frame)
    x = mean(w)
    theta = angle(x)
    #R = abs(x)
    R = abs(z) / r
    sd = sqrt(-2*log(R))
    print('R,sd',R,sd)
    cv = 1-R
    s = exp(rl)*exp(1j*theta)
    arc = exp(rl+theta*1j)*exp(1j*linspace(-sd,sd,100))
    circle = exp(1j*linspace(0,2*pi,100))
    circle = real(circle)*rsl + 1j*imag(circle)*sd
    circle = circle+rl+1j*theta
    circle = exp(circle)
    radial = arr([s*exp(-rsl),s*exp(rsl)])
    if doplot:
        plot(*c2p(circle),color='m',lw=2)
        plot(*c2p(arc),color='m',lw=2)
        plot(*c2p(radial),color='m',lw=2)
    return circle,arc,radial
Ejemplo n.º 21
0
def abspolar_stats(frame,doplot=False):
    z = frame
    phi = angle(mean(z**2))/2
    flip = sign(cos(angle(z)-phi))
    r = abs(z)*flip
    h = angle(z) + pi*int32(flip==-1)
    mr = mean(r)
    sr = std(r)
    mt = phi
    #s  = r*exp(1j*h)
    #st = sqrt(-2*log(abs(mean(s))/mean(abs(s))))
    st = sqrt(-2*log(abs(mean(exp(1j*h)))))
    arc    = mr*exp(1j*(phi+linspace(-st,st,100)))
    circle = exp(1j*linspace(0,2*pi,100))
    circle = (real(circle)*sr+mr)*exp(1j*(imag(circle)*st+phi))
    radial = arr([(mr-sr)*exp(1j*phi),(mr+sr)*exp(1j*phi)])
    if doplot:
        clf()
        plot(*c2p(circle),color='m',lw=2)
        plot(*c2p(arc),color='m',lw=2)
        plot(*c2p(radial),color='m',lw=2)
        scatter(*c2p([mr*exp(1j*phi)]),color='k',s=5**2)
    return circle,arc,radial
Ejemplo n.º 22
0
def phase_plane_animation_distribution(session,tr,areas,fa=10,fb=45,epoch=None,\
    skip=1,saveas=None,hook=None,FPS=30,stabilize=False,M=None,extension='png',markersize=1.5):
    areacolors = [OCHRE,AZURE,RUST]
    '''
    Test code:
    from os.path import expanduser
    session = 'SPK120918'
    area = 'M1'
    trial = 2
    tr = trial
    epoch = None
    fa,fb = 15,30
    fa = int(round(fa))
    fb = int(round(fb))
    close('all')
    phase_plane_animation_distribution(session,tr,['M1','PMv','PMd'],fa=10,fb=45,FPS=Inf,
    M=100,skip=1,extension='pdf',saveas='cgauss')
    '''
    models = [logpolar_stats,complex_gaussian]
    modelcolors = ['c','k']
    modellw = 1.5
    if not saveas is None:
        saveas += '_'+'_'.join(map(str,areas))
    '''
    phase_plane_animation(session,tr)
    '''
    print session, tr
    # save current figure so we can return to it
    ff=gcf()
    ax=gca()
    # get time base
    times = get_trial_times_ms(session,'M1',tr,epoch)[::skip]
    # retrieve filtered array data
    data = {}
    for a in areas:
        print 'loading area',a 
        x = get_all_analytic_lfp(session,a,tr,epoch,fa,fb,onlygood=True)[:,::skip]
        data[a]=x.T
    # compute phase velocities for stabilization
    alldata = concatenate(data.values(),1)
    #phasedt = rewrap(diff(angle(alldata),1,0))
    #phasev  = median(phasedt,axis=1)
    #phaseshift = append(0,cumsum(phasev))
    # compute stabilization differently, using median phase
    phaseshift = angle(mean(alldata,axis=1))
    # PREPARE FIGURE
    figtitle = 'Analytic Signal %s-%sHz\n%s trial %s'%(fa,fb,session,tr)
    if not saveas is None:
        saveas += '_%s_%s_%s_%s'%(session,tr,fa,fb)
    figure('Phase Plane',figsize=(4,4))
    a2 = cla()
    if M is None:
        M = 10
        for a in areas:
            M = max(M,int(ceil(np.max(abs(data[a]))/10.))*10)
    complex_axis(M)
    title(figtitle+' t=%dms'%times[0])
    tight_layout()
    # prepare output directory if we're going to save this    
    if not saveas is None:
        savedir = './'+saveas
        ensuredir(savedir)
    # initialize points and lines
    scat={}
    frame = []
    for i,a in en(areas):
        x = data[a][0]
        scat[a] = scatter(*c2p(x),s=markersize**2,color=areacolors[i],label=a)
        frame.extend(x)
    frame = arr(frame)
    model = [m(frame) for m in models]
    modellines = []
    for i,m in en|model:
        distcolor = modelcolors[i]
        lines = [plot(*c2p(a),color=distcolor,lw=modellw,zorder=Inf)[0] for a in m]
        modellines.append(lines)
    nice_legend()
    # perform animation
    st = now()
    for i,t in en|times:
        stabilizer = exp(-1j*phaseshift[i]) if stabilize else 1
        frame = []
        for a in areas:
            x = stabilizer*data[a][i]
            scat[a].set_offsets(c2p(x).T)
            frame.extend(x)
        frame = arr(frame)
        for i,(lines,m) in en|iz(modellines,[m(frame) for m in models]):
            for l,x in iz(lines,m):
                l.set_data(*c2p(x))
        title(figtitle+' t=%sms'%t)
        draw()
        if not saveas is None:
            savefig(savedir+'/'+saveas+'_%s.%s'%(t,extension))
        if not hook is None: hook(t)
        st=waitfor(st+1000/FPS)
    if not ff is None: figure(ff.number)
Ejemplo n.º 23
0
def get_high_beta_events(session,area,channel,epoch,
    MINLEN  = 40,   # ms
    BOXLEN  = 50,   # ms
    THSCALE = 1.5,  # sigma (standard deviations)
    lowf    = 10.0, # Hz
    highf   = 45.0, # Hz
    pad     = 200,  # ms
    clip    = True,
    audit   = False
    ):
    '''
    get_high_beta_events(session,area,channel,epoch) will identify periods of
    elevated beta-frequency power for the given channel.

    Thresholds are selected per-channel based on all available trials.
    The entire trial time is used when estimating the average beta power.
    To avoid recomputing, we extract beta events for all trials at once.

    By default events that extend past the edge of the specified epoch will
    be clipped. Passing clip=False will discard these events.

    returns the event threshold, and a list of event start and stop
    times relative to session time (not per-trial or epoch time)

    passing audit=True will enable previewing each trial and the isolated
    beta events.

    >>> thr,events = get_high_beta_events('SPK120925','PMd',50,(6,-1000,0))
    '''

    # get LFP data
    signal = get_raw_lfp_session(session,area,channel)

    # esimate threshold for beta events
    beta_trials = [get_filtered_lfp(session,area,channel,t,(6,-1000,0),lowf,highf) for t in get_good_trials(session)]
    threshold   = np.std(beta_trials)*THSCALE
    print 'threshold=',threshold

    N = len(signal)
    event,start,stop = epoch
    all_events = []
    all_high_beta_times = []
    for trial in get_good_trials(session):
        evt        = get_trial_event(session,area,trial,event)
        trialstart = get_trial_event(session,area,trial,4)
        epochstart = evt + start + trialstart
        epochstop  = evt + stop  + trialstart
        tstart     = max(0,epochstart-pad)
        tstop      = min(N,epochstop +pad)
        filtered   = bandfilter(signal[tstart:tstop],lowf,highf)
        envelope   = abs(hilbert(filtered))
        smoothed   = convolve(envelope,ones(BOXLEN)/float(BOXLEN),'same')
        E = array(get_edges(smoothed>threshold))+tstart
        E = E[:,(diff(E,1,0)[0]>=MINLEN)
                & (E[0,:]<epochstop )
                & (E[1,:]>epochstart)]
        if audit: print E
        if clip:
            E[0,:] = np.maximum(E[0,:],epochstart)
            E[1,:] = np.minimum(E[1,:],epochstop )
        else:
            E = E[:,(E[1,:]<=epochstop)&(E[0,:]>=epochstart)]
        if audit:
            clf()
            axvspan(epochstart,epochstop,color=(0,0,0,0.25))
            plot(arange(tstart,tstop),filtered,lw=0.7,color='k')
            plot(arange(tstart,tstop),envelope,lw=0.7,color='r')
            plot(arange(tstart,tstop),smoothed,lw=0.7,color='b')
            ylim(-80,80)
            for a,b in E.T:
                axvspan(a,b,color=(1,0,0,0.5))
            axhline(threshold,color='k',lw=1.5)
            xlim(tstart,tstop)
            draw()
            wait()
        all_events.extend(E.T)
        assert all(diff(E,0,1)>=MINLEN)
    return threshold, all_events
Ejemplo n.º 24
0
    return freqs, transpose(result, (1, 0, 2))


############################################################################
############################################################################
if __name__ == "__main__":
    # wavelet power test
    signal = randn(1000)
    Fs = 1000
    for freq in arange(5, 500, 5):
        ws = arange(4, 30)
        M = 2.0 * 1.0 * ws * Fs / float(freq)
        clf()
        x = []
        bw = []
        for m, w in zip(M, ws):
            wl = normalized_morlet(m, w)
            a, b = fftfreq(len(wl), 1.0 / Fs), abs(fft(wl))
            # plot(a,b)
            df = a[1] - a[0]
            bw.append(sum(b) * df)
            s = convolve(signal, wl, "same")
            # print(m,w,mean(abs(s)),mean(abs(s)**2))
            x.append(var(s))

        bw = arr(bw)
        x = arr(x)
        plot(x / bw)
        positivey()
        print(freq, 1 / mean(bw / x))
Ejemplo n.º 25
0
def phase_plane_animation_arraygrid(session,tr,fa=10,fb=45,\
    epoch=None,skip=1,saveas=None,hook=None,FPS=30,stabilize=True,markersize=1.5):
    '''
    phase_plane_animation(session,tr)
    '''
    warn('Also plots "bad" channels')
    print session, tr
    # save current figure so we can return to it
    ff=gcf()
    ax=gca()
    # get time base
    times = get_trial_times_ms(session,'M1',tr,epoch)[::skip]
    # retrieve filtered array data
    data = {}
    for a in areas:
        print 'loading area',a 
        x = get_all_analytic_lfp(session,a,tr,epoch,fa,fb,onlygood=True)[:,::skip]
        data[a]=x.T
    # locate all pairs
    pairs = {}
    for a in areas:
        pairs[a] = get_all_pairs_ordered_as_channel_indecies(session,a)
    # compute phase velocities for stabilization
    alldata = concatenate(data.values(),1)
    phasedt = rewrap(diff(angle(alldata),1,0))
    phasev  = median(phasedt,axis=1)
    phaseshift = append(0,cumsum(phasev))
    # compute stabilization differently, using median phase
    #phaseshift = angle(mean(alldata,axis=1))
    # PREPARE FIGURE
    figtitle = 'Analytic Signal %s-%sHz\n%s trial %s'%(fa,fb,session,tr)
    if not saveas is None:
        saveas += '_%s_%s_%s_%s'%(session,tr,fa,fb)
    figure('Phase Plane')
    a2 = cla()
    # DETERMINE NICE SQUARE AXIS BOUNDS
    M = 10
    for a in areas:
        M = max(M,int(ceil(np.max(abs(data[a]))/10.))*10)
    complex_axis(M)
    title(figtitle+' t=%dms'%times[0])
    # prepare output directory if we're going to save this    
    if not saveas is None:
        savedir = './'+saveas
        ensuredir(savedir)
    # prepare stuff for blitting
    aa = gca()
    canvas = aa.figure.canvas
    background = canvas.copy_from_bbox(aa.bbox)
    def updateline(time):
        print '!!! t=',time
        canvas.restore_region(background)
        aa.draw_artist(line)
        aa.figure.canvas.blit(ax.bbox)
    # initialize points
    scat={}
    grid={}
    for i,a in en(areas):
        points = c2p(data[a][0])
        c = darkhues(9)[i*3+2]
        scat[a] = scatter(*points,s=markersize**2,color=c,label=a)
        lines = []
        for ix1,ix2 in pairs[a]:
            p1=points[:,ix1]
            p2=points[:,ix2]
            line = plot([p1[0],p2[0]],[p1[1],p2[1]],color=c,lw=0.4)[0]
            lines.append((line,ix1,ix2))
        grid[a]=lines
    nice_legend()
    # perform animation
    st = now()
    for i,t in en|times:
        stabilizer = exp(-1j*phaseshift[i]) if stabilize else 1
        # update via blitting instead of draw
        canvas.restore_region(background)
        for a in areas:
            points = c2p(stabilizer*data[a][i])
            scat[a].set_offsets(points.T)
            for line,ix1,ix2 in grid[a]:
                p1=points[:,ix1]
                p2=points[:,ix2]
                line.set_data([p1[0],p2[0]],[p1[1],p2[1]])
                aa.draw_artist(line)
        title(figtitle+' t=%sms'%t)
        # update via blitting instead of draw
        aa.figure.canvas.blit(ax.bbox)
        # draw()
        if not saveas is None:
            savefig(savedir+'/'+saveas+'_%s.png'%t)
        if not hook is None: hook(t)
        st=waitfor(st+1000/FPS)
    if not ff is None: figure(ff.number)
Ejemplo n.º 26
0
def compare_ppc_approaches(session,area,unit,start,stop,window=100,FMAX=250):
    '''
    Try with
    compare_ppc_approaches('RUS120523','PMv',42,-1000,0,200)
    '''
    channel = get_unit_channel(session,area,unit)
    signal = get_raw_lfp_session(session,area,channel)
    times  = get_spikes_session_filtered_by_epoch(session,area,unit,(6,start,stop))

    (freqs,unbiased),nSamples = pairwise_phase_consistancy(signal,times,
        window=window,Fs=1000,delta=100,
        multitaper=False,biased=False,taper=hann)
    use = (freqs>0.)&(freqs<=250.)
    plot(freqs[use],unbiased [use],label='Hann unbiased')

    (freqs,unbiased),nSamples = pairwise_phase_consistancy(signal,times,
        window=window,Fs=1000,delta=100,
        multitaper=False,biased=True,taper=hann)
    use = (freqs>0.)&(freqs<=250.)
    plot(freqs[use],unbiased [use],label='Hann biased')

    (freqs,unbiased),nSamples = pairwise_phase_consistancy(signal,times,
        window=window,Fs=1000,delta=100,
        multitaper=True,biased=False,k=1)
    use = (freqs>0.)&(freqs<=250.)
    plot(freqs[use],unbiased [use],label='Multitaper 1 taper unbiased')

    (freqs,unbiased),nSamples = pairwise_phase_consistancy(signal,times,
        window=window,Fs=1000,delta=100,k=2,
        multitaper=True,biased=False)
    use = (freqs>0.)&(freqs<=250.)
    plot(freqs[use],unbiased [use],label='Multitaper 2 taper unbiased')

    (freqs,unbiased),nSamples = pairwise_phase_consistancy(signal,times,
        window=window,Fs=1000,delta=100,k=3,
        multitaper=True,biased=False)
    use = (freqs>0.)&(freqs<=250.)
    plot(freqs[use],unbiased [use],label='Multitaper 3 taper unbiased')

    (freqs,unbiased),nSamples = pairwise_phase_consistancy(signal,times,
        window=window,Fs=1000,delta=100,k=4,
        multitaper=True,biased=False)
    use = (freqs>0.)&(freqs<=250.)
    plot(freqs[use],unbiased [use],label='Multitaper 4 taper unbiased')

    #cl = ppc_chance_level(nSamples,10000,.999,nTapers)
    #plot(xlim(),[cl,cl],color=color,label=label+' 99.9% chance level')
    #print 'chance level is %s'%cl
    ylim(0,0.5)
    ylabel('PPC',labelpad=-10)
    xlim(0,FMAX)
    nicelimits()
    xticks(linspace(0,FMAX,11),['%d'%x for x in linspace(0,FMAX,11)])
    xlabel('Frequency (Hz)')
    title('Pairwise phase consistency')
    simpleaxis(gca())