def test_example_works_pip_and_dict(path_to_config_sample): with open(path_to_config_sample) as f: cfg = yaml.load(f) pipeline.run(cfg) clean_tmp()
def sort(config, logger_level, clean, output_dir, complete, zero_seed, global_gpu_memory, calculate_rf, visualize): """ Sort recordings using a configuration file located in CONFIG """ return pipeline.run(config, logger_level=logger_level, clean=clean, output_dir=output_dir, complete=complete, calculate_rf=calculate_rf, visualize=visualize)#,
def sort(config, logger_level, clean, output_dir, complete, zero_seed): """ Sort recordings using a configuration file located in CONFIG """ return pipeline.run(config, logger_level=logger_level, clean=clean, output_dir=output_dir, complete=complete, set_zero_seed=zero_seed)
def test_threshold_pipeline_returns_expected_results(path_to_threshold_config, path_to_data_folder): spike_train = pipeline.run(path_to_threshold_config, clean=True) path_to_reference = path.join(path_to_data_folder, 'output_reference', 'threshold_spike_train.npy') ReferenceTesting.assert_array_equal(spike_train, path_to_reference) clean_tmp()
def sort(config, logger_level, clean, output_dir, complete, zero_seed, global_gpu_memory): """ Sort recordings using a configuration file located in CONFIG """ if global_gpu_memory != 1.0: tf_config = tf.ConfigProto() (tf_config. gpu_options.per_process_gpu_memory_fraction) = global_gpu_memory yass.set_tensorflow_config(config=tf_config) return pipeline.run(config, logger_level=logger_level, clean=clean, output_dir=output_dir, complete=complete)#,
def test_works_with_nnet_config(patch_triage_network, path_to_config, make_tmp_folder): pipeline.run(path_to_config, output_dir=make_tmp_folder)
def test_works_with_sample_config_passing_dict(path_to_config, make_tmp_folder): with open(path_to_config) as f: cfg = yaml.load(f) pipeline.run(cfg, output_dir=make_tmp_folder)
def test_works_with_threshold_config(path_to_config_threshold, make_tmp_folder): pipeline.run(path_to_config_threshold, output_dir=make_tmp_folder)
def main(n_batches=6): """Runs the procedure for evaluating yass on retinal data.""" config_file = open('config_template.yaml', 'r') config = yaml.load(config_file) config_file.close() # Extracting window around spikes. sampling_rate = config['recordings']['sampling_rate'] n_chan = config['recordings']['n_channels'] dtype = config['recordings']['dtype'] spike_length = config['recordings']['spike_size_ms'] window_radius = int(spike_length * sampling_rate / 1e3) window = range(-window_radius, window_radius) k_tot_data = 4 # Set up the pyplot figures stb_plot = EvaluationPlot('EJ Retinal', k_tot_data, eval_type='Stability') acc_plot = EvaluationPlot('EJ Retinal', k_tot_data) for data_idx, data_number in enumerate(range(1, k_tot_data + 1)): # Setting up config file for yass. bin_file = 'ej49_data{}.bin'.format(data_number) geom_file = 'ej49_geometry{}.txt'.format(data_number) config['data']['recordings'] = bin_file config['data']['geometry'] = geom_file spike_train = run(config=config) # Data augmentation setup. os.path.getsize(bin_file) file_size_bytes = os.path.getsize(bin_file) tot_samples = file_size_bytes / (np.dtype(dtype).itemsize * n_chan) radius = 70 n_batch_samples = int(tot_samples / n_batches) batch_reader = RecordingBatchIterator( bin_file, geom_file, sample_rate=sampling_rate, batch_time_samples=n_batch_samples, n_batches=n_batches, n_chan=n_chan, radius=radius, whiten=False) mean_wave = MeanWaveCalculator(batch_reader, spike_train, window=window) mean_wave.compute_templates(n_batches=n_batches) # Augment with new spikes. stab = RecordingAugmentation(mean_wave, augment_rate=0.25, move_rate=0.2) aug_bin_file = 'ej49_data{}.aug.bin'.format(data_number) aug_gold_spt, status = stab.save_augment_recording( aug_bin_file, n_batches) np.save('ej49_data{}.aug.npy'.format(data_number), aug_gold_spt) # Setting up config file for yass to run on augmented data. config['data']['recordings'] = aug_bin_file config['data']['geometry'] = geom_file yass_aug_spike_train = run(config=config) # Evaluate accuracy of yass. gold_std_spike_train_file = 'groundtruth_ej49_data{}.mat'.format( data_number) gold_std_map = scipy.io.loadmat(gold_std_spike_train_file) gold_std_spike_train = np.append(gold_std_map['spt_gt'], gold_std_map['L_gt'], axis=1) gold_standard_mean_wave = MeanWaveCalculator(batch_reader, gold_std_spike_train, window=window) gold_standard_mean_wave.compute_templates(n_batches=n_batches) accuracy_eval = SpikeSortingEvaluation( gold_std_spike_train, spike_train, gold_standard_mean_wave.templates, mean_wave.templates) acc_tp = accuracy_eval.true_positive acc_plot.add_metric( np.log(temp_snr(gold_standard_mean_wave.templates)), acc_tp, data_idx) batch_reader.close_iterator() # Evaluate stability of yass. batch_reader = RecordingBatchIterator( aug_bin_file, geom_file, sample_rate=sampling_rate, batch_time_samples=n_batch_samples, n_batches=n_batches, n_chan=n_chan, radius=radius, whiten=False) aug_gold_standard_mean_wave = MeanWaveCalculator(batch_reader, aug_gold_spt, window=window) aug_gold_standard_mean_wave.compute_templates(n_batches=n_batches) aug_yass_mean_wave = MeanWaveCalculator(batch_reader, yass_aug_spike_train, window=window) aug_yass_mean_wave.compute_templates(n_batches=n_batches) stability_eval = SpikeSortingEvaluation( aug_gold_spt, yass_aug_spike_train, aug_gold_standard_mean_wave.templates, aug_yass_mean_wave.templates) stb_tp = stability_eval.true_positive stb_plot.add_metric( np.log(temp_snr(aug_gold_standard_mean_wave.templates)), stb_tp, data_idx) batch_reader.close_iterator() # Render the plots and save them. acc_plot.generate_snr_metric_plot() stb_plot.generate_snr_metric_plot()
def run_stability(self, n_batches=6): """Runs stability metric computation for the given config file. Parameters ---------- n_batchs: int Break down the processing of the dataset in these many batches. """ # Check whether this analysis is not done already. if os.path.isfile(self.stability_file): return sampling_rate = self.config['recordings']['sampling_rate'] n_chan = self.config['recordings']['n_channels'] dtype = self.config['recordings']['dtype'] spike_length = self.config['recordings']['spike_size_ms'] # Extracting window around spikes. window_radius = int(spike_length * sampling_rate / 1e3) window = range(-window_radius, window_radius) bin_file = os.path.join(self.root_dir, self.config['data']['recordings']) geom_file = os.path.join(self.root_dir, self.config['data']['geometry']) # Check whether spike sorting has already been completed. if not os.path.isfile(self.yass_spike_train_file): spike_train = run(config=self.config) else: spike_train = np.load(self.yass_spike_train_file) # Data augmentation setup. os.path.getsize(bin_file) file_size_bytes = os.path.getsize(bin_file) tot_samples = file_size_bytes / (np.dtype(dtype).itemsize * n_chan) radius = 70 n_batch_samples = int(tot_samples / n_batches) bin_extension = os.path.splitext(bin_file)[1] aug_file_name = 'augmented_recording{}'.format(bin_extension) aug_bin_file = os.path.join(self.tmp_dir, aug_file_name) # Check whether data augmentation has been done before or not. is_file_aug_bin = os.path.isfile(aug_bin_file) is_file_aug_spt = os.path.isfile(self.aug_spike_train_file) is_file_yass_temp = os.path.isfile(self.yass_templates_file) if is_file_aug_bin and is_file_aug_spt and is_file_yass_temp: aug_gold_spt = np.load(self.aug_spike_train_file) else: batch_reader = RecordingBatchIterator( bin_file, geom_file, sample_rate=sampling_rate, batch_time_samples=n_batch_samples, n_batches=n_batches, n_chan=n_chan, radius=radius, whiten=False) mean_wave = MeanWaveCalculator(batch_reader, spike_train, window=window) mean_wave.compute_templates(n_batches=n_batches) np.save(self.yass_templates_file, mean_wave.templates) # Compute gold standard mean waveforms too. gold_standard_mean_wave = MeanWaveCalculator( batch_reader, self.gold_std_spike_train, window=window) gold_standard_mean_wave.compute_templates(n_batches=n_batches) np.save(self.gold_templates_file, gold_standard_mean_wave.templates) # Augment with new spikes. stab = RecordingAugmentation(mean_wave, augment_rate=0.25, move_rate=0.2) aug_gold_spt, status = stab.save_augment_recording( aug_bin_file, n_batches) np.save(self.aug_spike_train_file, aug_gold_spt) np.save(os.path.join(self.tmp_dir, 'geom.npy'), batch_reader.geometry) # Setting up config file for yass to run on augmented data. self.config['data']['root_folder'] = self.tmp_dir self.config['data']['recordings'] = aug_file_name self.config['data']['geometry'] = 'geom.npy' # Check whether spike sorting has already been completed. if not os.path.isfile(self.yass_aug_spike_train_file): yass_aug_spike_train = run(config=self.config) else: yass_aug_spike_train = np.load(self.yass_aug_spike_train_file) # Evaluate stability of yass. # Check whether the mean wave of the yass spike train on the augmented # data has been computed before or not. is_file_temp_yass = os.path.isfile(self.yass_aug_templates_file) is_file_temp_gold = os.path.isfile(self.gold_aug_templates_file) if is_file_temp_gold and is_file_temp_yass: gold_aug_templates = np.load(self.gold_aug_templates_file) yass_aug_templates = np.load(self.yass_aug_templates_file) else: batch_reader = RecordingBatchIterator( aug_bin_file, geom_file, sample_rate=sampling_rate, batch_time_samples=n_batch_samples, n_batches=n_batches, n_chan=n_chan, radius=radius, filter_std=False, whiten=False) aug_gold_standard_mean_wave = MeanWaveCalculator(batch_reader, aug_gold_spt, window=window) aug_gold_standard_mean_wave.compute_templates(n_batches=n_batches) gold_aug_templates = aug_gold_standard_mean_wave.templates aug_yass_mean_wave = MeanWaveCalculator(batch_reader, yass_aug_spike_train, window=window) aug_yass_mean_wave.compute_templates(n_batches=n_batches) yass_aug_templates = aug_yass_mean_wave.templates batch_reader.close_iterator() np.save(self.gold_aug_templates_file, gold_aug_templates) np.save(self.yass_aug_templates_file, yass_aug_templates) # Finally, instantiate a spike train evaluation object for comparisons. stability_eval = SpikeSortingEvaluation(aug_gold_spt, yass_aug_spike_train, gold_aug_templates, yass_aug_templates) # Saving results of evaluation for stability. stability_results = np.array([ stability_eval.true_positive, stability_eval.false_positive, stability_eval.unit_cluster_map ]) np.save(self.stability_file, stability_results)
def test_example_works_default_pipeline(path_to_config_sample): pipeline.run(path_to_config_sample) clean_tmp()