Пример #1
0
def test_can_make_training_data(path_to_tests, path_to_sample_pipeline_folder):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    make_training_data(CONFIG,
                       spike_train,
                       chosen_templates,
                       min_amplitude,
                       n_spikes_to_make,
                       data_folder=path_to_sample_pipeline_folder)
Пример #2
0
def test_can_reload_detector(path_to_tests, path_to_sample_pipeline_folder,
                             tmp_folder):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    (x_detect, y_detect, x_triage, y_triage, x_ae,
     y_ae) = make_training_data(CONFIG, spike_train, chosen_templates,
                                min_amplitude, n_spikes,
                                path_to_sample_pipeline_folder)

    _, waveform_length, n_neighbors = x_detect.shape

    path_to_model = path.join(tmp_folder, 'detect-net.ckpt')

    detector = NeuralNetDetector(path_to_model,
                                 filters,
                                 waveform_length,
                                 n_neighbors,
                                 threshold=0.5,
                                 channel_index=CONFIG.channel_index,
                                 n_iter=10)

    detector.fit(x_detect, y_detect)

    NeuralNetDetector.load(path_to_model,
                           threshold=0.5,
                           channel_index=CONFIG.channel_index)
Пример #3
0
def test_can_use_detect_and_triage_after_reload(path_to_tests,
                                                path_to_sample_pipeline_folder,
                                                tmp_folder,
                                                path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    (x_detect, y_detect, x_triage, y_triage, x_ae,
     y_ae) = make_training_data(CONFIG, spike_train, chosen_templates,
                                min_amplitude, n_spikes,
                                path_to_sample_pipeline_folder)

    _, waveform_length, n_neighbors = x_detect.shape

    path_to_model = path.join(tmp_folder, 'detect-net.ckpt')

    detector = NeuralNetDetector(path_to_model,
                                 filters,
                                 waveform_length,
                                 n_neighbors,
                                 threshold=0.5,
                                 channel_index=CONFIG.channel_index,
                                 n_iter=10)

    detector.fit(x_detect, y_detect)

    detector = NeuralNetDetector.load(path_to_model,
                                      threshold=0.5,
                                      channel_index=CONFIG.channel_index)

    triage = NeuralNetTriage(path_to_model,
                             filters,
                             waveform_length,
                             n_neighbors,
                             threshold=0.5,
                             n_iter=10)

    triage.fit(x_detect, y_detect)

    triage = NeuralNetTriage.load(path_to_model, threshold=0.5)

    data = RecordingExplorer(path_to_standarized_data).reader.data

    output_names = ('spike_index', 'waveform', 'probability')

    (spike_index, waveform,
     proba) = detector.predict(data, output_names=output_names)

    triage.predict(waveform[:, :, :n_neighbors])
Пример #4
0
def train_neural_networks(CONFIG, CONFIG_TRAIN, spike_train, data_folder):
    """Train all neural networks

    Parameters
    ----------
    """
    logger = logging.getLogger(__name__)

    chosen_templates = CONFIG_TRAIN['templates']['ids']
    min_amp = CONFIG_TRAIN['templates']['minimum_amplitude']
    nspikes = CONFIG_TRAIN['training']['n_spikes']
    n_filters_detect = CONFIG_TRAIN['network_detector']['n_filters']
    n_iter = CONFIG_TRAIN['training']['n_iterations']
    n_batch = CONFIG_TRAIN['training']['n_batch']
    l2_reg_scale = CONFIG_TRAIN['training']['l2_regularization_scale']
    train_step_size = CONFIG_TRAIN['training']['step_size']
    detectnet_name = CONFIG_TRAIN['network_detector']['name'] + '.ckpt'
    n_filters_triage = CONFIG_TRAIN['network_triage']['n_filters']
    triagenet_name = CONFIG_TRAIN['network_triage']['name'] + '.ckpt'
    n_features = CONFIG_TRAIN['network_autoencoder']['n_features']
    ae_name = CONFIG_TRAIN['network_autoencoder']['name'] + '.ckpt'

    # generate training data
    logger.info('Generating training data...')
    (x_detect, y_detect, x_triage, y_triage, x_ae,
     y_ae) = make_training_data(CONFIG,
                                spike_train,
                                chosen_templates,
                                min_amp,
                                nspikes,
                                data_folder=data_folder)

    # train detector
    logger.info('Training detector network...')
    train_detector(x_detect, y_detect, n_filters_detect, n_iter, n_batch,
                   l2_reg_scale, train_step_size, detectnet_name)

    # save detector model parameters
    logger.info('Saving detector network parameters...')
    save_detect_network_params(filters=n_filters_detect,
                               size=x_detect.shape[1],
                               n_neighbors=x_detect.shape[2],
                               output_path=change_extension(
                                   detectnet_name, 'yaml'))

    # train triage
    logger.info('Training triage network...')
    train_triage(x_triage, y_triage, n_filters_triage, n_iter, n_batch,
                 l2_reg_scale, train_step_size, triagenet_name)

    # save triage model parameters
    logger.info('Saving triage network parameters...')
    save_triage_network_params(filters=n_filters_triage,
                               size=x_detect.shape[1],
                               n_neighbors=x_detect.shape[2],
                               output_path=change_extension(
                                   triagenet_name, 'yaml'))

    # train autoencoder
    logger.info('Training autoencoder network...')
    train_ae(x_ae, y_ae, n_features, n_iter, n_batch, train_step_size, ae_name)

    # save autoencoder model parameters
    logger.info('Saving autoencoder network parameters...')
    save_ae_network_params(n_input=x_ae.shape[1],
                           n_features=n_features,
                           output_path=change_extension(ae_name, 'yaml'))