示例#1
0
def test_example_works_pip_and_dict(path_to_config_sample):

    with open(path_to_config_sample) as f:
        cfg = yaml.load(f)

    pipeline.run(cfg)
    clean_tmp()
示例#2
0
def sort(config, logger_level, clean, output_dir, complete, zero_seed,
         global_gpu_memory, calculate_rf, visualize):
    """
    Sort recordings using a configuration file located in CONFIG
    """

    return pipeline.run(config, logger_level=logger_level, clean=clean,
                        output_dir=output_dir, complete=complete,
                        calculate_rf=calculate_rf, visualize=visualize)#,
示例#3
0
def sort(config, logger_level, clean, output_dir, complete, zero_seed):
    """
    Sort recordings using a configuration file located in CONFIG
    """
    return pipeline.run(config,
                        logger_level=logger_level,
                        clean=clean,
                        output_dir=output_dir,
                        complete=complete,
                        set_zero_seed=zero_seed)
示例#4
0
def test_threshold_pipeline_returns_expected_results(path_to_threshold_config,
                                                     path_to_data_folder):
    spike_train = pipeline.run(path_to_threshold_config, clean=True)

    path_to_reference = path.join(path_to_data_folder,
                                  'output_reference',
                                  'threshold_spike_train.npy')

    ReferenceTesting.assert_array_equal(spike_train, path_to_reference)

    clean_tmp()
示例#5
0
def sort(config, logger_level, clean, output_dir, complete, zero_seed,
         global_gpu_memory):
    """
    Sort recordings using a configuration file located in CONFIG
    """
    if global_gpu_memory != 1.0:
        tf_config = tf.ConfigProto()
        (tf_config.
         gpu_options.per_process_gpu_memory_fraction) = global_gpu_memory
        yass.set_tensorflow_config(config=tf_config)

    return pipeline.run(config, logger_level=logger_level, clean=clean,
                        output_dir=output_dir, complete=complete)#,
示例#6
0
def test_works_with_nnet_config(patch_triage_network, path_to_config,
                                make_tmp_folder):
    pipeline.run(path_to_config, output_dir=make_tmp_folder)
示例#7
0
def test_works_with_sample_config_passing_dict(path_to_config,
                                               make_tmp_folder):
    with open(path_to_config) as f:
        cfg = yaml.load(f)

    pipeline.run(cfg, output_dir=make_tmp_folder)
示例#8
0
def test_works_with_threshold_config(path_to_config_threshold,
                                     make_tmp_folder):
    pipeline.run(path_to_config_threshold, output_dir=make_tmp_folder)
示例#9
0
def main(n_batches=6):
    """Runs the procedure for evaluating yass on retinal data."""

    config_file = open('config_template.yaml', 'r')
    config = yaml.load(config_file)
    config_file.close()
    # Extracting window around spikes.
    sampling_rate = config['recordings']['sampling_rate']
    n_chan = config['recordings']['n_channels']
    dtype = config['recordings']['dtype']
    spike_length = config['recordings']['spike_size_ms']
    window_radius = int(spike_length * sampling_rate / 1e3)
    window = range(-window_radius, window_radius)

    k_tot_data = 4

    # Set up the pyplot figures
    stb_plot = EvaluationPlot('EJ Retinal', k_tot_data, eval_type='Stability')
    acc_plot = EvaluationPlot('EJ Retinal', k_tot_data)

    for data_idx, data_number in enumerate(range(1, k_tot_data + 1)):
        # Setting up config file for yass.
        bin_file = 'ej49_data{}.bin'.format(data_number)
        geom_file = 'ej49_geometry{}.txt'.format(data_number)
        config['data']['recordings'] = bin_file
        config['data']['geometry'] = geom_file
        spike_train = run(config=config)
        # Data augmentation setup.
        os.path.getsize(bin_file)
        file_size_bytes = os.path.getsize(bin_file)
        tot_samples = file_size_bytes / (np.dtype(dtype).itemsize * n_chan)
        radius = 70
        n_batch_samples = int(tot_samples / n_batches)
        batch_reader = RecordingBatchIterator(
            bin_file,
            geom_file,
            sample_rate=sampling_rate,
            batch_time_samples=n_batch_samples,
            n_batches=n_batches,
            n_chan=n_chan,
            radius=radius,
            whiten=False)
        mean_wave = MeanWaveCalculator(batch_reader,
                                       spike_train,
                                       window=window)
        mean_wave.compute_templates(n_batches=n_batches)
        # Augment with new spikes.
        stab = RecordingAugmentation(mean_wave,
                                     augment_rate=0.25,
                                     move_rate=0.2)
        aug_bin_file = 'ej49_data{}.aug.bin'.format(data_number)
        aug_gold_spt, status = stab.save_augment_recording(
            aug_bin_file, n_batches)
        np.save('ej49_data{}.aug.npy'.format(data_number), aug_gold_spt)
        # Setting up config file for yass to run on augmented data.
        config['data']['recordings'] = aug_bin_file
        config['data']['geometry'] = geom_file
        yass_aug_spike_train = run(config=config)

        # Evaluate accuracy of yass.
        gold_std_spike_train_file = 'groundtruth_ej49_data{}.mat'.format(
            data_number)
        gold_std_map = scipy.io.loadmat(gold_std_spike_train_file)
        gold_std_spike_train = np.append(gold_std_map['spt_gt'],
                                         gold_std_map['L_gt'],
                                         axis=1)
        gold_standard_mean_wave = MeanWaveCalculator(batch_reader,
                                                     gold_std_spike_train,
                                                     window=window)
        gold_standard_mean_wave.compute_templates(n_batches=n_batches)
        accuracy_eval = SpikeSortingEvaluation(
            gold_std_spike_train, spike_train,
            gold_standard_mean_wave.templates, mean_wave.templates)
        acc_tp = accuracy_eval.true_positive
        acc_plot.add_metric(
            np.log(temp_snr(gold_standard_mean_wave.templates)), acc_tp,
            data_idx)
        batch_reader.close_iterator()

        # Evaluate stability of yass.
        batch_reader = RecordingBatchIterator(
            aug_bin_file,
            geom_file,
            sample_rate=sampling_rate,
            batch_time_samples=n_batch_samples,
            n_batches=n_batches,
            n_chan=n_chan,
            radius=radius,
            whiten=False)
        aug_gold_standard_mean_wave = MeanWaveCalculator(batch_reader,
                                                         aug_gold_spt,
                                                         window=window)
        aug_gold_standard_mean_wave.compute_templates(n_batches=n_batches)
        aug_yass_mean_wave = MeanWaveCalculator(batch_reader,
                                                yass_aug_spike_train,
                                                window=window)
        aug_yass_mean_wave.compute_templates(n_batches=n_batches)
        stability_eval = SpikeSortingEvaluation(
            aug_gold_spt, yass_aug_spike_train,
            aug_gold_standard_mean_wave.templates,
            aug_yass_mean_wave.templates)
        stb_tp = stability_eval.true_positive
        stb_plot.add_metric(
            np.log(temp_snr(aug_gold_standard_mean_wave.templates)), stb_tp,
            data_idx)
        batch_reader.close_iterator()

    # Render the plots and save them.
    acc_plot.generate_snr_metric_plot()
    stb_plot.generate_snr_metric_plot()
示例#10
0
    def run_stability(self, n_batches=6):
        """Runs stability metric computation for the given config file.

        Parameters
        ----------
        n_batchs: int
            Break down the processing of the dataset in these many batches.
        """
        # Check whether this analysis is not done already.
        if os.path.isfile(self.stability_file):
            return
        sampling_rate = self.config['recordings']['sampling_rate']
        n_chan = self.config['recordings']['n_channels']
        dtype = self.config['recordings']['dtype']
        spike_length = self.config['recordings']['spike_size_ms']
        # Extracting window around spikes.
        window_radius = int(spike_length * sampling_rate / 1e3)
        window = range(-window_radius, window_radius)

        bin_file = os.path.join(self.root_dir,
                                self.config['data']['recordings'])
        geom_file = os.path.join(self.root_dir,
                                 self.config['data']['geometry'])
        # Check whether spike sorting has already been completed.
        if not os.path.isfile(self.yass_spike_train_file):
            spike_train = run(config=self.config)
        else:
            spike_train = np.load(self.yass_spike_train_file)
        # Data augmentation setup.
        os.path.getsize(bin_file)
        file_size_bytes = os.path.getsize(bin_file)
        tot_samples = file_size_bytes / (np.dtype(dtype).itemsize * n_chan)
        radius = 70
        n_batch_samples = int(tot_samples / n_batches)

        bin_extension = os.path.splitext(bin_file)[1]
        aug_file_name = 'augmented_recording{}'.format(bin_extension)
        aug_bin_file = os.path.join(self.tmp_dir, aug_file_name)

        # Check whether data augmentation has been done before or not.
        is_file_aug_bin = os.path.isfile(aug_bin_file)
        is_file_aug_spt = os.path.isfile(self.aug_spike_train_file)
        is_file_yass_temp = os.path.isfile(self.yass_templates_file)
        if is_file_aug_bin and is_file_aug_spt and is_file_yass_temp:
            aug_gold_spt = np.load(self.aug_spike_train_file)
        else:
            batch_reader = RecordingBatchIterator(
                bin_file,
                geom_file,
                sample_rate=sampling_rate,
                batch_time_samples=n_batch_samples,
                n_batches=n_batches,
                n_chan=n_chan,
                radius=radius,
                whiten=False)
            mean_wave = MeanWaveCalculator(batch_reader,
                                           spike_train,
                                           window=window)
            mean_wave.compute_templates(n_batches=n_batches)
            np.save(self.yass_templates_file, mean_wave.templates)
            # Compute gold standard mean waveforms too.
            gold_standard_mean_wave = MeanWaveCalculator(
                batch_reader, self.gold_std_spike_train, window=window)
            gold_standard_mean_wave.compute_templates(n_batches=n_batches)
            np.save(self.gold_templates_file,
                    gold_standard_mean_wave.templates)
            # Augment with new spikes.
            stab = RecordingAugmentation(mean_wave,
                                         augment_rate=0.25,
                                         move_rate=0.2)
            aug_gold_spt, status = stab.save_augment_recording(
                aug_bin_file, n_batches)
            np.save(self.aug_spike_train_file, aug_gold_spt)
            np.save(os.path.join(self.tmp_dir, 'geom.npy'),
                    batch_reader.geometry)

        # Setting up config file for yass to run on augmented data.
        self.config['data']['root_folder'] = self.tmp_dir
        self.config['data']['recordings'] = aug_file_name
        self.config['data']['geometry'] = 'geom.npy'
        # Check whether spike sorting has already been completed.
        if not os.path.isfile(self.yass_aug_spike_train_file):
            yass_aug_spike_train = run(config=self.config)
        else:
            yass_aug_spike_train = np.load(self.yass_aug_spike_train_file)
        # Evaluate stability of yass.
        # Check whether the mean wave of the yass spike train on the augmented
        # data has been computed before or not.
        is_file_temp_yass = os.path.isfile(self.yass_aug_templates_file)
        is_file_temp_gold = os.path.isfile(self.gold_aug_templates_file)
        if is_file_temp_gold and is_file_temp_yass:
            gold_aug_templates = np.load(self.gold_aug_templates_file)
            yass_aug_templates = np.load(self.yass_aug_templates_file)
        else:
            batch_reader = RecordingBatchIterator(
                aug_bin_file,
                geom_file,
                sample_rate=sampling_rate,
                batch_time_samples=n_batch_samples,
                n_batches=n_batches,
                n_chan=n_chan,
                radius=radius,
                filter_std=False,
                whiten=False)
            aug_gold_standard_mean_wave = MeanWaveCalculator(batch_reader,
                                                             aug_gold_spt,
                                                             window=window)
            aug_gold_standard_mean_wave.compute_templates(n_batches=n_batches)
            gold_aug_templates = aug_gold_standard_mean_wave.templates
            aug_yass_mean_wave = MeanWaveCalculator(batch_reader,
                                                    yass_aug_spike_train,
                                                    window=window)
            aug_yass_mean_wave.compute_templates(n_batches=n_batches)
            yass_aug_templates = aug_yass_mean_wave.templates
            batch_reader.close_iterator()
            np.save(self.gold_aug_templates_file, gold_aug_templates)
            np.save(self.yass_aug_templates_file, yass_aug_templates)
        # Finally, instantiate a spike train evaluation object for comparisons.
        stability_eval = SpikeSortingEvaluation(aug_gold_spt,
                                                yass_aug_spike_train,
                                                gold_aug_templates,
                                                yass_aug_templates)
        # Saving results of evaluation for stability.
        stability_results = np.array([
            stability_eval.true_positive, stability_eval.false_positive,
            stability_eval.unit_cluster_map
        ])
        np.save(self.stability_file, stability_results)
示例#11
0
def test_example_works_default_pipeline(path_to_config_sample):
    pipeline.run(path_to_config_sample)
    clean_tmp()