Ejemplo n.º 1
0
    def crop_spatially(self, neighbors, geometry, inplace=False):
        n_templates, waveform_length, _ = self.templates.shape

        # spatially crop (only keep neighbors)
        n_neigh_to_keep = np.max(np.sum(neighbors, 0))
        new_templates = np.zeros(
            (n_templates, waveform_length, n_neigh_to_keep))

        for k in range(n_templates):

            # get neighbors for the main channel in the kth template
            ch_idx = np.where(neighbors[self.main_channels[k]])[0]

            # order channels
            ch_idx, _ = order_channels_by_distance(self.main_channels[k],
                                                   ch_idx, geometry)

            # new kth template is the old kth template by keeping only
            # ordered neighboring channels
            new_templates[k, :, :ch_idx.shape[0]] = self.templates[k][:,
                                                                      ch_idx]

        if inplace:
            self._update_templates(new_templates)
        else:
            return TemplatesProcessor(new_templates)
Ejemplo n.º 2
0
def matrix_localized(ts, neighbors, geom, spike_size):
    """Spatial whitening filter for time series
    [How is this different from the other method?]

    Parameters
    ----------
    ts: np.array
        T x C numpy array, where T is the number of time samples and
        C is the number of channels

    Returns
    -------
    numpy.ndarray (n_channels, n_channels)
        whitening matrix
    """
    # get all necessary parameters from param
    [T, C] = ts.shape
    R = spike_size * 2 + 1
    th = 4
    nneigh = np.max(np.sum(neighbors, 0))

    # masked recording
    spikes_rec = np.ones(ts.shape)
    for i in range(0, C):
        idxCrossing = np.where(ts[:, i] < -th)[0]
        idxCrossing = idxCrossing[np.logical_and(idxCrossing >= (R + 1),
                                                 idxCrossing <= (T - R - 1))]
        spike_time = idxCrossing[np.logical_and(
            ts[idxCrossing, i] <= ts[idxCrossing - 1, i],
            ts[idxCrossing, i] <= ts[idxCrossing + 1, i])]

        # the portion of recording where spikes present is set to nan
        for j in np.arange(-spike_size, spike_size + 1):
            spikes_rec[spike_time + j, i] = 0

    # get covariance matrix
    blanked_rec = ts * spikes_rec
    M = np.matmul(blanked_rec.transpose(), blanked_rec) / \
        np.matmul(spikes_rec.transpose(), spikes_rec)

    # since ts is standardized recording, covaraince = correlation
    invhalf_var = np.diag(np.power(np.diag(M), -0.5))
    M = np.matmul(np.matmul(invhalf_var, M), invhalf_var)

    # get localized whitening filter
    Q = np.zeros((nneigh, nneigh, C))
    for c in range(0, C):
        ch_idx, _ = order_channels_by_distance(c,
                                               np.where(neighbors[c])[0], geom)
        nneigh_c = ch_idx.shape[0]

        V, D, _ = np.linalg.svd(M[ch_idx, :][:, ch_idx])
        eps = 1e-6
        Epsilon = np.diag(1 / np.power((D + eps), 0.5))
        Q_small = np.matmul(np.matmul(V, Epsilon), V.transpose())
        Q[:nneigh_c][:, :nneigh_c, c] = Q_small

    return Q
Ejemplo n.º 3
0
def crop_templates(templatesBig, R, neighbors, geom):
    """[Description]

    Parameters
    ----------

    Returns
    -------
    """
    # number of templates
    K = templatesBig.shape[0]

    # main channel for each template and amplitudes
    mainC = np.argmax(np.amax(np.abs(templatesBig), axis=1), axis=1)
    amps = np.amax(np.abs(templatesBig), axis=(1, 2))

    # get a template on a main channel and align them
    K_big = np.argmax(amps)
    templates_mainc = np.zeros((K, templatesBig.shape[1]))
    t_rec = templatesBig[K_big, :, mainC[K_big]]
    t_rec = t_rec / np.sqrt(np.sum(np.square(t_rec)))
    for k in range(K):
        t1 = templatesBig[k, :, mainC[k]]
        t1 = t1 / np.sqrt(np.sum(np.square(t1)))
        shift = align_templates(t1, t_rec)
        if shift > 0:
            templates_mainc[k, :(templatesBig.shape[1] - shift)] = t1[shift:]
            templatesBig[k, :(templatesBig.shape[1] -
                              shift)] = templatesBig[k, shift:]

        elif shift < 0:
            templates_mainc[k,
                            (-shift):] = t1[:(templatesBig.shape[1] + shift)]
            templatesBig[k,
                         (-shift):] = templatesBig[k, :(templatesBig.shape[1] +
                                                        shift)]

        else:
            templates_mainc[k] = t1

    # determin temporal center of templates and crop around it
    R2 = int(R / 2)
    center = np.argmax(
        np.convolve(np.sum(np.square(templates_mainc), 0), np.ones(2 * R2 + 1),
                    'valid')) + R2
    templatesBig = templatesBig[:, (center - 3 * R):(center + 3 * R + 1)]

    # spatially crop
    nneigh = np.max(np.sum(neighbors, 0))
    templatesBig2 = np.zeros(
        (templatesBig.shape[0], templatesBig.shape[1], nneigh))
    for k in range(K):
        ch_idx = np.where(neighbors[mainC[k]])[0]
        ch_idx, temp = order_channels_by_distance(mainC[k], ch_idx, geom)
        templatesBig2[k, :, :ch_idx.shape[0]] = templatesBig[k][:, ch_idx]

    return templatesBig2
Ejemplo n.º 4
0
def covariance(recordings, temporal_size, neigbor_steps):
    """Compute noise spatial and temporal covariance

    Parameters
    ----------
    recordings: matrix
        Multi-cannel recordings (n observations x n channels)
    temporal_size:

    neigbor_steps: int
        Number of steps from the multi-channel geometry to consider two
        channels as neighors
    """
    CONFIG = read_config()

    # get the neighbor channels at a max "neigbor_steps" steps
    neigh_channels = n_steps_neigh_channels(CONFIG.neighChannels,
                                            neigbor_steps)

    # sum neighor flags for every channel, this gives the number of neighbors
    # per channel, then find the channel with the most neighbors
    # TODO: why are we selecting this one?
    channel = np.argmax(np.sum(neigh_channels, 0))

    # get the neighbor channels for "channel"
    (neighbords_idx, ) = np.where(neigh_channels[channel])

    # order neighbors by distance
    neighbords_idx, temp = order_channels_by_distance(channel, neighbords_idx,
                                                      CONFIG.geom)

    # from the multi-channel recordings, get the neighbor channels
    # (this includes the channel with the most neighbors itself)
    rec = recordings[:, neighbords_idx]

    # filter recording
    if CONFIG.preprocess.filter == 1:
        rec = butterworth(rec, CONFIG.filter.low_pass_freq,
                          CONFIG.filter.high_factor, CONFIG.filter.order,
                          CONFIG.recordings.sampling_rate)

    # standardize recording
    sd_ = standarize.sd(rec, CONFIG.recordings.sampling_rate)
    rec = standarize.standarize(rec, sd_)

    # compute and return spatial and temporal covariance
    return util.covariance(rec, temporal_size, neigbor_steps, CONFIG.spikeSize)
Ejemplo n.º 5
0
Archivo: make.py Proyecto: Nomow/yass
def training_data(CONFIG, templates_uncropped, min_amp, max_amp,
                  n_isolated_spikes,
                  path_to_standarized, noise_ratio=10,
                  collision_ratio=1, misalign_ratio=1, misalign_ratio2=1,
                  multi_channel=True, return_metadata=False):
    """Makes training sets for detector, triage and autoencoder

    Parameters
    ----------
    CONFIG: yaml file
        Configuration file
    min_amp: float
        Minimum value allowed for the maximum absolute amplitude of the
        isolated spike on its main channel
    max_amp: float
        Maximum value allowed for the maximum absolute amplitude of the
        isolated spike on its main channel
    n_isolated_spikes: int
        Number of isolated spikes to generate. This is different from the
        total number of x_detect
    path_to_standarized: str
        Folder storing the standarized data (if not exist, run preprocess to
        automatically generate)
    noise_ratio: int
        Ratio of number of noise to isolated spikes. For example, if
        n_isolated_spike=1000, noise_ratio=5, then n_noise=5000
    collision_ratio: int
        Ratio of number of collisions to isolated spikes.
    misalign_ratio: int
        Ratio of number of spatially and temporally misaligned spikes to
        isolated spikes
    misalign_ratio2: int
        Ratio of number of only-spatially misaligned spikes to isolated spikes
    multi_channel: bool
        If True, generate training data for multi-channel neural
        network. Otherwise generate single-channel data

    Returns
    -------
    x_detect: numpy.ndarray
        [number of detection training data, temporal length, number of
        channels] Training data for the detect net.
    y_detect: numpy.ndarray
        [number of detection training data] Label for x_detect

    x_triage: numpy.ndarray
        [number of triage training data, temporal length, number of channels]
        Training data for the triage net.
    y_triage: numpy.ndarray
        [number of triage training data] Label for x_triage
    x_ae: numpy.ndarray
        [number of ae training data, temporal length] Training data for the
        autoencoder: noisy spikes
    y_ae: numpy.ndarray
        [number of ae training data, temporal length] Denoised x_ae

    Notes
    -----
    * Detection training data
        * Multi channel
            * Positive examples: Clean spikes + noise, Collided spikes + noise
            * Negative examples: Temporally misaligned spikes + noise, Noise

    * Triage training data
        * Multi channel
            * Positive examples: Clean spikes + noise
            * Negative examples: Collided spikes + noise
    """
    # FIXME: should we add collided spikes with the first spike non-centered
    # tod the detection training set?

    logger = logging.getLogger(__name__)

    # STEP1: Load recordings data, and select one channel and random (with the
    # right number of neighbors, then swap the channels so the first one
    # corresponds to the selected channel, then the nearest neighbor, then the
    # second nearest and so on... this is only used for estimating noise
    # structure

    # ##### FIXME: this needs to be removed, the user should already
    # pass data with the desired channels
    rec = RecordingsReader(path_to_standarized, loader='array')
    channel_n_neighbors = np.sum(CONFIG.neigh_channels, 0)
    max_neighbors = np.max(channel_n_neighbors)
    channels_with_max_neighbors = np.where(channel_n_neighbors
                                           == max_neighbors)[0]
    logger.debug('The following channels have %i neighbors: %s',
                 max_neighbors, channels_with_max_neighbors)

    # reference channel: channel with max number of neighbors
    channel_selected = np.random.choice(channels_with_max_neighbors)

    logger.debug('Selected channel %i', channel_selected)

    # neighbors for the reference channel
    channel_neighbors = np.where(CONFIG.neigh_channels[channel_selected])[0]

    # ordered neighbors for reference channel
    channel_idx, _ = order_channels_by_distance(channel_selected,
                                                channel_neighbors,
                                                CONFIG.geom)
    # read the selected channels
    rec = rec[:, channel_idx]
    # ##### FIXME:end of section to be removed

    # STEP 2: load templates

    processor = TemplatesProcessor(templates_uncropped)

    # swap channels, first channel is main channel, then nearest neighbor
    # and so on, only keep neigh_channels
    templates = (processor.crop_spatially(CONFIG.neigh_channels, CONFIG.geom)
                 .values)

    # TODO: remove, this data can be obtained from other variables
    K, _, n_channels = templates_uncropped.shape

    # make training data set
    R = CONFIG.spike_size

    logger.debug('Output will be of size %s', 2 * R + 1)

    # make clean augmented spikes
    nk = int(np.ceil(n_isolated_spikes/K))
    max_shift = 2*R

    # make spikes from templates
    x_templates = util.make_from_templates(templates, min_amp, max_amp, nk)

    # make collided spikes - max shift is set to R since 2 * R + 1 will be
    # the final dimension for the spikes. one of the spikes is kept with the
    # main channel, the other one is shifted and channels are changed
    x_collision = util.make_collided(x_templates, collision_ratio,
                                     multi_channel, max_shift=R,
                                     min_shift=5,
                                     return_metadata=return_metadata)

    # make misaligned spikes
    x_temporally_misaligned = util.make_temporally_misaligned(
        x_templates, misalign_ratio,
        multi_channel=multi_channel,
        max_shift=max_shift)

    # now spatially misalign those
    x_misaligned = util.make_spatially_misaligned(x_temporally_misaligned,
                                                  n_per_spike=misalign_ratio2)

    # determine noise covariance structure
    spatial_SIG, temporal_SIG = noise_cov(rec,
                                          temporal_size=templates.shape[1],
                                          window_size=templates.shape[1],
                                          sample_size=1000,
                                          threshold=3.0)

    # make noise
    n_noise = int(x_templates.shape[0] * noise_ratio)
    noise = util.make_noise(n_noise, spatial_SIG, temporal_SIG)

    # make labels
    y_clean_1 = np.ones((x_templates.shape[0]))
    y_collision_1 = np.ones((x_collision.shape[0]))

    y_misaligned_0 = np.zeros((x_misaligned.shape[0]))
    y_noise_0 = np.zeros((noise.shape[0]))
    y_collision_0 = np.zeros((x_collision.shape[0]))

    mid_point = int((x_templates.shape[1]-1)/2)
    MID_POINT_IDX = slice(mid_point - R, mid_point + R + 1)

    # TODO: replace _make_noisy for new function
    x_templates_noisy = util._make_noisy(x_templates, noise)
    x_collision_noisy = util._make_noisy(x_collision, noise)
    x_misaligned_noisy = util._make_noisy(x_misaligned,
                                          noise)

    #############
    # Detection #
    #############

    if multi_channel:
        x = yarr.concatenate((x_templates_noisy, x_collision_noisy,
                              x_misaligned_noisy, noise))
        x_detect = x[:, MID_POINT_IDX, :]

        y_detect = np.concatenate((y_clean_1, y_collision_1,
                                   y_misaligned_0, y_noise_0))
    else:
        x = yarr.concatenate((x_templates_noisy, x_misaligned_noisy,
                              noise))
        x_detect = x[:, MID_POINT_IDX, 0]

        y_detect = yarr.concatenate((y_clean_1, y_misaligned_0, y_noise_0))
    ##########
    # Triage #
    ##########

    if multi_channel:
        x = yarr.concatenate((x_templates_noisy, x_collision_noisy))
        x_triage = x[:, MID_POINT_IDX, :]

        y_triage = yarr.concatenate((y_clean_1, y_collision_0))
    else:
        x = yarr.concatenate((x_templates_noisy, x_collision_noisy,))
        x_triage = x[:, MID_POINT_IDX, 0]

        y_triage = yarr.concatenate((y_clean_1, y_collision_0))

    ###############
    # Autoencoder #
    ###############

    # # TODO: need to abstract this part of the code, create a separate
    # # function and document it
    # neighbors_ae = np.ones((n_channels, n_channels), 'int32')

    # templates_ae = crop_and_align_templates(templates_uncropped,
    #                                         CONFIG.spike_size,
    #                                         neighbors_ae,
    #                                         CONFIG.geom)

    # tt = templates_ae.transpose(1, 0, 2).reshape(templates_ae.shape[1], -1)
    # tt = tt[:, np.ptp(tt, axis=0) > 2]
    # max_amp = np.max(np.ptp(tt, axis=0))

    # y_ae = np.zeros((nk*tt.shape[1], tt.shape[0]))

    # for k in range(tt.shape[1]):
    #     amp_now = np.ptp(tt[:, k])
    #     amps_range = (np.arange(nk)*(max_amp-min_amp)
    #                   / nk+min_amp)[:, np.newaxis, np.newaxis]

    #     y_ae[k*nk:(k+1)*nk] = ((tt[:, k]/amp_now)[np.newaxis, :]
    #                            * amps_range[:, :, 0])

    # noise_ae = np.random.normal(size=y_ae.shape)
    # noise_ae = np.matmul(noise_ae, temporal_SIG)

    # x_ae = y_ae + noise_ae
    # x_ae = x_ae[:, MID_POINT_IDX]
    # y_ae = y_ae[:, MID_POINT_IDX]

    x_ae = None
    y_ae = None

    # FIXME: y_ae is no longer used, autoencoder was replaced by PCA
    return x_detect, y_detect, x_triage, y_triage, x_ae, y_ae
Ejemplo n.º 6
0
def nn_detection(recordings, neighbors, geom, temporal_features,
                 temporal_window, th_detect, th_triage, detector_filename,
                 autoencoder_filename, triage_filename):
    """Detect spikes using a neural network

    Parameters
    ----------
    recordings: numpy.ndarray (n_observations, n_channels)
        Neural recordings

    neighbors: numpy.ndarray (n_channels, n_channels)
        Channels neighbors matric

    geom: numpy.ndarray (n_channels, 2)
        Cartesian coordinates for the channels

    temporal_features: int
        ?

    temporal_window: int
        ?

    th_detect: float?
        Spike threshold [improve this explanation]

    th_triage: float?
        Triage threshold [improve this explanation]

    detector_filename: str
        Path to neural network detector

    autoencoder_filename: str
        Path to neural network autoencoder

    triage_filename: str
        Path to triage neural network

    Returns
    -------
    clear_scores: numpy.ndarray (n_spikes, n_features, n_channels)
        3D array with the scores for the clear spikes, first simension is
        the number of spikes, second is the nymber of features and third the
        number of channels

    spike_index_clear: numpy.ndarray (n_clear_spikes, 2)
        2D array with indexes for clear spikes, first column contains the
        spike location in the recording and the second the main channel
        (channel whose amplitude is maximum)

    spike_index_collision: numpy.ndarray (n_collided_spikes, 2)
        2D array with indexes for collided spikes, first column contains the
        spike location in the recording and the second the main channel
        (channel whose amplitude is maximum)
    """
    nnd = NeuralNetDetector(detector_filename, autoencoder_filename)
    nnt = NeuralNetTriage(triage_filename)

    T, C = recordings.shape

    a, b = neighbors.shape

    if a != b:
        raise ValueError('neighbors is not a square matrix, verify')

    if a != C:
        raise ValueError(
            'Number of channels in recording are {} but the '
            'neighbors matrix has {} elements, they must match'.format(C, a))

    # neighboring channel info
    nneigh = np.max(np.sum(neighbors, 0))
    c_idx = np.ones((C, nneigh), 'int32') * C

    for c in range(C):
        ch_idx, temp = order_channels_by_distance(c,
                                                  np.where(neighbors[c])[0],
                                                  geom)
        c_idx[c, :ch_idx.shape[0]] = ch_idx

    # input
    x_tf = tf.placeholder("float", [T, C])

    # detect spike index
    local_max_idx_tf = nnd.get_spikes(x_tf, T, nneigh, c_idx, temporal_window,
                                      th_detect)

    # get score train
    score_train_tf = nnd.get_score_train(x_tf)

    # get energy for detected index
    energy_tf = tf.reduce_sum(tf.square(score_train_tf), axis=2)
    energy_val_tf = tf.gather_nd(energy_tf, local_max_idx_tf)

    # get triage probability
    triage_prob_tf = nnt.triage_prob(x_tf, T, nneigh, c_idx)

    # gather all results above
    result = (local_max_idx_tf, score_train_tf, energy_val_tf, triage_prob_tf)

    # remove duplicates
    energy_train_tf = tf.placeholder("float", [T, C])
    spike_index_tf = remove_duplicate_spikes_by_energy(energy_train_tf, T,
                                                       c_idx, temporal_window)

    # get score
    score_train_placeholder = tf.placeholder("float",
                                             [T, C, temporal_features])
    spike_index_clear_tf = tf.placeholder("int64", [None, 2])
    score_tf = get_score(score_train_placeholder, spike_index_clear_tf, T,
                         temporal_features, c_idx)

    ###############################
    # get values of above tensors #
    ###############################

    with tf.Session() as sess:

        nnd.saver.restore(sess, nnd.path_to_detector_model)
        nnd.saver_ae.restore(sess, nnd.path_to_ae_model)
        nnt.saver.restore(sess, nnt.path_to_triage_model)

        local_max_idx, score_train, energy_val, triage_prob = sess.run(
            result, feed_dict={x_tf: recordings})

        energy_train = np.zeros((T, C))
        energy_train[local_max_idx[:, 0], local_max_idx[:, 1]] = energy_val
        spike_index = sess.run(spike_index_tf,
                               feed_dict={energy_train_tf: energy_train})

        idx_clean = triage_prob[spike_index[:, 0], spike_index[:,
                                                               1]] > th_triage

        spike_index_clear = spike_index[idx_clean]
        spike_index_collision = spike_index[~idx_clean]

        score = sess.run(score_tf,
                         feed_dict={
                             score_train_placeholder: score_train,
                             spike_index_clear_tf: spike_index_clear
                         })

    return score, spike_index_clear, spike_index_collision
Ejemplo n.º 7
0
Archivo: crop.py Proyecto: Nomow/yass
def crop_and_align_templates(big_templates,
                             R,
                             neighbors,
                             geom,
                             crop_spatially=True):
    """Crop (spatially) and align (temporally) templates

    Parameters
    ----------

    Returns
    -------
    """
    logger = logging.getLogger(__name__)

    logger.debug('crop and align input shape %s', big_templates.shape)

    # copy templates to avoid modifying the original ones
    big_templates = np.copy(big_templates)

    n_templates, _, _ = big_templates.shape

    # main channels ad amplitudes for each template
    main_ch = main_channels(big_templates)
    amps = amplitudes(big_templates)

    # get a template on a main channel and align them
    K_big = np.argmax(amps)
    templates_mainc = np.zeros((n_templates, big_templates.shape[1]))
    t_rec = big_templates[K_big, :, main_ch[K_big]]
    t_rec = t_rec / np.sqrt(np.sum(np.square(t_rec)))

    for k in range(n_templates):
        t1 = big_templates[k, :, main_ch[k]]
        t1 = t1 / np.sqrt(np.sum(np.square(t1)))
        shift = align_templates(t1, t_rec)

        logger.debug('Template %i will be shifted by %i', k, shift)

        if shift > 0:
            templates_mainc[k, :(big_templates.shape[1] - shift)] = t1[shift:]
            big_templates[k, :(big_templates.shape[1] -
                               shift)] = big_templates[k, shift:]

        elif shift < 0:
            templates_mainc[k,
                            (-shift):] = t1[:(big_templates.shape[1] + shift)]
            big_templates[k, (-shift):] = big_templates[k, :(
                big_templates.shape[1] + shift)]

        else:
            templates_mainc[k] = t1

    # determin temporal center of templates and crop around it
    R2 = int(R / 2)
    center = np.argmax(
        np.convolve(np.sum(np.square(templates_mainc), 0), np.ones(2 * R2 + 1),
                    'valid')) + R2

    # crop templates, now they are from 4*R to 3*R
    logger.debug('6*R+1 %s', 6 * R + 1)
    big_templates = big_templates[:, (center - 3 * R):(center + 3 * R + 1)]

    if not crop_spatially:

        return big_templates

    else:
        # spatially crop (only keep neighbors)
        n_neigh_to_keep = np.max(np.sum(neighbors, 0))
        small = np.zeros(
            (n_templates, big_templates.shape[1], n_neigh_to_keep))

        for k in range(n_templates):

            # get neighbors for the main channel in the kth template
            ch_idx = np.where(neighbors[main_ch[k]])[0]

            # order channels
            ch_idx, _ = order_channels_by_distance(main_ch[k], ch_idx, geom)

            # new kth template is the old kth template by keeping only
            # ordered neighboring channels
            small[k, :, :ch_idx.shape[0]] = big_templates[k][:, ch_idx]

        return small
Ejemplo n.º 8
0
def noise_cov(path_to_data, dtype, n_channels, data_order, neighbors, geom,
              temporal_size):
    """[Description]

    Parameters
    ----------
    path_to_data: str
        Path to recordings data

    dtype: str
        dtype for recordings

    n_channels: int
        Number of channels in the recordings

    data_order: str
        Recordings order, one of ('channels', 'samples'). In a dataset with k
        observations per channel and j channels: 'channels' means first k
        contiguous observations come from channel 0, then channel 1, and so
        on. 'sample' means first j contiguous data are the first observations
        from all channels, then the second observations from all channels and
        so on

    neighbors: numpy.ndarray
        Neighbors matrix

    geom: numpy.ndarray
        Cartesian coordinates for the channels

    temporal_size:
        Waveform size

    Returns
    -------
    """
    c_ref = np.argmax(np.sum(neighbors, 0))
    ch_idx = np.where(neighbors[c_ref])[0]
    ch_idx, temp = order_channels_by_distance(c_ref, ch_idx, geom)

    rec = RecordingsReader(path_to_data, dtype=dtype, n_channels=n_channels,
                           data_order=data_order, loader='array')
    rec = rec[:, ch_idx]

    T, C = rec.shape
    idxNoise = np.zeros((T, C))

    R = int((temporal_size-1)/2)
    for c in range(C):
        idx_temp = np.where(rec[:, c] > 3)[0]
        for j in range(-R, R+1):
            idx_temp2 = idx_temp + j
            idx_temp2 = idx_temp2[np.logical_and(
                idx_temp2 >= 0, idx_temp2 < T)]
            rec[idx_temp2, c] = np.nan
        idxNoise_temp = (rec[:, c] == rec[:, c])
        rec[:, c] = rec[:, c]/np.nanstd(rec[:, c])

        rec[~idxNoise_temp, c] = 0
        idxNoise[idxNoise_temp, c] = 1

    spatial_cov = np.divide(np.matmul(rec.T, rec),
                            np.matmul(idxNoise.T, idxNoise))

    w, v = np.linalg.eig(spatial_cov)
    spatial_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T)

    spatial_whitener = np.matmul(np.matmul(v, np.diag(1/np.sqrt(w))), v.T)
    rec = np.matmul(rec, spatial_whitener)

    noise_wf = np.zeros((1000, temporal_size))
    count = 0
    while count < 1000:
        tt = np.random.randint(T-temporal_size)
        cc = np.random.randint(C)
        temp = rec[tt:(tt+temporal_size), cc]
        temp_idxnoise = idxNoise[tt:(tt+temporal_size), cc]
        if np.sum(temp_idxnoise == 0) == 0:
            noise_wf[count] = temp
            count += 1

    w, v = np.linalg.eig(np.cov(noise_wf.T))

    temporal_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T)

    return spatial_SIG, temporal_SIG
Ejemplo n.º 9
0
def run(score,
        spike_index_clear,
        spike_index_collision,
        output_directory='tmp/',
        recordings_filename='standarized.bin'):
    """Process spikes

    Parameters
    ----------
    score: numpy.ndarray (n_spikes, n_features, n_channels)
        3D array with the scores for the clear spikes, first simension is
        the number of spikes, second is the nymber of features and third the
        number of channels

    spike_index_clear: numpy.ndarray (n_clear_spikes, 2)
        2D array with indexes for clear spikes, first column contains the
        spike location in the recording and the second the main channel
        (channel whose amplitude is maximum)

    spike_index_collision: numpy.ndarray (n_collided_spikes, 2)
        2D array with indexes for collided spikes, first column contains the
        spike location in the recording and the second the main channel
        (channel whose amplitude is maximum)

    output_directory: str, optional
        Output directory (relative to CONFIG.data.root_folder) used to load
        the recordings to generate templates, defaults to tmp/

    recordings_filename: str, optional
        Recordings filename (relative to CONFIG.data.root_folder/
        output_directory) used to generate the templates, defaults to
        whitened.bin

    Returns
    -------
    spike_train_clear: numpy.ndarray (n_clear_spikes, 2)
        A 2D array for clear spikes whose first column indicates the spike
        time and the second column the neuron id determined by the clustering
        algorithm

    templates: numpy.ndarray (n_channels, waveform_size, n_templates)
        A 3D array with the templates

    spike_index_collision: numpy.ndarray (n_collided_spikes, 2)
        A 2D array for collided spikes whose first column indicates the spike
        time and the second column the neuron id determined by the clustering
        algorithm

    Examples
    --------

    .. literalinclude:: ../examples/process.py

    """
    CONFIG = read_config()
    MAIN_CHANNEL = 1

    startTime = datetime.datetime.now()

    Time = {'t': 0, 'c': 0, 'm': 0, 's': 0, 'e': 0}

    logger = logging.getLogger(__name__)

    nG = len(CONFIG.channelGroups)
    nneigh = np.max(np.sum(CONFIG.neighChannels, 0))
    n_coreset = 0
    K = 0

    # first column: spike_time
    # second column: cluster id
    spike_train_clear = np.zeros((0, 2), 'int32')

    if CONFIG.clustering.clustering_method == 'location':
        spike_index_clear_proc = np.zeros((0, 2), 'int32')
        main_channel_index = spike_index_clear[:, MAIN_CHANNEL]
        for i, c in enumerate(np.unique(main_channel_index)):
            logger.info('Processing channel {}'.format(i))
            idx = main_channel_index == c
            score_c = score[idx]
            spike_index_clear_c = spike_index_clear[idx]

            ##########
            # Triage #
            ##########

            # TODO: refactor this as CONFIG.doTriage was removed
            doTriage = True
            _b = datetime.datetime.now()
            logger.info('Triaging events with main channel {}'.format(c))
            index_keep = triage(score_c, 0, CONFIG.triage.nearest_neighbors,
                                CONFIG.triage.percent, doTriage)
            Time['t'] += (datetime.datetime.now() - _b).total_seconds()

            # add untriaged spike index to spike_index_clear_group
            # and triaged spike index to spike_index_collision
            spike_index_clear_proc = np.concatenate(
                (spike_index_clear_proc, spike_index_clear_c[index_keep]),
                axis=0)
            spike_index_collision = np.concatenate(
                (spike_index_collision, spike_index_clear_c[~index_keep]),
                axis=0)

            # TODO: add documentation for all of this part, until the
            # "cleaning" commend

            # keep untriaged score only
            score_c = score_c[index_keep]
            group = np.arange(score_c.shape[0])
            mask = np.ones([score_c.shape[0], 1])
            _b = datetime.datetime.now()
            logger.info('Clustering events with main channel {}'.format(c))
            if i == 0:
                global_vbParam, global_maskedData = spikesort(
                    score_c, mask, group, CONFIG)
                score_proc = score_c
            else:
                local_vbParam, local_maskedData = spikesort(
                    score_c, mask, group, CONFIG)
                global_vbParam.muhat = np.concatenate(
                    [global_vbParam.muhat, local_vbParam.muhat], axis=1)
                global_vbParam.Vhat = np.concatenate(
                    [global_vbParam.Vhat, local_vbParam.Vhat], axis=2)
                global_vbParam.invVhat = np.concatenate(
                    [global_vbParam.invVhat, local_vbParam.invVhat], axis=2)
                global_vbParam.lambdahat = np.concatenate(
                    [global_vbParam.lambdahat, local_vbParam.lambdahat],
                    axis=0)
                global_vbParam.nuhat = np.concatenate(
                    [global_vbParam.nuhat, local_vbParam.nuhat], axis=0)
                global_vbParam.ahat = np.concatenate(
                    [global_vbParam.ahat, local_vbParam.ahat], axis=0)
                global_maskedData.sumY = np.concatenate(
                    [global_maskedData.sumY, local_maskedData.sumY], axis=0)
                global_maskedData.sumYSq = np.concatenate(
                    [global_maskedData.sumYSq, local_maskedData.sumYSq],
                    axis=0)
                global_maskedData.sumEta = np.concatenate(
                    [global_maskedData.sumEta, local_maskedData.sumEta],
                    axis=0)
                global_maskedData.weight = np.concatenate(
                    [global_maskedData.weight, local_maskedData.weight],
                    axis=0)
                global_maskedData.groupMask = np.concatenate(
                    [global_maskedData.groupMask, local_maskedData.groupMask],
                    axis=0)
                global_maskedData.meanY = np.concatenate(
                    [global_maskedData.meanY, local_maskedData.meanY], axis=0)
                global_maskedData.meanYSq = np.concatenate(
                    [global_maskedData.meanYSq, local_maskedData.meanYSq],
                    axis=0)
                global_maskedData.meanEta = np.concatenate(
                    [global_maskedData.meanEta, local_maskedData.meanEta],
                    axis=0)
                score_proc = np.concatenate([score_proc, score_c], axis=0)

        logger.info('merging all channels')
        L = np.ones(global_vbParam.muhat.shape[1])
        global_vbParam.update_local(global_maskedData)
        suffStat = suffStatistics(global_maskedData, global_vbParam)
        global_vbParam, suffStat, L = merge_move(global_maskedData,
                                                 global_vbParam, suffStat,
                                                 CONFIG, L, 0)
        assignmentTemp = np.argmax(global_vbParam.rhat, axis=1)
        assignment = np.zeros(score_proc.shape[0], 'int16')

        for j in range(score_proc.shape[0]):
            assignment[j] = assignmentTemp[j]

        idx_triage = cluster_triage(global_vbParam, score_proc, 3)
        assignment[idx_triage] = -1
        Time['s'] += (datetime.datetime.now() - _b).total_seconds()

        ############
        # Cleaning #
        ############

        # TODO: describe this step

        spike_train_clear = np.concatenate([
            spike_index_clear_proc[~idx_triage, 0:1:], assignment[~idx_triage,
                                                                  np.newaxis]
        ],
                                           axis=1)
        spike_index_collision = np.concatenate(
            [spike_index_collision, spike_index_clear_proc[idx_triage]])

    else:

        # according to the docs if clustering method is not 2+3, you can set
        # 3 x neighboring_channels, but I do not see where the
        # neighboring_channels is being parsed on this else statemente

        c_idx = np.ones((CONFIG.recordings.n_channels, nneigh),
                        'int32') * CONFIG.recordings.n_channels

        for c in range(CONFIG.recordings.n_channels):
            ch_idx, _ = order_channels_by_distance(
                c,
                np.where(CONFIG.neighChannels[c])[0], CONFIG.geom)
            c_idx[c, :ch_idx.shape[0]] = ch_idx

        # iterate over every channel group [missing documentation for this
        # function]. why is this order needed?
        for g in range(nG):
            logger.info("Processing group {} in {} groups.".format(g + 1, nG))
            logger.info("Processiing data (triage, coreset, masking) ...")
            channels = CONFIG.channelGroups[g]
            neigh_chans = np.where(
                np.sum(CONFIG.neighChannels[channels], axis=0) > 0)[0]

            score_group = np.zeros(
                (0, CONFIG.spikes.temporal_features, neigh_chans.shape[0]))
            coreset_id_group = np.zeros((0), 'int32')
            mask_group = np.zeros((0, neigh_chans.shape[0]))
            spike_index_clear_group = np.zeros((0, 2), 'int32')

            # go through every channel in the group
            for c in channels:

                # index of data whose main channel is c
                idx = spike_index_clear[:, MAIN_CHANNEL] == c
                if np.sum(idx) > 0:

                    # score whose main channel is c
                    score_c = score[idx]
                    # spike_index_clear whose main channel is c
                    spike_index_clear_c = spike_index_clear[idx]

                    ##########
                    # Triage #
                    ##########

                    # TODO: refactor this as CONFIG.doTriage was removed
                    doTriage = True

                    _b = datetime.datetime.now()
                    index_keep = triage(score_c, 0,
                                        CONFIG.triage.nearest_neighbors,
                                        CONFIG.triage.percent, doTriage)
                    Time['t'] += (datetime.datetime.now() - _b).total_seconds()

                    # add untriaged spike index to spike_index_clear_group
                    # and triaged spike index to spike_index_collision
                    spike_index_clear_group = np.concatenate(
                        (spike_index_clear_group,
                         spike_index_clear_c[index_keep]),
                        axis=0)
                    spike_index_collision = np.concatenate(
                        (spike_index_collision,
                         spike_index_clear_c[~index_keep]),
                        axis=0)

                    # keep untriaged score only
                    score_c = score_c[index_keep]

                    ###########
                    # Coreset #
                    ###########

                    # TODO: refactor this as CONFIG.doCoreset was removed
                    doCoreset = True

                    _b = datetime.datetime.now()
                    coreset_id = coreset(score_c, CONFIG.coreset.clusters,
                                         CONFIG.coreset.threshold, doCoreset)
                    Time['c'] += (datetime.datetime.now() - _b).total_seconds()

                    ###########
                    # Masking #
                    ###########

                    _b = datetime.datetime.now()
                    mask = getmask(score_c, coreset_id,
                                   CONFIG.clustering.masking_threshold,
                                   CONFIG.spikes.temporal_features)
                    Time['m'] += (datetime.datetime.now() - _b).total_seconds()

                    ################
                    # collect data #
                    ################

                    # restructure score_c and mask to have same number of
                    # channels as score_group
                    score_temp = np.zeros(
                        (score_c.shape[0], CONFIG.spikes.temporal_features,
                         neigh_chans.shape[0]))
                    mask_temp = np.zeros((mask.shape[0], neigh_chans.shape[0]))
                    nneigh_c = np.sum(c_idx[c] < CONFIG.recordings.n_channels)
                    for j in range(nneigh_c):
                        c_interest = np.where(neigh_chans == c_idx[c, j])[0][0]
                        score_temp[:, :, c_interest] = score_c[:, :, j]
                        mask_temp[:, c_interest] = mask[:, j]

                    # add score, coreset_id, mask to the groups
                    score_group = np.concatenate((score_group, score_temp),
                                                 axis=0)
                    mask_group = np.concatenate((mask_group, mask_temp),
                                                axis=0)
                    coreset_id_group = np.concatenate(
                        (coreset_id_group, coreset_id + n_coreset + 1), axis=0)
                    n_coreset += np.max(coreset_id) + 1

            if score_group.shape[0] > 0:
                ##############
                # Clustering #
                ##############

                _b = datetime.datetime.now()
                logger.info("Clustering...")
                coreset_id_group = coreset_id_group - 1
                n_coreset = 0
                cluster_id = spikesort(score_group, mask_group,
                                       coreset_id_group, CONFIG)
                Time['s'] += (datetime.datetime.now() - _b).total_seconds()

                ############
                # Cleaning #
                ############

                # model based triage
                idx_triage = (cluster_id == -1)

                # concatenate spike index with cluster id of untriaged ones
                # to create spike_train_clear
                si_clustered = spike_index_clear_group[~idx_triage]
                spt = si_clustered[:, [0]]
                cluster_id = cluster_id[~idx_triage][:, np.newaxis]

                spike_train_temp = np.concatenate((spt, cluster_id + K),
                                                  axis=1)
                spike_train_clear = np.concatenate(
                    (spike_train_clear, spike_train_temp), axis=0)
                K += np.amax(cluster_id) + 1

                # concatenate triaged spike_index_clear_group
                # into spike_index_collision
                spike_index_collision = np.concatenate(
                    (spike_index_collision,
                     spike_index_clear_group[idx_triage]),
                    axis=0)

    #################
    # Get templates #
    #################

    _b = datetime.datetime.now()
    logger.info("Getting Templates...")
    path_to_recordings = os.path.join(CONFIG.data.root_folder,
                                      output_directory, recordings_filename)
    merge_threshold = CONFIG.templates.merge_threshold
    spike_train_clear, templates = gam_templates(
        spike_train_clear, path_to_recordings, CONFIG.spikeSize,
        CONFIG.templatesMaxShift, merge_threshold, CONFIG.neighChannels)

    Time['e'] += (datetime.datetime.now() - _b).total_seconds()

    currentTime = datetime.datetime.now()

    if CONFIG.clustering.clustering_method == 'location':
        logger.info("Mainprocess done in {0} seconds.".format(
            (currentTime - startTime).seconds))
        logger.info("\ttriage:\t{0} seconds".format(Time['t']))
        logger.info("\tclustering:\t{0} seconds".format(Time['s']))
        logger.info("\ttemplates:\t{0} seconds".format(Time['e']))
    else:
        logger.info("\ttriage:\t{0} seconds".format(Time['t']))
        logger.info("\tcoreset:\t{0} seconds".format(Time['c']))
        logger.info("\tmasking:\t{0} seconds".format(Time['m']))
        logger.info("\tclustering:\t{0} seconds".format(Time['s']))
        logger.info("\ttemplates:\t{0} seconds".format(Time['e']))

    return spike_train_clear, templates, spike_index_collision
Ejemplo n.º 10
0
def crop_and_align_templates(fname_templates, save_dir, CONFIG):
    """Crop (spatially) and align (temporally) templates

    Parameters
    ----------

    Returns
    -------
    """
    logger = logging.getLogger(__name__)
    
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    # load templates
    templates = np.load(fname_templates)
    
    n_units, n_times, n_channels = templates.shape
    mcs = templates.ptp(1).argmax(1)
    spike_size = (CONFIG.spike_size_nn - 1)*2 + 1

    ########## TEMPORALLY ALIGN TEMPLATES #################
    
    # template on max channel only
    templates_max_channel = np.zeros((n_units, n_times))
    for k in range(n_units):
        templates_max_channel[k] = templates[k, :, mcs[k]]

    # align them
    ref = np.mean(templates_max_channel, axis=0)
    upsample_factor = 8
    nshifts = spike_size//2

    shifts = align_get_shifts_with_ref(
        templates_max_channel, ref, upsample_factor, nshifts)

    templates_aligned = shift_chans(templates, shifts)
    
    # crop out the edges since they have bad artifacts
    templates_aligned = templates_aligned[:, nshifts//2:-nshifts//2]

    ########## Find High Energy Center of Templates #################

    templates_max_channel_aligned = np.zeros((n_units, templates_aligned.shape[1]))
    for k in range(n_units):
        templates_max_channel_aligned[k] = templates_aligned[k, :, mcs[k]]

    # determin temporal center of templates and crop around it
    total_energy = np.sum(np.square(templates_max_channel_aligned), axis=0)
    center = np.argmax(np.convolve(total_energy, np.ones(spike_size//2), 'same'))
    templates_aligned = templates_aligned[:, (center-spike_size//2):(center+spike_size//2+1)]
    
    ########## spatially crop (only keep neighbors) #################

    neighbors = CONFIG.neigh_channels
    n_neigh = np.max(np.sum(CONFIG.neigh_channels, axis=1))
    templates_cropped = np.zeros((n_units, spike_size, n_neigh))

    for k in range(n_units):

        # get neighbors for the main channel in the kth template
        ch_idx = np.where(neighbors[mcs[k]])[0]

        # order channels
        ch_idx, _ = order_channels_by_distance(mcs[k], ch_idx, CONFIG.geom)

        # new kth template is the old kth template by keeping only
        # ordered neighboring channels
        templates_cropped[k, :, :ch_idx.shape[0]] = templates_aligned[k][:, ch_idx]

    fname_templates_cropped = os.path.join(save_dir, 'templates_cropped.npy')
    np.save(fname_templates_cropped, templates_cropped)

    return fname_templates_cropped
Ejemplo n.º 11
0
def get_waveforms(recording, spike_index, proj, neighbors, geom, nnt, th):
    """Extract waveforms from detected spikes

    Parameters
    ----------
    recording: matrix (observations, number of channels)
        Multi-channel recordings
    spike_index: matrix (number of spikes, 2)
        Spike index matrix, as returned from any of the detectors
    proj: matrix (waveform temporal length, number of features)
        Projection matrix that reduces the dimension of waveform
    neighbors: matrix (number of channels, number of channel)
        Neighbors matrix
    geom: matrix (number of channels, 2)
        Each row is the x,y coordinate of each channel
    nnt: class
        Class for Neural Network based triage
    th: int
        Threshold for Neural Network triage algorithm

    Returns
    -------
    score: matrix (observations, number of features, number of neighbors)
    clear_spike: boolean vector (observations)
        Boolean indicating if it is a clear spike or not

    Notes
    -----
    Le'ts consider a single channel recording V, where V is a vector of
    length 1 x T. When a spike is detected at time t, then (V_(t-R),
    V_(t-R+1), ..., V_t, V_(t+1),...V_(t+R)) is going to be a waveform.
    (a small snippet from the recording around the spike time)
    """
    # column ids for index matrix
    SPIKE_TIME, MAIN_CHANNEL = 0, 1

    n_times, n_channels = recording.shape
    n_spikes, _ = spike_index.shape
    window_size, n_features = proj.shape
    spike_size = int((window_size - 1) / 2)
    nneigh = np.max(np.sum(neighbors, 0))

    recording = np.concatenate((recording, np.zeros((n_times, 1))), axis=1)
    c_idx = np.ones((n_channels, nneigh), 'int32') * n_channels
    for c in range(n_channels):
        ch_idx, _ = order_channels_by_distance(c,
                                               np.where(neighbors[c])[0], geom)
        c_idx[c, :ch_idx.shape[0]] = ch_idx

    spike_index_clear = np.zeros((0, 2), 'int32')
    spike_index_collision = np.zeros((0, 2), 'int32')
    score = np.zeros((0, n_features, nneigh), 'float32')

    nbuff = 500000
    wf = np.zeros((nbuff, window_size, nneigh), 'float32')

    count = 0
    for j in range(n_spikes):
        t = spike_index[j, SPIKE_TIME]
        c = c_idx[spike_index[j, MAIN_CHANNEL]]
        wf[count] = recording[(t - spike_size):(t + spike_size + 1), c]
        count += 1

        # when we gathered enough spikes, go through triage NN and save score
        if (count == nbuff) or (j == n_spikes - 1):
            # if we seek all spikes before reaching the buffer size,
            # size of buffer becomes the number of leftover spikes
            if j == n_spikes - 1:
                nbuff = count
                wf = wf[:nbuff]

            # going through triage NN.
            # The output is 1 for clear spike and 0 otherwise
            clear_spike = nnt.nn_triage(wf, th)

            # collect clear and colliding spikes
            spike_index_buff = spike_index[(j - nbuff + 1):(j + 1)]
            spike_index_clear = np.concatenate(
                (spike_index_clear, spike_index_buff[clear_spike]))
            spike_index_collision = np.concatenate(
                (spike_index_collision, spike_index_buff[~clear_spike]))

            # calculate score and collect into variable 'score'
            reshaped_wf = np.reshape(np.transpose(wf[clear_spike], (0, 2, 1)),
                                     (-1, window_size))
            score_temp = np.transpose(
                np.reshape(np.matmul(reshaped_wf, proj),
                           (-1, nneigh, n_features)), (0, 2, 1))
            score = np.concatenate((score, score_temp), axis=0)

            # set counter back to zero
            count = 0

    return spike_index_clear, score, spike_index_collision
Ejemplo n.º 12
0
def noise_cov(path_to_data, dtype, n_channels, data_format, neighbors, geom,
              temporal_size):
    """[Description]

    Parameters
    ----------
    path_to_data: str
        Path to recordings data
    dtype: str
        dtype for recordings
    n_channels: int
        Number of channels in the recordings
    data_format: str
        Recordings shape ('wide', 'long')
    neighbors: numpy.ndarray
        Neighbors matrix
    geom: numpy.ndarray
        Cartesian coordinates for the channels
    temporal_size:
        Waveform size

    Returns
    -------
    """
    c_ref = np.argmax(np.sum(neighbors, 0))
    ch_idx = np.where(neighbors[c_ref])[0]
    ch_idx, temp = order_channels_by_distance(c_ref, ch_idx, geom)

    rec = RecordingsReader(path_to_data, dtype=dtype, n_channels=n_channels,
                           data_format=data_format, mmap=False)
    rec = rec[:, ch_idx]

    T, C = rec.shape
    idxNoise = np.zeros((T, C))

    R = int((temporal_size-1)/2)
    for c in range(C):
        idx_temp = np.where(rec[:, c] > 3)[0]
        for j in range(-R, R+1):
            idx_temp2 = idx_temp + j
            idx_temp2 = idx_temp2[np.logical_and(
                idx_temp2 >= 0, idx_temp2 < T)]
            rec[idx_temp2, c] = np.nan
        idxNoise_temp = (rec[:, c] == rec[:, c])
        rec[:, c] = rec[:, c]/np.nanstd(rec[:, c])

        rec[~idxNoise_temp, c] = 0
        idxNoise[idxNoise_temp, c] = 1

    spatial_cov = np.divide(np.matmul(rec.T, rec),
                            np.matmul(idxNoise.T, idxNoise))

    w, v = np.linalg.eig(spatial_cov)
    spatial_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T)

    spatial_whitener = np.matmul(np.matmul(v, np.diag(1/np.sqrt(w))), v.T)
    rec = np.matmul(rec, spatial_whitener)

    noise_wf = np.zeros((1000, temporal_size))
    count = 0
    while count < 1000:
        tt = np.random.randint(T-temporal_size)
        cc = np.random.randint(C)
        temp = rec[tt:(tt+temporal_size), cc]
        temp_idxnoise = idxNoise[tt:(tt+temporal_size), cc]
        if np.sum(temp_idxnoise == 0) == 0:
            noise_wf[count] = temp
            count += 1

    w, v = np.linalg.eig(np.cov(noise_wf.T))

    temporal_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T)

    return spatial_SIG, temporal_SIG
Ejemplo n.º 13
0
def crop_and_align_templates(big_templates, R, neighbors, geom):
    """Crop (spatially) and align (temporally) templates

    Parameters
    ----------

    Returns
    -------
    """
    # copy templates to avoid modifying the original ones
    big_templates = np.copy(big_templates)

    K, _, _ = big_templates.shape

    # main channel for each template and amplitudes
    mainC = np.argmax(np.amax(np.abs(big_templates), axis=1), axis=1)
    amps = np.amax(np.abs(big_templates), axis=(1, 2))

    # get a template on a main channel and align them
    K_big = np.argmax(amps)
    templates_mainc = np.zeros((K, big_templates.shape[1]))
    t_rec = big_templates[K_big, :, mainC[K_big]]
    t_rec = t_rec/np.sqrt(np.sum(np.square(t_rec)))

    for k in range(K):
        t1 = big_templates[k, :, mainC[k]]
        t1 = t1/np.sqrt(np.sum(np.square(t1)))
        shift = align_templates(t1, t_rec)

        if shift > 0:
            templates_mainc[k, :(big_templates.shape[1]-shift)] = t1[shift:]
            big_templates[k, :(big_templates.shape[1]-shift)
                          ] = big_templates[k, shift:]

        elif shift < 0:
            templates_mainc[k, (-shift):] = t1[:(big_templates.shape[1]+shift)]
            big_templates[k,
                          (-shift):] = big_templates[k,
                                                     :(big_templates.shape[1]
                                                       + shift)]

        else:
            templates_mainc[k] = t1

    # determin temporal center of templates and crop around it
    R2 = int(R/2)
    center = np.argmax(np.convolve(
        np.sum(np.square(templates_mainc), 0), np.ones(2*R2+1), 'valid')) + R2
    big_templates = big_templates[:, (center-3*R):(center+3*R+1)]

    # spatially crop
    nneigh = np.max(np.sum(neighbors, 0))

    small_templates = np.zeros((K, big_templates.shape[1], nneigh))

    for k in range(K):
        ch_idx = np.where(neighbors[mainC[k]])[0]
        ch_idx, temp = order_channels_by_distance(mainC[k], ch_idx, geom)
        small_templates[k, :, :ch_idx.shape[0]] = big_templates[k][:, ch_idx]

    return small_templates
Ejemplo n.º 14
0
def noise_cov(path_to_data, neighbors, geom, temporal_size):
    """Compute noise temporal and spatial covariance

    Parameters
    ----------
    path_to_data: str
        Path to recordings data

    neighbors: numpy.ndarray
        Neighbors matrix

    geom: numpy.ndarray
        Cartesian coordinates for the channels

    temporal_size:
        Waveform size

    Returns
    -------
    spatial_SIG: numpy.ndarray

    temporal_SIG: numpy.ndarray
    """
    logger = logging.getLogger(__name__)

    logger.debug('Computing noise_cov. Neighbors shape: {}, geom shape: {} '
                 'temporal_size: {}'.format(neighbors.shape, geom.shape,
                                            temporal_size))

    c_ref = np.argmax(np.sum(neighbors, 0))
    ch_idx = np.where(neighbors[c_ref])[0]
    ch_idx, temp = order_channels_by_distance(c_ref, ch_idx, geom)

    rec = RecordingsReader(path_to_data, loader='array')
    rec = rec[:, ch_idx]

    T, C = rec.shape
    idxNoise = np.zeros((T, C))
    R = int((temporal_size-1)/2)

    for c in range(C):

        idx_temp = np.where(rec[:, c] > 3)[0]

        for j in range(-R, R+1):

            idx_temp2 = idx_temp + j
            idx_temp2 = idx_temp2[np.logical_and(idx_temp2 >= 0,
                                                 idx_temp2 < T)]
            rec[idx_temp2, c] = np.nan

        idxNoise_temp = (rec[:, c] == rec[:, c])
        rec[:, c] = rec[:, c]/np.nanstd(rec[:, c])

        rec[~idxNoise_temp, c] = 0
        idxNoise[idxNoise_temp, c] = 1

    spatial_cov = np.divide(np.matmul(rec.T, rec),
                            np.matmul(idxNoise.T, idxNoise))

    w, v = np.linalg.eig(spatial_cov)
    spatial_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T)

    spatial_whitener = np.matmul(np.matmul(v, np.diag(1/np.sqrt(w))), v.T)
    rec = np.matmul(rec, spatial_whitener)

    noise_wf = np.zeros((1000, temporal_size))
    count = 0

    while count < 1000:

        tt = np.random.randint(T-temporal_size)
        cc = np.random.randint(C)
        temp = rec[tt:(tt+temporal_size), cc]
        temp_idxnoise = idxNoise[tt:(tt+temporal_size), cc]

        if np.sum(temp_idxnoise == 0) == 0:
            noise_wf[count] = temp
            count += 1

    w, v = np.linalg.eig(np.cov(noise_wf.T))

    temporal_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T)

    logger.debug('spatial_SIG shape: {} temporal_SIG shape: {}'
                 .format(spatial_SIG.shape, temporal_SIG.shape))

    return spatial_SIG, temporal_SIG