Beispiel #1
0
def test_can_use_neural_network_detector(path_to_tests):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    data = RecordingsReader(path.join(path_to_tests, 'data/standarized.bin'),
                            loader='array').data

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    whiten_filter = np.tile(
        np.eye(channel_index.shape[1], dtype='float32')[np.newaxis, :, :],
        [channel_index.shape[0], 1, 1])

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    (x_tf, output_tf, NND, NNAE,
     NNT) = neuralnetwork.prepare_nn(channel_index, whiten_filter,
                                     detection_th, triage_th, detection_fname,
                                     ae_fname, triage_fname)

    with tf.Session() as sess:
        # get values of above tensors
        NND.saver.restore(sess, NND.path_to_detector_model)
        NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model)
        NNT.saver.restore(sess, NNT.path_to_triage_model)
        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        neuralnetwork.run_detect_triage_featurize(data, sess, x_tf, output_tf,
                                                  neighbors, rot)
Beispiel #2
0
def test_can_use_neural_network_detector(path_to_tests,
                                         path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    data = RecordingsReader(path_to_standarized_data, loader='array').data

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index)
    NNT = NeuralNetTriage.load(triage_fname,
                               triage_th,
                               input_tensor=NND.waveform_tf)
    NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf)

    output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean)

    with tf.Session() as sess:
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        neuralnetwork.run_detect_triage_featurize(data, sess, NND.x_tf,
                                                  output_tf, neighbors, rot)
Beispiel #3
0
    def __init__(self, mapping, output_directory=None):
        self._logger = logging.getLogger(__name__)

        # FIXME: not raising errors due to schema validation for now
        mapping = validate(mapping, silent=True)

        _processes = mapping['resources']['processes']
        mapping['resources']['processes'] = (
            multiprocess.cpu_count() if _processes == 'max' else _processes)

        self._frozenjson = FrozenJSON(mapping)

        if output_directory is not None:
            if path.isabs(output_directory):
                self._path_to_output_directory = output_directory
            else:
                _ = Path(self.data.root_folder, output_directory)
                self._path_to_output_directory = str(_)
        else:
            self._path_to_output_directory = None

        # init the rest of the parameters, these parameters are used
        # througout the pipeline so we compute them once to avoid redudant
        # computations

        # GEOMETRY PARAMETERS
        path_to_geom = path.join(self.data.root_folder, self.data.geometry)

        self._set_param('geom',
                        geom.parse(path_to_geom, self.recordings.n_channels))

        # check dimensions of the geometry file
        n_channels_geom, _ = self.geom.shape

        if self.recordings.n_channels != n_channels_geom:
            raise ValueError('Channels in the geometry file ({}) does not '
                             'value in the configuration file ({})'.format(
                                 n_channels_geom, self.recordings.n_channels))

        neigh_channels = geom.find_channel_neighbors(
            self.geom, self.recordings.spatial_radius)
        self._set_param('neigh_channels', neigh_channels)

        channel_groups = geom.make_channel_groups(self.recordings.n_channels,
                                                  self.neigh_channels,
                                                  self.geom)
        self._set_param('channel_groups', channel_groups)

        self._set_param(
            'spike_size',
            int(
                np.round(self.recordings.spike_size_ms *
                         self.recordings.sampling_rate / (2 * 1000))))

        channel_index = geom.make_channel_index(self.neigh_channels,
                                                self.geom,
                                                steps=2)

        self._set_param('channel_index', channel_index)
Beispiel #4
0
def run_neural_network(standardized_path,
                       standardized_dtype,
                       output_directory,
                       run_chunk_sec='full'):
    """Run neural network detection
    """
    logger = logging.getLogger(__name__)

    CONFIG = read_config()

    # load NN detector
    detector = Detect(CONFIG.neuralnetwork.detect.n_filters,
                      CONFIG.spike_size_nn, CONFIG.channel_index)
    detector.load(CONFIG.neuralnetwork.detect.filename)

    # load NN denoiser
    denoiser = Denoise(CONFIG.neuralnetwork.denoise.n_filters,
                       CONFIG.neuralnetwork.denoise.filter_sizes,
                       CONFIG.spike_size_nn)
    denoiser.load(CONFIG.neuralnetwork.denoise.filename)

    # get data reader
    batch_length = CONFIG.resources.n_sec_chunk * CONFIG.resources.n_processors
    n_sec_chunk = CONFIG.resources.n_sec_chunk_gpu_detect
    print("   batch length to (sec): ", batch_length,
          " (longer increase speed a bit)")
    print("   length of each seg (sec): ", n_sec_chunk)
    buffer = CONFIG.spike_size_nn
    if run_chunk_sec == 'full':
        chunk_sec = None
    else:
        chunk_sec = run_chunk_sec

    reader = READER(standardized_path, standardized_dtype, CONFIG,
                    batch_length, buffer, chunk_sec)

    # neighboring channels
    channel_index_dedup = make_channel_index(CONFIG.neigh_channels,
                                             CONFIG.geom,
                                             steps=2)

    # threshold for neuralnet detection
    detect_threshold = CONFIG.detect.threshold

    # loop over each chunk
    batch_ids = np.arange(reader.n_batches)
    batch_ids_split = np.split(batch_ids, len(CONFIG.torch_devices))
    processes = []
    for ii, device in enumerate(CONFIG.torch_devices):
        p = mp.Process(target=run_nn_detction_batch,
                       args=(batch_ids_split[ii], output_directory, reader,
                             n_sec_chunk, detector, denoiser,
                             channel_index_dedup, detect_threshold, device))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()
Beispiel #5
0
def test_can_compute_whiten_matrix(data, data_info, path_to_geometry):
    geometry = parse(path_to_geometry, data_info['recordings']['n_channels'])
    neighbors = find_channel_neighbors(geometry, radius=70)
    channel_index = make_channel_index(neighbors, geometry)

    # FIXME: using the same formula from yass/config/config.py, might be
    # better to avoid having this hardcoded
    spike_size = int(
        np.round(data_info['recordings']['spike_size_ms'] *
                 data_info['recordings']['sampling_rate'] / (2 * 1000)))

    whiten._matrix(data, channel_index, spike_size)
Beispiel #6
0
def run_voltage_treshold(standardized_path,
                         standardized_dtype,
                         output_directory,
                         run_chunk_sec='full'):
    """Run detection that thresholds on amplitude
    """
    logger = logging.getLogger(__name__)

    CONFIG = read_config()

    # get data reader
    #n_sec_chunk = CONFIG.resources.n_sec_chunk*CONFIG.resources.n_processors
    batch_length = CONFIG.resources.n_sec_chunk
    n_sec_chunk = 0.5
    print("   batch length to (sec): ", batch_length,
          " (longer increase speed a bit)")
    print("   length of each seg (sec): ", n_sec_chunk)
    buffer = CONFIG.spike_size
    if run_chunk_sec == 'full':
        chunk_sec = None
    else:
        chunk_sec = run_chunk_sec

    reader = READER(standardized_path, standardized_dtype, CONFIG,
                    batch_length, buffer, chunk_sec)

    # number of processed chunks
    n_mini_per_big_batch = int(np.ceil(batch_length / n_sec_chunk))
    total_processing = int(reader.n_batches * n_mini_per_big_batch)

    # neighboring channels
    channel_index = make_channel_index(CONFIG.neigh_channels,
                                       CONFIG.geom,
                                       steps=2)

    if CONFIG.resources.multi_processing:
        parmap.starmap(run_voltage_threshold_parallel,
                       list(zip(np.arange(reader.n_batches))),
                       reader,
                       n_sec_chunk,
                       CONFIG.detect.threshold,
                       channel_index,
                       output_directory,
                       processes=CONFIG.resources.n_processors,
                       pm_pbar=True)
    else:
        for batch_id in range(reader.n_batches):
            run_voltage_threshold_parallel(batch_id, reader, n_sec_chunk,
                                           CONFIG.detect.threshold,
                                           channel_index, output_directory)
Beispiel #7
0
    def __init__(self, mapping):
        mapping = validate(mapping)

        super(Config, self).__init__(mapping)

        self._logger = logging.getLogger(__name__)

        # init the rest of the parameters, these parameters are used
        # througout the pipeline so we compute them once to avoid redudant
        # computations

        # GEOMETRY PARAMETERS
        path_to_geom = path.join(self.data.root_folder, self.data.geometry)
        self._set_param('geom',
                        geom.parse(path_to_geom, self.recordings.n_channels))

        # check dimensions of the geometry file
        n_channels_geom, _ = self.geom.shape

        if self.recordings.n_channels != n_channels_geom:
            raise ValueError('Channels in the geometry file ({}) does not '
                             'value in the configuration file ({})'
                             .format(n_channels_geom,
                                     self.recordings.n_channels))

        neigh_channels = geom.find_channel_neighbors(
            self.geom, self.recordings.spatial_radius)
        self._set_param('neigh_channels', neigh_channels)

        channel_groups = geom.make_channel_groups(
            self.recordings.n_channels, self.neigh_channels, self.geom)
        self._set_param('channel_groups', channel_groups)

        self._set_param(
            'spike_size',
            int(
                np.round(self.recordings.spike_size_ms *
                         self.recordings.sampling_rate / (2 * 1000))))

        channel_index = geom.make_channel_index(self.neigh_channels,
                                                self.geom, steps=2)
        self._set_param('channel_index', channel_index)
Beispiel #8
0
def get_o_layer(standarized_path,
                standarized_params,
                output_directory='tmp/',
                output_dtype='float32',
                output_filename='o_layer.bin',
                if_file_exists='skip',
                save_partial_results=False):
    """Get the output of NN detector instead of outputting the spike index
    """

    CONFIG = read_config()
    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom, 1)

    x_tf = tf.placeholder("float", [None, None])

    # load Neural Net's
    detection_fname = CONFIG.detect.neural_network_detector.filename
    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    NND = NeuralNetDetector(detection_fname)
    o_layer_tf = NND.make_o_layer_tf_tensors(x_tf, channel_index, detection_th)

    bp = BatchProcessor(standarized_path,
                        standarized_params['dtype'],
                        standarized_params['n_channels'],
                        standarized_params['data_format'],
                        CONFIG.resources.max_memory,
                        buffer_size=CONFIG.spike_size)

    TMP = os.path.join(CONFIG.data.root_folder, output_directory)
    _output_path = os.path.join(TMP, output_filename)
    (o_path, o_params) = bp.multi_channel_apply(_get_o_layer,
                                                mode='disk',
                                                cleanup_function=fix_indexes,
                                                output_path=_output_path,
                                                cast_dtype=output_dtype,
                                                x_tf=x_tf,
                                                o_layer_tf=o_layer_tf,
                                                NND=NND)

    return o_path, o_params
Beispiel #9
0
    def __init__(self, mapping, output_directory=None):
        self._logger = logging.getLogger(__name__)

        # FIXME: not raising errors due to schema validation for now
        mapping = validate(mapping, silent=True)

        self._frozenjson = FrozenJSON(mapping)

        if output_directory is not None:
            if path.isabs(output_directory):
                self._path_to_output_directory = output_directory
            else:
                _ = Path(self.data.root_folder, output_directory)
                self._path_to_output_directory = str(_)
        else:
            self._path_to_output_directory = None

        # init the rest of the parameters, these parameters are used
        # througout the pipeline so we compute them once to avoid redudant
        # computations

        # GEOMETRY PARAMETERS
        path_to_geom = path.join(self.data.root_folder, self.data.geometry)

        self._set_param('geom',
                        geom.parse(path_to_geom, self.recordings.n_channels))

        # check dimensions of the geometry file
        n_channels_geom, _ = self.geom.shape

        if self.recordings.n_channels != n_channels_geom:
            raise ValueError('Channels in the geometry file ({}) does not '
                             'value in the configuration file ({})'.format(
                                 n_channels_geom, self.recordings.n_channels))

        neigh_channels = geom.find_channel_neighbors(
            self.geom, self.recordings.spatial_radius)
        self._set_param('neigh_channels', neigh_channels)

        # spike size long (to cover full axonal propagation)
        spike_size = int(
            np.round(self.recordings.spike_size_ms *
                     self.recordings.sampling_rate / 1000))
        if spike_size % 2 == 0:
            spike_size += 1
        self._set_param('spike_size', spike_size)

        # spike size center
        if self.recordings.center_spike_size_ms is not None:
            center_spike_size = int(
                np.round(self.recordings.center_spike_size_ms *
                         self.recordings.sampling_rate / 1000))
            if center_spike_size % 2 == 0:
                center_spike_size += 1
        else:
            center_spike_size = int(np.copy(spike_size))
        self._set_param('center_spike_size', center_spike_size)

        # channel index for nn
        channel_index = geom.make_channel_index(self.neigh_channels,
                                                self.geom,
                                                steps=1)
        self._set_param('channel_index', channel_index)

        # spike size to nn
        if self.neuralnetwork.apply_nn:
            if self.neuralnetwork.training.spike_size_ms is None:
                detect_saved_file = torch.load(
                    self.neuralnetwork.detect.filename,
                    map_location=lambda storage, loc: storage)
                spike_size_nn_detector = detect_saved_file[
                    'temporal_filter1.0.weight'].shape[2]

                denoised_saved_file = torch.load(
                    self.neuralnetwork.denoise.filename,
                    map_location=lambda storage, loc: storage)
                spike_size_nn_denoiser = denoised_saved_file[
                    'out.weight'].shape[0]

                del detect_saved_file
                del denoised_saved_file
                torch.cuda.empty_cache()

                if spike_size_nn_detector != spike_size_nn_denoiser:
                    raise ValueError(
                        'input spike sizes of nn detector and denoiser do not match. change models'
                    )

                else:
                    spike_size_nn = spike_size_nn_detector
            else:
                spike_size_nn = int(
                    np.round(self.neuralnetwork.training.spike_size_ms *
                             self.recordings.sampling_rate / 1000))
                if spike_size_nn % 2 == 0:
                    spike_size_nn += 1
            self._set_param('spike_size_nn', spike_size_nn)
        else:
            self._set_param('spike_size_nn', center_spike_size)

        # torch devices
        devices = []
        if torch.cuda.is_available():
            n_processors = np.min(
                (torch.cuda.device_count(), self.resources.n_gpu_processors))
            for j in range(n_processors):
                devices.append(torch.device("cuda:{}".format(j)))
        if len(devices) == 0:
            devices = [torch.device("cpu")]
        self._set_param('torch_devices', devices)

        # compute the length of recording
        filename_dat = os.path.join(self.data.root_folder,
                                    self.data.recordings)
        filesize = os.path.getsize(filename_dat)
        dtype = np.dtype(self.recordings.dtype)
        rec_len = int(filesize / dtype.itemsize / self.recordings.n_channels)
        self._set_param('rec_len', rec_len)

        #
        if self.recordings.final_deconv_chunk is None:
            start = 0
            end = int(np.ceil(self.rec_len / self.recordings.sampling_rate))
        else:
            start = int(np.floor(self.recordings.final_deconv_chunk[0]))
            end = int(np.ceil(self.recordings.final_deconv_chunk[1]))
        self._set_param('final_deconv_chunk', [start, end])

        #
        if self.recordings.clustering_chunk is None:
            start = 0
            end = int(np.ceil(self.rec_len / self.recordings.sampling_rate))
        else:
            start = int(np.floor(self.recordings.clustering_chunk[0]))
            end = int(np.ceil(self.recordings.clustering_chunk[1]))
        self._set_param('clustering_chunk', [start, end])
Beispiel #10
0
def test_can_compute_whiten_matrix(data, data_info, path_to_geometry):
    geometry = parse(path_to_geometry, data_info['n_channels'])
    neighbors = find_channel_neighbors(geometry, radius=70)
    channel_index = make_channel_index(neighbors, geometry)

    whiten._matrix(data, channel_index, data_info['spike_size'])
Beispiel #11
0
def run(if_file_exists='skip'):
    """Preprocess pipeline: filtering, standarization and whitening filter

    This step (optionally) performs filtering on the data, standarizes it
    and computes a whitening filter. Filtering and standarized data are
    processed in chunks and written to disk.

    Parameters
    ----------
    if_file_exists: str, optional
        One of 'overwrite', 'abort', 'skip'. Control de behavior for every
        generated file. If 'overwrite' it replaces the files if any exist,
        if 'abort' it raises a ValueError exception if any file exists,
        if 'skip' it skips the operation (and loads the files) if any of them
        exist

    Returns
    -------
    standarized_path: str
        Path to standarized data binary file

    standarized_params: str
        Path to standarized data parameters

    channel_index: numpy.ndarray
        Channel indexes

    whiten_filter: numpy.ndarray
        Whiten matrix

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``filtered.bin`` - Filtered recordings
    * ``filtered.yaml`` - Filtered recordings metadata
    * ``standarized.bin`` - Standarized recordings
    * ``standarized.yaml`` - Standarized recordings metadata
    * ``whitening.npy`` - Whitening filter

    Everything is run on CPU.

    Examples
    --------

    .. literalinclude:: ../../examples/pipeline/preprocess.py
    """

    logger = logging.getLogger(__name__)

    CONFIG = read_config()
    OUTPUT_DTYPE = CONFIG.preprocess.dtype
    output_directory = os.path.join(CONFIG.path_to_output_directory,
                                    'preprocess')

    logger.info(
        'Output dtype for transformed data will be {}'.format(OUTPUT_DTYPE))

    if not os.path.exists(output_directory):
        logger.info('Creating temporary folder: {}'.format(output_directory))
        os.makedirs(output_directory)
    else:
        logger.info('Temporary folder {} already exists, output will be '
                    'stored there'.format(output_directory))

    params = dict(dtype=CONFIG.recordings.dtype,
                  n_channels=CONFIG.recordings.n_channels,
                  data_order=CONFIG.recordings.order)

    # Generate params:
    standarized_path = os.path.join(output_directory, "standarized.bin")
    standarized_params = params
    standarized_params['dtype'] = 'float32'

    # Check if data already saved to disk and skip:
    if if_file_exists == 'skip':
        if os.path.exists(standarized_path):

            channel_index = make_channel_index(CONFIG.neigh_channels,
                                               CONFIG.geom, 2)

            # Cat: this is redundant, should save to disk/not recompute
            whiten_filter = whiten.matrix(standarized_path,
                                          standarized_params['dtype'],
                                          standarized_params['n_channels'],
                                          standarized_params['data_order'],
                                          channel_index,
                                          CONFIG.spike_size,
                                          CONFIG.resources.max_memory,
                                          output_directory,
                                          output_filename='whitening.npy',
                                          if_file_exists=if_file_exists)

            path_to_channel_index = os.path.join(output_directory,
                                                 "channel_index.npy")

            return str(standarized_path), standarized_params, whiten_filter

    # read config params
    multi_processing = CONFIG.resources.multi_processing
    n_processors = CONFIG.resources.n_processors
    n_sec_chunk = CONFIG.resources.n_sec_chunk
    n_channels = CONFIG.recordings.n_channels
    sampling_rate = CONFIG.recordings.sampling_rate

    # Read filter params
    low_frequency = CONFIG.preprocess.filter.low_pass_freq
    high_factor = CONFIG.preprocess.filter.high_factor
    order = CONFIG.preprocess.filter.order
    buffer_size = 200

    # compute len of recording
    filename_dat = os.path.join(CONFIG.data.root_folder,
                                CONFIG.data.recordings)
    fp = np.memmap(filename_dat, dtype='int16', mode='r')
    fp_len = fp.shape[0]

    # compute batch indexes
    indexes = np.arange(0, fp_len / n_channels, sampling_rate * n_sec_chunk)
    if indexes[-1] != fp_len / n_channels:
        indexes = np.hstack((indexes, fp_len / n_channels))

    idx_list = []
    for k in range(len(indexes) - 1):
        idx_list.append([
            indexes[k], indexes[k + 1], buffer_size,
            indexes[k + 1] - indexes[k] + buffer_size
        ])

    idx_list = np.int64(np.vstack(idx_list))
    proc_indexes = np.arange(len(idx_list))

    logger.info("# of chunks: %i", len(idx_list))

    # Make directory to hold filtered batch files:
    filtered_location = os.path.join(output_directory, "filtered_files")
    logger.info(filtered_location)
    if not os.path.exists(filtered_location):
        os.makedirs(filtered_location)

    # filter and standardize in one step
    if multi_processing:
        parmap.map(filter_standardize,
                   list(zip(idx_list, proc_indexes)),
                   low_frequency,
                   high_factor,
                   order,
                   sampling_rate,
                   buffer_size,
                   filename_dat,
                   n_channels,
                   output_directory,
                   processes=n_processors,
                   pm_pbar=True)
    else:
        for k in range(len(idx_list)):
            filter_standardize([idx_list[k], k], low_frequency, high_factor,
                               order, sampling_rate, buffer_size, filename_dat,
                               n_channels, output_directory)

    # Merge the chunk filtered files and delete the individual chunks
    merge_filtered_files(output_directory)

    # save yaml file with params
    path_to_yaml = standarized_path.replace('.bin', '.yaml')

    params = dict(dtype=standarized_params['dtype'],
                  n_channels=standarized_params['n_channels'],
                  data_order=standarized_params['data_order'])

    with open(path_to_yaml, 'w') as f:
        logger.info('Saving params...')
        yaml.dump(params, f)

    # TODO: this shoulnd't be done here, it would be better to compute
    # this when initializing the config object and then access it from there
    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom, 2)

    # logger.info CONFIG.resources.max_memory
    # quit()
    # Cat: TODO: need to make this much smaller in size, don't need such
    # large batches
    # OLD CODE: compute whiten filter using batch processor
    # TODO: remove whiten_filter out of output argument

    whiten_filter = whiten.matrix(
        standarized_path,
        standarized_params['dtype'],
        standarized_params['n_channels'],
        standarized_params['data_order'],
        channel_index,
        CONFIG.spike_size,
        # CONFIG.resources.max_memory,
        '50MB',
        output_directory,
        output_filename='whitening.npy',
        if_file_exists=if_file_exists)

    path_to_channel_index = os.path.join(output_directory, 'channel_index.npy')
    save_numpy_object(channel_index,
                      path_to_channel_index,
                      if_file_exists=if_file_exists,
                      name='Channel index')

    return str(standarized_path), standarized_params, whiten_filter
Beispiel #12
0
def test_splitting_in_batches_does_not_affect(path_to_config,
                                              path_to_sample_pipeline_folder,
                                              make_tmp_folder,
                                              path_to_standardized_data):
    yass.set_config(path_to_config, make_tmp_folder)
    CONFIG = yass.read_config()

    PATH_TO_DATA = path_to_standardized_data

    with open(path.join(path_to_sample_pipeline_folder, 'preprocess',
                        'standardized.yaml')) as f:
        PARAMS = yaml.load(f)

    channel_index = make_channel_index(CONFIG.neigh_channels,
                                       CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th,
                                 channel_index)
    triage = KerasModel(triage_fname,
                        allow_longer_waveform_length=True,
                        allow_more_channels=True)
    NNAE = AutoEncoder.load(ae_fname, input_tensor=NND.waveform_tf)

    bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'],
                        PARAMS['data_order'], '100KB',
                        buffer_size=CONFIG.spike_size)

    out = ('spike_index', 'waveform')
    fn = neuralnetwork.apply.fix_indexes_spike_index

    # detector
    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)

        res = bp.multi_channel_apply(NND.predict_recording,
                                     mode='memory',
                                     sess=sess,
                                     output_names=out,
                                     cleanup_function=fn)

    spike_index_new = np.concatenate([element[0] for element in res], axis=0)
    wfs = np.concatenate([element[1] for element in res], axis=0)

    idx_clean = triage.predict_with_threshold(wfs, triage_th)
    score = NNAE.predict(wfs)
    rot = NNAE.load_rotation()
    neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

    (score_clear_new,
        spike_index_clear_new) = post_processing(score,
                                                 spike_index_new,
                                                 idx_clean,
                                                 rot,
                                                 neighbors)

    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)

        res = bp.multi_channel_apply(NND.predict_recording,
                                     mode='memory',
                                     sess=sess,
                                     output_names=('spike_index',
                                                   'waveform'),
                                     cleanup_function=fn)

    spike_index_batch, wfs = zip(*res)

    spike_index_batch = np.concatenate(spike_index_batch, axis=0)
    wfs = np.concatenate(wfs, axis=0)

    idx_clean = triage.predict_with_threshold(x=wfs,
                                              threshold=triage_th)

    score = NNAE.predict(wfs)
    rot = NNAE.load_rotation()
    neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

    (score_clear_batch,
        spike_index_clear_batch) = post_processing(score,
                                                   spike_index_batch,
                                                   idx_clean,
                                                   rot,
                                                   neighbors)
Beispiel #13
0
def run(output_directory='tmp/', if_file_exists='skip'):
    """Preprocess pipeline: filtering, standarization and whitening filter

    This step (optionally) performs filtering on the data, standarizes it
    and computes a whitening filter. Filtering and standarized data are
    processed in chunks and written to disk.

    Parameters
    ----------
    output_directory: str, optional
        Location to store results, relative to CONFIG.data.root_folder,
        defaults to tmp/. See list of files in Notes section below.

    if_file_exists: str, optional
        One of 'overwrite', 'abort', 'skip'. Control de behavior for every
        generated file. If 'overwrite' it replaces the files if any exist,
        if 'abort' it raises a ValueError exception if any file exists,
        if 'skip' it skips the operation (and loads the files) if any of them
        exist

    Returns
    -------
    standarized_path: str
        Path to standarized data binary file

    standarized_params: str
        Path to standarized data parameters

    channel_index: numpy.ndarray
        Channel indexes

    whiten_filter: numpy.ndarray
        Whiten matrix

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``preprocess/filtered.bin`` - Filtered recordings
    * ``preprocess/filtered.yaml`` - Filtered recordings metadata
    * ``preprocess/standarized.bin`` - Standarized recordings
    * ``preprocess/standarized.yaml`` - Standarized recordings metadata
    * ``preprocess/whitening.npy`` - Whitening filter

    Everything is run on CPU.

    Examples
    --------

    .. literalinclude:: ../../examples/pipeline/preprocess.py
    """

    logger = logging.getLogger(__name__)

    CONFIG = read_config()
    OUTPUT_DTYPE = CONFIG.preprocess.dtype
    PROCESSES = CONFIG.resources.processes

    logger.info(
        'Output dtype for transformed data will be {}'.format(OUTPUT_DTYPE))

    TMP = Path(CONFIG.data.root_folder, output_directory, 'preprocess/')
    TMP.mkdir(parents=True, exist_ok=True)
    TMP = str(TMP)

    path = os.path.join(CONFIG.data.root_folder, CONFIG.data.recordings)
    params = dict(dtype=CONFIG.recordings.dtype,
                  n_channels=CONFIG.recordings.n_channels,
                  data_order=CONFIG.recordings.order)

    # filter and standarize
    if CONFIG.preprocess.apply_filter:
        filter_params = CONFIG.preprocess.filter

        (standarized_path,
         standarized_params) = butterworth(path,
                                           params['dtype'],
                                           params['n_channels'],
                                           params['data_order'],
                                           filter_params.low_pass_freq,
                                           filter_params.high_factor,
                                           filter_params.order,
                                           CONFIG.recordings.sampling_rate,
                                           CONFIG.resources.max_memory,
                                           TMP,
                                           OUTPUT_DTYPE,
                                           standarize=True,
                                           output_filename='standarized.bin',
                                           if_file_exists=if_file_exists,
                                           processes=PROCESSES)
    # just standarize
    else:
        (standarized_path,
         standarized_params) = standarize(path,
                                          params['dtype'],
                                          params['n_channels'],
                                          params['data_order'],
                                          CONFIG.recordings.sampling_rate,
                                          CONFIG.resources.max_memory,
                                          TMP,
                                          OUTPUT_DTYPE,
                                          output_filename='standarized.bin',
                                          if_file_exists=if_file_exists,
                                          processes=PROCESSES)

    # TODO: this shoulnd't be done here, it would be better to compute
    # this when initializing the config object and then access it from there
    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom, 2)

    # TODO: remove whiten_filter out of output argument
    whiten_filter = whiten.matrix(standarized_path,
                                  standarized_params['dtype'],
                                  standarized_params['n_channels'],
                                  standarized_params['data_order'],
                                  channel_index,
                                  CONFIG.spike_size,
                                  CONFIG.resources.max_memory,
                                  TMP,
                                  output_filename='whitening.npy',
                                  if_file_exists=if_file_exists)

    path_to_channel_index = os.path.join(TMP, 'channel_index.npy')
    save_numpy_object(channel_index,
                      path_to_channel_index,
                      if_file_exists=if_file_exists,
                      name='Channel index')

    return (str(standarized_path), standarized_params, channel_index,
            whiten_filter)
Beispiel #14
0
def test_splitting_in_batches_does_not_affect_result(path_to_tests):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    PATH_TO_DATA = path.join(path_to_tests, 'data/standarized.bin')

    data = RecordingsReader(PATH_TO_DATA, loader='array').data

    with open(path.join(path_to_tests, 'data/standarized.yaml')) as f:
        PARAMS = yaml.load(f)

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    whiten_filter = np.tile(
        np.eye(channel_index.shape[1], dtype='float32')[np.newaxis, :, :],
        [channel_index.shape[0], 1, 1])

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename
    (x_tf, output_tf, NND, NNAE, NNT) = neuralnetwork.prepare_nn(
        channel_index,
        whiten_filter,
        detection_th,
        triage_th,
        detection_fname,
        ae_fname,
        triage_fname,
    )

    # run all at once
    with tf.Session() as sess:
        # get values of above tensors
        NND.saver.restore(sess, NND.path_to_detector_model)
        NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model)
        NNT.saver.restore(sess, NNT.path_to_triage_model)
        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize(
            data, sess, x_tf, output_tf, neighbors, rot)

    # run in batches - buffer size makes sure we can detect spikes if they
    # appear at the end of any batch
    bp = BatchProcessor(PATH_TO_DATA,
                        PARAMS['dtype'],
                        PARAMS['n_channels'],
                        PARAMS['data_order'],
                        '100KB',
                        buffer_size=CONFIG.spike_size)

    with tf.Session() as sess:
        # get values of above tensors
        NND.saver.restore(sess, NND.path_to_detector_model)
        NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model)
        NNT.saver.restore(sess, NNT.path_to_triage_model)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        res = bp.multi_channel_apply(
            neuralnetwork.run_detect_triage_featurize,
            mode='memory',
            cleanup_function=neuralnetwork.fix_indexes,
            sess=sess,
            x_tf=x_tf,
            output_tf=output_tf,
            rot=rot,
            neighbors=neighbors)

    scores_batch = np.concatenate([element[0] for element in res], axis=0)
    clear_batch = np.concatenate([element[1] for element in res], axis=0)
    collision_batch = np.concatenate([element[2] for element in res], axis=0)

    np.testing.assert_array_equal(clear_batch, clear)
    np.testing.assert_array_equal(collision_batch, collision)
    np.testing.assert_array_equal(scores_batch, scores)
Beispiel #15
0
def test_splitting_in_batches_does_not_affect(path_to_tests,
                                              path_to_standarized_data,
                                              path_to_sample_pipeline_folder):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    PATH_TO_DATA = path_to_standarized_data

    data = RecordingsReader(PATH_TO_DATA, loader='array').data

    with open(
            path.join(path_to_sample_pipeline_folder, 'preprocess',
                      'standarized.yaml')) as f:
        PARAMS = yaml.load(f)

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index)
    NNT = NeuralNetTriage.load(triage_fname,
                               triage_th,
                               input_tensor=NND.waveform_tf)
    NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf)

    output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean)

    # run all at once
    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize(
            data, sess, NND.x_tf, output_tf, neighbors, rot)

    # run in batches - buffer size makes sure we can detect spikes if they
    # appear at the end of any batch
    bp = BatchProcessor(PATH_TO_DATA,
                        PARAMS['dtype'],
                        PARAMS['n_channels'],
                        PARAMS['data_order'],
                        '100KB',
                        buffer_size=CONFIG.spike_size)

    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        res = bp.multi_channel_apply(
            neuralnetwork.run_detect_triage_featurize,
            mode='memory',
            cleanup_function=neuralnetwork.fix_indexes,
            sess=sess,
            x_tf=NND.x_tf,
            output_tf=output_tf,
            rot=rot,
            neighbors=neighbors)

    scores_batch = np.concatenate([element[0] for element in res], axis=0)
    clear_batch = np.concatenate([element[1] for element in res], axis=0)
    collision_batch = np.concatenate([element[2] for element in res], axis=0)

    np.testing.assert_array_equal(clear_batch, clear)
    np.testing.assert_array_equal(collision_batch, collision)
    np.testing.assert_array_equal(scores_batch, scores)