def get_SpikeSample(dataRAW, row, col, params): """ given a batch of data (time by channels), and some time (row) and channel (col) indices for spikes, this function returns the 1D time clips of voltage around those spike times """ nT, nChan = dataRAW.shape # times around the peak to consider dt = cp.arange(params.nt0) # the negativity is expected at nt0min, so we align the detected peaks there dt = -params.nt0min + dt # temporal indices (awkward way to index into full matrix of data) indsT = row + dt[:, np.newaxis] + 1 # broadcasting indsC = col indsC[indsC < 0] = 0 # anything that's out of bounds just gets set to the limit indsC[ indsC >= nChan] = nChan - 1 # only needed for channels not time (due to time buffer) indsT = cp.transpose(cp.atleast_3d(indsT), [0, 2, 1]) indsC = cp.transpose(cp.atleast_3d(indsC), [2, 0, 1]) # believe it or not, these indices grab just the right timesamples forour spikes ix = indsT + indsC * nT # grab the data and reshape it appropriately (time samples by channels by num spikes) clips = dataRAW.T.ravel()[ix[:, 0, :]].reshape((dt.size, row.size), order='F') # HERE return clips
def _preprocess(labels): label_values, inv_idx = cp.unique(labels, return_inverse=True) if not (label_values == 0).any(): warn('Random walker only segments unlabeled areas, where ' 'labels == 0. No zero valued areas in labels were ' 'found. Returning provided labels.', stacklevel=2) return labels, None, None, None, None # If some labeled pixels are isolated inside pruned zones, prune them # as well and keep the labels for the final output null_mask = labels == 0 pos_mask = labels > 0 mask = labels >= 0 fill = ndi.binary_propagation(null_mask, mask=mask) isolated = cp.logical_and(pos_mask, cp.logical_not(fill)) pos_mask[isolated] = False # If the array has pruned zones, be sure that no isolated pixels # exist between pruned zones (they could not be determined) if label_values[0] < 0 or cp.any(isolated): # synchronize! isolated = cp.logical_and( cp.logical_not(ndi.binary_propagation(pos_mask, mask=mask)), null_mask) labels[isolated] = -1 if cp.all(isolated[null_mask]): warn('All unlabeled pixels are isolated, they could not be ' 'determined by the random walker algorithm.', stacklevel=2) return labels, None, None, None, None mask[isolated] = False mask = cp.atleast_3d(mask) else: mask = None # Reorder label values to have consecutive integers (no gaps) zero_idx = cp.searchsorted(label_values, cp.array(0)) labels = cp.atleast_3d(inv_idx.reshape(labels.shape) - zero_idx) nlabels = label_values[zero_idx + 1:].shape[0] inds_isolated_seeds = cp.nonzero(isolated) isolated_values = labels[inds_isolated_seeds] return labels, nlabels, mask, inds_isolated_seeds, isolated_values
def dstack(tup): """Stacks arrays along the third axis. Args: tup (sequence of arrays): Arrays to be stacked. Each array is converted by :func:`cupy.atleast_3d` before stacking. Returns: cupy.ndarray: Stacked array. .. seealso:: :func:`numpy.dstack` """ return concatenate(cupy.atleast_3d(*tup), 2)
def dstack(tup): """Stacks arrays along the third axis. Args: tup (sequence of arrays): Arrays to be stacked. Each array is converted by :func:`cupy.atleast_3d` before stacking. Returns: cupy.ndarray: Stacked array. .. seealso:: :func:`numpy.dstack` """ return concatenate(cupy.atleast_3d(*tup), 2)
def clusterSingleBatches(ctx, sanity_plots=False, plot_widgets=None, plot_pos=0): """ outputs an ordering of the batches according to drift for each batch, it extracts spikes as threshold crossings and clusters them with kmeans the resulting cluster means are then compared for all pairs of batches, and a dissimilarity score is assigned to each pair the matrix of similarity scores is then re-ordered so that low dissimilaity is along the diagonal """ Nbatch = ctx.intermediate.Nbatch params = ctx.params probe = ctx.probe raw_data = ctx.raw_data ir = ctx.intermediate proc = ir.proc if not params.reorder: # if reordering is turned off, return consecutive order iorig = np.arange(Nbatch) return iorig, None, None nPCs = params.nPCs Nfilt = ceil(probe.Nchan / 2) # extract PCA waveforms pooled over channels wPCA = extractPCfromSnippets(proc, probe=probe, params=params, Nbatch=Nbatch) Nchan = probe.Nchan # TODO: move_to_config niter = 10 # iterations for k-means. we won't run it to convergence to save time nBatches = Nbatch NchanNear = min(Nchan, 2 * 8 + 1) # initialize big arrays on the GPU to hold the results from each batch # this holds the unit norm templates Ws = cp.zeros((nPCs, NchanNear, Nfilt, nBatches), dtype=np.float32, order='F') # this holds the scalings mus = cp.zeros((Nfilt, nBatches), dtype=np.float32, order='F') # this holds the number of spikes for that cluster ns = cp.zeros((Nfilt, nBatches), dtype=np.float32, order='F') # this holds the center channel for each template Whs = ones((Nfilt, nBatches), dtype=np.int32, order='F') i0 = 0 # TODO: move_to_config NrankPC = 3 # I am not sure if this gets used, but it goes into the function # return an array of closest channels for each channel iC = getClosestChannels(probe, params.sigmaMask, NchanNear)[0] for ibatch in tqdm(range(nBatches), desc="Clustering spikes"): enough_spikes = False # extract spikes using PCA waveforms uproj, call = extractPCbatch2(proc, params, probe, wPCA, min(nBatches - 2, ibatch), iC, Nbatch) if cp.sum(cp.isnan(uproj)) > 0: break # I am not sure what case this safeguards against.... if uproj.shape[1] > Nfilt: enough_spikes = True # this initialize the k-means W, mu, Wheights, irand = initializeWdata2(call, uproj, Nchan, nPCs, Nfilt, iC) # Params is a whole bunch of parameters sent to the C++ scripts inside a float64 vector Params = [ uproj.shape[1], NrankPC, Nfilt, 0, W.shape[0], 0, NchanNear, Nchan ] for i in range(niter): Wheights = Wheights.reshape((1, 1, -1), order='F') iC = cp.atleast_3d(iC) # we only compute distances to clusters on the same channels # this tells us which spikes and which clusters might match iMatch = cp.min(cp.abs(iC - Wheights), axis=0) < .1 # get iclust and update W # CUDA script to efficiently compute distances for pairs in which iMatch is 1 dWU, iclust, dx, nsp, dV = mexClustering2( Params, uproj, W, mu, call, iMatch, iC) dWU = dWU / ( 1e-5 + nsp.T ) # divide the cumulative waveform by the number of spike mu = cp.sqrt(cp.sum(dWU**2, axis=0)) # norm of cluster template W = dWU / (1e-5 + mu) # unit normalize templates W = W.reshape((nPCs, Nchan, Nfilt), order='F') nW = W[ 0, ...]**2 # compute best channel from the square of the first PC feature W = W.reshape((Nchan * nPCs, Nfilt), order='F') Wheights = cp.argmax( nW, axis=0) # the new best channel of each cluster template # carefully keep track of cluster templates in dense format W = W.reshape((nPCs, Nchan, Nfilt), order='F') W0 = cp.zeros((nPCs, NchanNear, Nfilt), dtype=np.float32, order='F') for t in range(Nfilt): W0[..., t] = W[:, iC[:, Wheights[t]], t].squeeze() # I don't really know why this needs another normalization W0 = W0 / (1e-5 + cp.sum(cp.sum(W0**2, axis=0)[np.newaxis, ...], axis=1)**.5) # if a batch doesn't have enough spikes, it gets the cluster templates of the previous batc if enough_spikes: Ws[..., ibatch] = W0 mus[:, ibatch] = mu ns[:, ibatch] = nsp Whs[:, ibatch] = Wheights.astype(np.int32) else: logger.warning('Data batch #%d only had %d spikes.', ibatch, uproj.shape[1]) i0 = i0 + Nfilt # anothr one of these Params variables transporting parameters to the C++ code Params = [1, NrankPC, Nfilt, 0, W.shape[0], 0, NchanNear, Nchan] # the total number of templates is the number of templates per batch times the number of batch Params[0] = Ws.shape[2] * Ws.shape[3] # initialize dissimilarity matrix ccb = cp.zeros((nBatches, nBatches), dtype=np.float32, order='F') for ibatch in tqdm(range(nBatches), desc="Computing distances"): # for every batch, compute in parallel its dissimilarity to ALL other batches Wh0 = Whs[:, ibatch] # this one is the primary batch W0 = Ws[..., ibatch] mu = mus[..., ibatch] # embed the templates from the primary batch back into a full, sparse representation W = cp.zeros((nPCs, Nchan, Nfilt), dtype=np.float32, order='F') for t in range(Nfilt): W[:, iC[:, Wh0[t]], t] = cp.atleast_3d(Ws[:, :, t, ibatch]) # pairs of templates that live on the same channels are potential "matches" # TODO: move_to_config? is the 0.1 here important? iMatch = cp.min(cp.abs(iC - Wh0.reshape((1, 1, -1), order='F')), axis=0) < .1 # compute dissimilarities for iMatch = 1 iclust, ds = mexDistances2(Params, Ws, W, iMatch, iC, Whs, mus, mu) # ds are squared Euclidian distances ds = ds.reshape((Nfilt, -1), order='F') # this should just be an Nfilt-long vector ds = cp.maximum(0, ds) # weigh the distances according to number of spikes in cluster ccb[ibatch, :] = cp.mean(cp.sqrt(ds) * ns, axis=0) / cp.mean(ns, axis=0) # ccb = cp.asnumpy(ccb) # some normalization steps are needed: zscoring, and symmetrizing ccb ccb0 = zscore(ccb, axis=0) ccb0 = ccb0 + ccb0.T # sort by manifold embedding algorithm # iorig is the sorting of the batches # ccbsort is the resorted matrix (useful for diagnosing drift) ccbsort, iorig = sortBatches2(ccb0) logger.info("Finished clustering.") if sanity_plots: assert plot_widgets is not None, "if sanity_plots is set, then plot_widgets cannot be None" plot_dissimilarity_matrices(ccb0, ccbsort, plot_widgets[plot_pos]) return Bunch(iorig=iorig, ccb0=ccb0, ccbsort=ccbsort)
def learnAndSolve8b(ctx): """This is the main optimization. Takes the longest time and uses the GPU heavily.""" Nbatch = ctx.intermediate.Nbatch params = ctx.params probe = ctx.probe ir = ctx.intermediate proc = ir.proc iorig = ir.iorig # TODO: move_to_config NrankPC = 6 # this one is the rank of the PCs, used to detect spikes with threshold crossings Nrank = 3 # this one is the rank of the templates wTEMP, wPCA = extractTemplatesfromSnippets(proc=proc, probe=probe, params=params, Nbatch=Nbatch, nPCs=NrankPC) # move these to the GPU wPCA = cp.asarray(wPCA[:, :Nrank], dtype=np.float32, order='F') wTEMP = cp.asarray(wTEMP, dtype=np.float32, order='F') wPCAd = cp.asarray(wPCA, dtype=np.float64, order='F') # convert to double for extra precision nt0 = params.nt0 nt0min = params.nt0min nBatches = Nbatch NT = params.NT Nfilt = params.Nfilt Nchan = probe.Nchan # two variables for the same thing? number of nearest channels to each primary channel # TODO: unclear - let's fix this NchanNear = min(probe.Nchan, 32) Nnearest = min(probe.Nchan, 32) # decay of gaussian spatial mask centered on a channel sigmaMask = params.sigmaMask batchstart = list(range(0, NT * nBatches + 1, NT)) # find the closest NchanNear channels, and the masks for those channels iC, mask, C2C = getClosestChannels(probe, sigmaMask, NchanNear) # sorting order for the batches isortbatches = iorig nhalf = int(ceil(nBatches / 2)) - 1 # halfway point # this batch order schedule goes through half of the data forward and backward during the model # fitting and then goes through the data symmetrically-out from the center during the final # pass ischedule = np.concatenate( (np.arange(nhalf, nBatches), np.arange(nBatches - 1, nhalf - 1, -1))) i1 = np.arange(nhalf - 1, -1, -1) i2 = np.arange(nhalf, nBatches) irounds = np.concatenate((ischedule, i1, i2)) niter = irounds.size if irounds[niter - nBatches - 1] != nhalf: # this check is in here in case I do somehting weird when I try different schedules raise ValueError('Mismatch between number of batches') # these two flags are used to keep track of what stage of model fitting we're at # flag_final = 0 flag_resort = 1 # this is the absolute temporal offset in seconds corresponding to the start of the # spike sorted time segment t0 = 0 # ceil(params.trange(1) * ops.fs) nInnerIter = 60 # this is for SVD for the power iteration # schedule of learning rates for the model fitting part # starts small and goes high, it corresponds approximately to the number of spikes # from the past that were averaged to give rise to the current template pmi = cp.exp( -1. / cp.linspace(params.momentum[0], params.momentum[1], niter - nBatches)) Nsum = min( Nchan, 7) # how many channels to extend out the waveform in mexgetspikes # lots of parameters passed into the CUDA scripts Params = np.array([ NT, Nfilt, params.Th[0], nInnerIter, nt0, Nnearest, Nrank, params.lam, pmi[0], Nchan, NchanNear, params.nt0min, 1, Nsum, NrankPC, params.Th[0] ], dtype=np.float64) # W0 has to be ordered like this W0 = cp.transpose( cp.atleast_3d(cp.asarray(wPCA, dtype=np.float64, order='F')), [0, 2, 1]) # initialize the list of channels each template lives on iList = cp.zeros((Nnearest, Nfilt), dtype=np.int32, order='F') # initialize average number of spikes per batch for each template nsp = cp.zeros((0, 1), dtype=np.float64, order='F') # this flag starts 0, is set to 1 later Params[12] = 0 # kernels for subsample alignment Ka, Kb = getKernels(params) p1 = .95 # decay of nsp estimate in each batch ntot = 0 # this keeps track of dropped templates for debugging purposes ndrop = np.zeros(2, dtype=np.float32, order='F') # this is the minimum firing rate that all templates must maintain, or be dropped m0 = params.minFR * params.NT / params.fs # allocate variables when switching to extraction phase # this holds spike times, clusters and other info per spike st3 = [] # cp.zeros((int(1e7), 5), dtype=np.float32, order='F') # these ones store features per spike # Nnearest is the number of nearest templates to store features for fW = LargeArrayWriter(ctx.path('fW', ext='.dat'), dtype=np.float32, shape=(Nnearest, -1)) # NchanNear is the number of nearest channels to take PC features from fWpc = LargeArrayWriter(ctx.path('fWpc', ext='.dat'), dtype=np.float32, shape=(NchanNear, Nrank, -1)) for ibatch in tqdm(range(niter), desc="Optimizing templates"): # korder is the index of the batch at this point in the schedule korder = int(irounds[ibatch]) # k is the index of the batch in absolute terms k = int(isortbatches[korder]) logger.debug("Batch %d/%d, %d templates.", ibatch, niter, Nfilt) if ibatch > niter - nBatches - 1 and korder == nhalf: # this is required to revert back to the template states in the middle of the # batches W, dWU = ir.W, ir.dWU logger.debug('Reverted back to middle timepoint.') if ibatch < niter - nBatches: # obtained pm for this batch Params[8] = float(pmi[ibatch]) pm = pmi[ibatch] * ones((Nfilt, ), dtype=np.float64, order='F') # loading a single batch (same as everywhere) offset = Nchan * batchstart[k] dat = proc.flat[offset:offset + NT * Nchan].reshape((-1, Nchan), order='F') dataRAW = cp.asarray(dat, dtype=np.float32) / params.scaleproc if ibatch == 0: # only on the first batch, we first get a new set of spikes from the residuals, # which in this case is the unmodified data because we start with no templates # CUDA function to get spatiotemporal clips from spike detections dWU, cmap = mexGetSpikes2(Params, dataRAW, wTEMP, iC) dWU = cp.asarray(dWU, dtype=np.float64, order='F') # project these into the wPCA waveforms dWU = cp.reshape(cp.dot( wPCAd, cp.dot(wPCAd.T, dWU.reshape((dWU.shape[0], -1), order='F'))), dWU.shape, order='F') # initialize the low-rank decomposition with standard waves W = W0[:, cp.ones(dWU.shape[2], dtype=np.int32), :] Nfilt = W.shape[1] # update the number of filters/templates # initialize the number of spikes for new templates with the minimum allowed value, # so it doesn't get thrown back out right away nsp = _extend(nsp, 0, Nfilt, m0) Params[1] = Nfilt # update in the CUDA parameters if flag_resort: # this is a flag to resort the order of the templates according to best peak # channel # this is important in order to have cohesive memory requests from the GPU RAM # max channel (either positive or negative peak) iW = cp.argmax(cp.abs(dWU[nt0min - 1, :, :]), axis=0) # iW = int32(squeeze(iW)) isort = cp.argsort(iW) # sort by max abs channel iW = iW[isort] W = W[:, isort, :] # user ordering to resort all the other template variables dWU = dWU[:, :, isort] nsp = nsp[isort] # decompose dWU by svd of time and space (via covariance matrix of 61 by 61 samples) # this uses a "warm start" by remembering the W from the previous iteration W, U, mu = mexSVDsmall2(Params, dWU, W, iC, iW, Ka, Kb) # UtU is the gram matrix of the spatial components of the low-rank SVDs # it tells us which pairs of templates are likely to "interfere" with each other # such as when we subtract off a template # this needs to change (but I don't know why!) UtU, maskU = getMeUtU(iW, iC, mask, Nnearest, Nchan) # main CUDA function in the whole codebase. does the iterative template matching # based on the current templates, gets features for these templates if requested # (featW, featPC), # gets scores for the template fits to each spike (vexp), outputs the average of # waveforms assigned to each cluster (dWU0), # and probably a few more things I forget about st0, id0, x0, featW, dWU0, drez, nsp0, featPC, vexp = mexMPnu8( Params, dataRAW, U, W, mu, iC, iW, UtU, iList, wPCA) logger.debug("%d spikes.", x0.size) # Sometimes nsp can get transposed (think this has to do with it being # a single element in one iteration, to which elements are added # nsp, nsp0, and pm must all be row vectors (Nfilt x 1), so force nsp # to be a row vector. # nsp = cp.atleast_2d(nsp) # nsprow, nspcol = nsp.shape # if nsprow < nspcol: # nsp = nsp.T nsp = nsp.squeeze() # updates the templates as a running average weighted by recency # since some clusters have different number of spikes, we need to apply the # exp(pm) factor several times, and fexp is the resulting update factor # for each template fexp = np.exp(nsp0 * cp.log(pm[:Nfilt])) fexp = cp.reshape(fexp, (1, 1, -1), order='F') dWU = dWU * fexp + (1 - fexp) * ( dWU0 / cp.reshape(cp.maximum(1, nsp0), (1, 1, -1), order='F')) # nsp just gets updated according to the fixed factor p1 nsp = nsp * p1 + (1 - p1) * nsp0 if ibatch == niter - nBatches - 1: # if we reached this point, we need to disable secondary template updates # like dropping, and adding new templates. We need to memorize the state of the # templates at this timepoint, and set the processing mode to "extraction and # tracking" flag_resort = 0 # no need to resort templates by channel any more # flag_final = 1 # this is the "final" pass # final clean up, triage templates one last time W, U, dWU, mu, nsp, ndrop = triageTemplates2( params, iW, C2C, W, U, dWU, mu, nsp, ndrop) # final number of templates Nfilt = W.shape[1] Params[1] = Nfilt # final covariance matrix between all templates WtW, iList = getMeWtW(W, U, Nnearest) # iW is the final channel assigned to each template iW = cp.argmax(cp.abs(dWU[nt0min - 1, :, :]), axis=0) # extract ALL features on the last pass Params[ 12] = 2 # this is a flag to output features (PC and template features) # different threshold on last pass? Params[2] = params.Th[ -1] # usually the threshold is much lower on the last pass # memorize the state of the templates logger.debug("Memorized middle timepoint.") ir.W, ir.dWU, ir.U, ir.mu = W, dWU, U, mu ir.Wraw = cp.zeros((U.shape[0], W.shape[0], U.shape[1]), dtype=np.float64, order='F') for n in range(U.shape[1]): # temporarily use U rather Urot until I have a chance to test it ir.Wraw[:, :, n] = mu[n] * cp.dot(U[:, n, :], W[:, n, :].T) if ibatch < niter - nBatches - 1: # during the main "learning" phase of fitting a model if ibatch % 5 == 0: # this drops templates based on spike rates and/or similarities to # other templates W, U, dWU, mu, nsp, ndrop = triageTemplates2( params, iW, C2C, W, U, dWU, mu, nsp, ndrop) Nfilt = W.shape[1] # update the number of filters Params[1] = Nfilt # this adds new templates if they are detected in the residual dWU0, cmap = mexGetSpikes2(Params, drez, wTEMP, iC) if dWU0.shape[2] > 0: # new templates need to be integrated into the same format as all templates # apply PCA for smoothing purposes dWU0 = cp.reshape(cp.dot( wPCAd, cp.dot( wPCAd.T, dWU0.reshape( (dWU0.shape[0], dWU0.shape[1] * dWU0.shape[2]), order='F'))), dWU0.shape, order='F') dWU = cp.concatenate((dWU, dWU0), axis=2) m = dWU0.shape[2] # initialize temporal components of waveforms W = _extend(W, Nfilt, Nfilt + m, W0[:, cp.ones(m, dtype=np.int32), :], axis=1) # initialize the number of spikes with the minimum allowed nsp = _extend(nsp, Nfilt, Nfilt + m, params.minFR * NT / params.fs) # initialize the amplitude of this spike with a lowish number mu = _extend(mu, Nfilt, Nfilt + m, 10) # if the number of filters exceed the maximum allowed, clip it Nfilt = min(params.Nfilt, W.shape[1]) Params[1] = Nfilt W = W[:, : Nfilt, :] # remove any new filters over the maximum allowed dWU = dWU[:, :, : Nfilt] # remove any new filters over the maximum allowed nsp = nsp[: Nfilt] # remove any new filters over the maximum allowed mu = mu[: Nfilt] # remove any new filters over the maximum allowed if ibatch > niter - nBatches - 1: # during the final extraction pass, this keeps track of all spikes and features # we memorize the spatio-temporal decomposition of the waveforms at this batch # this is currently only used in the GUI to provide an accurate reconstruction # of the raw data at this time ir.WA[..., k] = cp.asnumpy(W) ir.UA[..., k] = cp.asnumpy(U) ir.muA[..., k] = cp.asnumpy(mu) # we carefully assign the correct absolute times to spikes found in this batch ioffset = params.ntbuff - 1 if k == 0: ioffset = 0 # the first batch is special (no pre-buffer) toff = nt0min + t0 - ioffset + (NT - params.ntbuff) * k st = toff + st0 st30 = np.c_[ cp.asnumpy(st), # spike times cp.asnumpy(id0), # spike clusters (0-indexing) cp.asnumpy(x0), # template amplitudes cp.asnumpy(vexp), # residual variance of this spike korder * np.ones(st.size), # batch from which this spike was found ] # Check the number of spikes. assert st30.shape[0] == featW.shape[1] == featPC.shape[2] st3.append(st30) fW.append(featW) fWpc.append(featPC) ntot = ntot + x0.size # keeps track of total number of spikes so far if ibatch == niter - nBatches - 1: # these next three store the low-d template decompositions ir.WA = np.zeros((nt0, Nfilt, Nrank, nBatches), dtype=np.float32, order='F') ir.UA = np.zeros((Nchan, Nfilt, Nrank, nBatches), dtype=np.float32, order='F') ir.muA = np.zeros((Nfilt, nBatches), dtype=np.float32, order='F') if ibatch % 100 == 0: # this is some of the relevant diagnostic information to be printed during training logger.info(('%d / %d batches, %d units, nspks: %2.4f, mu: %2.4f, ' 'nst0: %d, merges: %2.4f, %2.4f'), ibatch, niter, Nfilt, nsp.sum(), median(mu), st0.size, *ndrop) free_gpu_memory() # Close the large array writers and save the JSON metadata files to disk. fW.close() fWpc.close() # just display the total number of spikes logger.info("Found %d spikes.", ntot) # Save results to the ctx.intermediate object. ir.st3 = np.concatenate(st3, axis=0) # the similarity score between templates is simply the correlation, # taken as the max over several consecutive time delays ir.simScore = cp.asnumpy(cp.max(WtW, axis=2)) # NOTE: these are now already saved by LargeArrayWriter # fWa = np.concatenate(fW, axis=-1) # fWpca = np.concatenate(fWpc, axis=-1) # the template features are stored in cProj, like in Kilosort1 # ir.cProj = fWa.T # the neihboring templates idnices are stored in iNeigh ir.iNeigh = cp.asnumpy(iList) # permute the PC projections in the right order # ir.cProjPC = np.transpose(fWpca, (2, 1, 0)) # iNeighPC keeps the indices of the channels corresponding to the PC features ir.iNeighPC = cp.asnumpy(iC[:, iW]) # Number of spikes. assert ir.st3.shape[0] == fW.shape[-1] == fWpc.shape[-1] # this whole next block is just done to compress the compressed templates # we separately svd the time components of each template, and the spatial components # this also requires a careful decompression function, available somewhere in the GUI code nKeep = min(Nchan * 3, 20) # how many PCs to keep W_a = np.zeros((nt0 * Nrank, nKeep, Nfilt), dtype=np.float32) W_b = np.zeros((nBatches, nKeep, Nfilt), dtype=np.float32) U_a = np.zeros((Nchan * Nrank, nKeep, Nfilt), dtype=np.float32) U_b = np.zeros((nBatches, nKeep, Nfilt), dtype=np.float32) for j in tqdm(range(Nfilt), desc='Compressing templates'): # do this for every template separately WA = np.reshape(ir.WA[:, j, ...], (-1, nBatches), order='F') # svd on the GPU was faster for this, but the Python randomized CPU version # might be faster still # WA = gpuArray(WA) A, B, C = svdecon_cpu(WA) # W_a times W_b results in a reconstruction of the time components W_a[:, :, j] = np.dot(A[:, :nKeep], B[:nKeep, :nKeep]) W_b[:, :, j] = C[:, :nKeep] UA = np.reshape(ir.UA[:, j, ...], (-1, nBatches), order='F') # UA = gpuArray(UA) A, B, C = svdecon_cpu(UA) # U_a times U_b results in a reconstruction of the time components U_a[:, :, j] = np.dot(A[:, :nKeep], B[:nKeep, :nKeep]) U_b[:, :, j] = C[:, :nKeep] logger.info('Finished compressing time-varying templates.') return Bunch( wPCA=wPCA[:, :Nrank], wTEMP=wTEMP, st3=ir.st3, simScore=ir.simScore, # cProj=ir.cProj, # cProjPC=ir.cProjPC, iNeigh=ir.iNeigh, iNeighPC=ir.iNeighPC, WA=ir.WA, UA=ir.UA, W=ir.W, U=ir.U, dWU=ir.dWU, mu=ir.mu, W_a=W_a, W_b=W_b, U_a=U_a, U_b=U_b, )
def splitAllClusters(ctx, flag): # I call this algorithm "bimodal pursuit" # split clusters if they have bimodal projections # the strategy is to maximize a bimodality score and find a single vector projection # that maximizes it. If the distribution along that maximal projection crosses a # bimodality threshold, then the cluster is split along that direction # it only uses the PC features for each spike, stored in ir.cProjPC params = ctx.params probe = ctx.probe ir = ctx.intermediate Nchan = ctx.probe.Nchan wPCA = cp.asarray( ir.wPCA ) # use PCA projections to reconstruct templates when we do splits assert wPCA.shape[1] == 3 # Take intermediate arrays from context. st3 = cp.asnumpy(ir.st3_m) cProjPC = ir.cProjPC dWU = ir.dWU # For the following arrays that will be overwritten by this function, try to get # it from a previous call to this function (as it is called twice), otherwise # get it from before (without the _s suffix). W = ir.get('W_s', ir.W) simScore = ir.get('simScore_s', ir.simScore) iNeigh = ir.get('iNeigh_s', ir.iNeigh) iNeighPC = ir.get('iNeighPC_s', ir.iNeighPC) # this is the threshold for splits, and is one of the main parameters users can change ccsplit = params.AUCsplit NchanNear = min(Nchan, 32) Nnearest = min(Nchan, 32) sigmaMask = params.sigmaMask ik = -1 Nfilt = W.shape[1] nsplits = 0 # determine what channels each template lives on iC, mask, C2C = getClosestChannels(probe, sigmaMask, NchanNear) # the waveforms must be aligned to this sample nt0min = params.nt0min # find the peak abs channel for each template iW = np.argmax(np.abs((dWU[nt0min - 1, :, :])), axis=0) # keep track of original cluster for each cluster. starts with all clusters being their # own origin. isplit = np.arange(Nfilt) dt = 1. / 1000 nccg = 0 while ik < Nfilt: if ik % 100 == 0: # periodically write updates logger.info( f'Found {nsplits} splits, checked {ik}/{Nfilt} clusters, nccg {nccg}' ) ik += 1 isp = (st3[:, 1] == ik) # get all spikes from this cluster nSpikes = isp.sum() logger.debug(f"Splitting template {ik}/{Nfilt} with {nSpikes} spikes.") free_gpu_memory() if nSpikes < 300: # TODO: move_to_config # do not split if fewer than 300 spikes (we cannot estimate # cross-correlograms accurately) continue ss = st3[isp, 0] / params.fs # convert to seconds clp0 = cProjPC[isp, :, :] # get the PC projections for these spikes clp0 = cp.asarray(clp0, dtype=cp.float32) # upload to the GPU clp0 = clp0.reshape((clp0.shape[0], -1), order='F') clp = clp0 - mean(clp0, axis=0) # mean center them isp = np.nonzero(isp)[0] # subtract a running average, because the projections are NOT drift corrected clp = clp - my_conv2(clp, 250, 0) # now use two different ways to initialize the bimodal direction # the main script calls this function twice, and does both initializations if flag: u, s, v = svdecon(clp.T) # u, v = -u, -v # change sign for consistency with MATLAB w = u[:, 0] # initialize with the top PC else: w = mean(clp0, axis=0 ) # initialize with the mean of NOT drift-corrected trace w = w / cp.sum(w**2)**0.5 # unit-normalize # initial projections of waveform PCs onto 1D vector x = cp.dot(clp, w) s1 = var( x[x > mean(x)]) # initialize estimates of variance for the first s2 = var(x[ x < mean(x)]) # and second gaussian in the mixture of 1D gaussians mu1 = mean(x[x > mean(x)]) # initialize the means as well mu2 = mean(x[x < mean(x)]) # and the probability that a spike is assigned to the first Gaussian p = mean(x > mean(x)) # initialize matrix of log probabilities that each spike is assigned to the first # or second cluster logp = cp.zeros((nSpikes, 2), order='F') # do 50 pursuit iteration logP = cp.zeros(50) # used to monitor the cost function # TODO: move_to_config - maybe... for k in range(50): # for each spike, estimate its probability to come from either Gaussian cluster logp[:, 0] = -1. / 2 * log(s1) - ((x - mu1)**2) / (2 * s1) + log(p) logp[:, 1] = -1. / 2 * log(s2) - ( (x - mu2)**2) / (2 * s2) + log(1 - p) lMax = logp.max(axis=1) logp = logp - lMax[:, cp. newaxis] # subtract the max for floating point accuracy rs = cp.exp(logp) # exponentiate the probabilities pval = cp.log(cp.sum( rs, axis=1)) + lMax # get the normalizer and add back the max logP[k] = mean( pval) # this is the cost function: we can monitor its increase rs = rs / cp.sum( rs, axis=1)[:, cp.newaxis] # normalize so that probabilities sum to 1 p = mean(rs[:, 0]) # mean probability to be assigned to Gaussian 1 # new estimate of mean of cluster 1 (weighted by "responsibilities") mu1 = cp.dot(rs[:, 0], x) / cp.sum(rs[:, 0]) # new estimate of mean of cluster 2 (weighted by "responsibilities") mu2 = cp.dot(rs[:, 1], x) / cp.sum(rs[:, 1]) s1 = cp.dot(rs[:, 0], (x - mu1)**2) / cp.sum( rs[:, 0]) # new estimates of variances s2 = cp.dot(rs[:, 1], (x - mu2)**2) / cp.sum(rs[:, 1]) if (k >= 10) and (k % 2 == 0): # starting at iteration 10, we start re-estimating the pursuit direction # that is, given the Gaussian cluster assignments, and the mean and variances, # we re-estimate w # these equations follow from the model StS = cp.matmul( clp.T, clp * (rs[:, 0] / s1 + rs[:, 1] / s2)[:, cp.newaxis]) / nSpikes StMu = cp.dot( clp.T, rs[:, 0] * mu1 / s1 + rs[:, 1] * mu2 / s2) / nSpikes # this is the new estimate of the best pursuit direction w = cp.linalg.solve(StS.T, StMu) w = w / cp.sum(w**2)**0.5 # which we unit normalize x = cp.dot(clp, w) # these spikes are assigned to cluster 1 ilow = rs[:, 0] > rs[:, 1] # the mean probability of spikes assigned to cluster 1 plow = mean(rs[:, 0][ilow]) phigh = mean(rs[:, 1][~ilow]) # same for cluster 2 # the smallest cluster has this proportion of all spikes nremove = min(mean(ilow), mean(~ilow)) # did this split fix the autocorrelograms? # compute the cross-correlogram between spikes in the putative new clusters ilow_cpu = cp.asnumpy(ilow) K, Qi, Q00, Q01, rir = ccg(ss[ilow_cpu], ss[~ilow_cpu], 500, dt) Q12 = (Qi / max(Q00, Q01)).min() # refractoriness metric 1 R = rir.min() # refractoriness metric 2 # if the CCG has a dip, don't do the split. # These thresholds are consistent with the ones from merges. # TODO: move_to_config (or at least a single constant so the are the same as the merges) if (Q12 < 0.25) and (R < 0.05): # if both metrics are below threshold. nccg += 1 # keep track of how many splits were voided by the CCG criterion continue # now decide if the split would result in waveforms that are too similar # the reconstructed mean waveforms for putative cluster 1 # c1 = cp.matmul(wPCA, cp.reshape((mean(clp0[ilow, :], 0), 3, -1), order='F')) c1 = cp.matmul(wPCA, mean(clp0[ilow, :], 0).reshape((3, -1), order='F')) # the reconstructed mean waveforms for putative cluster 2 # c2 = cp.matmul(wPCA, cp.reshape((mean(clp0[~ilow, :], 0), 3, -1), order='F')) c2 = cp.matmul(wPCA, mean(clp0[~ilow, :], 0).reshape((3, -1), order='F')) cc = cp.corrcoef(c1.ravel(), c2.ravel()) # correlation of mean waveforms n1 = sqrt(cp.sum(c1**2)) # the amplitude estimate 1 n2 = sqrt(cp.sum(c2**2)) # the amplitude estimate 2 r0 = 2 * abs((n1 - n2) / (n1 + n2)) # if the templates are correlated, and their amplitudes are similar, stop the split!!! # TODO: move_to_config if (cc[0, 1] > 0.9) and (r0 < 0.2): continue # finaly criteria to continue with the split: if the split piece is more than 5% of all # spikes, if the split piece is more than 300 spikes, and if the confidences for # assigning spikes to # both clusters exceeds a preset criterion ccsplit # TODO: move_to_config if (nremove > 0.05) and (min(plow, phigh) > ccsplit) and (min( cp.sum(ilow), cp.sum(~ilow)) > 300): # one cluster stays, one goes Nfilt += 1 # the templates for the splits have been estimated from PC coefficients # (DEV_NOTES) code below involves multiple CuPy arrays changing shape to accomodate # the extra cluster, this could potentially be done more efficiently? dWU = cp.concatenate( (cp.asarray(dWU), cp.zeros((*dWU.shape[:-1], 1), order='F')), axis=2) dWU[:, iC[:, iW[ik]], Nfilt - 1] = c2 dWU[:, iC[:, iW[ik]], ik] = c1 # the temporal components are therefore just the PC waveforms W = cp.asarray(W) W = cp.concatenate((W, cp.transpose(cp.atleast_3d(wPCA), (0, 2, 1))), axis=1) assert W.shape[1] == Nfilt # copy the best channel from the original template iW = cp.asarray(iW) iW = cp.pad(iW, (0, (Nfilt - len(iW))), mode='constant') iW[Nfilt - 1] = iW[ik] assert iW.shape[0] == Nfilt # copy the provenance index to keep track of splits isplit = cp.asarray(isplit) isplit = cp.pad(isplit, (0, (Nfilt - len(isplit))), mode='constant') isplit[Nfilt - 1] = isplit[ik] assert isplit.shape[0] == Nfilt st3[isp[ilow_cpu], 1] = Nfilt - 1 # overwrite spike indices with the new index # copy similarity scores from the original simScore = cp.asarray(simScore) simScore = cp.pad(simScore, (0, (Nfilt - simScore.shape[0])), mode='constant') simScore[:, Nfilt - 1] = simScore[:, ik] simScore[Nfilt - 1, :] = simScore[ik, :] # copy similarity scores from the original simScore[ik, Nfilt - 1] = 1 # set the similarity with original to 1 simScore[Nfilt - 1, ik] = 1 # set the similarity with original to 1 assert simScore.shape == (Nfilt, Nfilt) # copy neighbor template list from the original iNeigh = cp.asarray(iNeigh) iNeigh = cp.pad(iNeigh, ((0, 0), (0, (Nfilt - iNeigh.shape[1]))), mode='constant') iNeigh[:, Nfilt - 1] = iNeigh[:, ik] assert iNeigh.shape[1] == Nfilt # copy neighbor channel list from the original iNeighPC = cp.asarray(iNeighPC) iNeighPC = cp.pad(iNeighPC, ((0, 0), (0, (Nfilt - iNeighPC.shape[1]))), mode='constant') iNeighPC[:, Nfilt - 1] = iNeighPC[:, ik] assert iNeighPC.shape[1] == Nfilt # try this cluster again # the cluster piece that stays at this index needs to be tested for splits again # before proceeding ik -= 1 # the piece that became a new cluster will be tested again when we get to the end # of the list nsplits += 1 # keep track of how many splits we did # pbar.update(ik) # pbar.close() logger.info(f'Finished splitting. Found {nsplits} splits, checked ' f'{ik}/{Nfilt} clusters, nccg {nccg}') Nfilt = W.shape[1] # new number of templates Nrank = 3 Nchan = probe.Nchan Params = cp.array( [ 0, Nfilt, 0, 0, W.shape[0], Nnearest, Nrank, 0, 0, Nchan, NchanNear, nt0min, 0 ], dtype=cp.float64) # make a new Params to pass on parameters to CUDA # we need to re-estimate the spatial profiles # we get the time upsampling kernels again Ka, Kb = getKernels(params) # we run SVD W, U, mu = mexSVDsmall2(Params, dWU, W, iC, iW, Ka, Kb) # we re-compute similarity scores between templates WtW, iList = getMeWtW(W.astype(cp.float32), U.astype(cp.float32), Nnearest) # ir.iList = iList # over-write the list of nearest templates isplit = simScore == 1 # overwrite the similarity scores of clusters with same parent simScore = WtW.max(axis=2) simScore[isplit] = 1 # 1 means they come from the same parent iNeigh = iList[:, :Nfilt] # get the new neighbor templates iNeighPC = iC[:, iW[:Nfilt]] # get the new neighbor channels # for Phy, we need to pad the spikes with zeros so the spikes are aligned to the center of # the window Wphy = cp.concatenate((cp.zeros((1 + nt0min, Nfilt, Nrank), order='F'), W), axis=0) # ir.isplit = isplit # keep track of origins for each cluster return Bunch( st3_s=st3, W_s=W, U_s=U, mu_s=mu, simScore_s=simScore, iNeigh_s=iNeigh, iNeighPC_s=iNeighPC, Wphy=Wphy, iList=iList, isplit=isplit, )
def random_walker(data, labels, beta=130, mode='cg_j', tol=1.e-3, copy=True, multichannel=False, return_full_prob=False, spacing=None, *, prob_tol=1e-3): """Random walker algorithm for segmentation from markers. Random walker algorithm is implemented for gray-level or multichannel images. Parameters ---------- data : array_like Image to be segmented in phases. Gray-level `data` can be two- or three-dimensional; multichannel data can be three- or four- dimensional (multichannel=True) with the highest dimension denoting channels. Data spacing is assumed isotropic unless the `spacing` keyword argument is used. labels : array of ints, of same shape as `data` without channels dimension Array of seed markers labeled with different positive integers for different phases. Zero-labeled pixels are unlabeled pixels. Negative labels correspond to inactive pixels that are not taken into account (they are removed from the graph). If labels are not consecutive integers, the labels array will be transformed so that labels are consecutive. In the multichannel case, `labels` should have the same shape as a single channel of `data`, i.e. without the final dimension denoting channels. beta : float, optional Penalization coefficient for the random walker motion (the greater `beta`, the more difficult the diffusion). mode : string, available options {'cg', 'cg_j', 'cg_mg', 'bf'} Mode for solving the linear system in the random walker algorithm. - 'bf' (brute force): an LU factorization of the Laplacian is computed. This is fast for small images (<1024x1024), but very slow and memory-intensive for large images (e.g., 3-D volumes). - 'cg' (conjugate gradient): the linear system is solved iteratively using the Conjugate Gradient method from scipy.sparse.linalg. This is less memory-consuming than the brute force method for large images, but it is quite slow. - 'cg_j' (conjugate gradient with Jacobi preconditionner): the Jacobi preconditionner is applyed during the Conjugate gradient method iterations. This may accelerate the convergence of the 'cg' method. - 'cg_mg' (conjugate gradient with multigrid preconditioner): a preconditioner is computed using a multigrid solver, then the solution is computed with the Conjugate Gradient method. This mode requires that the pyamg module is installed. tol : float, optional Tolerance to achieve when solving the linear system using the conjugate gradient based modes ('cg', 'cg_j' and 'cg_mg'). copy : bool, optional If copy is False, the `labels` array will be overwritten with the result of the segmentation. Use copy=False if you want to save on memory. multichannel : bool, optional If True, input data is parsed as multichannel data (see 'data' above for proper input format in this case). return_full_prob : bool, optional If True, the probability that a pixel belongs to each of the labels will be returned, instead of only the most likely label. spacing : iterable of floats, optional Spacing between voxels in each spatial dimension. If `None`, then the spacing between pixels/voxels in each dimension is assumed 1. prob_tol : float, optional Tolerance on the resulting probability to be in the interval [0, 1]. If the tolerance is not satisfied, a warning is displayed. Returns ------- output : ndarray * If `return_full_prob` is False, array of ints of same shape and data type as `labels`, in which each pixel has been labeled according to the marker that reached the pixel first by anisotropic diffusion. * If `return_full_prob` is True, array of floats of shape `(nlabels, labels.shape)`. `output[label_nb, i, j]` is the probability that label `label_nb` reaches the pixel `(i, j)` first. See Also -------- skimage.morphology.watershed : watershed segmentation A segmentation algorithm based on mathematical morphology and "flooding" of regions from markers. Notes ----- Multichannel inputs are scaled with all channel data combined. Ensure all channels are separately normalized prior to running this algorithm. The `spacing` argument is specifically for anisotropic datasets, where data points are spaced differently in one or more spatial dimensions. Anisotropic data is commonly encountered in medical imaging. The algorithm was first proposed in [1]_. The algorithm solves the diffusion equation at infinite times for sources placed on markers of each phase in turn. A pixel is labeled with the phase that has the greatest probability to diffuse first to the pixel. The diffusion equation is solved by minimizing x.T L x for each phase, where L is the Laplacian of the weighted graph of the image, and x is the probability that a marker of the given phase arrives first at a pixel by diffusion (x=1 on markers of the phase, x=0 on the other markers, and the other coefficients are looked for). Each pixel is attributed the label for which it has a maximal value of x. The Laplacian L of the image is defined as: - L_ii = d_i, the number of neighbors of pixel i (the degree of i) - L_ij = -w_ij if i and j are adjacent pixels The weight w_ij is a decreasing function of the norm of the local gradient. This ensures that diffusion is easier between pixels of similar values. When the Laplacian is decomposed into blocks of marked and unmarked pixels:: L = M B.T B A with first indices corresponding to marked pixels, and then to unmarked pixels, minimizing x.T L x for one phase amount to solving:: A x = - B x_m where x_m = 1 on markers of the given phase, and 0 on other markers. This linear system is solved in the algorithm using a direct method for small images, and an iterative method for larger images. References ---------- .. [1] Leo Grady, Random walks for image segmentation, IEEE Trans Pattern Anal Mach Intell. 2006 Nov;28(11):1768-83. :DOI:`10.1109/TPAMI.2006.233`. Examples -------- >>> import cupy as cp >>> cp.random.seed(0) >>> a = cp.zeros((10, 10)) + 0.2 * cp.random.rand(10, 10) >>> a[5:8, 5:8] += 1 >>> b = cp.zeros_like(a, dtype=cp.int32) >>> b[3, 3] = 1 # Marker for first phase >>> b[6, 6] = 2 # Marker for second phase >>> random_walker(a, b) array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 2, 2, 2, 1, 1], [1, 1, 1, 1, 1, 2, 2, 2, 1, 1], [1, 1, 1, 1, 1, 2, 2, 2, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32) """ # Parse input data if mode not in ('cg_mg', 'cg', 'bf', 'cg_j', None): raise ValueError( "{mode} is not a valid mode. Valid modes are 'cg_mg'," " 'cg', 'cg_j', 'bf' and None".format(mode=mode)) if data.dtype.kind == 'f': float_dtype = cp.promote_types(data.dtype, cp.float32) else: float_dtype = cp.float64 # Spacing kwarg checks if spacing is None: spacing = cp.ones(3, dtype=float_dtype) elif len(spacing) == labels.ndim: if len(spacing) == 2: # Need a dummy spacing for singleton 3rd dim spacing = cp.r_[spacing, 1.] spacing = cp.asarray(spacing, dtype=float_dtype) else: raise ValueError('Input argument `spacing` incorrect, should be an ' 'iterable with one number per spatial dimension.') # This algorithm expects 4-D arrays of floats, where the first three # dimensions are spatial and the final denotes channels. 2-D images have # a singleton placeholder dimension added for the third spatial dimension, # and single channel images likewise have a singleton added for channels. # The following block ensures valid input and coerces it to the correct # form. if not multichannel: if data.ndim not in (2, 3): raise ValueError('For non-multichannel input, data must be of ' 'dimension 2 or 3.') if data.shape != labels.shape: raise ValueError('Incompatible data and labels shapes.') data = cp.atleast_3d(img_as_float(data))[..., cp.newaxis] else: if data.ndim not in (3, 4): raise ValueError('For multichannel input, data must have 3 or 4 ' 'dimensions.') if data.shape[:-1] != labels.shape: raise ValueError('Incompatible data and labels shapes.') data = img_as_float(data) if data.ndim == 3: # 2D multispectral, needs singleton in 3rd axis data = data[:, :, cp.newaxis, :] labels_shape = labels.shape labels_dtype = labels.dtype if copy: labels = cp.copy(labels) (labels, nlabels, mask, inds_isolated_seeds, isolated_values) = _preprocess(labels) if isolated_values is None: # No non isolated zero valued areas in labels were # found. Returning provided labels. if return_full_prob: # Return the concatenation of the masks of each unique label unique_labels = cp.unique(labels) labels = cp.atleast_3d(labels) return cp.concatenate([labels == lab for lab in unique_labels if lab > 0], axis=-1) return labels # Build the linear system (lap_sparse, B) lap_sparse, B = _build_linear_system(data, spacing, labels, nlabels, mask, beta, multichannel) # Solve the linear system lap_sparse X = B # where X[i, j] is the probability that a marker of label i arrives # first at pixel j by anisotropic diffusion. X = _solve_linear_system(lap_sparse, B, tol, mode) if X.min() < -prob_tol or X.max() > 1 + prob_tol: warn('The probability range is outside [0, 1] given the tolerance ' '`prob_tol`. Consider decreasing `beta` and/or decreasing ' '`tol`.') # Build the output according to return_full_prob value # Put back labels of isolated seeds labels[inds_isolated_seeds] = isolated_values labels = labels.reshape(labels_shape) mask = labels == 0 mask[inds_isolated_seeds] = False if return_full_prob: out = cp.zeros((nlabels,) + labels_shape) for lab, (label_prob, prob) in enumerate(zip(out, X), start=1): label_prob[mask] = prob label_prob[labels == lab] = 1 else: X = cp.argmax(X, axis=0) + 1 out = labels.astype(labels_dtype) out[mask] = X return out