def four_way_scan(data, sm_masks, sm_blocksum, blksz, valid): sm_chunkoffset = hsa.shared.array(4, dtype=int32) tid = hsa.get_local_id(0) laneid = tid & (_WARPSIZE - 1) warpid = tid >> 6 my_digit = -1 for digit in range(RADIX): sm_masks[digit, tid] = 0 if valid and data == digit: sm_masks[digit, tid] = 1 my_digit = digit hsa.barrier() offset = 0 base = 0 while offset < blksz: # Exclusive scan if warpid < RADIX: val = intp(sm_masks[warpid, offset + laneid]) cur, psum = shuf_wave_exclusive_scan(val) sm_masks[warpid, offset + laneid] = cur + base base += psum hsa.barrier() offset += _WARPSIZE hsa.barrier() # Store blocksum from the exclusive scan if warpid < RADIX and laneid == 0: sm_blocksum[warpid] = base hsa.barrier() # Calc chunk offset (a short exclusive scan) if tid == 0: sm_chunkoffset[0] = 0 sm_chunkoffset[1] = sm_blocksum[0] sm_chunkoffset[2] = sm_chunkoffset[1] + sm_blocksum[1] sm_chunkoffset[3] = sm_chunkoffset[2] + sm_blocksum[2] hsa.barrier() # Prepare output chunk_offset = -1 scanval = -1 if my_digit != -1: chunk_offset = sm_chunkoffset[my_digit] scanval = sm_masks[my_digit, tid] hsa.wavebarrier() hsa.barrier() return chunk_offset, scanval
def local_inclusive_scan_shuf(tid, value, nelem, temp): """ * temp: shared array Size of the array must be at least the number of active wave Note: This function must be called by all threads in the block """ hsa.barrier() hsa.wavebarrier() res = shuf_device_inclusive_scan(value, temp) hsa.barrier() return res
def wave_reduce(val): tmp = val tid = hsa.get_local_id(0) laneid = tid & (WAVESIZE - 1) width = WAVESIZE // 2 while width > 0: hsa.wavebarrier() other = hsa.activelanepermute_wavewidth(tmp, laneid + width, 0, False) if laneid < width: tmp += other width //= 2 # First thread has the result hsa.wavebarrier() return hsa.activelanepermute_wavewidth(tmp, 0, 0, False)
def warp_scan(tid, temp, inclusive): """Intra-warp scan Note ---- Assume all threads are in lockstep """ hsa.wavebarrier() lane = tid & (_WARPSIZE - 1) if lane >= 1: temp[tid] += temp[tid - 1] hsa.wavebarrier() if lane >= 2: temp[tid] += temp[tid - 2] hsa.wavebarrier() if lane >= 4: temp[tid] += temp[tid - 4] hsa.wavebarrier() if lane >= 8: temp[tid] += temp[tid - 8] hsa.wavebarrier() if lane >= 16: temp[tid] += temp[tid - 16] hsa.wavebarrier() if lane >= 32: temp[tid] += temp[tid - 32] hsa.wavebarrier() if inclusive: return temp[tid] else: return temp[tid - 1] if lane > 0 else 0
def shuf_wave_inclusive_scan(val): tid = hsa.get_local_id(0) lane = tid & (_WARPSIZE - 1) hsa.wavebarrier() shuf = shuffle_up(val, 1) if lane >= 1: val += shuf hsa.wavebarrier() shuf = shuffle_up(val, 2) if lane >= 2: val += shuf hsa.wavebarrier() shuf = shuffle_up(val, 4) if lane >= 4: val += shuf hsa.wavebarrier() shuf = shuffle_up(val, 8) if lane >= 8: val += shuf hsa.wavebarrier() shuf = shuffle_up(val, 16) if lane >= 16: val += shuf hsa.wavebarrier() shuf = shuffle_up(val, 32) if lane >= 32: val += shuf hsa.wavebarrier() return val
def shuffle_up(val, width): tid = hsa.get_local_id(0) hsa.wavebarrier() res = hsa.activelanepermute_wavewidth(val, tid - width, 0, False) return res
def broadcast(val, src): hsa.wavebarrier() return hsa.activelanepermute_wavewidth(val, src, 0, False)