def four_way_scan(data, sm_masks, sm_blocksum, blksz, valid): sm_chunkoffset = roc.shared.array(4, dtype=int32) tid = roc.get_local_id(0) laneid = tid & (_WARPSIZE - 1) warpid = tid >> 6 my_digit = -1 for digit in range(RADIX): sm_masks[digit, tid] = 0 if valid and data == digit: sm_masks[digit, tid] = 1 my_digit = digit roc.barrier() offset = 0 base = 0 while offset < blksz: # Exclusive scan if warpid < RADIX: val = intp(sm_masks[warpid, offset + laneid]) cur, psum = shuf_wave_exclusive_scan(val) sm_masks[warpid, offset + laneid] = cur + base base += psum roc.barrier() offset += _WARPSIZE roc.barrier() # Store blocksum from the exclusive scan if warpid < RADIX and laneid == 0: sm_blocksum[warpid] = base roc.barrier() # Calc chunk offset (a short exclusive scan) if tid == 0: sm_chunkoffset[0] = 0 sm_chunkoffset[1] = sm_blocksum[0] sm_chunkoffset[2] = sm_chunkoffset[1] + sm_blocksum[1] sm_chunkoffset[3] = sm_chunkoffset[2] + sm_blocksum[2] roc.barrier() # Prepare output chunk_offset = -1 scanval = -1 if my_digit != -1: chunk_offset = sm_chunkoffset[my_digit] scanval = sm_masks[my_digit, tid] roc.wavebarrier() roc.barrier() return chunk_offset, scanval
def local_inclusive_scan_shuf(tid, value, nelem, temp): """ * temp: shared array Size of the array must be at least the number of active wave Note: This function must be called by all threads in the block """ roc.barrier() roc.wavebarrier() res = shuf_device_inclusive_scan(value, temp) roc.barrier() return res
def wave_reduce(val): tid = roc.get_local_id(0) laneid = tid % WAVESIZE width = WAVESIZE // 2 while width: if laneid < width: val[laneid] += val[laneid + width] val[laneid + width] = -1 # debug roc.wavebarrier() width = width // 2 # First thread has the result roc.wavebarrier() return val[0]
def wave_reduce(val): tid = roc.get_local_id(0) laneid = tid % WAVESIZE width = WAVESIZE // 2 while width: if laneid < width: val[laneid] += val[laneid + width] val[laneid + width] = -1 # debug roc.wavebarrier() width = width // 2 # First thread has the result roc.wavebarrier() return val[0]
def shuf_wave_inclusive_scan(val): tid = roc.get_local_id(0) lane = tid & (_WARPSIZE - 1) roc.wavebarrier() shuf = shuffle_up(val, 1) if lane >= 1: val += shuf roc.wavebarrier() shuf = shuffle_up(val, 2) if lane >= 2: val += shuf roc.wavebarrier() shuf = shuffle_up(val, 4) if lane >= 4: val += shuf roc.wavebarrier() shuf = shuffle_up(val, 8) if lane >= 8: val += shuf roc.wavebarrier() shuf = shuffle_up(val, 16) if lane >= 16: val += shuf roc.wavebarrier() shuf = shuffle_up(val, 32) if lane >= 32: val += shuf roc.wavebarrier() return val
def broadcast(val, src): roc.wavebarrier() return roc.activelanepermute_wavewidth(val, src, 0, False)
def shuffle_up(val, width): tid = roc.get_local_id(0) roc.wavebarrier() res = roc.activelanepermute_wavewidth(val, tid - width, 0, False) return res
def broadcast(val, src): tid = roc.get_local_id(0) roc.wavebarrier() val[tid] = src return val
def shuffle_down(val, width): tid = roc.get_local_id(0) roc.wavebarrier() idx = (tid + width) % WAVESIZE res = roc.ds_permute(idx, val) return res
def warp_scan(tid, temp, inclusive): """Intra-warp scan Note ---- Assume all threads are in lockstep """ roc.wavebarrier() lane = tid & (_WARPSIZE - 1) if lane >= 1: temp[tid] += temp[tid - 1] roc.wavebarrier() if lane >= 2: temp[tid] += temp[tid - 2] roc.wavebarrier() if lane >= 4: temp[tid] += temp[tid - 4] roc.wavebarrier() if lane >= 8: temp[tid] += temp[tid - 8] roc.wavebarrier() if lane >= 16: temp[tid] += temp[tid - 16] roc.wavebarrier() if lane >= 32: temp[tid] += temp[tid - 32] roc.wavebarrier() if inclusive: return temp[tid] else: return temp[tid - 1] if lane > 0 else 0
def broadcast(val, from_lane): tid = roc.get_local_id(0) roc.wavebarrier() res = roc.ds_bpermute(from_lane, val) return res
def broadcast(val, from_lane): tid = roc.get_local_id(0) roc.wavebarrier() res = roc.ds_bpermute(from_lane, val) return res
def shuffle_down(val, width): tid = roc.get_local_id(0) roc.wavebarrier() idx = (tid - width) % _WAVESIZE res = roc.ds_permute(idx, val) return res
def warp_scan(tid, temp, inclusive): """Intra-warp scan Note ---- Assume all threads are in lockstep """ roc.wavebarrier() lane = tid & (_WARPSIZE - 1) if lane >= 1: temp[tid] += temp[tid - 1] roc.wavebarrier() if lane >= 2: temp[tid] += temp[tid - 2] roc.wavebarrier() if lane >= 4: temp[tid] += temp[tid - 4] roc.wavebarrier() if lane >= 8: temp[tid] += temp[tid - 8] roc.wavebarrier() if lane >= 16: temp[tid] += temp[tid - 16] roc.wavebarrier() if lane >= 32: temp[tid] += temp[tid - 32] roc.wavebarrier() if inclusive: return temp[tid] else: return temp[tid - 1] if lane > 0 else 0
def shuf_wave_inclusive_scan(val): tid = roc.get_local_id(0) lane = tid & (_WARPSIZE - 1) roc.wavebarrier() shuf = shuffle_up(val, 1) if lane >= 1: val = dtype(val + shuf) roc.wavebarrier() shuf = shuffle_up(val, 2) if lane >= 2: val = dtype(val + shuf) roc.wavebarrier() shuf = shuffle_up(val, 4) if lane >= 4: val = dtype(val + shuf) roc.wavebarrier() shuf = shuffle_up(val, 8) if lane >= 8: val = dtype(val + shuf) roc.wavebarrier() shuf = shuffle_up(val, 16) if lane >= 16: val = dtype(val + shuf) roc.wavebarrier() shuf = shuffle_up(val, 32) if lane >= 32: val = dtype(val + shuf) roc.wavebarrier() return val