Esempio n. 1
0
def get_SpikeSample(dataRAW, row, col, params):
    """
    given a batch of data (time by channels), and some time (row) and channel (col) indices for
    spikes, this function returns the 1D time clips of voltage around those spike times
    """
    nT, nChan = dataRAW.shape

    # times around the peak to consider
    dt = cp.arange(params.nt0)

    # the negativity is expected at nt0min, so we align the detected peaks there
    dt = -params.nt0min + dt

    # temporal indices (awkward way to index into full matrix of data)
    indsT = row + dt[:, np.newaxis] + 1  # broadcasting
    indsC = col

    indsC[indsC <
          0] = 0  # anything that's out of bounds just gets set to the limit
    indsC[
        indsC >=
        nChan] = nChan - 1  # only needed for channels not time (due to time buffer)

    indsT = cp.transpose(cp.atleast_3d(indsT), [0, 2, 1])
    indsC = cp.transpose(cp.atleast_3d(indsC), [2, 0, 1])

    # believe it or not, these indices grab just the right timesamples forour spikes
    ix = indsT + indsC * nT

    # grab the data and reshape it appropriately (time samples  by channels by num spikes)
    clips = dataRAW.T.ravel()[ix[:, 0, :]].reshape((dt.size, row.size),
                                                   order='F')  # HERE
    return clips
Esempio n. 2
0
def _preprocess(labels):

    label_values, inv_idx = cp.unique(labels, return_inverse=True)
    if not (label_values == 0).any():
        warn('Random walker only segments unlabeled areas, where '
             'labels == 0. No zero valued areas in labels were '
             'found. Returning provided labels.',
             stacklevel=2)

        return labels, None, None, None, None

    # If some labeled pixels are isolated inside pruned zones, prune them
    # as well and keep the labels for the final output

    null_mask = labels == 0
    pos_mask = labels > 0
    mask = labels >= 0

    fill = ndi.binary_propagation(null_mask, mask=mask)
    isolated = cp.logical_and(pos_mask, cp.logical_not(fill))

    pos_mask[isolated] = False

    # If the array has pruned zones, be sure that no isolated pixels
    # exist between pruned zones (they could not be determined)
    if label_values[0] < 0 or cp.any(isolated):  # synchronize!
        isolated = cp.logical_and(
            cp.logical_not(ndi.binary_propagation(pos_mask, mask=mask)),
            null_mask)

        labels[isolated] = -1
        if cp.all(isolated[null_mask]):
            warn('All unlabeled pixels are isolated, they could not be '
                 'determined by the random walker algorithm.',
                 stacklevel=2)
            return labels, None, None, None, None

        mask[isolated] = False
        mask = cp.atleast_3d(mask)

    else:
        mask = None

    # Reorder label values to have consecutive integers (no gaps)
    zero_idx = cp.searchsorted(label_values, cp.array(0))
    labels = cp.atleast_3d(inv_idx.reshape(labels.shape) - zero_idx)

    nlabels = label_values[zero_idx + 1:].shape[0]

    inds_isolated_seeds = cp.nonzero(isolated)
    isolated_values = labels[inds_isolated_seeds]

    return labels, nlabels, mask, inds_isolated_seeds, isolated_values
Esempio n. 3
0
def dstack(tup):
    """Stacks arrays along the third axis.

    Args:
        tup (sequence of arrays): Arrays to be stacked. Each array is converted
            by :func:`cupy.atleast_3d` before stacking.

    Returns:
        cupy.ndarray: Stacked array.

    .. seealso:: :func:`numpy.dstack`

    """
    return concatenate(cupy.atleast_3d(*tup), 2)
Esempio n. 4
0
File: join.py Progetto: 2php/chainer
def dstack(tup):
    """Stacks arrays along the third axis.

    Args:
        tup (sequence of arrays): Arrays to be stacked. Each array is converted
            by :func:`cupy.atleast_3d` before stacking.

    Returns:
        cupy.ndarray: Stacked array.

    .. seealso:: :func:`numpy.dstack`

    """
    return concatenate(cupy.atleast_3d(*tup), 2)
Esempio n. 5
0
def clusterSingleBatches(ctx,
                         sanity_plots=False,
                         plot_widgets=None,
                         plot_pos=0):
    """
    outputs an ordering of the batches according to drift
    for each batch, it extracts spikes as threshold crossings and clusters them with kmeans
    the resulting cluster means are then compared for all pairs of batches, and a dissimilarity
    score is assigned to each pair
    the matrix of similarity scores is then re-ordered so that low dissimilaity is along
    the diagonal
    """
    Nbatch = ctx.intermediate.Nbatch
    params = ctx.params
    probe = ctx.probe
    raw_data = ctx.raw_data
    ir = ctx.intermediate
    proc = ir.proc

    if not params.reorder:
        # if reordering is turned off, return consecutive order
        iorig = np.arange(Nbatch)
        return iorig, None, None

    nPCs = params.nPCs
    Nfilt = ceil(probe.Nchan / 2)

    # extract PCA waveforms pooled over channels
    wPCA = extractPCfromSnippets(proc,
                                 probe=probe,
                                 params=params,
                                 Nbatch=Nbatch)

    Nchan = probe.Nchan
    # TODO: move_to_config
    niter = 10  # iterations for k-means. we won't run it to convergence to save time

    nBatches = Nbatch
    NchanNear = min(Nchan, 2 * 8 + 1)

    # initialize big arrays on the GPU to hold the results from each batch
    # this holds the unit norm templates
    Ws = cp.zeros((nPCs, NchanNear, Nfilt, nBatches),
                  dtype=np.float32,
                  order='F')
    # this holds the scalings
    mus = cp.zeros((Nfilt, nBatches), dtype=np.float32, order='F')
    # this holds the number of spikes for that cluster
    ns = cp.zeros((Nfilt, nBatches), dtype=np.float32, order='F')
    # this holds the center channel for each template
    Whs = ones((Nfilt, nBatches), dtype=np.int32, order='F')

    i0 = 0
    # TODO: move_to_config
    NrankPC = 3  # I am not sure if this gets used, but it goes into the function

    # return an array of closest channels for each channel
    iC = getClosestChannels(probe, params.sigmaMask, NchanNear)[0]

    for ibatch in tqdm(range(nBatches), desc="Clustering spikes"):
        enough_spikes = False

        # extract spikes using PCA waveforms
        uproj, call = extractPCbatch2(proc, params, probe, wPCA,
                                      min(nBatches - 2, ibatch), iC, Nbatch)

        if cp.sum(cp.isnan(uproj)) > 0:
            break  # I am not sure what case this safeguards against....

        if uproj.shape[1] > Nfilt:
            enough_spikes = True

            # this initialize the k-means
            W, mu, Wheights, irand = initializeWdata2(call, uproj, Nchan, nPCs,
                                                      Nfilt, iC)

            # Params is a whole bunch of parameters sent to the C++ scripts inside a float64 vector
            Params = [
                uproj.shape[1], NrankPC, Nfilt, 0, W.shape[0], 0, NchanNear,
                Nchan
            ]

            for i in range(niter):

                Wheights = Wheights.reshape((1, 1, -1), order='F')
                iC = cp.atleast_3d(iC)

                # we only compute distances to clusters on the same channels
                # this tells us which spikes and which clusters might match
                iMatch = cp.min(cp.abs(iC - Wheights), axis=0) < .1

                # get iclust and update W
                # CUDA script to efficiently compute distances for pairs in which iMatch is 1
                dWU, iclust, dx, nsp, dV = mexClustering2(
                    Params, uproj, W, mu, call, iMatch, iC)

                dWU = dWU / (
                    1e-5 + nsp.T
                )  # divide the cumulative waveform by the number of spike

                mu = cp.sqrt(cp.sum(dWU**2,
                                    axis=0))  # norm of cluster template
                W = dWU / (1e-5 + mu)  # unit normalize templates

                W = W.reshape((nPCs, Nchan, Nfilt), order='F')
                nW = W[
                    0,
                    ...]**2  # compute best channel from the square of the first PC feature
                W = W.reshape((Nchan * nPCs, Nfilt), order='F')

                Wheights = cp.argmax(
                    nW,
                    axis=0)  # the new best channel of each cluster template

            # carefully keep track of cluster templates in dense format
            W = W.reshape((nPCs, Nchan, Nfilt), order='F')
            W0 = cp.zeros((nPCs, NchanNear, Nfilt),
                          dtype=np.float32,
                          order='F')
            for t in range(Nfilt):
                W0[..., t] = W[:, iC[:, Wheights[t]], t].squeeze()
            # I don't really know why this needs another normalization
            W0 = W0 / (1e-5 + cp.sum(cp.sum(W0**2, axis=0)[np.newaxis, ...],
                                     axis=1)**.5)

        # if a batch doesn't have enough spikes, it gets the cluster templates of the previous batc
        if enough_spikes:
            Ws[..., ibatch] = W0
            mus[:, ibatch] = mu
            ns[:, ibatch] = nsp
            Whs[:, ibatch] = Wheights.astype(np.int32)
        else:
            logger.warning('Data batch #%d only had %d spikes.', ibatch,
                           uproj.shape[1])

        i0 = i0 + Nfilt

    # anothr one of these Params variables transporting parameters to the C++ code
    Params = [1, NrankPC, Nfilt, 0, W.shape[0], 0, NchanNear, Nchan]
    # the total number of templates is the number of templates per batch times the number of batch
    Params[0] = Ws.shape[2] * Ws.shape[3]

    # initialize dissimilarity matrix
    ccb = cp.zeros((nBatches, nBatches), dtype=np.float32, order='F')

    for ibatch in tqdm(range(nBatches), desc="Computing distances"):
        # for every batch, compute in parallel its dissimilarity to ALL other batches
        Wh0 = Whs[:, ibatch]  # this one is the primary batch
        W0 = Ws[..., ibatch]
        mu = mus[..., ibatch]

        # embed the templates from the primary batch back into a full, sparse representation
        W = cp.zeros((nPCs, Nchan, Nfilt), dtype=np.float32, order='F')
        for t in range(Nfilt):
            W[:, iC[:, Wh0[t]], t] = cp.atleast_3d(Ws[:, :, t, ibatch])

        # pairs of templates that live on the same channels are potential "matches"
        # TODO: move_to_config? is the 0.1 here important?
        iMatch = cp.min(cp.abs(iC - Wh0.reshape((1, 1, -1), order='F')),
                        axis=0) < .1

        # compute dissimilarities for iMatch = 1
        iclust, ds = mexDistances2(Params, Ws, W, iMatch, iC, Whs, mus, mu)

        # ds are squared Euclidian distances
        ds = ds.reshape((Nfilt, -1),
                        order='F')  # this should just be an Nfilt-long vector
        ds = cp.maximum(0, ds)

        # weigh the distances according to number of spikes in cluster
        ccb[ibatch, :] = cp.mean(cp.sqrt(ds) * ns, axis=0) / cp.mean(ns,
                                                                     axis=0)

    # ccb = cp.asnumpy(ccb)
    # some normalization steps are needed: zscoring, and symmetrizing ccb
    ccb0 = zscore(ccb, axis=0)
    ccb0 = ccb0 + ccb0.T

    # sort by manifold embedding algorithm
    # iorig is the sorting of the batches
    # ccbsort is the resorted matrix (useful for diagnosing drift)
    ccbsort, iorig = sortBatches2(ccb0)

    logger.info("Finished clustering.")

    if sanity_plots:
        assert plot_widgets is not None, "if sanity_plots is set, then plot_widgets cannot be None"
        plot_dissimilarity_matrices(ccb0, ccbsort, plot_widgets[plot_pos])

    return Bunch(iorig=iorig, ccb0=ccb0, ccbsort=ccbsort)
Esempio n. 6
0
def learnAndSolve8b(ctx):
    """This is the main optimization. Takes the longest time and uses the GPU heavily."""

    Nbatch = ctx.intermediate.Nbatch
    params = ctx.params
    probe = ctx.probe
    ir = ctx.intermediate
    proc = ir.proc

    iorig = ir.iorig

    # TODO: move_to_config
    NrankPC = 6  # this one is the rank of the PCs, used to detect spikes with threshold crossings
    Nrank = 3  # this one is the rank of the templates

    wTEMP, wPCA = extractTemplatesfromSnippets(proc=proc,
                                               probe=probe,
                                               params=params,
                                               Nbatch=Nbatch,
                                               nPCs=NrankPC)

    # move these to the GPU
    wPCA = cp.asarray(wPCA[:, :Nrank], dtype=np.float32, order='F')
    wTEMP = cp.asarray(wTEMP, dtype=np.float32, order='F')
    wPCAd = cp.asarray(wPCA, dtype=np.float64,
                       order='F')  # convert to double for extra precision

    nt0 = params.nt0
    nt0min = params.nt0min
    nBatches = Nbatch
    NT = params.NT
    Nfilt = params.Nfilt
    Nchan = probe.Nchan

    # two variables for the same thing? number of nearest channels to each primary channel
    # TODO: unclear - let's fix this
    NchanNear = min(probe.Nchan, 32)
    Nnearest = min(probe.Nchan, 32)

    # decay of gaussian spatial mask centered on a channel
    sigmaMask = params.sigmaMask

    batchstart = list(range(0, NT * nBatches + 1, NT))

    # find the closest NchanNear channels, and the masks for those channels
    iC, mask, C2C = getClosestChannels(probe, sigmaMask, NchanNear)

    # sorting order for the batches
    isortbatches = iorig
    nhalf = int(ceil(nBatches / 2)) - 1  # halfway point

    # this batch order schedule goes through half of the data forward and backward during the model
    # fitting and then goes through the data symmetrically-out from the center during the final
    # pass
    ischedule = np.concatenate(
        (np.arange(nhalf, nBatches), np.arange(nBatches - 1, nhalf - 1, -1)))
    i1 = np.arange(nhalf - 1, -1, -1)
    i2 = np.arange(nhalf, nBatches)

    irounds = np.concatenate((ischedule, i1, i2))

    niter = irounds.size
    if irounds[niter - nBatches - 1] != nhalf:
        # this check is in here in case I do somehting weird when I try different schedules
        raise ValueError('Mismatch between number of batches')

    # these two flags are used to keep track of what stage of model fitting we're at
    # flag_final = 0
    flag_resort = 1

    # this is the absolute temporal offset in seconds corresponding to the start of the
    # spike sorted time segment
    t0 = 0  # ceil(params.trange(1) * ops.fs)

    nInnerIter = 60  # this is for SVD for the power iteration

    # schedule of learning rates for the model fitting part
    # starts small and goes high, it corresponds approximately to the number of spikes
    # from the past that were averaged to give rise to the current template
    pmi = cp.exp(
        -1. /
        cp.linspace(params.momentum[0], params.momentum[1], niter - nBatches))

    Nsum = min(
        Nchan,
        7)  # how many channels to extend out the waveform in mexgetspikes
    # lots of parameters passed into the CUDA scripts
    Params = np.array([
        NT, Nfilt, params.Th[0], nInnerIter, nt0, Nnearest, Nrank, params.lam,
        pmi[0], Nchan, NchanNear, params.nt0min, 1, Nsum, NrankPC, params.Th[0]
    ],
                      dtype=np.float64)

    # W0 has to be ordered like this
    W0 = cp.transpose(
        cp.atleast_3d(cp.asarray(wPCA, dtype=np.float64, order='F')),
        [0, 2, 1])

    # initialize the list of channels each template lives on
    iList = cp.zeros((Nnearest, Nfilt), dtype=np.int32, order='F')

    # initialize average number of spikes per batch for each template
    nsp = cp.zeros((0, 1), dtype=np.float64, order='F')

    # this flag starts 0, is set to 1 later
    Params[12] = 0

    # kernels for subsample alignment
    Ka, Kb = getKernels(params)

    p1 = .95  # decay of nsp estimate in each batch

    ntot = 0
    # this keeps track of dropped templates for debugging purposes
    ndrop = np.zeros(2, dtype=np.float32, order='F')

    # this is the minimum firing rate that all templates must maintain, or be dropped
    m0 = params.minFR * params.NT / params.fs

    # allocate variables when switching to extraction phase
    # this holds spike times, clusters and other info per spike
    st3 = []  # cp.zeros((int(1e7), 5), dtype=np.float32, order='F')

    # these ones store features per spike
    # Nnearest is the number of nearest templates to store features for
    fW = LargeArrayWriter(ctx.path('fW', ext='.dat'),
                          dtype=np.float32,
                          shape=(Nnearest, -1))
    # NchanNear is the number of nearest channels to take PC features from
    fWpc = LargeArrayWriter(ctx.path('fWpc', ext='.dat'),
                            dtype=np.float32,
                            shape=(NchanNear, Nrank, -1))

    for ibatch in tqdm(range(niter), desc="Optimizing templates"):
        # korder is the index of the batch at this point in the schedule
        korder = int(irounds[ibatch])
        # k is the index of the batch in absolute terms
        k = int(isortbatches[korder])
        logger.debug("Batch %d/%d, %d templates.", ibatch, niter, Nfilt)

        if ibatch > niter - nBatches - 1 and korder == nhalf:
            # this is required to revert back to the template states in the middle of the
            # batches
            W, dWU = ir.W, ir.dWU
            logger.debug('Reverted back to middle timepoint.')

        if ibatch < niter - nBatches:
            # obtained pm for this batch
            Params[8] = float(pmi[ibatch])
            pm = pmi[ibatch] * ones((Nfilt, ), dtype=np.float64, order='F')

        # loading a single batch (same as everywhere)
        offset = Nchan * batchstart[k]
        dat = proc.flat[offset:offset + NT * Nchan].reshape((-1, Nchan),
                                                            order='F')
        dataRAW = cp.asarray(dat, dtype=np.float32) / params.scaleproc

        if ibatch == 0:
            # only on the first batch, we first get a new set of spikes from the residuals,
            # which in this case is the unmodified data because we start with no templates
            # CUDA function to get spatiotemporal clips from spike detections
            dWU, cmap = mexGetSpikes2(Params, dataRAW, wTEMP, iC)

            dWU = cp.asarray(dWU, dtype=np.float64, order='F')

            # project these into the wPCA waveforms
            dWU = cp.reshape(cp.dot(
                wPCAd,
                cp.dot(wPCAd.T, dWU.reshape((dWU.shape[0], -1), order='F'))),
                             dWU.shape,
                             order='F')

            # initialize the low-rank decomposition with standard waves
            W = W0[:, cp.ones(dWU.shape[2], dtype=np.int32), :]
            Nfilt = W.shape[1]  # update the number of filters/templates
            # initialize the number of spikes for new templates with the minimum allowed value,
            # so it doesn't get thrown back out right away
            nsp = _extend(nsp, 0, Nfilt, m0)
            Params[1] = Nfilt  # update in the CUDA parameters

        if flag_resort:
            # this is a flag to resort the order of the templates according to best peak
            # channel
            # this is important in order to have cohesive memory requests from the GPU RAM
            # max channel (either positive or negative peak)
            iW = cp.argmax(cp.abs(dWU[nt0min - 1, :, :]), axis=0)
            # iW = int32(squeeze(iW))

            isort = cp.argsort(iW)  # sort by max abs channel
            iW = iW[isort]
            W = W[:,
                  isort, :]  # user ordering to resort all the other template variables
            dWU = dWU[:, :, isort]
            nsp = nsp[isort]

        # decompose dWU by svd of time and space (via covariance matrix of 61 by 61 samples)
        # this uses a "warm start" by remembering the W from the previous iteration
        W, U, mu = mexSVDsmall2(Params, dWU, W, iC, iW, Ka, Kb)

        # UtU is the gram matrix of the spatial components of the low-rank SVDs
        # it tells us which pairs of templates are likely to "interfere" with each other
        # such as when we subtract off a template
        # this needs to change (but I don't know why!)
        UtU, maskU = getMeUtU(iW, iC, mask, Nnearest, Nchan)

        # main CUDA function in the whole codebase. does the iterative template matching
        # based on the current templates, gets features for these templates if requested
        # (featW, featPC),
        # gets scores for the template fits to each spike (vexp), outputs the average of
        # waveforms assigned to each cluster (dWU0),
        # and probably a few more things I forget about
        st0, id0, x0, featW, dWU0, drez, nsp0, featPC, vexp = mexMPnu8(
            Params, dataRAW, U, W, mu, iC, iW, UtU, iList, wPCA)

        logger.debug("%d spikes.", x0.size)

        # Sometimes nsp can get transposed (think this has to do with it being
        # a single element in one iteration, to which elements are added
        # nsp, nsp0, and pm must all be row vectors (Nfilt x 1), so force nsp
        # to be a row vector.
        # nsp = cp.atleast_2d(nsp)
        # nsprow, nspcol = nsp.shape
        # if nsprow < nspcol:
        #     nsp = nsp.T
        nsp = nsp.squeeze()

        # updates the templates as a running average weighted by recency
        # since some clusters have different number of spikes, we need to apply the
        # exp(pm) factor several times, and fexp is the resulting update factor
        # for each template
        fexp = np.exp(nsp0 * cp.log(pm[:Nfilt]))
        fexp = cp.reshape(fexp, (1, 1, -1), order='F')
        dWU = dWU * fexp + (1 - fexp) * (
            dWU0 / cp.reshape(cp.maximum(1, nsp0), (1, 1, -1), order='F'))

        # nsp just gets updated according to the fixed factor p1
        nsp = nsp * p1 + (1 - p1) * nsp0

        if ibatch == niter - nBatches - 1:
            # if we reached this point, we need to disable secondary template updates
            # like dropping, and adding new templates. We need to memorize the state of the
            # templates at this timepoint, and set the processing mode to "extraction and
            # tracking"

            flag_resort = 0  # no need to resort templates by channel any more
            # flag_final = 1  # this is the "final" pass

            # final clean up, triage templates one last time
            W, U, dWU, mu, nsp, ndrop = triageTemplates2(
                params, iW, C2C, W, U, dWU, mu, nsp, ndrop)

            # final number of templates
            Nfilt = W.shape[1]
            Params[1] = Nfilt

            # final covariance matrix between all templates
            WtW, iList = getMeWtW(W, U, Nnearest)

            # iW is the final channel assigned to each template
            iW = cp.argmax(cp.abs(dWU[nt0min - 1, :, :]), axis=0)

            # extract ALL features on the last pass
            Params[
                12] = 2  # this is a flag to output features (PC and template features)

            # different threshold on last pass?
            Params[2] = params.Th[
                -1]  # usually the threshold is much lower on the last pass

            # memorize the state of the templates
            logger.debug("Memorized middle timepoint.")
            ir.W, ir.dWU, ir.U, ir.mu = W, dWU, U, mu
            ir.Wraw = cp.zeros((U.shape[0], W.shape[0], U.shape[1]),
                               dtype=np.float64,
                               order='F')
            for n in range(U.shape[1]):
                # temporarily use U rather Urot until I have a chance to test it
                ir.Wraw[:, :, n] = mu[n] * cp.dot(U[:, n, :], W[:, n, :].T)

        if ibatch < niter - nBatches - 1:
            # during the main "learning" phase of fitting a model
            if ibatch % 5 == 0:
                # this drops templates based on spike rates and/or similarities to
                # other templates
                W, U, dWU, mu, nsp, ndrop = triageTemplates2(
                    params, iW, C2C, W, U, dWU, mu, nsp, ndrop)

            Nfilt = W.shape[1]  # update the number of filters
            Params[1] = Nfilt

            # this adds new templates if they are detected in the residual
            dWU0, cmap = mexGetSpikes2(Params, drez, wTEMP, iC)

            if dWU0.shape[2] > 0:
                # new templates need to be integrated into the same format as all templates
                # apply PCA for smoothing purposes
                dWU0 = cp.reshape(cp.dot(
                    wPCAd,
                    cp.dot(
                        wPCAd.T,
                        dWU0.reshape(
                            (dWU0.shape[0], dWU0.shape[1] * dWU0.shape[2]),
                            order='F'))),
                                  dWU0.shape,
                                  order='F')
                dWU = cp.concatenate((dWU, dWU0), axis=2)

                m = dWU0.shape[2]
                # initialize temporal components of waveforms
                W = _extend(W,
                            Nfilt,
                            Nfilt + m,
                            W0[:, cp.ones(m, dtype=np.int32), :],
                            axis=1)

                # initialize the number of spikes with the minimum allowed
                nsp = _extend(nsp, Nfilt, Nfilt + m,
                              params.minFR * NT / params.fs)
                # initialize the amplitude of this spike with a lowish number
                mu = _extend(mu, Nfilt, Nfilt + m, 10)

                # if the number of filters exceed the maximum allowed, clip it
                Nfilt = min(params.Nfilt, W.shape[1])
                Params[1] = Nfilt

                W = W[:, :
                      Nfilt, :]  # remove any new filters over the maximum allowed
                dWU = dWU[:, :, :
                          Nfilt]  # remove any new filters over the maximum allowed
                nsp = nsp[:
                          Nfilt]  # remove any new filters over the maximum allowed
                mu = mu[:
                        Nfilt]  # remove any new filters over the maximum allowed

        if ibatch > niter - nBatches - 1:
            # during the final extraction pass, this keeps track of all spikes and features

            # we memorize the spatio-temporal decomposition of the waveforms at this batch
            # this is currently only used in the GUI to provide an accurate reconstruction
            # of the raw data at this time
            ir.WA[..., k] = cp.asnumpy(W)
            ir.UA[..., k] = cp.asnumpy(U)
            ir.muA[..., k] = cp.asnumpy(mu)

            # we carefully assign the correct absolute times to spikes found in this batch
            ioffset = params.ntbuff - 1
            if k == 0:
                ioffset = 0  # the first batch is special (no pre-buffer)

            toff = nt0min + t0 - ioffset + (NT - params.ntbuff) * k
            st = toff + st0

            st30 = np.c_[
                cp.asnumpy(st),  # spike times
                cp.asnumpy(id0),  # spike clusters (0-indexing)
                cp.asnumpy(x0),  # template amplitudes
                cp.asnumpy(vexp),  # residual variance of this spike
                korder *
                np.ones(st.size),  # batch from which this spike was found
            ]
            # Check the number of spikes.
            assert st30.shape[0] == featW.shape[1] == featPC.shape[2]
            st3.append(st30)
            fW.append(featW)
            fWpc.append(featPC)

            ntot = ntot + x0.size  # keeps track of total number of spikes so far

        if ibatch == niter - nBatches - 1:
            # these next three store the low-d template decompositions
            ir.WA = np.zeros((nt0, Nfilt, Nrank, nBatches),
                             dtype=np.float32,
                             order='F')
            ir.UA = np.zeros((Nchan, Nfilt, Nrank, nBatches),
                             dtype=np.float32,
                             order='F')
            ir.muA = np.zeros((Nfilt, nBatches), dtype=np.float32, order='F')

        if ibatch % 100 == 0:
            # this is some of the relevant diagnostic information to be printed during training
            logger.info(('%d / %d batches, %d units, nspks: %2.4f, mu: %2.4f, '
                         'nst0: %d, merges: %2.4f, %2.4f'), ibatch, niter,
                        Nfilt, nsp.sum(), median(mu), st0.size, *ndrop)

        free_gpu_memory()

    # Close the large array writers and save the JSON metadata files to disk.
    fW.close()
    fWpc.close()

    # just display the total number of spikes
    logger.info("Found %d spikes.", ntot)

    # Save results to the ctx.intermediate object.
    ir.st3 = np.concatenate(st3, axis=0)

    # the similarity score between templates is simply the correlation,
    # taken as the max over several consecutive time delays
    ir.simScore = cp.asnumpy(cp.max(WtW, axis=2))

    # NOTE: these are now already saved by LargeArrayWriter
    # fWa = np.concatenate(fW, axis=-1)
    # fWpca = np.concatenate(fWpc, axis=-1)

    # the template features are stored in cProj, like in Kilosort1
    # ir.cProj = fWa.T
    # the neihboring templates idnices are stored in iNeigh
    ir.iNeigh = cp.asnumpy(iList)

    #  permute the PC projections in the right order
    # ir.cProjPC = np.transpose(fWpca, (2, 1, 0))
    # iNeighPC keeps the indices of the channels corresponding to the PC features
    ir.iNeighPC = cp.asnumpy(iC[:, iW])

    # Number of spikes.
    assert ir.st3.shape[0] == fW.shape[-1] == fWpc.shape[-1]

    # this whole next block is just done to compress the compressed templates
    # we separately svd the time components of each template, and the spatial components
    # this also requires a careful decompression function, available somewhere in the GUI code
    nKeep = min(Nchan * 3, 20)  # how many PCs to keep
    W_a = np.zeros((nt0 * Nrank, nKeep, Nfilt), dtype=np.float32)
    W_b = np.zeros((nBatches, nKeep, Nfilt), dtype=np.float32)
    U_a = np.zeros((Nchan * Nrank, nKeep, Nfilt), dtype=np.float32)
    U_b = np.zeros((nBatches, nKeep, Nfilt), dtype=np.float32)

    for j in tqdm(range(Nfilt), desc='Compressing templates'):
        # do this for every template separately
        WA = np.reshape(ir.WA[:, j, ...], (-1, nBatches), order='F')
        # svd on the GPU was faster for this, but the Python randomized CPU version
        # might be faster still
        # WA = gpuArray(WA)
        A, B, C = svdecon_cpu(WA)
        # W_a times W_b results in a reconstruction of the time components
        W_a[:, :, j] = np.dot(A[:, :nKeep], B[:nKeep, :nKeep])
        W_b[:, :, j] = C[:, :nKeep]

        UA = np.reshape(ir.UA[:, j, ...], (-1, nBatches), order='F')
        # UA = gpuArray(UA)
        A, B, C = svdecon_cpu(UA)
        # U_a times U_b results in a reconstruction of the time components
        U_a[:, :, j] = np.dot(A[:, :nKeep], B[:nKeep, :nKeep])
        U_b[:, :, j] = C[:, :nKeep]

    logger.info('Finished compressing time-varying templates.')

    return Bunch(
        wPCA=wPCA[:, :Nrank],
        wTEMP=wTEMP,
        st3=ir.st3,
        simScore=ir.simScore,
        # cProj=ir.cProj,
        # cProjPC=ir.cProjPC,
        iNeigh=ir.iNeigh,
        iNeighPC=ir.iNeighPC,
        WA=ir.WA,
        UA=ir.UA,
        W=ir.W,
        U=ir.U,
        dWU=ir.dWU,
        mu=ir.mu,
        W_a=W_a,
        W_b=W_b,
        U_a=U_a,
        U_b=U_b,
    )
Esempio n. 7
0
def splitAllClusters(ctx, flag):
    # I call this algorithm "bimodal pursuit"
    # split clusters if they have bimodal projections
    # the strategy is to maximize a bimodality score and find a single vector projection
    # that maximizes it. If the distribution along that maximal projection crosses a
    # bimodality threshold, then the cluster is split along that direction
    # it only uses the PC features for each spike, stored in ir.cProjPC

    params = ctx.params
    probe = ctx.probe
    ir = ctx.intermediate
    Nchan = ctx.probe.Nchan

    wPCA = cp.asarray(
        ir.wPCA
    )  # use PCA projections to reconstruct templates when we do splits
    assert wPCA.shape[1] == 3

    # Take intermediate arrays from context.
    st3 = cp.asnumpy(ir.st3_m)
    cProjPC = ir.cProjPC
    dWU = ir.dWU

    # For the following arrays that will be overwritten by this function, try to get
    # it from a previous call to this function (as it is called twice), otherwise
    # get it from before (without the _s suffix).
    W = ir.get('W_s', ir.W)
    simScore = ir.get('simScore_s', ir.simScore)
    iNeigh = ir.get('iNeigh_s', ir.iNeigh)
    iNeighPC = ir.get('iNeighPC_s', ir.iNeighPC)

    # this is the threshold for splits, and is one of the main parameters users can change
    ccsplit = params.AUCsplit

    NchanNear = min(Nchan, 32)
    Nnearest = min(Nchan, 32)
    sigmaMask = params.sigmaMask

    ik = -1
    Nfilt = W.shape[1]
    nsplits = 0

    # determine what channels each template lives on
    iC, mask, C2C = getClosestChannels(probe, sigmaMask, NchanNear)

    # the waveforms must be aligned to this sample
    nt0min = params.nt0min
    # find the peak abs channel for each template
    iW = np.argmax(np.abs((dWU[nt0min - 1, :, :])), axis=0)

    # keep track of original cluster for each cluster. starts with all clusters being their
    # own origin.
    isplit = np.arange(Nfilt)
    dt = 1. / 1000
    nccg = 0

    while ik < Nfilt:
        if ik % 100 == 0:
            # periodically write updates
            logger.info(
                f'Found {nsplits} splits, checked {ik}/{Nfilt} clusters, nccg {nccg}'
            )
        ik += 1

        isp = (st3[:, 1] == ik)  # get all spikes from this cluster
        nSpikes = isp.sum()
        logger.debug(f"Splitting template {ik}/{Nfilt} with {nSpikes} spikes.")
        free_gpu_memory()

        if nSpikes < 300:
            # TODO: move_to_config
            # do not split if fewer than 300 spikes (we cannot estimate
            # cross-correlograms accurately)
            continue

        ss = st3[isp, 0] / params.fs  # convert to seconds

        clp0 = cProjPC[isp, :, :]  # get the PC projections for these spikes
        clp0 = cp.asarray(clp0, dtype=cp.float32)  # upload to the GPU
        clp0 = clp0.reshape((clp0.shape[0], -1), order='F')
        clp = clp0 - mean(clp0, axis=0)  # mean center them

        isp = np.nonzero(isp)[0]

        # subtract a running average, because the projections are NOT drift corrected
        clp = clp - my_conv2(clp, 250, 0)

        # now use two different ways to initialize the bimodal direction
        # the main script calls this function twice, and does both initializations

        if flag:
            u, s, v = svdecon(clp.T)
            #    u, v = -u, -v  # change sign for consistency with MATLAB
            w = u[:, 0]  # initialize with the top PC
        else:
            w = mean(clp0, axis=0
                     )  # initialize with the mean of NOT drift-corrected trace
            w = w / cp.sum(w**2)**0.5  # unit-normalize

        # initial projections of waveform PCs onto 1D vector
        x = cp.dot(clp, w)
        s1 = var(
            x[x > mean(x)])  # initialize estimates of variance for the first
        s2 = var(x[
            x < mean(x)])  # and second gaussian in the mixture of 1D gaussians

        mu1 = mean(x[x > mean(x)])  # initialize the means as well
        mu2 = mean(x[x < mean(x)])
        # and the probability that a spike is assigned to the first Gaussian
        p = mean(x > mean(x))

        # initialize matrix of log probabilities that each spike is assigned to the first
        # or second cluster
        logp = cp.zeros((nSpikes, 2), order='F')

        # do 50 pursuit iteration

        logP = cp.zeros(50)  # used to monitor the cost function

        # TODO: move_to_config - maybe...
        for k in range(50):
            # for each spike, estimate its probability to come from either Gaussian cluster
            logp[:, 0] = -1. / 2 * log(s1) - ((x - mu1)**2) / (2 * s1) + log(p)
            logp[:, 1] = -1. / 2 * log(s2) - (
                (x - mu2)**2) / (2 * s2) + log(1 - p)

            lMax = logp.max(axis=1)
            logp = logp - lMax[:, cp.
                               newaxis]  # subtract the max for floating point accuracy
            rs = cp.exp(logp)  # exponentiate the probabilities

            pval = cp.log(cp.sum(
                rs, axis=1)) + lMax  # get the normalizer and add back the max
            logP[k] = mean(
                pval)  # this is the cost function: we can monitor its increase

            rs = rs / cp.sum(
                rs,
                axis=1)[:,
                        cp.newaxis]  # normalize so that probabilities sum to 1

            p = mean(rs[:, 0])  # mean probability to be assigned to Gaussian 1
            # new estimate of mean of cluster 1 (weighted by "responsibilities")
            mu1 = cp.dot(rs[:, 0], x) / cp.sum(rs[:, 0])
            # new estimate of mean of cluster 2 (weighted by "responsibilities")
            mu2 = cp.dot(rs[:, 1], x) / cp.sum(rs[:, 1])

            s1 = cp.dot(rs[:, 0], (x - mu1)**2) / cp.sum(
                rs[:, 0])  # new estimates of variances
            s2 = cp.dot(rs[:, 1], (x - mu2)**2) / cp.sum(rs[:, 1])

            if (k >= 10) and (k % 2 == 0):
                # starting at iteration 10, we start re-estimating the pursuit direction
                # that is, given the Gaussian cluster assignments, and the mean and variances,
                # we re-estimate w
                # these equations follow from the model
                StS = cp.matmul(
                    clp.T,
                    clp *
                    (rs[:, 0] / s1 + rs[:, 1] / s2)[:, cp.newaxis]) / nSpikes
                StMu = cp.dot(
                    clp.T, rs[:, 0] * mu1 / s1 + rs[:, 1] * mu2 / s2) / nSpikes

                # this is the new estimate of the best pursuit direction
                w = cp.linalg.solve(StS.T, StMu)
                w = w / cp.sum(w**2)**0.5  # which we unit normalize
                x = cp.dot(clp, w)

        # these spikes are assigned to cluster 1
        ilow = rs[:, 0] > rs[:, 1]
        # the mean probability of spikes assigned to cluster 1
        plow = mean(rs[:, 0][ilow])
        phigh = mean(rs[:, 1][~ilow])  # same for cluster 2
        # the smallest cluster has this proportion of all spikes
        nremove = min(mean(ilow), mean(~ilow))

        # did this split fix the autocorrelograms?
        # compute the cross-correlogram between spikes in the putative new clusters
        ilow_cpu = cp.asnumpy(ilow)
        K, Qi, Q00, Q01, rir = ccg(ss[ilow_cpu], ss[~ilow_cpu], 500, dt)
        Q12 = (Qi / max(Q00, Q01)).min()  # refractoriness metric 1
        R = rir.min()  # refractoriness metric 2

        # if the CCG has a dip, don't do the split.
        # These thresholds are consistent with the ones from merges.
        # TODO: move_to_config (or at least a single constant so the are the same as the merges)
        if (Q12 < 0.25) and (R < 0.05):  # if both metrics are below threshold.
            nccg += 1  # keep track of how many splits were voided by the CCG criterion
            continue

        # now decide if the split would result in waveforms that are too similar
        # the reconstructed mean waveforms for putative cluster 1
        # c1 = cp.matmul(wPCA, cp.reshape((mean(clp0[ilow, :], 0), 3, -1), order='F'))
        c1 = cp.matmul(wPCA,
                       mean(clp0[ilow, :], 0).reshape((3, -1), order='F'))
        # the reconstructed mean waveforms for putative cluster 2
        # c2 = cp.matmul(wPCA, cp.reshape((mean(clp0[~ilow, :], 0), 3, -1), order='F'))
        c2 = cp.matmul(wPCA,
                       mean(clp0[~ilow, :], 0).reshape((3, -1), order='F'))

        cc = cp.corrcoef(c1.ravel(),
                         c2.ravel())  # correlation of mean waveforms
        n1 = sqrt(cp.sum(c1**2))  # the amplitude estimate 1
        n2 = sqrt(cp.sum(c2**2))  # the amplitude estimate 2

        r0 = 2 * abs((n1 - n2) / (n1 + n2))

        # if the templates are correlated, and their amplitudes are similar, stop the split!!!

        # TODO: move_to_config
        if (cc[0, 1] > 0.9) and (r0 < 0.2):
            continue

        # finaly criteria to continue with the split: if the split piece is more than 5% of all
        # spikes, if the split piece is more than 300 spikes, and if the confidences for
        # assigning spikes to # both clusters exceeds a preset criterion ccsplit
        # TODO: move_to_config
        if (nremove > 0.05) and (min(plow, phigh) > ccsplit) and (min(
                cp.sum(ilow), cp.sum(~ilow)) > 300):
            # one cluster stays, one goes
            Nfilt += 1

            # the templates for the splits have been estimated from PC coefficients

            # (DEV_NOTES) code below involves multiple CuPy arrays changing shape to accomodate
            # the extra cluster, this could potentially be done more efficiently?

            dWU = cp.concatenate(
                (cp.asarray(dWU), cp.zeros((*dWU.shape[:-1], 1), order='F')),
                axis=2)
            dWU[:, iC[:, iW[ik]], Nfilt - 1] = c2
            dWU[:, iC[:, iW[ik]], ik] = c1

            # the temporal components are therefore just the PC waveforms
            W = cp.asarray(W)
            W = cp.concatenate((W, cp.transpose(cp.atleast_3d(wPCA),
                                                (0, 2, 1))),
                               axis=1)
            assert W.shape[1] == Nfilt

            # copy the best channel from the original template
            iW = cp.asarray(iW)
            iW = cp.pad(iW, (0, (Nfilt - len(iW))), mode='constant')
            iW[Nfilt - 1] = iW[ik]
            assert iW.shape[0] == Nfilt

            # copy the provenance index to keep track of splits
            isplit = cp.asarray(isplit)
            isplit = cp.pad(isplit, (0, (Nfilt - len(isplit))),
                            mode='constant')
            isplit[Nfilt - 1] = isplit[ik]
            assert isplit.shape[0] == Nfilt

            st3[isp[ilow_cpu],
                1] = Nfilt - 1  # overwrite spike indices with the new index

            # copy similarity scores from the original
            simScore = cp.asarray(simScore)
            simScore = cp.pad(simScore, (0, (Nfilt - simScore.shape[0])),
                              mode='constant')
            simScore[:, Nfilt - 1] = simScore[:, ik]
            simScore[Nfilt - 1, :] = simScore[ik, :]
            # copy similarity scores from the original
            simScore[ik,
                     Nfilt - 1] = 1  # set the similarity with original to 1
            simScore[Nfilt - 1,
                     ik] = 1  # set the similarity with original to 1
            assert simScore.shape == (Nfilt, Nfilt)

            # copy neighbor template list from the original
            iNeigh = cp.asarray(iNeigh)
            iNeigh = cp.pad(iNeigh, ((0, 0), (0, (Nfilt - iNeigh.shape[1]))),
                            mode='constant')
            iNeigh[:, Nfilt - 1] = iNeigh[:, ik]
            assert iNeigh.shape[1] == Nfilt

            # copy neighbor channel list from the original
            iNeighPC = cp.asarray(iNeighPC)
            iNeighPC = cp.pad(iNeighPC,
                              ((0, 0), (0, (Nfilt - iNeighPC.shape[1]))),
                              mode='constant')
            iNeighPC[:, Nfilt - 1] = iNeighPC[:, ik]
            assert iNeighPC.shape[1] == Nfilt

            # try this cluster again
            # the cluster piece that stays at this index needs to be tested for splits again
            # before proceeding
            ik -= 1
            # the piece that became a new cluster will be tested again when we get to the end
            # of the list
            nsplits += 1  # keep track of how many splits we did
    #         pbar.update(ik)
    # pbar.close()

    logger.info(f'Finished splitting. Found {nsplits} splits, checked '
                f'{ik}/{Nfilt} clusters, nccg {nccg}')

    Nfilt = W.shape[1]  # new number of templates
    Nrank = 3
    Nchan = probe.Nchan
    Params = cp.array(
        [
            0, Nfilt, 0, 0, W.shape[0], Nnearest, Nrank, 0, 0, Nchan,
            NchanNear, nt0min, 0
        ],
        dtype=cp.float64)  # make a new Params to pass on parameters to CUDA

    # we need to re-estimate the spatial profiles

    # we get the time upsampling kernels again
    Ka, Kb = getKernels(params)
    # we run SVD
    W, U, mu = mexSVDsmall2(Params, dWU, W, iC, iW, Ka, Kb)

    # we re-compute similarity scores between templates
    WtW, iList = getMeWtW(W.astype(cp.float32), U.astype(cp.float32), Nnearest)
    # ir.iList = iList  # over-write the list of nearest templates

    isplit = simScore == 1  # overwrite the similarity scores of clusters with same parent
    simScore = WtW.max(axis=2)
    simScore[isplit] = 1  # 1 means they come from the same parent

    iNeigh = iList[:, :Nfilt]  # get the new neighbor templates
    iNeighPC = iC[:, iW[:Nfilt]]  # get the new neighbor channels

    # for Phy, we need to pad the spikes with zeros so the spikes are aligned to the center of
    # the window
    Wphy = cp.concatenate((cp.zeros((1 + nt0min, Nfilt, Nrank), order='F'), W),
                          axis=0)

    # ir.isplit = isplit  # keep track of origins for each cluster

    return Bunch(
        st3_s=st3,
        W_s=W,
        U_s=U,
        mu_s=mu,
        simScore_s=simScore,
        iNeigh_s=iNeigh,
        iNeighPC_s=iNeighPC,
        Wphy=Wphy,
        iList=iList,
        isplit=isplit,
    )
Esempio n. 8
0
def random_walker(data, labels, beta=130, mode='cg_j', tol=1.e-3, copy=True,
                  multichannel=False, return_full_prob=False, spacing=None,
                  *, prob_tol=1e-3):
    """Random walker algorithm for segmentation from markers.

    Random walker algorithm is implemented for gray-level or multichannel
    images.

    Parameters
    ----------
    data : array_like
        Image to be segmented in phases. Gray-level `data` can be two- or
        three-dimensional; multichannel data can be three- or four-
        dimensional (multichannel=True) with the highest dimension denoting
        channels. Data spacing is assumed isotropic unless the `spacing`
        keyword argument is used.
    labels : array of ints, of same shape as `data` without channels dimension
        Array of seed markers labeled with different positive integers
        for different phases. Zero-labeled pixels are unlabeled pixels.
        Negative labels correspond to inactive pixels that are not taken
        into account (they are removed from the graph). If labels are not
        consecutive integers, the labels array will be transformed so that
        labels are consecutive. In the multichannel case, `labels` should have
        the same shape as a single channel of `data`, i.e. without the final
        dimension denoting channels.
    beta : float, optional
        Penalization coefficient for the random walker motion
        (the greater `beta`, the more difficult the diffusion).
    mode : string, available options {'cg', 'cg_j', 'cg_mg', 'bf'}
        Mode for solving the linear system in the random walker algorithm.

        - 'bf' (brute force): an LU factorization of the Laplacian is
          computed. This is fast for small images (<1024x1024), but very slow
          and memory-intensive for large images (e.g., 3-D volumes).
        - 'cg' (conjugate gradient): the linear system is solved iteratively
          using the Conjugate Gradient method from scipy.sparse.linalg. This is
          less memory-consuming than the brute force method for large images,
          but it is quite slow.
        - 'cg_j' (conjugate gradient with Jacobi preconditionner): the
          Jacobi preconditionner is applyed during the Conjugate
          gradient method iterations. This may accelerate the
          convergence of the 'cg' method.
        - 'cg_mg' (conjugate gradient with multigrid preconditioner): a
          preconditioner is computed using a multigrid solver, then the
          solution is computed with the Conjugate Gradient method. This mode
          requires that the pyamg module is installed.
    tol : float, optional
        Tolerance to achieve when solving the linear system using
        the conjugate gradient based modes ('cg', 'cg_j' and 'cg_mg').
    copy : bool, optional
        If copy is False, the `labels` array will be overwritten with
        the result of the segmentation. Use copy=False if you want to
        save on memory.
    multichannel : bool, optional
        If True, input data is parsed as multichannel data (see 'data' above
        for proper input format in this case).
    return_full_prob : bool, optional
        If True, the probability that a pixel belongs to each of the
        labels will be returned, instead of only the most likely
        label.
    spacing : iterable of floats, optional
        Spacing between voxels in each spatial dimension. If `None`, then
        the spacing between pixels/voxels in each dimension is assumed 1.
    prob_tol : float, optional
        Tolerance on the resulting probability to be in the interval [0, 1].
        If the tolerance is not satisfied, a warning is displayed.

    Returns
    -------
    output : ndarray
        * If `return_full_prob` is False, array of ints of same shape
          and data type as `labels`, in which each pixel has been
          labeled according to the marker that reached the pixel first
          by anisotropic diffusion.
        * If `return_full_prob` is True, array of floats of shape
          `(nlabels, labels.shape)`. `output[label_nb, i, j]` is the
          probability that label `label_nb` reaches the pixel `(i, j)`
          first.

    See Also
    --------
    skimage.morphology.watershed : watershed segmentation
        A segmentation algorithm based on mathematical morphology
        and "flooding" of regions from markers.

    Notes
    -----
    Multichannel inputs are scaled with all channel data combined. Ensure all
    channels are separately normalized prior to running this algorithm.

    The `spacing` argument is specifically for anisotropic datasets, where
    data points are spaced differently in one or more spatial dimensions.
    Anisotropic data is commonly encountered in medical imaging.

    The algorithm was first proposed in [1]_.

    The algorithm solves the diffusion equation at infinite times for
    sources placed on markers of each phase in turn. A pixel is labeled with
    the phase that has the greatest probability to diffuse first to the pixel.

    The diffusion equation is solved by minimizing x.T L x for each phase,
    where L is the Laplacian of the weighted graph of the image, and x is
    the probability that a marker of the given phase arrives first at a pixel
    by diffusion (x=1 on markers of the phase, x=0 on the other markers, and
    the other coefficients are looked for). Each pixel is attributed the label
    for which it has a maximal value of x. The Laplacian L of the image
    is defined as:

       - L_ii = d_i, the number of neighbors of pixel i (the degree of i)
       - L_ij = -w_ij if i and j are adjacent pixels

    The weight w_ij is a decreasing function of the norm of the local gradient.
    This ensures that diffusion is easier between pixels of similar values.

    When the Laplacian is decomposed into blocks of marked and unmarked
    pixels::

        L = M B.T
            B A

    with first indices corresponding to marked pixels, and then to unmarked
    pixels, minimizing x.T L x for one phase amount to solving::

        A x = - B x_m

    where x_m = 1 on markers of the given phase, and 0 on other markers.
    This linear system is solved in the algorithm using a direct method for
    small images, and an iterative method for larger images.

    References
    ----------
    .. [1] Leo Grady, Random walks for image segmentation, IEEE Trans Pattern
        Anal Mach Intell. 2006 Nov;28(11):1768-83.
        :DOI:`10.1109/TPAMI.2006.233`.

    Examples
    --------
    >>> import cupy as cp
    >>> cp.random.seed(0)
    >>> a = cp.zeros((10, 10)) + 0.2 * cp.random.rand(10, 10)
    >>> a[5:8, 5:8] += 1
    >>> b = cp.zeros_like(a, dtype=cp.int32)
    >>> b[3, 3] = 1  # Marker for first phase
    >>> b[6, 6] = 2  # Marker for second phase
    >>> random_walker(a, b)
    array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
           [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
           [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
           [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
           [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
           [1, 1, 1, 1, 1, 2, 2, 2, 1, 1],
           [1, 1, 1, 1, 1, 2, 2, 2, 1, 1],
           [1, 1, 1, 1, 1, 2, 2, 2, 1, 1],
           [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
           [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)

    """
    # Parse input data
    if mode not in ('cg_mg', 'cg', 'bf', 'cg_j', None):
        raise ValueError(
            "{mode} is not a valid mode. Valid modes are 'cg_mg',"
            " 'cg', 'cg_j', 'bf' and None".format(mode=mode))

    if data.dtype.kind == 'f':
        float_dtype = cp.promote_types(data.dtype, cp.float32)
    else:
        float_dtype = cp.float64

    # Spacing kwarg checks
    if spacing is None:
        spacing = cp.ones(3, dtype=float_dtype)
    elif len(spacing) == labels.ndim:
        if len(spacing) == 2:
            # Need a dummy spacing for singleton 3rd dim
            spacing = cp.r_[spacing, 1.]
        spacing = cp.asarray(spacing, dtype=float_dtype)
    else:
        raise ValueError('Input argument `spacing` incorrect, should be an '
                         'iterable with one number per spatial dimension.')

    # This algorithm expects 4-D arrays of floats, where the first three
    # dimensions are spatial and the final denotes channels. 2-D images have
    # a singleton placeholder dimension added for the third spatial dimension,
    # and single channel images likewise have a singleton added for channels.
    # The following block ensures valid input and coerces it to the correct
    # form.
    if not multichannel:
        if data.ndim not in (2, 3):
            raise ValueError('For non-multichannel input, data must be of '
                             'dimension 2 or 3.')
        if data.shape != labels.shape:
            raise ValueError('Incompatible data and labels shapes.')
        data = cp.atleast_3d(img_as_float(data))[..., cp.newaxis]
    else:
        if data.ndim not in (3, 4):
            raise ValueError('For multichannel input, data must have 3 or 4 '
                             'dimensions.')
        if data.shape[:-1] != labels.shape:
            raise ValueError('Incompatible data and labels shapes.')
        data = img_as_float(data)
        if data.ndim == 3:  # 2D multispectral, needs singleton in 3rd axis
            data = data[:, :, cp.newaxis, :]

    labels_shape = labels.shape
    labels_dtype = labels.dtype

    if copy:
        labels = cp.copy(labels)

    (labels, nlabels, mask,
     inds_isolated_seeds, isolated_values) = _preprocess(labels)

    if isolated_values is None:
        # No non isolated zero valued areas in labels were
        # found. Returning provided labels.
        if return_full_prob:
            # Return the concatenation of the masks of each unique label
            unique_labels = cp.unique(labels)
            labels = cp.atleast_3d(labels)
            return cp.concatenate([labels == lab
                                   for lab in unique_labels if lab > 0],
                                  axis=-1)
        return labels

    # Build the linear system (lap_sparse, B)
    lap_sparse, B = _build_linear_system(data, spacing, labels, nlabels, mask,
                                         beta, multichannel)

    # Solve the linear system lap_sparse X = B
    # where X[i, j] is the probability that a marker of label i arrives
    # first at pixel j by anisotropic diffusion.
    X = _solve_linear_system(lap_sparse, B, tol, mode)

    if X.min() < -prob_tol or X.max() > 1 + prob_tol:
        warn('The probability range is outside [0, 1] given the tolerance '
             '`prob_tol`. Consider decreasing `beta` and/or decreasing '
             '`tol`.')

    # Build the output according to return_full_prob value
    # Put back labels of isolated seeds
    labels[inds_isolated_seeds] = isolated_values
    labels = labels.reshape(labels_shape)

    mask = labels == 0
    mask[inds_isolated_seeds] = False

    if return_full_prob:
        out = cp.zeros((nlabels,) + labels_shape)
        for lab, (label_prob, prob) in enumerate(zip(out, X), start=1):
            label_prob[mask] = prob
            label_prob[labels == lab] = 1
    else:
        X = cp.argmax(X, axis=0) + 1
        out = labels.astype(labels_dtype)
        out[mask] = X

    return out