Exemple #1
0
def mexWtW2(Params, W1, W2, UtU):
    code, constants = get_cuda("mexWtW2")

    nblock = constants.nblock

    Nfilt = int(Params[1])
    nt0 = int(Params[9])

    d_Params = cp.asarray(Params, dtype=np.float64, order="F")

    d_W1 = cp.asarray(W1, dtype=np.float32, order="F")
    d_W2 = cp.asarray(W2, dtype=np.float32, order="F")
    d_UtU = cp.asarray(UtU, dtype=np.float32, order="F")

    d_WtW = cp.zeros((Nfilt, Nfilt, 2 * nt0 - 1), dtype=np.float32, order="F")

    grid = (1 + int(Nfilt // nblock), 1 + int(Nfilt // nblock))
    block = (nblock, nblock)

    crossFilter = cp.RawKernel(code, "crossFilter")
    crossFilter(grid, block, (d_Params, d_W1, d_W2, d_UtU, d_WtW))

    del d_Params, d_W1, d_W2, d_UtU

    return d_WtW
Exemple #2
0
def mexSVDsmall2(Params, dWU, W, iC, iW, Ka, Kb):
    code, constants = get_cuda("mexSVDsmall2")

    Nthreads = constants.Nthreads

    Nfilt = int(Params[1])
    nt0 = int(Params[4])
    Nrank = int(Params[6])
    Nchan = int(Params[9])

    d_Params = cp.asarray(Params, dtype=np.float64, order="F")

    d_dWU = cp.asarray(dWU, dtype=np.float64, order="F")
    d_iC = cp.asarray(iC, dtype=np.int32, order="F")
    d_iW = cp.asarray(iW, dtype=np.int32, order="F")

    d_A = cp.asarray(Ka, dtype=np.float64, order="F")
    d_B = cp.asarray(Kb, dtype=np.float64, order="F")

    d_U = cp.zeros((Nchan, Nfilt, Nrank), dtype=np.float64, order="F")
    d_mu = cp.zeros(Nfilt, dtype=np.float64, order="F")

    d_W = cp.asarray(W, dtype=np.float64, order="F")

    d_wtw = cp.zeros((nt0, nt0, Nfilt), dtype=np.float64, order="F")
    d_dWUb = cp.zeros((nt0, Nchan, Nfilt), dtype=np.float64, order="F")

    tpS = (nt0, int(Nthreads // nt0))
    tpK = (Nrank, int(Nthreads // Nrank))

    blankdWU = cp.RawKernel(code, "blankdWU")
    blankdWU((Nfilt,), tpS, (d_Params, d_dWU, d_iC, d_iW, d_dWUb))

    # compute dWU * dWU'
    getwtw = cp.RawKernel(code, "getwtw")
    getwtw((Nfilt,), tpS, (d_Params, d_dWUb, d_wtw))

    # get W by power svd iterations
    getW = cp.RawKernel(code, "getW")
    getW((Nfilt,), (nt0,), (d_Params, d_wtw, d_W))

    # compute U by W' * dWU
    getU = cp.RawKernel(code, "getU")
    getU((Nfilt,), tpK, (d_Params, d_dWUb, d_W, d_U))

    # normalize U, get S, get mu, renormalize W
    reNormalize = cp.RawKernel(code, "reNormalize")
    reNormalize((Nfilt,), (nt0,), (d_Params, d_A, d_B, d_W, d_U, d_mu))

    del d_wtw, d_Params, d_dWUb

    return d_W, d_U, d_mu
Exemple #3
0
def mexMPnu8(Params, dataRAW, U, W, mu, iC, iW, UtU, iList, wPCA, params):
    code, constants = get_cuda("mexMPnu8")
    maxFR = int(constants.maxFR)
    nmaxiter = int(constants.nmaxiter)
    Nthreads = int(constants.Nthreads)

    NT = int(Params[0]) # NT = (unsigned int) Params[0];
    Nfilt = int(Params[1])
    nt0 = int(Params[4])
    Nnearest = int(Params[5])
    Nrank = int(Params[6])
    NchanU = int(Params[10])
    Nchan = int(Params[9])

    d_Params = cp.asarray(Params, dtype=np.float64, order="F")

    d_draw = cp.asarray(dataRAW, dtype=np.float32, order="F")
    d_U = cp.asarray(U, dtype=np.float32, order="F")
    d_W = cp.asarray(W, dtype=np.float32, order="F")
    d_mu = cp.asarray(mu, dtype=np.float32, order="F")
    d_iC = cp.asarray(iC, dtype=np.int32, order="F")
    d_iW = cp.asarray(iW, dtype=np.int32, order="F")
    d_UtU = cp.asarray(UtU, dtype=np.bool, order="F")
    d_iList = cp.asarray(iList, dtype=np.int32, order="F")
    d_wPCA = cp.asarray(wPCA, dtype=np.float32, order="F")

    d_nsp = cp.zeros(Nfilt, dtype=np.int32, order="F")
    d_dWU = cp.zeros((nt0, Nchan, Nfilt), dtype=np.float64, order="F")

    d_dout = cp.zeros((2 * NT, Nfilt), dtype=np.float32, order="F")
    d_data = cp.zeros((NT, Nfilt, Nrank), dtype=np.float32, order="F")
    d_err = cp.zeros(NT, dtype=np.float32, order="F")
    d_ftype = cp.zeros(NT, dtype=np.int32, order="F")
    d_eloss = cp.zeros(NT, dtype=np.float32, order="F")
    d_st = cp.zeros(maxFR, dtype=np.int32, order="F")
    d_id = cp.zeros(maxFR, dtype=np.int32, order="F")
    d_x = cp.zeros(maxFR, dtype=np.float32, order="F")
    d_y = cp.zeros(maxFR, dtype=np.float32, order="F")
    d_z = cp.zeros(maxFR, dtype=np.float32, order="F")

    d_counter = cp.zeros(2, dtype=np.int32, order="F")
    d_count = cp.zeros(nmaxiter, dtype=np.int32, order="F")
    d_feat = cp.zeros((Nnearest, maxFR), dtype=np.float32, order="F")
    d_featPC = cp.zeros((NchanU, Nrank, maxFR), dtype=np.float32, order="F")

    d_idx = cp.zeros(maxFR, dtype=np.int32, order="F")

    counter = np.zeros(2, dtype=np.int32, order="F")

    tpF = (16, Nnearest)
    tpS = (nt0, 16)
    tpPC = (NchanU, Nrank)

    # filter the data with the spatial templates
    spaceFilter = cp.RawKernel(code, "spaceFilter")
    spaceFilter((Nfilt,), (Nthreads,), (d_Params, d_draw, d_U, d_iC, d_iW, d_data))

    # filter the data with the temporal templates
    timeFilter = cp.RawKernel(code, "timeFilter")
    timeFilter((Nfilt,), (Nthreads,), (d_Params, d_data, d_W, d_dout))

    # compute the best filter
    bestFilter = cp.RawKernel(code, "bestFilter")
    bestFilter(
        (int(NT // Nthreads),),
        (Nthreads,),
        (d_Params, d_dout, d_mu, d_err, d_eloss, d_ftype),
    )

    if params.stablemode_enabled and not params.deterministicmode_enabled:
        d_draw64 = cp.array(d_draw, dtype=np.float64)

    # loop to find and subtract spikes
    for k in range(int(Params[3])):
        # ignore peaks that are smaller than another nearby peak
        cleanup_spikes = cp.RawKernel(code, "cleanup_spikes")
        cleanup_spikes(
            (int(NT // Nthreads),),
            (Nthreads,),
            (
                d_Params,
                d_dout,
                d_mu,
                d_err,
                d_eloss,
                d_ftype,
                d_st,
                d_id,
                d_x,
                d_y,
                d_z,
                d_counter,
            ),
        )

        # add new spikes to 2nd counter
        counter[:] = cp.asnumpy(d_counter[:])
        if counter[0] > maxFR:
            logger.warning("Firing rate limit hit for batch.")
            counter[0] = maxFR
            d_counter[0] = counter[0]

        # extract template features before subtraction
        if Params[12] > 1:
            extractFEAT = cp.RawKernel(code, "extractFEAT")
            extractFEAT(
                (64,),
                tpF,
                (d_Params, d_st, d_id, d_counter, d_dout, d_iList, d_mu, d_feat),
            )

        if params.deterministicmode_enabled:
            if params.stablemode_enabled:
                d_stSort = d_st[counter[1]:counter[0]] # cudaMemcpy( d_stSort, d_st+counter[1], (counter[0] - counter[1])*sizeof(int), cudaMemcpyDeviceToDevice );
                d_idx[:counter[0]-counter[1]] = cp.argsort(d_stSort) # cdp_simple_quicksort<<< 1, 1 >>>(d_stSort, d_idx, 0, counter[0] - counter[1] - 1, 0);
            else:
                raise ValueError("Stablemode required for deterministic calculations.")
                # This is allowed in the MATLAB version runtime but it doesn't really make sense
                # and isn't recommended so let's not anyone fall into the trap without knowing.
                # d_idx = cp.arange(0, counter[0] - counter[1])

            if Nchan < Nthreads:
                subtract_spikes_v2 = cp.RawKernel(code, "subtract_spikes_v2")
                subtract_spikes_v2(
                    (1,),
                    (Nchan,),
                    (d_Params, d_st, d_idx, d_id, d_y, d_counter, d_draw, d_W, d_U),
                )
            else:
                subtract_spikes_v2 = cp.RawKernel(code, "subtract_spikes_v2")
                subtract_spikes_v2(
                    (Nchan / Nthreads,),
                    (Nthreads,),
                    (d_Params, d_st, d_idx, d_id, d_y, d_counter, d_draw, d_W, d_U),
                )

            spaceFilterUpdate = cp.RawKernel(code, "spaceFilterUpdate")
            spaceFilterUpdate(
                (Nfilt,),
                (2 * nt0 - 1,),
                (
                    d_Params,
                    d_draw,
                    d_U,
                    d_UtU,
                    d_iC,
                    d_iW,
                    d_data,
                    d_st,
                    d_id,
                    d_counter,
                ),
            )
        else:
            if params.stablemode_enabled:
                subtract_spikes_v4 = cp.RawKernel(code, "subtract_spikes_v4")
                subtract_spikes_v4(
                    (Nfilt,),
                    tpS,
                    (d_Params, d_st, d_id, d_y, d_counter, d_draw64, d_W, d_U),
                )

                spaceFilterUpdate_v2 = cp.RawKernel(code, "spaceFilterUpdate_v2")
                spaceFilterUpdate_v2(
                    (Nfilt,),
                    (2 * nt0 - 1,),
                    (
                        d_Params,
                        d_draw64,
                        d_U,
                        d_UtU,
                        d_iC,
                        d_iW,
                        d_data,
                        d_st,
                        d_id,
                        d_counter,
                    ),
                )
            else:
                # subtract spikes from raw data here
                subtract_spikes = cp.RawKernel(code, "subtract_spikes")
                subtract_spikes(
                    (Nfilt,),
                    tpS,
                    (d_Params, d_st, d_id, d_y, d_counter, d_draw, d_W, d_U),
                )

                # filter the data with the spatial templates
                spaceFilterUpdate = cp.RawKernel(code, "spaceFilterUpdate")
                spaceFilterUpdate(
                    (Nfilt,),
                    (2 * nt0 - 1,),
                    (
                        d_Params,
                        d_draw,
                        d_U,
                        d_UtU,
                        d_iC,
                        d_iW,
                        d_data,
                        d_st,
                        d_id,
                        d_counter,
                    ),
                )

        # filter the data with the temporal templates
        timeFilterUpdate = cp.RawKernel(code, "timeFilterUpdate")
        timeFilterUpdate(
            (Nfilt,),
            (2 * nt0 - 1,),
            (d_Params, d_data, d_W, d_UtU, d_dout, d_st, d_id, d_counter),
        )

        if counter[0] - counter[1] > 0:
            bestFilterUpdate = cp.RawKernel(code, "bestFilterUpdate")
            bestFilterUpdate(
                (counter[0] - counter[1],),
                (2 * nt0 - 1,),
                (
                    d_Params,
                    d_dout,
                    d_mu,
                    d_err,
                    d_eloss,
                    d_ftype,
                    d_st,
                    d_id,
                    d_counter,
                ),
            )

        d_count[k + 1] = d_counter[0]

        # update 1st counter from 2nd counter
        d_counter[1] = d_counter[0]

    if params.stablemode_enabled and not params.deterministicmode_enabled:
        d_draw = cp.array(d_draw64, dtype=np.float32)

    # compute PC features from residuals + subtractions
    # TODO: design - let's not use numeric indexing into the Params array. It's much more difficult to read.
    if Params[12] > 0:
        computePCfeatures = cp.RawKernel(code, "computePCfeatures")
        computePCfeatures(
            (Nfilt,),
            tpPC,
            (
                d_Params,
                d_counter,
                d_draw,
                d_st,
                d_id,
                d_y,
                d_W,
                d_U,
                d_mu,
                d_iW,
                d_iC,
                d_wPCA,
                d_featPC,
            ),
        )

    if params.stablemode_enabled:
        # d_idx = array of time sorted indices
        d_idx[:counter[0]] = cp.argsort(d_st[:counter[0]]) # cdp_simple_quicksort<<< 1, 1 >>>(d_stSort, d_idx, 0, counter[0] - counter[1] - 1, 0);
    else:
        d_idx = cp.arange(0, counter[0])

    # update dWU here by adding back to subbed spikes.
    average_snips = cp.RawKernel(code, "average_snips")
    average_snips(
        (Nfilt,),
        tpS,
        (
            d_Params,
            d_st,
            d_idx,
            d_id,
            d_x,
            d_y,
            d_counter,
            d_draw,
            d_W,
            d_U,
            d_dWU,
            d_nsp,
            d_mu,
            d_z,
        ),
    )

    if counter[0] < maxFR:
        minSize = counter[0]
    else:
        minSize = maxFR

    del d_counter, d_Params, d_ftype, d_err, d_eloss, d_z, d_dout, d_data

    return (
        d_st[:minSize],
        d_id[:minSize],
        d_y[:minSize],
        d_feat[..., :minSize],
        d_dWU,
        d_draw,
        d_nsp,
        d_featPC[..., :minSize],
        d_x[:minSize],
    )
Exemple #4
0
def mexGetSpikes2(Params, drez, wTEMP, iC):
    code, constants = get_cuda("mexGetSpikes2")

    NT = int(Params[0])
    Nchan = int(Params[9])
    nt0 = int(Params[4])
    # Nnearest = int(Params[5])
    Nrank = int(Params[14])

    maxFR = constants.maxFR
    Nthreads = constants.Nthreads

    # tpB = (8, 2 * nt0 - 1)
    # tpF = (16, Nnearest)
    tpS = (nt0, 16)

    d_Params = cp.asarray(Params, dtype=np.float64, order="F")
    d_data = cp.asarray(drez, dtype=np.float32, order="F")
    d_W = cp.asarray(wTEMP, dtype=np.float32, order="F")
    d_iC = cp.asarray(iC, dtype=np.int32, order="F")

    d_counter = cp.zeros(2, dtype=np.int32, order="F")
    d_dout = cp.zeros((NT, Nchan), dtype=np.float32, order="F")
    d_dfilt = cp.zeros((Nrank, NT, Nchan), dtype=np.float32, order="F")
    d_err = cp.zeros(NT, dtype=np.float32, order="F")
    d_kkmax = cp.zeros((NT, Nchan), dtype=np.int32, order="F")
    d_kk = cp.zeros(NT, dtype=np.int32, order="F")
    d_ftype = cp.zeros(NT, dtype=np.int32, order="F")
    d_st = cp.zeros(maxFR, dtype=np.int32, order="F")
    d_id = cp.zeros(maxFR, dtype=np.int32, order="F")
    d_x = cp.zeros(maxFR, dtype=np.float32, order="F")
    d_st1 = cp.zeros(maxFR, dtype=np.int32, order="F")
    d_id1 = cp.zeros(maxFR, dtype=np.int32, order="F")

    counter = np.zeros(2, dtype=np.int32, order="F")

    # filter the data with the temporal templates
    Conv1D = cp.RawKernel(code, "Conv1D")
    Conv1D((Nchan,), (Nthreads,), (d_Params, d_data, d_W, d_dfilt))

    # sum each template across channels, square, take max
    sumChannels = cp.RawKernel(code, "sumChannels")
    sumChannels(
        (int(NT / Nthreads),), (Nthreads,), (d_Params, d_dfilt, d_dout, d_kkmax, d_iC)
    )

    # compute the best filter
    bestFilter = cp.RawKernel(code, "bestFilter")
    bestFilter(
        (int(NT / Nthreads),),
        (Nthreads,),
        (d_Params, d_dout, d_err, d_ftype, d_kkmax, d_kk),
    )

    # ignore peaks that are smaller than another nearby peak
    cleanup_spikes = cp.RawKernel(code, "cleanup_spikes")
    cleanup_spikes(
        (int(NT / Nthreads),),
        (Nthreads,),
        (d_Params, d_err, d_ftype, d_x, d_st, d_id, d_counter),
    )

    # ignore peaks that are smaller than another nearby peak
    cleanup_heights = cp.RawKernel(code, "cleanup_heights")
    cleanup_heights(
        (1 + int(maxFR // 32),),
        (32,),
        (d_Params, d_x, d_st, d_id, d_st1, d_id1, d_counter),
    )

    # add new spikes to 2nd counter
    counter[0] = d_counter[1]
    counter[0] = min(maxFR, counter[0])

    d_WU = cp.zeros((nt0, Nchan, counter[0]), dtype=np.float32, order="F")
    # d_WU1 = cp.zeros((nt0, Nchan, counter[0]), dtype=np.float32, order='F')

    # update dWU here by adding back to subbed spikes
    extract_snips = cp.RawKernel(code, "extract_snips")
    extract_snips((Nchan,), tpS, (d_Params, d_st1, d_id1, d_counter, d_data, d_WU))

    # QUESTION: why a copy here??
    # if counter[0] > 0:
    #     d_WU1[...] = d_WU[...]

    del (
        d_ftype,
        d_kkmax,
        d_err,
        d_st,
        d_id,
        d_st1,
        d_x,
        d_kk,
        d_id1,
        d_counter,
        d_Params,
        d_dfilt,
    )
    return d_WU, d_dout