Exemplo n.º 1
0
def extract_wave_new(IndList, FilteredArr, s_before, s_after, n_ch, s_start,Threshold):
    IndArr = np.array(IndList, dtype=np.int32)
    SampArr = IndArr[:, 0]
    ChArr = IndArr[:, 1]
    n_ch = FilteredArr.shape[1]
    log_fd = GlobalVariables['log_fd']
    if np.amax(SampArr)-np.amin(SampArr)>Parameters['CHUNK_OVERLAP']/2:
        s = '''
        ************ ERROR **********************************************
        Connected component found with width larger than CHUNK_OVERLAP/2.
        Spikes could be repeatedly detected, increase the size of
        CHUNK_OVERLAP and re-run.
        Component sample range: {sample_range}
        *****************************************************************
        '''.format(sample_range=(s_start+np.amin(SampArr),
                                 s_start+np.amax(SampArr)))
        log_warning(s, multiline=True)
        #exit()

    bc = np.bincount(ChArr)
    # convert to bool and force it to have the right type
    ChMask = np.zeros(n_ch, dtype=np.bool8)
    ChMask[:len(bc)] = bc.astype(np.bool8)
    
    # Find peak sample:
    # 1. upsample channels we're using on thresholded range
    # 2. find weighted mean peak sample
    SampArrMin, SampArrMax = np.amin(SampArr)-3, np.amax(SampArr)+4
    WavePlus = get_padded(FilteredArr, SampArrMin, SampArrMax)
    WavePlus = WavePlus[:, ChMask]
    # upsample WavePlus
    upsampling_factor = Parameters['UPSAMPLING_FACTOR']
    if upsampling_factor>1:
        old_s = np.arange(WavePlus.shape[0])
        new_s_i = np.arange((WavePlus.shape[0]-1)*upsampling_factor+1)
        new_s = np.array(new_s_i*(1.0/upsampling_factor), dtype=np.float32)
        f = interp1d(old_s, WavePlus, bounds_error=True, kind='cubic', axis=0)
        UpsampledWavePlus = f(new_s)
    else:
        UpsampledWavePlus = WavePlus
    # find weighted mean peak for each channel above threshold
    if Parameters['USE_WEIGHTED_MEAN_PEAK_SAMPLE']:
        peak_sum = 0.0
        total_weight = 0.0
        for ch in xrange(WavePlus.shape[1]):
            X = UpsampledWavePlus[:, ch]
            if Parameters['DETECT_POSITIVE']:
                X = -np.abs(X)
            i_intpeak = np.argmin(X)
            left, right = i_intpeak-1, i_intpeak+2
            if right>len(X):
                left, right = left+len(X)-right, len(X)
            elif left<0:
                left, right = 0, right-left
            a_b_c = abc(np.arange(left, right, dtype=np.float32),
                        X[left:right])
            s_fracpeak = max_t(a_b_c)
            if Parameters['USE_SINGLE_THRESHOLD']:
                weight = -(X[i_intpeak]+Threshold)
            else:
                weight = -(X[i_intpeak]+Threshold[ch])
            if weight<0:
                weight = 0
            peak_sum += s_fracpeak*weight
            total_weight += weight
        s_fracpeak = (peak_sum/total_weight)
    else:
        if Parameters['DETECT_POSITIVE']:
            X = -np.abs(UpsampledWavePlus)
        else:
            X = UpsampledWavePlus
        s_fracpeak = 1.0*np.argmin(np.amin(X, axis=1))
    # s_fracpeak currently in coords of UpsampledWavePlus
    s_fracpeak = s_fracpeak/upsampling_factor
    # s_fracpeak now in coordinates of WavePlus
    s_fracpeak += SampArrMin
    # s_fracpeak now in coordinates of FilteredArr
    
    # get block of given size around peaksample
    try:
        s_peak = int(s_fracpeak)
    except ValueError:
        # This is a bit of a hack. Essentially, the problem here is that
        # s_fracpeak is a nan because the interpolation didn't work, and
        # therefore we want to skip the spike. There's already code in
        # core.extract_spikes that does this if a LinAlgError is raised,
        # so we just use that to skip this spike (and write a message to the
        # log).
        raise np.linalg.LinAlgError 
    WaveBlock = get_padded(FilteredArr,
                           s_peak-s_before-1, s_peak+s_after+2)
    
    # Perform interpolation around the fractional peak
    old_s = np.arange(s_peak-s_before-1, s_peak+s_after+2)
    new_s = np.arange(s_peak-s_before, s_peak+s_after)+(s_fracpeak-s_peak)
    f = interp1d(old_s, WaveBlock, bounds_error=True, kind='cubic', axis=0)
    Wave = f(new_s)
    
    return Wave, s_peak, ChMask
Exemplo n.º 2
0
def extract_wave(IndList, FilteredArr, s_before, s_after, n_ch, s_start,Threshold):
    '''
    Extract an aligned wave corresponding to a spike.
    
    Arguments:
    
    IndList
        A list of pairs (sample_number, channel_number) returned from the
        thresholding and flood filling algorithm
    FilteredArr
        An array of shape (numsamples, numchannels) containing the filtered
        wave data
    s_before, s_after
        The number of samples to return before and after the peak
        
    Returns a tuple (Wave, PeakSample, ST):
    
    Wave
        The wave aligned around the peak (with interpolation to give subsample
        alignment), consisting of s_before+s_after+1 samples.
    PeakSample
        The index of the peak sample in FilteredArr (the peak sample in Wave
        will always be s_before).
    ChMask
        The mask for this spike, a boolean array of length the number of
        channels, with value 1 if the channel is used and 0 otherwise.
    '''
    if Parameters['USE_WEIGHTED_MEAN_PEAK_SAMPLE'] or Parameters['UPSAMPLING_FACTOR']>1:
        return extract_wave_new(IndList, FilteredArr,
                                s_before, s_after, n_ch, s_start,Threshold)
    IndArr = np.array(IndList, dtype=np.int32)
    SampArr = IndArr[:, 0]
    log_fd = GlobalVariables['log_fd']
    if np.amax(SampArr)-np.amin(SampArr)>Parameters['CHUNK_OVERLAP']/2:
        s = '''
        ************ ERROR **********************************************
        Connected component found with width larger than CHUNK_OVERLAP/2.
        Spikes could be repeatedly detected, increase the size of
        CHUNK_OVERLAP and re-run.
        Component sample range: {sample_range}
        *****************************************************************
        '''.format(sample_range=(s_start+np.amin(SampArr),
                                 s_start+np.amax(SampArr)))
        log_warning(s, multiline=True)
        #exit()
    ChArr = IndArr[:, 1]
    n_ch = FilteredArr.shape[1]
    
    # Find peak sample and channel
    # TODO: argmin only works for negative threshold crossings
    PeakInd = FilteredArr[SampArr, ChArr].argmin()
    PeakSample, PeakChannel = SampArr[PeakInd], ChArr[PeakInd]
    
    # Ensure that we get a fixed size chunk of the wave, padded with zeroes if
    # the segment from PeakSample-s_before-1 to PeakSample+s_after+1 goes
    # outside the bounds of FilteredArr.
    WavePlus = get_padded(FilteredArr,
                          PeakSample-s_before-1, PeakSample+s_after+1)
    # Perform interpolation around the fractional peak
    Wave = interp_around_peak(WavePlus, s_before+1,
                              PeakChannel, s_before, s_after)
    # Return the aligned wave, the peak sample index and the associated mask
    # which is computed by counting the number of times each channel index
    # appears in IndList and then converting to a bool (so that channel i is
    # True if channel i features at least once).
    bc = np.bincount(ChArr)
    # convert to bool and force it to have the right type
    ChMask = np.zeros(n_ch, dtype=np.bool8)
    ChMask[:len(bc)] = bc.astype(np.bool8)
    
    return Wave, PeakSample, ChMask
Exemplo n.º 3
0
def extract_spikes(h5s, basename, DatFileNames, n_ch_dat,
                   ChannelsToUse, ChannelGraph,
                   max_spikes=None):
    # some global variables we use
    CHUNK_SIZE = Parameters['CHUNK_SIZE']
    CHUNKS_FOR_THRESH = Parameters['CHUNKS_FOR_THRESH']
    DTYPE = Parameters['DTYPE']
    CHUNK_OVERLAP = Parameters['CHUNK_OVERLAP']
    N_CH = Parameters['N_CH']
    S_JOIN_CC = Parameters['S_JOIN_CC']
    S_BEFORE = Parameters['S_BEFORE']
    S_AFTER = Parameters['S_AFTER']
    THRESH_SD = Parameters['THRESH_SD']
    THRESH_SD_LOWER = Parameters['THRESH_SD_LOWER']

    # filter coefficents for the high pass filtering
    filter_params = get_filter_params()
    print filter_params

    progress_bar = ProgressReporter()
    
    #m A code that writes out a high-pass filtered version of the raw data (.fil file)
    fil_writer = FilWriter(DatFileNames, n_ch_dat)

    # Just use first dat file for getting the thresholding data
    with open(DatFileNames[0], 'rb') as fd:
        # Use 5 chunks to figure out threshold
        DatChunk = get_chunk_for_thresholding(fd, n_ch_dat, ChannelsToUse,
                                              num_samples(DatFileNames[0],
                                                          n_ch_dat))
        FilteredChunk = apply_filtering(filter_params, DatChunk)
        # get the STD of the beginning of the filtered data
        if Parameters['USE_HILBERT']:
            first_chunks_std = np.std(FilteredChunk)
            print 'first_chunks_std',  first_chunks_std, '\n'
        else:
            if Parameters['USE_SINGLE_THRESHOLD']:
                ThresholdSDFactor = np.median(np.abs(FilteredChunk))/.6745
            else:
                ThresholdSDFactor = np.median(np.abs(FilteredChunk), axis=0)/.6745
            Threshold = ThresholdSDFactor*THRESH_SD
            print 'Threshold = ', Threshold, '\n' 
            Parameters['THRESHOLD'] = Threshold #Record the absolute Threshold used
            
        
    # set the high and low thresholds
    do_pickle = False
    if Parameters['USE_HILBERT']:
        ThresholdStrong = Parameters['THRESH_STRONG']
        ThresholdWeak = Parameters['THRESH_WEAK']
        do_pickle = True
    elif Parameters['USE_COMPONENT_ALIGNFLOATMASK']:#to be used with a single threshold only
        ThresholdStrong = Threshold
        ThresholdWeak = ThresholdSDFactor*THRESH_SD_LOWER
        do_pickle = True

    if do_pickle:
        picklefile =     open("threshold.p","wb")
        pickle.dump([ThresholdStrong,ThresholdWeak], picklefile)
        threshold_outputstring = 'Threshold strong = ' + repr(ThresholdStrong) + '\n' + 'Threshold weak = ' + repr(ThresholdWeak)
        log_message(threshold_outputstring)
        
    n_samples = num_samples(DatFileNames, n_ch_dat)
    spike_count = 0
    for (DatChunk, s_start, s_end,
         keep_start, keep_end) in chunks(DatFileNames, n_ch_dat, ChannelsToUse):
        ############## FILTERING ########################################
        FilteredChunk = apply_filtering(filter_params, DatChunk)
        
        # write filtered output to file
        if Parameters['WRITE_FIL_FILE']:
            fil_writer.write(FilteredChunk, s_start, s_end, keep_start, keep_end)

        ############## THRESHOLDING #####################################
        
        
        # NEW: HILBERT TRANSFORM
        if Parameters['USE_HILBERT']:
            FilteredChunkHilbert = np.abs(signal.hilbert(FilteredChunk, axis=0) / first_chunks_std) ** 2
            BinaryChunkWeak = FilteredChunkHilbert > ThresholdWeak
            BinaryChunkStrong = FilteredChunkHilbert > ThresholdStrong
            BinaryChunkWeak = BinaryChunkWeak.astype(np.int8)
            BinaryChunkStrong = BinaryChunkStrong.astype(np.int8)
        #elif Parameters['USE_COMPONENT_ALIGNFLOATMASK']:
        else: # Usual method
            #FilteredChunk = apply_filtering(filter_params, DatChunk) Why did you filter twice!!!???
            if Parameters['USE_COMPONENT_ALIGNFLOATMASK']:
                if Parameters['DETECT_POSITIVE']:
                    BinaryChunkWeak = FilteredChunk > ThresholdWeak
                    BinaryChunkStrong = FilteredChunk > ThresholdStrong
                else:
                    BinaryChunkWeak = FilteredChunk < -ThresholdWeak
                    BinaryChunkStrong = FilteredChunk < -ThresholdStrong
                BinaryChunkWeak = BinaryChunkWeak.astype(np.int8)
                BinaryChunkStrong = BinaryChunkStrong.astype(np.int8)
            else:
                if Parameters['DETECT_POSITIVE']:
                    BinaryChunk = np.abs(FilteredChunk)>Threshold
                else:
                    BinaryChunk = (FilteredChunk<-Threshold)
                BinaryChunk = BinaryChunk.astype(np.int8)
        # write filtered output to file
        #if Parameters['WRITE_FIL_FILE']:
        #    fil_writer.write(FilteredChunk, s_start, s_end, keep_start, keep_end)
        #    print 'I am here at line 313'

        ############### FLOOD FILL  ######################################
        ChannelGraphToUse = complete_if_none(ChannelGraph, N_CH)
        if (Parameters['USE_HILBERT'] or Parameters['USE_COMPONENT_ALIGNFLOATMASK']):
            if Parameters['USE_OLD_CC_CODE']:
                IndListsChunkOld = connected_components(BinaryChunkWeak,
                            ChannelGraphToUse, S_JOIN_CC)
                IndListsChunk = []  #Final list of connected components. Go through all \weak' connected components
            # and only include in final list if there are some samples that also exceed the strong threshold
            # This method works better than connected_components_twothresholds.
                for IndListWeak in IndListsChunkOld:
                   # embed()
#                    if sum(BinaryChunkStrong[zip(*IndListWeak)]) != 0:
                    i,j = np.array(IndListWeak).transpose()
                    if sum(BinaryChunkStrong[i,j]) != 0: 
                        IndListsChunk.append(IndListWeak)
            else:
                IndListsChunk = connected_components_twothresholds(BinaryChunkWeak, BinaryChunkStrong,
                            ChannelGraphToUse, S_JOIN_CC)
            BinaryChunk = 1 * BinaryChunkWeak + 1 * BinaryChunkStrong
        else:
            IndListsChunk = connected_components(BinaryChunk,
                            ChannelGraphToUse, S_JOIN_CC)
            
        
        if Parameters['DEBUG']:  #TO DO: Change plot_diagnostics for the HILBERT case
            if Parameters['USE_HILBERT']:
                plot_diagnostics_twothresholds(s_start,IndListsChunk,BinaryChunkWeak, BinaryChunkStrong,BinaryChunk,DatChunk,FilteredChunk,FilteredChunkHilbert,ThresholdStrong,ThresholdWeak)
            elif Parameters['USE_COMPONENT_ALIGNFLOATMASK']:
                plot_diagnostics_twothresholds(s_start,IndListsChunk,BinaryChunkWeak,BinaryChunkStrong,BinaryChunk,DatChunk,FilteredChunk,-FilteredChunk,ThresholdStrong,ThresholdWeak)#TODO: change HIlbert in plot_diagnostics_twothresholds
            else:
                plot_diagnostics(s_start,IndListsChunk,BinaryChunk,DatChunk,FilteredChunk,Threshold)
        if Parameters['WRITE_BINFIL_FILE']:
            fil_writer.write_bin(BinaryChunk, s_start, s_end, keep_start, keep_end)
        
        #print len(IndListsChunk), 'len(IndListsChunk)'
        ############## ALIGN AND INTERPOLATE WAVES #######################
        nextbits = []
        if Parameters['USE_HILBERT']:
            
            for IndList in IndListsChunk:
                try:
                    wave, s_peak, sf_peak, cm, fcm = extract_wave_hilbert_new(IndList, FilteredChunk,
                                                    FilteredChunkHilbert,
                                                    S_BEFORE, S_AFTER, N_CH,
                                                    s_start, ThresholdStrong, ThresholdWeak)
                    s_offset = s_start + s_peak
                    sf_offset = s_start + sf_peak
                    if keep_start<=s_offset<keep_end:
                        spike_count += 1
                        nextbits.append((wave, s_offset, sf_offset, cm, fcm))
                except np.linalg.LinAlgError:
                    s = '*** WARNING *** Unalignable spike discarded in chunk {chunk}.'.format(
                            chunk=(s_start, s_end))
                    log_warning(s)
                except InterpolationError:
                    s = '*** WARNING *** Interpolation error in chunk {chunk}.'.format(
                            chunk=(s_start, s_end))
                    log_warning(s)
            # and return them in time sorted order
            nextbits.sort(key=lambda (wave, s, s_frac, cm, fcm): s_frac)
            for wave, s, s_frac, cm, fcm in nextbits:
                uwave = get_padded(DatChunk, int(s)-S_BEFORE-s_start,
                                   int(s)+S_AFTER-s_start).astype(np.int32)
                # cm = add_penumbra(cm, ChannelGraphToUse,
                                  # Parameters['PENUMBRA_SIZE'])
                # fcm = get_float_mask(wave, cm, ChannelGraphToUse,
                                     # 1.)
                yield uwave, wave, s, s_frac, cm, fcm
                # unfiltered wave,wave, s_peak, ChMask, FloatChMask
        elif Parameters['USE_COMPONENT_ALIGNFLOATMASK']:
            for IndList in IndListsChunk:
                try:
                    if Parameters['DETECT_POSITIVE']:
                        wave, s_peak, sf_peak, cm, fcm, comp_normalised, comp_normalised_power = extract_wave_twothresholds(IndList, FilteredChunk,
                                                    FilteredChunk,
                                                    S_BEFORE, S_AFTER, N_CH,
                                                    s_start, ThresholdStrong, ThresholdWeak) 
                    else:
                        wave, s_peak, sf_peak, cm, fcm,comp_normalised, comp_normalised_power = extract_wave_twothresholds(IndList, FilteredChunk,
                                                    -FilteredChunk,
                                                    S_BEFORE, S_AFTER, N_CH,
                                                    s_start, ThresholdStrong, ThresholdWeak)
                    s_offset = s_start+s_peak
                    sf_offset = s_start + sf_peak
                    if keep_start<=s_offset<keep_end:
                        spike_count += 1
                        nextbits.append((wave, s_offset, sf_offset, cm, fcm))
                except np.linalg.LinAlgError:
                    s = '*** WARNING *** Unalignable spike discarded in chunk {chunk}.'.format(
                            chunk=(s_start, s_end))
                    log_warning(s)
                except InterpolationError:
                    s = '*** WARNING *** Interpolation error in chunk {chunk}.'.format(
                            chunk=(s_start, s_end))
                    log_warning(s)
            # and return them in time sorted order
            nextbits.sort(key=lambda (wave, s, s_frac, cm, fcm): s_frac)
            for wave, s, s_frac, cm, fcm in nextbits:
                uwave = get_padded(DatChunk, int(s)-S_BEFORE-s_start,
                                   int(s)+S_AFTER-s_start).astype(np.int32)
                # cm = add_penumbra(cm, ChannelGraphToUse,
                                  # Parameters['PENUMBRA_SIZE'])
                # fcm = get_float_mask(wave, cm, ChannelGraphToUse,
                                     # 1.)
                yield uwave, wave, s, s_frac, cm, fcm   
                # unfiltered wave,wave, s_peak, ChMask, FloatChMask
        else:    #Original SpikeDetekt. This code duplication is regretable but probably easier to deal with
            
            for IndList in IndListsChunk:
                try:
                    wave, s_peak, sf_peak, cm = extract_wave(IndList, FilteredChunk,
                                                    S_BEFORE, S_AFTER, N_CH,
                                                    s_start,Threshold)
                    s_offset = s_start+s_peak
                    sf_offset = s_start + sf_peak
                    if keep_start<=s_offset<keep_end:
                        spike_count += 1
                        nextbits.append((wave, s_offset, sf_offset, cm))
                except np.linalg.LinAlgError:
                    s = '*** WARNING *** Unalignable spike discarded in chunk {chunk}.'.format(
                            chunk=(s_start, s_end))
                    log_warning(s)
            # and return them in time sorted order
            nextbits.sort(key=lambda (wave, s, s_frac, cm): s_frac)
            for wave, s, s_frac, cm in nextbits:
                uwave = get_padded(DatChunk, int(s)-S_BEFORE-s_start,
                                   int(s)+S_AFTER-s_start).astype(np.int32)
                cm = add_penumbra(cm, ChannelGraphToUse,
                                  Parameters['PENUMBRA_SIZE'])
                fcm = get_float_mask(wave, cm, ChannelGraphToUse,
                                     ThresholdSDFactor)
                yield uwave, wave, s, s_frac, cm, fcm    
                # unfiltered wave,wave, s_peak, ChMask, FloatChMask

        progress_bar.update(float(s_end)/n_samples,
            '%d/%d samples, %d spikes found'%(s_end, n_samples, spike_count))
        if max_spikes is not None and spike_count>=max_spikes:
            break
    
    progress_bar.finish()
Exemplo n.º 4
0
def extract_wave_hilbert_old(IndList, FilteredArr, FilteredHilbertArr, s_before, 
                     s_after, n_ch, s_start, ThresholdStrong, ThresholdWeak):
    IndArr = np.array(IndList, dtype=np.int32)
    SampArr = IndArr[:, 0]
    ChArr = IndArr[:, 1]
    n_ch = FilteredArr.shape[1]
    log_fd = GlobalVariables['log_fd']
    if np.amax(SampArr)-np.amin(SampArr)>Parameters['CHUNK_OVERLAP']/2:
        s = '''
        ************ ERROR **********************************************
        Connected component found with width larger than CHUNK_OVERLAP/2.
        Spikes could be repeatedly detected, increase the size of
        CHUNK_OVERLAP and re-run.
        Component sample range: {sample_range}
        *****************************************************************
        '''.format(sample_range=(s_start+np.amin(SampArr),
                                 s_start+np.amax(SampArr)))
        log_warning(s, multiline=True)
        #exit()

    bc = np.bincount(ChArr)
    # convert to bool and force it to have the right type
    ChMask = np.zeros(n_ch, dtype=np.bool8)
    ChMask[:len(bc)] = bc.astype(np.bool8)
    n_unmasked_ch = np.sum(ChMask)
    
    # Find peak sample:
    # 1. upsample channels we're using on thresholded range
    # 2. find weighted mean peak sample
    SampArrMin, SampArrMax = np.amin(SampArr)-3, np.amax(SampArr)+4
    # ChArrMin, ChArrMax = np.amin(ChArr), np.amax(ChArr)
    
    
    WavePlus = get_padded(FilteredArr, SampArrMin, SampArrMax)
    WavePlus = WavePlus[:, ChMask]
    
    # upsample WavePlus
    upsampling_factor = Parameters['UPSAMPLING_FACTOR']
    if upsampling_factor>1:
        old_s = np.arange(WavePlus.shape[0])
        new_s_i = np.arange((WavePlus.shape[0]-1)*upsampling_factor+1)
        new_s = np.array(new_s_i*(1.0/upsampling_factor), dtype=np.float32)
        f = interp1d(old_s, WavePlus, bounds_error=True, kind='cubic', axis=0)
        UpsampledWavePlus = f(new_s)
    else:
        UpsampledWavePlus = WavePlus
        
    # find weighted mean peak for each channel above threshold
    if Parameters['USE_WEIGHTED_MEAN_PEAK_SAMPLE']:
        peak_sum = 0.0
        total_weight = 0.0
        for ch in xrange(WavePlus.shape[1]):
            X = UpsampledWavePlus[:, ch]
            if Parameters['DETECT_POSITIVE']:
                X = -np.abs(X)
            i_intpeak = np.argmin(X)
            left, right = i_intpeak-1, i_intpeak+2
            if right>len(X):
                left, right = left+len(X)-right, len(X)
            elif left<0:
                left, right = 0, right-left
            a_b_c = abc(np.arange(left, right, dtype=np.float32),
                        X[left:right])
            s_fracpeak = max_t(a_b_c)
            weight = -X[i_intpeak]
            if weight<0:
                weight = 0
            peak_sum += s_fracpeak*weight
            total_weight += weight
        s_fracpeak = (peak_sum/total_weight)
    else:
        if Parameters['DETECT_POSITIVE']:
            X = -np.abs(UpsampledWavePlus)
        else:
            X = UpsampledWavePlus
        s_fracpeak = 1.0*np.argmin(np.amin(X, axis=1))
        
    # s_fracpeak currently in coords of UpsampledWavePlus
    s_fracpeak = s_fracpeak/upsampling_factor
    # s_fracpeak now in coordinates of WavePlus
    s_fracpeak += SampArrMin
    # s_fracpeak now in coordinates of FilteredArr
    
    
    
    #################################
    # NEW: FLOAT MASK
    #################################
    # connected component as window in chunk with Hilbert
    # contains values only on weak threshold-exceeding points, 
    # zeros everywhere else
    comp = np.zeros((SampArrMax - SampArrMin, n_ch), dtype=FilteredHilbertArr.dtype)
    comp[SampArr - SampArrMin, ChArr] = FilteredHilbertArr[SampArr, ChArr]
    # 1D array: for each channel, the peak of the Hilbert, relative to the
    # start of the chunk
    peaks = np.argmax(comp, axis=0) + SampArrMin
    # 1D array: values of the peaks, on each channel
    peaks_values = FilteredHilbertArr[peaks, np.arange(0, n_ch)] * ChMask
    FloatChMask = np.clip((peaks_values - ThresholdWeak) / (ThresholdStrong - ThresholdWeak), 0, 1)
    
    
    
    # #################################
    # # New alignment
    # #################################
    # # In the window of the chunk (connected component), we take the clipped Hilbert 
    # # (masks between 0 and 1).
    # comp_clipped = np.clip((comp - ThresholdWeak) / (ThresholdStrong - ThresholdWeak), 0, 1)
    # # now we take the weighted average of the sample times in the component
    # s_fracpeak = np.sum(comp_clipped * np.arange(SampArrMax - SampArrMin).reshape((-1, 1))) / np.sum(comp_clipped)
    # s_fracpeak += SampArrMin
    
    
    #################################
    # Realign spike with respect to s_fracpeak
    #################################
    # get block of given size around peaksample
    try:
        s_peak = int(s_fracpeak)
    except ValueError:
        # This is a bit of a hack. Essentially, the problem here is that
        # s_fracpeak is a nan because the interpolation didn't work, and
        # therefore we want to skip the spike. There's already code in
        # core.extract_spikes that does this if a LinAlgError is raised,
        # so we just use that to skip this spike (and write a message to the
        # log).
        raise np.linalg.LinAlgError 
    WaveBlock = get_padded(FilteredArr,
                           s_peak-s_before-1, s_peak+s_after+2)
    # Perform interpolation around the fractional peak
    old_s = np.arange(s_peak-s_before-1, s_peak+s_after+2)
    new_s = np.arange(s_peak-s_before, s_peak+s_after)+(s_fracpeak-s_peak)
    try:
        f = interp1d(old_s, WaveBlock, bounds_error=True, kind='cubic', axis=0)
    except ValueError: 
        #  File "/usr/lib/python2.7/dist-packages/scipy/interpolate/interpolate.py", line 509, in _dot0
        #  return dot(a, b)
        #ValueError: matrices are not aligned
        raise InterpolationError
    Wave = f(new_s)
    
    
    
    return Wave, s_peak, s_fracpeak, ChMask, FloatChMask
Exemplo n.º 5
0
def extract_wave_hilbert_new(IndList, FilteredArr, FilteredHilbertArr, s_before, 
                     s_after, n_ch, s_start, ThresholdStrong, ThresholdWeak):
    IndArr = np.array(IndList, dtype=np.int32)
    SampArr = IndArr[:, 0]
    ChArr = IndArr[:, 1]
    n_ch = FilteredArr.shape[1]
    log_fd = GlobalVariables['log_fd']
    if np.amax(SampArr)-np.amin(SampArr)>Parameters['CHUNK_OVERLAP']/2:
        s = '''
        ************ ERROR **********************************************
        Connected component found with width larger than CHUNK_OVERLAP/2.
        Spikes could be repeatedly detected, increase the size of
        CHUNK_OVERLAP and re-run.
        Component sample range: {sample_range}
        *****************************************************************
        '''.format(sample_range=(s_start+np.amin(SampArr),
                                 s_start+np.amax(SampArr)))
        log_warning(s, multiline=True)
        #exit()

    bc = np.bincount(ChArr)
    # convert to bool and force it to have the right type
    ChMask = np.zeros(n_ch, dtype=np.bool8)
    ChMask[:len(bc)] = bc.astype(np.bool8)
    n_unmasked_ch = np.sum(ChMask)
    
    # Find peak sample:
    # 1. upsample channels we're using on thresholded range
    # 2. find weighted mean peak sample
    SampArrMin, SampArrMax = np.amin(SampArr)-3, np.amax(SampArr)+4
    # ChArrMin, ChArrMax = np.amin(ChArr), np.amax(ChArr)
    
    
    #################################
    # NEW: FLOAT MASK
    #################################
    # connected component as window in chunk with Hilbert
    # contains values only on weak threshold-exceeding points, 
    # zeros everywhere else
    comp = np.zeros((SampArrMax - SampArrMin, n_ch), dtype=FilteredHilbertArr.dtype)
    comp[SampArr - SampArrMin, ChArr] = FilteredHilbertArr[SampArr, ChArr]
    # 1D array: for each channel, the peak of the Hilbert, relative to the
    # start of the chunk
    peaks = np.argmax(comp, axis=0) + SampArrMin
    # 1D array: values of the peaks, on each channel
    peaks_values = FilteredHilbertArr[peaks, np.arange(0, n_ch)] * ChMask
    FloatChMask = np.clip((peaks_values - ThresholdWeak) / (ThresholdStrong - ThresholdWeak), 0, 1)
    #embed()
    
    
    #################################
    # New alignment
    #################################
    # In the window of the chunk (connected component), we take the clipped Hilbert 
    # (masks between 0 and 1).
    
    comp_clipped = np.clip((comp - ThresholdWeak) / (ThresholdStrong - ThresholdWeak), 0, 1)
    # No need to clip - might makes things worse - you lose the peaks!
    comp_normalised = (comp - ThresholdWeak) / (ThresholdStrong - ThresholdWeak)
    
    # now we take the weighted average of the sample times in the component
    s_fracpeak = np.sum(comp_normalised * np.arange(SampArrMax - SampArrMin).reshape((-1, 1))) / np.sum(comp_normalised)
    s_fracpeak += SampArrMin
    
    
    #################################
    # Realign spike with respect to s_fracpeak
    #################################
    # get block of given size around peaksample
    try:
        s_peak = int(s_fracpeak)
    except ValueError:
        # This is a bit of a hack. Essentially, the problem here is that
        # s_fracpeak is a nan because the interpolation didn't work, and
        # therefore we want to skip the spike. There's already code in
        # core.extract_spikes that does this if a LinAlgError is raised,
        # so we just use that to skip this spike (and write a message to the
        # log).
        raise np.linalg.LinAlgError 
    WaveBlock = get_padded(FilteredArr,
                           s_peak-s_before-1, s_peak+s_after+2)
    # Perform interpolation around the fractional peak
    old_s = np.arange(s_peak-s_before-1, s_peak+s_after+2)
    new_s = np.arange(s_peak-s_before, s_peak+s_after)+(s_fracpeak-s_peak)
    try:
        f = interp1d(old_s, WaveBlock, bounds_error=True, kind='cubic', axis=0)
    except ValueError: 
        #  File "/usr/lib/python2.7/dist-packages/scipy/interpolate/interpolate.py", line 509, in _dot0
        #  return dot(a, b)
        #ValueError: matrices are not aligned
        raise InterpolationError
    Wave = f(new_s)
    
    return Wave, s_peak, s_fracpeak, ChMask, FloatChMask
Exemplo n.º 6
0
def extract_spikes(h5s, basename, DatFileNames, n_ch_dat,
                   ChannelsToUse, ChannelGraph,
                   max_spikes=None):
    # some global variables we use
    CHUNK_SIZE = Parameters['CHUNK_SIZE']
    CHUNKS_FOR_THRESH = Parameters['CHUNKS_FOR_THRESH']
    DTYPE = Parameters['DTYPE']
    CHUNK_OVERLAP = Parameters['CHUNK_OVERLAP']
    N_CH = Parameters['N_CH']
    S_JOIN_CC = Parameters['S_JOIN_CC']
    S_BEFORE = Parameters['S_BEFORE']
    S_AFTER = Parameters['S_AFTER']
    THRESH_SD = Parameters['THRESH_SD']

    # filter coefficents for the high pass filtering
    filter_params = get_filter_params()

    progress_bar = ProgressReporter()

    # m A code that writes out a high-pass filtered version of the raw data
    # (.fil file)
    fil_writer = FilWriter(DatFileNames, n_ch_dat)

    # Just use first dat file for getting the thresholding data
    with open(DatFileNames[0], 'rb') as fd:
        # Use 5 chunks to figure out threshold
        DatChunk = get_chunk_for_thresholding(fd, n_ch_dat, ChannelsToUse,
                                              num_samples(DatFileNames[0],
                                                          n_ch_dat))
        FilteredChunk = apply_filtering(filter_params, DatChunk)
        # .6745 converts median to standard deviation
        if Parameters['USE_SINGLE_THRESHOLD']:
            ThresholdSDFactor = np.median(np.abs(FilteredChunk)) / .6745
        else:
            ThresholdSDFactor = np.median(
                np.abs(FilteredChunk),
                axis=0) / .6745
        Threshold = ThresholdSDFactor * THRESH_SD

        print 'Threshold = ', Threshold, '\n'
        # Record the absolute Threshold used
        Parameters['THRESHOLD'] = Threshold

    n_samples = num_samples(DatFileNames, n_ch_dat)

    spike_count = 0
    for (DatChunk, s_start, s_end,
         keep_start, keep_end) in chunks(DatFileNames, n_ch_dat, ChannelsToUse):
        ############## FILTERING ########################################
        FilteredChunk = apply_filtering(filter_params, DatChunk)

        # write filtered output to file
        # if Parameters['WRITE_FIL_FILE']:
        fil_writer.write(FilteredChunk, s_start, s_end, keep_start, keep_end)

        ############## THRESHOLDING #####################################
        if Parameters['DETECT_POSITIVE']:
            BinaryChunk = np.abs(FilteredChunk) > Threshold
        else:
            BinaryChunk = (FilteredChunk < -Threshold)
        BinaryChunk = BinaryChunk.astype(np.int8)
        # write binary chunk filtered output to file
        if Parameters['WRITE_BINFIL_FILE']:
            fil_writer.write_bin(
                BinaryChunk,
                s_start,
                s_end,
                keep_start,
                keep_end)
        ############### FLOOD FILL  ######################################
        ChannelGraphToUse = complete_if_none(ChannelGraph, N_CH)
        IndListsChunk = connected_components(BinaryChunk,
                                             ChannelGraphToUse, S_JOIN_CC)
        if Parameters['DEBUG']:
            plot_diagnostics(
                s_start,
                IndListsChunk,
                BinaryChunk,
                DatChunk,
                FilteredChunk,
                Threshold)
            fil_writer.write_bin(
                BinaryChunk,
                s_start,
                s_end,
                keep_start,
                keep_end)

        ############## ALIGN AND INTERPOLATE WAVES #######################
        nextbits = []
        for IndList in IndListsChunk:
            try:
                wave, s_peak, cm = extract_wave(IndList, FilteredChunk,
                                                S_BEFORE, S_AFTER, N_CH,
                                                s_start, Threshold)
                s_offset = s_start + s_peak
                if keep_start <= s_offset < keep_end:
                    spike_count += 1
                    nextbits.append((wave, s_offset, cm))
            except np.linalg.LinAlgError:
                s = '*** WARNING *** Unalignable spike discarded in chunk {chunk}.'.format(
                    chunk=(s_start, s_end))
                log_warning(s)
        # and return them in time sorted order
        nextbits.sort(key=lambda wave_s_cm: wave_s_cm[1])
        for wave, s, cm in nextbits:
            uwave = get_padded(DatChunk, int(s) - S_BEFORE - s_start,
                               int(s) + S_AFTER - s_start).astype(np.int32)
            cm = add_penumbra(cm, ChannelGraphToUse,
                              Parameters['PENUMBRA_SIZE'])
            fcm = get_float_mask(wave, cm, ChannelGraphToUse,
                                 ThresholdSDFactor)
            yield uwave, wave, s, cm, fcm
        progress_bar.update(float(s_end) / n_samples,
                            '%d/%d samples, %d spikes found' % (s_end, n_samples, spike_count))
        if max_spikes is not None and spike_count >= max_spikes:
            break

    progress_bar.finish()