def mexWtW2(Params, W1, W2, UtU): code, constants = get_cuda("mexWtW2") nblock = constants.nblock Nfilt = int(Params[1]) nt0 = int(Params[9]) d_Params = cp.asarray(Params, dtype=np.float64, order="F") d_W1 = cp.asarray(W1, dtype=np.float32, order="F") d_W2 = cp.asarray(W2, dtype=np.float32, order="F") d_UtU = cp.asarray(UtU, dtype=np.float32, order="F") d_WtW = cp.zeros((Nfilt, Nfilt, 2 * nt0 - 1), dtype=np.float32, order="F") grid = (1 + int(Nfilt // nblock), 1 + int(Nfilt // nblock)) block = (nblock, nblock) crossFilter = cp.RawKernel(code, "crossFilter") crossFilter(grid, block, (d_Params, d_W1, d_W2, d_UtU, d_WtW)) del d_Params, d_W1, d_W2, d_UtU return d_WtW
def mexSVDsmall2(Params, dWU, W, iC, iW, Ka, Kb): code, constants = get_cuda("mexSVDsmall2") Nthreads = constants.Nthreads Nfilt = int(Params[1]) nt0 = int(Params[4]) Nrank = int(Params[6]) Nchan = int(Params[9]) d_Params = cp.asarray(Params, dtype=np.float64, order="F") d_dWU = cp.asarray(dWU, dtype=np.float64, order="F") d_iC = cp.asarray(iC, dtype=np.int32, order="F") d_iW = cp.asarray(iW, dtype=np.int32, order="F") d_A = cp.asarray(Ka, dtype=np.float64, order="F") d_B = cp.asarray(Kb, dtype=np.float64, order="F") d_U = cp.zeros((Nchan, Nfilt, Nrank), dtype=np.float64, order="F") d_mu = cp.zeros(Nfilt, dtype=np.float64, order="F") d_W = cp.asarray(W, dtype=np.float64, order="F") d_wtw = cp.zeros((nt0, nt0, Nfilt), dtype=np.float64, order="F") d_dWUb = cp.zeros((nt0, Nchan, Nfilt), dtype=np.float64, order="F") tpS = (nt0, int(Nthreads // nt0)) tpK = (Nrank, int(Nthreads // Nrank)) blankdWU = cp.RawKernel(code, "blankdWU") blankdWU((Nfilt,), tpS, (d_Params, d_dWU, d_iC, d_iW, d_dWUb)) # compute dWU * dWU' getwtw = cp.RawKernel(code, "getwtw") getwtw((Nfilt,), tpS, (d_Params, d_dWUb, d_wtw)) # get W by power svd iterations getW = cp.RawKernel(code, "getW") getW((Nfilt,), (nt0,), (d_Params, d_wtw, d_W)) # compute U by W' * dWU getU = cp.RawKernel(code, "getU") getU((Nfilt,), tpK, (d_Params, d_dWUb, d_W, d_U)) # normalize U, get S, get mu, renormalize W reNormalize = cp.RawKernel(code, "reNormalize") reNormalize((Nfilt,), (nt0,), (d_Params, d_A, d_B, d_W, d_U, d_mu)) del d_wtw, d_Params, d_dWUb return d_W, d_U, d_mu
def mexMPnu8(Params, dataRAW, U, W, mu, iC, iW, UtU, iList, wPCA, params): code, constants = get_cuda("mexMPnu8") maxFR = int(constants.maxFR) nmaxiter = int(constants.nmaxiter) Nthreads = int(constants.Nthreads) NT = int(Params[0]) # NT = (unsigned int) Params[0]; Nfilt = int(Params[1]) nt0 = int(Params[4]) Nnearest = int(Params[5]) Nrank = int(Params[6]) NchanU = int(Params[10]) Nchan = int(Params[9]) d_Params = cp.asarray(Params, dtype=np.float64, order="F") d_draw = cp.asarray(dataRAW, dtype=np.float32, order="F") d_U = cp.asarray(U, dtype=np.float32, order="F") d_W = cp.asarray(W, dtype=np.float32, order="F") d_mu = cp.asarray(mu, dtype=np.float32, order="F") d_iC = cp.asarray(iC, dtype=np.int32, order="F") d_iW = cp.asarray(iW, dtype=np.int32, order="F") d_UtU = cp.asarray(UtU, dtype=np.bool, order="F") d_iList = cp.asarray(iList, dtype=np.int32, order="F") d_wPCA = cp.asarray(wPCA, dtype=np.float32, order="F") d_nsp = cp.zeros(Nfilt, dtype=np.int32, order="F") d_dWU = cp.zeros((nt0, Nchan, Nfilt), dtype=np.float64, order="F") d_dout = cp.zeros((2 * NT, Nfilt), dtype=np.float32, order="F") d_data = cp.zeros((NT, Nfilt, Nrank), dtype=np.float32, order="F") d_err = cp.zeros(NT, dtype=np.float32, order="F") d_ftype = cp.zeros(NT, dtype=np.int32, order="F") d_eloss = cp.zeros(NT, dtype=np.float32, order="F") d_st = cp.zeros(maxFR, dtype=np.int32, order="F") d_id = cp.zeros(maxFR, dtype=np.int32, order="F") d_x = cp.zeros(maxFR, dtype=np.float32, order="F") d_y = cp.zeros(maxFR, dtype=np.float32, order="F") d_z = cp.zeros(maxFR, dtype=np.float32, order="F") d_counter = cp.zeros(2, dtype=np.int32, order="F") d_count = cp.zeros(nmaxiter, dtype=np.int32, order="F") d_feat = cp.zeros((Nnearest, maxFR), dtype=np.float32, order="F") d_featPC = cp.zeros((NchanU, Nrank, maxFR), dtype=np.float32, order="F") d_idx = cp.zeros(maxFR, dtype=np.int32, order="F") counter = np.zeros(2, dtype=np.int32, order="F") tpF = (16, Nnearest) tpS = (nt0, 16) tpPC = (NchanU, Nrank) # filter the data with the spatial templates spaceFilter = cp.RawKernel(code, "spaceFilter") spaceFilter((Nfilt,), (Nthreads,), (d_Params, d_draw, d_U, d_iC, d_iW, d_data)) # filter the data with the temporal templates timeFilter = cp.RawKernel(code, "timeFilter") timeFilter((Nfilt,), (Nthreads,), (d_Params, d_data, d_W, d_dout)) # compute the best filter bestFilter = cp.RawKernel(code, "bestFilter") bestFilter( (int(NT // Nthreads),), (Nthreads,), (d_Params, d_dout, d_mu, d_err, d_eloss, d_ftype), ) if params.stablemode_enabled and not params.deterministicmode_enabled: d_draw64 = cp.array(d_draw, dtype=np.float64) # loop to find and subtract spikes for k in range(int(Params[3])): # ignore peaks that are smaller than another nearby peak cleanup_spikes = cp.RawKernel(code, "cleanup_spikes") cleanup_spikes( (int(NT // Nthreads),), (Nthreads,), ( d_Params, d_dout, d_mu, d_err, d_eloss, d_ftype, d_st, d_id, d_x, d_y, d_z, d_counter, ), ) # add new spikes to 2nd counter counter[:] = cp.asnumpy(d_counter[:]) if counter[0] > maxFR: logger.warning("Firing rate limit hit for batch.") counter[0] = maxFR d_counter[0] = counter[0] # extract template features before subtraction if Params[12] > 1: extractFEAT = cp.RawKernel(code, "extractFEAT") extractFEAT( (64,), tpF, (d_Params, d_st, d_id, d_counter, d_dout, d_iList, d_mu, d_feat), ) if params.deterministicmode_enabled: if params.stablemode_enabled: d_stSort = d_st[counter[1]:counter[0]] # cudaMemcpy( d_stSort, d_st+counter[1], (counter[0] - counter[1])*sizeof(int), cudaMemcpyDeviceToDevice ); d_idx[:counter[0]-counter[1]] = cp.argsort(d_stSort) # cdp_simple_quicksort<<< 1, 1 >>>(d_stSort, d_idx, 0, counter[0] - counter[1] - 1, 0); else: raise ValueError("Stablemode required for deterministic calculations.") # This is allowed in the MATLAB version runtime but it doesn't really make sense # and isn't recommended so let's not anyone fall into the trap without knowing. # d_idx = cp.arange(0, counter[0] - counter[1]) if Nchan < Nthreads: subtract_spikes_v2 = cp.RawKernel(code, "subtract_spikes_v2") subtract_spikes_v2( (1,), (Nchan,), (d_Params, d_st, d_idx, d_id, d_y, d_counter, d_draw, d_W, d_U), ) else: subtract_spikes_v2 = cp.RawKernel(code, "subtract_spikes_v2") subtract_spikes_v2( (Nchan / Nthreads,), (Nthreads,), (d_Params, d_st, d_idx, d_id, d_y, d_counter, d_draw, d_W, d_U), ) spaceFilterUpdate = cp.RawKernel(code, "spaceFilterUpdate") spaceFilterUpdate( (Nfilt,), (2 * nt0 - 1,), ( d_Params, d_draw, d_U, d_UtU, d_iC, d_iW, d_data, d_st, d_id, d_counter, ), ) else: if params.stablemode_enabled: subtract_spikes_v4 = cp.RawKernel(code, "subtract_spikes_v4") subtract_spikes_v4( (Nfilt,), tpS, (d_Params, d_st, d_id, d_y, d_counter, d_draw64, d_W, d_U), ) spaceFilterUpdate_v2 = cp.RawKernel(code, "spaceFilterUpdate_v2") spaceFilterUpdate_v2( (Nfilt,), (2 * nt0 - 1,), ( d_Params, d_draw64, d_U, d_UtU, d_iC, d_iW, d_data, d_st, d_id, d_counter, ), ) else: # subtract spikes from raw data here subtract_spikes = cp.RawKernel(code, "subtract_spikes") subtract_spikes( (Nfilt,), tpS, (d_Params, d_st, d_id, d_y, d_counter, d_draw, d_W, d_U), ) # filter the data with the spatial templates spaceFilterUpdate = cp.RawKernel(code, "spaceFilterUpdate") spaceFilterUpdate( (Nfilt,), (2 * nt0 - 1,), ( d_Params, d_draw, d_U, d_UtU, d_iC, d_iW, d_data, d_st, d_id, d_counter, ), ) # filter the data with the temporal templates timeFilterUpdate = cp.RawKernel(code, "timeFilterUpdate") timeFilterUpdate( (Nfilt,), (2 * nt0 - 1,), (d_Params, d_data, d_W, d_UtU, d_dout, d_st, d_id, d_counter), ) if counter[0] - counter[1] > 0: bestFilterUpdate = cp.RawKernel(code, "bestFilterUpdate") bestFilterUpdate( (counter[0] - counter[1],), (2 * nt0 - 1,), ( d_Params, d_dout, d_mu, d_err, d_eloss, d_ftype, d_st, d_id, d_counter, ), ) d_count[k + 1] = d_counter[0] # update 1st counter from 2nd counter d_counter[1] = d_counter[0] if params.stablemode_enabled and not params.deterministicmode_enabled: d_draw = cp.array(d_draw64, dtype=np.float32) # compute PC features from residuals + subtractions # TODO: design - let's not use numeric indexing into the Params array. It's much more difficult to read. if Params[12] > 0: computePCfeatures = cp.RawKernel(code, "computePCfeatures") computePCfeatures( (Nfilt,), tpPC, ( d_Params, d_counter, d_draw, d_st, d_id, d_y, d_W, d_U, d_mu, d_iW, d_iC, d_wPCA, d_featPC, ), ) if params.stablemode_enabled: # d_idx = array of time sorted indices d_idx[:counter[0]] = cp.argsort(d_st[:counter[0]]) # cdp_simple_quicksort<<< 1, 1 >>>(d_stSort, d_idx, 0, counter[0] - counter[1] - 1, 0); else: d_idx = cp.arange(0, counter[0]) # update dWU here by adding back to subbed spikes. average_snips = cp.RawKernel(code, "average_snips") average_snips( (Nfilt,), tpS, ( d_Params, d_st, d_idx, d_id, d_x, d_y, d_counter, d_draw, d_W, d_U, d_dWU, d_nsp, d_mu, d_z, ), ) if counter[0] < maxFR: minSize = counter[0] else: minSize = maxFR del d_counter, d_Params, d_ftype, d_err, d_eloss, d_z, d_dout, d_data return ( d_st[:minSize], d_id[:minSize], d_y[:minSize], d_feat[..., :minSize], d_dWU, d_draw, d_nsp, d_featPC[..., :minSize], d_x[:minSize], )
def mexGetSpikes2(Params, drez, wTEMP, iC): code, constants = get_cuda("mexGetSpikes2") NT = int(Params[0]) Nchan = int(Params[9]) nt0 = int(Params[4]) # Nnearest = int(Params[5]) Nrank = int(Params[14]) maxFR = constants.maxFR Nthreads = constants.Nthreads # tpB = (8, 2 * nt0 - 1) # tpF = (16, Nnearest) tpS = (nt0, 16) d_Params = cp.asarray(Params, dtype=np.float64, order="F") d_data = cp.asarray(drez, dtype=np.float32, order="F") d_W = cp.asarray(wTEMP, dtype=np.float32, order="F") d_iC = cp.asarray(iC, dtype=np.int32, order="F") d_counter = cp.zeros(2, dtype=np.int32, order="F") d_dout = cp.zeros((NT, Nchan), dtype=np.float32, order="F") d_dfilt = cp.zeros((Nrank, NT, Nchan), dtype=np.float32, order="F") d_err = cp.zeros(NT, dtype=np.float32, order="F") d_kkmax = cp.zeros((NT, Nchan), dtype=np.int32, order="F") d_kk = cp.zeros(NT, dtype=np.int32, order="F") d_ftype = cp.zeros(NT, dtype=np.int32, order="F") d_st = cp.zeros(maxFR, dtype=np.int32, order="F") d_id = cp.zeros(maxFR, dtype=np.int32, order="F") d_x = cp.zeros(maxFR, dtype=np.float32, order="F") d_st1 = cp.zeros(maxFR, dtype=np.int32, order="F") d_id1 = cp.zeros(maxFR, dtype=np.int32, order="F") counter = np.zeros(2, dtype=np.int32, order="F") # filter the data with the temporal templates Conv1D = cp.RawKernel(code, "Conv1D") Conv1D((Nchan,), (Nthreads,), (d_Params, d_data, d_W, d_dfilt)) # sum each template across channels, square, take max sumChannels = cp.RawKernel(code, "sumChannels") sumChannels( (int(NT / Nthreads),), (Nthreads,), (d_Params, d_dfilt, d_dout, d_kkmax, d_iC) ) # compute the best filter bestFilter = cp.RawKernel(code, "bestFilter") bestFilter( (int(NT / Nthreads),), (Nthreads,), (d_Params, d_dout, d_err, d_ftype, d_kkmax, d_kk), ) # ignore peaks that are smaller than another nearby peak cleanup_spikes = cp.RawKernel(code, "cleanup_spikes") cleanup_spikes( (int(NT / Nthreads),), (Nthreads,), (d_Params, d_err, d_ftype, d_x, d_st, d_id, d_counter), ) # ignore peaks that are smaller than another nearby peak cleanup_heights = cp.RawKernel(code, "cleanup_heights") cleanup_heights( (1 + int(maxFR // 32),), (32,), (d_Params, d_x, d_st, d_id, d_st1, d_id1, d_counter), ) # add new spikes to 2nd counter counter[0] = d_counter[1] counter[0] = min(maxFR, counter[0]) d_WU = cp.zeros((nt0, Nchan, counter[0]), dtype=np.float32, order="F") # d_WU1 = cp.zeros((nt0, Nchan, counter[0]), dtype=np.float32, order='F') # update dWU here by adding back to subbed spikes extract_snips = cp.RawKernel(code, "extract_snips") extract_snips((Nchan,), tpS, (d_Params, d_st1, d_id1, d_counter, d_data, d_WU)) # QUESTION: why a copy here?? # if counter[0] > 0: # d_WU1[...] = d_WU[...] del ( d_ftype, d_kkmax, d_err, d_st, d_id, d_st1, d_x, d_kk, d_id1, d_counter, d_Params, d_dfilt, ) return d_WU, d_dout