Ejemplo n.º 1
0
def clipWave(W,freq,time,freqRange,timeRange,coi = []):
    '''
    W_clip, freq_clip,time_clip,coi_clip = clipWave(W,freq,time,freqRange, \
        timeRange,coi = [])
    Clips the matrix of wavelet coefficients to the specified freq and time 
        range
    Inputs:
    W - 2D matrix of wavelet coefficients
    freq - Freq vector of length = number of rows of W
    time - Time vector of length = number of cols of W
    freqRange  - 2 element array specifying the min and max freq to clip W to
    timeRange - 2 element array specifying the min and max time to clip W to
    coi - Cone of influence; can be empty
    Outputs:
    W_clip - Clipped W
    freq_clip - Clipped freq
    ...
    '''
    import apCode.SignalProcessingTools as spt
    fInds = spt.nearestMatchingInds(freqRange,freq)
    print(fInds)
    tInds = spt.valsToNearestInds(timeRange,time)    
    W_clip = W[fInds[1]:fInds[0],tInds[0]:tInds[1]]
    freq_clip = freq[fInds[1]:fInds[0]]
    time_clip = time[tInds[0]:tInds[1]]
    if len(coi) >0:
        coi_clip = coi[tInds[0]:tInds[1]]
    else:
        coi_clip = coi
    return W_clip, freq_clip,time_clip,coi_clip
Ejemplo n.º 2
0
def basToDf(bas, n_pre_bas=200e-3, n_post_bas=1.5):
    import apCode.SignalProcessingTools as spt
    import pandas as pd
    import numpy as np
    import apCode.util as util
    df = {}
    dt = bas['t'][1] - bas['t'][0]
    Fs = int(1 / dt)
    inds_stim, amps_stim, ch_stim = getStimInfo(bas)
    keys = list(bas.keys())
    ind_mtr = util.findStrInList('den', keys)[-1]
    x = np.squeeze(np.array(bas[keys[ind_mtr]]))
    if x.shape[0] == 2:
        x = x.T
    motor_trl = spt.segmentByEvents(spt.zscore(x, axis=0), inds_stim,
                                    int(n_pre_bas * Fs), int(n_post_bas * Fs))
    motorTime_trl = spt.segmentByEvents(bas['t'], inds_stim,
                                        int(n_pre_bas * Fs),
                                        int(n_post_bas * Fs))
    trlNum = np.arange(len(inds_stim)) + 1
    df['trlNum'] = trlNum
    df['stimInd'] = inds_stim
    df['stimAmp'] = amps_stim
    df['stimHT'] = ch_stim
    df['motorActivity'] = motor_trl
    df['motorTime'] = motorTime_trl
    return pd.DataFrame(df)
Ejemplo n.º 3
0
def clipWave(W, freq, time, freqRange, timeRange, coi=[]):
    '''
    W_clip, freq_clip,time_clip,coi_clip = clipWave(W,freq,time,freqRange, \
        timeRange,coi = [])
    Clips the matrix of wavelet coefficients to the specified freq and time 
        range
    Inputs:
    W - 2D matrix of wavelet coefficients
    freq - Freq vector of length = number of rows of W
    time - Time vector of length = number of cols of W
    freqRange  - 2 element array specifying the min and max freq to clip W to
    timeRange - 2 element array specifying the min and max time to clip W to
    coi - Cone of influence; can be empty
    Outputs:
    W_clip - Clipped W
    freq_clip - Clipped freq
    ...
    '''
    import apCode.SignalProcessingTools as spt
    fInds = spt.nearestMatchingInds(freqRange, freq)
    print(fInds)
    tInds = spt.valsToNearestInds(timeRange, time)
    W_clip = W[fInds[1]:fInds[0], tInds[0]:tInds[1]]
    freq_clip = freq[fInds[1]:fInds[0]]
    time_clip = time[tInds[0]:tInds[1]]
    if len(coi) > 0:
        coi_clip = coi[tInds[0]:tInds[1]]
    else:
        coi_clip = coi
    return W_clip, freq_clip, time_clip, coi_clip
Ejemplo n.º 4
0
   def plotCentroids(time, centroids, stimTimes, time_ephys, ephys, scaled = False, 
                     colors = None, xlabel = '',ylabel = '', title = ''):
       """
       Plots centroids resulting from some clustering method
       Parameters:
       time - Time vectors for centroids (optical samplign interval)
       centroids - Array of shape (M, N), where M is the number of centroids, and N is the #
           number of features (or time points)
       stimTimes - Times of stimulus onsets for overlaying vertical dashed lines
       time_ephys - Time axis for ephys data (usually sampled at higher rate)
       ephys - Ephys time series
       scaled - Boolean; If true, scales centroids individually, else scales jointly.
       colors - Array of shape (M,3) or (M,4). Colormap to use for plotting centroids 
           
       """
       import apCode.SignalProcessingTools as spt
       import seaborn as sns
       import numpy as np
       import matplotlib.pyplot as plt
       if scaled:
           centroids = spt.standardize(centroids,axis = 1)
       else:
           centroids = spt.standardize(centroids)        
 
       ephys = spt.standardize(ephys)
       
       n_clusters = np.shape(centroids)[0]
       if np.any(colors == None):
           colors = np.array(sns.color_palette('colorblind',np.shape(centroids)[0]))
       elif np.shape(colors)[0] < np.shape(centroids)[0]:
           colors = np.tile(colors,(np.shape(centroids)[0],1))
           colors = colors[:np.shape(centroids)[0],:]
   
       if np.any(time == None):
           time = np.arange(np.shape(centroids)[1])
   
       plt.style.use(['dark_background','seaborn-poster'])
       for cc in np.arange(np.shape(centroids)[0]):
           plt.plot(time,centroids[cc,:]-np.mean(centroids[cc,:])-cc,color = colors[cc,:])
       plt.plot(time_ephys,ephys-np.mean(ephys)-cc-1,color = colors[0,:])
       yt = np.arange(n_clusters + 1)
       ytl = list(yt)
       ytl[-1] = 'ephys'
       plt.yticks(-yt, ytl)
       plt.xlabel(xlabel,fontsize = 16)
       plt.ylabel(ylabel, fontsize = 16)
       plt.box('off')
       plt.title(title, fontsize = 20)
       plt.grid('off')
       plt.xlim(time[0],time[-1])
       for st in stimTimes:
           plt.axvline(x = st, ymin = 0, ymax =1, alpha = 0.3, 
                       color = 'w',linestyle = '--')  
Ejemplo n.º 5
0
 def getClrMap(x,cMap,maskInds):
     cm = np.zeros((len(x),4))*np.nan
     negInds = np.where(x<0)[0]
     posInds = np.where(x>=0)[0]
     x[negInds] = spt.mapToRange(np.hstack((x[negInds],[0,-1])),[0,127])[0:-2]
     x[posInds] = spt.mapToRange(np.hstack((x[posInds],[0,1])),[128,255])[0:-2]
     cm[negInds] = cMap(x[negInds].astype(int))    
     cm[posInds] = cMap(x[posInds].astype(int))
     if len(maskInds) == 0:
         pass
     else:
         cm[maskInds,-1] = 0
     return cm
Ejemplo n.º 6
0
def plotWave(W,
             freq,
             time,
             coi=[],
             powScale='log',
             cmap='coolwarm',
             xlabel='Time (sec)',
             ylabel='Freq (Hz)'):
    '''
    fh = plotWave(W,freq,time,...)
    Plots the matrix of wavelet coefficients W, using the specified freq and time axes
    '''
    import numpy as np
    import apCode.SignalProcessingTools as spt
    import matplotlib.pyplot as plt
    if powScale.lower() == 'log':
        W = np.log2(np.abs(W))
    else:
        W = np.abs(W)
    period = 1 / freq
    dt = time[1] - time[0]
    tt = np.hstack((time[[0, 0]] - dt * 0.5, time, time[[-1, -1]] + dt * 0.5))
    if len(coi) == 0:
        coi = time
    coi_ext = np.log2(np.hstack((freq[[-1, 1]], 1 / coi, freq[[1, -1]])))
    tt = spt.nearestMatchingInds(tt, time)
    coi_ext = spt.nearestMatchingInds(coi_ext, freq)
    freq_log = np.unique(spt.nextPow2(freq))
    fTick = 2**freq_log
    yTick = 1 / fTick
    inds = spt.nearestMatchingInds(yTick, period)
    ytl = (1 / period[inds]).astype(int).ravel()
    fig = plt.imshow(W, aspect='auto', cmap=cmap)
    fig.axes.set_yticks(inds)
    fig.axes.set_yticklabels(np.array(ytl))
    plt.colorbar()

    xTick = np.linspace(0, len(time) - 1, 5).astype(int)
    xtl = np.round(time[xTick] / (time[-1] / 4)) * 0.5
    fig.axes.set_xticks(xTick)
    fig.axes.set_xticklabels(xtl.ravel())
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    #plt.show()
    plt.plot(tt, coi_ext, 'k--')
    return fig
Ejemplo n.º 7
0
    def plot_with_labels_interact(self,
                                  ta,
                                  x=None,
                                  cmap='tab20',
                                  figsize=(20, 10),
                                  marker_size=10,
                                  line_alpha=0.2,
                                  ylim=(-150, 150),
                                  title='Tail angles with GMM labels'):

        import plotly.graph_objs as go
        if isinstance(cmap, str):
            cmap = eval(f'plt.cm.{cmap}')
        labels, features = self.predict(ta)
        if x is None:
            x = np.arange(ta.shape[1])
        y = ta[-1]
        if self.pk_thr_ is not None:
            pks = spt.findPeaks(y, thr=self.pk_thr_, pol=0, thrType='rel')[0]
        else:
            pks = np.arange(len(y))

        line = go.Scatter(x=x,
                          y=y,
                          mode='lines',
                          opacity=line_alpha,
                          marker=dict(color='black'),
                          name='ta')
        scatters = []
        scatters.append(line)
        for iLbl, lbl in enumerate(np.unique(labels)):
            clr = f'rgba{cmap(lbl/self.n_gmm_)}'
            inds = np.where(labels == lbl)[0]
            inds = np.intersect1d(inds, pks)
            scatter = go.Scatter(x=x[inds],
                                 y=y[inds],
                                 mode='markers',
                                 marker=dict(color=clr,
                                             symbol=lbl,
                                             size=marker_size),
                                 name=f'Lbl-{lbl}')
            scatters.append(scatter)
        fig = go.Figure(scatters)
        if ylim is not None:
            ylim = np.array(ylim)
            ylim[0] = np.minimum(ylim[0], y.min())
            ylim[1] = np.maximum(ylim[1], y.max())
            fig.layout.yaxis.range = ylim
        fig.layout.xaxis.range = [x[0], x[-1]]
        fig.update_layout(title=title)
        # fig.show()
        # figName = f'Fig-{util.timestamp()}_trl-{iTrl}.html'
        # fig.write_html(os.path.join(figDir,figName))
        return fig
Ejemplo n.º 8
0
 def getClrMapsForEachRegressor(betas,normed = True, cMap = 'PiYG', 
                                scaling =1, betaThr= None):
     """
     Given the coeffiecients(betas) from regression, returns a list of color
     maps, with each color map corresponding to the betas for a single regressor.
     These can be used by colorCellsInImgStack to create image stacks with
     cells colored by betas.
     Parameters:
     betas - Array-like with shape (nSamples, nFeatures).
     normed - Boolean; If True, normalizes betas such that for each feature
         the values range from -1 to 1.
     scaling - Not yet implemented
     betaThr - None(default),scalar,'auto'; Determines if any thresholding
         should be applied based on beta values. If None, then no thresholding,
         if scalar, then for beta values whose magnitude is less than this
         scalar, the alpha value in the color maps is set to zero. If 'auto'
         then automatically determines threshold
     """
     import apCode.SignalProcessingTools as spt       
     import matplotlib.pyplot as plt
     import apCode.volTools as volt
             
     def getClrMap(x,cMap,maskInds):
         cm = np.zeros((len(x),4))*np.nan
         negInds = np.where(x<0)[0]
         posInds = np.where(x>=0)[0]
         x[negInds] = spt.mapToRange(np.hstack((x[negInds],[0,-1])),[0,127])[0:-2]
         x[posInds] = spt.mapToRange(np.hstack((x[posInds],[0,1])),[128,255])[0:-2]
         cm[negInds] = cMap(x[negInds].astype(int))    
         cm[posInds] = cMap(x[posInds].astype(int))
         if len(maskInds) == 0:
             pass
         else:
             cm[maskInds,-1] = 0
         return cm
   
     if isinstance(cMap,str):
         cMap = plt.cm.get_cmap(cMap)
 
     if normed:
         betas = spt.standardize(betas, preserveSign = True, axis = 0)*scaling
     clrMaps = []
     for beta in betas.T:
         if betaThr == None:
             maskInds = []
         elif betaThr == 'auto':
             betaThr = volt.getGlobalThr(np.abs(beta))
             maskInds = np.where(np.abs(beta)<betaThr)
         else:
             maskInds = np.where(np.abs(beta)<betaThr)
         clrMaps.append(getClrMap(beta,cMap,maskInds))
     #clrMaps = [getClrMap(beta,cMap,betaThr) for beta in betas.transpose()]
     return clrMaps
Ejemplo n.º 9
0
def plotWave(W,freq,time,coi = [],powScale = 'log',cmap = 'coolwarm', xlabel = 'Time (sec)', ylabel = 'Freq (Hz)'):
    '''
    fh = plotWave(W,freq,time,...)
    Plots the matrix of wavelet coefficients W, using the specified freq and time axes
    '''
    import numpy as np
    import apCode.SignalProcessingTools as spt
    import matplotlib.pyplot as plt
    if powScale.lower() == 'log':
        W = np.log2(np.abs(W))
    else:
        W = np.abs(W)
    period = 1/freq
    dt = time[1]-time[0]
    tt = np.hstack((time[[0,0]]-dt*0.5,time,time[[-1,-1]]+dt*0.5))
    if len(coi) == 0:
        coi = time
    coi_ext = np.log2(np.hstack((freq[[-1,1]],1/coi,freq[[1,-1]])))
    tt = spt.nearestMatchingInds(tt,time)
    coi_ext = spt.nearestMatchingInds(coi_ext,freq)
    freq_log = np.unique(spt.nextPow2(freq))
    fTick = 2**freq_log
    yTick  = 1/fTick
    inds = spt.nearestMatchingInds(yTick,period)
    ytl = (1/period[inds]).astype(int).ravel()
    fig =plt.imshow(W, aspect = 'auto',cmap = cmap)
    fig.axes.set_yticks(inds)
    fig.axes.set_yticklabels(np.array(ytl))
    plt.colorbar()
    
    xTick = np.linspace(0,len(time)-1,5).astype(int)
    xtl = np.round(time[xTick]/(time[-1]/4))*0.5
    fig.axes.set_xticks(xTick)    
    fig.axes.set_xticklabels(xtl.ravel())   
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    #plt.show()
    plt.plot(tt,coi_ext,'k--')
    return fig
Ejemplo n.º 10
0
 def fit(self, ta):
     """Fit model to tail angles. This includes preprocessing wherein
     SVD-based feature extraction is performed, followed by PCA for
     dimensionality reduction, if specfied.
     Parameters
     ----------
     self: object
         Instance of initiated SvdGmm class
     ta: array, (nPointsAlongTail, nTimePoints)
         Tail angles array
     Returns
     -------
     self: object
         Trained SvdGmm model.
     """
     if self.svd is None:
         svd = TruncatedSVD(n_components=self.n_svd_,
                            random_state=self.random_state_).fit(ta.T)
     else:
         svd = self.svd
     V = svd.transform(ta.T)
     dv = np.gradient(V)[0]
     ddv = np.gradient(dv)[0]
     X = np.c_[V, dv, ddv]
     if self.use_envelopes_:
         features = max_min_envelopes(X.T).T
     scaler = StandardScaler(with_mean=self.scaler_withMean_).fit(features)
     features = scaler.transform(features)
     if self.pk_thr_ is not None:
         y = ta[-1]
         pks = spt.findPeaks(y, thr=self.pk_thr_, thrType='rel', pol=0)[0]
         print(f'Peaks are {round(100*len(pks)/len(y), 1)}% of all samples')
         features = features[pks, :]
     if self.pca_percVar_ is not None:
         pca = PCA(n_components=self.pca_percVar_,
                   random_state=self.random_state_).fit(features)
         features = pca.transform(features)
         pca.n_components_ = features.shape[1]
     else:
         pca = None
     print('Fitting GMM..')
     gmm = GMM(n_components=self.n_gmm_,
               random_state=self.random_state_,
               covariance_type=self.covariance_type_,
               **self.gmm_kwargs_)
     gmm = gmm.fit(features)
     self.svd = svd
     self.scaler = scaler
     self.pca = pca
     self.gmm = gmm
     return self
Ejemplo n.º 11
0
def getStimInfo(bas,ch_stim:str = 'patch3',ch_switch:str = 'patch2',\
                ch_camTrig = 'camTrigger', minPkDist =5*6000):
    import numpy as np
    import apCode.SignalProcessingTools as spt
    pks, amps = spt.findPeaks(np.array(bas[ch_stim]).flatten(),
                              thr=0.5,
                              pol=1,
                              minPkDist=minPkDist)
    inds_keep = np.where(amps > 0)[0]
    pks = pks[inds_keep]
    amps = amps[inds_keep]
    foo = np.array(bas[ch_switch]).flatten()[pks]
    ht = np.array(['T'] * len(pks))
    ht[np.where(foo > 1)] = 'H'
    return pks, amps, ht
Ejemplo n.º 12
0
def stimInds(bas,thr = 1, stimCh:str = 'stim0', minStimDist:int =5*6000, 
             normalize:bool = True):
    """
    bas: dic
        BehaveAndScan file read by "importCh".
    thr: scalar or str
        Threshold for detecting stimuli. If 'auto', then estimates a threshold (assuming)
        that stimulus polarity is positive (i.e. large positive values are stimuli)
    stimChannels: list of strings
        List of the names of stimulus channels.
    minStimDist: int
        Minimum distance (# of samples) between successive stimuli. Set 0
    normalize:bool
        Whether to normalize the stimulus channel before detecting stimuli.
        If True, converts stim channel to z-score units. Useful option when 
        absolute threshold value is unknown.
    
    """
    import numpy as np
    import apCode.SignalProcessingTools as spt
    from apCode.volTools import getGlobalThr
    
    x = bas[stimCh].copy()
    if normalize:
        x = spt.zscore(x)

    if isinstance(thr, str):
        if thr.lower() == 'auto':
            x_pos = x[np.where(x>=0)]
            thr = getGlobalThr(x_pos)
    pks = spt.findPeaks(x,thr = thr, pol =1, minPkDist = minStimDist)
    if len(pks)>0:
        return pks[0]
    else:
        print('No stimuli found!')
        return None   
for cNum, cInds in enumerate(data['info']['inds']):
    data['info']['inds'][cNum] = cInds.astype(int)-1
data['info']['center'] = data['info']['center'].astype(int)-1
data['info']['slice'] = data['info']['slice'].astype(int)-1
data['info']['x_minmax'] = data['info']['x_minmax'].astype(int)-1
data['info']['y_minmax'] = data['info']['y_minmax'].astype(int)-1
data['inds']['inMask'] = data['inds']['inMask'].astype(int)[0]-1

if alreadyRegressed:
    data['regr']= pyData['regr']
data['pathToData'] = os.path.join(inputDir,pyFileName)
print('\n', int(time.time()-tic),'sec')
print(time.ctime())

print('Geting indices for stacks in ephys data and stims in optical data...')
data['inds']['stim'] = spt.nearestMatchingInds(pyData['time'].ravel()[pyData['stim']['inds']],data['time'])
data['inds']['stackInit'] = spt.nearestMatchingInds(data['time'],pyData['time'],processing = 'parallel')

# Detrending dF/F signals
print('Detrending dF/F signals...')
for cc in np.arange(np.shape(data['dFF'])[1]):
    data['dFF'][:,cc] = ia.detrendCa(data['dFF'][:,cc],data['inds']['stim'])

# An adjustment for some datasets
data['avg'] = np.transpose(data['avg'],[0,2,1])
np.shape(data['avg'])

#%% Check to see if stimulus onset indices have been detected correctly
stimInd = 4
periTime_ephys = 0.1
periTime_opt = 8
Ejemplo n.º 14
0
def dataFrameOfMatchedMtrAndCaTrls_singleFish(bas, ca = None, ch_camTrig = 'camTrigger', ch_stim = 'patch3',
                                   ch_switch = 'patch2', thr_camTrig= 4, Fs_bas = 6000, t_pre_bas = 0.2,\
                                   t_post_bas = 1.5, t_pre_ca = 1, t_post_ca = 10, n_jobs = 20):
    """
    Parameters
    ----------
    bas: dict
        Dictionary resulting from reading of BAS file
    ca: array, (nRois, nSamples)
        nSamples is expected to match the number of Camera triggers
    Returns
    -------
    df: Pandas dataframe
    """
    import numpy as np
    from scipy.stats import mode
    import apCode.SignalProcessingTools as spt
    from apCode import util
    import pandas as pd

    def motorFromBas(bas):
        keys = list(bas.keys())
        ind_mtr = util.findStrInList('den', keys)
        if len(ind_mtr) > 0:
            ind_mtr = ind_mtr[-1]
            x = np.squeeze(np.array(bas[keys[ind_mtr]]))
        else:
            if 'ch3' in bas:
                x = np.array([bas['ch3'], bas['ch4']])
            else:
                x = np.array([bas['ch1'], bas['ch2']])
        if x.shape[0] < x.shape[1]:
            x = x.T
        return x

    mtr = motorFromBas(bas)
    stimInds, stimAmps, stimHt = getStimInfo(bas, ch_stim = ch_stim,\
                                             ch_camTrig= ch_camTrig, ch_switch= ch_switch)
    camTrigInds = getCamTrigInds(bas[ch_camTrig], thr=thr_camTrig)

    #    dt_bas = 1/Fs_bas
    dt_ca = np.mean(np.diff(bas['t'][camTrigInds]))
    Fs_ca = int(np.round(1 / dt_ca))
    if np.any(ca == None):
        nFrames = len(camTrigInds)


#        t_ca = np.arange(nFrames)*dt_ca
    else:
        nFrames = ca.shape[1]
        #        t_ca = np.arange(nFrames)*dt_ca
        d = len(camTrigInds) - ca.shape[1]
        if d > 0:
            print(f'Check threshold, {d} more camera triggers than expected')
        elif d < 0:
            print('Check threshold, {d} fewer camera triggers than expected')
    n_pre_bas = int(np.round(t_pre_bas * Fs_bas))
    n_post_bas = int(np.round(t_post_bas * Fs_bas))
    n_pre_ca = int(np.round(t_pre_ca * Fs_ca))
    n_post_ca = int(np.round(t_post_ca * Fs_ca))

    stimFrames = spt.nearestMatchingInds(stimInds, camTrigInds) - 2
    mtr_trl = spt.segmentByEvents(mtr,
                                  stimInds,
                                  n_pre_bas,
                                  n_post_bas,
                                  n_jobs=n_jobs)
    indsVec_ca = np.arange(nFrames)
    indsVec_ca_trl = spt.segmentByEvents(indsVec_ca,
                                         stimFrames,
                                         n_pre_ca,
                                         n_post_ca,
                                         n_jobs=n_jobs)
    trlNum_actual = np.arange(len(mtr_trl))
    lens_mtr = np.array([len(mtr_) for mtr_ in mtr_trl])
    lens_ca = np.array([len(inds_) for inds_ in indsVec_ca_trl])
    inds_del_mtr = np.where(lens_mtr != mode(lens_mtr)[0])[0]
    inds_del_ca = np.where(lens_ca != mode(lens_ca)[0])[0]
    inds_del_trl = np.union1d(inds_del_mtr, inds_del_ca)
    trlNum_actual = np.delete(trlNum_actual, inds_del_trl)
    mtr_trl = list(np.delete(
        mtr_trl, inds_del_trl,
        axis=0))  # Converting to list so that can later put in dataframe
    indsVec_ca_trl = list(np.delete(indsVec_ca_trl, inds_del_trl, axis=0))
    dic = dict(mtr_trl=mtr_trl, caInds_trl=indsVec_ca_trl)
    if not np.any(ca == None):
        if np.ndim(ca) == 1:
            ca = ca[np.newaxis, ...]
        ca_trl = [ca[:, inds_] for inds_ in indsVec_ca_trl]
        dic['ca_trl'] = ca_trl
    trlNum = np.arange(len(mtr_trl))
    dic['trlNum'] = trlNum
    dic['trlNum_actual'] = trlNum_actual
    dic['stimAmp'] = np.delete(stimAmps, inds_del_trl)
    dic['stimLoc'] = np.delete(stimHt, inds_del_trl)
    return pd.DataFrame(dic, columns=dic.keys())
Ejemplo n.º 15
0
def wavelet(y,
            t,
            dj=1 / 32,
            mother='morlet',
            pad=1,
            param=-1,
            freqScale='log',
            freqRange=None,
            **kwargs):
    '''
    Computes the wavelet transform of a timeseries
    [W,period,scale,coi] = wavelet(...)
    Parameters
    ----------
    y: 1D array
        Timeseries to obtain wavelet transform for
    t: Scalar or 1D array
        If scalar,then sampling interval (dt), else time sequence
    pad: Boolean
        Zero padding; pad = 1 results in padding of the signal with zeros till
        the next largest length that is a power of 2.
    dj: Scalar
        Wavelet scale resolution; defaults to 1/24
    mother: String
        Mother wavelet function (Morlet (default), Paul, or DOG) 
    param: Scalar
        Wave number; defaults to 6 for Morlet wavelet
    freqSCale: String, ['log'] | 'lin'
        Base-2 logarithmic or linear frequency scale.
    freqRange - 2 tuple or list, or None:
            Determines the frequency range over which to compute the wavelets.
            If None, then computes for the full possible range, determined by
            Nyquist limit the time length of the timeseries signal.
    **kwargs
        s0 - Smallest scale;  defaults to 2*dt, where dt is sampling interval
        J1 - Starting wavelet scale
    Outputs:
    W - Matrix of wavelet coefficients
    period - Vector of Fourier periods over which WT is computed
    scale  - Vector of wavelet scales used for computation
    coi  - Cone of influence; a vector of values that show where edge effects become significant
    '''
    import numpy as np
    import apCode.SignalProcessingTools as spt

    n1 = len(y)
    if np.size(t) == 1:
        dt = t
        t = np.arange(n1) * dt
    else:
        dt = t[1] - t[0]

# Smallest scale
    s0 = kwargs.get('s0', 2 * dt)

    J1 = kwargs.get('J1', int(np.fix(np.log2(n1 * dt / s0) / dj)))

    #...demean and and zero pad timeseries if specified
    y = y - np.mean(y)
    if pad == 1:
        y = spt.zeroPadToNextPowOf2(y)

    #...construct wavenumber array used in transform (eqn 5)
    N = len(y)
    k = (np.arange(N / 2) + 1) * ((2 * np.pi) / (N * dt))
    k_neg = -k[np.int(np.fix((N - 1) / 2)):0:-1]
    k = np.hstack((0, k, k_neg))

    #... compute fft of the padded timeseries (eqn 3)
    f = np.fft.fft(y)

    #...construct SCALE array & empty PERIOD and WAVE arrays
    fourier_factor = get_fourier_factor(mother=mother, k0=param)

    if freqScale.lower() == 'log':
        scale = s0 * 2**(np.arange(0, J1 + 1) * dj)
        if freqRange != None:
            freq = 1 / (scale * fourier_factor)
            minF = np.min(freqRange)
            maxF = np.max(freqRange)
            keepInds = np.where((freq >= minF) & (freq <= maxF))[0]
            freq = freq[keepInds]
            scale = 1. / (freq * fourier_factor)
    else:
        maxF = 1 / s0
        minF = 1 / (n1 * dt)
        freq = np.arange(minF, maxF, dj)
        if freqRange != None:
            minF = np.min(freqRange)
            maxF = np.max(freqRange)
            keepInds = np.where((freq >= minF) & (freq <= maxF))[0]
            freq = freq[keepInds]
        scale = 1 / (freq * fourier_factor)

    period = scale
    wave = np.zeros((len(scale), N))  # instantiate the wavelet array
    wave = wave + 1j * wave  # Make it complex

    for sNum, s in enumerate(scale):
        daughter, fourier_factor, coi, dofmin = wave_bases(k,
                                                           s,
                                                           mother=mother,
                                                           param=param)
        wave[sNum, :] = np.fft.ifft(f * daughter)

    period = fourier_factor * scale
    vec1 = np.arange(1, (n1 + 1) / 2 - 1)
    vec2 = np.arange((n1 / 2 - 1), 0, -1)
    vec = np.hstack((1e-5, vec1, vec2, 1e-5))
    coi = coi * dt * vec
    wave = wave[:, 0:n1]
    return wave, period, scale, coi
Ejemplo n.º 16
0
def expand_on_bends(df_trl,
                    Fs=500,
                    tPre_ms=100,
                    bendThr=10,
                    minLat_ms=5,
                    maxGap_ms=100):
    """Takes dataframe where each row contains single trial information and
    expands such that each row contains single bend information
    Parameters
    ----------
    df_trl: pandas dataframe, (nTrlsInTotal, nVariables)
    Fs: int
        Sampling frequency when collecting data(images)
    nPre_ms: scalar

    """
    import apCode.SignalProcessingTools as spt
    minPkDist = int((10e-3) * Fs)
    nPre = tPre_ms * 1e-3 * Fs
    minLat = minLat_ms * 1e-3 * Fs
    maxGap = maxGap_ms * 1e-3 * Fs
    df_bend = []
    for iTrl in np.unique(df_trl.trlIdx_glob):
        df_now = df_trl.loc[df_trl.trlIdx_glob == iTrl]
        y = df_now.iloc[0]['tailAngles'][-1]
        y = spt.chebFilt(y, 1 / Fs, (5, 60), btype='bandpass')
        pks = spt.findPeaks(y,
                            thr=bendThr,
                            thrType='rel',
                            pol=0,
                            minPkDist=minPkDist)[0]
        if len(pks) > 3:
            dpks = np.diff(pks)
            tooSoon = np.where(pks < (nPre + minLat))[0]
            tooSparse = np.where(dpks > maxGap)[0] + 1
            inds_del = np.union1d(tooSoon, tooSparse)
            pks = np.delete(pks, inds_del, axis=0)
        if len(pks) > 3:
            nBends = len(pks)
            bendIdx = np.arange(nBends)
            bendSampleIdxInTrl = pks
            bendAmp = y[pks]
            bendAmp_abs = np.abs(bendAmp)
            bendAmp_rel = np.insert(np.abs(np.diff(bendAmp)), 0,
                                    np.abs(bendAmp[0]))
            bendInt_ms = np.gradient(pks) * (1 / Fs) * 1000
            onset_ms = (pks[0] - nPre + 1) * (1 / Fs) * 1000
        else:
            nBends = 0
            bendIdx, bendAmp, bendAmp_abs, bendAmp_rel, bendInt_ms =\
                [np.nan for _ in range(5)]
            bendsampleIdxInTrl, onset_ms = [np.nan for _ in range(2)]
        dic = dict(trlIdx_glob=iTrl,
                   nBends=nBends,
                   bendIdx=bendIdx,
                   bendSampleIdxInTrl=bendSampleIdxInTrl,
                   bendAmp=bendAmp,
                   bendAmp_abs=bendAmp_abs,
                   bendAmp_rel=bendAmp_rel,
                   bendInt_ms=bendInt_ms,
                   onset_ms=onset_ms)
        df_now = pd.DataFrame(dic)
        df_bend.append(df_now)
    df_bend = pd.concat(df_bend, ignore_index=True)
    return pd.merge(df_trl, df_bend, on='trlIdx_glob')
#xtl = (np.floor(np.linspace(data['time'][0],data['time'][-1],10)/100)*100).astype(int) 
xt = xtl/(data['time'][1]-data['time'][0])
ax.xaxis.set_ticks(xt)
ax.xaxis.set_ticklabels(xtl)

ax.tick_params(labelsize = 14)
plt.ylabel('Clstr #', fontsize = 16)
plt.xlabel('Time (sec)', fontsize = 16)
plt.title('Cell data, sorted by cluster', fontsize = 18);


#%% Create rgb img stack with cells colored by cluster ID
import nvCode.tifffile as tff
import apCode.volTools as volt

imgStack_norm = spt.standardize(data['avg'].copy())
imgStack = np.transpose(np.tile(np.zeros(np.shape(imgStack_norm)),[3,1,1,1]),[1,2,3,0])
sliceList = np.arange(np.shape(data['avg'])[0])
#sliceList = [30]
for z in sliceList:   
    inds_slice = np.where(data['info']['slice']==z)[0]    
    imgStack[z,:,:,:] = volt.gray2rgb(imgStack_norm[z])
    for lbl in np.unique(labels):        
        inds_lbl = data['inds']['inMask'].ravel()[np.where(labels == lbl)[0]]
        inds_cell = np.intersect1d(inds_slice,inds_lbl)     
        if len(inds_cell)>0:
            pxls_cell = [data['info']['inds'][ind] for ind in inds_cell]
            #pxls = np.squeeze(np.array(pxls_cell)).ravel().astype(int)
            rgb= plt.cm.RdYlGn(color_idx[lbl])[0:3]           
            for pxl in pxls_cell:                
                pxl = pxl.ravel().astype(int)
Ejemplo n.º 18
0
dt = 1 / Fs
preStimPts = preStimPer * Fs
postStimPts = postStimPer * Fs

# In[112]:

importlib.reload(aed)
pre = aed.import10ch(os.path.join(epDir, preFile))
post = aed.import10ch(os.path.join(epDir, postFile))
print(pre.keys())
print(time.ctime())

# In[155]:

#%% Detect and check stimulus indices
pre['stimInds'] = spt.findPeaks(spt.zscore(pre[stimCh]), thr=3)[0] - 2
post['stimInds'] = spt.findPeaks(spt.zscore(post[stimCh]), thr=3)[0] - 2

xmin = np.max((pre['t'][0], post['t'][0]))
xmax = np.min((pre['t'][-1], post['t'][-1]))
plt.style.use(('seaborn-dark', 'seaborn-colorblind', 'seaborn-poster'))
plt.subplot(2, 1, 1)
plt.plot(pre['t'], pre[stimCh])
plt.plot(pre['t'][pre['stimInds']],
         pre[stimCh][pre['stimInds']],
         'o',
         markersize=15)
plt.xlim(xmin, xmax)
plt.title('Pre-Ablation')

plt.subplot(2, 1, 2)
Ejemplo n.º 19
0
def estimateCaDecayKinetics(time, signals, p0 = None, thr = 2, preTime = 10, 
                            postTime = 40):
    """
    Given a time vector and Ca signal matrix of shape = (C,T), where
        C = # of cells, and T = # of time points (must match length of time
        vector), returns output of shape = (nSamples, 2), where the 1st and
        2nd columns contain the fast and slow decay tau estimates after
        fitting Ca2+ signals with  double exponential
    Parameters:
    time - Time vector of length T
    signals - Ca signals array of shape (nSamples,T)
    p0 - Array-like, (tau_fast, tau_slow, wt_fast), where tau_fast is the 
        fast decay time constant (in sec), tau_slow is the slow decay
        constant, and wt_fast is the weight of the fast exponential (<1)
        for fitting the signal as a weighted sum of the fast and slow
        exponential. Default is None, in which case fitting optimization
        begins without initial estimate
    thr - Threshold for peak detection in Ca signals, in units of zscore
    preTime - Pre-peak time length of the Ca signals to include for segmentation
    postTime - Post-peak "           "          "               "
    Avinash Pujala, JRC, 2017
        
    """
    import numpy as np
    from scipy.optimize import curve_fit as cf
    import apCode.SignalProcessingTools as spt
    import apCode.AnalyzeEphysData as aed
    
    def doubleExp(time, tau1, tau2, wt1):    
        wt2 = 1-wt1
        time = time - time[0]
        e = wt1*np.exp(-time/tau1) + wt2*np.exp(-time/tau2)
        return e
    
    def listToArray(x):
        lens = [len(item) for item in x]
        lenOfLens = len(lens)       
        lens = lens[np.min((lenOfLens-1,2))]
        a = np.zeros((len(x),lens))
        delInds = []
        for itemNum,item in enumerate(x):
            if len(item) == lens:
                a[itemNum,:] = item
            else:
                delInds.append(itemNum)
        a = np.delete(a,delInds,axis = 0)
        return a, delInds
    if np.ndim(signals)==1:
        signals = np.reshape(signals,(1,len(signals)))
    dt = time[2]-time[1]
    pts_post = np.round(postTime/dt).astype(int)
    pts_pre = np.round(preTime/dt).astype(int) 
    x_norm = spt.zscore(signals,axis = 1)
    x_seg, params, x_seg_fit = [],[],[]
    nSamples = np.shape(signals)[0]
    excludedSamples = np.zeros((nSamples,1))
    for nSample in np.arange(nSamples):
        inds_pk = spt.findPeaks(x_norm[nSample,:],thr = thr,ampType = 'rel')[0]
        if len(inds_pk)==0:
            print('Peak detection failed for sample #', nSample, '. Try lowering threshold')
            excludedSamples[nSample] = 1
        else:
            blah = aed.SegmentDataByEvents(signals[nSample,:],inds_pk,pts_pre,pts_post,axis = 0)
            blah = listToArray(blah)[0]          
            blah = np.mean(blah,axis=0)
            x_seg.append(blah) 
            ind_max = np.where(blah == np.max(blah))[0][0]
            y = spt.standardize(blah[ind_max:])
            t = np.arange(len(y))*dt           
            popt,pcov = cf(doubleExp,t,y,p0 = [10,20, 0.5], bounds = (0,20))
            if popt[0]> popt[1]:
                popt[0:2] = popt[2:0:-1]
                popt[-1] = 1-popt[-1]
            params.append(popt)
            foo = doubleExp(t,popt[0],popt[1],popt[2])
            x_seg_fit.append(foo)
    excludedSamples = np.where(excludedSamples)[0]
    includedSamples = np.setdiff1d(np.arange(nSamples),excludedSamples)
    x_seg,delInds = listToArray(x_seg)
    params = np.delete(np.array(params),delInds,axis = 0)
    delInds = includedSamples[delInds]
    if len(delInds)>0:
        print('Sample #', delInds, 'excluded for short segment length. Consider decreasing pre-peak time length')
    excludedSamples = np.union1d(delInds,excludedSamples)
    
    x_seg = spt.standardize(np.array(x_seg),axis = 1)    
    x_seg_fit = np.array(listToArray(x_seg_fit)[0])
    out = {'raw': x_seg,'fit': x_seg_fit,'params': np.array(params),'excludedSamples': excludedSamples}
    return out
Ejemplo n.º 20
0
def wavelet(y,t,dj = 1/32, mother = 'morlet',pad = 1, param = -1, freqScale = 'log',
            freqRange = None, **kwargs):
    '''
    Computes the wavelet transform of a timeseries
    [W,period,scale,coi] = wavelet(...)
    Parameters
    ----------
    y: 1D array
        Timeseries to obtain wavelet transform for
    t: Scalar or 1D array
        If scalar,then sampling interval (dt), else time sequence
    pad: Boolean
        Zero padding; pad = 1 results in padding of the signal with zeros till
        the next largest length that is a power of 2.
    dj: Scalar
        Wavelet scale resolution; defaults to 1/24
    mother: String
        Mother wavelet function (Morlet (default), Paul, or DOG) 
    param: Scalar
        Wave number; defaults to 6 for Morlet wavelet
    freqSCale: String, ['log'] | 'lin'
        Base-2 logarithmic or linear frequency scale.
    freqRange - 2 tuple or list, or None:
            Determines the frequency range over which to compute the wavelets.
            If None, then computes for the full possible range, determined by
            Nyquist limit the time length of the timeseries signal.
    **kwargs
        s0 - Smallest scale;  defaults to 2*dt, where dt is sampling interval
        J1 - Starting wavelet scale
    Outputs:
    W - Matrix of wavelet coefficients
    period - Vector of Fourier periods over which WT is computed
    scale  - Vector of wavelet scales used for computation
    coi  - Cone of influence; a vector of values that show where edge effects become significant
    '''
    import numpy as np
    import apCode.SignalProcessingTools as spt
    
    n1 = len(y)
    if np.ndim(t)==0:
        dt = t
        t = np.arange(n1)*dt
    else:
        dt = t[1]-t[0]    
    
   # Smallest scale
    s0 = kwargs.get('s0')
    if s0 is None:
        s0 = 2*dt    
    
    J1 = kwargs.get('J1')     
    if J1 is None:
        J1 = int(np.fix(np.log2(n1*dt/s0)/dj))    
    
    #...demean and and zero pad timeseries if specified
    y = y - np.mean(y)
    if pad ==1:
        y = spt.zeroPadToNextPowOf2(y)
    
    #...construct wavenumber array used in transform (eqn 5)
    N = len(y)
    k = (np.arange(N/2) + 1) *((2*np.pi)/(N*dt))
    k_neg = -k[np.int(np.fix((N-1)/2)):0:-1]
    k = np.hstack((0,k,k_neg))
    
    #... compute fft of the padded timeseries (eqn 3)
    f = np.fft.fft(y)
    
    #...construct SCALE array & empty PERIOD and WAVE arrays
    fourier_factor = get_fourier_factor(mother = mother, k0 = param)
    
    if freqScale.lower()== 'log':
        scale = s0*2**(np.arange(0,J1+1)*dj)        
        if freqRange != None:
            freq = 1/(scale*fourier_factor)
            minF = np.min(freqRange)
            maxF = np.max(freqRange)
            keepInds = np.where((freq>=minF) & (freq <= maxF))[0]
            freq = freq[keepInds]
            scale = 1./(freq*fourier_factor)
    else:
        maxF = 1/s0
        minF = 1/(n1*dt)
        freq = np.arange(minF,maxF,dj)
        if freqRange != None:
            minF = np.min(freqRange)
            maxF = np.max(freqRange)
            keepInds = np.where((freq>=minF) & (freq <= maxF))[0]
            freq = freq[keepInds]
        scale = 1/(freq*fourier_factor)        
    
    period = scale
    wave = np.zeros((len(scale),N)) # instantiate the wavelet array
    wave = wave + 1j*wave # Make it complex
    
    for sNum, s in enumerate(scale):
        daughter, fourier_factor,coi, dofmin = wave_bases(k,s, mother = mother, param = param)
        wave[sNum,:] = np.fft.ifft(f * daughter)
    
    period = fourier_factor*scale
    vec1 = np.arange(1,(n1+1)/2-1)
    vec2 = np.arange((n1/2-1),0,-1)
    vec = np.hstack((1e-5,vec1,vec2,1e-5))
    coi = coi*dt*vec
    wave = wave[:,0:n1]    
    return wave,period,scale,coi
Ejemplo n.º 21
0
                              replace=True)
var['sat_score'] = np.random.choice(np.arange(1200, 1601),
                                    size=n_samples,
                                    replace=True)
var['hours_of_research_experience'] = np.random.choice(np.arange(20, 110),
                                                       size=n_samples,
                                                       replace=True)
var['age_of_applicant'] = np.random.choice(np.arange(18, 22),
                                           size=n_samples,
                                           replace=True)
var['harvard_entrance_test_score'] = np.random.choice(np.arange(70, 101), size = n_samples,\
   replace = True)

X = np.array([var['gpa'], var['sat_score'], var['hours_of_research_experience'], var['age_of_applicant'],\
     var['harvard_entrance_test_score']])
X = spt.standardize(X.T, axis=0)

y = np.dot(X, wts)
y = y + np.random.rand(len(y)) * np.std(y)
y = spt.standardize(y) * 0.8 + 0.2

var['chances_of_acceptance_into_harvard'] = y

data = var

#%%
from sklearn.linear_model import LinearRegression
reg = LinearRegression().fit(X, y)

print('Regression coefficients are {}, and intercept is {}'.format(
    reg.coef_, reg.intercept_))
Ejemplo n.º 22
0
def estimateCaDecayKinetics(time,
                            signals,
                            p0=None,
                            thr=2,
                            preTime=10,
                            postTime=40):
    """
    Given a time vector and Ca signal matrix of shape = (C,T), where
        C = # of cells, and T = # of time points (must match length of time
        vector), returns output of shape = (nSamples, 2), where the 1st and
        2nd columns contain the fast and slow decay tau estimates after
        fitting Ca2+ signals with  double exponential
    Parameters:
    time - Time vector of length T
    signals - Ca signals array of shape (nSamples,T)
    p0 - Array-like, (tau_fast, tau_slow, wt_fast), where tau_fast is the
        fast decay time constant (in sec), tau_slow is the slow decay
        constant, and wt_fast is the weight of the fast exponential (<1)
        for fitting the signal as a weighted sum of the fast and slow
        exponential. Default is None, in which case fitting optimization
        begins without initial estimate
    thr - Threshold for peak detection in Ca signals, in units of zscore
    preTime - Pre-peak time length of the Ca signals to include for segmentation
    postTime - Post-peak "           "          "               "
    Avinash Pujala, JRC, 2017

    """
    import numpy as np
    from scipy.optimize import curve_fit as cf
    import apCode.SignalProcessingTools as spt
    import apCode.AnalyzeEphysData as aed

    def doubleExp(time, tau1, tau2, wt1):
        wt2 = 1 - wt1
        time = time - time[0]
        e = wt1 * np.exp(-time / tau1) + wt2 * np.exp(-time / tau2)
        return e

    def listToArray(x):
        lens = [len(item) for item in x]
        lenOfLens = len(lens)
        lens = lens[np.min((lenOfLens - 1, 2))]
        a = np.zeros((len(x), lens))
        delInds = []
        for itemNum, item in enumerate(x):
            if len(item) == lens:
                a[itemNum, :] = item
            else:
                delInds.append(itemNum)
        a = np.delete(a, delInds, axis=0)
        return a, delInds

    if np.ndim(signals) == 1:
        signals = np.reshape(signals, (1, len(signals)))
    dt = time[2] - time[1]
    pts_post = np.round(postTime / dt).astype(int)
    pts_pre = np.round(preTime / dt).astype(int)
    x_norm = spt.zscore(signals, axis=1)
    x_seg, params, x_seg_fit = [], [], []
    nSamples = np.shape(signals)[0]
    excludedSamples = np.zeros((nSamples, 1))
    for nSample in np.arange(nSamples):
        inds_pk = spt.findPeaks(x_norm[nSample, :], thr=thr, ampType='rel')[0]
        if len(inds_pk) == 0:
            print('Peak detection failed for sample #', nSample,
                  '. Try lowering threshold')
            excludedSamples[nSample] = 1
        else:
            blah = aed.SegmentDataByEvents(signals[nSample, :],
                                           inds_pk,
                                           pts_pre,
                                           pts_post,
                                           axis=0)
            blah = listToArray(blah)[0]
            blah = np.mean(blah, axis=0)
            x_seg.append(blah)
            ind_max = np.where(blah == np.max(blah))[0][0]
            y = spt.standardize(blah[ind_max:])
            t = np.arange(len(y)) * dt
            popt, pcov = cf(doubleExp, t, y, p0=[10, 20, 0.5], bounds=(0, 20))
            if popt[0] > popt[1]:
                popt[0:2] = popt[2:0:-1]
                popt[-1] = 1 - popt[-1]
            params.append(popt)
            foo = doubleExp(t, popt[0], popt[1], popt[2])
            x_seg_fit.append(foo)
    excludedSamples = np.where(excludedSamples)[0]
    includedSamples = np.setdiff1d(np.arange(nSamples), excludedSamples)
    x_seg, delInds = listToArray(x_seg)
    params = np.delete(np.array(params), delInds, axis=0)
    delInds = includedSamples[delInds]
    if len(delInds) > 0:
        print(
            'Sample #', delInds,
            'excluded for short segment length. Consider decreasing pre-peak time length'
        )
    excludedSamples = np.union1d(delInds, excludedSamples)

    x_seg = spt.standardize(np.array(x_seg), axis=1)
    x_seg_fit = np.array(listToArray(x_seg_fit)[0])
    out = {
        'raw': x_seg,
        'fit': x_seg_fit,
        'params': np.array(params),
        'excludedSamples': excludedSamples
    }
    return out
Ejemplo n.º 23
0
# In[112]:


importlib.reload(aed)
pre = aed.import10ch(os.path.join(epDir,preFile))
post = aed.import10ch(os.path.join(epDir,postFile))
print(pre.keys())
print(time.ctime())


# In[155]:


#%% Detect and check stimulus indices
pre['stimInds'] = spt.findPeaks(spt.zscore(pre[stimCh]),thr=3)[0]-2
post['stimInds'] = spt.findPeaks(spt.zscore(post[stimCh]),thr=3)[0]-2

xmin = np.max((pre['t'][0],post['t'][0]))
xmax = np.min((pre['t'][-1],post['t'][-1]))
plt.style.use(('seaborn-dark','seaborn-colorblind','seaborn-poster'))
plt.subplot(2,1,1)
plt.plot(pre['t'],pre[stimCh])
plt.plot(pre['t'][pre['stimInds']],pre[stimCh][pre['stimInds']],'o',markersize =15)
plt.xlim(xmin,xmax)
plt.title('Pre-Ablation')

plt.subplot(2,1,2)
plt.plot(post['t'],post[stimCh])
plt.plot(post['t'][post['stimInds']],post[stimCh][post['stimInds']],'o',markersize =15)                                 
plt.xlim(xmin,xmax)
Ejemplo n.º 24
0
def readPeriStimulusTifImages(tifPaths, basPaths, nBasCh = 16, ch_camTrig = 'patch4', ch_stim = 'patch3',\
                  tifNameStr = '', time_preStim = 1, time_postStim = 10, thr_stim = 1.5,\
                  thr_camTrig = 1, maxAllowedTimeBetweenStimAndCamTrig = 0.5, n_jobs = 1):
    """
    Given the directory to .tif files stored by ScanImage (Bessel beam image settings) and the full
    path to the accompanying bas files returns a dictionary with values holding peri-stimulus
    image data (Ca activity) in trialized format along with some other pertinent info.

    Parameters
    ----------
    tifDir: string
        Path to directory holding .tif files written by ScanImage. In the current setting, each .tif
        file holds nCh*3000 images, where nCh = number of channels.
    basPath: string
        Full path to the bas (BehavAndScan) file accompanying the imaging session.
    nBasCh: scalar
        Number of signal channels in the bas file.
    ch_camTrig: string
        Name of the channel in bas corresponding to the camera trigger signal
    ch_stim: string
        Name of the stimulus signal channel in bas.
    tifNameStr: string
        Only .tif files containing this will be read.
    time_preStim: scalar
        The length of the pre-stimulus time to include when reading images
    time_postStim: scalar
        The length of the post-stimulus time.
    thr_stim: scalar
        Threshold to use for detection of stimuli in the stimulus channel of bas.
    thr_camTrig: scalar
        Threshold to use for detection of camera trigger onsets in the camera trigger channel of bas.
    maxAllowedTimeBetweenStimAndCamTrig: scalar
        If a camera trigger is separated in time by the nearest stimulus by longer than this time
        interval, then ignore this stimulus trial.
    Returns
    -------
    D: dict
        Dictionary contaning the following keys:
        'I': array, (nTrials, nTime, nImageChannels, imageWidth, imageHeight)
            Image hyperstack arranged in conveniently-accessible trialized format.
        'tifInfo': dict
            Dictionary holding useful image metadata. Has following keys:
            'filePaths': list of strings
                Paths to .tif files
            'nImagesInfile': scalar int
                Number of images in each .tif file after accounting of number of
                image channels
            'nChannelsInFile': scalar int
                Number of image channels
        'inds_stim': array of integers, (nStim,)
            Indices in bas coordinates where stimuli occurred.
        'inds_stim_img': array of integers, (nStim,)
            Indices in image coordinates where stimuli occurred
        'inds_camTrig': array of integers, (nCameraTriggers,)
            Indices in bas coordinates corresponding to the onsets of camera triggers.
        'bas': dict
            BehavAndScan data

    """
    import tifffile as tff
    import numpy as np
    import apCode.FileTools as ft
    import apCode.ephys as ephys
    import apCode.SignalProcessingTools as spt
    import apCode.util as util

    #    import os
    def getImgIndsInTifs(tifInfo):
        nImgsInFile_cum = np.cumsum(tifInfo['nImagesInFile'] *
                                    tifInfo['nChannelsInFile'])
        imgIndsInTifs = []
        for i in range(len(nImgsInFile_cum)):
            if i == 0:
                inds_ = np.arange(0, nImgsInFile_cum[i])
            else:
                inds_ = np.arange(nImgsInFile_cum[i - 1], nImgsInFile_cum[i])
            imgIndsInTifs.append(inds_)
        return imgIndsInTifs

    ### Read relevant metadata from tif files in directory
    print('Reading ScanImage metadata from tif files...')

    tifInfo = ft.scanImageTifInfo(tifPaths)
    nCaImgs = np.sum(tifInfo['nImagesInFile'])

    ### Check for consistency in the number of image channels in all files.
    if len(np.unique(tifInfo['nChannelsInFile'])) > 1:
        print('Different number of image channels across files, check files!')
        return None
    nImgCh = tifInfo['nChannelsInFile'][0]
    print(f'{nCaImgs} {nImgCh}-channel images from all tif files')

    ### Get a list of indices corresponding to images in each of the tif files
    inds_imgsInTifs = getImgIndsInTifs(tifInfo)

    ### Read bas file to get stimulus and camera trigger indices required to align images and behavior
    print(
        'Reading and joining bas files, detecting stimuli and camera triggers...'
    )
    basList = [ephys.importCh(bp, nCh=nBasCh) for bp in basPaths]
    bas = concatenateBas(basList)
    inds_stim = spt.levelCrossings(bas[ch_stim], thr=thr_stim)[0]
    if len(inds_stim) == 0:
        print(
            f'Only {len(inds_stim)} stims detected, check channel specification or threshold'
        )
        return dict(bas=bas)
    inds_camTrig = spt.levelCrossings(bas[ch_camTrig], thr=thr_camTrig)[0]
    if len(inds_camTrig) == 0:
        print(
            f'Only {len(inds_camTrig)} cam trigs detected, check channel specification or threshold'
        )
        return dict(bas=bas)
    dt_vec = np.diff(bas['t'][inds_camTrig])
    dt_ca = np.round(np.mean(dt_vec) * 100) / 100
    print('Ca sampling rate = {}'.format(1 / dt_ca))
    inds_del = np.where(dt_vec <= (0.5 * dt_ca))[0] + 1
    inds_camTrig = np.delete(inds_camTrig, inds_del)

    ### Deal with possible mismatch in number of camera trigger indices and number of images in tif files
    if nCaImgs < len(inds_camTrig):
        inds_camTrig = inds_camTrig[:nCaImgs]
        nCaImgs_extra = 0
    elif nCaImgs > len(inds_camTrig):
        nCaImgs_extra = nCaImgs - len(inds_camTrig)
    else:
        nCaImgs_extra = 0
        print('{} extra Ca2+ images'.format(nCaImgs_extra))
    print('{} stimuli and {} camera triggers'.format(len(inds_stim),
                                                     len(inds_camTrig)))

    ### Indices of ca images closest to stimulus
    inds_stim_img = spt.nearestMatchingInds(inds_stim, inds_camTrig)

    ### Find trials where the nearest cam trigger is farther than the stimulus by a certain amount
    inds_camTrigNearStim = inds_camTrig[inds_stim_img]
    t_stim = bas['t'][inds_stim]
    t_camTrigNearStim = bas['t'][inds_camTrigNearStim]
    inds_tooFar = np.where(
        np.abs(t_stim -
               t_camTrigNearStim) > maxAllowedTimeBetweenStimAndCamTrig)[0]
    inds_ca_all = np.arange(nCaImgs)
    nPreStim = int(time_preStim / dt_ca)
    nPostStim = int(time_postStim / dt_ca)
    print("{} pre-stim points, and {} post-stim points".format(
        nPreStim, nPostStim))
    inds_ca_trl = np.array(
        spt.segmentByEvents(inds_ca_all, inds_stim_img + nCaImgs_extra,
                            nPreStim, nPostStim))
    ### Find trials that are too short to include the pre- or post-stimulus period
    trlLens = np.array([len(trl_) for trl_ in inds_ca_trl])
    inds_tooShort = np.where(trlLens < np.max(trlLens))[0]
    inds_trl_del = np.union1d(inds_tooFar, inds_tooShort)
    inds_trl_keep = np.setdiff1d(np.arange(len(inds_ca_trl)), inds_trl_del)

    ### Exclude the above 2 types of trials from consideration
    if len(inds_trl_del) > 0:
        print('Excluding the trials {}'.format(inds_trl_del))
        inds_ca_trl = inds_ca_trl[inds_trl_keep]

    I = []
    print('Reading trial-related images from tif files...')
    nTrls = len(inds_ca_trl)

    def trlImages(inds_ca_trl, inds_imgsInTifs, nImgCh, tifInfo, trl):
        trl_ = np.arange(trl.min() * nImgCh, (trl.max() + 1) * nImgCh)
        loc = util.locateItemsInSetsOfItems(trl_, inds_imgsInTifs)
        I_ = []
        for subInds, supInd in zip(loc['subInds'], loc['supInds']):
            with tff.TiffFile(tifInfo['filePaths'][supInd]) as tif:
                img = tif.asarray(key=subInds)
            I_.extend(img.reshape(-1, nImgCh, *img.shape[1:]))
        I_ = np.array(I_)
        return I_

    if n_jobs < 2:
        chunkSize = int(nTrls / 5)
        for iTrl, trl in enumerate(inds_ca_trl):
            if np.mod(iTrl, chunkSize) == 0:
                print('Trl # {}/{}'.format(iTrl + 1, nTrls))
            I_ = trlImages(inds_ca_trl, inds_imgsInTifs, nImgCh, tifInfo, trl)
            I.append(I_)
    else:
        print('Processing with dask')
        import dask
        from dask.diagnostics import ProgressBar
        for trl in inds_ca_trl:
            I_ = dask.delayed(trlImages)(inds_ca_trl, inds_imgsInTifs, nImgCh,
                                         tifInfo, trl)
            I.append(I_)
        with ProgressBar():
            I = dask.compute(*I)

    D = dict(I = np.squeeze(np.array(I)), tifInfo = tifInfo, inds_stim = inds_stim, inds_stim_img = inds_stim_img,\
             inds_camTrig = inds_camTrig,bas = bas, inds_trl_excluded = inds_trl_del)
    return D
Ejemplo n.º 25
0
    def plotCentroids(time,
                      centroids,
                      stimTimes,
                      time_ephys,
                      ephys,
                      scaled=False,
                      colors=None,
                      xlabel='',
                      ylabel='',
                      title=''):
        """
        Plots centroids resulting from some clustering method
        Parameters:
        time - Time vectors for centroids (optical samplign interval)
        centroids - Array of shape (M, N), where M is the number of centroids, and N is the #
            number of features (or time points)
        stimTimes - Times of stimulus onsets for overlaying vertical dashed lines
        time_ephys - Time axis for ephys data (usually sampled at higher rate)
        ephys - Ephys time series
        scaled - Boolean; If true, scales centroids individually, else scales jointly.
        colors - Array of shape (M,3) or (M,4). Colormap to use for plotting centroids

        """
        import apCode.SignalProcessingTools as spt
        import seaborn as sns
        import numpy as np
        import matplotlib.pyplot as plt
        if scaled:
            centroids = spt.standardize(centroids, axis=1)
        else:
            centroids = spt.standardize(centroids)

        ephys = spt.standardize(ephys)

        n_clusters = np.shape(centroids)[0]
        if np.any(colors == None):
            colors = np.array(
                sns.color_palette('colorblind',
                                  np.shape(centroids)[0]))
        elif np.shape(colors)[0] < np.shape(centroids)[0]:
            colors = np.tile(colors, (np.shape(centroids)[0], 1))
            colors = colors[:np.shape(centroids)[0], :]

        if np.any(time == None):
            time = np.arange(np.shape(centroids)[1])

        plt.style.use(['dark_background', 'seaborn-poster'])
        for cc in np.arange(np.shape(centroids)[0]):
            plt.plot(time,
                     centroids[cc, :] - np.mean(centroids[cc, :]) - cc,
                     color=colors[cc, :])
        plt.plot(time_ephys,
                 ephys - np.mean(ephys) - cc - 1,
                 color=colors[0, :])
        yt = np.arange(n_clusters + 1)
        ytl = list(yt)
        ytl[-1] = 'ephys'
        plt.yticks(-yt, ytl)
        plt.xlabel(xlabel, fontsize=16)
        plt.ylabel(ylabel, fontsize=16)
        plt.box('off')
        plt.title(title, fontsize=20)
        plt.grid('off')
        plt.xlim(time[0], time[-1])
        for st in stimTimes:
            plt.axvline(x=st,
                        ymin=0,
                        ymax=1,
                        alpha=0.3,
                        color='w',
                        linestyle='--')
Ejemplo n.º 26
0
tic = time.time()
with sitr.ScanImageTiffReader(os.path.join(dir_imgs, files_ca[0])) as reader:
    I_si = reader.data()
print("SI time = {}".format(time.time() - tic))

#%%
ch_camTrig = 'patch1'
ch_stim = 'patch3'
frameRate = 50

path_bas = os.path.join(os.path.split(dir_imgs)[0], 't01_bas.16ch')
import apCode.ephys as ephys
bas = ephys.importCh(path_bas, nCh=16)
print(bas.keys())

inds_stim = spt.levelCrossings(bas[ch_stim], thr=2)[0]
dInds = np.diff(bas['t'][inds_stim])
inds_del = np.where(dInds < 15)[0] + 1
inds_stim = np.delete(inds_stim, inds_del)
inds_camTrig = spt.levelCrossings(bas[ch_camTrig], thr=2)[0]
dInds = np.diff(bas['t'][inds_camTrig])
inds_del = np.where(dInds <= (0.5 / frameRate))[0] + 1
inds_camTrig = np.delete(inds_camTrig, inds_del)
# plt.plot(bas[ch_camTrig])
# plt.plot(inds_camTrig,bas[ch_camTrig][inds_camTrig],'o')
print('# stims = {}, # of camera triggers = {}'.format(len(inds_stim),
                                                       len(inds_camTrig)))

#%%
#%%
maxAllowedTimeBetweenStimAndCamTrig = 0.5  # In sec
Ejemplo n.º 27
0
# NMFD core method
nmfdW, nmfdH, nmfdV, divKL, _ = NMFD(A, paramNMFD)

# alpha-Wiener filtering
nmfdA, _ = alphaWienerFilter(A, nmfdV, 1.0)


#%% 4. Visualize
paramVis = dict()
paramVis['deltaT'] = deltaT
paramVis['deltaF'] = deltaF
paramVis['endeSec'] = 3.8
paramVis['fontSize'] = 14
fh1, _ = visualizeComponentsNMF(A, nmfdW, nmfdH, nmfdA, paramVis)
import matplotlib.pyplot as plt
plt.show()

#%% 5. Write audio
audios = []
# resynthesize results of NMF with soft constraints and score information
for k in range(numComp):
    Y = nmfdA[k] * np.exp(1j * P);
    y, _ = inverseSTFT(Y, paramSTFT)

    audios.append(y)
    # save result
out_filepath = os.path.join(outDir, f'drums_{k}.wav')
y_out = audios[0] + audios[2]
wav.write(filename=out_filepath, rate=fs, data=spt.standardize(y_out)*2-1)

xtl = np.arange(0, data['time'][-1], 500).astype(int)
#xtl = (np.floor(np.linspace(data['time'][0],data['time'][-1],10)/100)*100).astype(int)
xt = xtl / (data['time'][1] - data['time'][0])
ax.xaxis.set_ticks(xt)
ax.xaxis.set_ticklabels(xtl)

ax.tick_params(labelsize=14)
plt.ylabel('Clstr #', fontsize=16)
plt.xlabel('Time (sec)', fontsize=16)
plt.title('Cell data, sorted by cluster', fontsize=18)

#%% Create rgb img stack with cells colored by cluster ID
import nvCode.tifffile as tff
import apCode.volTools as volt

imgStack_norm = spt.standardize(data['avg'].copy())
imgStack = np.transpose(
    np.tile(np.zeros(np.shape(imgStack_norm)), [3, 1, 1, 1]), [1, 2, 3, 0])
sliceList = np.arange(np.shape(data['avg'])[0])
#sliceList = [30]
for z in sliceList:
    inds_slice = np.where(data['info']['slice'] == z)[0]
    imgStack[z, :, :, :] = volt.gray2rgb(imgStack_norm[z])
    for lbl in np.unique(labels):
        inds_lbl = data['inds']['inMask'].ravel()[np.where(labels == lbl)[0]]
        inds_cell = np.intersect1d(inds_slice, inds_lbl)
        if len(inds_cell) > 0:
            pxls_cell = [data['info']['inds'][ind] for ind in inds_cell]
            #pxls = np.squeeze(np.array(pxls_cell)).ravel().astype(int)
            rgb = plt.cm.RdYlGn(color_idx[lbl])[0:3]
            for pxl in pxls_cell:
Ejemplo n.º 29
0
 def filtImgs(I,filtSize = 5, kernel = 'median', process = 'parallel'):
     '''
     Processes images so as to make moving particle tracking easier
     I_proc = processImagesForTracking(I,filtSize = 5)
     Parameters
     ----------
     I: 3D array of shape (T,M,N), where T = # of images, M, N = # of rows
         and columns respectively
         Image stack to filter
     kernel: String or 2D array
         ['median'] | 'rect' | 'gauss' or array specifying the kernel
     filtSize: Scalar or 2-tuple
         Size of the kernel to generate if kernel is string
     '''
     from scipy import signal
     import numpy as np
     import apCode.SignalProcessingTools as spt
     import time
     if process.lower() == 'parallel':
         from joblib import Parallel, delayed
         import multiprocessing
         parFlag = True
         num_cores = np.min((multiprocessing.cpu_count(),32))
     else:
         parFlag = False        
     tic = time.time()
     if np.ndim(I)<3:
         I = I[np.newaxis,:,:]
     N = np.shape(I)[0]
     
     print('Filter dimensions: {0}'.format(np.shape(filtSize)))
     
     I_flt = np.zeros(np.shape(I))
     if isinstance(kernel, str):
         if kernel.lower()=='median':
             print('Median filtering...')
             if np.size(filtSize)>1:
                 filtSize = filtSize[0]
             if np.mod(filtSize,2)==0:
                 filtSize = filtSize+1 # For median, the filter size should be odd
                 print('Median filter size must be odd, changed to {}'.format(filtSize))
             if parFlag:
                 print('# of cores = {}'.format(num_cores))
                 I_flt = Parallel(n_jobs = num_cores,verbose = 5)(delayed(signal.medfilt2d)(img,filtSize) for img in I)
                 I_flt = np.array(I_flt)
             else:                    
                 for imgNum, img in enumerate(I):
                     if np.mod(imgNum,300)==0:
                         print('Img # {0}/{1}'.format(imgNum,N))
                     I_flt[imgNum,:,:] = signal.medfilt2d(img,filtSize)                        
                 
         elif kernel.lower() == 'rect':
             print('Rectangular filtering...')
             if np.size(filtSize)==1:
                 ker = np.ones((filtSize,filtSize))
             else:
                 ker = np.ones(filtSize)
             ker = ker/ker.sum()
             ker = ker[np.newaxis,:,:]
             if parFlag:
                 I_flt = Parallel(n_jobs = num_cores,verbose = 5)(delayed(signal.convolve)(img,ker) for img in I)
             else:
                 I_flt = signal.convolve(I,ker,mode = 'same')
         elif kernel.lower()=='gauss':
             print('Gaussian filtering...')
             if np.size(filtSize)==1:
                 ker = spt.gausswin(filtSize)
                 ker = ker.reshape((-1,1))
                 ker = ker*ker.T
             else:
                 ker1 = spt.gausswin(filtSize[0]).reshape((-1,1))
                 ker2 = spt.gausswin(filtSize[0]).reshape((-1,1))
                 ker = ker1*ker2.T
             ker = ker/ker.sum()                
             if parFlag:
                 I_flt = Parallel(n_jobs = num_cores,verbose = 5)(delayed(signal.convolve)(img,ker) for img in I)
             else:
                 ker= ker[np.newaxis,:,:]
                 I_flt = signal.convolve(I,ker,mode = 'same')
     else:
         ker = ker/ker.sum()                
         if parFlag:
             I_flt = Parallel(n_jobs = num_cores,verbose = 5)(delayed(signal.convolve)(img,ker) for img in I)
         else:
             ker= ker[np.newaxis,:,:]
             I_flt = signal.convolve(I,ker,mode = 'same')
    
     print(int(time.time()-tic), 'sec')
     if np.shape(I_flt)[0]==1:
         I_flt = np.squeeze(I_flt)
     
     return I_flt