Пример #1
0
def spike_detection_from_raw_data(basename, DatFileNames, n_ch_dat, Channels_dat,
                                  ChannelGraph, probe, max_spikes):
    """
    Filter, detect, extract from raw data.
    """
    ### Detect spikes. For each detected spike, send it to spike writer, which
    ### writes it to a spk file. List of times is small (memorywise) so we just
    ### store the list and write it later.

    np.savetxt("dat_channels.txt", Channels_dat, fmt="%i")
    
    # Create HDF5 files
    h5s = {}
    h5s_filenames = {}
    for n in ['main', 'waves']:
        filename = basename+'.'+n+'.h5'
        h5s[n] = tables.openFile(filename, 'w')
        h5s_filenames[n] = filename
    for n in ['raw', 'high', 'low']:
        if Parameters['RECORD_'+n.upper()]:
            filename = basename+'.'+n+'.h5'
            h5s[n] = tables.openFile(filename, 'w')
            h5s_filenames[n] = filename
    main_h5 = h5s['main']
    # Shanks groups
    shanks_group = {}
    shank_group = {}
    shank_table = {}
    for k in ['main', 'waves']:
        h5 = h5s[k]
        shanks_group[k] = h5.createGroup('/', 'shanks')
        for i in probe.shanks_set:
            shank_group[k, i] = h5.createGroup(shanks_group[k], 'shank_'+str(i))
    # waveform data for wave file
    for i in probe.shanks_set:
        shank_table['waveforms', i] = h5s['waves'].createTable(
            shank_group['waves', i], 'waveforms',
            waveform_description(len(probe.channel_set[i])))
    # spikedetekt data for main file, and links to waveforms
    for i in probe.shanks_set:
        shank_table['spikedetekt', i] = main_h5.createTable(shank_group['main', i],
            'spikedetekt', shank_description(len(probe.channel_set[i])))
        main_h5.createExternalLink(shank_group['main', i], 'waveforms', 
                                   shank_table['waveforms', i])
    # Metadata
    n_samples = np.array([num_samples(DatFileName, n_ch_dat) for DatFileName in DatFileNames])
    for k, h5 in h5s.items():
        metadata_group = h5.createGroup('/', 'metadata')
        parameters_group = h5.createGroup(metadata_group, 'parameters')
        for k, v in Parameters.items():
            if not k.startswith('_'):
                if isinstance(v, bool):
                    r = int(v)
                elif isinstance(v, (int, float)):
                    r = v
                else:
                    r = repr(v)
                h5.setNodeAttr(parameters_group, k, r)
        h5.setNodeAttr(metadata_group, 'probe', json.dumps(probe.probes))
        h5.createArray(metadata_group, 'datfiles_offsets_samples',
                       np.hstack((0, np.cumsum(n_samples)))[:-1])
    
    ########## MAIN TIME CONSUMING LOOP OF PROGRAM ########################
    for (USpk, Spk, PeakSample,FracPeakSample,
         ChannelMask, FloatChannelMask) in extract_spikes(h5s, basename,
                                                          DatFileNames,
                                                          n_ch_dat,
                                                          Channels_dat,
                                                          ChannelGraph,
                                                          max_spikes,
                                                          ):
        # what shank are we in?
        nzc, = ChannelMask.nonzero()
        internzc = list(set(nzc).intersection(probe.channel_to_shank.keys()))
        if internzc:
            shank = probe.channel_to_shank[internzc[0]]
        else:
            continue
        # write only the channels of this shank
        channel_list = np.array(sorted(list(probe.channel_set[shank])))
        t = shank_table['spikedetekt', shank]
        t.row['time'] = PeakSample
        t.row['float_time'] = FracPeakSample
        t.row['mask_binary'] = ChannelMask[channel_list]
        t.row['mask_float'] = FloatChannelMask[channel_list]
        t.row.append()
        # and the waveforms
        t = shank_table['waveforms', shank]
        t.row['wave'] = Spk[:, channel_list]
        t.row['unfiltered_wave'] = USpk[:, channel_list]
        t.row.append()
        
    for h5 in h5s.values():
        h5.flush()

    # Feature extraction
    for shank in probe.shanks_set:
        X = shank_table['waveforms', shank].cols.wave[:Parameters['PCA_MAXWAVES']]
        if len(X) == 0:
            continue
        PC_3s = reget_features(X)
        for sd_row, w_row in izip(shank_table['spikedetekt', shank],
                                  shank_table['waveforms', shank]):
            ##embed()
            f = project_features(PC_3s, w_row['wave'])
            
            ### NEW
            # add PCA components
            sd_row['PC_3s'] = PC_3s.flatten()
            
            sd_row['features'] = np.hstack((f.flatten(), sd_row['time']))
            sd_row.update()
            
    main_h5.flush()
            
    klusters_files(h5s, shank_table, basename, probe)

#    for h5 in h5s.values():
#        h5.close()

    for key, h5 in h5s.iteritems():
        h5.close()
        if not Parameters['KEEP_OLD_HDF5_FILES']:
            # NEW: erase the HDF5 files at the end, because we're using a direct 
            # conversion tool in KlustaViewa for now.
            os.remove(h5s_filenames[key])
Пример #2
0
def extract_spikes(h5s, basename, DatFileNames, n_ch_dat,
                   ChannelsToUse, ChannelGraph,
                   max_spikes=None):
    # some global variables we use
    CHUNK_SIZE = Parameters['CHUNK_SIZE']
    CHUNKS_FOR_THRESH = Parameters['CHUNKS_FOR_THRESH']
    DTYPE = Parameters['DTYPE']
    CHUNK_OVERLAP = Parameters['CHUNK_OVERLAP']
    N_CH = Parameters['N_CH']
    S_JOIN_CC = Parameters['S_JOIN_CC']
    S_BEFORE = Parameters['S_BEFORE']
    S_AFTER = Parameters['S_AFTER']
    THRESH_SD = Parameters['THRESH_SD']
    THRESH_SD_LOWER = Parameters['THRESH_SD_LOWER']

    # filter coefficents for the high pass filtering
    filter_params = get_filter_params()
    print filter_params

    progress_bar = ProgressReporter()
    
    #m A code that writes out a high-pass filtered version of the raw data (.fil file)
    fil_writer = FilWriter(DatFileNames, n_ch_dat)

    # Just use first dat file for getting the thresholding data
    with open(DatFileNames[0], 'rb') as fd:
        # Use 5 chunks to figure out threshold
        DatChunk = get_chunk_for_thresholding(fd, n_ch_dat, ChannelsToUse,
                                              num_samples(DatFileNames[0],
                                                          n_ch_dat))
        FilteredChunk = apply_filtering(filter_params, DatChunk)
        # get the STD of the beginning of the filtered data
        if Parameters['USE_HILBERT']:
            first_chunks_std = np.std(FilteredChunk)
            print 'first_chunks_std',  first_chunks_std, '\n'
        else:
            if Parameters['USE_SINGLE_THRESHOLD']:
                ThresholdSDFactor = np.median(np.abs(FilteredChunk))/.6745
            else:
                ThresholdSDFactor = np.median(np.abs(FilteredChunk), axis=0)/.6745
            Threshold = ThresholdSDFactor*THRESH_SD
            print 'Threshold = ', Threshold, '\n' 
            Parameters['THRESHOLD'] = Threshold #Record the absolute Threshold used
            
        
    # set the high and low thresholds
    do_pickle = False
    if Parameters['USE_HILBERT']:
        ThresholdStrong = Parameters['THRESH_STRONG']
        ThresholdWeak = Parameters['THRESH_WEAK']
        do_pickle = True
    elif Parameters['USE_COMPONENT_ALIGNFLOATMASK']:#to be used with a single threshold only
        ThresholdStrong = Threshold
        ThresholdWeak = ThresholdSDFactor*THRESH_SD_LOWER
        do_pickle = True

    if do_pickle:
        picklefile =     open("threshold.p","wb")
        pickle.dump([ThresholdStrong,ThresholdWeak], picklefile)
        threshold_outputstring = 'Threshold strong = ' + repr(ThresholdStrong) + '\n' + 'Threshold weak = ' + repr(ThresholdWeak)
        log_message(threshold_outputstring)
        
    n_samples = num_samples(DatFileNames, n_ch_dat)
    spike_count = 0
    for (DatChunk, s_start, s_end,
         keep_start, keep_end) in chunks(DatFileNames, n_ch_dat, ChannelsToUse):
        ############## FILTERING ########################################
        FilteredChunk = apply_filtering(filter_params, DatChunk)
        
        # write filtered output to file
        if Parameters['WRITE_FIL_FILE']:
            fil_writer.write(FilteredChunk, s_start, s_end, keep_start, keep_end)

        ############## THRESHOLDING #####################################
        
        
        # NEW: HILBERT TRANSFORM
        if Parameters['USE_HILBERT']:
            FilteredChunkHilbert = np.abs(signal.hilbert(FilteredChunk, axis=0) / first_chunks_std) ** 2
            BinaryChunkWeak = FilteredChunkHilbert > ThresholdWeak
            BinaryChunkStrong = FilteredChunkHilbert > ThresholdStrong
            BinaryChunkWeak = BinaryChunkWeak.astype(np.int8)
            BinaryChunkStrong = BinaryChunkStrong.astype(np.int8)
        #elif Parameters['USE_COMPONENT_ALIGNFLOATMASK']:
        else: # Usual method
            #FilteredChunk = apply_filtering(filter_params, DatChunk) Why did you filter twice!!!???
            if Parameters['USE_COMPONENT_ALIGNFLOATMASK']:
                if Parameters['DETECT_POSITIVE']:
                    BinaryChunkWeak = FilteredChunk > ThresholdWeak
                    BinaryChunkStrong = FilteredChunk > ThresholdStrong
                else:
                    BinaryChunkWeak = FilteredChunk < -ThresholdWeak
                    BinaryChunkStrong = FilteredChunk < -ThresholdStrong
                BinaryChunkWeak = BinaryChunkWeak.astype(np.int8)
                BinaryChunkStrong = BinaryChunkStrong.astype(np.int8)
            else:
                if Parameters['DETECT_POSITIVE']:
                    BinaryChunk = np.abs(FilteredChunk)>Threshold
                else:
                    BinaryChunk = (FilteredChunk<-Threshold)
                BinaryChunk = BinaryChunk.astype(np.int8)
        # write filtered output to file
        #if Parameters['WRITE_FIL_FILE']:
        #    fil_writer.write(FilteredChunk, s_start, s_end, keep_start, keep_end)
        #    print 'I am here at line 313'

        ############### FLOOD FILL  ######################################
        ChannelGraphToUse = complete_if_none(ChannelGraph, N_CH)
        if (Parameters['USE_HILBERT'] or Parameters['USE_COMPONENT_ALIGNFLOATMASK']):
            if Parameters['USE_OLD_CC_CODE']:
                IndListsChunkOld = connected_components(BinaryChunkWeak,
                            ChannelGraphToUse, S_JOIN_CC)
                IndListsChunk = []  #Final list of connected components. Go through all \weak' connected components
            # and only include in final list if there are some samples that also exceed the strong threshold
            # This method works better than connected_components_twothresholds.
                for IndListWeak in IndListsChunkOld:
                   # embed()
#                    if sum(BinaryChunkStrong[zip(*IndListWeak)]) != 0:
                    i,j = np.array(IndListWeak).transpose()
                    if sum(BinaryChunkStrong[i,j]) != 0: 
                        IndListsChunk.append(IndListWeak)
            else:
                IndListsChunk = connected_components_twothresholds(BinaryChunkWeak, BinaryChunkStrong,
                            ChannelGraphToUse, S_JOIN_CC)
            BinaryChunk = 1 * BinaryChunkWeak + 1 * BinaryChunkStrong
        else:
            IndListsChunk = connected_components(BinaryChunk,
                            ChannelGraphToUse, S_JOIN_CC)
            
        
        if Parameters['DEBUG']:  #TO DO: Change plot_diagnostics for the HILBERT case
            if Parameters['USE_HILBERT']:
                plot_diagnostics_twothresholds(s_start,IndListsChunk,BinaryChunkWeak, BinaryChunkStrong,BinaryChunk,DatChunk,FilteredChunk,FilteredChunkHilbert,ThresholdStrong,ThresholdWeak)
            elif Parameters['USE_COMPONENT_ALIGNFLOATMASK']:
                plot_diagnostics_twothresholds(s_start,IndListsChunk,BinaryChunkWeak,BinaryChunkStrong,BinaryChunk,DatChunk,FilteredChunk,-FilteredChunk,ThresholdStrong,ThresholdWeak)#TODO: change HIlbert in plot_diagnostics_twothresholds
            else:
                plot_diagnostics(s_start,IndListsChunk,BinaryChunk,DatChunk,FilteredChunk,Threshold)
        if Parameters['WRITE_BINFIL_FILE']:
            fil_writer.write_bin(BinaryChunk, s_start, s_end, keep_start, keep_end)
        
        #print len(IndListsChunk), 'len(IndListsChunk)'
        ############## ALIGN AND INTERPOLATE WAVES #######################
        nextbits = []
        if Parameters['USE_HILBERT']:
            
            for IndList in IndListsChunk:
                try:
                    wave, s_peak, sf_peak, cm, fcm = extract_wave_hilbert_new(IndList, FilteredChunk,
                                                    FilteredChunkHilbert,
                                                    S_BEFORE, S_AFTER, N_CH,
                                                    s_start, ThresholdStrong, ThresholdWeak)
                    s_offset = s_start + s_peak
                    sf_offset = s_start + sf_peak
                    if keep_start<=s_offset<keep_end:
                        spike_count += 1
                        nextbits.append((wave, s_offset, sf_offset, cm, fcm))
                except np.linalg.LinAlgError:
                    s = '*** WARNING *** Unalignable spike discarded in chunk {chunk}.'.format(
                            chunk=(s_start, s_end))
                    log_warning(s)
                except InterpolationError:
                    s = '*** WARNING *** Interpolation error in chunk {chunk}.'.format(
                            chunk=(s_start, s_end))
                    log_warning(s)
            # and return them in time sorted order
            nextbits.sort(key=lambda (wave, s, s_frac, cm, fcm): s_frac)
            for wave, s, s_frac, cm, fcm in nextbits:
                uwave = get_padded(DatChunk, int(s)-S_BEFORE-s_start,
                                   int(s)+S_AFTER-s_start).astype(np.int32)
                # cm = add_penumbra(cm, ChannelGraphToUse,
                                  # Parameters['PENUMBRA_SIZE'])
                # fcm = get_float_mask(wave, cm, ChannelGraphToUse,
                                     # 1.)
                yield uwave, wave, s, s_frac, cm, fcm
                # unfiltered wave,wave, s_peak, ChMask, FloatChMask
        elif Parameters['USE_COMPONENT_ALIGNFLOATMASK']:
            for IndList in IndListsChunk:
                try:
                    if Parameters['DETECT_POSITIVE']:
                        wave, s_peak, sf_peak, cm, fcm, comp_normalised, comp_normalised_power = extract_wave_twothresholds(IndList, FilteredChunk,
                                                    FilteredChunk,
                                                    S_BEFORE, S_AFTER, N_CH,
                                                    s_start, ThresholdStrong, ThresholdWeak) 
                    else:
                        wave, s_peak, sf_peak, cm, fcm,comp_normalised, comp_normalised_power = extract_wave_twothresholds(IndList, FilteredChunk,
                                                    -FilteredChunk,
                                                    S_BEFORE, S_AFTER, N_CH,
                                                    s_start, ThresholdStrong, ThresholdWeak)
                    s_offset = s_start+s_peak
                    sf_offset = s_start + sf_peak
                    if keep_start<=s_offset<keep_end:
                        spike_count += 1
                        nextbits.append((wave, s_offset, sf_offset, cm, fcm))
                except np.linalg.LinAlgError:
                    s = '*** WARNING *** Unalignable spike discarded in chunk {chunk}.'.format(
                            chunk=(s_start, s_end))
                    log_warning(s)
                except InterpolationError:
                    s = '*** WARNING *** Interpolation error in chunk {chunk}.'.format(
                            chunk=(s_start, s_end))
                    log_warning(s)
            # and return them in time sorted order
            nextbits.sort(key=lambda (wave, s, s_frac, cm, fcm): s_frac)
            for wave, s, s_frac, cm, fcm in nextbits:
                uwave = get_padded(DatChunk, int(s)-S_BEFORE-s_start,
                                   int(s)+S_AFTER-s_start).astype(np.int32)
                # cm = add_penumbra(cm, ChannelGraphToUse,
                                  # Parameters['PENUMBRA_SIZE'])
                # fcm = get_float_mask(wave, cm, ChannelGraphToUse,
                                     # 1.)
                yield uwave, wave, s, s_frac, cm, fcm   
                # unfiltered wave,wave, s_peak, ChMask, FloatChMask
        else:    #Original SpikeDetekt. This code duplication is regretable but probably easier to deal with
            
            for IndList in IndListsChunk:
                try:
                    wave, s_peak, sf_peak, cm = extract_wave(IndList, FilteredChunk,
                                                    S_BEFORE, S_AFTER, N_CH,
                                                    s_start,Threshold)
                    s_offset = s_start+s_peak
                    sf_offset = s_start + sf_peak
                    if keep_start<=s_offset<keep_end:
                        spike_count += 1
                        nextbits.append((wave, s_offset, sf_offset, cm))
                except np.linalg.LinAlgError:
                    s = '*** WARNING *** Unalignable spike discarded in chunk {chunk}.'.format(
                            chunk=(s_start, s_end))
                    log_warning(s)
            # and return them in time sorted order
            nextbits.sort(key=lambda (wave, s, s_frac, cm): s_frac)
            for wave, s, s_frac, cm in nextbits:
                uwave = get_padded(DatChunk, int(s)-S_BEFORE-s_start,
                                   int(s)+S_AFTER-s_start).astype(np.int32)
                cm = add_penumbra(cm, ChannelGraphToUse,
                                  Parameters['PENUMBRA_SIZE'])
                fcm = get_float_mask(wave, cm, ChannelGraphToUse,
                                     ThresholdSDFactor)
                yield uwave, wave, s, s_frac, cm, fcm    
                # unfiltered wave,wave, s_peak, ChMask, FloatChMask

        progress_bar.update(float(s_end)/n_samples,
            '%d/%d samples, %d spikes found'%(s_end, n_samples, spike_count))
        if max_spikes is not None and spike_count>=max_spikes:
            break
    
    progress_bar.finish()
Пример #3
0
def extract_spikes(h5s, basename, DatFileNames, n_ch_dat,
                   ChannelsToUse, ChannelGraph,
                   max_spikes=None):
    # some global variables we use
    CHUNK_SIZE = Parameters['CHUNK_SIZE']
    CHUNKS_FOR_THRESH = Parameters['CHUNKS_FOR_THRESH']
    DTYPE = Parameters['DTYPE']
    CHUNK_OVERLAP = Parameters['CHUNK_OVERLAP']
    N_CH = Parameters['N_CH']
    S_JOIN_CC = Parameters['S_JOIN_CC']
    S_BEFORE = Parameters['S_BEFORE']
    S_AFTER = Parameters['S_AFTER']
    THRESH_SD = Parameters['THRESH_SD']

    # filter coefficents for the high pass filtering
    filter_params = get_filter_params()

    progress_bar = ProgressReporter()

    # m A code that writes out a high-pass filtered version of the raw data
    # (.fil file)
    fil_writer = FilWriter(DatFileNames, n_ch_dat)

    # Just use first dat file for getting the thresholding data
    with open(DatFileNames[0], 'rb') as fd:
        # Use 5 chunks to figure out threshold
        DatChunk = get_chunk_for_thresholding(fd, n_ch_dat, ChannelsToUse,
                                              num_samples(DatFileNames[0],
                                                          n_ch_dat))
        FilteredChunk = apply_filtering(filter_params, DatChunk)
        # .6745 converts median to standard deviation
        if Parameters['USE_SINGLE_THRESHOLD']:
            ThresholdSDFactor = np.median(np.abs(FilteredChunk)) / .6745
        else:
            ThresholdSDFactor = np.median(
                np.abs(FilteredChunk),
                axis=0) / .6745
        Threshold = ThresholdSDFactor * THRESH_SD

        print 'Threshold = ', Threshold, '\n'
        # Record the absolute Threshold used
        Parameters['THRESHOLD'] = Threshold

    n_samples = num_samples(DatFileNames, n_ch_dat)

    spike_count = 0
    for (DatChunk, s_start, s_end,
         keep_start, keep_end) in chunks(DatFileNames, n_ch_dat, ChannelsToUse):
        ############## FILTERING ########################################
        FilteredChunk = apply_filtering(filter_params, DatChunk)

        # write filtered output to file
        # if Parameters['WRITE_FIL_FILE']:
        fil_writer.write(FilteredChunk, s_start, s_end, keep_start, keep_end)

        ############## THRESHOLDING #####################################
        if Parameters['DETECT_POSITIVE']:
            BinaryChunk = np.abs(FilteredChunk) > Threshold
        else:
            BinaryChunk = (FilteredChunk < -Threshold)
        BinaryChunk = BinaryChunk.astype(np.int8)
        # write binary chunk filtered output to file
        if Parameters['WRITE_BINFIL_FILE']:
            fil_writer.write_bin(
                BinaryChunk,
                s_start,
                s_end,
                keep_start,
                keep_end)
        ############### FLOOD FILL  ######################################
        ChannelGraphToUse = complete_if_none(ChannelGraph, N_CH)
        IndListsChunk = connected_components(BinaryChunk,
                                             ChannelGraphToUse, S_JOIN_CC)
        if Parameters['DEBUG']:
            plot_diagnostics(
                s_start,
                IndListsChunk,
                BinaryChunk,
                DatChunk,
                FilteredChunk,
                Threshold)
            fil_writer.write_bin(
                BinaryChunk,
                s_start,
                s_end,
                keep_start,
                keep_end)

        ############## ALIGN AND INTERPOLATE WAVES #######################
        nextbits = []
        for IndList in IndListsChunk:
            try:
                wave, s_peak, cm = extract_wave(IndList, FilteredChunk,
                                                S_BEFORE, S_AFTER, N_CH,
                                                s_start, Threshold)
                s_offset = s_start + s_peak
                if keep_start <= s_offset < keep_end:
                    spike_count += 1
                    nextbits.append((wave, s_offset, cm))
            except np.linalg.LinAlgError:
                s = '*** WARNING *** Unalignable spike discarded in chunk {chunk}.'.format(
                    chunk=(s_start, s_end))
                log_warning(s)
        # and return them in time sorted order
        nextbits.sort(key=lambda wave_s_cm: wave_s_cm[1])
        for wave, s, cm in nextbits:
            uwave = get_padded(DatChunk, int(s) - S_BEFORE - s_start,
                               int(s) + S_AFTER - s_start).astype(np.int32)
            cm = add_penumbra(cm, ChannelGraphToUse,
                              Parameters['PENUMBRA_SIZE'])
            fcm = get_float_mask(wave, cm, ChannelGraphToUse,
                                 ThresholdSDFactor)
            yield uwave, wave, s, cm, fcm
        progress_bar.update(float(s_end) / n_samples,
                            '%d/%d samples, %d spikes found' % (s_end, n_samples, spike_count))
        if max_spikes is not None and spike_count >= max_spikes:
            break

    progress_bar.finish()