def calculate_sparse_rhat(vbParam, tmp_loc, scores, spike_index, neighbors): # vbParam.rhat calculation n_channels = np.max(spike_index[:, 1]) + 1 n_templates = tmp_loc.shape[0] rhat = lil_matrix((scores.shape[0], n_templates)) rhat = None for channel in range(n_channels): idx_data = np.where(spike_index[:, 1] == channel)[0] score = scores[idx_data] n_data = score.shape[0] ch_idx = [channel] cluster_idx = np.zeros(n_templates, 'bool') for c in ch_idx: cluster_idx[tmp_loc == c] = 1 cluster_idx = np.where(cluster_idx)[0] if n_data > 0 and cluster_idx.shape[0] > 0: local_vbParam = mfm.vbPar(None) local_vbParam.muhat = vbParam.muhat[:, cluster_idx] local_vbParam.Vhat = vbParam.Vhat[:, :, cluster_idx] local_vbParam.invVhat = vbParam.invVhat[:, :, cluster_idx] local_vbParam.nuhat = vbParam.nuhat[cluster_idx] local_vbParam.lambdahat = vbParam.lambdahat[cluster_idx] local_vbParam.ahat = vbParam.ahat[cluster_idx] mask = np.ones([n_data, 1]) group = np.arange(n_data) masked_data = mfm.maskData(score, mask, group) local_vbParam.update_local(masked_data) local_vbParam.rhat[local_vbParam.rhat < 0.1] = 0 local_vbParam.rhat = local_vbParam.rhat / \ np.sum(local_vbParam.rhat, axis=1, keepdims=True) row_idx, col_idx = np.where(local_vbParam.rhat > 0) val = local_vbParam.rhat[row_idx, col_idx] row_idx = idx_data[row_idx] col_idx = cluster_idx[col_idx] rhat_local = np.hstack( (row_idx[:, np.newaxis], col_idx[:, np.newaxis], val[:, np.newaxis])) if rhat is None: rhat = rhat_local else: rhat = np.vstack((rhat, rhat_local)) return rhat
def recover_spikes(self, vbParam, pca, maha_dist=1): N, D = pca.shape # Cat: TODO: check if this maha thresholding recovering distance is good threshold = np.sqrt(chi2.ppf(0.99, D)) # update rhat on full data maskedData = mfm.maskData(pca[:,:,np.newaxis], np.ones([N, 1]), np.arange(N)) vbParam.update_local(maskedData) # calculate mahalanobis distance maha = mfm.calc_mahalonobis(vbParam, pca[:,:,np.newaxis]) idx_recovered = np.where(~np.all(maha >= threshold, axis=1))[0] vbParam.rhat = vbParam.rhat[idx_recovered] # zero out low assignment vals if True: vbParam.rhat[vbParam.rhat < self.assignment_delete_threshold] = 0 vbParam.rhat = vbParam.rhat/np.sum(vbParam.rhat, 1, keepdims=True) return idx_recovered, vbParam
def try_merge(k1, k2, scores, vbParam, maha, cfg): ka, kb = min(k1, k2), max(k1, k2) assignment = vbParam.rhat[:, :2].astype('int32') idx_ka = assignment[:, 1] == ka idx_kb = assignment[:, 1] == kb indices = np.unique(assignment[np.logical_or(idx_ka, idx_kb), 0]) rhat = np.zeros((scores.shape[0], 2)) rhat[assignment[idx_ka, 0], 0] = vbParam.rhat[idx_ka, 2] rhat[assignment[idx_kb, 0], 1] = vbParam.rhat[idx_kb, 2] rhat = rhat[indices] local_scores = scores[indices] local_vbParam = mfm.vbPar(rhat) local_vbParam.muhat = vbParam.muhat[:, [ka, kb]] local_vbParam.Vhat = vbParam.Vhat[:, :, [ka, kb]] local_vbParam.invVhat = vbParam.invVhat[:, :, [ka, kb]] local_vbParam.nuhat = vbParam.nuhat[[ka, kb]] local_vbParam.lambdahat = vbParam.lambdahat[[ka, kb]] local_vbParam.ahat = vbParam.ahat[[ka, kb]] mask = np.ones([local_scores.shape[0], 1]) group = np.arange(local_scores.shape[0]) local_maskedData = mfm.maskData(local_scores, mask, group) # local_vbParam.update_local(local_maskedData) local_suffStat = mfm.suffStatistics(local_maskedData, local_vbParam) ELBO = mfm.ELBO_Class(local_maskedData, local_suffStat, local_vbParam, cfg) L = np.ones(2) (local_vbParam, local_suffStat, merged, _, _) = mfm.check_merge(local_maskedData, local_vbParam, local_suffStat, 0, 1, cfg, L, ELBO) if merged: print("merging {}, {}".format(ka, kb)) vbParam.muhat = np.delete(vbParam.muhat, kb, 1) vbParam.muhat[:, ka] = local_vbParam.muhat[:, 0] vbParam.Vhat = np.delete(vbParam.Vhat, kb, 2) vbParam.Vhat[:, :, ka] = local_vbParam.Vhat[:, :, 0] vbParam.invVhat = np.delete(vbParam.invVhat, kb, 2) vbParam.invVhat[:, :, ka] = local_vbParam.invVhat[:, :, 0] vbParam.nuhat = np.delete(vbParam.nuhat, kb, 0) vbParam.nuhat[ka] = local_vbParam.nuhat[0] vbParam.lambdahat = np.delete(vbParam.lambdahat, kb, 0) vbParam.lambdahat[ka] = local_vbParam.lambdahat[0] vbParam.ahat = np.delete(vbParam.ahat, kb, 0) vbParam.ahat[ka] = local_vbParam.ahat[0] idx_delete = np.where(np.logical_or(idx_ka, idx_kb))[0] vbParam.rhat = np.delete(vbParam.rhat, idx_delete, 0) vbParam.rhat[vbParam.rhat[:, 1] > kb, 1] -= 1 rhat_temp = np.hstack( (indices[:, np.newaxis], np.ones( (indices.size, 1)) * ka, np.sum(rhat, 1, keepdims=True))) vbParam.rhat = np.vstack((vbParam.rhat, rhat_temp)) maha = np.delete(maha, kb, 1) maha = np.delete(maha, kb, 0) diff = vbParam.muhat[:, :, 0] - local_vbParam.muhat[:, :, 0] prec = local_vbParam.Vhat[..., 0] * local_vbParam.nuhat[0] maha[ka] = np.squeeze( np.matmul(diff.T[:, np.newaxis, :], np.matmul(prec[:, :, 0], diff.T[..., np.newaxis]))) prec = np.transpose(vbParam.Vhat[..., 0] * vbParam.nuhat, [2, 0, 1]) maha[:, ka] = np.squeeze( np.matmul(diff.T[:, np.newaxis, :], np.matmul(prec, diff.T[..., np.newaxis]))) maha[ka, ka] = np.inf if not merged: maha[ka, kb] = maha[kb, ka] = np.inf return vbParam, maha
def merge_move_patches(cluster, neigh_clusters, scores, vbParam, maha, cfg): while len(neigh_clusters) > 0: i = neigh_clusters[-1] # indices = np.logical_or(clusterid == cluster, clusterid == i) indices, temp = vbParam.rhat[:, [cluster, i]].nonzero() indices = np.unique(indices) ka, kb = min(cluster, i), max(cluster, i) local_scores = scores[indices] local_vbParam = mfm.vbPar( vbParam.rhat[:, [cluster, i]].toarray()[indices]) local_vbParam.muhat = vbParam.muhat[:, [cluster, i]] local_vbParam.Vhat = vbParam.Vhat[:, :, [cluster, i]] local_vbParam.invVhat = vbParam.invVhat[:, :, [cluster, i]] local_vbParam.nuhat = vbParam.nuhat[[cluster, i]] local_vbParam.lambdahat = vbParam.lambdahat[[cluster, i]] local_vbParam.ahat = vbParam.ahat[[cluster, i]] mask = np.ones([local_scores.shape[0], 1]) group = np.arange(local_scores.shape[0]) local_maskedData = mfm.maskData(local_scores, mask, group) # local_vbParam.update_local(local_maskedData) local_suffStat = mfm.suffStatistics(local_maskedData, local_vbParam) ELBO = mfm.ELBO_Class(local_maskedData, local_suffStat, local_vbParam, cfg) L = np.ones(2) (local_vbParam, local_suffStat, merged, _, _) = mfm.check_merge(local_maskedData, local_vbParam, local_suffStat, 0, 1, cfg, L, ELBO) if merged: print("merging {}, {}".format(cluster, i)) vbParam.muhat = np.delete(vbParam.muhat, kb, 1) vbParam.muhat[:, ka] = local_vbParam.muhat[:, 0] vbParam.Vhat = np.delete(vbParam.Vhat, kb, 2) vbParam.Vhat[:, :, ka] = local_vbParam.Vhat[:, :, 0] vbParam.invVhat = np.delete(vbParam.invVhat, kb, 2) vbParam.invVhat[:, :, ka] = local_vbParam.invVhat[:, :, 0] vbParam.nuhat = np.delete(vbParam.nuhat, kb, 0) vbParam.nuhat[ka] = local_vbParam.nuhat[0] vbParam.lambdahat = np.delete(vbParam.lambdahat, kb, 0) vbParam.lambdahat[ka] = local_vbParam.lambdahat[0] vbParam.ahat = np.delete(vbParam.ahat, kb, 0) vbParam.ahat[ka] = local_vbParam.ahat[0] vbParam.rhat[:, ka] = vbParam.rhat[:, ka] + vbParam.rhat[:, kb] n_data_all, n_templates_all = vbParam.rhat.shape to_keep = list(set(np.arange(n_templates_all)) - set([kb])) vbParam.rhat = vbParam.rhat[:, to_keep] # clusterid[indices] = ka # clusterid[clusterid > kb] = clusterid[clusterid > kb] - 1 neigh_clusters.pop() maha = np.delete(maha, kb, 1) maha = np.delete(maha, kb, 0) diff = vbParam.muhat[:, :, 0] - local_vbParam.muhat[:, :, 0] prec = local_vbParam.Vhat[..., 0] * local_vbParam.nuhat[0] maha[ka] = np.squeeze( np.matmul(diff.T[:, np.newaxis, :], np.matmul(prec[:, :, 0], diff.T[..., np.newaxis]))) prec = np.transpose(vbParam.Vhat[..., 0] * vbParam.nuhat, [2, 0, 1]) maha[:, ka] = np.squeeze( np.matmul(diff.T[:, np.newaxis, :], np.matmul(prec, diff.T[..., np.newaxis]))) maha[ka, ka] = np.inf neigh_clusters = list( np.where(np.logical_or(maha[ka] < 15, maha.T[ka] < 15))[0]) cluster = ka if not merged: maha[ka, kb] = maha[kb, ka] = np.inf neigh_clusters.pop() return vbParam, maha