示例#1
0
def test_can_use_neural_network_detector(path_to_tests):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    data = RecordingsReader(path.join(path_to_tests, 'data/standarized.bin'),
                            loader='array').data

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    whiten_filter = np.tile(
        np.eye(channel_index.shape[1], dtype='float32')[np.newaxis, :, :],
        [channel_index.shape[0], 1, 1])

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    (x_tf, output_tf, NND, NNAE,
     NNT) = neuralnetwork.prepare_nn(channel_index, whiten_filter,
                                     detection_th, triage_th, detection_fname,
                                     ae_fname, triage_fname)

    with tf.Session() as sess:
        # get values of above tensors
        NND.saver.restore(sess, NND.path_to_detector_model)
        NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model)
        NNT.saver.restore(sess, NNT.path_to_triage_model)
        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        neuralnetwork.run_detect_triage_featurize(data, sess, x_tf, output_tf,
                                                  neighbors, rot)
示例#2
0
def templates_uncropped(path_to_config, make_tmp_folder,
                        path_to_sample_pipeline_folder,
                        path_to_standarized_data):
    spike_train = np.array([100, 0, 150, 0, 200, 1, 250, 1, 300, 2, 350,
                            2]).reshape(-1, 2)

    yass.set_config(path_to_config, make_tmp_folder)
    CONFIG = yass.read_config()

    spike_train = np.load(
        path.join(path_to_sample_pipeline_folder,
                  'spike_train_post_deconv_post_merge.npy'))

    n_spikes, _ = spike_train.shape

    weighted_spike_train = np.hstack(
        (spike_train, np.ones((n_spikes, 1), 'int32')))

    templates_uncropped, _ = get_templates(weighted_spike_train,
                                           path_to_standarized_data,
                                           CONFIG.resources.max_memory,
                                           4 * CONFIG.spike_size)

    templates_uncropped_ = np.transpose(templates_uncropped, (2, 1, 0))

    return templates_uncropped_
示例#3
0
def test_can_make_noise(path_to_tests, path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    n_spikes, _ = spike_train.shape

    weighted_spike_train = np.hstack(
        (spike_train, np.ones((n_spikes, 1), 'int32')))

    templates_uncropped, _ = get_templates(weighted_spike_train,
                                           path_to_standarized_data,
                                           CONFIG.resources.max_memory,
                                           4 * CONFIG.spike_size)

    templates_uncropped = np.transpose(templates_uncropped, (2, 1, 0))

    templates = crop_and_align_templates(templates_uncropped,
                                         CONFIG.spike_size,
                                         CONFIG.neigh_channels, CONFIG.geom)

    spatial_SIG, temporal_SIG = noise_cov(path_to_standarized_data,
                                          CONFIG.neigh_channels, CONFIG.geom,
                                          templates.shape[1])

    x_clean = make_clean(templates, min_amp=2, max_amp=10, nk=100)

    make_noise(x_clean,
               noise_ratio=10,
               templates=templates,
               spatial_SIG=spatial_SIG,
               temporal_SIG=temporal_SIG)
示例#4
0
def test_can_reload_detector(path_to_tests, path_to_sample_pipeline_folder,
                             tmp_folder):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    (x_detect, y_detect, x_triage, y_triage, x_ae,
     y_ae) = make_training_data(CONFIG, spike_train, chosen_templates,
                                min_amplitude, n_spikes,
                                path_to_sample_pipeline_folder)

    _, waveform_length, n_neighbors = x_detect.shape

    path_to_model = path.join(tmp_folder, 'detect-net.ckpt')

    detector = NeuralNetDetector(path_to_model,
                                 filters,
                                 waveform_length,
                                 n_neighbors,
                                 threshold=0.5,
                                 channel_index=CONFIG.channel_index,
                                 n_iter=10)

    detector.fit(x_detect, y_detect)

    NeuralNetDetector.load(path_to_model,
                           threshold=0.5,
                           channel_index=CONFIG.channel_index)
示例#5
0
def test_can_use_neural_network_detector(path_to_tests,
                                         path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    data = RecordingsReader(path_to_standarized_data, loader='array').data

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index)
    NNT = NeuralNetTriage.load(triage_fname,
                               triage_th,
                               input_tensor=NND.waveform_tf)
    NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf)

    output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean)

    with tf.Session() as sess:
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        neuralnetwork.run_detect_triage_featurize(data, sess, NND.x_tf,
                                                  output_tf, neighbors, rot)
示例#6
0
def test_can_make_misaligned(path_to_tests, path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    n_spikes, _ = spike_train.shape

    weighted_spike_train = np.hstack(
        (spike_train, np.ones((n_spikes, 1), 'int32')))

    templates_uncropped, _ = get_templates(weighted_spike_train,
                                           path_to_standarized_data,
                                           CONFIG.resources.max_memory,
                                           4 * CONFIG.spike_size)

    templates_uncropped = np.transpose(templates_uncropped, (2, 1, 0))

    x_clean = make_clean(templates_uncropped, min_amp=2, max_amp=10, nk=100)

    make_misaligned(x_clean,
                    templates_uncropped,
                    max_shift=2 * CONFIG.spike_size,
                    misalign_ratio=1,
                    misalign_ratio2=1,
                    multi=True,
                    nneigh=templates_uncropped.shape[2])
示例#7
0
def run_neural_network(standardized_path,
                       standardized_dtype,
                       output_directory,
                       run_chunk_sec='full'):
    """Run neural network detection
    """
    logger = logging.getLogger(__name__)

    CONFIG = read_config()

    # load NN detector
    detector = Detect(CONFIG.neuralnetwork.detect.n_filters,
                      CONFIG.spike_size_nn, CONFIG.channel_index)
    detector.load(CONFIG.neuralnetwork.detect.filename)

    # load NN denoiser
    denoiser = Denoise(CONFIG.neuralnetwork.denoise.n_filters,
                       CONFIG.neuralnetwork.denoise.filter_sizes,
                       CONFIG.spike_size_nn)
    denoiser.load(CONFIG.neuralnetwork.denoise.filename)

    # get data reader
    batch_length = CONFIG.resources.n_sec_chunk * CONFIG.resources.n_processors
    n_sec_chunk = CONFIG.resources.n_sec_chunk_gpu_detect
    print("   batch length to (sec): ", batch_length,
          " (longer increase speed a bit)")
    print("   length of each seg (sec): ", n_sec_chunk)
    buffer = CONFIG.spike_size_nn
    if run_chunk_sec == 'full':
        chunk_sec = None
    else:
        chunk_sec = run_chunk_sec

    reader = READER(standardized_path, standardized_dtype, CONFIG,
                    batch_length, buffer, chunk_sec)

    # neighboring channels
    channel_index_dedup = make_channel_index(CONFIG.neigh_channels,
                                             CONFIG.geom,
                                             steps=2)

    # threshold for neuralnet detection
    detect_threshold = CONFIG.detect.threshold

    # loop over each chunk
    batch_ids = np.arange(reader.n_batches)
    batch_ids_split = np.split(batch_ids, len(CONFIG.torch_devices))
    processes = []
    for ii, device in enumerate(CONFIG.torch_devices):
        p = mp.Process(target=run_nn_detction_batch,
                       args=(batch_ids_split[ii], output_directory, reader,
                             n_sec_chunk, detector, denoiser,
                             channel_index_dedup, detect_threshold, device))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()
示例#8
0
def test_can_make_training_data(path_to_tests, path_to_sample_pipeline_folder):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    make_training_data(CONFIG,
                       spike_train,
                       chosen_templates,
                       min_amplitude,
                       n_spikes_to_make,
                       data_folder=path_to_sample_pipeline_folder)
示例#9
0
def run_template_update(output_directory,
                        fname_templates,
                        fname_spike_train,
                        fname_shifts,
                        fname_scales,
                        fname_residual,
                        residual_dtype,
                        residual_offset=0,
                        update_weight=50,
                        units_to_update=None):

    fname_templates_out = os.path.join(output_directory, 'templates.npy')

    if not os.path.exists(fname_templates_out):

        print('updating templates')

        CONFIG = read_config()

        # output folder
        if not os.path.exists(output_directory):
            os.makedirs(output_directory)

        # reader
        if CONFIG.deconvolution.deconv_gpu:
            n_sec_chunk = CONFIG.resources.n_sec_chunk_gpu_deconv
        else:
            n_sec_chunk = CONFIG.resources.n_sec_chunk
        reader_residual = READER(fname_residual,
                                 residual_dtype,
                                 CONFIG,
                                 n_sec_chunk,
                                 offset=residual_offset)

        # residual obj that can shift templates in gpu
        residual_comp = RESIDUAL_GPU2(None, CONFIG, None, None, None, None,
                                      None, None, None, None, True)
        residual_comp.load_templates(fname_templates)
        residual_comp.make_bsplines_parallel()

        avg_min_max_vals, weights = get_avg_min_max_vals(
            fname_templates, fname_spike_train, fname_shifts, fname_scales,
            reader_residual, residual_comp, units_to_update)

        templates_updated = update_templates(fname_templates, weights,
                                             avg_min_max_vals, update_weight,
                                             units_to_update)

        np.save(fname_templates_out, templates_updated)

    return fname_templates_out
示例#10
0
def run(fname_recording, recording_dtype, fname_spike_train,
        output_directory):
           
    """
    """

    logger = logging.getLogger(__name__)

    CONFIG = read_config()

    # make output directory if not exist
    if not os.path.exists(output_directory):
        os.mkdir(output_directory)

    # get reader
    reader = READER(fname_recording,
                    recording_dtype,
                    CONFIG)
    reader.spike_size = CONFIG.spike_size_nn

    # get noise covariance
    logger.info('Compute Noise Covaraince')
    save_dir = os.path.join(output_directory, 'noise_cov')
    chunk = [0, np.min((5*60*reader.sampling_rate, reader.end))]
    fname_spatial_sig, fname_temporal_sig = get_noise_covariance(
        reader, save_dir, CONFIG, chunk)
    
    # get processed templates
    logger.info('Crop Templates')
    save_dir = os.path.join(output_directory, 'templates')
    fname_templates_snippets = get_templates_on_local_channels(
        reader, save_dir, fname_spike_train, CONFIG)

    # denoise templates
    fname_templates_denoised = denoise_templates(
        fname_templates_snippets, save_dir)

    # make training data
    logger.info('Make Training Data')
    DetectTD = Detection_Training_Data(
        fname_templates_denoised,
        fname_spatial_sig,
        fname_temporal_sig)
    
    DenoTD = Denoising_Training_Data(
        fname_templates_denoised,
        fname_spatial_sig,
        fname_temporal_sig)
    
    return DetectTD, DenoTD
示例#11
0
def run_voltage_treshold(standardized_path,
                         standardized_dtype,
                         output_directory,
                         run_chunk_sec='full'):
    """Run detection that thresholds on amplitude
    """
    logger = logging.getLogger(__name__)

    CONFIG = read_config()

    # get data reader
    #n_sec_chunk = CONFIG.resources.n_sec_chunk*CONFIG.resources.n_processors
    batch_length = CONFIG.resources.n_sec_chunk
    n_sec_chunk = 0.5
    print("   batch length to (sec): ", batch_length,
          " (longer increase speed a bit)")
    print("   length of each seg (sec): ", n_sec_chunk)
    buffer = CONFIG.spike_size
    if run_chunk_sec == 'full':
        chunk_sec = None
    else:
        chunk_sec = run_chunk_sec

    reader = READER(standardized_path, standardized_dtype, CONFIG,
                    batch_length, buffer, chunk_sec)

    # number of processed chunks
    n_mini_per_big_batch = int(np.ceil(batch_length / n_sec_chunk))
    total_processing = int(reader.n_batches * n_mini_per_big_batch)

    # neighboring channels
    channel_index = make_channel_index(CONFIG.neigh_channels,
                                       CONFIG.geom,
                                       steps=2)

    if CONFIG.resources.multi_processing:
        parmap.starmap(run_voltage_threshold_parallel,
                       list(zip(np.arange(reader.n_batches))),
                       reader,
                       n_sec_chunk,
                       CONFIG.detect.threshold,
                       channel_index,
                       output_directory,
                       processes=CONFIG.resources.n_processors,
                       pm_pbar=True)
    else:
        for batch_id in range(reader.n_batches):
            run_voltage_threshold_parallel(batch_id, reader, n_sec_chunk,
                                           CONFIG.detect.threshold,
                                           channel_index, output_directory)
示例#12
0
def test_can_use_detect_and_triage_after_reload(path_to_tests,
                                                path_to_sample_pipeline_folder,
                                                tmp_folder,
                                                path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    (x_detect, y_detect, x_triage, y_triage, x_ae,
     y_ae) = make_training_data(CONFIG, spike_train, chosen_templates,
                                min_amplitude, n_spikes,
                                path_to_sample_pipeline_folder)

    _, waveform_length, n_neighbors = x_detect.shape

    path_to_model = path.join(tmp_folder, 'detect-net.ckpt')

    detector = NeuralNetDetector(path_to_model,
                                 filters,
                                 waveform_length,
                                 n_neighbors,
                                 threshold=0.5,
                                 channel_index=CONFIG.channel_index,
                                 n_iter=10)

    detector.fit(x_detect, y_detect)

    detector = NeuralNetDetector.load(path_to_model,
                                      threshold=0.5,
                                      channel_index=CONFIG.channel_index)

    triage = NeuralNetTriage(path_to_model,
                             filters,
                             waveform_length,
                             n_neighbors,
                             threshold=0.5,
                             n_iter=10)

    triage.fit(x_detect, y_detect)

    triage = NeuralNetTriage.load(path_to_model, threshold=0.5)

    data = RecordingExplorer(path_to_standarized_data).reader.data

    output_names = ('spike_index', 'waveform', 'probability')

    (spike_index, waveform,
     proba) = detector.predict(data, output_names=output_names)

    triage.predict(waveform[:, :, :n_neighbors])
示例#13
0
def run():
    """RF computation
    """

    CONFIG = read_config()

    stim_movie_file = os.path.join(CONFIG.data.root_folder, CONFIG.data.stimulus)
    triggers_fname = os.path.join(CONFIG.data.root_folder, CONFIG.data.triggers)
    spike_train_fname = os.path.join(CONFIG.path_to_output_directory,
                                     'spike_train.npy')
    saving_dir = os.path.join(CONFIG.path_to_output_directory, 'rf')
    
    rf = RF(stim_movie_file, triggers_fname, spike_train_fname, saving_dir)
    rf.calculate_STA()
    rf.detect_multi_rf()
    rf.classification()
示例#14
0
def covariance(recordings, temporal_size, neigbor_steps):
    """Compute noise spatial and temporal covariance

    Parameters
    ----------
    recordings: matrix
        Multi-cannel recordings (n observations x n channels)
    temporal_size:

    neigbor_steps: int
        Number of steps from the multi-channel geometry to consider two
        channels as neighors
    """
    CONFIG = read_config()

    # get the neighbor channels at a max "neigbor_steps" steps
    neigh_channels = n_steps_neigh_channels(CONFIG.neighChannels,
                                            neigbor_steps)

    # sum neighor flags for every channel, this gives the number of neighbors
    # per channel, then find the channel with the most neighbors
    # TODO: why are we selecting this one?
    channel = np.argmax(np.sum(neigh_channels, 0))

    # get the neighbor channels for "channel"
    (neighbords_idx, ) = np.where(neigh_channels[channel])

    # order neighbors by distance
    neighbords_idx, temp = order_channels_by_distance(channel, neighbords_idx,
                                                      CONFIG.geom)

    # from the multi-channel recordings, get the neighbor channels
    # (this includes the channel with the most neighbors itself)
    rec = recordings[:, neighbords_idx]

    # filter recording
    if CONFIG.preprocess.filter == 1:
        rec = butterworth(rec, CONFIG.filter.low_pass_freq,
                          CONFIG.filter.high_factor, CONFIG.filter.order,
                          CONFIG.recordings.sampling_rate)

    # standardize recording
    sd_ = standarize.sd(rec, CONFIG.recordings.sampling_rate)
    rec = standarize.standarize(rec, sd_)

    # compute and return spatial and temporal covariance
    return util.covariance(rec, temporal_size, neigbor_steps, CONFIG.spikeSize)
示例#15
0
def test_can_make_clean(path_to_tests, path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    n_spikes, _ = spike_train.shape

    weighted_spike_train = np.hstack(
        (spike_train, np.ones((n_spikes, 1), 'int32')))

    templates_uncropped, _ = get_templates(weighted_spike_train,
                                           path_to_standarized_data,
                                           CONFIG.resources.max_memory,
                                           4 * CONFIG.spike_size)

    templates_uncropped = np.transpose(templates_uncropped, (2, 1, 0))

    make_clean(templates_uncropped, min_amp=2, max_amp=10, nk=100)
示例#16
0
def test_can_use_detector_after_fit(path_to_config,
                                    path_to_sample_pipeline_folder,
                                    make_tmp_folder,
                                    path_to_standardized_data):
    yass.set_config(path_to_config, make_tmp_folder)
    CONFIG = yass.read_config()

    spike_train = np.load(path.join(path_to_sample_pipeline_folder,
                                    'spike_train.npy'))
    chosen_templates = np.unique(spike_train[:, 1])
    min_amplitude = 4
    max_amplitude = 60
    n_spikes_to_make = 100

    templates = make.load_templates(path_to_sample_pipeline_folder,
                                    spike_train, CONFIG, chosen_templates)

    path_to_standardized = path.join(path_to_sample_pipeline_folder,
                                     'preprocess', 'standarized.bin')

    (x_detect, y_detect,
     x_triage, y_triage,
     x_ae, y_ae) = make.training_data(CONFIG, templates,
                                      min_amplitude, max_amplitude,
                                      n_spikes_to_make,
                                      path_to_standardized)

    _, waveform_length, n_neighbors = x_detect.shape

    path_to_model = path.join(make_tmp_folder, 'detect-net.ckpt')
    detector = NeuralNetDetector(path_to_model, [8, 4],
                                 waveform_length, n_neighbors,
                                 threshold=0.5,
                                 channel_index=CONFIG.channel_index,
                                 n_iter=10)
    detector.fit(x_detect, y_detect)

    data = RecordingExplorer(path_to_standardized_data).reader.data

    output_names = ('spike_index', 'waveform', 'probability')

    (spike_index, waveform,
        proba) = detector.predict_recording(data, output_names=output_names)

    detector.predict(x_detect)
示例#17
0
def test_can_crop_and_align_templates(path_to_tests, path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    n_spikes, _ = spike_train.shape

    weighted_spike_train = np.hstack(
        (spike_train, np.ones((n_spikes, 1), 'int32')))

    templates_uncropped, _ = get_templates(weighted_spike_train,
                                           path_to_standarized_data,
                                           CONFIG.resources.max_memory,
                                           4 * CONFIG.spike_size)

    templates_uncropped = np.transpose(templates_uncropped, (2, 1, 0))

    crop_and_align_templates(templates_uncropped, CONFIG.spike_size,
                             CONFIG.neigh_channels, CONFIG.geom)
示例#18
0
def test_deconvolution(patch_triage_network, path_to_config, make_tmp_folder):
    yass.set_config(path_to_config, make_tmp_folder)

    (standarized_path, standarized_params, whiten_filter) = preprocess.run()

    spike_index_all = detect.run(standarized_path, standarized_params,
                                 whiten_filter)

    cluster.run(None, spike_index_all)

    CONFIG = read_config()
    TMP_FOLDER = CONFIG.path_to_output_directory

    path_to_spike_train_cluster = path.join(TMP_FOLDER,
                                            'spike_train_cluster.npy')
    spike_train_cluster = np.load(path_to_spike_train_cluster)
    templates_cluster = np.load(path.join(TMP_FOLDER, 'templates_cluster.npy'))

    spike_train, postdeconv_templates = deconvolve.run(spike_train_cluster,
                                                       templates_cluster)
示例#19
0
def test_can_reload_detector(path_to_config,
                             path_to_sample_pipeline_folder,
                             make_tmp_folder):
    yass.set_config(path_to_config, make_tmp_folder)
    CONFIG = yass.read_config()

    spike_train = np.load(path.join(path_to_sample_pipeline_folder,
                                    'spike_train.npy'))
    chosen_templates = np.unique(spike_train[:, 1])
    min_amplitude = 4
    max_amplitude = 60
    n_spikes_to_make = 100

    templates = make.load_templates(path_to_sample_pipeline_folder,
                                    spike_train, CONFIG, chosen_templates)

    path_to_standarized = path.join(path_to_sample_pipeline_folder,
                                    'preprocess', 'standarized.bin')

    (x_detect, y_detect,
     x_triage, y_triage,
     x_ae, y_ae) = make.training_data(CONFIG, templates,
                                      min_amplitude, max_amplitude,
                                      n_spikes_to_make,
                                      path_to_standarized)

    _, waveform_length, n_neighbors = x_detect.shape

    path_to_model = path.join(make_tmp_folder, 'detect-net.ckpt')

    detector = NeuralNetDetector(path_to_model, [8, 4],
                                 waveform_length, n_neighbors,
                                 threshold=0.5,
                                 channel_index=CONFIG.channel_index,
                                 n_iter=10)

    detector.fit(x_detect, y_detect)

    NeuralNetDetector.load(path_to_model, threshold=0.5,
                           channel_index=CONFIG.channel_index)
示例#20
0
def get_o_layer(standarized_path,
                standarized_params,
                output_directory='tmp/',
                output_dtype='float32',
                output_filename='o_layer.bin',
                if_file_exists='skip',
                save_partial_results=False):
    """Get the output of NN detector instead of outputting the spike index
    """

    CONFIG = read_config()
    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom, 1)

    x_tf = tf.placeholder("float", [None, None])

    # load Neural Net's
    detection_fname = CONFIG.detect.neural_network_detector.filename
    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    NND = NeuralNetDetector(detection_fname)
    o_layer_tf = NND.make_o_layer_tf_tensors(x_tf, channel_index, detection_th)

    bp = BatchProcessor(standarized_path,
                        standarized_params['dtype'],
                        standarized_params['n_channels'],
                        standarized_params['data_format'],
                        CONFIG.resources.max_memory,
                        buffer_size=CONFIG.spike_size)

    TMP = os.path.join(CONFIG.data.root_folder, output_directory)
    _output_path = os.path.join(TMP, output_filename)
    (o_path, o_params) = bp.multi_channel_apply(_get_o_layer,
                                                mode='disk',
                                                cleanup_function=fix_indexes,
                                                output_path=_output_path,
                                                cast_dtype=output_dtype,
                                                x_tf=x_tf,
                                                o_layer_tf=o_layer_tf,
                                                NND=NND)

    return o_path, o_params
示例#21
0
def run_deduplication(batch_files_dir, output_directory):

    CONFIG = read_config()

    neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)
    w = 5

    batch_ids = list(np.arange(len(os.listdir(batch_files_dir))))
    if CONFIG.resources.multi_processing:
        #if False:
        parmap.map(run_deduplication_batch_simple,
                   batch_ids,
                   batch_files_dir,
                   output_directory,
                   neighbors,
                   w,
                   processes=CONFIG.resources.n_processors,
                   pm_pbar=True)
    else:
        for batch_id in batch_ids:
            run_deduplication_batch_simple(batch_id, batch_files_dir,
                                           output_directory, neighbors, w)
示例#22
0
文件: run.py 项目: AkiHase/yass
def run(template_fname, spike_train_fname, shifts_fname, output_directory,
        residual_fname, residual_dtype):

    logger = logging.getLogger(__name__)

    CONFIG = read_config()

    #
    fname_out = os.path.join(output_directory, 'soft_assignment.npy')
    if os.path.exists(fname_out):
        return fname_out

    # output folder
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    # reader for residual
    reader_resid = READER(residual_fname, residual_dtype, CONFIG,
                          CONFIG.resources.n_sec_chunk_gpu_deconv)

    # load NN detector
    detector = Detect(CONFIG.neuralnetwork.detect.n_filters,
                      CONFIG.spike_size_nn, CONFIG.channel_index)
    detector.load(CONFIG.neuralnetwork.detect.filename)
    detector = detector.cuda()

    # initialize soft assignment calculator
    threshold = CONFIG.deconvolution.threshold / 0.1
    sna = SOFTNOISEASSIGNMENT(spike_train_fname, template_fname, shifts_fname,
                              reader_resid, detector, CONFIG.channel_index,
                              threshold)

    # compuate soft assignment
    probs = sna.compute_soft_assignment()
    np.save(fname_out, probs)

    return fname_out
from pathlib import Path
import logging
from datetime import datetime
from memory_profiler import profile
import yass
from yass import deconvolute
import settings

if __name__ == '__main__':
    settings.run()
    start = datetime.now()
    logger = logging.getLogger(__name__)

    CONFIG = yass.read_config()

    logger.info('Deconvolution started at second: %.2f',
                (datetime.now() - start).total_seconds())

    DIRECTORY = Path(CONFIG.data.root_folder, 'profiling')

    spike_index_all = str(DIRECTORY / 'spike_index_all.npy')
    templates = str(DIRECTORY / 'templates.npy')

    profile(deconvolute.run)(spike_index_all,
                             templates,
                             output_directory='profiling')

    logger.info('Deconvolution finished at second: %.2f',
                (datetime.now() - start).total_seconds())
示例#24
0
文件: pipeline.py 项目: Nomow/yass
def run(config,
        logger_level='INFO',
        clean=False,
        output_dir='tmp/',
        complete=False,
        set_zero_seed=False):
    """Run YASS built-in pipeline

    Parameters
    ----------
    config: str or mapping (such as dictionary)
        Path to YASS configuration file or mapping object

    logger_level: str
        Logger level

    clean: bool, optional
        Delete CONFIG.data.root_folder/output_dir/ before running

    output_dir: str, optional
        Output directory (if relative, it makes it relative to
        CONFIG.data.root_folder) to store the output data, defaults to tmp/.
        If absolute, it leaves it as it is.

    complete: bool, optional
        Generates extra files (needed to generate phy files)

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``config.yaml`` - Copy of the configuration file
    * ``metadata.yaml`` - Experiment metadata
    * ``filtered.bin`` - Filtered recordings (from preprocess)
    * ``filtered.yaml`` - Filtered recordings metadata (from preprocess)
    * ``standardized.bin`` - Standarized recordings (from preprocess)
    * ``standardized.yaml`` - Standarized recordings metadata (from preprocess)
    * ``whitening.npy`` - Whitening filter (from preprocess)


    Returns
    -------
    numpy.ndarray
        Spike train
    """

    # load yass configuration parameters
    set_config(config, output_dir)
    CONFIG = read_config()
    TMP_FOLDER = CONFIG.path_to_output_directory

    # remove tmp folder if needed
    if os.path.exists(TMP_FOLDER) and clean:
        shutil.rmtree(TMP_FOLDER)

    # create TMP_FOLDER if needed
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    # load logging config file
    logging_config = load_logging_config_file()
    logging_config['handlers']['file']['filename'] = path.join(
        TMP_FOLDER, 'yass.log')
    logging_config['root']['level'] = logger_level

    # configure logging
    logging.config.dictConfig(logging_config)

    # instantiate logger
    logger = logging.getLogger(__name__)

    # print yass version
    logger.info('YASS version: %s', yass.__version__)
    ''' **********************************************
        ******** SET ENVIRONMENT VARIABLES ***********
        **********************************************
    '''
    os.environ["OPENBLAS_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"
    os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/"
    ''' **********************************************
        ************** PREPROCESS ********************
        **********************************************
    '''
    # preprocess
    start = time.time()
    (standardized_path, standardized_params, whiten_filter) = (preprocess.run(
        if_file_exists=CONFIG.preprocess.if_file_exists))

    time_preprocess = time.time() - start
    ''' **********************************************
        ************** DETECT EVENTS *****************
        **********************************************
    '''
    # detect
    # Cat: This code now runs with open tensorflow calls
    start = time.time()
    (spike_index_all) = detect.run(standardized_path,
                                   standardized_params,
                                   whiten_filter,
                                   if_file_exists=CONFIG.detect.if_file_exists,
                                   save_results=CONFIG.detect.save_results)
    spike_index_clear = None
    time_detect = time.time() - start
    ''' **********************************************
        ***************** CLUSTER ********************
        **********************************************
    '''

    # cluster
    start = time.time()
    path_to_spike_train_cluster = path.join(TMP_FOLDER,
                                            'spike_train_cluster.npy')
    if os.path.exists(path_to_spike_train_cluster) == False:
        cluster.run(spike_index_clear, spike_index_all)
    else:
        print("\nClustering completed previously...\n\n")

    spike_train_cluster = np.load(path_to_spike_train_cluster)
    templates_cluster = np.load(
        os.path.join(TMP_FOLDER, 'templates_cluster.npy'))

    time_cluster = time.time() - start
    #print ("Spike train clustered: ", spike_index_cluster.shape, "spike train clear: ",
    #        spike_train_clear.shape, " templates: ", templates.shape)
    ''' **********************************************
        ************** DECONVOLUTION *****************
        **********************************************
    '''

    # run deconvolution
    start = time.time()
    spike_train, postdeconv_templates = deconvolve.run(spike_train_cluster,
                                                       templates_cluster)
    time_deconvolution = time.time() - start

    # save spike train
    path_to_spike_train = path.join(TMP_FOLDER,
                                    'spike_train_post_deconv_post_merge.npy')
    np.save(path_to_spike_train, spike_train)
    logger.info('Spike train saved in: {}'.format(path_to_spike_train))

    # save template
    path_to_templates = path.join(TMP_FOLDER,
                                  'templates_post_deconv_post_merge.npy')
    np.save(path_to_templates, postdeconv_templates)
    logger.info('Templates saved in: {}'.format(path_to_templates))
    ''' **********************************************
        ************** POST PROCESSING****************
        **********************************************
    '''

    # save metadata in tmp
    path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml')
    logging.info('Saving metadata in {}'.format(path_to_metadata))
    save_metadata(path_to_metadata)

    # save metadata in tmp
    path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml')
    logging.info('Saving metadata in {}'.format(path_to_metadata))
    save_metadata(path_to_metadata)

    # save config.yaml copy in tmp/
    path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml')

    if isinstance(config, Mapping):
        with open(path_to_config_copy, 'w') as f:
            yaml.dump(config, f, default_flow_style=False)
    else:
        shutil.copy2(config, path_to_config_copy)

    logging.info('Saving copy of config: {} in {}'.format(
        config, path_to_config_copy))

    # this part loads waveforms for all spikes in the spike train and scores
    # them, this data is needed to later generate phy files
    if complete:
        STANDARIZED_PATH = path.join(TMP_FOLDER, 'standardized.bin')
        PARAMS = load_yaml(path.join(TMP_FOLDER, 'standardized.yaml'))

        # load waveforms for all spikes in the spike train
        logger.info('Loading waveforms from all spikes in the spike train...')
        explorer = RecordingExplorer(STANDARIZED_PATH,
                                     spike_size=CONFIG.spike_size,
                                     dtype=PARAMS['dtype'],
                                     n_channels=PARAMS['n_channels'],
                                     data_order=PARAMS['data_order'])
        waveforms = explorer.read_waveforms(spike_train[:, 0])

        path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy')
        np.save(path_to_waveforms, waveforms)
        logger.info('Saved all waveforms from the spike train in {}...'.format(
            path_to_waveforms))

        # score all waveforms
        logger.info('Scoring waveforms from all spikes in the spike train...')
        path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy')
        rotation = np.load(path_to_rotation)

        main_channels = explorer.main_channel_for_waveforms(waveforms)
        path_to_main_channels = path.join(TMP_FOLDER,
                                          'waveforms_main_channel.npy')
        np.save(path_to_main_channels, main_channels)
        logger.info('Saved all waveforms main channels in {}...'.format(
            path_to_waveforms))

        waveforms_score = dim_red.score(waveforms, rotation, main_channels,
                                        CONFIG.neigh_channels, CONFIG.geom)
        path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy')
        np.save(path_to_waveforms_score, waveforms_score)
        logger.info('Saved all scores in {}...'.format(path_to_waveforms))

        # score templates
        # TODO: templates should be returned in the right shape to avoid .T
        templates_ = templates.T
        main_channels_tmpls = explorer.main_channel_for_waveforms(templates_)
        path_to_templates_main_c = path.join(TMP_FOLDER,
                                             'templates_main_channel.npy')
        np.save(path_to_templates_main_c, main_channels_tmpls)
        logger.info('Saved all templates main channels in {}...'.format(
            path_to_templates_main_c))

        templates_score = dim_red.score(templates_, rotation,
                                        main_channels_tmpls,
                                        CONFIG.neigh_channels, CONFIG.geom)
        path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy')
        np.save(path_to_templates_score, templates_score)
        logger.info(
            'Saved all templates scores in {}...'.format(path_to_waveforms))

    logger.info('Finished YASS execution. Timing summary:')
    total = (time_preprocess + time_detect + time_cluster + time_deconvolution)
    logger.info('\t Preprocess: %s (%.2f %%)',
                human_readable_time(time_preprocess),
                time_preprocess / total * 100)
    logger.info('\t Detection: %s (%.2f %%)', human_readable_time(time_detect),
                time_detect / total * 100)
    logger.info('\t Clustering: %s (%.2f %%)',
                human_readable_time(time_cluster), time_cluster / total * 100)
    logger.info('\t Deconvolution: %s (%.2f %%)',
                human_readable_time(time_deconvolution),
                time_deconvolution / total * 100)

    return spike_train
示例#25
0
def run(config, logger_level='INFO', clean=False, output_dir='tmp/'):
    """Run YASS built-in pipeline

    Parameters
    ----------
    config: str or mapping (such as dictionary)
        Path to YASS configuration file or mapping object

    logger_level: str
        Logger level

    clean: bool, optional
        Delete CONFIG.data.root_folder/output_dir/ before running

    output_dir: str, optional
        Output directory (if relative, it makes it relative to
        CONFIG.data.root_folder) to store the output data, defaults to tmp/.
        If absolute, it leaves it as it is.

    complete: bool, optional
        Generates extra files (needed to generate phy files)

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``config.yaml`` - Copy of the configuration file
    * ``metadata.yaml`` - Experiment metadata
    * ``filtered.bin`` - Filtered recordings (from preprocess)
    * ``filtered.yaml`` - Filtered recordings metadata (from preprocess)
    * ``standardized.bin`` - Standarized recordings (from preprocess)
    * ``standardized.yaml`` - Standarized recordings metadata (from preprocess)
    * ``whitening.npy`` - Whitening filter (from preprocess)


    Returns
    -------
    numpy.ndarray
        Spike train
    """

    # load yass configuration parameters
    CONFIG = Config.from_yaml(config)
    CONFIG._data['cluster']['min_fr'] = 1
    CONFIG._data['clean_up']['mad']['min_var_gap'] = 1.5
    CONFIG._data['clean_up']['mad']['max_violations'] = 5
    CONFIG._data['neuralnetwork']['apply_nn'] = False
    CONFIG._data['detect']['threshold'] = 4

    set_config(CONFIG._data, output_dir)
    CONFIG = read_config()
    TMP_FOLDER = CONFIG.path_to_output_directory

    # remove tmp folder if needed
    if os.path.exists(TMP_FOLDER) and clean:
        shutil.rmtree(TMP_FOLDER)

    # create TMP_FOLDER if needed
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    # load logging config file
    logging_config = load_logging_config_file()
    logging_config['handlers']['file']['filename'] = os.path.join(
        TMP_FOLDER, 'yass.log')
    logging_config['root']['level'] = logger_level

    # configure logging
    logging.config.dictConfig(logging_config)

    # instantiate logger
    logger = logging.getLogger(__name__)

    # print yass version
    logger.info('YASS version: %s', yass.__version__)
    ''' **********************************************
        ******** SET ENVIRONMENT VARIABLES ***********
        **********************************************
    '''
    os.environ["OPENBLAS_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"
    os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/"

    # TODO: if input spike train is None, run yass with threshold detector
    #if fname_spike_train is None:
    #    logger.info('Not available yet. You must input spike train')
    #    return
    ''' **********************************************
        ************** PREPROCESS ********************
        **********************************************
    '''
    # preprocess
    start = time.time()
    (standardized_path, standardized_dtype) = preprocess.run(
        os.path.join(TMP_FOLDER, 'preprocess'))

    TMP_FOLDER = os.path.join(TMP_FOLDER, 'nn_train')
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    if CONFIG.neuralnetwork.training.input_spike_train_filname is None:

        # run on 10 minutes of data
        rec_len = np.min(
            (CONFIG.rec_len, CONFIG.recordings.sampling_rate * 10 * 60))
        # detect
        logger.info('DETECTION')
        spike_index_path = detect.run(standardized_path,
                                      standardized_dtype,
                                      os.path.join(TMP_FOLDER, 'detect'),
                                      run_chunk_sec=[0, rec_len])

        logger.info('CLUSTERING')

        # cluster
        raw_data = True
        full_run = False
        fname_templates, fname_spike_train = cluster.run(
            os.path.join(TMP_FOLDER, 'cluster'),
            standardized_path,
            standardized_dtype,
            fname_spike_index=spike_index_path,
            raw_data=True,
            full_run=True)

        methods = [
            'off_center', 'low_ptp', 'high_mad', 'duplicate', 'duplicate_l2'
        ]
        fname_templates, fname_spike_train = postprocess.run(
            methods, os.path.join(TMP_FOLDER,
                                  'cluster_post_process'), standardized_path,
            standardized_dtype, fname_templates, fname_spike_train)

    else:
        # if there is an input spike train, use it
        fname_spike_train = CONFIG.neuralnetwork.training.input_spike_train_filname

    # Get training data maker
    DetectTD, DenoTD = augment.run(standardized_path, standardized_dtype,
                                   fname_spike_train,
                                   os.path.join(TMP_FOLDER, 'augment'))

    # Train Detector
    detector = Detect(CONFIG.neuralnetwork.detect.n_filters,
                      CONFIG.spike_size_nn, CONFIG.channel_index).cuda()
    detector.train(os.path.join(TMP_FOLDER, 'detect.pt'), DetectTD)

    # Train Denoiser
    denoiser = Denoise(CONFIG.neuralnetwork.denoise.n_filters,
                       CONFIG.neuralnetwork.denoise.filter_sizes,
                       CONFIG.spike_size_nn).cuda()
    denoiser.train(os.path.join(TMP_FOLDER, 'denoise.pt'), DenoTD)
示例#26
0
def run(output_directory, fname_spike_train, fname_shifts, fname_scales,
        fname_templates, fname_soft_assignment, fname_residual,
        residual_dtype):

    logger = logging.getLogger(__name__)

    CONFIG = read_config()

    # output folder
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    fname_spike_train_out = os.path.join(output_directory, 'spike_train.npy')
    fname_templates_out = os.path.join(output_directory, 'templates.npy')
    fname_soft_assignment_out = os.path.join(output_directory,
                                             'soft_assignment.npy')
    fname_shifts_out = os.path.join(output_directory, 'shifts.npy')
    fname_scales_out = os.path.join(output_directory, 'scales.npy')
    if os.path.exists(fname_spike_train_out) and os.path.exists(
            fname_templates_out):
        return (fname_templates_out, fname_spike_train_out, fname_shifts_out,
                fname_scales_out, fname_soft_assignment_out)

    reader_residual = READER(fname_residual, residual_dtype, CONFIG)

    # get whitening filters
    fname_spatial_cov = os.path.join(output_directory, 'spatial_cov.npy')
    fname_temporal_cov = os.path.join(output_directory, 'temporal_cov.npy')
    if not (os.path.exists(fname_spatial_cov)
            and os.path.exists(fname_temporal_cov)):
        spatial_cov, temporal_cov = get_noise_covariance(
            reader_residual, CONFIG)
        np.save(fname_spatial_cov, spatial_cov)
        np.save(fname_temporal_cov, temporal_cov)
    else:
        spatial_cov = np.load(fname_spatial_cov)
        temporal_cov = np.load(fname_temporal_cov)

    # initialize merge: find candidates
    logger.info("finding merge candidates")
    tm = TemplateMerge(output_directory, reader_residual, fname_templates,
                       fname_spike_train, fname_shifts, fname_scales,
                       fname_soft_assignment, fname_spatial_cov,
                       fname_temporal_cov, CONFIG.geom,
                       CONFIG.resources.multi_processing,
                       CONFIG.resources.n_processors)

    # find merge pairs
    logger.info("merging pairs")
    tm.get_merge_pairs()

    # update templates adn spike train accordingly
    logger.info("udpating templates and spike train")
    (templates_new, spike_train_new, shifts_new, scales_new,
     soft_assignment_new, merge_array) = tm.merge_units()

    # save results
    fname_merge_array = os.path.join(output_directory, 'merge_array.npy')
    np.save(fname_merge_array, merge_array)
    np.save(fname_spike_train_out, spike_train_new)
    np.save(fname_templates_out, templates_new)
    np.save(fname_shifts_out, shifts_new)
    np.save(fname_scales_out, scales_new)
    np.save(fname_soft_assignment_out, soft_assignment_new)

    logger.info('Number of units after merge: {}'.format(
        templates_new.shape[0]))

    return (fname_templates_out, fname_spike_train_out, fname_shifts_out,
            fname_scales_out, fname_soft_assignment_out)
示例#27
0
def _run_pipeline(config,
                  output_file,
                  logger_level='INFO',
                  clean=True,
                  output_dir='tmp/',
                  complete=False):
    """
    Run the entire pipeline given a path to a config file
    and output path
    """
    # load yass configuration parameters
    set_config(config)
    CONFIG = read_config()
    ROOT_FOLDER = CONFIG.data.root_folder
    TMP_FOLDER = path.join(ROOT_FOLDER, output_dir)

    # remove tmp folder if needed
    if os.path.exists(TMP_FOLDER) and clean:
        shutil.rmtree(TMP_FOLDER)

    # create TMP_FOLDER if needed
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    # load logging config file
    logging_config = load_logging_config_file()
    logging_config['handlers']['file']['filename'] = path.join(
        TMP_FOLDER, 'yass.log')
    logging_config['root']['level'] = logger_level

    # configure logging
    logging.config.dictConfig(logging_config)

    # instantiate logger
    logger = logging.getLogger(__name__)

    # run preprocessor
    (score, spike_index_clear,
     spike_index_collision) = preprocess.run(output_directory=output_dir)

    # run processor
    (spike_train_clear, templates,
     spike_index_collision) = process.run(score,
                                          spike_index_clear,
                                          spike_index_collision,
                                          output_directory=output_dir)

    # run deconvolution
    spike_train = deconvolute.run(spike_train_clear,
                                  templates,
                                  spike_index_collision,
                                  output_directory=output_dir)

    # save metadata in tmp
    path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml')
    logging.info('Saving metadata in {}'.format(path_to_metadata))
    save_metadata(path_to_metadata)

    # save config.yaml copy in tmp/
    path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml')
    shutil.copy2(config, path_to_config_copy)
    logging.info('Saving copy of config: {} in {}'.format(
        config, path_to_config_copy))

    # save templates
    path_to_templates = path.join(TMP_FOLDER, 'templates.npy')
    logging.info('Saving templates in {}'.format(path_to_templates))
    np.save(path_to_templates, templates)

    path_to_spike_train = path.join(TMP_FOLDER, output_file)
    np.save(path_to_spike_train, spike_train)
    logger.info('Spike train saved in: {}'.format(path_to_spike_train))

    # this part loads waveforms for all spikes in the spike train and scores
    # them, this data is needed to later generate phy files
    if complete:
        STANDARIZED_PATH = path.join(TMP_FOLDER, 'standarized.bin')
        PARAMS = load_yaml(path.join(TMP_FOLDER, 'standarized.yaml'))

        # load waveforms for all spikes in the spike train
        logger.info('Loading waveforms from all spikes in the spike train...')
        explorer = RecordingExplorer(STANDARIZED_PATH,
                                     spike_size=CONFIG.spikeSize,
                                     dtype=PARAMS['dtype'],
                                     n_channels=PARAMS['n_channels'],
                                     data_format=PARAMS['data_format'])
        waveforms = explorer.read_waveforms(spike_train[:, 0])

        path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy')
        np.save(path_to_waveforms, waveforms)
        logger.info('Saved all waveforms from the spike train in {}...'.format(
            path_to_waveforms))

        # score all waveforms
        logger.info('Scoring waveforms from all spikes in the spike train...')
        path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy')
        rotation = np.load(path_to_rotation)

        main_channels = explorer.main_channel_for_waveforms(waveforms)
        path_to_main_channels = path.join(TMP_FOLDER,
                                          'waveforms_main_channel.npy')
        np.save(path_to_main_channels, main_channels)
        logger.info('Saved all waveforms main channels in {}...'.format(
            path_to_waveforms))

        waveforms_score = dim_red.score(waveforms, rotation, main_channels,
                                        CONFIG.neighChannels, CONFIG.geom)
        path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy')
        np.save(path_to_waveforms_score, waveforms_score)
        logger.info('Saved all scores in {}...'.format(path_to_waveforms))

        # score templates
        # TODO: templates should be returned in the right shape to avoid .T
        templates_ = templates.T
        main_channels_tmpls = explorer.main_channel_for_waveforms(templates_)
        path_to_templates_main_c = path.join(TMP_FOLDER,
                                             'templates_main_channel.npy')
        np.save(path_to_templates_main_c, main_channels_tmpls)
        logger.info('Saved all templates main channels in {}...'.format(
            path_to_templates_main_c))

        templates_score = dim_red.score(templates_, rotation,
                                        main_channels_tmpls,
                                        CONFIG.neighChannels, CONFIG.geom)
        path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy')
        np.save(path_to_templates_score, templates_score)
        logger.info(
            'Saved all templates scores in {}...'.format(path_to_waveforms))
示例#28
0
文件: run.py 项目: hooshmandshr/yass
def _neural_network_detection(standarized_path, standarized_params,
                              n_observations, output_directory):
    """Run neural network detection and autoencoder dimensionality reduction
    """
    logger = logging.getLogger(__name__)

    CONFIG = read_config()
    OUTPUT_DTYPE = CONFIG.preprocess.dtype
    TMP_FOLDER = os.path.join(CONFIG.data.root_folder, output_directory)

    # detect spikes
    bp = BatchProcessor(standarized_path,
                        standarized_params['dtype'],
                        standarized_params['n_channels'],
                        standarized_params['data_format'],
                        CONFIG.resources.max_memory,
                        buffer_size=0)

    # check if all scores, clear and collision spikes exist..
    path_to_score = os.path.join(TMP_FOLDER, 'score_clear.npy')
    path_to_spike_index_clear = os.path.join(TMP_FOLDER,
                                             'spike_index_clear.npy')
    path_to_spike_index_collision = os.path.join(TMP_FOLDER,
                                                 'spike_index_collision.npy')

    if all([
            os.path.exists(path_to_score),
            os.path.exists(path_to_spike_index_clear),
            os.path.exists(path_to_spike_index_collision)
    ]):
        logger.info('Loading "{}", "{}" and "{}"'.format(
            path_to_score, path_to_spike_index_clear,
            path_to_spike_index_collision))

        scores = np.load(path_to_score)
        clear = np.load(path_to_spike_index_clear)
        collision = np.load(path_to_spike_index_collision)

    else:
        logger.info('One or more of "{}", "{}" or "{}" files were missing, '
                    'computing...'.format(path_to_score,
                                          path_to_spike_index_clear,
                                          path_to_spike_index_collision))

        # apply threshold detector on standarized data
        autoencoder_filename = CONFIG.neural_network_autoencoder.filename
        mc = bp.multi_channel_apply
        res = mc(
            neuralnetwork.nn_detection,
            mode='memory',
            cleanup_function=neuralnetwork.fix_indexes,
            neighbors=CONFIG.neighChannels,
            geom=CONFIG.geom,
            temporal_features=CONFIG.spikes.temporal_features,
            # FIXME: what is this?
            temporal_window=3,
            th_detect=CONFIG.neural_network_detector.threshold_spike,
            th_triage=CONFIG.neural_network_triage.threshold_collision,
            detector_filename=CONFIG.neural_network_detector.filename,
            autoencoder_filename=autoencoder_filename,
            triage_filename=CONFIG.neural_network_triage.filename)

        # save clear spikes
        clear = np.concatenate([element[1] for element in res], axis=0)
        logger.info('Removing clear indexes outside the allowed range to '
                    'draw a complete waveform...')
        clear, idx = detect.remove_incomplete_waveforms(
            clear, CONFIG.spikeSize + CONFIG.templatesMaxShift, n_observations)
        np.save(path_to_spike_index_clear, clear)
        logger.info('Saved spike index clear in {}...'.format(
            path_to_spike_index_clear))

        # save collided spikes
        collision = np.concatenate([element[2] for element in res], axis=0)
        logger.info('Removing collision indexes outside the allowed range to '
                    'draw a complete waveform...')
        collision, _ = detect.remove_incomplete_waveforms(
            collision, CONFIG.spikeSize + CONFIG.templatesMaxShift,
            n_observations)
        np.save(path_to_spike_index_collision, collision)
        logger.info('Saved spike index collision in {}...'.format(
            path_to_spike_index_collision))

        if CONFIG.clustering.clustering_method == 'location':
            #######################
            # Waveform extraction #
            #######################

            # TODO: what should the behaviour be for spike indexes that are
            # when starting/ending the recordings and it is not possible to
            # draw a complete waveform?
            logger.info('Computing whitening matrix...')
            bp = BatchProcessor(standarized_path, standarized_params['dtype'],
                                standarized_params['n_channels'],
                                standarized_params['data_format'],
                                CONFIG.resources.max_memory)
            batches = bp.multi_channel()
            first_batch, _, _ = next(batches)
            Q = whiten.matrix(first_batch, CONFIG.neighChannels,
                              CONFIG.spikeSize)

            path_to_whitening_matrix = os.path.join(TMP_FOLDER,
                                                    'whitening.npy')
            np.save(path_to_whitening_matrix, Q)
            logger.info('Saved whitening matrix in {}'.format(
                path_to_whitening_matrix))

            # apply whitening to every batch
            (whitened_path, whitened_params) = bp.multi_channel_apply(
                np.matmul,
                mode='disk',
                output_path=os.path.join(TMP_FOLDER, 'whitened.bin'),
                if_file_exists='skip',
                cast_dtype=OUTPUT_DTYPE,
                b=Q)

            main_channel = clear[:, 1]

            # load and dump waveforms from clear spikes

            path_to_waveforms_clear = os.path.join(TMP_FOLDER,
                                                   'waveforms_clear.npy')

            if os.path.exists(path_to_waveforms_clear):
                logger.info(
                    'Found clear waveforms in {}, loading them...'.format(
                        path_to_waveforms_clear))
                waveforms_clear = np.load(path_to_waveforms_clear)
            else:
                logger.info(
                    'Did not find clear waveforms in {}, reading them from {}'.
                    format(path_to_waveforms_clear, whitened_path))
                explorer = RecordingExplorer(whitened_path,
                                             spike_size=CONFIG.spikeSize)
                waveforms_clear = explorer.read_waveforms(clear[:, 0], 'all')
                np.save(path_to_waveforms_clear, waveforms_clear)
                logger.info('Saved waveform from clear spikes in: {}'.format(
                    path_to_waveforms_clear))

            main_channel = clear[:, 1]

            # save rotation
            detector_filename = CONFIG.neural_network_detector.filename
            autoencoder_filename = CONFIG.neural_network_autoencoder.filename
            rotation = neuralnetwork.load_rotation(detector_filename,
                                                   autoencoder_filename)
            path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy')
            logger.info("rotation_matrix_shape = {}".format(rotation.shape))
            np.save(path_to_rotation, rotation)
            logger.info(
                'Saved rotation matrix in {}...'.format(path_to_rotation))

            logger.info('Denoising...')
            path_to_denoised_waveforms = os.path.join(
                TMP_FOLDER, 'denoised_waveforms.npy')
            if os.path.exists(path_to_denoised_waveforms):
                logger.info(
                    'Found denoised waveforms in {}, loading them...'.format(
                        path_to_denoised_waveforms))
                denoised_waveforms = np.load(path_to_denoised_waveforms)
            else:
                logger.info(
                    'Did not find denoised waveforms in {}, evaluating them'
                    'from {}'.format(path_to_denoised_waveforms,
                                     path_to_waveforms_clear))
                waveforms_clear = np.load(path_to_waveforms_clear)
                denoised_waveforms = dim_red.denoise(waveforms_clear, rotation,
                                                     CONFIG)
                logger.info('Saving denoised waveforms to {}'.format(
                    path_to_denoised_waveforms))
                np.save(path_to_denoised_waveforms, denoised_waveforms)

            isolated_index, x, y = get_isolated_spikes_and_locations(
                denoised_waveforms, main_channel, CONFIG)
            x = (x - np.mean(x)) / np.std(x)
            y = (y - np.mean(y)) / np.std(y)
            corrupted_index = np.logical_not(
                np.in1d(np.arange(clear.shape[0]), isolated_index))
            collision = np.concatenate([collision, clear[corrupted_index]],
                                       axis=0)
            clear = clear[isolated_index]
            waveforms_clear = waveforms_clear[isolated_index]
            #################################################
            # Dimensionality reduction (Isolated Waveforms) #
            #################################################

            scores = dim_red.main_channel_scores(waveforms_clear, rotation,
                                                 clear, CONFIG)
            scores = (scores - np.mean(scores, axis=0)) / np.std(scores)
            scores = np.concatenate([
                x[:, np.newaxis, np.newaxis], y[:, np.newaxis, np.newaxis],
                scores[:, :, np.newaxis]
            ],
                                    axis=1)

        else:

            # save scores
            scores = np.concatenate([element[0] for element in res], axis=0)

            logger.info(
                'Removing scores for indexes outside the allowed range to '
                'draw a complete waveform...')
            scores = scores[idx]

            # compute Q for whitening
            logger.info('Computing whitening matrix...')
            bp = BatchProcessor(standarized_path, standarized_params['dtype'],
                                standarized_params['n_channels'],
                                standarized_params['data_format'],
                                CONFIG.resources.max_memory)
            batches = bp.multi_channel()
            first_batch, _, _ = next(batches)
            Q = whiten.matrix_localized(first_batch, CONFIG.neighChannels,
                                        CONFIG.geom, CONFIG.spikeSize)

            path_to_whitening_matrix = os.path.join(TMP_FOLDER,
                                                    'whitening.npy')
            np.save(path_to_whitening_matrix, Q)
            logger.info('Saved whitening matrix in {}'.format(
                path_to_whitening_matrix))

            scores = whiten.score(scores, clear[:, 1], Q)

            np.save(path_to_score, scores)
            logger.info('Saved spike scores in {}...'.format(path_to_score))

            # save rotation
            detector_filename = CONFIG.neural_network_detector.filename
            autoencoder_filename = CONFIG.neural_network_autoencoder.filename
            rotation = neuralnetwork.load_rotation(detector_filename,
                                                   autoencoder_filename)
            path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy')
            np.save(path_to_rotation, rotation)
            logger.info(
                'Saved rotation matrix in {}...'.format(path_to_rotation))

        np.save(path_to_score, scores)
        logger.info('Saved spike scores in {}...'.format(path_to_score))
    return scores, clear, collision
示例#29
0
文件: run.py 项目: hooshmandshr/yass
def run(output_directory='tmp/'):
    """Execute preprocessing pipeline

    Parameters
    ----------
    output_directory: str, optional
      Location to store partial results, relative to CONFIG.data.root_folder,
      defaults to tmp/

    Returns
    -------
    clear_scores: numpy.ndarray (n_spikes, n_features, n_channels)
        3D array with the scores for the clear spikes, first simension is
        the number of spikes, second is the nymber of features and third the
        number of channels

    spike_index_clear: numpy.ndarray (n_clear_spikes, 2)
        2D array with indexes for clear spikes, first column contains the
        spike location in the recording and the second the main channel
        (channel whose amplitude is maximum)

    spike_index_collision: numpy.ndarray (n_collided_spikes, 2)
        2D array with indexes for collided spikes, first column contains the
        spike location in the recording and the second the main channel
        (channel whose amplitude is maximum)

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``config.yaml`` - Copy of the configuration file
    * ``metadata.yaml`` - Experiment metadata
    * ``filtered.bin`` - Filtered recordings
    * ``filtered.yaml`` - Filtered recordings metadata
    * ``standarized.bin`` - Standarized recordings
    * ``standarized.yaml`` - Standarized recordings metadata
    * ``whitened.bin`` - Whitened recordings
    * ``whitened.yaml`` - Whitened recordings metadata
    * ``rotation.npy`` - Rotation matrix for dimensionality reduction
    * ``spike_index_clear.npy`` - Same as spike_index_clear returned
    * ``spike_index_collision.npy`` - Same as spike_index_collision returned
    * ``score_clear.npy`` - Scores for clear spikes
    * ``waveforms_clear.npy`` - Waveforms for clear spikes

    Examples
    --------

    .. literalinclude:: ../examples/preprocess.py
    """

    logger = logging.getLogger(__name__)

    CONFIG = read_config()

    OUTPUT_DTYPE = CONFIG.preprocess.dtype

    logger.info(
        'Output dtype for transformed data will be {}'.format(OUTPUT_DTYPE))

    TMP = os.path.join(CONFIG.data.root_folder, output_directory)

    if not os.path.exists(TMP):
        logger.info('Creating temporary folder: {}'.format(TMP))
        os.makedirs(TMP)
    else:
        logger.info('Temporary folder {} already exists, output will be '
                    'stored there'.format(TMP))

    path = os.path.join(CONFIG.data.root_folder, CONFIG.data.recordings)
    dtype = CONFIG.recordings.dtype

    # initialize pipeline object, one batch per channel
    pipeline = BatchPipeline(path, dtype, CONFIG.recordings.n_channels,
                             CONFIG.recordings.format,
                             CONFIG.resources.max_memory, TMP)

    # add filter transformation if necessary
    if CONFIG.preprocess.filter:
        filter_op = Transform(butterworth,
                              'filtered.bin',
                              mode='single_channel_one_batch',
                              keep=True,
                              if_file_exists='skip',
                              cast_dtype=OUTPUT_DTYPE,
                              low_freq=CONFIG.filter.low_pass_freq,
                              high_factor=CONFIG.filter.high_factor,
                              order=CONFIG.filter.order,
                              sampling_freq=CONFIG.recordings.sampling_rate)

        pipeline.add([filter_op])

    (filtered_path, ), (filtered_params, ) = pipeline.run()

    # standarize
    bp = BatchProcessor(filtered_path, filtered_params['dtype'],
                        filtered_params['n_channels'],
                        filtered_params['data_format'],
                        CONFIG.resources.max_memory)
    batches = bp.multi_channel()
    first_batch, _, _ = next(batches)
    sd = standard_deviation(first_batch, CONFIG.recordings.sampling_rate)

    (standarized_path, standarized_params) = bp.multi_channel_apply(
        standarize,
        mode='disk',
        output_path=os.path.join(TMP, 'standarized.bin'),
        if_file_exists='skip',
        cast_dtype=OUTPUT_DTYPE,
        sd=sd)

    standarized = RecordingsReader(standarized_path)
    n_observations = standarized.observations

    if CONFIG.spikes.detection == 'threshold':
        return _threshold_detection(standarized_path, standarized_params,
                                    n_observations, output_directory)
    elif CONFIG.spikes.detection == 'nn':
        return _neural_network_detection(standarized_path, standarized_params,
                                         n_observations, output_directory)
示例#30
0
文件: run.py 项目: hooshmandshr/yass
def _threshold_detection(standarized_path, standarized_params, n_observations,
                         output_directory):
    """Run threshold detector and dimensionality reduction using PCA
    """
    logger = logging.getLogger(__name__)

    CONFIG = read_config()
    OUTPUT_DTYPE = CONFIG.preprocess.dtype
    TMP_FOLDER = os.path.join(CONFIG.data.root_folder, output_directory)

    ###############
    # Whiten data #
    ###############

    # compute Q for whitening
    logger.info('Computing whitening matrix...')
    bp = BatchProcessor(standarized_path, standarized_params['dtype'],
                        standarized_params['n_channels'],
                        standarized_params['data_format'],
                        CONFIG.resources.max_memory)
    batches = bp.multi_channel()
    first_batch, _, _ = next(batches)
    Q = whiten.matrix(first_batch, CONFIG.neighChannels, CONFIG.spikeSize)

    path_to_whitening_matrix = os.path.join(TMP_FOLDER, 'whitening.npy')
    np.save(path_to_whitening_matrix, Q)
    logger.info(
        'Saved whitening matrix in {}'.format(path_to_whitening_matrix))

    # apply whitening to every batch
    (whitened_path, whitened_params) = bp.multi_channel_apply(
        np.matmul,
        mode='disk',
        output_path=os.path.join(TMP_FOLDER, 'whitened.bin'),
        if_file_exists='skip',
        cast_dtype=OUTPUT_DTYPE,
        b=Q)

    ###################
    # Spike detection #
    ###################

    path_to_spike_index_clear = os.path.join(TMP_FOLDER,
                                             'spike_index_clear.npy')

    bp = BatchProcessor(standarized_path,
                        standarized_params['dtype'],
                        standarized_params['n_channels'],
                        standarized_params['data_format'],
                        CONFIG.resources.max_memory,
                        buffer_size=0)

    # clear spikes
    if os.path.exists(path_to_spike_index_clear):
        # if it exists, load it...
        logger.info('Found file in {}, loading it...'.format(
            path_to_spike_index_clear))
        spike_index_clear = np.load(path_to_spike_index_clear)
    else:
        # if it doesn't, detect spikes...
        logger.info('Did not find file in {}, finding spikes using threshold'
                    ' detector...'.format(path_to_spike_index_clear))

        # apply threshold detector on standarized data
        spikes = bp.multi_channel_apply(detect.threshold,
                                        mode='memory',
                                        cleanup_function=detect.fix_indexes,
                                        neighbors=CONFIG.neighChannels,
                                        spike_size=CONFIG.spikeSize,
                                        std_factor=CONFIG.stdFactor)
        spike_index_clear = np.vstack(spikes)

        logger.info('Removing clear indexes outside the allowed range to '
                    'draw a complete waveform...')
        spike_index_clear, _ = (detect.remove_incomplete_waveforms(
            spike_index_clear, CONFIG.spikeSize + CONFIG.templatesMaxShift,
            n_observations))

        logger.info('Saving spikes in {}...'.format(path_to_spike_index_clear))
        np.save(path_to_spike_index_clear, spike_index_clear)

    path_to_spike_index_collision = os.path.join(TMP_FOLDER,
                                                 'spike_index_collision.npy')

    # collided spikes
    if os.path.exists(path_to_spike_index_collision):
        # if it exists, load it...
        logger.info('Found collided spikes in {}, loading them...'.format(
            path_to_spike_index_collision))
        spike_index_collision = np.load(path_to_spike_index_collision)

        if spike_index_collision.shape[0] != 0:
            raise ValueError('Found non-empty collision spike index in {}, '
                             'but threshold detector is selected, collision '
                             'detection is not implemented for threshold '
                             'detector so array must have dimensios (0, 2) '
                             'but had ({}, {})'.format(
                                 path_to_spike_index_collision,
                                 *spike_index_collision.shape))
    else:
        # triage is not implemented on threshold detector, return empty array
        logger.info('Creating empty array for'
                    ' collided spikes (collision detection is not implemented'
                    ' with threshold detector. Saving them in {}'.format(
                        path_to_spike_index_collision))
        spike_index_collision = np.zeros((0, 2), 'int32')
        np.save(path_to_spike_index_collision, spike_index_collision)

    #######################
    # Waveform extraction #
    #######################

    # load and dump waveforms from clear spikes
    path_to_waveforms_clear = os.path.join(TMP_FOLDER, 'waveforms_clear.npy')

    if os.path.exists(path_to_waveforms_clear):
        logger.info('Found clear waveforms in {}, loading them...'.format(
            path_to_waveforms_clear))
        waveforms_clear = np.load(path_to_waveforms_clear)
    else:
        logger.info(
            'Did not find clear waveforms in {}, reading them from {}'.format(
                path_to_waveforms_clear, standarized_path))
        explorer = RecordingExplorer(standarized_path,
                                     spike_size=CONFIG.spikeSize)
        waveforms_clear = explorer.read_waveforms(spike_index_clear[:, 0])
        np.save(path_to_waveforms_clear, waveforms_clear)
        logger.info('Saved waveform from clear spikes in: {}'.format(
            path_to_waveforms_clear))

    #########################
    # PCA - rotation matrix #
    #########################

    # compute per-batch sufficient statistics for PCA on standarized data
    logger.info('Computing PCA sufficient statistics...')
    stats = bp.multi_channel_apply(dim_red.suff_stat,
                                   mode='memory',
                                   spike_index=spike_index_clear,
                                   spike_size=CONFIG.spikeSize)

    suff_stats = reduce(lambda x, y: np.add(x, y), [e[0] for e in stats])

    spikes_per_channel = reduce(lambda x, y: np.add(x, y),
                                [e[1] for e in stats])

    # compute rotation matrix
    logger.info('Computing PCA projection matrix...')
    rotation = dim_red.project(suff_stats, spikes_per_channel,
                               CONFIG.spikes.temporal_features,
                               CONFIG.neighChannels)
    path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy')
    np.save(path_to_rotation, rotation)
    logger.info('Saved rotation matrix in {}...'.format(path_to_rotation))

    main_channel = spike_index_clear[:, 1]
    ###########################################
    # PCA - waveform dimensionality reduction #
    ###########################################
    if CONFIG.clustering.clustering_method == 'location':
        logger.info('Denoising...')
        path_to_denoised_waveforms = os.path.join(TMP_FOLDER,
                                                  'denoised_waveforms.npy')
        if os.path.exists(path_to_denoised_waveforms):
            logger.info(
                'Found denoised waveforms in {}, loading them...'.format(
                    path_to_denoised_waveforms))
            denoised_waveforms = np.load(path_to_denoised_waveforms)
        else:
            logger.info(
                'Did not find denoised waveforms in {}, evaluating them'
                'from {}'.format(path_to_denoised_waveforms,
                                 path_to_waveforms_clear))
            waveforms_clear = np.load(path_to_waveforms_clear)
            denoised_waveforms = dim_red.denoise(waveforms_clear, rotation,
                                                 CONFIG)
            logger.info('Saving denoised waveforms to {}'.format(
                path_to_denoised_waveforms))
            np.save(path_to_denoised_waveforms, denoised_waveforms)

        isolated_index, x, y = get_isolated_spikes_and_locations(
            denoised_waveforms, main_channel, CONFIG)
        x = (x - np.mean(x)) / np.std(x)
        y = (y - np.mean(y)) / np.std(y)
        corrupted_index = np.logical_not(
            np.in1d(np.arange(spike_index_clear.shape[0]), isolated_index))
        spike_index_collision = np.concatenate(
            [spike_index_collision, spike_index_clear[corrupted_index]],
            axis=0)
        spike_index_clear = spike_index_clear[isolated_index]
        waveforms_clear = waveforms_clear[isolated_index]

        #################################################
        # Dimensionality reduction (Isolated Waveforms) #
        #################################################

        scores = dim_red.main_channel_scores(waveforms_clear, rotation,
                                             spike_index_clear, CONFIG)
        scores = (scores - np.mean(scores, axis=0)) / np.std(scores)
        scores = np.concatenate([
            x[:, np.newaxis, np.newaxis], y[:, np.newaxis, np.newaxis],
            scores[:, :, np.newaxis]
        ],
                                axis=1)
    else:
        logger.info('Reducing spikes dimensionality with PCA matrix...')
        scores = dim_red.score(waveforms_clear, rotation, spike_index_clear[:,
                                                                            1],
                               CONFIG.neighChannels, CONFIG.geom)

        # save scores
    path_to_score = os.path.join(TMP_FOLDER, 'score_clear.npy')
    np.save(path_to_score, scores)
    logger.info('Saved spike scores in {}...'.format(path_to_score))

    return scores, spike_index_clear, spike_index_collision