def test_can_reload_detector(path_to_tests, path_to_sample_pipeline_folder, tmp_folder): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make_training_data(CONFIG, spike_train, chosen_templates, min_amplitude, n_spikes, path_to_sample_pipeline_folder) _, waveform_length, n_neighbors = x_detect.shape path_to_model = path.join(tmp_folder, 'detect-net.ckpt') detector = NeuralNetDetector(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, channel_index=CONFIG.channel_index, n_iter=10) detector.fit(x_detect, y_detect) NeuralNetDetector.load(path_to_model, threshold=0.5, channel_index=CONFIG.channel_index)
def test_can_use_detect_and_triage_after_reload(path_to_tests, path_to_sample_pipeline_folder, tmp_folder, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make_training_data(CONFIG, spike_train, chosen_templates, min_amplitude, n_spikes, path_to_sample_pipeline_folder) _, waveform_length, n_neighbors = x_detect.shape path_to_model = path.join(tmp_folder, 'detect-net.ckpt') detector = NeuralNetDetector(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, channel_index=CONFIG.channel_index, n_iter=10) detector.fit(x_detect, y_detect) detector = NeuralNetDetector.load(path_to_model, threshold=0.5, channel_index=CONFIG.channel_index) triage = NeuralNetTriage(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, n_iter=10) triage.fit(x_detect, y_detect) triage = NeuralNetTriage.load(path_to_model, threshold=0.5) data = RecordingExplorer(path_to_standarized_data).reader.data output_names = ('spike_index', 'waveform', 'probability') (spike_index, waveform, proba) = detector.predict(data, output_names=output_names) triage.predict(waveform[:, :, :n_neighbors])
def test_can_use_detector_after_fit(path_to_config, path_to_sample_pipeline_folder, make_tmp_folder, path_to_standardized_data): yass.set_config(path_to_config, make_tmp_folder) CONFIG = yass.read_config() spike_train = np.load(path.join(path_to_sample_pipeline_folder, 'spike_train.npy')) chosen_templates = np.unique(spike_train[:, 1]) min_amplitude = 4 max_amplitude = 60 n_spikes_to_make = 100 templates = make.load_templates(path_to_sample_pipeline_folder, spike_train, CONFIG, chosen_templates) path_to_standardized = path.join(path_to_sample_pipeline_folder, 'preprocess', 'standarized.bin') (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make.training_data(CONFIG, templates, min_amplitude, max_amplitude, n_spikes_to_make, path_to_standardized) _, waveform_length, n_neighbors = x_detect.shape path_to_model = path.join(make_tmp_folder, 'detect-net.ckpt') detector = NeuralNetDetector(path_to_model, [8, 4], waveform_length, n_neighbors, threshold=0.5, channel_index=CONFIG.channel_index, n_iter=10) detector.fit(x_detect, y_detect) data = RecordingExplorer(path_to_standardized_data).reader.data output_names = ('spike_index', 'waveform', 'probability') (spike_index, waveform, proba) = detector.predict_recording(data, output_names=output_names) detector.predict(x_detect)
def test_can_reload_detector(path_to_config, path_to_sample_pipeline_folder, make_tmp_folder): yass.set_config(path_to_config, make_tmp_folder) CONFIG = yass.read_config() spike_train = np.load(path.join(path_to_sample_pipeline_folder, 'spike_train.npy')) chosen_templates = np.unique(spike_train[:, 1]) min_amplitude = 4 max_amplitude = 60 n_spikes_to_make = 100 templates = make.load_templates(path_to_sample_pipeline_folder, spike_train, CONFIG, chosen_templates) path_to_standarized = path.join(path_to_sample_pipeline_folder, 'preprocess', 'standarized.bin') (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make.training_data(CONFIG, templates, min_amplitude, max_amplitude, n_spikes_to_make, path_to_standarized) _, waveform_length, n_neighbors = x_detect.shape path_to_model = path.join(make_tmp_folder, 'detect-net.ckpt') detector = NeuralNetDetector(path_to_model, [8, 4], waveform_length, n_neighbors, threshold=0.5, channel_index=CONFIG.channel_index, n_iter=10) detector.fit(x_detect, y_detect) NeuralNetDetector.load(path_to_model, threshold=0.5, channel_index=CONFIG.channel_index)