Ejemplo n.º 1
0
def test_can_make_noise(path_to_tests, path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    n_spikes, _ = spike_train.shape

    weighted_spike_train = np.hstack(
        (spike_train, np.ones((n_spikes, 1), 'int32')))

    templates_uncropped, _ = get_templates(weighted_spike_train,
                                           path_to_standarized_data,
                                           CONFIG.resources.max_memory,
                                           4 * CONFIG.spike_size)

    templates_uncropped = np.transpose(templates_uncropped, (2, 1, 0))

    templates = crop_and_align_templates(templates_uncropped,
                                         CONFIG.spike_size,
                                         CONFIG.neigh_channels, CONFIG.geom)

    spatial_SIG, temporal_SIG = noise_cov(path_to_standarized_data,
                                          CONFIG.neigh_channels, CONFIG.geom,
                                          templates.shape[1])

    x_clean = make_clean(templates, min_amp=2, max_amp=10, nk=100)

    make_noise(x_clean,
               noise_ratio=10,
               templates=templates,
               spatial_SIG=spatial_SIG,
               temporal_SIG=temporal_SIG)
Ejemplo n.º 2
0
def get_noise_covariance(reader, save_dir, CONFIG, chunk=None):

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    fname_spatial_sig = os.path.join(save_dir, 'spatial_sig.npy')
    fname_temporal_sig = os.path.join(save_dir, 'temporal_sig.npy')    
    
    if os.path.exists(fname_spatial_sig) and os.path.exists(fname_temporal_sig):
        return fname_spatial_sig, fname_temporal_sig

    # only need subset of channels
    n_neigh_channels = np.sum(CONFIG.neigh_channels, axis=1)
    c_chosen = np.where(n_neigh_channels == np.max(n_neigh_channels))[0][0]
    channels = np.where(CONFIG.neigh_channels[c_chosen])[0]

    if chunk is None:
        rec = reader.read_data(reader.start, reader.end, channels)
    else:
        rec = reader.read_data(chunk[0], chunk[1], channels)

    spatial_SIG, temporal_SIG = noise_cov(
        rec,
        temporal_size=reader.spike_size,
        window_size=reader.spike_size,
        sample_size=1000,
        threshold=3.0)
    
    np.save(fname_spatial_sig, spatial_SIG)
    np.save(fname_temporal_sig, temporal_SIG)
    
    return fname_spatial_sig, fname_temporal_sig
Ejemplo n.º 3
0
def test_can_compute_noise_cov(path_to_tests, path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    n_spikes, _ = spike_train.shape

    weighted_spike_train = np.hstack(
        (spike_train, np.ones((n_spikes, 1), 'int32')))

    templates_uncropped, _ = get_templates(weighted_spike_train,
                                           path_to_standarized_data,
                                           CONFIG.resources.max_memory,
                                           4 * CONFIG.spike_size)

    templates_uncropped = np.transpose(templates_uncropped, (2, 1, 0))

    noise_cov(path_to_standarized_data, CONFIG.neigh_channels, CONFIG.geom,
              templates_uncropped.shape[1])
Ejemplo n.º 4
0
def test_can_compute_noise_cov(path_to_tests, path_to_standarized_data):
    recordings = RecordingsReader(path_to_standarized_data,
                                  loader='array')._data

    spatial_SIG, temporal_SIG = noise_cov(recordings,
                                          temporal_size=10,
                                          sample_size=100,
                                          threshold=3.0,
                                          window_size=10)
Ejemplo n.º 5
0
def test_can_estimate_temporal_and_spatial_sig(path_to_standarized_data):

    recordings = RecordingsReader(path_to_standarized_data,
                                  loader='array')._data

    (spatial_SIG,
        temporal_SIG) = noise.noise_cov(recordings,
                                        temporal_size=40,
                                        sample_size=1000,
                                        threshold=3.0,
                                        window_size=10)

    # check no nans
    assert (~np.isnan(spatial_SIG)).all()
    assert (~np.isnan(temporal_SIG)).all()
Ejemplo n.º 6
0
def make_training_data(CONFIG, spike_train, chosen_templates, min_amp, nspikes,
                       data_folder):
    """[Description]

    Parameters
    ----------

    Returns
    -------
    """

    logger = logging.getLogger(__name__)

    path_to_data = os.path.join(data_folder, 'standarized.bin')
    path_to_config = os.path.join(data_folder, 'standarized.yaml')

    # make sure standarized data already exists
    if not os.path.exists(path_to_data):
        raise ValueError(
            'Standarized data does not exist in: {}, this is '
            'needed to generate training data, run the '
            'preprocesor first to generate it'.format(path_to_data))

    PARAMS = load_yaml(path_to_config)

    logger.info('Getting templates...')

    # get templates
    templates, _ = get_templates(spike_train, path_to_data, CONFIG.spikeSize)

    templates = np.transpose(templates, (2, 1, 0))

    logger.info('Got templates ndarray of shape: {}'.format(templates.shape))

    # choose good templates (good looking and big enough)
    templates = choose_templates(templates, chosen_templates)

    if templates.shape[0] == 0:
        raise ValueError("Coulndt find any good templates...")

    logger.info('Good looking templates of shape: {}'.format(templates.shape))

    # align and crop templates
    templates = crop_templates(templates, CONFIG.spikeSize,
                               CONFIG.neighChannels, CONFIG.geom)

    # determine noise covariance structure
    spatial_SIG, temporal_SIG = noise_cov(path_to_data, PARAMS['dtype'],
                                          CONFIG.recordings.n_channels,
                                          PARAMS['data_format'],
                                          CONFIG.neighChannels, CONFIG.geom,
                                          templates.shape[1])

    # make training data set
    K = templates.shape[0]
    R = CONFIG.spikeSize
    amps = np.max(np.abs(templates), axis=1)

    # make clean augmented spikes
    nk = int(np.ceil(nspikes / K))
    max_amp = np.max(amps) * 1.5
    nneigh = templates.shape[2]

    ################
    # clean spikes #
    ################
    x_clean = np.zeros((nk * K, templates.shape[1], templates.shape[2]))
    for k in range(K):
        tt = templates[k]
        amp_now = np.max(np.abs(tt))
        amps_range = (np.arange(nk) * (max_amp - min_amp) / nk +
                      min_amp)[:, np.newaxis, np.newaxis]
        x_clean[k * nk:(k + 1) *
                nk] = (tt / amp_now)[np.newaxis, :, :] * amps_range

    #############
    # collision #
    #############
    x_collision = np.zeros(x_clean.shape)
    max_shift = 2 * R

    temporal_shifts = np.random.randint(max_shift * 2, size=nk * K) - max_shift
    temporal_shifts[
        temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5
    temporal_shifts[
        temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6

    amp_per_data = np.max(x_clean[:, :, 0], axis=1)
    for j in range(nk * K):
        shift = temporal_shifts[j]

        x_collision[j] = np.copy(x_clean[j])
        idx_candidate = np.where(amp_per_data > amp_per_data[j] * 0.3)[0]
        idx_match = idx_candidate[np.random.randint(idx_candidate.shape[0],
                                                    size=1)[0]]
        x_clean2 = np.copy(x_clean[idx_match]
                           [:,
                            np.random.choice(nneigh, nneigh, replace=False)])

        if shift > 0:
            x_collision[j, :(x_collision.shape[1] - shift)] += x_clean2[shift:]

        elif shift < 0:
            x_collision[j,
                        (-shift):] += x_clean2[:(x_collision.shape[1] + shift)]
        else:
            x_collision[j] += x_clean2

    #####################
    # misaligned spikes #
    #####################
    x_misaligned = np.zeros(x_clean.shape)

    temporal_shifts = np.random.randint(max_shift * 2, size=nk * K) - max_shift
    temporal_shifts[
        temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5
    temporal_shifts[
        temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6

    for j in range(nk * K):
        shift = temporal_shifts[j]
        x_clean2 = np.copy(
            x_clean[j][:, np.random.choice(nneigh, nneigh, replace=False)])

        if shift > 0:
            x_misaligned[j, :(x_collision.shape[1] -
                              shift)] += x_clean2[shift:]

        elif shift < 0:
            x_misaligned[j, (-shift):] += x_clean2[:(x_collision.shape[1] +
                                                     shift)]
        else:
            x_misaligned[j] += x_clean2

    #########
    # noise #
    #########

    # get noise
    noise = np.random.normal(size=x_clean.shape)
    for c in range(noise.shape[2]):
        noise[:, :, c] = np.matmul(noise[:, :, c], temporal_SIG)

        reshaped_noise = np.reshape(noise, (-1, noise.shape[2]))
    noise = np.reshape(np.matmul(reshaped_noise, spatial_SIG),
                       [x_clean.shape[0], x_clean.shape[1], x_clean.shape[2]])

    y_clean = np.ones((x_clean.shape[0]))
    y_col = np.ones((x_clean.shape[0]))
    y_misalinged = np.zeros((x_clean.shape[0]))
    y_noise = np.zeros((x_clean.shape[0]))

    mid_point = int((x_clean.shape[1] - 1) / 2)

    # get training set for detection
    x = np.concatenate(
        (x_clean + noise,
         x_collision + noise[np.random.permutation(noise.shape[0])],
         x_misaligned + noise[np.random.permutation(noise.shape[0])], noise))

    x_detect = x[:, (mid_point - R):(mid_point + R + 1), :]
    y_detect = np.concatenate((y_clean, y_col, y_misalinged, y_noise))

    # get training set for triage
    x = np.concatenate((
        x_clean + noise,
        x_collision + noise[np.random.permutation(noise.shape[0])],
    ))
    x_triage = x[:, (mid_point - R):(mid_point + R + 1), :]
    y_triage = np.concatenate((y_clean, np.zeros((x_clean.shape[0]))))

    # ge training set for auto encoder
    ae_shift_max = 1
    temporal_shifts_ae = np.random.randint(
        ae_shift_max * 2 + 1, size=x_clean.shape[0]) - ae_shift_max
    y_ae = np.zeros((x_clean.shape[0], 2 * R + 1))
    x_ae = np.zeros((x_clean.shape[0], 2 * R + 1))
    for j in range(x_ae.shape[0]):
        y_ae[j] = x_clean[j, (mid_point - R +
                              temporal_shifts_ae[j]):(mid_point + R + 1 +
                                                      temporal_shifts_ae[j]),
                          0]
        x_ae[j] = x_clean[j, (mid_point - R + temporal_shifts_ae[j]):(
            mid_point + R + 1 + temporal_shifts_ae[j]), 0] + noise[j, (
                mid_point - R + temporal_shifts_ae[j]):(
                    mid_point + R + 1 + temporal_shifts_ae[j]), 0]

    return x_detect, y_detect, x_triage, y_triage, x_ae, y_ae
Ejemplo n.º 7
0
Archivo: make.py Proyecto: Nomow/yass
def training_data(CONFIG, templates_uncropped, min_amp, max_amp,
                  n_isolated_spikes,
                  path_to_standarized, noise_ratio=10,
                  collision_ratio=1, misalign_ratio=1, misalign_ratio2=1,
                  multi_channel=True, return_metadata=False):
    """Makes training sets for detector, triage and autoencoder

    Parameters
    ----------
    CONFIG: yaml file
        Configuration file
    min_amp: float
        Minimum value allowed for the maximum absolute amplitude of the
        isolated spike on its main channel
    max_amp: float
        Maximum value allowed for the maximum absolute amplitude of the
        isolated spike on its main channel
    n_isolated_spikes: int
        Number of isolated spikes to generate. This is different from the
        total number of x_detect
    path_to_standarized: str
        Folder storing the standarized data (if not exist, run preprocess to
        automatically generate)
    noise_ratio: int
        Ratio of number of noise to isolated spikes. For example, if
        n_isolated_spike=1000, noise_ratio=5, then n_noise=5000
    collision_ratio: int
        Ratio of number of collisions to isolated spikes.
    misalign_ratio: int
        Ratio of number of spatially and temporally misaligned spikes to
        isolated spikes
    misalign_ratio2: int
        Ratio of number of only-spatially misaligned spikes to isolated spikes
    multi_channel: bool
        If True, generate training data for multi-channel neural
        network. Otherwise generate single-channel data

    Returns
    -------
    x_detect: numpy.ndarray
        [number of detection training data, temporal length, number of
        channels] Training data for the detect net.
    y_detect: numpy.ndarray
        [number of detection training data] Label for x_detect

    x_triage: numpy.ndarray
        [number of triage training data, temporal length, number of channels]
        Training data for the triage net.
    y_triage: numpy.ndarray
        [number of triage training data] Label for x_triage
    x_ae: numpy.ndarray
        [number of ae training data, temporal length] Training data for the
        autoencoder: noisy spikes
    y_ae: numpy.ndarray
        [number of ae training data, temporal length] Denoised x_ae

    Notes
    -----
    * Detection training data
        * Multi channel
            * Positive examples: Clean spikes + noise, Collided spikes + noise
            * Negative examples: Temporally misaligned spikes + noise, Noise

    * Triage training data
        * Multi channel
            * Positive examples: Clean spikes + noise
            * Negative examples: Collided spikes + noise
    """
    # FIXME: should we add collided spikes with the first spike non-centered
    # tod the detection training set?

    logger = logging.getLogger(__name__)

    # STEP1: Load recordings data, and select one channel and random (with the
    # right number of neighbors, then swap the channels so the first one
    # corresponds to the selected channel, then the nearest neighbor, then the
    # second nearest and so on... this is only used for estimating noise
    # structure

    # ##### FIXME: this needs to be removed, the user should already
    # pass data with the desired channels
    rec = RecordingsReader(path_to_standarized, loader='array')
    channel_n_neighbors = np.sum(CONFIG.neigh_channels, 0)
    max_neighbors = np.max(channel_n_neighbors)
    channels_with_max_neighbors = np.where(channel_n_neighbors
                                           == max_neighbors)[0]
    logger.debug('The following channels have %i neighbors: %s',
                 max_neighbors, channels_with_max_neighbors)

    # reference channel: channel with max number of neighbors
    channel_selected = np.random.choice(channels_with_max_neighbors)

    logger.debug('Selected channel %i', channel_selected)

    # neighbors for the reference channel
    channel_neighbors = np.where(CONFIG.neigh_channels[channel_selected])[0]

    # ordered neighbors for reference channel
    channel_idx, _ = order_channels_by_distance(channel_selected,
                                                channel_neighbors,
                                                CONFIG.geom)
    # read the selected channels
    rec = rec[:, channel_idx]
    # ##### FIXME:end of section to be removed

    # STEP 2: load templates

    processor = TemplatesProcessor(templates_uncropped)

    # swap channels, first channel is main channel, then nearest neighbor
    # and so on, only keep neigh_channels
    templates = (processor.crop_spatially(CONFIG.neigh_channels, CONFIG.geom)
                 .values)

    # TODO: remove, this data can be obtained from other variables
    K, _, n_channels = templates_uncropped.shape

    # make training data set
    R = CONFIG.spike_size

    logger.debug('Output will be of size %s', 2 * R + 1)

    # make clean augmented spikes
    nk = int(np.ceil(n_isolated_spikes/K))
    max_shift = 2*R

    # make spikes from templates
    x_templates = util.make_from_templates(templates, min_amp, max_amp, nk)

    # make collided spikes - max shift is set to R since 2 * R + 1 will be
    # the final dimension for the spikes. one of the spikes is kept with the
    # main channel, the other one is shifted and channels are changed
    x_collision = util.make_collided(x_templates, collision_ratio,
                                     multi_channel, max_shift=R,
                                     min_shift=5,
                                     return_metadata=return_metadata)

    # make misaligned spikes
    x_temporally_misaligned = util.make_temporally_misaligned(
        x_templates, misalign_ratio,
        multi_channel=multi_channel,
        max_shift=max_shift)

    # now spatially misalign those
    x_misaligned = util.make_spatially_misaligned(x_temporally_misaligned,
                                                  n_per_spike=misalign_ratio2)

    # determine noise covariance structure
    spatial_SIG, temporal_SIG = noise_cov(rec,
                                          temporal_size=templates.shape[1],
                                          window_size=templates.shape[1],
                                          sample_size=1000,
                                          threshold=3.0)

    # make noise
    n_noise = int(x_templates.shape[0] * noise_ratio)
    noise = util.make_noise(n_noise, spatial_SIG, temporal_SIG)

    # make labels
    y_clean_1 = np.ones((x_templates.shape[0]))
    y_collision_1 = np.ones((x_collision.shape[0]))

    y_misaligned_0 = np.zeros((x_misaligned.shape[0]))
    y_noise_0 = np.zeros((noise.shape[0]))
    y_collision_0 = np.zeros((x_collision.shape[0]))

    mid_point = int((x_templates.shape[1]-1)/2)
    MID_POINT_IDX = slice(mid_point - R, mid_point + R + 1)

    # TODO: replace _make_noisy for new function
    x_templates_noisy = util._make_noisy(x_templates, noise)
    x_collision_noisy = util._make_noisy(x_collision, noise)
    x_misaligned_noisy = util._make_noisy(x_misaligned,
                                          noise)

    #############
    # Detection #
    #############

    if multi_channel:
        x = yarr.concatenate((x_templates_noisy, x_collision_noisy,
                              x_misaligned_noisy, noise))
        x_detect = x[:, MID_POINT_IDX, :]

        y_detect = np.concatenate((y_clean_1, y_collision_1,
                                   y_misaligned_0, y_noise_0))
    else:
        x = yarr.concatenate((x_templates_noisy, x_misaligned_noisy,
                              noise))
        x_detect = x[:, MID_POINT_IDX, 0]

        y_detect = yarr.concatenate((y_clean_1, y_misaligned_0, y_noise_0))
    ##########
    # Triage #
    ##########

    if multi_channel:
        x = yarr.concatenate((x_templates_noisy, x_collision_noisy))
        x_triage = x[:, MID_POINT_IDX, :]

        y_triage = yarr.concatenate((y_clean_1, y_collision_0))
    else:
        x = yarr.concatenate((x_templates_noisy, x_collision_noisy,))
        x_triage = x[:, MID_POINT_IDX, 0]

        y_triage = yarr.concatenate((y_clean_1, y_collision_0))

    ###############
    # Autoencoder #
    ###############

    # # TODO: need to abstract this part of the code, create a separate
    # # function and document it
    # neighbors_ae = np.ones((n_channels, n_channels), 'int32')

    # templates_ae = crop_and_align_templates(templates_uncropped,
    #                                         CONFIG.spike_size,
    #                                         neighbors_ae,
    #                                         CONFIG.geom)

    # tt = templates_ae.transpose(1, 0, 2).reshape(templates_ae.shape[1], -1)
    # tt = tt[:, np.ptp(tt, axis=0) > 2]
    # max_amp = np.max(np.ptp(tt, axis=0))

    # y_ae = np.zeros((nk*tt.shape[1], tt.shape[0]))

    # for k in range(tt.shape[1]):
    #     amp_now = np.ptp(tt[:, k])
    #     amps_range = (np.arange(nk)*(max_amp-min_amp)
    #                   / nk+min_amp)[:, np.newaxis, np.newaxis]

    #     y_ae[k*nk:(k+1)*nk] = ((tt[:, k]/amp_now)[np.newaxis, :]
    #                            * amps_range[:, :, 0])

    # noise_ae = np.random.normal(size=y_ae.shape)
    # noise_ae = np.matmul(noise_ae, temporal_SIG)

    # x_ae = y_ae + noise_ae
    # x_ae = x_ae[:, MID_POINT_IDX]
    # y_ae = y_ae[:, MID_POINT_IDX]

    x_ae = None
    y_ae = None

    # FIXME: y_ae is no longer used, autoencoder was replaced by PCA
    return x_detect, y_detect, x_triage, y_triage, x_ae, y_ae
Ejemplo n.º 8
0
def make_training_data(CONFIG,
                       spike_train,
                       chosen_templates,
                       min_amp,
                       nspikes,
                       data_folder,
                       noise_ratio=10,
                       collision_ratio=1,
                       misalign_ratio=1,
                       misalign_ratio2=1,
                       multi=True):
    """[Description]
    Parameters
    ----------
    CONFIG: yaml file
        Configuration file
    spike_train: numpy.ndarray
        [number of spikes, 2] Ground truth for training. First column is the
        spike time, second column is the spike id
    chosen_templates: list
        List of chosen templates' id's
    min_amp: float
        Minimum value allowed for the maximum absolute amplitude of the
        isolated spike on its main channel
    nspikes: int
        Number of isolated spikes to generate. This is different from the
        total number of x_detect
    data_folder: str
        Folder storing the standarized data (if not exist, run preprocess to
        automatically generate)
    noise_ratio: int
        Ratio of number of noise to isolated spikes. For example, if
        n_isolated_spike=1000, noise_ratio=5, then n_noise=5000
    collision_ratio: int
        Ratio of number of collisions to isolated spikes.
    misalign_ratio: int
        Ratio of number of spatially and temporally misaligned spikes to
        isolated spikes
    misalign_ratio2: int
        Ratio of number of only-spatially misaligned spikes to isolated spikes
    multi: bool
        If multi= True, generate training data for multi-channel neural
        network. Otherwise generate single-channel data

    Returns
    -------
    x_detect: numpy.ndarray
        [number of detection training data, temporal length, number of
        channels] Training data for the detect net.
    y_detect: numpy.ndarray
        [number of detection training data] Label for x_detect

    x_triage: numpy.ndarray
        [number of triage training data, temporal length, number of channels]
        Training data for the triage net.
    y_triage: numpy.ndarray
        [number of triage training data] Label for x_triage


    x_ae: numpy.ndarray
        [number of ae training data, temporal length] Training data for the
        autoencoder: noisy spikes
    y_ae: numpy.ndarray
        [number of ae training data, temporal length] Denoised x_ae

    """

    logger = logging.getLogger(__name__)

    path_to_data = os.path.join(data_folder, 'standarized.bin')
    path_to_config = os.path.join(data_folder, 'standarized.yaml')

    # make sure standarized data already exists
    if not os.path.exists(path_to_data):
        raise ValueError(
            'Standarized data does not exist in: {}, this is '
            'needed to generate training data, run the '
            'preprocesor first to generate it'.format(path_to_data))

    PARAMS = load_yaml(path_to_config)

    logger.info('Getting templates...')

    # get templates
    templates, _ = get_templates(
        np.hstack((spike_train, np.ones((spike_train.shape[0], 1), 'int32'))),
        path_to_data, CONFIG.resources.max_memory, 4 * CONFIG.spike_size)

    templates = np.transpose(templates, (2, 1, 0))

    logger.info('Got templates ndarray of shape: {}'.format(templates.shape))

    # choose good templates (good looking and big enough)
    templates = choose_templates(templates, chosen_templates)
    templates_uncropped = np.copy(templates)

    if templates.shape[0] == 0:
        raise ValueError("Coulndt find any good templates...")

    logger.info('Good looking templates of shape: {}'.format(templates.shape))

    # align and crop templates
    templates = crop_templates(templates, CONFIG.spike_size,
                               CONFIG.neigh_channels, CONFIG.geom)

    # determine noise covariance structure
    spatial_SIG, temporal_SIG = noise_cov(path_to_data, PARAMS['dtype'],
                                          CONFIG.recordings.n_channels,
                                          PARAMS['data_order'],
                                          CONFIG.neigh_channels, CONFIG.geom,
                                          templates.shape[1])

    # make training data set
    K = templates.shape[0]
    R = CONFIG.spike_size
    amps = np.max(np.abs(templates), axis=1)

    # make clean augmented spikes
    nk = int(np.ceil(nspikes / K))
    max_amp = np.max(amps) * 1.5
    nneigh = templates.shape[2]

    ################
    # clean spikes #
    ################
    x_clean = np.zeros((nk * K, templates.shape[1], templates.shape[2]))
    for k in range(K):
        tt = templates[k]
        amp_now = np.max(np.abs(tt))
        amps_range = (np.arange(nk) * (max_amp - min_amp) / nk +
                      min_amp)[:, np.newaxis, np.newaxis]
        x_clean[k * nk:(k + 1) *
                nk] = (tt / amp_now)[np.newaxis, :, :] * amps_range

    #############
    # collision #
    #############
    x_collision = np.zeros((x_clean.shape[0] * int(collision_ratio),
                            templates.shape[1], templates.shape[2]))
    max_shift = 2 * R

    temporal_shifts = np.random.randint(max_shift * 2,
                                        size=x_collision.shape[0]) - max_shift
    temporal_shifts[
        temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5
    temporal_shifts[
        temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6

    amp_per_data = np.max(x_clean[:, :, 0], axis=1)
    for j in range(x_collision.shape[0]):
        shift = temporal_shifts[j]

        x_collision[j] = np.copy(x_clean[np.random.choice(x_clean.shape[0],
                                                          1,
                                                          replace=True)])
        idx_candidate = np.where(
            amp_per_data > np.max(x_collision[j, :, 0]) * 0.3)[0]
        idx_match = idx_candidate[np.random.randint(idx_candidate.shape[0],
                                                    size=1)[0]]
        if multi:
            x_clean2 = np.copy(
                x_clean[idx_match]
                [:, np.random.choice(nneigh, nneigh, replace=False)])
        else:
            x_clean2 = np.copy(x_clean[idx_match])

        if shift > 0:
            x_collision[j, :(x_collision.shape[1] - shift)] += x_clean2[shift:]

        elif shift < 0:
            x_collision[j,
                        (-shift):] += x_clean2[:(x_collision.shape[1] + shift)]
        else:
            x_collision[j] += x_clean2

    ###############################################
    # temporally and spatially misaligned spikes #
    #############################################
    x_misaligned = np.zeros((x_clean.shape[0] * int(misalign_ratio),
                             templates.shape[1], templates.shape[2]))

    temporal_shifts = np.random.randint(max_shift * 2,
                                        size=x_misaligned.shape[0]) - max_shift
    temporal_shifts[
        temporal_shifts < 0] = temporal_shifts[temporal_shifts < 0] - 5
    temporal_shifts[
        temporal_shifts >= 0] = temporal_shifts[temporal_shifts >= 0] + 6

    for j in range(x_misaligned.shape[0]):
        shift = temporal_shifts[j]
        if multi:
            x_clean2 = np.copy(
                x_clean[np.random.choice(x_clean.shape[0], 1, replace=True)]
                [:, :, np.random.choice(nneigh, nneigh, replace=False)])
            x_clean2 = np.squeeze(x_clean2)
        else:
            x_clean2 = np.copy(x_clean[np.random.choice(x_clean.shape[0],
                                                        1,
                                                        replace=True)])
            x_clean2 = np.squeeze(x_clean2)

        if shift > 0:
            x_misaligned[j, :(x_misaligned.shape[1] -
                              shift)] += x_clean2[shift:]

        elif shift < 0:
            x_misaligned[j, (-shift):] += x_clean2[:(x_misaligned.shape[1] +
                                                     shift)]
        else:
            x_misaligned[j] += x_clean2

    ################################
    # spatially misaligned spikes #
    ##############################
    if multi:
        x_misaligned2 = np.zeros((x_clean.shape[0] * int(misalign_ratio2),
                                  templates.shape[1], templates.shape[2]))
        for j in range(x_misaligned2.shape[0]):
            x_misaligned2[j] = np.copy(
                x_clean[np.random.choice(x_clean.shape[0], 1, replace=True)]
                [:, :, np.random.choice(nneigh, nneigh, replace=False)])

    #########
    # noise #
    #########

    # get noise
    noise = np.random.normal(size=[
        x_clean.shape[0] *
        int(noise_ratio), templates.shape[1], templates.shape[2]
    ])
    for c in range(noise.shape[2]):
        noise[:, :, c] = np.matmul(noise[:, :, c], temporal_SIG)

        reshaped_noise = np.reshape(noise, (-1, noise.shape[2]))
    noise = np.reshape(np.matmul(reshaped_noise, spatial_SIG),
                       [noise.shape[0], x_clean.shape[1], x_clean.shape[2]])

    y_clean = np.ones((x_clean.shape[0]))
    y_col = np.ones((x_collision.shape[0]))
    y_misaligned = np.zeros((x_misaligned.shape[0]))
    if multi:
        y_misaligned2 = np.zeros((x_misaligned2.shape[0]))
    y_noise = np.zeros((noise.shape[0]))

    mid_point = int((x_clean.shape[1] - 1) / 2)

    # get training set for detection
    if multi:
        x = np.concatenate(
            (x_clean + noise[np.random.choice(
                noise.shape[0], x_clean.shape[0], replace=False)],
             x_collision + noise[np.random.choice(
                 noise.shape[0], x_collision.shape[0], replace=False)],
             x_misaligned + noise[np.random.choice(
                 noise.shape[0], x_misaligned.shape[0], replace=False)],
             noise))

        x_detect = x[:, (mid_point - R):(mid_point + R + 1), :]
        y_detect = np.concatenate((y_clean, y_col, y_misaligned, y_noise))
    else:
        x = np.concatenate(
            (x_clean + noise[np.random.choice(
                noise.shape[0], x_clean.shape[0], replace=False)],
             x_misaligned + noise[np.random.choice(
                 noise.shape[0], x_misaligned.shape[0], replace=False)],
             noise))
        x_detect = x[:, (mid_point - R):(mid_point + R + 1), 0]
        y_detect = np.concatenate((y_clean, y_misaligned, y_noise))

    # get training set for triage
    if multi:
        x = np.concatenate((
            x_clean + noise[np.random.choice(
                noise.shape[0], x_clean.shape[0], replace=False)],
            x_collision + noise[np.random.choice(
                noise.shape[0], x_collision.shape[0], replace=False)],
            x_misaligned2 + noise[np.random.choice(
                noise.shape[0], x_misaligned2.shape[0], replace=False)],
        ))
        x_triage = x[:, (mid_point - R):(mid_point + R + 1), :]
        y_triage = np.concatenate((y_clean, np.zeros(
            (x_collision.shape[0])), y_misaligned2))
    else:
        x = np.concatenate((
            x_clean + noise[np.random.choice(
                noise.shape[0], x_clean.shape[0], replace=False)],
            x_collision + noise[np.random.choice(
                noise.shape[0], x_collision.shape[0], replace=False)],
        ))
        x_triage = x[:, (mid_point - R):(mid_point + R + 1), 0]
        y_triage = np.concatenate((y_clean, np.zeros((x_collision.shape[0]))))

    ###############
    # Autoencoder #
    ###############

    n_channels = templates_uncropped.shape[2]
    templates_ae = crop_templates(templates_uncropped, CONFIG.spike_size,
                                  np.ones((n_channels, n_channels), 'int32'),
                                  CONFIG.geom)

    tt = templates_ae.transpose(1, 0, 2).reshape(templates_ae.shape[1], -1)
    tt = tt[:, np.ptp(tt, axis=0) > 2]
    max_amp = np.max(np.ptp(tt, axis=0))

    y_ae = np.zeros((nk * tt.shape[1], tt.shape[0]))
    for k in range(tt.shape[1]):
        amp_now = np.ptp(tt[:, k])
        amps_range = (np.arange(nk) * (max_amp - min_amp) / nk +
                      min_amp)[:, np.newaxis, np.newaxis]

        y_ae[k * nk:(k + 1) * nk] = ((tt[:, k] / amp_now)[np.newaxis, :] *
                                     amps_range[:, :, 0])

    noise = np.random.normal(size=y_ae.shape)
    noise = np.matmul(noise, temporal_SIG)

    x_ae = y_ae + noise
    x_ae = x_ae[:, (mid_point - R):(mid_point + R + 1)]
    y_ae = y_ae[:, (mid_point - R):(mid_point + R + 1)]

    return x_detect, y_detect, x_triage, y_triage, x_ae, y_ae
Ejemplo n.º 9
0
def make_training_data(CONFIG,
                       spike_train,
                       chosen_templates_indexes,
                       min_amp,
                       nspikes,
                       data_folder,
                       noise_ratio=10,
                       collision_ratio=1,
                       misalign_ratio=1,
                       misalign_ratio2=1,
                       multi_channel=True):
    """Makes training sets for detector, triage and autoencoder

    Parameters
    ----------
    CONFIG: yaml file
        Configuration file
    spike_train: numpy.ndarray
        [number of spikes, 2] Ground truth for training. First column is the
        spike time, second column is the spike id
    chosen_templates_indexes: list
        List of chosen templates' id's
    min_amp: float
        Minimum value allowed for the maximum absolute amplitude of the
        isolated spike on its main channel
    nspikes: int
        Number of isolated spikes to generate. This is different from the
        total number of x_detect
    data_folder: str
        Folder storing the standarized data (if not exist, run preprocess to
        automatically generate)
    noise_ratio: int
        Ratio of number of noise to isolated spikes. For example, if
        n_isolated_spike=1000, noise_ratio=5, then n_noise=5000
    collision_ratio: int
        Ratio of number of collisions to isolated spikes.
    misalign_ratio: int
        Ratio of number of spatially and temporally misaligned spikes to
        isolated spikes
    misalign_ratio2: int
        Ratio of number of only-spatially misaligned spikes to isolated spikes
    multi_channel: bool
        If True, generate training data for multi-channel neural
        network. Otherwise generate single-channel data

    Returns
    -------
    x_detect: numpy.ndarray
        [number of detection training data, temporal length, number of
        channels] Training data for the detect net.
    y_detect: numpy.ndarray
        [number of detection training data] Label for x_detect

    x_triage: numpy.ndarray
        [number of triage training data, temporal length, number of channels]
        Training data for the triage net.
    y_triage: numpy.ndarray
        [number of triage training data] Label for x_triage
    x_ae: numpy.ndarray
        [number of ae training data, temporal length] Training data for the
        autoencoder: noisy spikes
    y_ae: numpy.ndarray
        [number of ae training data, temporal length] Denoised x_ae

    Notes
    -----
    * Detection training data
        * Multi channel
            * Positive examples: Clean spikes + noise, Collided spikes + noise
            * Negative examples: Temporally misaligned spikes + noise, Noise

    * Triage training data
        * Multi channel
            * Positive examples: Clean spikes + noise
            * Negative examples: Collided spikes + noise,
                spatially misaligned spikes  + noise
    """

    logger = logging.getLogger(__name__)

    path_to_data = os.path.join(data_folder, 'preprocess', 'standarized.bin')

    n_spikes, _ = spike_train.shape

    # make sure standarized data already exists
    if not os.path.exists(path_to_data):
        raise ValueError(
            'Standarized data does not exist in: {}, this is '
            'needed to generate training data, run the '
            'preprocesor first to generate it'.format(path_to_data))

    logger.info('Getting templates...')

    # add weight of one to every spike
    weighted_spike_train = np.hstack(
        (spike_train, np.ones((n_spikes, 1), 'int32')))

    # get templates
    templates_uncropped, _ = get_templates(weighted_spike_train, path_to_data,
                                           CONFIG.resources.max_memory,
                                           4 * CONFIG.spike_size)

    templates_uncropped = np.transpose(templates_uncropped, (2, 1, 0))

    K, _, n_channels = templates_uncropped.shape

    logger.info('Got templates ndarray of shape: {}'.format(
        templates_uncropped.shape))

    # choose good templates (user selected and amplitude above threshold)
    # TODO: maybe the minimum_amplitude parameter should be selected by the
    # user
    templates_uncropped = choose_templates(templates_uncropped,
                                           chosen_templates_indexes,
                                           minimum_amplitude=4)

    if templates_uncropped.shape[0] == 0:
        raise ValueError("Coulndt find any good templates...")

    logger.info('Good looking templates of shape: {}'.format(
        templates_uncropped.shape))

    templates = crop_and_align_templates(templates_uncropped,
                                         CONFIG.spike_size,
                                         CONFIG.neigh_channels, CONFIG.geom)

    # make training data set
    R = CONFIG.spike_size
    amps = np.max(np.abs(templates), axis=1)

    # make clean augmented spikes
    nk = int(np.ceil(nspikes / K))
    max_amp = np.max(amps) * 1.5
    nneigh = templates.shape[2]
    max_shift = 2 * R

    # make clean spikes
    x_clean = make_clean(templates, min_amp, max_amp, nk)

    # make collided spikes
    x_collision = make_collided(x_clean, collision_ratio, templates, R,
                                multi_channel, nneigh)

    # make misaligned spikes
    (x_temporally_misaligned,
     x_spatially_misaligned) = make_misaligned(x_clean, templates, max_shift,
                                               misalign_ratio, misalign_ratio2,
                                               multi_channel, nneigh)

    # determine noise covariance structure
    spatial_SIG, temporal_SIG = noise_cov(path_to_data, CONFIG.neigh_channels,
                                          CONFIG.geom, templates.shape[1])

    # make noise
    noise = make_noise(x_clean, noise_ratio, templates, spatial_SIG,
                       temporal_SIG)

    # make labels
    y_clean_1 = np.ones((x_clean.shape[0]))
    y_collision_1 = np.ones((x_collision.shape[0]))

    y_misaligned_0 = np.zeros((x_temporally_misaligned.shape[0]))
    y_noise_0 = np.zeros((noise.shape[0]))
    y_collision_0 = np.zeros((x_collision.shape[0]))

    if multi_channel:
        y_misaligned2_0 = np.zeros((x_spatially_misaligned.shape[0]))

    mid_point = int((x_clean.shape[1] - 1) / 2)
    MID_POINT_IDX = slice(mid_point - R, mid_point + R + 1)

    x_clean_noisy = make_noisy(x_clean, noise)
    x_collision_noisy = make_noisy(x_collision, noise)
    x_temporally_misaligned_noisy = make_noisy(x_temporally_misaligned, noise)
    x_spatially_misaligned_noisy = make_noisy(x_spatially_misaligned, noise)

    #############
    # Detection #
    #############

    if multi_channel:
        x = np.concatenate((x_clean_noisy, x_collision_noisy,
                            x_temporally_misaligned_noisy, noise))
        x_detect = x[:, MID_POINT_IDX, :]

        y_detect = np.concatenate(
            (y_clean_1, y_collision_1, y_misaligned_0, y_noise_0))
    else:
        x = np.concatenate(
            (x_clean_noisy, x_temporally_misaligned_noisy, noise))
        x_detect = x[:, MID_POINT_IDX, 0]

        y_detect = np.concatenate((y_clean_1, y_misaligned_0, y_noise_0))

    ##########
    # Triage #
    ##########

    if multi_channel:
        x = np.concatenate(
            (x_clean_noisy, x_collision_noisy, x_spatially_misaligned_noisy))
        x_triage = x[:, MID_POINT_IDX, :]

        y_triage = np.concatenate((y_clean_1, y_collision_0, y_misaligned2_0))
    else:
        x = np.concatenate((
            x_clean_noisy,
            x_collision_noisy,
        ))
        x_triage = x[:, MID_POINT_IDX, 0]

        y_triage = np.concatenate((y_clean_1, y_collision_0))

    ###############
    # Autoencoder #
    ###############

    # TODO: need to abstract this part of the code, create a separate
    # function and document it
    neighbors_ae = np.ones((n_channels, n_channels), 'int32')

    templates_ae = crop_and_align_templates(templates_uncropped,
                                            CONFIG.spike_size, neighbors_ae,
                                            CONFIG.geom)

    tt = templates_ae.transpose(1, 0, 2).reshape(templates_ae.shape[1], -1)
    tt = tt[:, np.ptp(tt, axis=0) > 2]
    max_amp = np.max(np.ptp(tt, axis=0))

    y_ae = np.zeros((nk * tt.shape[1], tt.shape[0]))

    for k in range(tt.shape[1]):
        amp_now = np.ptp(tt[:, k])
        amps_range = (np.arange(nk) * (max_amp - min_amp) / nk +
                      min_amp)[:, np.newaxis, np.newaxis]

        y_ae[k * nk:(k + 1) * nk] = ((tt[:, k] / amp_now)[np.newaxis, :] *
                                     amps_range[:, :, 0])

    noise_ae = np.random.normal(size=y_ae.shape)
    noise_ae = np.matmul(noise_ae, temporal_SIG)

    x_ae = y_ae + noise_ae
    x_ae = x_ae[:, MID_POINT_IDX]
    y_ae = y_ae[:, MID_POINT_IDX]

    return x_detect, y_detect, x_triage, y_triage, x_ae, y_ae