Пример #1
0
def test_can_use_neural_network_detector(path_to_tests,
                                         path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    data = RecordingsReader(path_to_standarized_data, loader='array').data

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index)
    NNT = NeuralNetTriage.load(triage_fname,
                               triage_th,
                               input_tensor=NND.waveform_tf)
    NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf)

    output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean)

    with tf.Session() as sess:
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        neuralnetwork.run_detect_triage_featurize(data, sess, NND.x_tf,
                                                  output_tf, neighbors, rot)
Пример #2
0
def test_can_use_detect_and_triage_after_reload(path_to_tests,
                                                path_to_sample_pipeline_folder,
                                                tmp_folder,
                                                path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    (x_detect, y_detect, x_triage, y_triage, x_ae,
     y_ae) = make_training_data(CONFIG, spike_train, chosen_templates,
                                min_amplitude, n_spikes,
                                path_to_sample_pipeline_folder)

    _, waveform_length, n_neighbors = x_detect.shape

    path_to_model = path.join(tmp_folder, 'detect-net.ckpt')

    detector = NeuralNetDetector(path_to_model,
                                 filters,
                                 waveform_length,
                                 n_neighbors,
                                 threshold=0.5,
                                 channel_index=CONFIG.channel_index,
                                 n_iter=10)

    detector.fit(x_detect, y_detect)

    detector = NeuralNetDetector.load(path_to_model,
                                      threshold=0.5,
                                      channel_index=CONFIG.channel_index)

    triage = NeuralNetTriage(path_to_model,
                             filters,
                             waveform_length,
                             n_neighbors,
                             threshold=0.5,
                             n_iter=10)

    triage.fit(x_detect, y_detect)

    triage = NeuralNetTriage.load(path_to_model, threshold=0.5)

    data = RecordingExplorer(path_to_standarized_data).reader.data

    output_names = ('spike_index', 'waveform', 'probability')

    (spike_index, waveform,
     proba) = detector.predict(data, output_names=output_names)

    triage.predict(waveform[:, :, :n_neighbors])
Пример #3
0
def test_can_reload_triage(path_to_tests, path_to_sample_pipeline_folder,
                           tmp_folder):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    (x_detect, y_detect, x_triage, y_triage, x_ae,
     y_ae) = make_training_data(CONFIG, spike_train, chosen_templates,
                                min_amplitude, n_spikes,
                                path_to_sample_pipeline_folder)

    _, waveform_length, n_neighbors = x_triage.shape

    path_to_model = path.join(tmp_folder, 'triage-net.ckpt')

    triage = NeuralNetTriage(path_to_model,
                             filters,
                             waveform_length,
                             n_neighbors,
                             threshold=0.5,
                             n_iter=10)

    triage.fit(x_detect, y_detect)

    NeuralNetTriage.load(path_to_model, threshold=0.5)
Пример #4
0
def run_neural_network(standarized_path, standarized_params, whiten_filter,
                       output_directory, if_file_exists, save_results):
    """Run neural network detection and autoencoder dimensionality reduction

    Returns
    -------
    scores
      Scores for all spikes

    spike_index_clear
      Spike indexes for clear spikes

    spike_index_all
      Spike indexes for all spikes
    """
    logger = logging.getLogger(__name__)

    CONFIG = read_config()

    folder = Path(CONFIG.data.root_folder, output_directory, 'detect')
    folder.mkdir(exist_ok=True)

    TMP_FOLDER = str(folder)

    # check if all scores, clear and collision spikes exist..
    path_to_score = os.path.join(TMP_FOLDER, 'scores_clear.npy')
    path_to_spike_index_clear = os.path.join(TMP_FOLDER,
                                             'spike_index_clear.npy')
    path_to_spike_index_all = os.path.join(TMP_FOLDER, 'spike_index_all.npy')
    path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy')

    paths = [path_to_score, path_to_spike_index_clear, path_to_spike_index_all]
    exists = [os.path.exists(p) for p in paths]

    if (if_file_exists == 'overwrite'
            or if_file_exists == 'abort' and not any(exists)
            or if_file_exists == 'skip' and not all(exists)):
        max_memory = (CONFIG.resources.max_memory_gpu
                      if GPU_ENABLED else CONFIG.resources.max_memory)

        # instantiate batch processor
        bp = BatchProcessor(standarized_path,
                            standarized_params['dtype'],
                            standarized_params['n_channels'],
                            standarized_params['data_order'],
                            max_memory,
                            buffer_size=CONFIG.spike_size)

        # load parameters
        detection_th = CONFIG.detect.neural_network_detector.threshold_spike
        triage_th = CONFIG.detect.neural_network_triage.threshold_collision
        detection_fname = CONFIG.detect.neural_network_detector.filename
        ae_fname = CONFIG.detect.neural_network_autoencoder.filename
        triage_fname = CONFIG.detect.neural_network_triage.filename

        # instantiate neural networks
        NND = NeuralNetDetector.load(detection_fname, detection_th,
                                     CONFIG.channel_index)
        NNT = NeuralNetTriage.load(triage_fname,
                                   triage_th,
                                   input_tensor=NND.waveform_tf)
        NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf)

        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)
        rotation = NNAE.load_rotation()

        # gather all output tensors
        output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean)

        # run detection
        with tf.Session() as sess:

            # get values of above tensors
            NND.restore(sess)
            NNAE.restore(sess)
            NNT.restore(sess)

            mc = bp.multi_channel_apply
            res = mc(neuralnetwork.run_detect_triage_featurize,
                     mode='memory',
                     cleanup_function=neuralnetwork.fix_indexes,
                     sess=sess,
                     x_tf=NND.x_tf,
                     output_tf=output_tf,
                     rot=rotation,
                     neighbors=neighbors)

        # get clear spikes
        clear = np.concatenate([element[1] for element in res], axis=0)
        logger.info('Removing clear indexes outside the allowed range to '
                    'draw a complete waveform...')
        clear, idx = detect.remove_incomplete_waveforms(
            clear, CONFIG.spike_size + CONFIG.templates.max_shift,
            bp.reader._n_observations)

        # get all spikes
        spikes_all = np.concatenate([element[2] for element in res], axis=0)
        logger.info('Removing indexes outside the allowed range to '
                    'draw a complete waveform...')
        spikes_all, _ = detect.remove_incomplete_waveforms(
            spikes_all, CONFIG.spike_size + CONFIG.templates.max_shift,
            bp.reader._n_observations)

        # get scores
        scores = np.concatenate([element[0] for element in res], axis=0)
        logger.info('Removing scores for indexes outside the allowed range to '
                    'draw a complete waveform...')
        scores = scores[idx]

        # transform scores to location + shape feature space
        # TODO: move this to another place

        if CONFIG.cluster.method == 'location':
            threshold = 2
            scores = get_locations_features(scores, rotation, clear[:, 1],
                                            CONFIG.channel_index, CONFIG.geom,
                                            threshold)
            idx_nan = np.where(np.isnan(np.sum(scores, axis=(1, 2))))[0]
            scores = np.delete(scores, idx_nan, 0)
            clear = np.delete(clear, idx_nan, 0)

        # save partial results if required
        if save_results:
            # save clear spikes
            np.save(path_to_spike_index_clear, clear)
            logger.info('Saved spike index clear in {}...'.format(
                path_to_spike_index_clear))

            # save all ppikes
            np.save(path_to_spike_index_all, spikes_all)
            logger.info('Saved spike index all in {}...'.format(
                path_to_spike_index_all))

            # save rotation
            np.save(path_to_rotation, rotation)
            logger.info(
                'Saved rotation matrix in {}...'.format(path_to_rotation))

            # saves scores
            np.save(path_to_score, scores)
            logger.info('Saved spike scores in {}...'.format(path_to_score))

    elif if_file_exists == 'abort' and any(exists):
        conflict = [p for p, e in zip(paths, exists) if e]
        message = reduce(lambda x, y: str(x) + ', ' + str(y), conflict)
        raise ValueError('if_file_exists was set to abort, the '
                         'program halted since the following files '
                         'already exist: {}'.format(message))
    elif if_file_exists == 'skip' and all(exists):
        logger.warning('Skipped execution. All output files exist'
                       ', loading them...')
        scores = np.load(path_to_score)
        clear = np.load(path_to_spike_index_clear)
        spikes_all = np.load(path_to_spike_index_all)

    else:
        raise ValueError(
            'Invalid value for if_file_exists {}'
            'must be one of overwrite, abort or skip'.format(if_file_exists))

    return scores, clear, spikes_all
Пример #5
0
def test_splitting_in_batches_does_not_affect(path_to_tests,
                                              path_to_standarized_data,
                                              path_to_sample_pipeline_folder):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    PATH_TO_DATA = path_to_standarized_data

    data = RecordingsReader(PATH_TO_DATA, loader='array').data

    with open(
            path.join(path_to_sample_pipeline_folder, 'preprocess',
                      'standarized.yaml')) as f:
        PARAMS = yaml.load(f)

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index)
    NNT = NeuralNetTriage.load(triage_fname,
                               triage_th,
                               input_tensor=NND.waveform_tf)
    NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf)

    output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean)

    # run all at once
    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize(
            data, sess, NND.x_tf, output_tf, neighbors, rot)

    # run in batches - buffer size makes sure we can detect spikes if they
    # appear at the end of any batch
    bp = BatchProcessor(PATH_TO_DATA,
                        PARAMS['dtype'],
                        PARAMS['n_channels'],
                        PARAMS['data_order'],
                        '100KB',
                        buffer_size=CONFIG.spike_size)

    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        res = bp.multi_channel_apply(
            neuralnetwork.run_detect_triage_featurize,
            mode='memory',
            cleanup_function=neuralnetwork.fix_indexes,
            sess=sess,
            x_tf=NND.x_tf,
            output_tf=output_tf,
            rot=rot,
            neighbors=neighbors)

    scores_batch = np.concatenate([element[0] for element in res], axis=0)
    clear_batch = np.concatenate([element[1] for element in res], axis=0)
    collision_batch = np.concatenate([element[2] for element in res], axis=0)

    np.testing.assert_array_equal(clear_batch, clear)
    np.testing.assert_array_equal(collision_batch, collision)
    np.testing.assert_array_equal(scores_batch, scores)