示例#1
0
        def four_way_scan(data, sm_masks, sm_blocksum, blksz, valid):
            sm_chunkoffset = hsa.shared.array(4, dtype=int32)

            tid = hsa.get_local_id(0)

            laneid = tid & (_WARPSIZE - 1)
            warpid = tid >> 6

            my_digit = -1

            for digit in range(RADIX):
                sm_masks[digit, tid] = 0
                if valid and data == digit:
                    sm_masks[digit, tid] = 1
                    my_digit = digit

            hsa.barrier()

            offset = 0
            base = 0
            while offset < blksz:
                # Exclusive scan
                if warpid < RADIX:
                    val = intp(sm_masks[warpid, offset + laneid])
                    cur, psum = shuf_wave_exclusive_scan(val)
                    sm_masks[warpid, offset + laneid] = cur + base
                    base += psum

                hsa.barrier()
                offset += _WARPSIZE

            hsa.barrier()

            # Store blocksum from the exclusive scan
            if warpid < RADIX and laneid == 0:
                sm_blocksum[warpid] = base

            hsa.barrier()
            # Calc chunk offset (a short exclusive scan)
            if tid == 0:
                sm_chunkoffset[0] = 0
                sm_chunkoffset[1] = sm_blocksum[0]
                sm_chunkoffset[2] = sm_chunkoffset[1] + sm_blocksum[1]
                sm_chunkoffset[3] = sm_chunkoffset[2] + sm_blocksum[2]

            hsa.barrier()
            # Prepare output
            chunk_offset = -1
            scanval = -1

            if my_digit != -1:
                chunk_offset = sm_chunkoffset[my_digit]
                scanval = sm_masks[my_digit, tid]

            hsa.wavebarrier()
            hsa.barrier()

            return chunk_offset, scanval
示例#2
0
def local_inclusive_scan_shuf(tid, value, nelem, temp):
    """
    * temp: shared array
        Size of the array must be at least the number of active wave

    Note: This function must be called by all threads in the block
    """
    hsa.barrier()
    hsa.wavebarrier()
    res = shuf_device_inclusive_scan(value, temp)
    hsa.barrier()
    return res
示例#3
0
def wave_reduce(val):
    tmp = val
    tid = hsa.get_local_id(0)
    laneid = tid & (WAVESIZE - 1)

    width = WAVESIZE // 2

    while width > 0:
        hsa.wavebarrier()
        other = hsa.activelanepermute_wavewidth(tmp, laneid + width, 0, False)
        if laneid < width:
            tmp += other

        width //= 2

    # First thread has the result
    hsa.wavebarrier()
    return hsa.activelanepermute_wavewidth(tmp, 0, 0, False)
示例#4
0
def wave_reduce(val):
    tmp = val
    tid = hsa.get_local_id(0)
    laneid = tid & (WAVESIZE - 1)

    width = WAVESIZE // 2

    while width > 0:
        hsa.wavebarrier()
        other = hsa.activelanepermute_wavewidth(tmp, laneid + width, 0, False)
        if laneid < width:
            tmp += other

        width //= 2

    # First thread has the result
    hsa.wavebarrier()
    return hsa.activelanepermute_wavewidth(tmp, 0, 0, False)
示例#5
0
def warp_scan(tid, temp, inclusive):
    """Intra-warp scan

    Note
    ----
    Assume all threads are in lockstep
    """
    hsa.wavebarrier()
    lane = tid & (_WARPSIZE - 1)
    if lane >= 1:
        temp[tid] += temp[tid - 1]

    hsa.wavebarrier()
    if lane >= 2:
        temp[tid] += temp[tid - 2]

    hsa.wavebarrier()
    if lane >= 4:
        temp[tid] += temp[tid - 4]

    hsa.wavebarrier()
    if lane >= 8:
        temp[tid] += temp[tid - 8]

    hsa.wavebarrier()
    if lane >= 16:
        temp[tid] += temp[tid - 16]

    hsa.wavebarrier()
    if lane >= 32:
        temp[tid] += temp[tid - 32]

    hsa.wavebarrier()
    if inclusive:
        return temp[tid]
    else:
        return temp[tid - 1] if lane > 0 else 0
示例#6
0
def shuf_wave_inclusive_scan(val):
    tid = hsa.get_local_id(0)
    lane = tid & (_WARPSIZE - 1)

    hsa.wavebarrier()
    shuf = shuffle_up(val, 1)
    if lane >= 1:
        val += shuf

    hsa.wavebarrier()
    shuf = shuffle_up(val, 2)
    if lane >= 2:
        val += shuf

    hsa.wavebarrier()
    shuf = shuffle_up(val, 4)
    if lane >= 4:
        val += shuf

    hsa.wavebarrier()
    shuf = shuffle_up(val, 8)
    if lane >= 8:
        val += shuf

    hsa.wavebarrier()
    shuf = shuffle_up(val, 16)
    if lane >= 16:
        val += shuf

    hsa.wavebarrier()
    shuf = shuffle_up(val, 32)
    if lane >= 32:
        val += shuf

    hsa.wavebarrier()
    return val
示例#7
0
def shuffle_up(val, width):
    tid = hsa.get_local_id(0)
    hsa.wavebarrier()
    res = hsa.activelanepermute_wavewidth(val, tid - width, 0, False)
    return res
def shuf_wave_inclusive_scan(val):
    tid = hsa.get_local_id(0)
    lane = tid & (_WARPSIZE - 1)

    hsa.wavebarrier()
    shuf = shuffle_up(val, 1)
    if lane >= 1:
        val += shuf

    hsa.wavebarrier()
    shuf = shuffle_up(val, 2)
    if lane >= 2:
        val += shuf

    hsa.wavebarrier()
    shuf = shuffle_up(val, 4)
    if lane >= 4:
        val += shuf

    hsa.wavebarrier()
    shuf = shuffle_up(val, 8)
    if lane >= 8:
        val += shuf

    hsa.wavebarrier()
    shuf = shuffle_up(val, 16)
    if lane >= 16:
        val += shuf

    hsa.wavebarrier()
    shuf = shuffle_up(val, 32)
    if lane >= 32:
        val += shuf

    hsa.wavebarrier()
    return val
示例#9
0
def broadcast(val, src):
    hsa.wavebarrier()
    return hsa.activelanepermute_wavewidth(val, src, 0, False)
示例#10
0
def shuffle_up(val, width):
    tid = hsa.get_local_id(0)
    hsa.wavebarrier()
    res = hsa.activelanepermute_wavewidth(val, tid - width, 0, False)
    return res
示例#11
0
def warp_scan(tid, temp, inclusive):
    """Intra-warp scan

    Note
    ----
    Assume all threads are in lockstep
    """
    hsa.wavebarrier()
    lane = tid & (_WARPSIZE - 1)
    if lane >= 1:
        temp[tid] += temp[tid - 1]

    hsa.wavebarrier()
    if lane >= 2:
        temp[tid] += temp[tid - 2]

    hsa.wavebarrier()
    if lane >= 4:
        temp[tid] += temp[tid - 4]

    hsa.wavebarrier()
    if lane >= 8:
        temp[tid] += temp[tid - 8]

    hsa.wavebarrier()
    if lane >= 16:
        temp[tid] += temp[tid - 16]

    hsa.wavebarrier()
    if lane >= 32:
        temp[tid] += temp[tid - 32]

    hsa.wavebarrier()
    if inclusive:
        return temp[tid]
    else:
        return temp[tid - 1] if lane > 0 else 0