Esempio n. 1
0
def test_can_use_neural_network_detector(path_to_tests,
                                         path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    data = RecordingsReader(path_to_standarized_data, loader='array').data

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index)
    NNT = NeuralNetTriage.load(triage_fname,
                               triage_th,
                               input_tensor=NND.waveform_tf)
    NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf)

    output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean)

    with tf.Session() as sess:
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        neuralnetwork.run_detect_triage_featurize(data, sess, NND.x_tf,
                                                  output_tf, neighbors, rot)
Esempio n. 2
0
def train(CONFIG, CONFIG_TRAIN, spike_train, data_folder):
    """
    Train neural network

    Parameters
    ----------
    CONFIG
        YASS configuration file
    CONFIG_TRAIN
        YASS Neural Network configuration file
    spike_train: numpy.ndarray
        Spike train, first column is spike index and second is main channel
    """
    logger = logging.getLogger(__name__)

    chosen_templates = CONFIG_TRAIN['templates']['ids']
    min_amp = CONFIG_TRAIN['templates']['minimum_amplitude']
    nspikes = CONFIG_TRAIN['training']['n_spikes']
    n_filters_detect = CONFIG_TRAIN['network_detector']['n_filters']
    n_iter = CONFIG_TRAIN['training']['n_iterations']
    n_batch = CONFIG_TRAIN['training']['n_batch']
    l2_reg_scale = CONFIG_TRAIN['training']['l2_regularization_scale']
    train_step_size = CONFIG_TRAIN['training']['step_size']
    detectnet_name = './' + CONFIG_TRAIN['network_detector']['name'] + '.ckpt'
    n_filters_triage = CONFIG_TRAIN['network_triage']['n_filters']
    triagenet_name = './' + CONFIG_TRAIN['network_triage']['name'] + '.ckpt'
    n_features = CONFIG_TRAIN['network_autoencoder']['n_features']
    ae_name = './' + CONFIG_TRAIN['network_autoencoder']['name'] + '.ckpt'

    # generate training data for detection, triage and autoencoder
    logger.info('Generating training data...')
    (x_detect, y_detect, x_triage, y_triage, x_ae,
     y_ae) = make_training_data(CONFIG,
                                spike_train,
                                chosen_templates,
                                min_amp,
                                nspikes,
                                data_folder=data_folder)

    # train detector
    NeuralNetDetector.train(x_detect, y_detect, n_filters_detect, n_iter,
                            n_batch, l2_reg_scale, train_step_size,
                            detectnet_name)

    # train triage
    NeuralNetTriage.train(x_triage, y_triage, n_filters_triage, n_iter,
                          n_batch, l2_reg_scale, train_step_size,
                          triagenet_name)

    # train autoencoder
    AutoEncoder.train(x_ae, y_ae, n_features, n_iter, n_batch, train_step_size,
                      ae_name)
Esempio n. 3
0
def run_neural_network(standardized_path, standardized_params, whiten_filter,
                       output_directory, if_file_exists, save_results):
    """Run neural network detection and autoencoder dimensionality reduction

    Returns
    -------
    scores
      Scores for all spikes

    spike_index_clear
      Spike indexes for clear spikes

    spike_index_all
      Spike indexes for all spikes
    """
    logger = logging.getLogger(__name__)

    CONFIG = read_config()
    TMP_FOLDER = CONFIG.path_to_output_directory

    # check if all scores, clear and collision spikes exist..
    path_to_score = os.path.join(TMP_FOLDER, 'scores_clear.npy')
    path_to_spike_index_clear = os.path.join(TMP_FOLDER,
                                             'spike_index_clear.npy')
    path_to_spike_index_all = os.path.join(TMP_FOLDER, 'spike_index_all.npy')
    path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy')

    path_to_standardized = os.path.join(TMP_FOLDER, 'preprocess',
                                        'standarized.bin')

    paths = [path_to_score, path_to_spike_index_clear, path_to_spike_index_all]
    exists = [os.path.exists(p) for p in paths]

    if (if_file_exists == 'overwrite' or not any(exists)):

        max_memory = (CONFIG.resources.max_memory_gpu
                      if GPU_ENABLED else CONFIG.resources.max_memory)

        # make tensorflow tensors and neural net classes
        detection_th = CONFIG.detect.neural_network_detector.threshold_spike
        triage_th = CONFIG.detect.neural_network_triage.threshold_collision
        detection_fname = CONFIG.detect.neural_network_detector.filename
        ae_fname = CONFIG.detect.neural_network_autoencoder.filename
        triage_fname = CONFIG.detect.neural_network_triage.filename
        n_channels = CONFIG.recordings.n_channels

        # open tensorflow for every chunk
        NND = NeuralNetDetector.load(detection_fname, detection_th,
                                     CONFIG.channel_index)
        NNAE = AutoEncoder.load(ae_fname, input_tensor=NND.waveform_tf)

        # run nn preprocess batch-wsie
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        # compute len of recording
        filename_dat = os.path.join(CONFIG.data.root_folder,
                                    CONFIG.data.recordings)
        fp = np.memmap(filename_dat, dtype='int16', mode='r')
        fp_len = fp.shape[0] / n_channels

        # compute batch indexes
        buffer_size = 200  # Cat: to set this in CONFIG file
        sampling_rate = CONFIG.recordings.sampling_rate
        #n_sec_chunk = CONFIG.resources.n_sec_chunk

        # Cat: TODO: Set a different size chunk for clustering vs. detection
        n_sec_chunk = 60

        # take chunks
        indexes = np.arange(0, fp_len, sampling_rate * n_sec_chunk)

        # add last bit of recording if it's shorter
        if indexes[-1] != fp_len:
            indexes = np.hstack((indexes, fp_len))

        idx_list = []
        for k in range(len(indexes) - 1):
            idx_list.append([
                indexes[k], indexes[k + 1], buffer_size,
                indexes[k + 1] - indexes[k] + buffer_size
            ])

        idx_list = np.int64(np.vstack(idx_list))[:20]

        #idx_list = idx_list

        logger.info("# of chunks: %i", len(idx_list))

        logger.info(idx_list)

        # run tensorflow
        processing_ctr = 0
        #chunk_ctr = 0

        # chunks to cycle over are 10 x as much as initial chosen data
        total_processing = len(idx_list) * n_sec_chunk

        # keep tensorflow open
        # save iteratively
        fname_detection = os.path.join(CONFIG.path_to_output_directory,
                                       'detect')
        if not os.path.exists(fname_detection):
            os.mkdir(fname_detection)

        # set tensorflow verbosity level
        tf.logging.set_verbosity(tf.logging.ERROR)

        # open etsnrflow session
        with tf.Session() as sess:
            #K.set_session(sess)
            NND.restore(sess)

            #triage = KerasModel(triage_fname,
            #                    allow_longer_waveform_length=True,
            #                    allow_more_channels=True)

            # read chunks of data first:
            # read chunk of raw standardized data
            # Cat: TODO: don't save to lists, might want to use numpy arrays directl
            #print (os.path.join(fname_detection,"detect_"+
            #                      str(chunk_ctr).zfill(5)+'.npz'))

            # loop over 10sec or 60 sec chunks
            for chunk_ctr, idx in enumerate(idx_list):
                if os.path.exists(
                        os.path.join(
                            fname_detection,
                            "detect_" + str(chunk_ctr).zfill(5) + '.npz')):
                    continue

                # reset lists
                spike_index_list = []
                #idx_clean_list = []
                energy_list = []
                TC_list = []
                offset_list = []

                # load chunk of data
                standardized_recording = binary_reader(idx, buffer_size,
                                                       path_to_standardized,
                                                       n_channels,
                                                       CONFIG.data.root_folder)

                # run detection on smaller chunks of data, e.g. 1 sec
                # Cat: TODO: add last bit at end in case short
                indexes = np.arange(0, standardized_recording.shape[0],
                                    sampling_rate)

                # run tensorflow over 1sec chunks in general
                for ctr, index in enumerate(indexes[:-1]):

                    # save absolute index of each subchunk
                    offset_list.append(idx[0] + indexes[ctr])

                    data_temp = standardized_recording[indexes[ctr]:indexes[ctr
                                                                            +
                                                                            1]]

                    # store size of recordings in case at end of dataset.
                    TC_list.append(data_temp.shape)

                    # run detect nn
                    res = NND.predict_recording(data_temp,
                                                sess=sess,
                                                output_names=('spike_index',
                                                              'waveform'))
                    spike_index, wfs = res

                    #idx_clean = (triage
                    #             .predict_with_threshold(x=wfs,
                    #                                     threshold=triage_th))

                    score = NNAE.predict(wfs, sess)
                    rot = NNAE.load_rotation(sess)
                    neighbors = n_steps_neigh_channels(CONFIG.neigh_channels,
                                                       2)

                    # idx_clean is the indexes of clear spikes in all_spikes
                    spike_index_list.append(spike_index)
                    #idx_clean_list.append(idx_clean)

                    # run AE nn; required for remove_axon function
                    # Cat: TODO: Do we really need this: can we get energy list faster?
                    #rot = NNAE.load_rotation()
                    energy_ = np.ptp(np.matmul(score[:, :, 0], rot.T), axis=1)
                    energy_list.append(energy_)

                    logger.info('processed chunk: %s/%s,  # spikes: %s',
                                str(processing_ctr), str(total_processing),
                                spike_index.shape)

                    processing_ctr += 1

                # save chunk of data in case crashes occur
                np.savez(os.path.join(fname_detection,
                                      "detect_" + str(chunk_ctr).zfill(5)),
                         spike_index_list=spike_index_list,
                         energy_list=energy_list,
                         TC_list=TC_list,
                         offset_list=offset_list)

        # load all saved data;
        spike_index_list = []
        energy_list = []
        TC_list = []
        offset_list = []
        for ctr, idx in enumerate(idx_list):
            data = np.load(fname_detection + '/detect_' + str(ctr).zfill(5) +
                           '.npz')
            spike_index_list.extend(data['spike_index_list'])
            energy_list.extend(data['energy_list'])
            TC_list.extend(data['TC_list'])
            offset_list.extend(data['offset_list'])

        # save all detected spikes pre axon_kill
        spike_index_all_pre_deduplication = fix_indexes_firstbatch_3(
            spike_index_list, offset_list, buffer_size, sampling_rate)
        np.save(
            os.path.join(TMP_FOLDER, 'spike_index_all_pre_deduplication.npy'),
            spike_index_all_pre_deduplication)

        # remove axons - compute axons to be killed
        logger.info(' removing axons in parallel')
        multi_procesing = 1
        if CONFIG.resources.multi_processing:
            keep = parmap.map(deduplicate,
                              list(
                                  zip(spike_index_list, energy_list, TC_list,
                                      np.arange(len(energy_list)))),
                              neighbors,
                              processes=CONFIG.resources.n_processors,
                              pm_pbar=True)
        else:
            keep = []
            for k in range(len(energy_list)):
                keep.append(
                    deduplicate(
                        (spike_index_list[k], energy_list[k], TC_list[k], k),
                        neighbors))

        # Cat: TODO Note that we're killing spike_index_all as well.
        # remove axons from clear spikes - keep only non-killed+clean events
        spike_index_all_postkill = []
        for k in range(len(spike_index_list)):
            spike_index_all_postkill.append(spike_index_list[k][keep[k][0]])

        logger.info(' fixing indexes from batching')
        spike_index_all = fix_indexes_firstbatch_3(spike_index_all_postkill,
                                                   offset_list, buffer_size,
                                                   sampling_rate)

        # get and clean all spikes
        spikes_all = spike_index_all

        #logger.info('Removing all indexes outside the allowed range to '
        #            'draw a complete waveform...')
        _n_observations = fp_len
        spikes_all, _ = detect.remove_incomplete_waveforms(
            spikes_all, CONFIG.spike_size + CONFIG.templates.max_shift,
            _n_observations)

        np.save(os.path.join(TMP_FOLDER, 'spike_index_all.npy'), spikes_all)

    else:
        spikes_all = np.load(os.path.join(TMP_FOLDER, 'spike_index_all.npy'))

    return spikes_all
Esempio n. 4
0
def run_neural_network(standarized_path, standarized_params, whiten_filter,
                       output_directory, if_file_exists, save_results):
    """Run neural network detection and autoencoder dimensionality reduction

    Returns
    -------
    scores
      Scores for all spikes

    spike_index_clear
      Spike indexes for clear spikes

    spike_index_all
      Spike indexes for all spikes
    """
    logger = logging.getLogger(__name__)

    CONFIG = read_config()

    folder = Path(CONFIG.data.root_folder, output_directory, 'detect')
    folder.mkdir(exist_ok=True)

    TMP_FOLDER = str(folder)

    # check if all scores, clear and collision spikes exist..
    path_to_score = os.path.join(TMP_FOLDER, 'scores_clear.npy')
    path_to_spike_index_clear = os.path.join(TMP_FOLDER,
                                             'spike_index_clear.npy')
    path_to_spike_index_all = os.path.join(TMP_FOLDER, 'spike_index_all.npy')
    path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy')

    paths = [path_to_score, path_to_spike_index_clear, path_to_spike_index_all]
    exists = [os.path.exists(p) for p in paths]

    if (if_file_exists == 'overwrite'
            or if_file_exists == 'abort' and not any(exists)
            or if_file_exists == 'skip' and not all(exists)):
        max_memory = (CONFIG.resources.max_memory_gpu
                      if GPU_ENABLED else CONFIG.resources.max_memory)

        # instantiate batch processor
        bp = BatchProcessor(standarized_path,
                            standarized_params['dtype'],
                            standarized_params['n_channels'],
                            standarized_params['data_order'],
                            max_memory,
                            buffer_size=CONFIG.spike_size)

        # load parameters
        detection_th = CONFIG.detect.neural_network_detector.threshold_spike
        triage_th = CONFIG.detect.neural_network_triage.threshold_collision
        detection_fname = CONFIG.detect.neural_network_detector.filename
        ae_fname = CONFIG.detect.neural_network_autoencoder.filename
        triage_fname = CONFIG.detect.neural_network_triage.filename

        # instantiate neural networks
        NND = NeuralNetDetector.load(detection_fname, detection_th,
                                     CONFIG.channel_index)
        NNT = NeuralNetTriage.load(triage_fname,
                                   triage_th,
                                   input_tensor=NND.waveform_tf)
        NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf)

        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)
        rotation = NNAE.load_rotation()

        # gather all output tensors
        output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean)

        # run detection
        with tf.Session() as sess:

            # get values of above tensors
            NND.restore(sess)
            NNAE.restore(sess)
            NNT.restore(sess)

            mc = bp.multi_channel_apply
            res = mc(neuralnetwork.run_detect_triage_featurize,
                     mode='memory',
                     cleanup_function=neuralnetwork.fix_indexes,
                     sess=sess,
                     x_tf=NND.x_tf,
                     output_tf=output_tf,
                     rot=rotation,
                     neighbors=neighbors)

        # get clear spikes
        clear = np.concatenate([element[1] for element in res], axis=0)
        logger.info('Removing clear indexes outside the allowed range to '
                    'draw a complete waveform...')
        clear, idx = detect.remove_incomplete_waveforms(
            clear, CONFIG.spike_size + CONFIG.templates.max_shift,
            bp.reader._n_observations)

        # get all spikes
        spikes_all = np.concatenate([element[2] for element in res], axis=0)
        logger.info('Removing indexes outside the allowed range to '
                    'draw a complete waveform...')
        spikes_all, _ = detect.remove_incomplete_waveforms(
            spikes_all, CONFIG.spike_size + CONFIG.templates.max_shift,
            bp.reader._n_observations)

        # get scores
        scores = np.concatenate([element[0] for element in res], axis=0)
        logger.info('Removing scores for indexes outside the allowed range to '
                    'draw a complete waveform...')
        scores = scores[idx]

        # transform scores to location + shape feature space
        # TODO: move this to another place

        if CONFIG.cluster.method == 'location':
            threshold = 2
            scores = get_locations_features(scores, rotation, clear[:, 1],
                                            CONFIG.channel_index, CONFIG.geom,
                                            threshold)
            idx_nan = np.where(np.isnan(np.sum(scores, axis=(1, 2))))[0]
            scores = np.delete(scores, idx_nan, 0)
            clear = np.delete(clear, idx_nan, 0)

        # save partial results if required
        if save_results:
            # save clear spikes
            np.save(path_to_spike_index_clear, clear)
            logger.info('Saved spike index clear in {}...'.format(
                path_to_spike_index_clear))

            # save all ppikes
            np.save(path_to_spike_index_all, spikes_all)
            logger.info('Saved spike index all in {}...'.format(
                path_to_spike_index_all))

            # save rotation
            np.save(path_to_rotation, rotation)
            logger.info(
                'Saved rotation matrix in {}...'.format(path_to_rotation))

            # saves scores
            np.save(path_to_score, scores)
            logger.info('Saved spike scores in {}...'.format(path_to_score))

    elif if_file_exists == 'abort' and any(exists):
        conflict = [p for p, e in zip(paths, exists) if e]
        message = reduce(lambda x, y: str(x) + ', ' + str(y), conflict)
        raise ValueError('if_file_exists was set to abort, the '
                         'program halted since the following files '
                         'already exist: {}'.format(message))
    elif if_file_exists == 'skip' and all(exists):
        logger.warning('Skipped execution. All output files exist'
                       ', loading them...')
        scores = np.load(path_to_score)
        clear = np.load(path_to_spike_index_clear)
        spikes_all = np.load(path_to_spike_index_all)

    else:
        raise ValueError(
            'Invalid value for if_file_exists {}'
            'must be one of overwrite, abort or skip'.format(if_file_exists))

    return scores, clear, spikes_all
Esempio n. 5
0
def prepare_nn(channel_index, whiten_filter, threshold_detect,
               threshold_triage, detector_filename, autoencoder_filename,
               triage_filename):
    """Prepare neural net tensors in advance. This is to effciently run
    neural net with batch processor as we don't have to recreate tf
    tensors in every batch

    Parameters
    ----------
    channel_index: np.array (n_channels, n_neigh)
        Each row indexes its neighboring channels.
        For example, channel_index[c] is the index of
        neighboring channels (including itself)
        If any value is equal to n_channels, it is nothing but
        a space holder in a case that a channel has less than
        n_neigh neighboring channels

    whiten_filter: numpy.ndarray (n_channels, n_neigh, n_neigh)
        whitening matrix such that whiten_filter[c] is the whitening
        filter of channel c and its neighboring channel determined from
        channel_index.

    threshold_detect: int
        threshold for neural net detection

    threshold_triage: int
        threshold for neural net triage

    detector_filename: str
        location of trained neural net detectior

    autoencoder_filename: str
        location of trained neural net autoencoder

    triage_filename: str
        location of trained neural net triage

    Returns
    -------
    x_tf: tf.tensors (n_observations, n_channels)
        placeholder of recording for running tensorflow

    output_tf: tuple of tf.tensors
        a tuple of tensorflow tensors that produce score, spike_index_clear,
        and spike_index_collision

    NND: class
        an instance of class, NeuralNetDetector

    NNT: class
        an instance of class, NeuralNetTriage
    """

    # placeholder for input recording
    x_tf = tf.placeholder("float", [None, None])

    # load Neural Net's
    NND = NeuralNetDetector(detector_filename)
    NNAE = AutoEncoder(autoencoder_filename)
    NNT = NeuralNetTriage(triage_filename)

    # make spike_index tensorflow tensor
    spike_index_tf_all = NND.make_detection_tf_tensors(x_tf, channel_index,
                                                       threshold_detect)

    # remove edge spike time
    spike_index_tf = remove_edge_spikes(x_tf, spike_index_tf_all,
                                        NND.filters_dict['size'])

    # make waveform tensorflow tensor
    waveform_tf = make_waveform_tf_tensor(x_tf, spike_index_tf, channel_index,
                                          NND.filters_dict['size'])

    # make score tensorflow tensor from waveform
    score_tf = NNAE.make_score_tf_tensor(waveform_tf)

    # run neural net triage
    nneigh = NND.filters_dict['n_neighbors']
    idx_clean = NNT.triage_wf(waveform_tf[:, :, :nneigh], threshold_triage)

    # gather all output tensors
    output_tf = (score_tf, spike_index_tf, idx_clean)

    return x_tf, output_tf, NND, NNAE, NNT
Esempio n. 6
0
def test_splitting_in_batches_does_not_affect(path_to_config,
                                              path_to_sample_pipeline_folder,
                                              make_tmp_folder,
                                              path_to_standardized_data):
    yass.set_config(path_to_config, make_tmp_folder)
    CONFIG = yass.read_config()

    PATH_TO_DATA = path_to_standardized_data

    with open(path.join(path_to_sample_pipeline_folder, 'preprocess',
                        'standardized.yaml')) as f:
        PARAMS = yaml.load(f)

    channel_index = make_channel_index(CONFIG.neigh_channels,
                                       CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th,
                                 channel_index)
    triage = KerasModel(triage_fname,
                        allow_longer_waveform_length=True,
                        allow_more_channels=True)
    NNAE = AutoEncoder.load(ae_fname, input_tensor=NND.waveform_tf)

    bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'],
                        PARAMS['data_order'], '100KB',
                        buffer_size=CONFIG.spike_size)

    out = ('spike_index', 'waveform')
    fn = neuralnetwork.apply.fix_indexes_spike_index

    # detector
    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)

        res = bp.multi_channel_apply(NND.predict_recording,
                                     mode='memory',
                                     sess=sess,
                                     output_names=out,
                                     cleanup_function=fn)

    spike_index_new = np.concatenate([element[0] for element in res], axis=0)
    wfs = np.concatenate([element[1] for element in res], axis=0)

    idx_clean = triage.predict_with_threshold(wfs, triage_th)
    score = NNAE.predict(wfs)
    rot = NNAE.load_rotation()
    neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

    (score_clear_new,
        spike_index_clear_new) = post_processing(score,
                                                 spike_index_new,
                                                 idx_clean,
                                                 rot,
                                                 neighbors)

    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)

        res = bp.multi_channel_apply(NND.predict_recording,
                                     mode='memory',
                                     sess=sess,
                                     output_names=('spike_index',
                                                   'waveform'),
                                     cleanup_function=fn)

    spike_index_batch, wfs = zip(*res)

    spike_index_batch = np.concatenate(spike_index_batch, axis=0)
    wfs = np.concatenate(wfs, axis=0)

    idx_clean = triage.predict_with_threshold(x=wfs,
                                              threshold=triage_th)

    score = NNAE.predict(wfs)
    rot = NNAE.load_rotation()
    neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

    (score_clear_batch,
        spike_index_clear_batch) = post_processing(score,
                                                   spike_index_batch,
                                                   idx_clean,
                                                   rot,
                                                   neighbors)
Esempio n. 7
0
def test_splitting_in_batches_does_not_affect(path_to_tests,
                                              path_to_standarized_data,
                                              path_to_sample_pipeline_folder):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    PATH_TO_DATA = path_to_standarized_data

    data = RecordingsReader(PATH_TO_DATA, loader='array').data

    with open(
            path.join(path_to_sample_pipeline_folder, 'preprocess',
                      'standarized.yaml')) as f:
        PARAMS = yaml.load(f)

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index)
    NNT = NeuralNetTriage.load(triage_fname,
                               triage_th,
                               input_tensor=NND.waveform_tf)
    NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf)

    output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean)

    # run all at once
    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize(
            data, sess, NND.x_tf, output_tf, neighbors, rot)

    # run in batches - buffer size makes sure we can detect spikes if they
    # appear at the end of any batch
    bp = BatchProcessor(PATH_TO_DATA,
                        PARAMS['dtype'],
                        PARAMS['n_channels'],
                        PARAMS['data_order'],
                        '100KB',
                        buffer_size=CONFIG.spike_size)

    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        res = bp.multi_channel_apply(
            neuralnetwork.run_detect_triage_featurize,
            mode='memory',
            cleanup_function=neuralnetwork.fix_indexes,
            sess=sess,
            x_tf=NND.x_tf,
            output_tf=output_tf,
            rot=rot,
            neighbors=neighbors)

    scores_batch = np.concatenate([element[0] for element in res], axis=0)
    clear_batch = np.concatenate([element[1] for element in res], axis=0)
    collision_batch = np.concatenate([element[2] for element in res], axis=0)

    np.testing.assert_array_equal(clear_batch, clear)
    np.testing.assert_array_equal(collision_batch, collision)
    np.testing.assert_array_equal(scores_batch, scores)