Beispiel #1
0
def two_templates_dist_linear_align(templates1,
                                    templates2,
                                    max_shift=5,
                                    step=0.5):

    K1, R, C = templates1.shape
    K2 = templates2.shape[0]

    shifts = np.arange(-max_shift, max_shift + step, step)
    ptps1 = templates1.ptp(1)
    max_chans1 = np.argmax(ptps1, 1)
    ptps2 = templates2.ptp(1)
    max_chans2 = np.argmax(ptps2, 1)

    shifted_templates = np.zeros((len(shifts), K2, R, C))
    for ii, s in enumerate(shifts):
        shifted_templates[ii] = shift_chans(templates2, np.ones(K2) * s)

    distance = np.ones((K1, K2)) * 1e4
    for k in range(K1):
        candidates = np.abs(ptps2[:, max_chans1[k]] - ptps1[k, max_chans1[k]]
                            ) / ptps1[k, max_chans1[k]] < 0.5

        dist = np.min(
            np.sum(np.square(templates1[k][np.newaxis, np.newaxis] -
                             shifted_templates[:, candidates]),
                   axis=(2, 3)), 0)
        dist = np.sqrt(dist)
        distance[k, candidates] = dist

    return distance
Beispiel #2
0
def template_dist_linear_align(templates,
                               distance=None,
                               units=None,
                               max_shift=5,
                               step=0.5):

    K, R, C = templates.shape

    shifts = np.arange(-max_shift, max_shift + step, step)
    ptps = templates.ptp(1)
    max_chans = np.argmax(ptps, 1)

    shifted_templates = np.zeros((len(shifts), K, R, C))
    for ii, s in enumerate(shifts):
        shifted_templates[ii] = shift_chans(templates, np.ones(K) * s)

    if distance is None:
        distance = np.ones((K, K)) * 1e4

    if units is None:
        units = np.arange(K)

    for k in units:
        candidates = np.abs(ptps[:, max_chans[k]] - ptps[k, max_chans[k]]
                            ) / ptps[k, max_chans[k]] < 0.5

        dist = np.min(
            np.sum(np.square(templates[k][np.newaxis, np.newaxis] -
                             shifted_templates[:, candidates]),
                   axis=(2, 3)), 0)
        distance[k, candidates] = dist
        distance[candidates, k] = dist

    return distance
Beispiel #3
0
    def standardize_templates(self):

        # standardize templates
        ptp = self.templates.ptp(1)
        self.templates = self.templates / ptp[:, None]

        ref = np.mean(self.templates, 0)
        shifts = align_get_shifts_with_ref(self.templates, ref)
        self.templates = shift_chans(self.templates, shifts)
Beispiel #4
0
def align_tc(tc, ref):
    n_units, n_timepoints, n_channels = tc.shape

    max_channels = np.abs(tc).max(1).argmax(1)

    main_tc = np.zeros((n_units, n_timepoints))
    for j in range(n_units):
        main_tc[j] = np.abs(tc[j][:, max_channels[j]])

    best_shifts = align_get_shifts_tc(main_tc, ref, upsample_factor=1)
    shifted_tc = shift_chans(tc, best_shifts)

    return shifted_tc
Beispiel #5
0
    def align_step(self, local):

        if self.verbose:
            print("chan " + str(self.channel) + ", aligning")

        # align waveforms by finding best shfits
        if local:
            mc = np.where(self.loaded_channels == self.channel)[0][0]
            best_shifts = align_get_shifts_with_ref(self.wf_global[:, :, mc])
            self.shifts[self.indices_in] = best_shifts
        else:
            best_shifts = self.shifts[self.indices_in]

        self.wf_global = shift_chans(self.wf_global, best_shifts)

        if self.ari_flag:
            pass
Beispiel #6
0
def align_waveforms_parallel(fnames_input_data):

    for fname in fnames_input_data:
        temp = np.load(fname, allow_pickle=True)
        if 'shifts' in temp.files:
            continue

        wf = temp['wf']

        # align
        if wf.shape[0] > 0:
            mc = np.mean(wf, 0).ptp(0).argmax()
            shifts = align_get_shifts_with_ref(wf[:, :, mc], nshifts=3)
            wf = shift_chans(wf, shifts)
        else:
            shifts = None

        temp = dict(temp)
        temp['wf'] = wf
        temp['shifts'] = shifts
        np.savez(fname, **temp)
Beispiel #7
0
    def jitter_templates(self, up_factor=8):
        
        n_templates, n_times = self.templates.shape

        # upsample best fit template
        up_temp = scipy.signal.resample(
            x=self.templates,
            num=n_times*up_factor,
            axis=1)
        up_temp = up_temp.T

        idx = (np.arange(0, n_times)[:,None]*up_factor + np.arange(up_factor))
        up_shifted_temps = up_temp[idx].transpose(2,0,1)
        up_shifted_temps = np.concatenate(
            (up_shifted_temps,
             np.roll(up_shifted_temps, shift=1, axis=1)),
            axis=2)
        self.templates = up_shifted_temps.transpose(0,2,1).reshape(-1, n_times)

        ref = np.mean(self.templates, 0)
        shifts = align_get_shifts_with_ref(
            self.templates, ref, upsample_factor=1)
        self.templates = shift_chans(self.templates, shifts)
Beispiel #8
0
def template_spike_dist_linear_align(templates, spikes, vis_ptp=2.):
    """compares the templates and spikes.

    parameters:
    -----------
    templates: numpy.array shape (K, T, C)
    spikes: numpy.array shape (M, T, C)
    jitter: int
        Align jitter amount between the templates and the spikes.
    upsample int
        Upsample rate of the templates and spikes.
    """

    # get ref template
    max_idx = templates.ptp(1).max(1).argmax(0)
    ref_template = templates[max_idx]
    max_chan = ref_template.ptp(0).argmax(0)
    ref_template = ref_template[:, max_chan]

    # align templates on max channel
    best_shifts = align_get_shifts_with_ref(templates[:, :, max_chan],
                                            ref_template,
                                            nshifts=7)
    templates = shift_chans(templates, best_shifts)

    # align all spikes on max channel
    best_shifts = align_get_shifts_with_ref(spikes[:, :, max_chan],
                                            ref_template,
                                            nshifts=7)
    spikes = shift_chans(spikes, best_shifts)

    # if shifted, cut shifted parts
    # because it is contaminated by roll function
    cut = int(np.ceil(np.max(np.abs(best_shifts))))
    if cut > 0:
        templates = templates[:, cut:-cut]
        spikes = spikes[:, cut:-cut]

    # get visible channels
    vis_ptp = np.min((vis_ptp, np.max(templates.ptp(1))))
    vis_chan = np.where(templates.ptp(1).max(0) >= vis_ptp)[0]
    templates = templates[:, :, vis_chan].reshape(templates.shape[0], -1)
    spikes = spikes[:, :, vis_chan].reshape(spikes.shape[0], -1)

    # get a subset of locations with maximal difference
    if templates.shape[0] == 1:
        # single unit: all timepoints
        idx = np.arange(templates.shape[1])
    elif templates.shape[0] == 2:
        # two units:
        # get difference
        diffs = np.abs(np.diff(templates, axis=0)[0])
        # points with large difference
        idx = np.where(diffs > 1.5)[0]
        min_diff_points = 5
        if len(idx) < 5:
            idx = np.argsort(diffs)[-min_diff_points:]
    else:
        # more than two units:
        # diff is mean diff to the largest unit
        diffs = np.mean(np.abs(templates - templates[max_idx][None]), axis=0)

        idx = np.where(diffs > 1.5)[0]
        min_diff_points = 5
        if len(idx) < 5:
            idx = np.argsort(diffs)[-min_diff_points:]

    templates = templates[:, idx]
    spikes = spikes[:, idx]

    dist = scipy.spatial.distance.cdist(templates, spikes)

    return dist.T
Beispiel #9
0
    def merge_templates_parallel(self, pairs):
        """Whether to merge two templates or not.
        """
        n_samples = 2000
        p_val_threshold = 0.9
        merge_pairs = []

        for pair in pairs:
            unit1, unit2 = pair

            fname_out = os.path.join(self.save_dir,
                                     'unit_{}_{}.npz'.format(unit1, unit2))

            if os.path.exists(fname_out):
                if np.load(fname_out)['merge']:
                    merge_pairs.append(pair)

            else:

                # get spikes times and soft assignment
                idx1 = self.spike_train[:, 1] == unit1
                spt1 = self.spike_train[idx1, 0]
                prob1 = self.soft_assignment[idx1]
                shift1 = self.shifts[idx1]
                scale1 = self.scales[idx1]
                n_spikes1 = self.n_spikes_soft[unit1]

                idx2 = self.spike_train[:, 1] == unit2
                spt2 = self.spike_train[idx2, 0]
                prob2 = self.soft_assignment[idx2]
                shift2 = self.shifts[idx2]
                scale2 = self.scales[idx2]
                n_spikes2 = self.n_spikes_soft[unit2]

                # randomly subsample
                if n_spikes1 + n_spikes2 > n_samples:
                    ratio1 = n_spikes1 / float(n_spikes1 + n_spikes2)
                    n_samples1 = np.min((int(n_samples * ratio1), n_spikes1))
                    n_samples2 = n_samples - n_samples1

                else:
                    n_samples1 = n_spikes1
                    n_samples2 = n_spikes2
                idx1_ = np.random.choice(len(spt1),
                                         n_samples1,
                                         replace=False,
                                         p=prob1 / np.sum(prob1))
                idx2_ = np.random.choice(len(spt2),
                                         n_samples2,
                                         replace=False,
                                         p=prob2 / np.sum(prob2))
                spt1 = spt1[idx1_]
                spt2 = spt2[idx2_]
                shift1 = shift1[idx1_]
                shift2 = shift2[idx2_]
                scale1 = scale1[idx1_]
                scale2 = scale2[idx2_]

                ptp_max = self.ptps[[unit1, unit2]].max(0)
                mc = ptp_max.argmax()
                vis_chan = np.where(ptp_max > 1)[0]

                # align two units
                shift_temp = (self.templates[unit2, :, mc].argmin() -
                              self.templates[unit1, :, mc].argmin())
                spt2 += shift_temp

                # load residuals
                wfs1, skipped_idx1 = self.reader_residual.read_waveforms(
                    spt1, self.spike_size, vis_chan)
                spt1 = np.delete(spt1, skipped_idx1)
                shift1 = np.delete(shift1, skipped_idx1)
                scale1 = np.delete(scale1, skipped_idx1)

                wfs2, skipped_idx2 = self.reader_residual.read_waveforms(
                    spt2, self.spike_size, vis_chan)
                spt2 = np.delete(spt2, skipped_idx1)
                shift2 = np.delete(shift2, skipped_idx2)
                scale2 = np.delete(scale2, skipped_idx2)

                # align residuals
                wfs1 = shift_chans(wfs1, -shift1)
                wfs2 = shift_chans(wfs2, -shift2)

                # make clean waveforms
                wfs1 += scale1[:, None, None] * self.templates[[unit1], :,
                                                               vis_chan].T
                if shift_temp > 0:
                    temp_2_shfted = self.templates[[unit2], shift_temp:,
                                                   vis_chan].T
                    wfs2[:, :-shift_temp] += scale2[:, None,
                                                    None] * temp_2_shfted
                elif shift_temp < 0:
                    temp_2_shfted = self.templates[[unit2], :shift_temp,
                                                   vis_chan].T
                    wfs2[:,
                         -shift_temp:] += scale2[:, None, None] * temp_2_shfted
                else:
                    wfs2 += scale2[:, None, None] * self.templates[[unit2], :,
                                                                   vis_chan].T

                # compute spatial covariance
                spatial_whitener = self.get_spatial_whitener(vis_chan)
                # whiten
                wfs1_w = np.matmul(wfs1, spatial_whitener)
                wfs2_w = np.matmul(wfs2, spatial_whitener)
                wfs1_w = np.matmul(wfs1_w.transpose(0, 2, 1),
                                   self.temporal_whitener).transpose(0, 2, 1)
                wfs2_w = np.matmul(wfs2_w.transpose(0, 2, 1),
                                   self.temporal_whitener).transpose(0, 2, 1)

                temp_diff_w = np.mean(wfs1_w, 0) - np.mean(wfs2_w, 0)
                c_w = np.sum(0.5 * (np.mean(wfs1_w, 0) + np.mean(wfs2_w, 0)) *
                             temp_diff_w)
                dat1_w = np.sum(wfs1_w * temp_diff_w, (1, 2))
                dat2_w = np.sum(wfs2_w * temp_diff_w, (1, 2))
                dat_all = np.hstack((dat1_w, dat2_w))
                p_val = dp(dat_all)[1]

                if p_val > p_val_threshold:
                    merge = True
                else:
                    merge = False

                centers_dist = np.linalg.norm(temp_diff_w)

                if p_val > p_val_threshold:
                    merge = True
                else:
                    merge = False

                centers_dist = np.linalg.norm(temp_diff_w)
                np.savez(fname_out,
                         merge=merge,
                         dat1_w=dat1_w,
                         dat2_w=dat2_w,
                         centers_dist=centers_dist,
                         p_val=p_val)

                if merge:
                    merge_pairs.append(pair)

        return merge_pairs
Beispiel #10
0
def load_align_waveforms_parallel(labels_in, save_dir, raw_data, fname_splits,
                                  fname_spike_index, fname_labels_input,
                                  fname_templates, reader_raw, reader_resid,
                                  CONFIG):

    spike_index = np.load(fname_spike_index)
    if fname_splits is None:
        split_labels = spike_index[:, 1]
    else:
        split_labels = np.load(fname_splits)

    # minimum number of spikes per cluster
    rec_len_sec = np.ptp(spike_index[:, 0])
    min_spikes = int(rec_len_sec * CONFIG.cluster.min_fr /
                     CONFIG.recordings.sampling_rate)

    # first read waveforms in a bigger size
    # then align and cut down edges
    if CONFIG.neuralnetwork.apply_nn:
        spike_size_out = CONFIG.spike_size_nn
    else:
        spike_size_out = CONFIG.spike_size
    spike_size_buffer = 3
    spike_size_read = spike_size_out + 2 * spike_size_buffer

    # load data for making clean wfs
    if not raw_data:
        labels_input = np.load(fname_labels_input)
        templates = np.load(fname_templates)

        n_times_templates = templates.shape[1]
        if n_times_templates > spike_size_out:
            n_times_diff = (n_times_templates - spike_size_out) // 2
            templates = templates[:, n_times_diff:-n_times_diff]

    # get waveforms and align
    fname_outs = []
    for id_ in labels_in:
        fname_out = os.path.join(save_dir, 'partition_{}.npz'.format(id_))
        fname_outs.append(fname_out)

        if os.path.exists(fname_out):
            continue

        idx_ = np.where(split_labels == id_)[0]

        # spike times
        spike_times = spike_index[idx_, 0]

        # if it will be subsampled, min spikes should decrease also
        subsample_ratio = np.min(
            (1, CONFIG.cluster.max_n_spikes / float(len(spike_times))))
        min_spikes = int(min_spikes * subsample_ratio)
        # min_spikes needs to be at least 20 to cluster
        min_spikes = np.max((min_spikes, 20))

        # subsample spikes
        (spike_times,
         idx_sampled) = subsample_spikes(spike_times,
                                         CONFIG.cluster.max_n_spikes)

        # max channel and neighbor channels
        channel = int(spike_index[idx_, 1][0])
        neighbor_chans = np.where(CONFIG.neigh_channels[channel])[0]

        if raw_data:
            wf, skipped_idx = reader_raw.read_waveforms(
                spike_times, spike_size_read, neighbor_chans)
            spike_times = np.delete(spike_times, skipped_idx)

        else:

            # get upsampled ids
            template_ids_ = labels_input[idx_][idx_sampled]
            unique_template_ids = np.unique(template_ids_)

            # ids relabelled
            templates_in = templates[unique_template_ids]
            template_ids_in = np.zeros_like(template_ids_)
            for ii, k in enumerate(unique_template_ids):
                template_ids_in[template_ids_ == k] = ii

            # get clean waveforms
            wf, skipped_idx = reader_resid.read_clean_waveforms(
                spike_times, template_ids_in, templates_in, spike_size_read,
                neighbor_chans)
            spike_times = np.delete(spike_times, skipped_idx)
            template_ids_in = np.delete(template_ids_in, skipped_idx)

        # align
        if wf.shape[0] > 0:
            mc = np.where(neighbor_chans == channel)[0][0]
            shifts = align_get_shifts_with_ref(wf[:, :, mc], nshifts=3)
            wf = shift_chans(wf, shifts)
            wf = wf[:, spike_size_buffer:-spike_size_buffer]
        else:
            shifts = None

        if raw_data:
            np.savez(fname_out,
                     spike_times=spike_times,
                     wf=wf,
                     shifts=shifts,
                     channel=channel,
                     min_spikes=min_spikes)
        else:
            np.savez(fname_out,
                     spike_times=spike_times,
                     wf=wf,
                     shifts=shifts,
                     upsampled_ids=template_ids_in,
                     up_templates=templates_in,
                     channel=channel,
                     min_spikes=min_spikes)

    return fname_outs
Beispiel #11
0
def crop_and_align_templates(fname_templates, save_dir, CONFIG):
    """Crop (spatially) and align (temporally) templates

    Parameters
    ----------

    Returns
    -------
    """
    logger = logging.getLogger(__name__)
    
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    # load templates
    templates = np.load(fname_templates)
    
    n_units, n_times, n_channels = templates.shape
    mcs = templates.ptp(1).argmax(1)
    spike_size = (CONFIG.spike_size_nn - 1)*2 + 1

    ########## TEMPORALLY ALIGN TEMPLATES #################
    
    # template on max channel only
    templates_max_channel = np.zeros((n_units, n_times))
    for k in range(n_units):
        templates_max_channel[k] = templates[k, :, mcs[k]]

    # align them
    ref = np.mean(templates_max_channel, axis=0)
    upsample_factor = 8
    nshifts = spike_size//2

    shifts = align_get_shifts_with_ref(
        templates_max_channel, ref, upsample_factor, nshifts)

    templates_aligned = shift_chans(templates, shifts)
    
    # crop out the edges since they have bad artifacts
    templates_aligned = templates_aligned[:, nshifts//2:-nshifts//2]

    ########## Find High Energy Center of Templates #################

    templates_max_channel_aligned = np.zeros((n_units, templates_aligned.shape[1]))
    for k in range(n_units):
        templates_max_channel_aligned[k] = templates_aligned[k, :, mcs[k]]

    # determin temporal center of templates and crop around it
    total_energy = np.sum(np.square(templates_max_channel_aligned), axis=0)
    center = np.argmax(np.convolve(total_energy, np.ones(spike_size//2), 'same'))
    templates_aligned = templates_aligned[:, (center-spike_size//2):(center+spike_size//2+1)]
    
    ########## spatially crop (only keep neighbors) #################

    neighbors = CONFIG.neigh_channels
    n_neigh = np.max(np.sum(CONFIG.neigh_channels, axis=1))
    templates_cropped = np.zeros((n_units, spike_size, n_neigh))

    for k in range(n_units):

        # get neighbors for the main channel in the kth template
        ch_idx = np.where(neighbors[mcs[k]])[0]

        # order channels
        ch_idx, _ = order_channels_by_distance(mcs[k], ch_idx, CONFIG.geom)

        # new kth template is the old kth template by keeping only
        # ordered neighboring channels
        templates_cropped[k, :, :ch_idx.shape[0]] = templates_aligned[k][:, ch_idx]

    fname_templates_cropped = os.path.join(save_dir, 'templates_cropped.npy')
    np.save(fname_templates_cropped, templates_cropped)

    return fname_templates_cropped