Esempio n. 1
0
def test_can_use_neural_network_detector(path_to_tests,
                                         path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    data = RecordingsReader(path_to_standarized_data, loader='array').data

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index)
    NNT = NeuralNetTriage.load(triage_fname,
                               triage_th,
                               input_tensor=NND.waveform_tf)
    NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf)

    output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean)

    with tf.Session() as sess:
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        neuralnetwork.run_detect_triage_featurize(data, sess, NND.x_tf,
                                                  output_tf, neighbors, rot)
Esempio n. 2
0
def test_can_kill_signal(path_to_standarized_data):
    recordings = RecordingsReader(path_to_standarized_data,
                                  loader='array')._data

    noise.kill_signal(recordings,
                      threshold=3.0,
                      window_size=10)
Esempio n. 3
0
def test_can_use_neural_network_detector(path_to_tests):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    data = RecordingsReader(path.join(path_to_tests, 'data/standarized.bin'),
                            loader='array').data

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    whiten_filter = np.tile(
        np.eye(channel_index.shape[1], dtype='float32')[np.newaxis, :, :],
        [channel_index.shape[0], 1, 1])

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    (x_tf, output_tf, NND, NNAE,
     NNT) = neuralnetwork.prepare_nn(channel_index, whiten_filter,
                                     detection_th, triage_th, detection_fname,
                                     ae_fname, triage_fname)

    with tf.Session() as sess:
        # get values of above tensors
        NND.saver.restore(sess, NND.path_to_detector_model)
        NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model)
        NNT.saver.restore(sess, NNT.path_to_triage_model)
        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        neuralnetwork.run_detect_triage_featurize(data, sess, x_tf, output_tf,
                                                  neighbors, rot)
Esempio n. 4
0
    def __init__(self,
                 path_to_recordings,
                 path_to_geom=None,
                 spike_size=None,
                 neighbor_radius=None,
                 dtype=None,
                 n_channels=None,
                 data_format=None,
                 mmap=True,
                 waveform_dtype='float32'):
        self.data = RecordingsReader(path_to_recordings,
                                     dtype,
                                     n_channels,
                                     data_format,
                                     mmap,
                                     output_shape='long')

        if path_to_geom is not None:
            self.geom = geom.parse(path_to_geom, n_channels)
            self.neighbor_radius = neighbor_radius
            self.neigh_matrix = geom.find_channel_neighbors(
                self.geom, neighbor_radius)
        self.n_channels = self.data.channels
        self.spike_size = spike_size

        if waveform_dtype == 'default':
            waveform_dtype = dtype

        self.waveform_dtype = waveform_dtype

        self.logger = logging.getLogger(__name__)
Esempio n. 5
0
    def __init__(self,
                 path_to_recordings,
                 path_to_geom=None,
                 spike_size=None,
                 neighbor_radius=None,
                 dtype=None,
                 n_channels=None,
                 data_order=None,
                 loader='memmap',
                 waveform_dtype='float32'):
        self.reader = RecordingsReader(path_to_recordings, dtype, n_channels,
                                       data_order, loader)

        if path_to_geom is not None:
            self.geom = geom.parse(path_to_geom, n_channels)
            self.neighbor_radius = neighbor_radius
            self.neigh_matrix = geom.find_channel_neighbors(
                self.geom, neighbor_radius)
        self.n_channels = self.reader.channels
        self.spike_size = spike_size

        if waveform_dtype == 'default':
            waveform_dtype = dtype

        self.waveform_dtype = waveform_dtype

        self.logger = logging.getLogger(__name__)
Esempio n. 6
0
def test_can_compute_noise_cov(path_to_tests, path_to_standarized_data):
    recordings = RecordingsReader(path_to_standarized_data,
                                  loader='array')._data

    spatial_SIG, temporal_SIG = noise_cov(recordings,
                                          temporal_size=10,
                                          sample_size=100,
                                          threshold=3.0,
                                          window_size=10)
Esempio n. 7
0
def test_can_read_in_wide_format(path_to_wide, path_to_tests):
    indexer = RecordingsReader(path_to_wide,
                               n_channels=10,
                               data_order='channels',
                               dtype='float64',
                               loader='array')
    res = indexer[1000:1020, [1, 5]]

    np.testing.assert_equal(res, expected)
Esempio n. 8
0
def test_can_read_in_long_format(path_to_long, path_to_tests):
    indexer = RecordingsReader(path_to_long,
                               n_channels=10,
                               data_order='samples',
                               dtype='float64',
                               loader='array')
    res = indexer[1000:1020, [1, 5]]
    # reader always returns data in wide format
    expected = np.load(
        os.path.join(path_to_tests, 'data/test_indexer/wide.npy')).T

    np.testing.assert_equal(res, expected)
Esempio n. 9
0
def test_can_read_in_wide_format(path_to_wide, path_to_tests):
    indexer = RecordingsReader(path_to_wide,
                               n_channels=10,
                               data_format='wide',
                               dtype='float64',
                               mmap=False,
                               output_shape='wide')
    res = indexer[1000:1020, [1, 5]]
    expected = np.load(
        os.path.join(path_to_tests, 'data/test_indexer/wide.npy'))

    np.testing.assert_equal(res, expected)
Esempio n. 10
0
def test_can_estimate_temporal_and_spatial_sig(path_to_standarized_data):

    recordings = RecordingsReader(path_to_standarized_data,
                                  loader='array')._data

    (spatial_SIG,
        temporal_SIG) = noise.noise_cov(recordings,
                                        temporal_size=40,
                                        sample_size=1000,
                                        threshold=3.0,
                                        window_size=10)

    # check no nans
    assert (~np.isnan(spatial_SIG)).all()
    assert (~np.isnan(temporal_SIG)).all()
Esempio n. 11
0
x_long = dummy(big_long[(slice(0, 2000000, None), 1)])
big_long[(slice(0, 2000000, None), 1)] = x_long
big_long.flush()

x_long = dummy(big_long[:, 1])
big_long[:, 1] = x_long

bp_long = BatchProcessor(path_to_long,
                         dtype='int64',
                         n_channels=50,
                         data_format='long',
                         max_memory='500MB')

path = bp_long.single_channel_apply(dummy, path_to_out)

out = RecordingsReader(path)
out

bp_wide = BatchProcessor(path_to_wide,
                         dtype='int64',
                         n_channels=50,
                         data_format='wide',
                         max_memory='500MB')

path = bp_wide.single_channel_apply(dummy, path_to_out)
out = RecordingsReader(path)
out

path = bp_long.multi_channel_apply(dummy, path_to_out)
out = RecordingsReader(path)
out
Esempio n. 12
0

# create batch processor for the data
bp = BatchProcessor(path_to_neuropixel_data,
                    dtype='int16', n_channels=385, data_format='wide',
                    max_memory='500MB')


# appply a single channel transformation, each batch will be all observations
# from one channel, results are saved to disk
bp.single_channel_apply(butterworth,
                        mode='disk',
                        output_path=path_to_filtered_data,
                        low_freq=300, high_factor=0.1,
                        order=3, sampling_freq=30000,
                        channels=[0, 1, 2])


# let's visualize the results
raw = RecordingsReader(path_to_neuropixel_data, dtype='int16',
                       n_channels=385, data_format='wide')

# you do not need to specify the format since single_channel_apply
# saves a yaml file with such parameters
filtered = RecordingsReader(path_to_filtered_data)

fig, (ax1, ax2) = plt.subplots(2, 1)
ax1.plot(raw[:2000, 0])
ax2.plot(filtered[:2000, 0])
plt.show()
Esempio n. 13
0
def noise_cov(path_to_data, dtype, n_channels, data_format, neighbors, geom,
              temporal_size):
    """[Description]

    Parameters
    ----------
    path_to_data: str
        Path to recordings data
    dtype: str
        dtype for recordings
    n_channels: int
        Number of channels in the recordings
    data_format: str
        Recordings shape ('wide', 'long')
    neighbors: numpy.ndarray
        Neighbors matrix
    geom: numpy.ndarray
        Cartesian coordinates for the channels
    temporal_size:
        Waveform size

    Returns
    -------
    """
    c_ref = np.argmax(np.sum(neighbors, 0))
    ch_idx = np.where(neighbors[c_ref])[0]
    ch_idx, temp = order_channels_by_distance(c_ref, ch_idx, geom)

    rec = RecordingsReader(path_to_data, dtype=dtype, n_channels=n_channels,
                           data_format=data_format, mmap=False)
    rec = rec[:, ch_idx]

    T, C = rec.shape
    idxNoise = np.zeros((T, C))

    R = int((temporal_size-1)/2)
    for c in range(C):
        idx_temp = np.where(rec[:, c] > 3)[0]
        for j in range(-R, R+1):
            idx_temp2 = idx_temp + j
            idx_temp2 = idx_temp2[np.logical_and(
                idx_temp2 >= 0, idx_temp2 < T)]
            rec[idx_temp2, c] = np.nan
        idxNoise_temp = (rec[:, c] == rec[:, c])
        rec[:, c] = rec[:, c]/np.nanstd(rec[:, c])

        rec[~idxNoise_temp, c] = 0
        idxNoise[idxNoise_temp, c] = 1

    spatial_cov = np.divide(np.matmul(rec.T, rec),
                            np.matmul(idxNoise.T, idxNoise))

    w, v = np.linalg.eig(spatial_cov)
    spatial_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T)

    spatial_whitener = np.matmul(np.matmul(v, np.diag(1/np.sqrt(w))), v.T)
    rec = np.matmul(rec, spatial_whitener)

    noise_wf = np.zeros((1000, temporal_size))
    count = 0
    while count < 1000:
        tt = np.random.randint(T-temporal_size)
        cc = np.random.randint(C)
        temp = rec[tt:(tt+temporal_size), cc]
        temp_idxnoise = idxNoise[tt:(tt+temporal_size), cc]
        if np.sum(temp_idxnoise == 0) == 0:
            noise_wf[count] = temp
            count += 1

    w, v = np.linalg.eig(np.cov(noise_wf.T))

    temporal_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T)

    return spatial_SIG, temporal_SIG
Esempio n. 14
0
from yass.batch import RecordingsReader

# generate some big files
output_folder = '/Users/Edu/data/yass-benchmarks'
wide_data = np.random.rand(50, 1000000)
long_data = wide_data.T

path_to_wide = os.path.join(output_folder, 'wide.bin')
path_to_long = os.path.join(output_folder, 'long.bin')

wide_data.tofile(path_to_wide)
long_data.tofile(path_to_long)

# load the files using the readers, they are agnostic on the data type
# and will behave exactly the same
reader_wide = RecordingsReader(path_to_wide, dtype='float64',
                               channels=50, data_format='wide')

reader_long = RecordingsReader(path_to_long, dtype='float64',
                               channels=50, data_format='long')

reader_wide.shape
reader_long.shape

# first index is for observations and second index for channels
obs = reader_wide[10000:20000, 20:30]
obs, obs.shape

# same applies even if your data is in 'long' format, first index for
# observations, second for channels, the output is converted to 'wide'
obs = reader_long[10000:20000, 20:30]
obs, obs.shape
Esempio n. 15
0
def test_splitting_in_batches_does_not_affect_result(path_to_tests):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    PATH_TO_DATA = path.join(path_to_tests, 'data/standarized.bin')

    data = RecordingsReader(PATH_TO_DATA, loader='array').data

    with open(path.join(path_to_tests, 'data/standarized.yaml')) as f:
        PARAMS = yaml.load(f)

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    whiten_filter = np.tile(
        np.eye(channel_index.shape[1], dtype='float32')[np.newaxis, :, :],
        [channel_index.shape[0], 1, 1])

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename
    (x_tf, output_tf, NND, NNAE, NNT) = neuralnetwork.prepare_nn(
        channel_index,
        whiten_filter,
        detection_th,
        triage_th,
        detection_fname,
        ae_fname,
        triage_fname,
    )

    # run all at once
    with tf.Session() as sess:
        # get values of above tensors
        NND.saver.restore(sess, NND.path_to_detector_model)
        NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model)
        NNT.saver.restore(sess, NNT.path_to_triage_model)
        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize(
            data, sess, x_tf, output_tf, neighbors, rot)

    # run in batches - buffer size makes sure we can detect spikes if they
    # appear at the end of any batch
    bp = BatchProcessor(PATH_TO_DATA,
                        PARAMS['dtype'],
                        PARAMS['n_channels'],
                        PARAMS['data_order'],
                        '100KB',
                        buffer_size=CONFIG.spike_size)

    with tf.Session() as sess:
        # get values of above tensors
        NND.saver.restore(sess, NND.path_to_detector_model)
        NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model)
        NNT.saver.restore(sess, NNT.path_to_triage_model)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        res = bp.multi_channel_apply(
            neuralnetwork.run_detect_triage_featurize,
            mode='memory',
            cleanup_function=neuralnetwork.fix_indexes,
            sess=sess,
            x_tf=x_tf,
            output_tf=output_tf,
            rot=rot,
            neighbors=neighbors)

    scores_batch = np.concatenate([element[0] for element in res], axis=0)
    clear_batch = np.concatenate([element[1] for element in res], axis=0)
    collision_batch = np.concatenate([element[2] for element in res], axis=0)

    np.testing.assert_array_equal(clear_batch, clear)
    np.testing.assert_array_equal(collision_batch, collision)
    np.testing.assert_array_equal(scores_batch, scores)
Esempio n. 16
0
def test_nn_output(path_to_tests):
    """Test that pipeline using threshold detector returns the same results
    """
    logger = logging.getLogger(__name__)

    yass.set_config(path.join(path_to_tests, 'config_nn_49.yaml'))

    CONFIG = read_config()
    TMP = Path(CONFIG.data.root_folder, 'tmp')

    logger.info('Removing %s', TMP)
    shutil.rmtree(str(TMP))

    PATH_TO_REF = '/home/Edu/data/nnet'

    np.random.seed(0)

    # run preprocess
    (standarized_path, standarized_params, whiten_filter) = preprocess.run()

    # load preprocess output
    path_to_standarized = path.join(PATH_TO_REF, 'preprocess',
                                    'standarized.bin')
    path_to_whitening = path.join(PATH_TO_REF, 'preprocess', 'whitening.npy')

    whitening_saved = np.load(path_to_whitening)
    standarized_saved = RecordingsReader(path_to_standarized,
                                         loader='array').data
    standarized = RecordingsReader(standarized_path, loader='array').data

    # test preprocess
    np.testing.assert_array_equal(whitening_saved, whiten_filter)
    np.testing.assert_array_equal(standarized_saved, standarized)

    # run detect
    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path, standarized_params,
                                   whiten_filter)
    # load detect output
    path_to_scores = path.join(PATH_TO_REF, 'detect', 'scores_clear.npy')
    path_to_spike_index_clear = path.join(PATH_TO_REF, 'detect',
                                          'spike_index_clear.npy')
    path_to_spike_index_all = path.join(PATH_TO_REF, 'detect',
                                        'spike_index_all.npy')

    scores_saved = np.load(path_to_scores)
    spike_index_clear_saved = np.load(path_to_spike_index_clear)
    spike_index_all_saved = np.load(path_to_spike_index_all)

    # test detect output
    np.testing.assert_array_equal(scores_saved, score)
    np.testing.assert_array_equal(spike_index_clear_saved, spike_index_clear)
    np.testing.assert_array_equal(spike_index_all_saved, spike_index_all)

    # run cluster
    (spike_train_clear, tmp_loc,
     vbParam) = cluster.run(score, spike_index_clear)

    # load cluster output
    path_to_spike_train_cluster = path.join(PATH_TO_REF, 'cluster',
                                            'spike_train_cluster.npy')
    spike_train_cluster_saved = np.load(path_to_spike_train_cluster)

    # test cluster
    #np.testing.assert_array_equal(spike_train_cluster_saved, spike_train_clear)

    # run templates
    (templates_, spike_train, groups,
     idx_good_templates) = templates.run(spike_train_clear,
                                         tmp_loc,
                                         save_results=True)

    # load templates output
    path_to_templates = path.join(PATH_TO_REF, 'templates', 'templates.npy')
    templates_saved = np.load(path_to_templates)

    # test templates
    np.testing.assert_array_almost_equal(templates_saved,
                                         templates_,
                                         decimal=4)

    # run deconvolution
    spike_train = deconvolute.run(spike_index_all, templates_)

    # load deconvolution output
    path_to_spike_train = path.join(PATH_TO_REF, 'spike_train.npy')
    spike_train_saved = np.load(path_to_spike_train)

    # test deconvolution
    np.testing.assert_array_equal(spike_train_saved, spike_train)
Esempio n. 17
0
def run(spike_train_clear,
        templates,
        spike_index_collision,
        output_directory='tmp/',
        recordings_filename='standarized.bin'):
    """Deconvolute spikes

    Parameters
    ----------
    spike_train_clear: numpy.ndarray (n_clear_spikes, 2)
        A 2D array for clear spikes whose first column indicates the spike
        time and the second column the neuron id determined by the clustering
        algorithm

    templates: numpy.ndarray (n_channels, waveform_size, n_templates)
        A 3D array with the templates

    spike_index_collision: numpy.ndarray (n_collided_spikes, 2)
        A 2D array for collided spikes whose first column indicates the spike
        time and the second column the neuron id determined by the clustering
        algorithm

    output_directory: str, optional
        Output directory (relative to CONFIG.data.root_folder) used to load
        the recordings to generate templates, defaults to tmp/

    recordings_filename: str, optional
        Recordings filename (relative to CONFIG.data.root_folder/
        output_directory) used to draw the waveforms from, defaults to
        standarized.bin

    Returns
    -------
    spike_train: numpy.ndarray (n_clear_spikes, 2)
        A 2D array with the spike train, first column indicates the spike
        time and the second column the neuron ID

    Examples
    --------

    .. literalinclude:: ../examples/deconvolute.py
    """
    logger = logging.getLogger(__name__)

    CONFIG = read_config()

    recordings = RecordingsReader(
        os.path.join(CONFIG.data.root_folder, output_directory,
                     recordings_filename))

    logging.debug('Starting deconvolution. templates.shape: {}, '
                  'spike_index_collision.shape: {}'.format(
                      templates.shape, spike_index_collision.shape))

    deconv = Deconvolution(CONFIG, np.transpose(templates, [1, 0, 2]),
                           spike_index_collision, recordings)
    spike_train_deconv = deconv.fullMPMU()

    logger.debug('spike_train_deconv.shape: {}'.format(
        spike_train_deconv.shape))

    # merge spikes in one array
    spike_train = np.concatenate((spike_train_deconv, spike_train_clear))
    spike_train = spike_train[np.argsort(spike_train[:, 0])]

    logger.debug('spike_train.shape: {}'.format(spike_train.shape))

    idx_keep = np.zeros(spike_train.shape[0], 'bool')

    # TODO: check if we can remove this
    for k in range(templates.shape[2]):
        idx_c = np.where(spike_train[:, 1] == k)[0]
        idx_keep[idx_c[np.concatenate(
            ([True], np.diff(spike_train[idx_c, 0]) > 1))]] = 1

    logger.debug('deduplicated spike_train_deconv.shape: {}'.format(
        spike_train.shape))

    spike_train = spike_train[idx_keep]

    return spike_train
Esempio n. 18
0
def noise_cov(path_to_data, dtype, n_channels, data_order, neighbors, geom,
              temporal_size):
    """[Description]

    Parameters
    ----------
    path_to_data: str
        Path to recordings data

    dtype: str
        dtype for recordings

    n_channels: int
        Number of channels in the recordings

    data_order: str
        Recordings order, one of ('channels', 'samples'). In a dataset with k
        observations per channel and j channels: 'channels' means first k
        contiguous observations come from channel 0, then channel 1, and so
        on. 'sample' means first j contiguous data are the first observations
        from all channels, then the second observations from all channels and
        so on

    neighbors: numpy.ndarray
        Neighbors matrix

    geom: numpy.ndarray
        Cartesian coordinates for the channels

    temporal_size:
        Waveform size

    Returns
    -------
    """
    c_ref = np.argmax(np.sum(neighbors, 0))
    ch_idx = np.where(neighbors[c_ref])[0]
    ch_idx, temp = order_channels_by_distance(c_ref, ch_idx, geom)

    rec = RecordingsReader(path_to_data, dtype=dtype, n_channels=n_channels,
                           data_order=data_order, loader='array')
    rec = rec[:, ch_idx]

    T, C = rec.shape
    idxNoise = np.zeros((T, C))

    R = int((temporal_size-1)/2)
    for c in range(C):
        idx_temp = np.where(rec[:, c] > 3)[0]
        for j in range(-R, R+1):
            idx_temp2 = idx_temp + j
            idx_temp2 = idx_temp2[np.logical_and(
                idx_temp2 >= 0, idx_temp2 < T)]
            rec[idx_temp2, c] = np.nan
        idxNoise_temp = (rec[:, c] == rec[:, c])
        rec[:, c] = rec[:, c]/np.nanstd(rec[:, c])

        rec[~idxNoise_temp, c] = 0
        idxNoise[idxNoise_temp, c] = 1

    spatial_cov = np.divide(np.matmul(rec.T, rec),
                            np.matmul(idxNoise.T, idxNoise))

    w, v = np.linalg.eig(spatial_cov)
    spatial_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T)

    spatial_whitener = np.matmul(np.matmul(v, np.diag(1/np.sqrt(w))), v.T)
    rec = np.matmul(rec, spatial_whitener)

    noise_wf = np.zeros((1000, temporal_size))
    count = 0
    while count < 1000:
        tt = np.random.randint(T-temporal_size)
        cc = np.random.randint(C)
        temp = rec[tt:(tt+temporal_size), cc]
        temp_idxnoise = idxNoise[tt:(tt+temporal_size), cc]
        if np.sum(temp_idxnoise == 0) == 0:
            noise_wf[count] = temp
            count += 1

    w, v = np.linalg.eig(np.cov(noise_wf.T))

    temporal_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T)

    return spatial_SIG, temporal_SIG
Esempio n. 19
0
def run_threshold(standarized_path, standarized_params, channel_index,
                  whiten_filter, output_directory, if_file_exists,
                  save_results, gmm_params):
    """Run threshold detector and dimensionality reduction using PCA


    Returns
    -------
    scores
      Scores for all spikes

    spike_index_clear
      Spike indexes for clear spikes

    spike_index_all
      Spike indexes for all spikes
    """

    logger.debug('Running threshold detector...')

    CONFIG = read_config()

    # Set TMP_FOLDER to None if not save_results, this will disable
    # saving results in every function below
    TMP_FOLDER = (os.path.join(CONFIG.data.root_folder, output_directory)
                  if save_results else None)

    # files that will be saved if enable by the if_file_exists option
    filename_index_clear = 'spike_index_clear.npy'
    filename_index_clear_pca = 'spike_index_clear_pca.npy'
    filename_scores_clear = 'scores_clear.npy'
    filename_spike_index_all = 'spike_index_all.npy'
    filename_rotation = 'rotation.npy'

    ###################
    # Spike detection #
    ###################

    # run threshold detection, save clear indexes in TMP/filename_index_clear
    clear = threshold(standarized_path,
                      standarized_params['dtype'],
                      standarized_params['n_channels'],
                      standarized_params['data_order'],
                      CONFIG.resources.max_memory,
                      CONFIG.neigh_channels,
                      CONFIG.spike_size,
                      CONFIG.spike_size + CONFIG.templates.max_shift,
                      CONFIG.detect.threshold_detector.std_factor,
                      TMP_FOLDER,
                      spike_index_clear_filename=filename_index_clear,
                      if_file_exists=if_file_exists)

    #######
    # PCA #
    #######

    recordings = RecordingsReader(standarized_path)

    # run PCA, save rotation matrix and pca scores under TMP_FOLDER
    # TODO: remove clear as input for PCA and create an independent function
    pca_scores, clear, _ = pca(
        standarized_path, standarized_params['dtype'],
        standarized_params['n_channels'], standarized_params['data_order'],
        recordings, clear, CONFIG.spike_size, CONFIG.detect.temporal_features,
        CONFIG.neigh_channels, channel_index, CONFIG.resources.max_memory,
        gmm_params, TMP_FOLDER, 'scores_pca.npy', filename_rotation,
        filename_index_clear_pca, if_file_exists)

    #################
    # Whiten scores #
    #################

    # apply whitening to scores
    #     scores_clear = whiten.score(pca_scores, clear[:, 1], whiten_filter)
    scores_clear = pca_scores

    if TMP_FOLDER is not None:
        # saves whiten scores
        path_to_scores = os.path.join(TMP_FOLDER, filename_scores_clear)
        save_numpy_object(scores_clear,
                          path_to_scores,
                          if_file_exists,
                          name='scores')

        #         save_numpy_object(pca_scores, path_to_scores, if_file_exists,
        #                           name='scores')

        # save spike_index_all (same as spike_index_clear for threshold
        # detector)
        path_to_spike_index_all = os.path.join(TMP_FOLDER,
                                               filename_spike_index_all)
        save_numpy_object(clear,
                          path_to_spike_index_all,
                          if_file_exists,
                          name='Spike index all')

    # TODO: this shouldn't be here
    # transform scores to location + shape feature space
    if CONFIG.cluster.method == 'location':
        scores = get_locations_features_threshold(scores_clear, clear[:, 1],
                                                  channel_index, CONFIG.geom)
#         scores = get_locations_features_threshold(pca_scores, clear[:, 1],
#                                                   channel_index,
#                                                   CONFIG.geom)
    else:
        scores = scores_clear


#         scores = pca_scores

    return scores, clear, np.copy(clear)
Esempio n. 20
0
# generate data
output_folder = os.path.join(os.path.expanduser('~'), 'data/yass')
wide_data = np.random.rand(50, 100000)
long_data = wide_data.T

path_to_wide = os.path.join(output_folder, 'wide.bin')
path_to_long = os.path.join(output_folder, 'long.bin')

wide_data.tofile(path_to_wide)
long_data.tofile(path_to_long)

# load the files using the readers, they are agnostic on the data shape
# and will behave exactly the same
reader_wide = RecordingsReader(path_to_wide,
                               dtype='float64',
                               n_channels=50,
                               data_order='channels')

reader_long = RecordingsReader(path_to_long,
                               dtype='float64',
                               n_channels=50,
                               data_order='samples')

reader_wide.shape, reader_long.shape

# first index is for observations and second index for channels
obs = reader_wide[10000:20000, 20:30]
obs, obs.shape

# same applies even if your data is in 'wide' shape, first index for
# observations, second for channels, the output is converted to 'long'
Esempio n. 21
0
                                     high_factor=0.1,
                                     order=3,
                                     sampling_freq=30000)

standarize_op = PipedTransformation(standarize,
                                    'standardized.bin',
                                    mode='single_channel_one_batch',
                                    keep=True,
                                    sampling_freq=30000)

pipeline.add([butterworth_op, standarize_op])

pipeline.run()

raw = RecordingsReader(path_to_neuropixel_data,
                       dtype='int16',
                       n_channels=385,
                       data_format='wide')
filtered = RecordingsReader(os.path.join(path_output, 'filtered.bin'))
standardized = RecordingsReader(os.path.join(path_output, 'standardized.bin'))

# plot results
fig, (ax1, ax2, ax3) = plt.subplots(3, 1)
ax1.plot(raw[:2000, 0])
ax1.set_title('Raw data')
ax2.plot(filtered[:2000, 0])
ax2.set_title('Filtered data')
ax3.plot(standardized[:2000, 0])
ax3.set_title('Standarized data')
plt.tight_layout()
plt.show()
Esempio n. 22
0
def run(output_directory='tmp/'):
    """Execute preprocessing pipeline

    Parameters
    ----------
    output_directory: str, optional
      Location to store partial results, relative to CONFIG.data.root_folder,
      defaults to tmp/

    Returns
    -------
    clear_scores: numpy.ndarray (n_spikes, n_features, n_channels)
        3D array with the scores for the clear spikes, first simension is
        the number of spikes, second is the nymber of features and third the
        number of channels

    spike_index_clear: numpy.ndarray (n_clear_spikes, 2)
        2D array with indexes for clear spikes, first column contains the
        spike location in the recording and the second the main channel
        (channel whose amplitude is maximum)

    spike_index_collision: numpy.ndarray (n_collided_spikes, 2)
        2D array with indexes for collided spikes, first column contains the
        spike location in the recording and the second the main channel
        (channel whose amplitude is maximum)

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``config.yaml`` - Copy of the configuration file
    * ``metadata.yaml`` - Experiment metadata
    * ``filtered.bin`` - Filtered recordings
    * ``filtered.yaml`` - Filtered recordings metadata
    * ``standarized.bin`` - Standarized recordings
    * ``standarized.yaml`` - Standarized recordings metadata
    * ``whitened.bin`` - Whitened recordings
    * ``whitened.yaml`` - Whitened recordings metadata
    * ``rotation.npy`` - Rotation matrix for dimensionality reduction
    * ``spike_index_clear.npy`` - Same as spike_index_clear returned
    * ``spike_index_collision.npy`` - Same as spike_index_collision returned
    * ``score_clear.npy`` - Scores for clear spikes
    * ``waveforms_clear.npy`` - Waveforms for clear spikes

    Examples
    --------

    .. literalinclude:: ../examples/preprocess.py
    """

    logger = logging.getLogger(__name__)

    CONFIG = read_config()

    OUTPUT_DTYPE = CONFIG.preprocess.dtype

    logger.info(
        'Output dtype for transformed data will be {}'.format(OUTPUT_DTYPE))

    TMP = os.path.join(CONFIG.data.root_folder, output_directory)

    if not os.path.exists(TMP):
        logger.info('Creating temporary folder: {}'.format(TMP))
        os.makedirs(TMP)
    else:
        logger.info('Temporary folder {} already exists, output will be '
                    'stored there'.format(TMP))

    path = os.path.join(CONFIG.data.root_folder, CONFIG.data.recordings)
    dtype = CONFIG.recordings.dtype

    # initialize pipeline object, one batch per channel
    pipeline = BatchPipeline(path, dtype, CONFIG.recordings.n_channels,
                             CONFIG.recordings.format,
                             CONFIG.resources.max_memory, TMP)

    # add filter transformation if necessary
    if CONFIG.preprocess.filter:
        filter_op = Transform(butterworth,
                              'filtered.bin',
                              mode='single_channel_one_batch',
                              keep=True,
                              if_file_exists='skip',
                              cast_dtype=OUTPUT_DTYPE,
                              low_freq=CONFIG.filter.low_pass_freq,
                              high_factor=CONFIG.filter.high_factor,
                              order=CONFIG.filter.order,
                              sampling_freq=CONFIG.recordings.sampling_rate)

        pipeline.add([filter_op])

    (filtered_path, ), (filtered_params, ) = pipeline.run()

    # standarize
    bp = BatchProcessor(filtered_path, filtered_params['dtype'],
                        filtered_params['n_channels'],
                        filtered_params['data_format'],
                        CONFIG.resources.max_memory)
    batches = bp.multi_channel()
    first_batch, _, _ = next(batches)
    sd = standard_deviation(first_batch, CONFIG.recordings.sampling_rate)

    (standarized_path, standarized_params) = bp.multi_channel_apply(
        standarize,
        mode='disk',
        output_path=os.path.join(TMP, 'standarized.bin'),
        if_file_exists='skip',
        cast_dtype=OUTPUT_DTYPE,
        sd=sd)

    standarized = RecordingsReader(standarized_path)
    n_observations = standarized.observations

    if CONFIG.spikes.detection == 'threshold':
        return _threshold_detection(standarized_path, standarized_params,
                                    n_observations, output_directory)
    elif CONFIG.spikes.detection == 'nn':
        return _neural_network_detection(standarized_path, standarized_params,
                                         n_observations, output_directory)
Esempio n. 23
0
def test_splitting_in_batches_does_not_affect(path_to_tests,
                                              path_to_standarized_data,
                                              path_to_sample_pipeline_folder):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    PATH_TO_DATA = path_to_standarized_data

    data = RecordingsReader(PATH_TO_DATA, loader='array').data

    with open(
            path.join(path_to_sample_pipeline_folder, 'preprocess',
                      'standarized.yaml')) as f:
        PARAMS = yaml.load(f)

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index)
    NNT = NeuralNetTriage.load(triage_fname,
                               triage_th,
                               input_tensor=NND.waveform_tf)
    NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf)

    output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean)

    # run all at once
    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize(
            data, sess, NND.x_tf, output_tf, neighbors, rot)

    # run in batches - buffer size makes sure we can detect spikes if they
    # appear at the end of any batch
    bp = BatchProcessor(PATH_TO_DATA,
                        PARAMS['dtype'],
                        PARAMS['n_channels'],
                        PARAMS['data_order'],
                        '100KB',
                        buffer_size=CONFIG.spike_size)

    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        res = bp.multi_channel_apply(
            neuralnetwork.run_detect_triage_featurize,
            mode='memory',
            cleanup_function=neuralnetwork.fix_indexes,
            sess=sess,
            x_tf=NND.x_tf,
            output_tf=output_tf,
            rot=rot,
            neighbors=neighbors)

    scores_batch = np.concatenate([element[0] for element in res], axis=0)
    clear_batch = np.concatenate([element[1] for element in res], axis=0)
    collision_batch = np.concatenate([element[2] for element in res], axis=0)

    np.testing.assert_array_equal(clear_batch, clear)
    np.testing.assert_array_equal(collision_batch, collision)
    np.testing.assert_array_equal(scores_batch, scores)
Esempio n. 24
0
File: make.py Progetto: Nomow/yass
def training_data(CONFIG, templates_uncropped, min_amp, max_amp,
                  n_isolated_spikes,
                  path_to_standarized, noise_ratio=10,
                  collision_ratio=1, misalign_ratio=1, misalign_ratio2=1,
                  multi_channel=True, return_metadata=False):
    """Makes training sets for detector, triage and autoencoder

    Parameters
    ----------
    CONFIG: yaml file
        Configuration file
    min_amp: float
        Minimum value allowed for the maximum absolute amplitude of the
        isolated spike on its main channel
    max_amp: float
        Maximum value allowed for the maximum absolute amplitude of the
        isolated spike on its main channel
    n_isolated_spikes: int
        Number of isolated spikes to generate. This is different from the
        total number of x_detect
    path_to_standarized: str
        Folder storing the standarized data (if not exist, run preprocess to
        automatically generate)
    noise_ratio: int
        Ratio of number of noise to isolated spikes. For example, if
        n_isolated_spike=1000, noise_ratio=5, then n_noise=5000
    collision_ratio: int
        Ratio of number of collisions to isolated spikes.
    misalign_ratio: int
        Ratio of number of spatially and temporally misaligned spikes to
        isolated spikes
    misalign_ratio2: int
        Ratio of number of only-spatially misaligned spikes to isolated spikes
    multi_channel: bool
        If True, generate training data for multi-channel neural
        network. Otherwise generate single-channel data

    Returns
    -------
    x_detect: numpy.ndarray
        [number of detection training data, temporal length, number of
        channels] Training data for the detect net.
    y_detect: numpy.ndarray
        [number of detection training data] Label for x_detect

    x_triage: numpy.ndarray
        [number of triage training data, temporal length, number of channels]
        Training data for the triage net.
    y_triage: numpy.ndarray
        [number of triage training data] Label for x_triage
    x_ae: numpy.ndarray
        [number of ae training data, temporal length] Training data for the
        autoencoder: noisy spikes
    y_ae: numpy.ndarray
        [number of ae training data, temporal length] Denoised x_ae

    Notes
    -----
    * Detection training data
        * Multi channel
            * Positive examples: Clean spikes + noise, Collided spikes + noise
            * Negative examples: Temporally misaligned spikes + noise, Noise

    * Triage training data
        * Multi channel
            * Positive examples: Clean spikes + noise
            * Negative examples: Collided spikes + noise
    """
    # FIXME: should we add collided spikes with the first spike non-centered
    # tod the detection training set?

    logger = logging.getLogger(__name__)

    # STEP1: Load recordings data, and select one channel and random (with the
    # right number of neighbors, then swap the channels so the first one
    # corresponds to the selected channel, then the nearest neighbor, then the
    # second nearest and so on... this is only used for estimating noise
    # structure

    # ##### FIXME: this needs to be removed, the user should already
    # pass data with the desired channels
    rec = RecordingsReader(path_to_standarized, loader='array')
    channel_n_neighbors = np.sum(CONFIG.neigh_channels, 0)
    max_neighbors = np.max(channel_n_neighbors)
    channels_with_max_neighbors = np.where(channel_n_neighbors
                                           == max_neighbors)[0]
    logger.debug('The following channels have %i neighbors: %s',
                 max_neighbors, channels_with_max_neighbors)

    # reference channel: channel with max number of neighbors
    channel_selected = np.random.choice(channels_with_max_neighbors)

    logger.debug('Selected channel %i', channel_selected)

    # neighbors for the reference channel
    channel_neighbors = np.where(CONFIG.neigh_channels[channel_selected])[0]

    # ordered neighbors for reference channel
    channel_idx, _ = order_channels_by_distance(channel_selected,
                                                channel_neighbors,
                                                CONFIG.geom)
    # read the selected channels
    rec = rec[:, channel_idx]
    # ##### FIXME:end of section to be removed

    # STEP 2: load templates

    processor = TemplatesProcessor(templates_uncropped)

    # swap channels, first channel is main channel, then nearest neighbor
    # and so on, only keep neigh_channels
    templates = (processor.crop_spatially(CONFIG.neigh_channels, CONFIG.geom)
                 .values)

    # TODO: remove, this data can be obtained from other variables
    K, _, n_channels = templates_uncropped.shape

    # make training data set
    R = CONFIG.spike_size

    logger.debug('Output will be of size %s', 2 * R + 1)

    # make clean augmented spikes
    nk = int(np.ceil(n_isolated_spikes/K))
    max_shift = 2*R

    # make spikes from templates
    x_templates = util.make_from_templates(templates, min_amp, max_amp, nk)

    # make collided spikes - max shift is set to R since 2 * R + 1 will be
    # the final dimension for the spikes. one of the spikes is kept with the
    # main channel, the other one is shifted and channels are changed
    x_collision = util.make_collided(x_templates, collision_ratio,
                                     multi_channel, max_shift=R,
                                     min_shift=5,
                                     return_metadata=return_metadata)

    # make misaligned spikes
    x_temporally_misaligned = util.make_temporally_misaligned(
        x_templates, misalign_ratio,
        multi_channel=multi_channel,
        max_shift=max_shift)

    # now spatially misalign those
    x_misaligned = util.make_spatially_misaligned(x_temporally_misaligned,
                                                  n_per_spike=misalign_ratio2)

    # determine noise covariance structure
    spatial_SIG, temporal_SIG = noise_cov(rec,
                                          temporal_size=templates.shape[1],
                                          window_size=templates.shape[1],
                                          sample_size=1000,
                                          threshold=3.0)

    # make noise
    n_noise = int(x_templates.shape[0] * noise_ratio)
    noise = util.make_noise(n_noise, spatial_SIG, temporal_SIG)

    # make labels
    y_clean_1 = np.ones((x_templates.shape[0]))
    y_collision_1 = np.ones((x_collision.shape[0]))

    y_misaligned_0 = np.zeros((x_misaligned.shape[0]))
    y_noise_0 = np.zeros((noise.shape[0]))
    y_collision_0 = np.zeros((x_collision.shape[0]))

    mid_point = int((x_templates.shape[1]-1)/2)
    MID_POINT_IDX = slice(mid_point - R, mid_point + R + 1)

    # TODO: replace _make_noisy for new function
    x_templates_noisy = util._make_noisy(x_templates, noise)
    x_collision_noisy = util._make_noisy(x_collision, noise)
    x_misaligned_noisy = util._make_noisy(x_misaligned,
                                          noise)

    #############
    # Detection #
    #############

    if multi_channel:
        x = yarr.concatenate((x_templates_noisy, x_collision_noisy,
                              x_misaligned_noisy, noise))
        x_detect = x[:, MID_POINT_IDX, :]

        y_detect = np.concatenate((y_clean_1, y_collision_1,
                                   y_misaligned_0, y_noise_0))
    else:
        x = yarr.concatenate((x_templates_noisy, x_misaligned_noisy,
                              noise))
        x_detect = x[:, MID_POINT_IDX, 0]

        y_detect = yarr.concatenate((y_clean_1, y_misaligned_0, y_noise_0))
    ##########
    # Triage #
    ##########

    if multi_channel:
        x = yarr.concatenate((x_templates_noisy, x_collision_noisy))
        x_triage = x[:, MID_POINT_IDX, :]

        y_triage = yarr.concatenate((y_clean_1, y_collision_0))
    else:
        x = yarr.concatenate((x_templates_noisy, x_collision_noisy,))
        x_triage = x[:, MID_POINT_IDX, 0]

        y_triage = yarr.concatenate((y_clean_1, y_collision_0))

    ###############
    # Autoencoder #
    ###############

    # # TODO: need to abstract this part of the code, create a separate
    # # function and document it
    # neighbors_ae = np.ones((n_channels, n_channels), 'int32')

    # templates_ae = crop_and_align_templates(templates_uncropped,
    #                                         CONFIG.spike_size,
    #                                         neighbors_ae,
    #                                         CONFIG.geom)

    # tt = templates_ae.transpose(1, 0, 2).reshape(templates_ae.shape[1], -1)
    # tt = tt[:, np.ptp(tt, axis=0) > 2]
    # max_amp = np.max(np.ptp(tt, axis=0))

    # y_ae = np.zeros((nk*tt.shape[1], tt.shape[0]))

    # for k in range(tt.shape[1]):
    #     amp_now = np.ptp(tt[:, k])
    #     amps_range = (np.arange(nk)*(max_amp-min_amp)
    #                   / nk+min_amp)[:, np.newaxis, np.newaxis]

    #     y_ae[k*nk:(k+1)*nk] = ((tt[:, k]/amp_now)[np.newaxis, :]
    #                            * amps_range[:, :, 0])

    # noise_ae = np.random.normal(size=y_ae.shape)
    # noise_ae = np.matmul(noise_ae, temporal_SIG)

    # x_ae = y_ae + noise_ae
    # x_ae = x_ae[:, MID_POINT_IDX]
    # y_ae = y_ae[:, MID_POINT_IDX]

    x_ae = None
    y_ae = None

    # FIXME: y_ae is no longer used, autoencoder was replaced by PCA
    return x_detect, y_detect, x_triage, y_triage, x_ae, y_ae
Esempio n. 25
0
def noise_cov(path_to_data, neighbors, geom, temporal_size):
    """Compute noise temporal and spatial covariance

    Parameters
    ----------
    path_to_data: str
        Path to recordings data

    neighbors: numpy.ndarray
        Neighbors matrix

    geom: numpy.ndarray
        Cartesian coordinates for the channels

    temporal_size:
        Waveform size

    Returns
    -------
    spatial_SIG: numpy.ndarray

    temporal_SIG: numpy.ndarray
    """
    logger = logging.getLogger(__name__)

    logger.debug('Computing noise_cov. Neighbors shape: {}, geom shape: {} '
                 'temporal_size: {}'.format(neighbors.shape, geom.shape,
                                            temporal_size))

    c_ref = np.argmax(np.sum(neighbors, 0))
    ch_idx = np.where(neighbors[c_ref])[0]
    ch_idx, temp = order_channels_by_distance(c_ref, ch_idx, geom)

    rec = RecordingsReader(path_to_data, loader='array')
    rec = rec[:, ch_idx]

    T, C = rec.shape
    idxNoise = np.zeros((T, C))
    R = int((temporal_size-1)/2)

    for c in range(C):

        idx_temp = np.where(rec[:, c] > 3)[0]

        for j in range(-R, R+1):

            idx_temp2 = idx_temp + j
            idx_temp2 = idx_temp2[np.logical_and(idx_temp2 >= 0,
                                                 idx_temp2 < T)]
            rec[idx_temp2, c] = np.nan

        idxNoise_temp = (rec[:, c] == rec[:, c])
        rec[:, c] = rec[:, c]/np.nanstd(rec[:, c])

        rec[~idxNoise_temp, c] = 0
        idxNoise[idxNoise_temp, c] = 1

    spatial_cov = np.divide(np.matmul(rec.T, rec),
                            np.matmul(idxNoise.T, idxNoise))

    w, v = np.linalg.eig(spatial_cov)
    spatial_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T)

    spatial_whitener = np.matmul(np.matmul(v, np.diag(1/np.sqrt(w))), v.T)
    rec = np.matmul(rec, spatial_whitener)

    noise_wf = np.zeros((1000, temporal_size))
    count = 0

    while count < 1000:

        tt = np.random.randint(T-temporal_size)
        cc = np.random.randint(C)
        temp = rec[tt:(tt+temporal_size), cc]
        temp_idxnoise = idxNoise[tt:(tt+temporal_size), cc]

        if np.sum(temp_idxnoise == 0) == 0:
            noise_wf[count] = temp
            count += 1

    w, v = np.linalg.eig(np.cov(noise_wf.T))

    temporal_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T)

    logger.debug('spatial_SIG shape: {} temporal_SIG shape: {}'
                 .format(spatial_SIG.shape, temporal_SIG.shape))

    return spatial_SIG, temporal_SIG