def two_templates_dist_linear_align(templates1, templates2, max_shift=5, step=0.5): K1, R, C = templates1.shape K2 = templates2.shape[0] shifts = np.arange(-max_shift, max_shift + step, step) ptps1 = templates1.ptp(1) max_chans1 = np.argmax(ptps1, 1) ptps2 = templates2.ptp(1) max_chans2 = np.argmax(ptps2, 1) shifted_templates = np.zeros((len(shifts), K2, R, C)) for ii, s in enumerate(shifts): shifted_templates[ii] = shift_chans(templates2, np.ones(K2) * s) distance = np.ones((K1, K2)) * 1e4 for k in range(K1): candidates = np.abs(ptps2[:, max_chans1[k]] - ptps1[k, max_chans1[k]] ) / ptps1[k, max_chans1[k]] < 0.5 dist = np.min( np.sum(np.square(templates1[k][np.newaxis, np.newaxis] - shifted_templates[:, candidates]), axis=(2, 3)), 0) dist = np.sqrt(dist) distance[k, candidates] = dist return distance
def template_dist_linear_align(templates, distance=None, units=None, max_shift=5, step=0.5): K, R, C = templates.shape shifts = np.arange(-max_shift, max_shift + step, step) ptps = templates.ptp(1) max_chans = np.argmax(ptps, 1) shifted_templates = np.zeros((len(shifts), K, R, C)) for ii, s in enumerate(shifts): shifted_templates[ii] = shift_chans(templates, np.ones(K) * s) if distance is None: distance = np.ones((K, K)) * 1e4 if units is None: units = np.arange(K) for k in units: candidates = np.abs(ptps[:, max_chans[k]] - ptps[k, max_chans[k]] ) / ptps[k, max_chans[k]] < 0.5 dist = np.min( np.sum(np.square(templates[k][np.newaxis, np.newaxis] - shifted_templates[:, candidates]), axis=(2, 3)), 0) distance[k, candidates] = dist distance[candidates, k] = dist return distance
def standardize_templates(self): # standardize templates ptp = self.templates.ptp(1) self.templates = self.templates / ptp[:, None] ref = np.mean(self.templates, 0) shifts = align_get_shifts_with_ref(self.templates, ref) self.templates = shift_chans(self.templates, shifts)
def align_tc(tc, ref): n_units, n_timepoints, n_channels = tc.shape max_channels = np.abs(tc).max(1).argmax(1) main_tc = np.zeros((n_units, n_timepoints)) for j in range(n_units): main_tc[j] = np.abs(tc[j][:, max_channels[j]]) best_shifts = align_get_shifts_tc(main_tc, ref, upsample_factor=1) shifted_tc = shift_chans(tc, best_shifts) return shifted_tc
def align_step(self, local): if self.verbose: print("chan " + str(self.channel) + ", aligning") # align waveforms by finding best shfits if local: mc = np.where(self.loaded_channels == self.channel)[0][0] best_shifts = align_get_shifts_with_ref(self.wf_global[:, :, mc]) self.shifts[self.indices_in] = best_shifts else: best_shifts = self.shifts[self.indices_in] self.wf_global = shift_chans(self.wf_global, best_shifts) if self.ari_flag: pass
def align_waveforms_parallel(fnames_input_data): for fname in fnames_input_data: temp = np.load(fname, allow_pickle=True) if 'shifts' in temp.files: continue wf = temp['wf'] # align if wf.shape[0] > 0: mc = np.mean(wf, 0).ptp(0).argmax() shifts = align_get_shifts_with_ref(wf[:, :, mc], nshifts=3) wf = shift_chans(wf, shifts) else: shifts = None temp = dict(temp) temp['wf'] = wf temp['shifts'] = shifts np.savez(fname, **temp)
def jitter_templates(self, up_factor=8): n_templates, n_times = self.templates.shape # upsample best fit template up_temp = scipy.signal.resample( x=self.templates, num=n_times*up_factor, axis=1) up_temp = up_temp.T idx = (np.arange(0, n_times)[:,None]*up_factor + np.arange(up_factor)) up_shifted_temps = up_temp[idx].transpose(2,0,1) up_shifted_temps = np.concatenate( (up_shifted_temps, np.roll(up_shifted_temps, shift=1, axis=1)), axis=2) self.templates = up_shifted_temps.transpose(0,2,1).reshape(-1, n_times) ref = np.mean(self.templates, 0) shifts = align_get_shifts_with_ref( self.templates, ref, upsample_factor=1) self.templates = shift_chans(self.templates, shifts)
def template_spike_dist_linear_align(templates, spikes, vis_ptp=2.): """compares the templates and spikes. parameters: ----------- templates: numpy.array shape (K, T, C) spikes: numpy.array shape (M, T, C) jitter: int Align jitter amount between the templates and the spikes. upsample int Upsample rate of the templates and spikes. """ # get ref template max_idx = templates.ptp(1).max(1).argmax(0) ref_template = templates[max_idx] max_chan = ref_template.ptp(0).argmax(0) ref_template = ref_template[:, max_chan] # align templates on max channel best_shifts = align_get_shifts_with_ref(templates[:, :, max_chan], ref_template, nshifts=7) templates = shift_chans(templates, best_shifts) # align all spikes on max channel best_shifts = align_get_shifts_with_ref(spikes[:, :, max_chan], ref_template, nshifts=7) spikes = shift_chans(spikes, best_shifts) # if shifted, cut shifted parts # because it is contaminated by roll function cut = int(np.ceil(np.max(np.abs(best_shifts)))) if cut > 0: templates = templates[:, cut:-cut] spikes = spikes[:, cut:-cut] # get visible channels vis_ptp = np.min((vis_ptp, np.max(templates.ptp(1)))) vis_chan = np.where(templates.ptp(1).max(0) >= vis_ptp)[0] templates = templates[:, :, vis_chan].reshape(templates.shape[0], -1) spikes = spikes[:, :, vis_chan].reshape(spikes.shape[0], -1) # get a subset of locations with maximal difference if templates.shape[0] == 1: # single unit: all timepoints idx = np.arange(templates.shape[1]) elif templates.shape[0] == 2: # two units: # get difference diffs = np.abs(np.diff(templates, axis=0)[0]) # points with large difference idx = np.where(diffs > 1.5)[0] min_diff_points = 5 if len(idx) < 5: idx = np.argsort(diffs)[-min_diff_points:] else: # more than two units: # diff is mean diff to the largest unit diffs = np.mean(np.abs(templates - templates[max_idx][None]), axis=0) idx = np.where(diffs > 1.5)[0] min_diff_points = 5 if len(idx) < 5: idx = np.argsort(diffs)[-min_diff_points:] templates = templates[:, idx] spikes = spikes[:, idx] dist = scipy.spatial.distance.cdist(templates, spikes) return dist.T
def merge_templates_parallel(self, pairs): """Whether to merge two templates or not. """ n_samples = 2000 p_val_threshold = 0.9 merge_pairs = [] for pair in pairs: unit1, unit2 = pair fname_out = os.path.join(self.save_dir, 'unit_{}_{}.npz'.format(unit1, unit2)) if os.path.exists(fname_out): if np.load(fname_out)['merge']: merge_pairs.append(pair) else: # get spikes times and soft assignment idx1 = self.spike_train[:, 1] == unit1 spt1 = self.spike_train[idx1, 0] prob1 = self.soft_assignment[idx1] shift1 = self.shifts[idx1] scale1 = self.scales[idx1] n_spikes1 = self.n_spikes_soft[unit1] idx2 = self.spike_train[:, 1] == unit2 spt2 = self.spike_train[idx2, 0] prob2 = self.soft_assignment[idx2] shift2 = self.shifts[idx2] scale2 = self.scales[idx2] n_spikes2 = self.n_spikes_soft[unit2] # randomly subsample if n_spikes1 + n_spikes2 > n_samples: ratio1 = n_spikes1 / float(n_spikes1 + n_spikes2) n_samples1 = np.min((int(n_samples * ratio1), n_spikes1)) n_samples2 = n_samples - n_samples1 else: n_samples1 = n_spikes1 n_samples2 = n_spikes2 idx1_ = np.random.choice(len(spt1), n_samples1, replace=False, p=prob1 / np.sum(prob1)) idx2_ = np.random.choice(len(spt2), n_samples2, replace=False, p=prob2 / np.sum(prob2)) spt1 = spt1[idx1_] spt2 = spt2[idx2_] shift1 = shift1[idx1_] shift2 = shift2[idx2_] scale1 = scale1[idx1_] scale2 = scale2[idx2_] ptp_max = self.ptps[[unit1, unit2]].max(0) mc = ptp_max.argmax() vis_chan = np.where(ptp_max > 1)[0] # align two units shift_temp = (self.templates[unit2, :, mc].argmin() - self.templates[unit1, :, mc].argmin()) spt2 += shift_temp # load residuals wfs1, skipped_idx1 = self.reader_residual.read_waveforms( spt1, self.spike_size, vis_chan) spt1 = np.delete(spt1, skipped_idx1) shift1 = np.delete(shift1, skipped_idx1) scale1 = np.delete(scale1, skipped_idx1) wfs2, skipped_idx2 = self.reader_residual.read_waveforms( spt2, self.spike_size, vis_chan) spt2 = np.delete(spt2, skipped_idx1) shift2 = np.delete(shift2, skipped_idx2) scale2 = np.delete(scale2, skipped_idx2) # align residuals wfs1 = shift_chans(wfs1, -shift1) wfs2 = shift_chans(wfs2, -shift2) # make clean waveforms wfs1 += scale1[:, None, None] * self.templates[[unit1], :, vis_chan].T if shift_temp > 0: temp_2_shfted = self.templates[[unit2], shift_temp:, vis_chan].T wfs2[:, :-shift_temp] += scale2[:, None, None] * temp_2_shfted elif shift_temp < 0: temp_2_shfted = self.templates[[unit2], :shift_temp, vis_chan].T wfs2[:, -shift_temp:] += scale2[:, None, None] * temp_2_shfted else: wfs2 += scale2[:, None, None] * self.templates[[unit2], :, vis_chan].T # compute spatial covariance spatial_whitener = self.get_spatial_whitener(vis_chan) # whiten wfs1_w = np.matmul(wfs1, spatial_whitener) wfs2_w = np.matmul(wfs2, spatial_whitener) wfs1_w = np.matmul(wfs1_w.transpose(0, 2, 1), self.temporal_whitener).transpose(0, 2, 1) wfs2_w = np.matmul(wfs2_w.transpose(0, 2, 1), self.temporal_whitener).transpose(0, 2, 1) temp_diff_w = np.mean(wfs1_w, 0) - np.mean(wfs2_w, 0) c_w = np.sum(0.5 * (np.mean(wfs1_w, 0) + np.mean(wfs2_w, 0)) * temp_diff_w) dat1_w = np.sum(wfs1_w * temp_diff_w, (1, 2)) dat2_w = np.sum(wfs2_w * temp_diff_w, (1, 2)) dat_all = np.hstack((dat1_w, dat2_w)) p_val = dp(dat_all)[1] if p_val > p_val_threshold: merge = True else: merge = False centers_dist = np.linalg.norm(temp_diff_w) if p_val > p_val_threshold: merge = True else: merge = False centers_dist = np.linalg.norm(temp_diff_w) np.savez(fname_out, merge=merge, dat1_w=dat1_w, dat2_w=dat2_w, centers_dist=centers_dist, p_val=p_val) if merge: merge_pairs.append(pair) return merge_pairs
def load_align_waveforms_parallel(labels_in, save_dir, raw_data, fname_splits, fname_spike_index, fname_labels_input, fname_templates, reader_raw, reader_resid, CONFIG): spike_index = np.load(fname_spike_index) if fname_splits is None: split_labels = spike_index[:, 1] else: split_labels = np.load(fname_splits) # minimum number of spikes per cluster rec_len_sec = np.ptp(spike_index[:, 0]) min_spikes = int(rec_len_sec * CONFIG.cluster.min_fr / CONFIG.recordings.sampling_rate) # first read waveforms in a bigger size # then align and cut down edges if CONFIG.neuralnetwork.apply_nn: spike_size_out = CONFIG.spike_size_nn else: spike_size_out = CONFIG.spike_size spike_size_buffer = 3 spike_size_read = spike_size_out + 2 * spike_size_buffer # load data for making clean wfs if not raw_data: labels_input = np.load(fname_labels_input) templates = np.load(fname_templates) n_times_templates = templates.shape[1] if n_times_templates > spike_size_out: n_times_diff = (n_times_templates - spike_size_out) // 2 templates = templates[:, n_times_diff:-n_times_diff] # get waveforms and align fname_outs = [] for id_ in labels_in: fname_out = os.path.join(save_dir, 'partition_{}.npz'.format(id_)) fname_outs.append(fname_out) if os.path.exists(fname_out): continue idx_ = np.where(split_labels == id_)[0] # spike times spike_times = spike_index[idx_, 0] # if it will be subsampled, min spikes should decrease also subsample_ratio = np.min( (1, CONFIG.cluster.max_n_spikes / float(len(spike_times)))) min_spikes = int(min_spikes * subsample_ratio) # min_spikes needs to be at least 20 to cluster min_spikes = np.max((min_spikes, 20)) # subsample spikes (spike_times, idx_sampled) = subsample_spikes(spike_times, CONFIG.cluster.max_n_spikes) # max channel and neighbor channels channel = int(spike_index[idx_, 1][0]) neighbor_chans = np.where(CONFIG.neigh_channels[channel])[0] if raw_data: wf, skipped_idx = reader_raw.read_waveforms( spike_times, spike_size_read, neighbor_chans) spike_times = np.delete(spike_times, skipped_idx) else: # get upsampled ids template_ids_ = labels_input[idx_][idx_sampled] unique_template_ids = np.unique(template_ids_) # ids relabelled templates_in = templates[unique_template_ids] template_ids_in = np.zeros_like(template_ids_) for ii, k in enumerate(unique_template_ids): template_ids_in[template_ids_ == k] = ii # get clean waveforms wf, skipped_idx = reader_resid.read_clean_waveforms( spike_times, template_ids_in, templates_in, spike_size_read, neighbor_chans) spike_times = np.delete(spike_times, skipped_idx) template_ids_in = np.delete(template_ids_in, skipped_idx) # align if wf.shape[0] > 0: mc = np.where(neighbor_chans == channel)[0][0] shifts = align_get_shifts_with_ref(wf[:, :, mc], nshifts=3) wf = shift_chans(wf, shifts) wf = wf[:, spike_size_buffer:-spike_size_buffer] else: shifts = None if raw_data: np.savez(fname_out, spike_times=spike_times, wf=wf, shifts=shifts, channel=channel, min_spikes=min_spikes) else: np.savez(fname_out, spike_times=spike_times, wf=wf, shifts=shifts, upsampled_ids=template_ids_in, up_templates=templates_in, channel=channel, min_spikes=min_spikes) return fname_outs
def crop_and_align_templates(fname_templates, save_dir, CONFIG): """Crop (spatially) and align (temporally) templates Parameters ---------- Returns ------- """ logger = logging.getLogger(__name__) if not os.path.exists(save_dir): os.mkdir(save_dir) # load templates templates = np.load(fname_templates) n_units, n_times, n_channels = templates.shape mcs = templates.ptp(1).argmax(1) spike_size = (CONFIG.spike_size_nn - 1)*2 + 1 ########## TEMPORALLY ALIGN TEMPLATES ################# # template on max channel only templates_max_channel = np.zeros((n_units, n_times)) for k in range(n_units): templates_max_channel[k] = templates[k, :, mcs[k]] # align them ref = np.mean(templates_max_channel, axis=0) upsample_factor = 8 nshifts = spike_size//2 shifts = align_get_shifts_with_ref( templates_max_channel, ref, upsample_factor, nshifts) templates_aligned = shift_chans(templates, shifts) # crop out the edges since they have bad artifacts templates_aligned = templates_aligned[:, nshifts//2:-nshifts//2] ########## Find High Energy Center of Templates ################# templates_max_channel_aligned = np.zeros((n_units, templates_aligned.shape[1])) for k in range(n_units): templates_max_channel_aligned[k] = templates_aligned[k, :, mcs[k]] # determin temporal center of templates and crop around it total_energy = np.sum(np.square(templates_max_channel_aligned), axis=0) center = np.argmax(np.convolve(total_energy, np.ones(spike_size//2), 'same')) templates_aligned = templates_aligned[:, (center-spike_size//2):(center+spike_size//2+1)] ########## spatially crop (only keep neighbors) ################# neighbors = CONFIG.neigh_channels n_neigh = np.max(np.sum(CONFIG.neigh_channels, axis=1)) templates_cropped = np.zeros((n_units, spike_size, n_neigh)) for k in range(n_units): # get neighbors for the main channel in the kth template ch_idx = np.where(neighbors[mcs[k]])[0] # order channels ch_idx, _ = order_channels_by_distance(mcs[k], ch_idx, CONFIG.geom) # new kth template is the old kth template by keeping only # ordered neighboring channels templates_cropped[k, :, :ch_idx.shape[0]] = templates_aligned[k][:, ch_idx] fname_templates_cropped = os.path.join(save_dir, 'templates_cropped.npy') np.save(fname_templates_cropped, templates_cropped) return fname_templates_cropped