def run_split(units_in, ptps, labels, CONFIG, ptp_cut=5): spike_index_list = [] for unit in units_in: idx_ = np.where(labels == unit)[0] ptps_ = ptps[idx_] new_assignment = np.zeros(len(idx_), 'int32') idx_big = np.where(ptps_ > ptp_cut)[0] if len(idx_big) > 10: mask = np.ones((len(idx_big), 1)) group = np.arange(len(idx_big)) vbParam = mfm.spikesort(ptps_[idx_big, None, None], mask, group, CONFIG) cc_assignment, stability, cc = anneal_clusters(vbParam) # get ptp per cc mean_ptp_cc = np.zeros(len(cc)) for k in range(len(cc)): mean_ptp_cc[k] = np.mean(ptps_[idx_big][cc_assignment == k]) # reorder cc label by mean ptp cc_assignment_ordered = np.zeros_like(cc_assignment) for ii, k in enumerate(np.argsort(mean_ptp_cc)): cc_assignment_ordered[cc_assignment == k] = ii # cc with the smallest mean ptp will have the same assignment as ptps < ptp cut new_assignment[idx_big] = cc_assignment_ordered spike_index_list.append(np.vstack((idx_, new_assignment)).T) return spike_index_list
def run_mfm(self, gen, pca_wf): mask = np.ones((pca_wf.shape[0], 1)) group = np.arange(pca_wf.shape[0]) vbParam = mfm.spikesort(pca_wf[:, :, np.newaxis], mask, group, self.CONFIG) if self.verbose: print("chan "+ str(self.channel)+', gen '\ +str(gen)+", "+str(vbParam.rhat.shape[1])+" clusters from ",pca_wf.shape) return vbParam
def run_cluster_location(scores, spike_index, min_spikes, CONFIG): """ run clustering algorithm using MFM and location features Parameters ---------- scores: list (n_channels) A list such that scores[c] contains all scores whose main channel is c spike_times: list (n_channels) A list such that spike_index[c] cointains all spike times whose channel is c CONFIG: class configuration class Returns ------- spike_train: np.array (n_data, 2) spike_train such that spike_train[j, 0] and spike_train[j, 1] are the spike time and spike id of spike j """ logger = logging.getLogger(__name__) n_channels = np.max(spike_index[:, 1]) + 1 global_score = None global_vbParam = None global_spike_index = None global_tmp_loc = None # run clustering algorithm per main channel for channel in range(n_channels): logger.info('Processing channel {}'.format(channel)) idx_data = np.where(spike_index[:, 1] == channel)[0] score_channel = scores[idx_data] spike_index_channel = spike_index[idx_data] n_data = score_channel.shape[0] if n_data > 1: # make a fake mask of ones to run clustering algorithm mask = np.ones((n_data, 1)) group = np.arange(n_data) vbParam = mfm.spikesort(np.copy(score_channel), mask, group, CONFIG) # make rhat more sparse vbParam.rhat[vbParam.rhat < 0.1] = 0 vbParam.rhat = vbParam.rhat / np.sum( vbParam.rhat, 1, keepdims=True) # clean clusters with nearly no spikes vbParam = clean_empty_cluster(vbParam, min_spikes) if vbParam.rhat.shape[1] > 0: # add changes to global parameters (global_vbParam, global_tmp_loc, global_score, global_spike_index) = global_cluster_info( vbParam, channel, score_channel, spike_index_channel, global_vbParam, global_tmp_loc, global_score, global_spike_index) return global_vbParam, global_tmp_loc, global_score, global_spike_index
def run_cluster(scores, masks, groups, spike_index, min_spikes, CONFIG): """ run clustering algorithm using MFM Parameters ---------- scores: list (n_channels) A list such that scores[c] contains all scores whose main channel is c masks: list (n_channels) mask for each data in scores masks[c] is the mask of spikes in scores[c] groups: list (n_channels) coreset represented as group id. groups[c] is the group id of spikes in scores[c] spike_index: list (n_channels) A list such that spike_index[c] cointains all spike times whose channel is c CONFIG: class configuration class Returns ------- spike_train: np.array (n_data, 2) spike_train such that spike_train[j, 0] and spike_train[j, 1] are the spike time and spike id of spike j """ # FIXME: mutating parameter # this function is passing a config object and mutating it, # this is not a good idea as having a mutable object lying around the code # can break things and make it hard to debug # (09/27/17) Eduardo logger = logging.getLogger(__name__) n_channels = np.max(spike_index[:, 1]) + 1 global_score = None global_vbParam = None global_spike_index = None global_tmp_loc = None # run clustering algorithm per main channel for channel in range(n_channels): logger.info('Processing channel {}'.format(channel)) idx_data = np.where(spike_index[:, 1] == channel)[0] score_channel = scores[idx_data] mask_channel = masks[channel] group_channel = groups[channel] spike_index_channel = spike_index[idx_data] n_data = score_channel.shape[0] if n_data > 1: # run clustering vbParam = mfm.spikesort(score_channel, mask_channel, group_channel, CONFIG) # make rhat more sparse vbParam.rhat[vbParam.rhat < 0.1] = 0 vbParam.rhat = vbParam.rhat / np.sum( vbParam.rhat, 1, keepdims=True) # clean clusters with nearly no spikes vbParam = clean_empty_cluster(vbParam, min_spikes) # add changes to global parameters (global_vbParam, global_tmp_loc, global_score, global_spike_index) = global_cluster_info( vbParam, channel, score_channel, spike_index_channel, global_vbParam, global_tmp_loc, global_score, global_spike_index) return global_vbParam, global_tmp_loc, global_score, global_spike_index
def runSorter(score_all, mask_all, clr_idx_all, group_all, channel_groups, neighbors, n_features, config): """Run sorting algorithm for every channel group Parameters ---------- Returns ------- spike_train: ? """ # FIXME: mutating parameter # this function is passing a config object and mutating it, # this is not a good idea as having a mutable object lying around the code # can break things and make it hard to debug # (09/27/17) Eduardo nG = len(channel_groups) nmax = 10000 K = 0 spike_train = 0 bar = progressbar.ProgressBar(maxval=nG) # iterate over every channel group (this is computed in config.py) for g in range(nG): # get the channels that conform this group ch_idx = channel_groups[g] neigh_chan = np.sum(neighbors[ch_idx], axis=0) > 0 score = np.zeros( (nmax * ch_idx.shape[0], n_features, np.sum(neigh_chan))) index = np.zeros((nmax * ch_idx.shape[0], 2), 'int32') mask = np.zeros((nmax * ch_idx.shape[0], np.sum(neigh_chan))) group = np.zeros(nmax * ch_idx.shape[0], 'int16') count = 0 Ngroup = 0 for j in range(ch_idx.shape[0]): c = ch_idx[j] if score_all[c].shape[0] > 0: ndataTemp = score_all[c].shape[0] score[count:(count + ndataTemp), :, neighbors[c][neigh_chan]] = score_all[c] clr_idx_temp = clr_idx_all[c] index[count:(count + ndataTemp)] = np.concatenate((np.ones( (ndataTemp, 1)) * c, clr_idx_temp[:, np.newaxis]), axis=1) mask[count:(count + ndataTemp), neighbors[c][neigh_chan]] = mask_all[c] group[count:(count + ndataTemp)] = group_all[c] + Ngroup + 1 Ngroup += np.amax(group_all[c]) + 1 count += ndataTemp score = score[:count] index = index[:count] mask = mask[:count] group = group[:count] - 1 if score.shape[0] > 0: L = spikesort(score, mask, group, config) idx_triage = L == -1 L = L[~idx_triage] index = index[~idx_triage] spikeTrain_temp = np.concatenate((L[:, np.newaxis] + K, index), axis=1) K += np.amax(L) + 1 if g == 0: spike_train = spikeTrain_temp else: spike_train = np.concatenate((spike_train, spikeTrain_temp)) bar.update(g + 1) bar.finish() return spike_train
def run(score, spike_index_clear, spike_index_collision, output_directory='tmp/', recordings_filename='standarized.bin'): """Process spikes Parameters ---------- score: numpy.ndarray (n_spikes, n_features, n_channels) 3D array with the scores for the clear spikes, first simension is the number of spikes, second is the nymber of features and third the number of channels spike_index_clear: numpy.ndarray (n_clear_spikes, 2) 2D array with indexes for clear spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) spike_index_collision: numpy.ndarray (n_collided_spikes, 2) 2D array with indexes for collided spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) output_directory: str, optional Output directory (relative to CONFIG.data.root_folder) used to load the recordings to generate templates, defaults to tmp/ recordings_filename: str, optional Recordings filename (relative to CONFIG.data.root_folder/ output_directory) used to generate the templates, defaults to whitened.bin Returns ------- spike_train_clear: numpy.ndarray (n_clear_spikes, 2) A 2D array for clear spikes whose first column indicates the spike time and the second column the neuron id determined by the clustering algorithm templates: numpy.ndarray (n_channels, waveform_size, n_templates) A 3D array with the templates spike_index_collision: numpy.ndarray (n_collided_spikes, 2) A 2D array for collided spikes whose first column indicates the spike time and the second column the neuron id determined by the clustering algorithm Examples -------- .. literalinclude:: ../examples/process.py """ CONFIG = read_config() MAIN_CHANNEL = 1 startTime = datetime.datetime.now() Time = {'t': 0, 'c': 0, 'm': 0, 's': 0, 'e': 0} logger = logging.getLogger(__name__) nG = len(CONFIG.channelGroups) nneigh = np.max(np.sum(CONFIG.neighChannels, 0)) n_coreset = 0 K = 0 # first column: spike_time # second column: cluster id spike_train_clear = np.zeros((0, 2), 'int32') if CONFIG.clustering.clustering_method == 'location': spike_index_clear_proc = np.zeros((0, 2), 'int32') main_channel_index = spike_index_clear[:, MAIN_CHANNEL] for i, c in enumerate(np.unique(main_channel_index)): logger.info('Processing channel {}'.format(i)) idx = main_channel_index == c score_c = score[idx] spike_index_clear_c = spike_index_clear[idx] ########## # Triage # ########## # TODO: refactor this as CONFIG.doTriage was removed doTriage = True _b = datetime.datetime.now() logger.info('Triaging events with main channel {}'.format(c)) index_keep = triage(score_c, 0, CONFIG.triage.nearest_neighbors, CONFIG.triage.percent, doTriage) Time['t'] += (datetime.datetime.now() - _b).total_seconds() # add untriaged spike index to spike_index_clear_group # and triaged spike index to spike_index_collision spike_index_clear_proc = np.concatenate( (spike_index_clear_proc, spike_index_clear_c[index_keep]), axis=0) spike_index_collision = np.concatenate( (spike_index_collision, spike_index_clear_c[~index_keep]), axis=0) # TODO: add documentation for all of this part, until the # "cleaning" commend # keep untriaged score only score_c = score_c[index_keep] group = np.arange(score_c.shape[0]) mask = np.ones([score_c.shape[0], 1]) _b = datetime.datetime.now() logger.info('Clustering events with main channel {}'.format(c)) if i == 0: global_vbParam, global_maskedData = spikesort( score_c, mask, group, CONFIG) score_proc = score_c else: local_vbParam, local_maskedData = spikesort( score_c, mask, group, CONFIG) global_vbParam.muhat = np.concatenate( [global_vbParam.muhat, local_vbParam.muhat], axis=1) global_vbParam.Vhat = np.concatenate( [global_vbParam.Vhat, local_vbParam.Vhat], axis=2) global_vbParam.invVhat = np.concatenate( [global_vbParam.invVhat, local_vbParam.invVhat], axis=2) global_vbParam.lambdahat = np.concatenate( [global_vbParam.lambdahat, local_vbParam.lambdahat], axis=0) global_vbParam.nuhat = np.concatenate( [global_vbParam.nuhat, local_vbParam.nuhat], axis=0) global_vbParam.ahat = np.concatenate( [global_vbParam.ahat, local_vbParam.ahat], axis=0) global_maskedData.sumY = np.concatenate( [global_maskedData.sumY, local_maskedData.sumY], axis=0) global_maskedData.sumYSq = np.concatenate( [global_maskedData.sumYSq, local_maskedData.sumYSq], axis=0) global_maskedData.sumEta = np.concatenate( [global_maskedData.sumEta, local_maskedData.sumEta], axis=0) global_maskedData.weight = np.concatenate( [global_maskedData.weight, local_maskedData.weight], axis=0) global_maskedData.groupMask = np.concatenate( [global_maskedData.groupMask, local_maskedData.groupMask], axis=0) global_maskedData.meanY = np.concatenate( [global_maskedData.meanY, local_maskedData.meanY], axis=0) global_maskedData.meanYSq = np.concatenate( [global_maskedData.meanYSq, local_maskedData.meanYSq], axis=0) global_maskedData.meanEta = np.concatenate( [global_maskedData.meanEta, local_maskedData.meanEta], axis=0) score_proc = np.concatenate([score_proc, score_c], axis=0) logger.info('merging all channels') L = np.ones(global_vbParam.muhat.shape[1]) global_vbParam.update_local(global_maskedData) suffStat = suffStatistics(global_maskedData, global_vbParam) global_vbParam, suffStat, L = merge_move(global_maskedData, global_vbParam, suffStat, CONFIG, L, 0) assignmentTemp = np.argmax(global_vbParam.rhat, axis=1) assignment = np.zeros(score_proc.shape[0], 'int16') for j in range(score_proc.shape[0]): assignment[j] = assignmentTemp[j] idx_triage = cluster_triage(global_vbParam, score_proc, 3) assignment[idx_triage] = -1 Time['s'] += (datetime.datetime.now() - _b).total_seconds() ############ # Cleaning # ############ # TODO: describe this step spike_train_clear = np.concatenate([ spike_index_clear_proc[~idx_triage, 0:1:], assignment[~idx_triage, np.newaxis] ], axis=1) spike_index_collision = np.concatenate( [spike_index_collision, spike_index_clear_proc[idx_triage]]) else: # according to the docs if clustering method is not 2+3, you can set # 3 x neighboring_channels, but I do not see where the # neighboring_channels is being parsed on this else statemente c_idx = np.ones((CONFIG.recordings.n_channels, nneigh), 'int32') * CONFIG.recordings.n_channels for c in range(CONFIG.recordings.n_channels): ch_idx, _ = order_channels_by_distance( c, np.where(CONFIG.neighChannels[c])[0], CONFIG.geom) c_idx[c, :ch_idx.shape[0]] = ch_idx # iterate over every channel group [missing documentation for this # function]. why is this order needed? for g in range(nG): logger.info("Processing group {} in {} groups.".format(g + 1, nG)) logger.info("Processiing data (triage, coreset, masking) ...") channels = CONFIG.channelGroups[g] neigh_chans = np.where( np.sum(CONFIG.neighChannels[channels], axis=0) > 0)[0] score_group = np.zeros( (0, CONFIG.spikes.temporal_features, neigh_chans.shape[0])) coreset_id_group = np.zeros((0), 'int32') mask_group = np.zeros((0, neigh_chans.shape[0])) spike_index_clear_group = np.zeros((0, 2), 'int32') # go through every channel in the group for c in channels: # index of data whose main channel is c idx = spike_index_clear[:, MAIN_CHANNEL] == c if np.sum(idx) > 0: # score whose main channel is c score_c = score[idx] # spike_index_clear whose main channel is c spike_index_clear_c = spike_index_clear[idx] ########## # Triage # ########## # TODO: refactor this as CONFIG.doTriage was removed doTriage = True _b = datetime.datetime.now() index_keep = triage(score_c, 0, CONFIG.triage.nearest_neighbors, CONFIG.triage.percent, doTriage) Time['t'] += (datetime.datetime.now() - _b).total_seconds() # add untriaged spike index to spike_index_clear_group # and triaged spike index to spike_index_collision spike_index_clear_group = np.concatenate( (spike_index_clear_group, spike_index_clear_c[index_keep]), axis=0) spike_index_collision = np.concatenate( (spike_index_collision, spike_index_clear_c[~index_keep]), axis=0) # keep untriaged score only score_c = score_c[index_keep] ########### # Coreset # ########### # TODO: refactor this as CONFIG.doCoreset was removed doCoreset = True _b = datetime.datetime.now() coreset_id = coreset(score_c, CONFIG.coreset.clusters, CONFIG.coreset.threshold, doCoreset) Time['c'] += (datetime.datetime.now() - _b).total_seconds() ########### # Masking # ########### _b = datetime.datetime.now() mask = getmask(score_c, coreset_id, CONFIG.clustering.masking_threshold, CONFIG.spikes.temporal_features) Time['m'] += (datetime.datetime.now() - _b).total_seconds() ################ # collect data # ################ # restructure score_c and mask to have same number of # channels as score_group score_temp = np.zeros( (score_c.shape[0], CONFIG.spikes.temporal_features, neigh_chans.shape[0])) mask_temp = np.zeros((mask.shape[0], neigh_chans.shape[0])) nneigh_c = np.sum(c_idx[c] < CONFIG.recordings.n_channels) for j in range(nneigh_c): c_interest = np.where(neigh_chans == c_idx[c, j])[0][0] score_temp[:, :, c_interest] = score_c[:, :, j] mask_temp[:, c_interest] = mask[:, j] # add score, coreset_id, mask to the groups score_group = np.concatenate((score_group, score_temp), axis=0) mask_group = np.concatenate((mask_group, mask_temp), axis=0) coreset_id_group = np.concatenate( (coreset_id_group, coreset_id + n_coreset + 1), axis=0) n_coreset += np.max(coreset_id) + 1 if score_group.shape[0] > 0: ############## # Clustering # ############## _b = datetime.datetime.now() logger.info("Clustering...") coreset_id_group = coreset_id_group - 1 n_coreset = 0 cluster_id = spikesort(score_group, mask_group, coreset_id_group, CONFIG) Time['s'] += (datetime.datetime.now() - _b).total_seconds() ############ # Cleaning # ############ # model based triage idx_triage = (cluster_id == -1) # concatenate spike index with cluster id of untriaged ones # to create spike_train_clear si_clustered = spike_index_clear_group[~idx_triage] spt = si_clustered[:, [0]] cluster_id = cluster_id[~idx_triage][:, np.newaxis] spike_train_temp = np.concatenate((spt, cluster_id + K), axis=1) spike_train_clear = np.concatenate( (spike_train_clear, spike_train_temp), axis=0) K += np.amax(cluster_id) + 1 # concatenate triaged spike_index_clear_group # into spike_index_collision spike_index_collision = np.concatenate( (spike_index_collision, spike_index_clear_group[idx_triage]), axis=0) ################# # Get templates # ################# _b = datetime.datetime.now() logger.info("Getting Templates...") path_to_recordings = os.path.join(CONFIG.data.root_folder, output_directory, recordings_filename) merge_threshold = CONFIG.templates.merge_threshold spike_train_clear, templates = gam_templates( spike_train_clear, path_to_recordings, CONFIG.spikeSize, CONFIG.templatesMaxShift, merge_threshold, CONFIG.neighChannels) Time['e'] += (datetime.datetime.now() - _b).total_seconds() currentTime = datetime.datetime.now() if CONFIG.clustering.clustering_method == 'location': logger.info("Mainprocess done in {0} seconds.".format( (currentTime - startTime).seconds)) logger.info("\ttriage:\t{0} seconds".format(Time['t'])) logger.info("\tclustering:\t{0} seconds".format(Time['s'])) logger.info("\ttemplates:\t{0} seconds".format(Time['e'])) else: logger.info("\ttriage:\t{0} seconds".format(Time['t'])) logger.info("\tcoreset:\t{0} seconds".format(Time['c'])) logger.info("\tmasking:\t{0} seconds".format(Time['m'])) logger.info("\tclustering:\t{0} seconds".format(Time['s'])) logger.info("\ttemplates:\t{0} seconds".format(Time['e'])) return spike_train_clear, templates, spike_index_collision