Exemple #1
0
    def standardize_templates(self):

        # standardize templates
        ptp = self.templates.ptp(1)
        self.templates = self.templates / ptp[:, None]

        ref = np.mean(self.templates, 0)
        shifts = align_get_shifts_with_ref(self.templates, ref)
        self.templates = shift_chans(self.templates, shifts)
Exemple #2
0
    def align_step(self, local):

        if self.verbose:
            print("chan " + str(self.channel) + ", aligning")

        # align waveforms by finding best shfits
        if local:
            mc = np.where(self.loaded_channels == self.channel)[0][0]
            best_shifts = align_get_shifts_with_ref(self.wf_global[:, :, mc])
            self.shifts[self.indices_in] = best_shifts
        else:
            best_shifts = self.shifts[self.indices_in]

        self.wf_global = shift_chans(self.wf_global, best_shifts)

        if self.ari_flag:
            pass
Exemple #3
0
def align_waveforms_parallel(fnames_input_data):

    for fname in fnames_input_data:
        temp = np.load(fname, allow_pickle=True)
        if 'shifts' in temp.files:
            continue

        wf = temp['wf']

        # align
        if wf.shape[0] > 0:
            mc = np.mean(wf, 0).ptp(0).argmax()
            shifts = align_get_shifts_with_ref(wf[:, :, mc], nshifts=3)
            wf = shift_chans(wf, shifts)
        else:
            shifts = None

        temp = dict(temp)
        temp['wf'] = wf
        temp['shifts'] = shifts
        np.savez(fname, **temp)
Exemple #4
0
    def jitter_templates(self, up_factor=8):
        
        n_templates, n_times = self.templates.shape

        # upsample best fit template
        up_temp = scipy.signal.resample(
            x=self.templates,
            num=n_times*up_factor,
            axis=1)
        up_temp = up_temp.T

        idx = (np.arange(0, n_times)[:,None]*up_factor + np.arange(up_factor))
        up_shifted_temps = up_temp[idx].transpose(2,0,1)
        up_shifted_temps = np.concatenate(
            (up_shifted_temps,
             np.roll(up_shifted_temps, shift=1, axis=1)),
            axis=2)
        self.templates = up_shifted_temps.transpose(0,2,1).reshape(-1, n_times)

        ref = np.mean(self.templates, 0)
        shifts = align_get_shifts_with_ref(
            self.templates, ref, upsample_factor=1)
        self.templates = shift_chans(self.templates, shifts)
Exemple #5
0
def template_spike_dist_linear_align(templates, spikes, vis_ptp=2.):
    """compares the templates and spikes.

    parameters:
    -----------
    templates: numpy.array shape (K, T, C)
    spikes: numpy.array shape (M, T, C)
    jitter: int
        Align jitter amount between the templates and the spikes.
    upsample int
        Upsample rate of the templates and spikes.
    """

    # get ref template
    max_idx = templates.ptp(1).max(1).argmax(0)
    ref_template = templates[max_idx]
    max_chan = ref_template.ptp(0).argmax(0)
    ref_template = ref_template[:, max_chan]

    # align templates on max channel
    best_shifts = align_get_shifts_with_ref(templates[:, :, max_chan],
                                            ref_template,
                                            nshifts=7)
    templates = shift_chans(templates, best_shifts)

    # align all spikes on max channel
    best_shifts = align_get_shifts_with_ref(spikes[:, :, max_chan],
                                            ref_template,
                                            nshifts=7)
    spikes = shift_chans(spikes, best_shifts)

    # if shifted, cut shifted parts
    # because it is contaminated by roll function
    cut = int(np.ceil(np.max(np.abs(best_shifts))))
    if cut > 0:
        templates = templates[:, cut:-cut]
        spikes = spikes[:, cut:-cut]

    # get visible channels
    vis_ptp = np.min((vis_ptp, np.max(templates.ptp(1))))
    vis_chan = np.where(templates.ptp(1).max(0) >= vis_ptp)[0]
    templates = templates[:, :, vis_chan].reshape(templates.shape[0], -1)
    spikes = spikes[:, :, vis_chan].reshape(spikes.shape[0], -1)

    # get a subset of locations with maximal difference
    if templates.shape[0] == 1:
        # single unit: all timepoints
        idx = np.arange(templates.shape[1])
    elif templates.shape[0] == 2:
        # two units:
        # get difference
        diffs = np.abs(np.diff(templates, axis=0)[0])
        # points with large difference
        idx = np.where(diffs > 1.5)[0]
        min_diff_points = 5
        if len(idx) < 5:
            idx = np.argsort(diffs)[-min_diff_points:]
    else:
        # more than two units:
        # diff is mean diff to the largest unit
        diffs = np.mean(np.abs(templates - templates[max_idx][None]), axis=0)

        idx = np.where(diffs > 1.5)[0]
        min_diff_points = 5
        if len(idx) < 5:
            idx = np.argsort(diffs)[-min_diff_points:]

    templates = templates[:, idx]
    spikes = spikes[:, idx]

    dist = scipy.spatial.distance.cdist(templates, spikes)

    return dist.T
Exemple #6
0
def load_align_waveforms_parallel(labels_in, save_dir, raw_data, fname_splits,
                                  fname_spike_index, fname_labels_input,
                                  fname_templates, reader_raw, reader_resid,
                                  CONFIG):

    spike_index = np.load(fname_spike_index)
    if fname_splits is None:
        split_labels = spike_index[:, 1]
    else:
        split_labels = np.load(fname_splits)

    # minimum number of spikes per cluster
    rec_len_sec = np.ptp(spike_index[:, 0])
    min_spikes = int(rec_len_sec * CONFIG.cluster.min_fr /
                     CONFIG.recordings.sampling_rate)

    # first read waveforms in a bigger size
    # then align and cut down edges
    if CONFIG.neuralnetwork.apply_nn:
        spike_size_out = CONFIG.spike_size_nn
    else:
        spike_size_out = CONFIG.spike_size
    spike_size_buffer = 3
    spike_size_read = spike_size_out + 2 * spike_size_buffer

    # load data for making clean wfs
    if not raw_data:
        labels_input = np.load(fname_labels_input)
        templates = np.load(fname_templates)

        n_times_templates = templates.shape[1]
        if n_times_templates > spike_size_out:
            n_times_diff = (n_times_templates - spike_size_out) // 2
            templates = templates[:, n_times_diff:-n_times_diff]

    # get waveforms and align
    fname_outs = []
    for id_ in labels_in:
        fname_out = os.path.join(save_dir, 'partition_{}.npz'.format(id_))
        fname_outs.append(fname_out)

        if os.path.exists(fname_out):
            continue

        idx_ = np.where(split_labels == id_)[0]

        # spike times
        spike_times = spike_index[idx_, 0]

        # if it will be subsampled, min spikes should decrease also
        subsample_ratio = np.min(
            (1, CONFIG.cluster.max_n_spikes / float(len(spike_times))))
        min_spikes = int(min_spikes * subsample_ratio)
        # min_spikes needs to be at least 20 to cluster
        min_spikes = np.max((min_spikes, 20))

        # subsample spikes
        (spike_times,
         idx_sampled) = subsample_spikes(spike_times,
                                         CONFIG.cluster.max_n_spikes)

        # max channel and neighbor channels
        channel = int(spike_index[idx_, 1][0])
        neighbor_chans = np.where(CONFIG.neigh_channels[channel])[0]

        if raw_data:
            wf, skipped_idx = reader_raw.read_waveforms(
                spike_times, spike_size_read, neighbor_chans)
            spike_times = np.delete(spike_times, skipped_idx)

        else:

            # get upsampled ids
            template_ids_ = labels_input[idx_][idx_sampled]
            unique_template_ids = np.unique(template_ids_)

            # ids relabelled
            templates_in = templates[unique_template_ids]
            template_ids_in = np.zeros_like(template_ids_)
            for ii, k in enumerate(unique_template_ids):
                template_ids_in[template_ids_ == k] = ii

            # get clean waveforms
            wf, skipped_idx = reader_resid.read_clean_waveforms(
                spike_times, template_ids_in, templates_in, spike_size_read,
                neighbor_chans)
            spike_times = np.delete(spike_times, skipped_idx)
            template_ids_in = np.delete(template_ids_in, skipped_idx)

        # align
        if wf.shape[0] > 0:
            mc = np.where(neighbor_chans == channel)[0][0]
            shifts = align_get_shifts_with_ref(wf[:, :, mc], nshifts=3)
            wf = shift_chans(wf, shifts)
            wf = wf[:, spike_size_buffer:-spike_size_buffer]
        else:
            shifts = None

        if raw_data:
            np.savez(fname_out,
                     spike_times=spike_times,
                     wf=wf,
                     shifts=shifts,
                     channel=channel,
                     min_spikes=min_spikes)
        else:
            np.savez(fname_out,
                     spike_times=spike_times,
                     wf=wf,
                     shifts=shifts,
                     upsampled_ids=template_ids_in,
                     up_templates=templates_in,
                     channel=channel,
                     min_spikes=min_spikes)

    return fname_outs
Exemple #7
0
def crop_and_align_templates(fname_templates, save_dir, CONFIG):
    """Crop (spatially) and align (temporally) templates

    Parameters
    ----------

    Returns
    -------
    """
    logger = logging.getLogger(__name__)
    
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    # load templates
    templates = np.load(fname_templates)
    
    n_units, n_times, n_channels = templates.shape
    mcs = templates.ptp(1).argmax(1)
    spike_size = (CONFIG.spike_size_nn - 1)*2 + 1

    ########## TEMPORALLY ALIGN TEMPLATES #################
    
    # template on max channel only
    templates_max_channel = np.zeros((n_units, n_times))
    for k in range(n_units):
        templates_max_channel[k] = templates[k, :, mcs[k]]

    # align them
    ref = np.mean(templates_max_channel, axis=0)
    upsample_factor = 8
    nshifts = spike_size//2

    shifts = align_get_shifts_with_ref(
        templates_max_channel, ref, upsample_factor, nshifts)

    templates_aligned = shift_chans(templates, shifts)
    
    # crop out the edges since they have bad artifacts
    templates_aligned = templates_aligned[:, nshifts//2:-nshifts//2]

    ########## Find High Energy Center of Templates #################

    templates_max_channel_aligned = np.zeros((n_units, templates_aligned.shape[1]))
    for k in range(n_units):
        templates_max_channel_aligned[k] = templates_aligned[k, :, mcs[k]]

    # determin temporal center of templates and crop around it
    total_energy = np.sum(np.square(templates_max_channel_aligned), axis=0)
    center = np.argmax(np.convolve(total_energy, np.ones(spike_size//2), 'same'))
    templates_aligned = templates_aligned[:, (center-spike_size//2):(center+spike_size//2+1)]
    
    ########## spatially crop (only keep neighbors) #################

    neighbors = CONFIG.neigh_channels
    n_neigh = np.max(np.sum(CONFIG.neigh_channels, axis=1))
    templates_cropped = np.zeros((n_units, spike_size, n_neigh))

    for k in range(n_units):

        # get neighbors for the main channel in the kth template
        ch_idx = np.where(neighbors[mcs[k]])[0]

        # order channels
        ch_idx, _ = order_channels_by_distance(mcs[k], ch_idx, CONFIG.geom)

        # new kth template is the old kth template by keeping only
        # ordered neighboring channels
        templates_cropped[k, :, :ch_idx.shape[0]] = templates_aligned[k][:, ch_idx]

    fname_templates_cropped = os.path.join(save_dir, 'templates_cropped.npy')
    np.save(fname_templates_cropped, templates_cropped)

    return fname_templates_cropped