def crop_spatially(self, neighbors, geometry, inplace=False): n_templates, waveform_length, _ = self.templates.shape # spatially crop (only keep neighbors) n_neigh_to_keep = np.max(np.sum(neighbors, 0)) new_templates = np.zeros( (n_templates, waveform_length, n_neigh_to_keep)) for k in range(n_templates): # get neighbors for the main channel in the kth template ch_idx = np.where(neighbors[self.main_channels[k]])[0] # order channels ch_idx, _ = order_channels_by_distance(self.main_channels[k], ch_idx, geometry) # new kth template is the old kth template by keeping only # ordered neighboring channels new_templates[k, :, :ch_idx.shape[0]] = self.templates[k][:, ch_idx] if inplace: self._update_templates(new_templates) else: return TemplatesProcessor(new_templates)
def matrix_localized(ts, neighbors, geom, spike_size): """Spatial whitening filter for time series [How is this different from the other method?] Parameters ---------- ts: np.array T x C numpy array, where T is the number of time samples and C is the number of channels Returns ------- numpy.ndarray (n_channels, n_channels) whitening matrix """ # get all necessary parameters from param [T, C] = ts.shape R = spike_size * 2 + 1 th = 4 nneigh = np.max(np.sum(neighbors, 0)) # masked recording spikes_rec = np.ones(ts.shape) for i in range(0, C): idxCrossing = np.where(ts[:, i] < -th)[0] idxCrossing = idxCrossing[np.logical_and(idxCrossing >= (R + 1), idxCrossing <= (T - R - 1))] spike_time = idxCrossing[np.logical_and( ts[idxCrossing, i] <= ts[idxCrossing - 1, i], ts[idxCrossing, i] <= ts[idxCrossing + 1, i])] # the portion of recording where spikes present is set to nan for j in np.arange(-spike_size, spike_size + 1): spikes_rec[spike_time + j, i] = 0 # get covariance matrix blanked_rec = ts * spikes_rec M = np.matmul(blanked_rec.transpose(), blanked_rec) / \ np.matmul(spikes_rec.transpose(), spikes_rec) # since ts is standardized recording, covaraince = correlation invhalf_var = np.diag(np.power(np.diag(M), -0.5)) M = np.matmul(np.matmul(invhalf_var, M), invhalf_var) # get localized whitening filter Q = np.zeros((nneigh, nneigh, C)) for c in range(0, C): ch_idx, _ = order_channels_by_distance(c, np.where(neighbors[c])[0], geom) nneigh_c = ch_idx.shape[0] V, D, _ = np.linalg.svd(M[ch_idx, :][:, ch_idx]) eps = 1e-6 Epsilon = np.diag(1 / np.power((D + eps), 0.5)) Q_small = np.matmul(np.matmul(V, Epsilon), V.transpose()) Q[:nneigh_c][:, :nneigh_c, c] = Q_small return Q
def crop_templates(templatesBig, R, neighbors, geom): """[Description] Parameters ---------- Returns ------- """ # number of templates K = templatesBig.shape[0] # main channel for each template and amplitudes mainC = np.argmax(np.amax(np.abs(templatesBig), axis=1), axis=1) amps = np.amax(np.abs(templatesBig), axis=(1, 2)) # get a template on a main channel and align them K_big = np.argmax(amps) templates_mainc = np.zeros((K, templatesBig.shape[1])) t_rec = templatesBig[K_big, :, mainC[K_big]] t_rec = t_rec / np.sqrt(np.sum(np.square(t_rec))) for k in range(K): t1 = templatesBig[k, :, mainC[k]] t1 = t1 / np.sqrt(np.sum(np.square(t1))) shift = align_templates(t1, t_rec) if shift > 0: templates_mainc[k, :(templatesBig.shape[1] - shift)] = t1[shift:] templatesBig[k, :(templatesBig.shape[1] - shift)] = templatesBig[k, shift:] elif shift < 0: templates_mainc[k, (-shift):] = t1[:(templatesBig.shape[1] + shift)] templatesBig[k, (-shift):] = templatesBig[k, :(templatesBig.shape[1] + shift)] else: templates_mainc[k] = t1 # determin temporal center of templates and crop around it R2 = int(R / 2) center = np.argmax( np.convolve(np.sum(np.square(templates_mainc), 0), np.ones(2 * R2 + 1), 'valid')) + R2 templatesBig = templatesBig[:, (center - 3 * R):(center + 3 * R + 1)] # spatially crop nneigh = np.max(np.sum(neighbors, 0)) templatesBig2 = np.zeros( (templatesBig.shape[0], templatesBig.shape[1], nneigh)) for k in range(K): ch_idx = np.where(neighbors[mainC[k]])[0] ch_idx, temp = order_channels_by_distance(mainC[k], ch_idx, geom) templatesBig2[k, :, :ch_idx.shape[0]] = templatesBig[k][:, ch_idx] return templatesBig2
def covariance(recordings, temporal_size, neigbor_steps): """Compute noise spatial and temporal covariance Parameters ---------- recordings: matrix Multi-cannel recordings (n observations x n channels) temporal_size: neigbor_steps: int Number of steps from the multi-channel geometry to consider two channels as neighors """ CONFIG = read_config() # get the neighbor channels at a max "neigbor_steps" steps neigh_channels = n_steps_neigh_channels(CONFIG.neighChannels, neigbor_steps) # sum neighor flags for every channel, this gives the number of neighbors # per channel, then find the channel with the most neighbors # TODO: why are we selecting this one? channel = np.argmax(np.sum(neigh_channels, 0)) # get the neighbor channels for "channel" (neighbords_idx, ) = np.where(neigh_channels[channel]) # order neighbors by distance neighbords_idx, temp = order_channels_by_distance(channel, neighbords_idx, CONFIG.geom) # from the multi-channel recordings, get the neighbor channels # (this includes the channel with the most neighbors itself) rec = recordings[:, neighbords_idx] # filter recording if CONFIG.preprocess.filter == 1: rec = butterworth(rec, CONFIG.filter.low_pass_freq, CONFIG.filter.high_factor, CONFIG.filter.order, CONFIG.recordings.sampling_rate) # standardize recording sd_ = standarize.sd(rec, CONFIG.recordings.sampling_rate) rec = standarize.standarize(rec, sd_) # compute and return spatial and temporal covariance return util.covariance(rec, temporal_size, neigbor_steps, CONFIG.spikeSize)
def training_data(CONFIG, templates_uncropped, min_amp, max_amp, n_isolated_spikes, path_to_standarized, noise_ratio=10, collision_ratio=1, misalign_ratio=1, misalign_ratio2=1, multi_channel=True, return_metadata=False): """Makes training sets for detector, triage and autoencoder Parameters ---------- CONFIG: yaml file Configuration file min_amp: float Minimum value allowed for the maximum absolute amplitude of the isolated spike on its main channel max_amp: float Maximum value allowed for the maximum absolute amplitude of the isolated spike on its main channel n_isolated_spikes: int Number of isolated spikes to generate. This is different from the total number of x_detect path_to_standarized: str Folder storing the standarized data (if not exist, run preprocess to automatically generate) noise_ratio: int Ratio of number of noise to isolated spikes. For example, if n_isolated_spike=1000, noise_ratio=5, then n_noise=5000 collision_ratio: int Ratio of number of collisions to isolated spikes. misalign_ratio: int Ratio of number of spatially and temporally misaligned spikes to isolated spikes misalign_ratio2: int Ratio of number of only-spatially misaligned spikes to isolated spikes multi_channel: bool If True, generate training data for multi-channel neural network. Otherwise generate single-channel data Returns ------- x_detect: numpy.ndarray [number of detection training data, temporal length, number of channels] Training data for the detect net. y_detect: numpy.ndarray [number of detection training data] Label for x_detect x_triage: numpy.ndarray [number of triage training data, temporal length, number of channels] Training data for the triage net. y_triage: numpy.ndarray [number of triage training data] Label for x_triage x_ae: numpy.ndarray [number of ae training data, temporal length] Training data for the autoencoder: noisy spikes y_ae: numpy.ndarray [number of ae training data, temporal length] Denoised x_ae Notes ----- * Detection training data * Multi channel * Positive examples: Clean spikes + noise, Collided spikes + noise * Negative examples: Temporally misaligned spikes + noise, Noise * Triage training data * Multi channel * Positive examples: Clean spikes + noise * Negative examples: Collided spikes + noise """ # FIXME: should we add collided spikes with the first spike non-centered # tod the detection training set? logger = logging.getLogger(__name__) # STEP1: Load recordings data, and select one channel and random (with the # right number of neighbors, then swap the channels so the first one # corresponds to the selected channel, then the nearest neighbor, then the # second nearest and so on... this is only used for estimating noise # structure # ##### FIXME: this needs to be removed, the user should already # pass data with the desired channels rec = RecordingsReader(path_to_standarized, loader='array') channel_n_neighbors = np.sum(CONFIG.neigh_channels, 0) max_neighbors = np.max(channel_n_neighbors) channels_with_max_neighbors = np.where(channel_n_neighbors == max_neighbors)[0] logger.debug('The following channels have %i neighbors: %s', max_neighbors, channels_with_max_neighbors) # reference channel: channel with max number of neighbors channel_selected = np.random.choice(channels_with_max_neighbors) logger.debug('Selected channel %i', channel_selected) # neighbors for the reference channel channel_neighbors = np.where(CONFIG.neigh_channels[channel_selected])[0] # ordered neighbors for reference channel channel_idx, _ = order_channels_by_distance(channel_selected, channel_neighbors, CONFIG.geom) # read the selected channels rec = rec[:, channel_idx] # ##### FIXME:end of section to be removed # STEP 2: load templates processor = TemplatesProcessor(templates_uncropped) # swap channels, first channel is main channel, then nearest neighbor # and so on, only keep neigh_channels templates = (processor.crop_spatially(CONFIG.neigh_channels, CONFIG.geom) .values) # TODO: remove, this data can be obtained from other variables K, _, n_channels = templates_uncropped.shape # make training data set R = CONFIG.spike_size logger.debug('Output will be of size %s', 2 * R + 1) # make clean augmented spikes nk = int(np.ceil(n_isolated_spikes/K)) max_shift = 2*R # make spikes from templates x_templates = util.make_from_templates(templates, min_amp, max_amp, nk) # make collided spikes - max shift is set to R since 2 * R + 1 will be # the final dimension for the spikes. one of the spikes is kept with the # main channel, the other one is shifted and channels are changed x_collision = util.make_collided(x_templates, collision_ratio, multi_channel, max_shift=R, min_shift=5, return_metadata=return_metadata) # make misaligned spikes x_temporally_misaligned = util.make_temporally_misaligned( x_templates, misalign_ratio, multi_channel=multi_channel, max_shift=max_shift) # now spatially misalign those x_misaligned = util.make_spatially_misaligned(x_temporally_misaligned, n_per_spike=misalign_ratio2) # determine noise covariance structure spatial_SIG, temporal_SIG = noise_cov(rec, temporal_size=templates.shape[1], window_size=templates.shape[1], sample_size=1000, threshold=3.0) # make noise n_noise = int(x_templates.shape[0] * noise_ratio) noise = util.make_noise(n_noise, spatial_SIG, temporal_SIG) # make labels y_clean_1 = np.ones((x_templates.shape[0])) y_collision_1 = np.ones((x_collision.shape[0])) y_misaligned_0 = np.zeros((x_misaligned.shape[0])) y_noise_0 = np.zeros((noise.shape[0])) y_collision_0 = np.zeros((x_collision.shape[0])) mid_point = int((x_templates.shape[1]-1)/2) MID_POINT_IDX = slice(mid_point - R, mid_point + R + 1) # TODO: replace _make_noisy for new function x_templates_noisy = util._make_noisy(x_templates, noise) x_collision_noisy = util._make_noisy(x_collision, noise) x_misaligned_noisy = util._make_noisy(x_misaligned, noise) ############# # Detection # ############# if multi_channel: x = yarr.concatenate((x_templates_noisy, x_collision_noisy, x_misaligned_noisy, noise)) x_detect = x[:, MID_POINT_IDX, :] y_detect = np.concatenate((y_clean_1, y_collision_1, y_misaligned_0, y_noise_0)) else: x = yarr.concatenate((x_templates_noisy, x_misaligned_noisy, noise)) x_detect = x[:, MID_POINT_IDX, 0] y_detect = yarr.concatenate((y_clean_1, y_misaligned_0, y_noise_0)) ########## # Triage # ########## if multi_channel: x = yarr.concatenate((x_templates_noisy, x_collision_noisy)) x_triage = x[:, MID_POINT_IDX, :] y_triage = yarr.concatenate((y_clean_1, y_collision_0)) else: x = yarr.concatenate((x_templates_noisy, x_collision_noisy,)) x_triage = x[:, MID_POINT_IDX, 0] y_triage = yarr.concatenate((y_clean_1, y_collision_0)) ############### # Autoencoder # ############### # # TODO: need to abstract this part of the code, create a separate # # function and document it # neighbors_ae = np.ones((n_channels, n_channels), 'int32') # templates_ae = crop_and_align_templates(templates_uncropped, # CONFIG.spike_size, # neighbors_ae, # CONFIG.geom) # tt = templates_ae.transpose(1, 0, 2).reshape(templates_ae.shape[1], -1) # tt = tt[:, np.ptp(tt, axis=0) > 2] # max_amp = np.max(np.ptp(tt, axis=0)) # y_ae = np.zeros((nk*tt.shape[1], tt.shape[0])) # for k in range(tt.shape[1]): # amp_now = np.ptp(tt[:, k]) # amps_range = (np.arange(nk)*(max_amp-min_amp) # / nk+min_amp)[:, np.newaxis, np.newaxis] # y_ae[k*nk:(k+1)*nk] = ((tt[:, k]/amp_now)[np.newaxis, :] # * amps_range[:, :, 0]) # noise_ae = np.random.normal(size=y_ae.shape) # noise_ae = np.matmul(noise_ae, temporal_SIG) # x_ae = y_ae + noise_ae # x_ae = x_ae[:, MID_POINT_IDX] # y_ae = y_ae[:, MID_POINT_IDX] x_ae = None y_ae = None # FIXME: y_ae is no longer used, autoencoder was replaced by PCA return x_detect, y_detect, x_triage, y_triage, x_ae, y_ae
def nn_detection(recordings, neighbors, geom, temporal_features, temporal_window, th_detect, th_triage, detector_filename, autoencoder_filename, triage_filename): """Detect spikes using a neural network Parameters ---------- recordings: numpy.ndarray (n_observations, n_channels) Neural recordings neighbors: numpy.ndarray (n_channels, n_channels) Channels neighbors matric geom: numpy.ndarray (n_channels, 2) Cartesian coordinates for the channels temporal_features: int ? temporal_window: int ? th_detect: float? Spike threshold [improve this explanation] th_triage: float? Triage threshold [improve this explanation] detector_filename: str Path to neural network detector autoencoder_filename: str Path to neural network autoencoder triage_filename: str Path to triage neural network Returns ------- clear_scores: numpy.ndarray (n_spikes, n_features, n_channels) 3D array with the scores for the clear spikes, first simension is the number of spikes, second is the nymber of features and third the number of channels spike_index_clear: numpy.ndarray (n_clear_spikes, 2) 2D array with indexes for clear spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) spike_index_collision: numpy.ndarray (n_collided_spikes, 2) 2D array with indexes for collided spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) """ nnd = NeuralNetDetector(detector_filename, autoencoder_filename) nnt = NeuralNetTriage(triage_filename) T, C = recordings.shape a, b = neighbors.shape if a != b: raise ValueError('neighbors is not a square matrix, verify') if a != C: raise ValueError( 'Number of channels in recording are {} but the ' 'neighbors matrix has {} elements, they must match'.format(C, a)) # neighboring channel info nneigh = np.max(np.sum(neighbors, 0)) c_idx = np.ones((C, nneigh), 'int32') * C for c in range(C): ch_idx, temp = order_channels_by_distance(c, np.where(neighbors[c])[0], geom) c_idx[c, :ch_idx.shape[0]] = ch_idx # input x_tf = tf.placeholder("float", [T, C]) # detect spike index local_max_idx_tf = nnd.get_spikes(x_tf, T, nneigh, c_idx, temporal_window, th_detect) # get score train score_train_tf = nnd.get_score_train(x_tf) # get energy for detected index energy_tf = tf.reduce_sum(tf.square(score_train_tf), axis=2) energy_val_tf = tf.gather_nd(energy_tf, local_max_idx_tf) # get triage probability triage_prob_tf = nnt.triage_prob(x_tf, T, nneigh, c_idx) # gather all results above result = (local_max_idx_tf, score_train_tf, energy_val_tf, triage_prob_tf) # remove duplicates energy_train_tf = tf.placeholder("float", [T, C]) spike_index_tf = remove_duplicate_spikes_by_energy(energy_train_tf, T, c_idx, temporal_window) # get score score_train_placeholder = tf.placeholder("float", [T, C, temporal_features]) spike_index_clear_tf = tf.placeholder("int64", [None, 2]) score_tf = get_score(score_train_placeholder, spike_index_clear_tf, T, temporal_features, c_idx) ############################### # get values of above tensors # ############################### with tf.Session() as sess: nnd.saver.restore(sess, nnd.path_to_detector_model) nnd.saver_ae.restore(sess, nnd.path_to_ae_model) nnt.saver.restore(sess, nnt.path_to_triage_model) local_max_idx, score_train, energy_val, triage_prob = sess.run( result, feed_dict={x_tf: recordings}) energy_train = np.zeros((T, C)) energy_train[local_max_idx[:, 0], local_max_idx[:, 1]] = energy_val spike_index = sess.run(spike_index_tf, feed_dict={energy_train_tf: energy_train}) idx_clean = triage_prob[spike_index[:, 0], spike_index[:, 1]] > th_triage spike_index_clear = spike_index[idx_clean] spike_index_collision = spike_index[~idx_clean] score = sess.run(score_tf, feed_dict={ score_train_placeholder: score_train, spike_index_clear_tf: spike_index_clear }) return score, spike_index_clear, spike_index_collision
def crop_and_align_templates(big_templates, R, neighbors, geom, crop_spatially=True): """Crop (spatially) and align (temporally) templates Parameters ---------- Returns ------- """ logger = logging.getLogger(__name__) logger.debug('crop and align input shape %s', big_templates.shape) # copy templates to avoid modifying the original ones big_templates = np.copy(big_templates) n_templates, _, _ = big_templates.shape # main channels ad amplitudes for each template main_ch = main_channels(big_templates) amps = amplitudes(big_templates) # get a template on a main channel and align them K_big = np.argmax(amps) templates_mainc = np.zeros((n_templates, big_templates.shape[1])) t_rec = big_templates[K_big, :, main_ch[K_big]] t_rec = t_rec / np.sqrt(np.sum(np.square(t_rec))) for k in range(n_templates): t1 = big_templates[k, :, main_ch[k]] t1 = t1 / np.sqrt(np.sum(np.square(t1))) shift = align_templates(t1, t_rec) logger.debug('Template %i will be shifted by %i', k, shift) if shift > 0: templates_mainc[k, :(big_templates.shape[1] - shift)] = t1[shift:] big_templates[k, :(big_templates.shape[1] - shift)] = big_templates[k, shift:] elif shift < 0: templates_mainc[k, (-shift):] = t1[:(big_templates.shape[1] + shift)] big_templates[k, (-shift):] = big_templates[k, :( big_templates.shape[1] + shift)] else: templates_mainc[k] = t1 # determin temporal center of templates and crop around it R2 = int(R / 2) center = np.argmax( np.convolve(np.sum(np.square(templates_mainc), 0), np.ones(2 * R2 + 1), 'valid')) + R2 # crop templates, now they are from 4*R to 3*R logger.debug('6*R+1 %s', 6 * R + 1) big_templates = big_templates[:, (center - 3 * R):(center + 3 * R + 1)] if not crop_spatially: return big_templates else: # spatially crop (only keep neighbors) n_neigh_to_keep = np.max(np.sum(neighbors, 0)) small = np.zeros( (n_templates, big_templates.shape[1], n_neigh_to_keep)) for k in range(n_templates): # get neighbors for the main channel in the kth template ch_idx = np.where(neighbors[main_ch[k]])[0] # order channels ch_idx, _ = order_channels_by_distance(main_ch[k], ch_idx, geom) # new kth template is the old kth template by keeping only # ordered neighboring channels small[k, :, :ch_idx.shape[0]] = big_templates[k][:, ch_idx] return small
def noise_cov(path_to_data, dtype, n_channels, data_order, neighbors, geom, temporal_size): """[Description] Parameters ---------- path_to_data: str Path to recordings data dtype: str dtype for recordings n_channels: int Number of channels in the recordings data_order: str Recordings order, one of ('channels', 'samples'). In a dataset with k observations per channel and j channels: 'channels' means first k contiguous observations come from channel 0, then channel 1, and so on. 'sample' means first j contiguous data are the first observations from all channels, then the second observations from all channels and so on neighbors: numpy.ndarray Neighbors matrix geom: numpy.ndarray Cartesian coordinates for the channels temporal_size: Waveform size Returns ------- """ c_ref = np.argmax(np.sum(neighbors, 0)) ch_idx = np.where(neighbors[c_ref])[0] ch_idx, temp = order_channels_by_distance(c_ref, ch_idx, geom) rec = RecordingsReader(path_to_data, dtype=dtype, n_channels=n_channels, data_order=data_order, loader='array') rec = rec[:, ch_idx] T, C = rec.shape idxNoise = np.zeros((T, C)) R = int((temporal_size-1)/2) for c in range(C): idx_temp = np.where(rec[:, c] > 3)[0] for j in range(-R, R+1): idx_temp2 = idx_temp + j idx_temp2 = idx_temp2[np.logical_and( idx_temp2 >= 0, idx_temp2 < T)] rec[idx_temp2, c] = np.nan idxNoise_temp = (rec[:, c] == rec[:, c]) rec[:, c] = rec[:, c]/np.nanstd(rec[:, c]) rec[~idxNoise_temp, c] = 0 idxNoise[idxNoise_temp, c] = 1 spatial_cov = np.divide(np.matmul(rec.T, rec), np.matmul(idxNoise.T, idxNoise)) w, v = np.linalg.eig(spatial_cov) spatial_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T) spatial_whitener = np.matmul(np.matmul(v, np.diag(1/np.sqrt(w))), v.T) rec = np.matmul(rec, spatial_whitener) noise_wf = np.zeros((1000, temporal_size)) count = 0 while count < 1000: tt = np.random.randint(T-temporal_size) cc = np.random.randint(C) temp = rec[tt:(tt+temporal_size), cc] temp_idxnoise = idxNoise[tt:(tt+temporal_size), cc] if np.sum(temp_idxnoise == 0) == 0: noise_wf[count] = temp count += 1 w, v = np.linalg.eig(np.cov(noise_wf.T)) temporal_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T) return spatial_SIG, temporal_SIG
def run(score, spike_index_clear, spike_index_collision, output_directory='tmp/', recordings_filename='standarized.bin'): """Process spikes Parameters ---------- score: numpy.ndarray (n_spikes, n_features, n_channels) 3D array with the scores for the clear spikes, first simension is the number of spikes, second is the nymber of features and third the number of channels spike_index_clear: numpy.ndarray (n_clear_spikes, 2) 2D array with indexes for clear spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) spike_index_collision: numpy.ndarray (n_collided_spikes, 2) 2D array with indexes for collided spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) output_directory: str, optional Output directory (relative to CONFIG.data.root_folder) used to load the recordings to generate templates, defaults to tmp/ recordings_filename: str, optional Recordings filename (relative to CONFIG.data.root_folder/ output_directory) used to generate the templates, defaults to whitened.bin Returns ------- spike_train_clear: numpy.ndarray (n_clear_spikes, 2) A 2D array for clear spikes whose first column indicates the spike time and the second column the neuron id determined by the clustering algorithm templates: numpy.ndarray (n_channels, waveform_size, n_templates) A 3D array with the templates spike_index_collision: numpy.ndarray (n_collided_spikes, 2) A 2D array for collided spikes whose first column indicates the spike time and the second column the neuron id determined by the clustering algorithm Examples -------- .. literalinclude:: ../examples/process.py """ CONFIG = read_config() MAIN_CHANNEL = 1 startTime = datetime.datetime.now() Time = {'t': 0, 'c': 0, 'm': 0, 's': 0, 'e': 0} logger = logging.getLogger(__name__) nG = len(CONFIG.channelGroups) nneigh = np.max(np.sum(CONFIG.neighChannels, 0)) n_coreset = 0 K = 0 # first column: spike_time # second column: cluster id spike_train_clear = np.zeros((0, 2), 'int32') if CONFIG.clustering.clustering_method == 'location': spike_index_clear_proc = np.zeros((0, 2), 'int32') main_channel_index = spike_index_clear[:, MAIN_CHANNEL] for i, c in enumerate(np.unique(main_channel_index)): logger.info('Processing channel {}'.format(i)) idx = main_channel_index == c score_c = score[idx] spike_index_clear_c = spike_index_clear[idx] ########## # Triage # ########## # TODO: refactor this as CONFIG.doTriage was removed doTriage = True _b = datetime.datetime.now() logger.info('Triaging events with main channel {}'.format(c)) index_keep = triage(score_c, 0, CONFIG.triage.nearest_neighbors, CONFIG.triage.percent, doTriage) Time['t'] += (datetime.datetime.now() - _b).total_seconds() # add untriaged spike index to spike_index_clear_group # and triaged spike index to spike_index_collision spike_index_clear_proc = np.concatenate( (spike_index_clear_proc, spike_index_clear_c[index_keep]), axis=0) spike_index_collision = np.concatenate( (spike_index_collision, spike_index_clear_c[~index_keep]), axis=0) # TODO: add documentation for all of this part, until the # "cleaning" commend # keep untriaged score only score_c = score_c[index_keep] group = np.arange(score_c.shape[0]) mask = np.ones([score_c.shape[0], 1]) _b = datetime.datetime.now() logger.info('Clustering events with main channel {}'.format(c)) if i == 0: global_vbParam, global_maskedData = spikesort( score_c, mask, group, CONFIG) score_proc = score_c else: local_vbParam, local_maskedData = spikesort( score_c, mask, group, CONFIG) global_vbParam.muhat = np.concatenate( [global_vbParam.muhat, local_vbParam.muhat], axis=1) global_vbParam.Vhat = np.concatenate( [global_vbParam.Vhat, local_vbParam.Vhat], axis=2) global_vbParam.invVhat = np.concatenate( [global_vbParam.invVhat, local_vbParam.invVhat], axis=2) global_vbParam.lambdahat = np.concatenate( [global_vbParam.lambdahat, local_vbParam.lambdahat], axis=0) global_vbParam.nuhat = np.concatenate( [global_vbParam.nuhat, local_vbParam.nuhat], axis=0) global_vbParam.ahat = np.concatenate( [global_vbParam.ahat, local_vbParam.ahat], axis=0) global_maskedData.sumY = np.concatenate( [global_maskedData.sumY, local_maskedData.sumY], axis=0) global_maskedData.sumYSq = np.concatenate( [global_maskedData.sumYSq, local_maskedData.sumYSq], axis=0) global_maskedData.sumEta = np.concatenate( [global_maskedData.sumEta, local_maskedData.sumEta], axis=0) global_maskedData.weight = np.concatenate( [global_maskedData.weight, local_maskedData.weight], axis=0) global_maskedData.groupMask = np.concatenate( [global_maskedData.groupMask, local_maskedData.groupMask], axis=0) global_maskedData.meanY = np.concatenate( [global_maskedData.meanY, local_maskedData.meanY], axis=0) global_maskedData.meanYSq = np.concatenate( [global_maskedData.meanYSq, local_maskedData.meanYSq], axis=0) global_maskedData.meanEta = np.concatenate( [global_maskedData.meanEta, local_maskedData.meanEta], axis=0) score_proc = np.concatenate([score_proc, score_c], axis=0) logger.info('merging all channels') L = np.ones(global_vbParam.muhat.shape[1]) global_vbParam.update_local(global_maskedData) suffStat = suffStatistics(global_maskedData, global_vbParam) global_vbParam, suffStat, L = merge_move(global_maskedData, global_vbParam, suffStat, CONFIG, L, 0) assignmentTemp = np.argmax(global_vbParam.rhat, axis=1) assignment = np.zeros(score_proc.shape[0], 'int16') for j in range(score_proc.shape[0]): assignment[j] = assignmentTemp[j] idx_triage = cluster_triage(global_vbParam, score_proc, 3) assignment[idx_triage] = -1 Time['s'] += (datetime.datetime.now() - _b).total_seconds() ############ # Cleaning # ############ # TODO: describe this step spike_train_clear = np.concatenate([ spike_index_clear_proc[~idx_triage, 0:1:], assignment[~idx_triage, np.newaxis] ], axis=1) spike_index_collision = np.concatenate( [spike_index_collision, spike_index_clear_proc[idx_triage]]) else: # according to the docs if clustering method is not 2+3, you can set # 3 x neighboring_channels, but I do not see where the # neighboring_channels is being parsed on this else statemente c_idx = np.ones((CONFIG.recordings.n_channels, nneigh), 'int32') * CONFIG.recordings.n_channels for c in range(CONFIG.recordings.n_channels): ch_idx, _ = order_channels_by_distance( c, np.where(CONFIG.neighChannels[c])[0], CONFIG.geom) c_idx[c, :ch_idx.shape[0]] = ch_idx # iterate over every channel group [missing documentation for this # function]. why is this order needed? for g in range(nG): logger.info("Processing group {} in {} groups.".format(g + 1, nG)) logger.info("Processiing data (triage, coreset, masking) ...") channels = CONFIG.channelGroups[g] neigh_chans = np.where( np.sum(CONFIG.neighChannels[channels], axis=0) > 0)[0] score_group = np.zeros( (0, CONFIG.spikes.temporal_features, neigh_chans.shape[0])) coreset_id_group = np.zeros((0), 'int32') mask_group = np.zeros((0, neigh_chans.shape[0])) spike_index_clear_group = np.zeros((0, 2), 'int32') # go through every channel in the group for c in channels: # index of data whose main channel is c idx = spike_index_clear[:, MAIN_CHANNEL] == c if np.sum(idx) > 0: # score whose main channel is c score_c = score[idx] # spike_index_clear whose main channel is c spike_index_clear_c = spike_index_clear[idx] ########## # Triage # ########## # TODO: refactor this as CONFIG.doTriage was removed doTriage = True _b = datetime.datetime.now() index_keep = triage(score_c, 0, CONFIG.triage.nearest_neighbors, CONFIG.triage.percent, doTriage) Time['t'] += (datetime.datetime.now() - _b).total_seconds() # add untriaged spike index to spike_index_clear_group # and triaged spike index to spike_index_collision spike_index_clear_group = np.concatenate( (spike_index_clear_group, spike_index_clear_c[index_keep]), axis=0) spike_index_collision = np.concatenate( (spike_index_collision, spike_index_clear_c[~index_keep]), axis=0) # keep untriaged score only score_c = score_c[index_keep] ########### # Coreset # ########### # TODO: refactor this as CONFIG.doCoreset was removed doCoreset = True _b = datetime.datetime.now() coreset_id = coreset(score_c, CONFIG.coreset.clusters, CONFIG.coreset.threshold, doCoreset) Time['c'] += (datetime.datetime.now() - _b).total_seconds() ########### # Masking # ########### _b = datetime.datetime.now() mask = getmask(score_c, coreset_id, CONFIG.clustering.masking_threshold, CONFIG.spikes.temporal_features) Time['m'] += (datetime.datetime.now() - _b).total_seconds() ################ # collect data # ################ # restructure score_c and mask to have same number of # channels as score_group score_temp = np.zeros( (score_c.shape[0], CONFIG.spikes.temporal_features, neigh_chans.shape[0])) mask_temp = np.zeros((mask.shape[0], neigh_chans.shape[0])) nneigh_c = np.sum(c_idx[c] < CONFIG.recordings.n_channels) for j in range(nneigh_c): c_interest = np.where(neigh_chans == c_idx[c, j])[0][0] score_temp[:, :, c_interest] = score_c[:, :, j] mask_temp[:, c_interest] = mask[:, j] # add score, coreset_id, mask to the groups score_group = np.concatenate((score_group, score_temp), axis=0) mask_group = np.concatenate((mask_group, mask_temp), axis=0) coreset_id_group = np.concatenate( (coreset_id_group, coreset_id + n_coreset + 1), axis=0) n_coreset += np.max(coreset_id) + 1 if score_group.shape[0] > 0: ############## # Clustering # ############## _b = datetime.datetime.now() logger.info("Clustering...") coreset_id_group = coreset_id_group - 1 n_coreset = 0 cluster_id = spikesort(score_group, mask_group, coreset_id_group, CONFIG) Time['s'] += (datetime.datetime.now() - _b).total_seconds() ############ # Cleaning # ############ # model based triage idx_triage = (cluster_id == -1) # concatenate spike index with cluster id of untriaged ones # to create spike_train_clear si_clustered = spike_index_clear_group[~idx_triage] spt = si_clustered[:, [0]] cluster_id = cluster_id[~idx_triage][:, np.newaxis] spike_train_temp = np.concatenate((spt, cluster_id + K), axis=1) spike_train_clear = np.concatenate( (spike_train_clear, spike_train_temp), axis=0) K += np.amax(cluster_id) + 1 # concatenate triaged spike_index_clear_group # into spike_index_collision spike_index_collision = np.concatenate( (spike_index_collision, spike_index_clear_group[idx_triage]), axis=0) ################# # Get templates # ################# _b = datetime.datetime.now() logger.info("Getting Templates...") path_to_recordings = os.path.join(CONFIG.data.root_folder, output_directory, recordings_filename) merge_threshold = CONFIG.templates.merge_threshold spike_train_clear, templates = gam_templates( spike_train_clear, path_to_recordings, CONFIG.spikeSize, CONFIG.templatesMaxShift, merge_threshold, CONFIG.neighChannels) Time['e'] += (datetime.datetime.now() - _b).total_seconds() currentTime = datetime.datetime.now() if CONFIG.clustering.clustering_method == 'location': logger.info("Mainprocess done in {0} seconds.".format( (currentTime - startTime).seconds)) logger.info("\ttriage:\t{0} seconds".format(Time['t'])) logger.info("\tclustering:\t{0} seconds".format(Time['s'])) logger.info("\ttemplates:\t{0} seconds".format(Time['e'])) else: logger.info("\ttriage:\t{0} seconds".format(Time['t'])) logger.info("\tcoreset:\t{0} seconds".format(Time['c'])) logger.info("\tmasking:\t{0} seconds".format(Time['m'])) logger.info("\tclustering:\t{0} seconds".format(Time['s'])) logger.info("\ttemplates:\t{0} seconds".format(Time['e'])) return spike_train_clear, templates, spike_index_collision
def crop_and_align_templates(fname_templates, save_dir, CONFIG): """Crop (spatially) and align (temporally) templates Parameters ---------- Returns ------- """ logger = logging.getLogger(__name__) if not os.path.exists(save_dir): os.mkdir(save_dir) # load templates templates = np.load(fname_templates) n_units, n_times, n_channels = templates.shape mcs = templates.ptp(1).argmax(1) spike_size = (CONFIG.spike_size_nn - 1)*2 + 1 ########## TEMPORALLY ALIGN TEMPLATES ################# # template on max channel only templates_max_channel = np.zeros((n_units, n_times)) for k in range(n_units): templates_max_channel[k] = templates[k, :, mcs[k]] # align them ref = np.mean(templates_max_channel, axis=0) upsample_factor = 8 nshifts = spike_size//2 shifts = align_get_shifts_with_ref( templates_max_channel, ref, upsample_factor, nshifts) templates_aligned = shift_chans(templates, shifts) # crop out the edges since they have bad artifacts templates_aligned = templates_aligned[:, nshifts//2:-nshifts//2] ########## Find High Energy Center of Templates ################# templates_max_channel_aligned = np.zeros((n_units, templates_aligned.shape[1])) for k in range(n_units): templates_max_channel_aligned[k] = templates_aligned[k, :, mcs[k]] # determin temporal center of templates and crop around it total_energy = np.sum(np.square(templates_max_channel_aligned), axis=0) center = np.argmax(np.convolve(total_energy, np.ones(spike_size//2), 'same')) templates_aligned = templates_aligned[:, (center-spike_size//2):(center+spike_size//2+1)] ########## spatially crop (only keep neighbors) ################# neighbors = CONFIG.neigh_channels n_neigh = np.max(np.sum(CONFIG.neigh_channels, axis=1)) templates_cropped = np.zeros((n_units, spike_size, n_neigh)) for k in range(n_units): # get neighbors for the main channel in the kth template ch_idx = np.where(neighbors[mcs[k]])[0] # order channels ch_idx, _ = order_channels_by_distance(mcs[k], ch_idx, CONFIG.geom) # new kth template is the old kth template by keeping only # ordered neighboring channels templates_cropped[k, :, :ch_idx.shape[0]] = templates_aligned[k][:, ch_idx] fname_templates_cropped = os.path.join(save_dir, 'templates_cropped.npy') np.save(fname_templates_cropped, templates_cropped) return fname_templates_cropped
def get_waveforms(recording, spike_index, proj, neighbors, geom, nnt, th): """Extract waveforms from detected spikes Parameters ---------- recording: matrix (observations, number of channels) Multi-channel recordings spike_index: matrix (number of spikes, 2) Spike index matrix, as returned from any of the detectors proj: matrix (waveform temporal length, number of features) Projection matrix that reduces the dimension of waveform neighbors: matrix (number of channels, number of channel) Neighbors matrix geom: matrix (number of channels, 2) Each row is the x,y coordinate of each channel nnt: class Class for Neural Network based triage th: int Threshold for Neural Network triage algorithm Returns ------- score: matrix (observations, number of features, number of neighbors) clear_spike: boolean vector (observations) Boolean indicating if it is a clear spike or not Notes ----- Le'ts consider a single channel recording V, where V is a vector of length 1 x T. When a spike is detected at time t, then (V_(t-R), V_(t-R+1), ..., V_t, V_(t+1),...V_(t+R)) is going to be a waveform. (a small snippet from the recording around the spike time) """ # column ids for index matrix SPIKE_TIME, MAIN_CHANNEL = 0, 1 n_times, n_channels = recording.shape n_spikes, _ = spike_index.shape window_size, n_features = proj.shape spike_size = int((window_size - 1) / 2) nneigh = np.max(np.sum(neighbors, 0)) recording = np.concatenate((recording, np.zeros((n_times, 1))), axis=1) c_idx = np.ones((n_channels, nneigh), 'int32') * n_channels for c in range(n_channels): ch_idx, _ = order_channels_by_distance(c, np.where(neighbors[c])[0], geom) c_idx[c, :ch_idx.shape[0]] = ch_idx spike_index_clear = np.zeros((0, 2), 'int32') spike_index_collision = np.zeros((0, 2), 'int32') score = np.zeros((0, n_features, nneigh), 'float32') nbuff = 500000 wf = np.zeros((nbuff, window_size, nneigh), 'float32') count = 0 for j in range(n_spikes): t = spike_index[j, SPIKE_TIME] c = c_idx[spike_index[j, MAIN_CHANNEL]] wf[count] = recording[(t - spike_size):(t + spike_size + 1), c] count += 1 # when we gathered enough spikes, go through triage NN and save score if (count == nbuff) or (j == n_spikes - 1): # if we seek all spikes before reaching the buffer size, # size of buffer becomes the number of leftover spikes if j == n_spikes - 1: nbuff = count wf = wf[:nbuff] # going through triage NN. # The output is 1 for clear spike and 0 otherwise clear_spike = nnt.nn_triage(wf, th) # collect clear and colliding spikes spike_index_buff = spike_index[(j - nbuff + 1):(j + 1)] spike_index_clear = np.concatenate( (spike_index_clear, spike_index_buff[clear_spike])) spike_index_collision = np.concatenate( (spike_index_collision, spike_index_buff[~clear_spike])) # calculate score and collect into variable 'score' reshaped_wf = np.reshape(np.transpose(wf[clear_spike], (0, 2, 1)), (-1, window_size)) score_temp = np.transpose( np.reshape(np.matmul(reshaped_wf, proj), (-1, nneigh, n_features)), (0, 2, 1)) score = np.concatenate((score, score_temp), axis=0) # set counter back to zero count = 0 return spike_index_clear, score, spike_index_collision
def noise_cov(path_to_data, dtype, n_channels, data_format, neighbors, geom, temporal_size): """[Description] Parameters ---------- path_to_data: str Path to recordings data dtype: str dtype for recordings n_channels: int Number of channels in the recordings data_format: str Recordings shape ('wide', 'long') neighbors: numpy.ndarray Neighbors matrix geom: numpy.ndarray Cartesian coordinates for the channels temporal_size: Waveform size Returns ------- """ c_ref = np.argmax(np.sum(neighbors, 0)) ch_idx = np.where(neighbors[c_ref])[0] ch_idx, temp = order_channels_by_distance(c_ref, ch_idx, geom) rec = RecordingsReader(path_to_data, dtype=dtype, n_channels=n_channels, data_format=data_format, mmap=False) rec = rec[:, ch_idx] T, C = rec.shape idxNoise = np.zeros((T, C)) R = int((temporal_size-1)/2) for c in range(C): idx_temp = np.where(rec[:, c] > 3)[0] for j in range(-R, R+1): idx_temp2 = idx_temp + j idx_temp2 = idx_temp2[np.logical_and( idx_temp2 >= 0, idx_temp2 < T)] rec[idx_temp2, c] = np.nan idxNoise_temp = (rec[:, c] == rec[:, c]) rec[:, c] = rec[:, c]/np.nanstd(rec[:, c]) rec[~idxNoise_temp, c] = 0 idxNoise[idxNoise_temp, c] = 1 spatial_cov = np.divide(np.matmul(rec.T, rec), np.matmul(idxNoise.T, idxNoise)) w, v = np.linalg.eig(spatial_cov) spatial_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T) spatial_whitener = np.matmul(np.matmul(v, np.diag(1/np.sqrt(w))), v.T) rec = np.matmul(rec, spatial_whitener) noise_wf = np.zeros((1000, temporal_size)) count = 0 while count < 1000: tt = np.random.randint(T-temporal_size) cc = np.random.randint(C) temp = rec[tt:(tt+temporal_size), cc] temp_idxnoise = idxNoise[tt:(tt+temporal_size), cc] if np.sum(temp_idxnoise == 0) == 0: noise_wf[count] = temp count += 1 w, v = np.linalg.eig(np.cov(noise_wf.T)) temporal_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T) return spatial_SIG, temporal_SIG
def crop_and_align_templates(big_templates, R, neighbors, geom): """Crop (spatially) and align (temporally) templates Parameters ---------- Returns ------- """ # copy templates to avoid modifying the original ones big_templates = np.copy(big_templates) K, _, _ = big_templates.shape # main channel for each template and amplitudes mainC = np.argmax(np.amax(np.abs(big_templates), axis=1), axis=1) amps = np.amax(np.abs(big_templates), axis=(1, 2)) # get a template on a main channel and align them K_big = np.argmax(amps) templates_mainc = np.zeros((K, big_templates.shape[1])) t_rec = big_templates[K_big, :, mainC[K_big]] t_rec = t_rec/np.sqrt(np.sum(np.square(t_rec))) for k in range(K): t1 = big_templates[k, :, mainC[k]] t1 = t1/np.sqrt(np.sum(np.square(t1))) shift = align_templates(t1, t_rec) if shift > 0: templates_mainc[k, :(big_templates.shape[1]-shift)] = t1[shift:] big_templates[k, :(big_templates.shape[1]-shift) ] = big_templates[k, shift:] elif shift < 0: templates_mainc[k, (-shift):] = t1[:(big_templates.shape[1]+shift)] big_templates[k, (-shift):] = big_templates[k, :(big_templates.shape[1] + shift)] else: templates_mainc[k] = t1 # determin temporal center of templates and crop around it R2 = int(R/2) center = np.argmax(np.convolve( np.sum(np.square(templates_mainc), 0), np.ones(2*R2+1), 'valid')) + R2 big_templates = big_templates[:, (center-3*R):(center+3*R+1)] # spatially crop nneigh = np.max(np.sum(neighbors, 0)) small_templates = np.zeros((K, big_templates.shape[1], nneigh)) for k in range(K): ch_idx = np.where(neighbors[mainC[k]])[0] ch_idx, temp = order_channels_by_distance(mainC[k], ch_idx, geom) small_templates[k, :, :ch_idx.shape[0]] = big_templates[k][:, ch_idx] return small_templates
def noise_cov(path_to_data, neighbors, geom, temporal_size): """Compute noise temporal and spatial covariance Parameters ---------- path_to_data: str Path to recordings data neighbors: numpy.ndarray Neighbors matrix geom: numpy.ndarray Cartesian coordinates for the channels temporal_size: Waveform size Returns ------- spatial_SIG: numpy.ndarray temporal_SIG: numpy.ndarray """ logger = logging.getLogger(__name__) logger.debug('Computing noise_cov. Neighbors shape: {}, geom shape: {} ' 'temporal_size: {}'.format(neighbors.shape, geom.shape, temporal_size)) c_ref = np.argmax(np.sum(neighbors, 0)) ch_idx = np.where(neighbors[c_ref])[0] ch_idx, temp = order_channels_by_distance(c_ref, ch_idx, geom) rec = RecordingsReader(path_to_data, loader='array') rec = rec[:, ch_idx] T, C = rec.shape idxNoise = np.zeros((T, C)) R = int((temporal_size-1)/2) for c in range(C): idx_temp = np.where(rec[:, c] > 3)[0] for j in range(-R, R+1): idx_temp2 = idx_temp + j idx_temp2 = idx_temp2[np.logical_and(idx_temp2 >= 0, idx_temp2 < T)] rec[idx_temp2, c] = np.nan idxNoise_temp = (rec[:, c] == rec[:, c]) rec[:, c] = rec[:, c]/np.nanstd(rec[:, c]) rec[~idxNoise_temp, c] = 0 idxNoise[idxNoise_temp, c] = 1 spatial_cov = np.divide(np.matmul(rec.T, rec), np.matmul(idxNoise.T, idxNoise)) w, v = np.linalg.eig(spatial_cov) spatial_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T) spatial_whitener = np.matmul(np.matmul(v, np.diag(1/np.sqrt(w))), v.T) rec = np.matmul(rec, spatial_whitener) noise_wf = np.zeros((1000, temporal_size)) count = 0 while count < 1000: tt = np.random.randint(T-temporal_size) cc = np.random.randint(C) temp = rec[tt:(tt+temporal_size), cc] temp_idxnoise = idxNoise[tt:(tt+temporal_size), cc] if np.sum(temp_idxnoise == 0) == 0: noise_wf[count] = temp count += 1 w, v = np.linalg.eig(np.cov(noise_wf.T)) temporal_SIG = np.matmul(np.matmul(v, np.diag(np.sqrt(w))), v.T) logger.debug('spatial_SIG shape: {} temporal_SIG shape: {}' .format(spatial_SIG.shape, temporal_SIG.shape)) return spatial_SIG, temporal_SIG