Пример #1
0
def test_DataIO_probes():
    # initialze dataio
    if os.path.exists('test_DataIO'):
        shutil.rmtree('test_DataIO')
        
    dataio = DataIO(dirname='test_DataIO')
    print(dataio)


    localdir, filenames, params = download_dataset(name='olfactory_bulb')
    dataio.set_data_source(type='RawData', filenames=filenames,  **params)
    
    probe_filename = 'neuronexus/A4x8-5mm-100-400-413-A32.prb'
    dataio.download_probe(probe_filename)
    dataio.download_probe('neuronexus/A4x8-5mm-100-400-413-A32')
    
    #~ print(dataio.channel_groups)
    #~ print(dataio.channels)
    #~ print(dataio.info['probe_filename'])
    
    assert dataio.nb_channel(0) == 8
    assert probe_filename.split('/')[-1] == dataio.info['probe_filename']
    
    dataio = DataIO(dirname='test_DataIO')
    print(dataio)
Пример #2
0
def test_DataIO_probes():
    # initialze dataio
    if os.path.exists('test_DataIO'):
        shutil.rmtree('test_DataIO')
        
    dataio = DataIO(dirname='test_DataIO')
    print(dataio)


    localdir, filenames, params = download_dataset(name='olfactory_bulb')
    dataio.set_data_source(type='RawData', filenames=filenames,  **params)
    
    probe_filename = 'A4x8-5mm-100-400-413-A32.prb'
    dataio.download_probe(probe_filename)
    dataio.download_probe('A4x8-5mm-100-400-413-A32')
    
    #~ print(dataio.channel_groups)
    #~ print(dataio.channels)
    #~ print(dataio.info['probe_filename'])
    
    assert dataio.nb_channel(0) == 8
    assert probe_filename == dataio.info['probe_filename']
    
    dataio = DataIO(dirname='test_DataIO')
    print(dataio)
Пример #3
0
def test_DataIO():
    
    
    # initialze dataio
    if os.path.exists('test_DataIO'):
        shutil.rmtree('test_DataIO')
        
    dataio = DataIO(dirname='test_DataIO')
    print(dataio)


    localdir, filenames, params = download_dataset(name='olfactory_bulb')
    dataio.set_data_source(type='RawData', filenames=filenames,  **params)
    #~ dataio.set_channels(range(4))
    dataio.set_manual_channel_group(range(14))
    
    
    for seg_num in range(dataio.nb_segment):
        for i_stop, sigs_chunk in dataio.iter_over_chunk(seg_num=seg_num, chunksize=1024):
            assert sigs_chunk.shape[0] == 1024
            assert sigs_chunk.shape[1] == 14
            #~ print(seg_num, i_stop, sigs_chunk.shape)
    
    
    #reopen existing
    dataio = DataIO(dirname='test_DataIO')
    print(dataio)
    
    #~ exit()
    
    for seg_num in range(dataio.nb_segment):
        #~ print('seg_num', seg_num)
        for i_stop, sigs_chunk in dataio.iter_over_chunk(seg_num=seg_num, chunksize=1024):
            assert sigs_chunk.shape[0] == 1024
            assert sigs_chunk.shape[1] == 14
Пример #4
0
def make_catalogue():
    if os.path.exists(dirname):
        shutil.rmtree(dirname)

    dataio = DataIO(dirname=dirname)
    localdir, filenames, params = download_dataset(name='olfactory_bulb')
    dataio.set_data_source(type='RawData', filenames=filenames, **params)
    dataio.add_one_channel_group(channels=channels)

    cc = CatalogueConstructor(dataio=dataio)

    params = {
        'duration': 300.,
        'preprocessor': {
            'highpass_freq': 300.,
            'chunksize': 1024,
            'lostfront_chunksize': 100,
        },
        'peak_detector': {
            'peak_sign': '-',
            'relative_threshold': 7.,
            'peak_span': 0.0005,
            #~ 'peak_span' : 0.000,
        },
        'extract_waveforms': {
            'n_left': -25,
            'n_right': 40,
            'nb_max': 10000,
        },
        'clean_waveforms': {
            'alien_value_threshold': 60.,
        },
        'noise_snippet': {
            'nb_snippet': 300,
        },
        'feature_method': 'global_pca',
        'feature_kargs': {
            'n_components': 20
        },
        'cluster_method': 'kmeans',
        'cluster_kargs': {
            'n_clusters': 5
        },
        'clean_cluster': False,
        'clean_cluster_kargs': {},
    }

    apply_all_catalogue_steps(cc, params, verbose=True)

    cc.order_clusters(by='waveforms_rms')
    cc.move_cluster_to_trash(4)
    cc.make_catalogue_for_peeler()
Пример #5
0
def test_DataIO():

    # initialze dataio
    if os.path.exists('test_DataIO'):
        shutil.rmtree('test_DataIO')

    dataio = DataIO(dirname='test_DataIO')
    print(dataio)

    localdir, filenames, params = download_dataset(name='olfactory_bulb')
    dataio.set_data_source(type='RawData', filenames=filenames, **params)

    #with geometry
    channels = list(range(14))
    channel_groups = {
        0: {
            'channels': range(14),
            'geometry': {c: [0, i]
                         for i, c in enumerate(channels)}
        }
    }
    dataio.set_channel_groups(channel_groups)

    #with no geometry
    channel_groups = {0: {'channels': range(4)}}
    dataio.set_channel_groups(channel_groups)

    # add one group
    dataio.add_one_channel_group(channels=range(4, 8), chan_grp=5)

    channel_groups = {0: {'channels': range(14)}}
    dataio.set_channel_groups(channel_groups)

    for seg_num in range(dataio.nb_segment):
        for i_stop, sigs_chunk in dataio.iter_over_chunk(seg_num=seg_num,
                                                         chunksize=1024):
            assert sigs_chunk.shape[0] == 1024
            assert sigs_chunk.shape[1] == 14
            #~ print(seg_num, i_stop, sigs_chunk.shape)

    #reopen existing
    dataio = DataIO(dirname='test_DataIO')
    print(dataio)

    #~ exit()

    for seg_num in range(dataio.nb_segment):
        #~ print('seg_num', seg_num)
        for i_stop, sigs_chunk in dataio.iter_over_chunk(seg_num=seg_num,
                                                         chunksize=1024):
            assert sigs_chunk.shape[0] == 1024
            assert sigs_chunk.shape[1] == 14
Пример #6
0
def preprocess_data(subject, recording_date, data_files):
    output_dir = os.path.join(cfg['single_unit_spike_sorting_dir'], subject,
                              recording_date, 'preprocess')
    if os.path.exists(output_dir):
        # remove is already exists
        shutil.rmtree(output_dir)

    ## Setup DataIO
    dataio = DataIO(dirname=output_dir)
    dataio.set_data_source(type='Intan',
                           filenames=[x['fname'] for x in data_files])

    # Setup channel groups
    arrays_recorded = []
    grp_idx = 0
    for array_idx in range(len(cfg['arrays'])):
        first_chan = ''
        if array_idx == 0:
            first_chan = 'A-000'
        elif array_idx == 1:
            first_chan = 'A-032'
        elif array_idx == 2:
            first_chan = 'B-000'
        elif array_idx == 3:
            first_chan = 'B-032'
        elif array_idx == 4:
            first_chan = 'C-000'
        elif array_idx == 5:
            first_chan = 'C-032'
        found = False
        for i in range(len(dataio.datasource.sig_channels)):
            if dataio.datasource.sig_channels[i][0] == first_chan:
                found = True
                break

        chan_range = []
        if found:
            chan_range = range(grp_idx * cfg['n_channels_per_array'],
                               (grp_idx + 1) * cfg['n_channels_per_array'])
            grp_idx = grp_idx + 1
            arrays_recorded.append(array_idx)
        dataio.add_one_channel_group(channels=chan_range, chan_grp=array_idx)

    print(dataio)

    total_duration = np.sum([x['duration'] for x in data_files])
    for array_idx in arrays_recorded:
        print(array_idx)
        preprocess_array(array_idx, output_dir, total_duration)
Пример #7
0
def initialize_catalogueconstructor(dirname, filenames):
    # create a DataIO
    if os.path.exists(dirname):
        # remove is already exists
        shutil.rmtree(dirname)
    dataio = DataIO(dirname=dirname)

    # The dataset contains 4 channels : we use them all
    #dataio.set_channel_groups({'channels':{'channels':[0, 1, 2, 3]}})

    # feed DataIO
    dataio.set_data_source(type='Intan', filenames=filenames, channel_indexes=list(range(192)))
    #dataio.set_probe_file('/home/bonaiuto/Projects/tool_learning/recordings/rhd2000/betta/default.prb')

    dataio.add_one_channel_group(channels=range(192), chan_grp=0)

    print(dataio)
Пример #8
0
def compute_array_catalogue(array_idx, preprocess_dir, subject, recording_date,
                            data_files, cluster_merge_threshold):
    # If data exists for this array
    if os.path.exists(
            os.path.join(preprocess_dir, 'channel_group_%d' % array_idx,
                         'catalogue_constructor')):
        output_dir = os.path.join(cfg['single_unit_spike_sorting_dir'],
                                  subject, recording_date,
                                  'array_%d' % array_idx)
        if os.path.exists(output_dir):
            # remove is already exists
            shutil.rmtree(output_dir)
        # Compute total duration (want to use all data for clustering)
        data_file_names = []
        for seg in range(len(data_files)):
            data_file_names.append(
                os.path.join(preprocess_dir, 'channel_group_%d' % array_idx,
                             'segment_%d' % seg, 'processed_signals.raw'))

        dataio = DataIO(dirname=output_dir)
        dataio.set_data_source(type='RawData',
                               filenames=data_file_names,
                               dtype='float32',
                               sample_rate=cfg['intan_srate'],
                               total_channel=cfg['n_channels_per_array'])
        dataio.datasource.bit_to_microVolt = 0.195
        for ch_grp in range(cfg['n_channels_per_array']):
            dataio.add_one_channel_group(channels=[ch_grp], chan_grp=ch_grp)

        total_duration = np.sum([x['duration'] for x in data_files])

        figure_out_dir = os.path.join(output_dir, 'figures')
        os.mkdir(figure_out_dir)
        for ch_grp in range(cfg['n_channels_per_array']):
            print(ch_grp)
            cc = CatalogueConstructor(dataio=DataIO(dirname=output_dir,
                                                    ch_grp=ch_grp),
                                      chan_grp=ch_grp)

            fullchain_kargs = {
                'duration': total_duration,
                'preprocessor': {
                    'highpass_freq': None,
                    'lowpass_freq': None,
                    'smooth_size': 0,
                    'common_ref_removal': False,
                    'chunksize': 32768,
                    'lostfront_chunksize': 0,
                    'signalpreprocessor_engine': 'numpy',
                },
                'peak_detector': {
                    'peakdetector_engine': 'numpy',
                    'peak_sign': '-',
                    'relative_threshold': 2.,
                    'peak_span': 0.0002,
                },
                'noise_snippet': {
                    'nb_snippet': 300,
                },
                'extract_waveforms': {
                    'n_left': -20,
                    'n_right': 30,
                    'mode': 'all',
                    'nb_max': 2000000,
                    'align_waveform': False,
                },
                'clean_waveforms': {
                    'alien_value_threshold': 100.,
                },
            }
            feat_method = 'pca_by_channel'
            feat_kargs = {'n_components_by_channel': 5}
            clust_method = 'sawchaincut'
            clust_kargs = {
                'max_loop': 1000,
                'nb_min': 20,
                'break_nb_remain': 30,
                'kde_bandwith': 0.01,
                'auto_merge_threshold': 2.,
                'print_debug': False
                # 'max_loop': 1000,
                # 'nb_min': 20,
                # 'break_nb_remain': 30,
                # 'kde_bandwith': 0.01,
                # 'auto_merge_threshold': cluster_merge_threshold,
                # 'print_debug': False
            }

            p = {}
            p.update(fullchain_kargs['preprocessor'])
            p.update(fullchain_kargs['peak_detector'])
            cc.set_preprocessor_params(**p)

            noise_duration = min(
                10., fullchain_kargs['duration'],
                dataio.get_segment_length(seg_num=0) / dataio.sample_rate *
                .99)
            # ~ print('noise_duration', noise_duration)
            t1 = time.perf_counter()
            cc.estimate_signals_noise(seg_num=0, duration=noise_duration)
            t2 = time.perf_counter()
            print('estimate_signals_noise', t2 - t1)

            t1 = time.perf_counter()
            cc.run_signalprocessor(duration=fullchain_kargs['duration'])
            t2 = time.perf_counter()
            print('run_signalprocessor', t2 - t1)

            t1 = time.perf_counter()
            cc.extract_some_waveforms(**fullchain_kargs['extract_waveforms'])
            t2 = time.perf_counter()
            print('extract_some_waveforms', t2 - t1)

            fname = 'chan_%d_init_waveforms.png' % ch_grp
            fig = plot_waveforms(np.squeeze(cc.some_waveforms).T)
            fig.savefig(os.path.join(figure_out_dir, fname))
            fig.clf()
            plt.close()

            t1 = time.perf_counter()
            # ~ duration = d['duration'] if d['limit_duration'] else None
            # ~ d['clean_waveforms']
            cc.clean_waveforms(**fullchain_kargs['clean_waveforms'])
            t2 = time.perf_counter()
            print('clean_waveforms', t2 - t1)

            fname = 'chan_%d_clean_waveforms.png' % ch_grp
            fig = plot_waveforms(np.squeeze(cc.some_waveforms).T)
            fig.savefig(os.path.join(figure_out_dir, fname))
            fig.clf()
            plt.close()

            # ~ t1 = time.perf_counter()
            # ~ n_left, n_right = cc.find_good_limits(mad_threshold = 1.1,)
            # ~ t2 = time.perf_counter()
            # ~ print('find_good_limits', t2-t1)

            t1 = time.perf_counter()
            cc.extract_some_noise(**fullchain_kargs['noise_snippet'])
            t2 = time.perf_counter()
            print('extract_some_noise', t2 - t1)

            # Plot noise
            fname = 'chan_%d_noise.png' % ch_grp
            fig = plot_noise(cc)
            fig.savefig(os.path.join(figure_out_dir, fname))
            fig.clf()
            plt.close()

            t1 = time.perf_counter()
            cc.extract_some_features(method=feat_method, **feat_kargs)
            t2 = time.perf_counter()
            print('project', t2 - t1)

            t1 = time.perf_counter()
            cc.find_clusters(method=clust_method, **clust_kargs)
            t2 = time.perf_counter()
            print('find_clusters', t2 - t1)

            # Remove empty clusters
            cc.trash_small_cluster(n=0)

            if cc.centroids_median is None:
                cc.compute_all_centroid()

            # order cluster by waveforms rms
            cc.order_clusters(by='waveforms_rms')

            fname = 'chan_%d_init_clusters.png' % ch_grp
            cluster_labels = cc.clusters['cluster_label']
            fig = plot_cluster_waveforms(cc, cluster_labels)
            fig.savefig(os.path.join(figure_out_dir, fname))
            fig.clf()
            plt.close()

            # save the catalogue
            cc.make_catalogue_for_peeler()

            gc.collect()