Beispiel #1
0
def qpu_somatcopy_n(asm,
                    *,
                    num_qpus,
                    unroll,
                    code_offset,
                    align_cond=lambda pos: True):

    assert unroll in [1, 2, 4, 8]

    g = globals()
    for i, v in enumerate([
            'rows', 'cols', 'alpha', 'a', 'lda', 'b', 'ldb', 'qpu_num', 'a_i',
            'b_i', 'inc', 'rest_i', 'rest_j'
    ]):
        g[f'reg_{v}'] = rf[i]

    nop(sig=ldunifrf(reg_rows))
    nop(sig=ldunifrf(reg_cols))
    nop(sig=ldunifrf(reg_alpha))
    nop(sig=ldunifrf(reg_a))
    nop(sig=ldunifrf(reg_lda))
    nop(sig=ldunifrf(reg_b))
    nop(sig=ldunifrf(reg_ldb))

    if num_qpus == 1:
        mov(reg_qpu_num, 0)
    elif num_qpus == 8:
        tidx(r0)
        shr(r0, r0, 2)
        band(reg_qpu_num, r0, 0b1111)
    else:
        raise Exception('num_qpus must be 1 or 8')

    # a += 4 * lda * qpu_num + 4 * eidx
    # b += 4 * ldb * qpu_num + 4 * eidx
    eidx(r1).umul24(r0, reg_lda, reg_qpu_num)
    shl(r1, r1, ilog2(4)).add(reg_a, reg_a, r0)
    add(reg_a, reg_a, r1).add(reg_a_i, reg_a, r1)
    umul24(r0, reg_ldb, reg_qpu_num)
    add(reg_b, reg_b, r0)
    add(reg_b, reg_b, r1).add(reg_b_i, reg_b, r1)

    # lda *= num_qpus
    # ldb *= num_qpus
    umul24(reg_lda, reg_lda, num_qpus)
    umul24(reg_ldb, reg_ldb, num_qpus)

    shl(reg_inc, 4, 4)

    # rest_i = rows - qpu_num - 1
    sub(r0, reg_rows, reg_qpu_num)
    sub(reg_rest_i, r0, 1)

    nop(sig=thrsw)
    nop()
    nop()

    while not align_cond(code_offset + len(asm)):
        nop()

    with loop as li:

        shr(r0, reg_cols, ilog2(16) + ilog2(unroll), cond='pushz')
        b(R.rest_j, cond='a0')
        nop()
        nop()
        sub(reg_rest_j, r0, 2, cond='pushn')

        for i in range(unroll - 3):
            mov(tmua, reg_a).add(reg_a, reg_a, reg_inc)

        b(R.skip_j_unroll, cond='a0')
        mov(tmua, reg_a).add(reg_a, reg_a, reg_inc)
        mov(tmua, reg_a).add(reg_a, reg_a, reg_inc) if unroll >= 2 else nop()
        mov(tmua, reg_a).add(reg_a, reg_a, reg_inc) if unroll >= 3 else nop()

        with loop as lj:

            sub(reg_rest_j, reg_rest_j, 1, cond='pushn')

            for i in range(unroll - 1):
                nop(sig=ldtmu(r0))
                fmul(tmud, r0, reg_alpha)
                mov(tmua, reg_b).add(reg_b, reg_b, reg_inc)
                mov(tmua, reg_a).add(reg_a, reg_a, reg_inc)

            nop(sig=ldtmu(r0))
            lj.b(cond='na0')
            fmul(tmud, r0, reg_alpha)
            mov(tmua, reg_b).add(reg_b, reg_b, reg_inc)
            mov(tmua, reg_a).add(reg_a, reg_a, reg_inc)

        L.skip_j_unroll

        for i in range(unroll):
            nop(sig=ldtmu(r0))
            fmul(tmud, r0, reg_alpha)
            mov(tmua, reg_b).add(reg_b, reg_b, reg_inc)

        L.rest_j

        mov(r0, 1)
        shl(r0, r0, ilog2(16) + ilog2(unroll))
        sub(r0, r0, 1)
        band(r0, reg_cols, r0, cond='pushz')
        b(R.exit_j, cond='a0')
        sub(r0, r0, 1)
        add(r0, r0, -16, cond='pushn').mov(r1, r0)
        nop()
        with loop as lj0:
            lj0.b(cond='na0')
            mov(tmua, reg_a).add(reg_a, reg_a, reg_inc)
            add(r0, r0, -16, cond='pushn')
            nop()
        with loop as lj1:
            sub(r5, -1, r1)
            nop(sig=ldtmu(r0))
            mov(r3, reg_b)
            rotate(r0, r0, r5)
            rotate(r3, r3, r5)
            fmul(tmud, r0, reg_alpha)

            mov(broadcast, reg_b).add(reg_b, reg_b, reg_inc)
            eidx(r0).add(r2, r1, -16)
            add(null, r0, r2, cond='pushn')
            mov(r3, r5, cond='ifa')

            add(r1, r1, -16, cond='pushn')
            lj1.b(cond='na0')
            mov(tmua, r3)
            nop()
            nop()

        L.exit_j

        sub(reg_rest_i, reg_rest_i, num_qpus, cond='pushn')
        li.b(cond='na0')
        add(reg_a_i, reg_a_i, reg_lda)
        add(reg_b_i, reg_b_i, reg_ldb)
        mov(reg_a, reg_a_i).mov(reg_b, reg_b_i)

    barrierid(syncb, sig=thrsw)
    nop()
    nop()

    nop(sig=thrsw)
    nop(sig=thrsw)
    nop()
    nop()
    nop(sig=thrsw)
    nop()
    nop()
    nop()
Beispiel #2
0
def qpu_comatcopy_t(asm,
                    *,
                    num_qpus,
                    tile_rows,
                    tile_cols,
                    subtile_rows,
                    subtile_cols,
                    code_offset,
                    align_cond=lambda pos: True):

    tile_size = tile_rows * tile_cols
    assert 1 <= tile_size <= 32

    assert subtile_rows * subtile_cols == 16

    IDX4_ROWS, IDX4_COLS, IDX4_ALPHA_R, IDX4_ALPHA_I, IDX4_A_I, IDX4_A_J, \
        IDX4_B_I, IDX4_B_J, IDX4_LDA, IDX4_LDB, IDX4_I, IDX4_J = range(12)

    # r0 = eidx
    # r1 = qpu_num
    eidx(r0).mov(r4, 0)
    if num_qpus == 1:
        mov(r1, 0)
    elif num_qpus == 8:
        tidx(r1)
        shr(r1, r1, 2)
        band(r1, r1, 0b1111)
    else:
        raise Exception('num_qpus must be 1 or 8')

    for idx in [
            IDX4_ROWS, IDX4_COLS, IDX4_ALPHA_R, IDX4_ALPHA_I, IDX4_A_I,
            IDX4_LDA, IDX4_B_I, IDX4_LDB
    ]:
        nop(sig=ldunif)
        sub(null, r0, idx, cond='pushz')
        mov(r4, r5, cond='ifa')

    # I = ROWS / subtile_rows / tile_rows - qpu_num - num_qpus - 1
    sub(null, r0, IDX4_I, cond='pushz')
    rotate(broadcast, r4, -IDX4_ROWS)
    shr(r5, r5, ilog2(tile_rows * subtile_rows))
    sub(r5, r5, r1)
    sub(r4, r5, num_qpus + 1, cond='ifa')

    # A_I += LDA * qpu_num * tile_rows * subtile_rows * 8
    # B_I += qpu_num * tile_rows * subtile_rows * 8
    shl(r1, r1, ilog2(tile_rows * subtile_rows))
    rotate(broadcast, r4, -IDX4_LDA)
    sub(null, r0, IDX4_A_I, cond='pushz').umul24(r5, r1, r5)
    add(r4, r4, r5, cond='ifa').umul24(r5, r1, 8)
    sub(null, r0, IDX4_B_I, cond='pushz')
    add(r4, r4, r5, cond='ifa')

    nop(sig=thrsw)
    nop()
    nop()

    while not align_cond(code_offset + len(asm)):
        nop()

    with loop as li:

        # J = COLS / subtile_cols / tile_cols - 1
        eidx(r0).rotate(broadcast, r4, -IDX4_COLS)
        sub(null, r0, IDX4_J, cond='pushz')
        shr(r5, r5, ilog2(tile_cols * subtile_cols))
        sub(r4, r5, 1, cond='ifa')

        # A_J = A_I
        # B_J = B_I
        sub(null, r0, IDX4_A_J, cond='pushz')
        rotate(broadcast, r4, -IDX4_A_I)
        mov(r4, r5, cond='ifa')
        sub(null, r0, IDX4_B_J, cond='pushz')
        rotate(broadcast, r4, -IDX4_B_I)
        mov(r4, r5, cond='ifa')

        with loop as lj:

            # r0 = A_J + eidx / subtile_cols * LDA * 4 + eidx % subtile_cols * 8
            # r1 = subtile_cols * 8
            eidx(r1).rotate(broadcast, r4, -IDX4_A_J)
            band(r0, r1, subtile_cols - 1)
            shl(r0, r0, ilog2(8))
            add(r0, r0, r5).rotate(broadcast, r4, -IDX4_LDA)
            shr(r1, r1, ilog2(subtile_cols))
            mov(r5, 1).umul24(r1, r1, r5)
            shl(r1, r5, ilog2(subtile_cols * 8)).add(r0, r0, r1)

            tmu_unroll = min(tile_size, 4)

            for i in range(tile_size + tmu_unroll):
                if i > 0 and i % tile_cols == 0:
                    # r0 = -tile_cols * subtile_cols * 8 + LDA * subtile_rows * 8
                    shl(r5, r1, ilog2(tile_cols))
                    sub(r0, r0, r5).rotate(broadcast, r4, -IDX4_LDA)
                    shl(r5, r5, ilog2(subtile_rows))
                    add(r0, r0, r5)
                if i < tmu_unroll:
                    mov(tmua, r0).mov(r5, 4)
                    add(tmua, r0, r5).add(r0, r0, r1)
                if tmu_unroll <= i:
                    j = i - tmu_unroll
                    nop(sig=ldtmu(rf[0 + j * 2]))
                    nop(sig=ldtmu(rf[1 + j * 2]))
                if tmu_unroll <= i < tile_size:
                    mov(tmua, r0).mov(r5, 4)
                    add(tmua, r0, r5).add(r0, r0, r1)

            # r0 = B_J + eidx % subtile_cols * LDB * 8 + eidx / subtile_cols * 8
            # r1 = subtile_rows * 8
            # r3 = ALPHA_R
            # r5 = ALPHA_I
            eidx(r1).rotate(broadcast, r4, -IDX4_B_J)
            shr(r0, r1, ilog2(subtile_cols))
            shl(r0, r0, ilog2(8))
            add(r0, r0, r5).rotate(broadcast, r4, -IDX4_LDB)
            band(r1, r1, subtile_cols - 1)
            mov(r5, 1).umul24(r1, r1, r5)
            shl(r1, r5, ilog2(subtile_rows * 8)).add(r0, r0, r1)
            rotate(broadcast, r4, -IDX4_ALPHA_R)
            mov(r3, r5).rotate(broadcast, r4, -IDX4_ALPHA_I)

            for i in range(tile_size):
                if i > 0 and i % tile_rows == 0:
                    # r0 = -tile_rows * subtile_rows * 8 + LDB * subtile_cols * 8
                    shl(r5, r1, ilog2(tile_rows))
                    sub(r0, r0, r5).rotate(broadcast, r4, -IDX4_LDB)
                    shl(r5, r5, ilog2(subtile_cols))
                    add(r0, r0, r5).rotate(broadcast, r4, -IDX4_ALPHA_R)
                    mov(r3, r5).rotate(broadcast, r4, -IDX4_ALPHA_I)
                j = i % tile_rows * tile_cols + i // tile_rows
                fmul(r1, rf[0 + j * 2], r3)
                fmul(r2, rf[1 + j * 2], r5)
                fsub(tmud, r1, r2).fmul(r1, rf[0 + j * 2], r5)
                mov(tmua, r0).fmul(r2, rf[1 + j * 2], r3)
                fadd(tmud, r1, r2).mov(r1, 1)
                shl(r1, r1, ilog2(subtile_rows * 8))
                add(tmua, r0, 4).add(r0, r0, r1)

            # A_J += tile_cols * subtile_cols * 8
            # B_J += LDB * tile_cols * subtile_cols * 8
            eidx(r0).mov(r5, 1)
            sub(null, r0, IDX4_A_J, cond='pushz')
            shl(r5, r5, ilog2(tile_cols * subtile_cols * 8))
            add(r4, r4, r5, cond='ifa').rotate(broadcast, r4, -IDX4_LDB)
            sub(null, r0, IDX4_B_J, cond='pushz')
            shl(r5, r5, ilog2(tile_cols * subtile_cols))
            add(r4, r4, r5, cond='ifa')

            rotate(null, r4, -IDX4_J, cond='pushz')
            lj.b(cond='na0')
            sub(null, r0, IDX4_J, cond='pushz')
            sub(r4, r4, 1, cond='ifa')
            nop()

        # A_I += LDA * num_qpus * tile_rows * subtile_rows * 8
        # B_I += num_qpus * tile_rows * subtile_rows * 8
        eidx(r0).rotate(broadcast, r4, -IDX4_LDA)
        mov(r1, 1)
        shl(r1, r1, ilog2(num_qpus * tile_rows * subtile_rows))
        sub(null, r0, IDX4_A_I, cond='pushz').umul24(r5, r1, r5)
        add(r4, r4, r5, cond='ifa').sub(null, r0, IDX4_B_I, cond='pushz')
        shl(r5, r1, ilog2(8))
        add(r4, r4, r5, cond='ifa')

        rotate(null, r4, -IDX4_I, cond='pushn')
        li.b(cond='na0')
        sub(null, r0, IDX4_I, cond='pushz')
        sub(r4, r4, num_qpus, cond='ifa')
        nop()

    barrierid(syncb, sig=thrsw)
    nop()
    nop()

    nop(sig=thrsw)
    nop(sig=thrsw)
    nop()
    nop()
    nop(sig=thrsw)
    nop()
    nop()
    nop()
Beispiel #3
0
def qpu_fft2(asm, *, num_qpus, do_unroll):

    g = globals()
    for i, name in enumerate(['j', 'k', 'l', 'm', 'n', 'x', 'x0', 'x1', 'y',
                              'y0', 'y1', 'buf', 'omega_addr', 'omega_r',
                              'omega_c']):
        g[f'reg_{name}'] = rf[i]

    nop(sig=ldunifrf(reg_n))
    nop(sig=ldunifrf(reg_x))
    nop(sig=ldunifrf(reg_y))
    nop(sig=ldunifrf(reg_buf))
    nop(sig=ldunifrf(reg_omega_addr))

    nop(sig=thrsw)
    nop()
    nop()

    tidx(r0)
    shr(r0, r0, 2)
    band(null, r0, 0b1111, cond='pushz')
    if num_qpus == 8:
        b(R.exit, cond='na0')
        nop()
        nop()
        nop()
    else:
        raise Exception('num_qpus must be 8')

    b(R.set_unif, cond='always').unif_addr(reg_omega_addr)
    nop()
    nop()
    nop()
    L.set_unif

    shr(reg_k, reg_n, 1).mov(reg_j, 1)
    with loop as ljk:

        mov(reg_l, 0)
        with loop as ll:

            # m = j - 1
            # x0 = x + (j * l + eidx) * 8
            # x1 = x0 + (n / 2) * 8
            # y0 = y + (2 * j * l + eidx) * 8
            # y1 = y0 + j * 8
            sub(reg_m, reg_j, 1)
            eidx(r0).umul24(r1, reg_j, reg_l)
            add(r0, r0, r1)
            shl(r0, r0, ilog2(8)).add(r1, r0, r1)
            shl(r1, r1, ilog2(8)).add(reg_x0, reg_x, r0)
            add(reg_y0, reg_y, r1)
            shl(r0, reg_n, ilog2(8 // 2))
            add(reg_x1, reg_x0, r0, sig=ldunifrf(reg_omega_r))
            shl(r0, reg_j, ilog2(8))
            add(reg_y1, reg_y0, r0, sig=ldunifrf(reg_omega_c))

            if do_unroll:
                mov(r0, 1).mov(tmua, reg_x0)
                shl(r0, r0, ilog2(8 * 16)).mov(tmua, reg_x1)
                add(reg_x0, reg_x0, r0).add(tmua, reg_x0, 4)
                add(reg_x1, reg_x1, r0).add(tmua, reg_x1, 4)

            with loop as lm:

                mov(tmua, reg_x0)
                mov(tmua, reg_x1).mov(r2, 4)

                # eidx < 16 - (m + 1) ∴ m + eidx - 15 < 0
                # r5 = -1 - (rest - 1) = -rest
                eidx(r0).add(r1, reg_m, -15)
                add(null, r0, r1, cond='pushn').sub(r5, -1, reg_m)
                mov(r0, reg_y0).mov(r1, reg_y1)
                add(tmua, reg_x0, r2).rotate(reg_y0, r0, r5)
                add(tmua, reg_x1, r2).rotate(reg_y1, r1, r5)
                mov(reg_y0, r0, cond='ifa').mov(reg_y1, r1, cond='ifa')

                nop(sig=ldtmu(r0))
                nop(sig=ldtmu(r1))
                fadd(r2, r0, r1)
                fsub(r0, r0, r1).rotate(tmud, r2, r5)
                mov(tmua, reg_y0, sig=ldtmu(r1))
                nop(sig=ldtmu(r2))
                fadd(r3, r1, r2)
                fsub(r1, r1, r2).rotate(r3, r3, r5)
                mov(tmud, r3).fmul(r2, r0, reg_omega_r)
                add(tmua, reg_y0, 4)
                fmul(r3, r1, reg_omega_c)
                fsub(r2, r2, r3).add(reg_m, reg_m, -16, cond='pushn')
                rotate(tmud, r2, r5)
                mov(tmua, reg_y1).fmul(r2, r0, reg_omega_c)
                fmul(r3, r1, reg_omega_r)
                fadd(r2, r2, r3).mov(r0, 1)
                rotate(tmud, r2, r5)
                add(tmua, reg_y1, 4)

                lm.b(cond='na0')
                shl(r0, r0, ilog2(8 * 16))
                add(reg_x0, reg_x0, r0).add(reg_x1, reg_x1, r0)
                add(reg_y0, reg_y0, r0).add(reg_y1, reg_y1, r0)

            nop(sig=ldtmu(null)) if do_unroll else nop()
            add(reg_l, reg_l, 1)
            sub(null, reg_l, reg_k, cond='pushn')
            ll.b(cond='a0')
            nop(sig=ldtmu(null)) if do_unroll else nop()
            nop(sig=ldtmu(null)) if do_unroll else nop()
            nop(sig=ldtmu(null)) if do_unroll else nop()

        shr(reg_k, reg_k, 1, cond='pushz')
        ljk.b(cond='na0')
        shl(reg_j, reg_j, 1).sub(null, reg_j, 1, cond='pushz')
        mov(reg_y, reg_buf, cond='ifa').mov(r0, reg_y)
        mov(reg_y, reg_x, cond='ifna').mov(reg_x, r0)

    L.exit

    barrierid(syncb, sig=thrsw)
    nop()
    nop()

    nop(sig=thrsw)
    nop(sig=thrsw)
    nop()
    nop()
    nop(sig=thrsw)
    nop()
    nop()
    nop()
Beispiel #4
0
def qpu_sgemm_rtt(asm, *, num_qpus, code_offset, align_cond=lambda pos: True):

    # α ⋅ (P × Q) ⋅ (Q × R) + β ⋅ (P × R)
    # α ⋅ (m × k) ⋅ (k × n) + β ⋅ (m × n)
    # α ⋅ (i × k) ⋅ (k × j) + β ⋅ (i × j)

    IDX0_R, IDX0_Q, IDX0_A, IDX0_B, IDX0_C, IDX0_A_CUR, IDX0_B_CUR, \
        IDX0_LDA, IDX0_LDB, IDX0_LDC, IDX0_ALPHA, IDX0_BETA, \
        IDX0_I, IDX0_J, IDX0_K, IDX0_ROTATE_N = range(16)

    # r1 = qpu_num
    if num_qpus == 1:
        mov(r1, 0)
    elif num_qpus == 8:
        tidx(r1)
        shr(r1, r1, 2)
        band(r1, r1, 0b1111)
    else:
        raise Exception('num_qpus must be 1 or 8')

    eidx(r2).mov(r0, 0)
    for idx in [
            IDX0_I, IDX0_R, IDX0_Q, IDX0_A, IDX0_B, IDX0_C, IDX0_LDA, IDX0_LDB,
            IDX0_LDC, IDX0_ALPHA, IDX0_BETA
    ]:
        nop(sig=ldunifrf(r5))
        sub(null, r2, idx, cond='pushz')
        mov(r0, r5, cond='ifa')

    # LDA *= 4
    sub(null, r2, IDX0_LDA, cond='pushz')
    shl(r0, r0, ilog2(4), cond='ifa')
    # LDB *= 4
    sub(null, r2, IDX0_LDB, cond='pushz')
    shl(r0, r0, ilog2(4), cond='ifa')
    # LDC *= 4
    sub(null, r2, IDX0_LDC, cond='pushz')
    shl(r0, r0, ilog2(4), cond='ifa')
    # A += 4 * qpu_num
    sub(null, r2, IDX0_A, cond='pushz')
    shl(r3, r1, ilog2(4))
    add(r0, r0, r3, cond='ifa')
    # C += 4 * LDC * qpu_num
    nop()
    rotate(broadcast, r0, -IDX0_LDC)
    sub(null, r2, IDX0_C, cond='pushz').umul24(r3, r5, r1)
    add(r0, r0, r3, cond='ifa')
    # LDC *= num_qpus
    sub(null, r2, IDX0_LDC, cond='pushz')
    shl(r0, r0, ilog2(num_qpus), cond='ifa')

    for i in range(8):
        mov(rf[i], .0).mov(rf[i + 8], .0)

    nop(sig=thrsw)
    nop()
    nop()

    # I = P - 16 * num_qpus - 1
    eidx(r1).mov(r2, 1)
    sub(null, r1, IDX0_I, cond='pushz')
    shl(r2, r2, ilog2(16) + ilog2(num_qpus), cond='ifa')
    sub(r0, r0, r2, cond='ifa')
    sub(r0, r0, 1, cond='ifa')

    while not align_cond(code_offset + len(asm)):
        nop()

    with loop as lm:

        # J = R - 16 - 1
        assert IDX0_R == 0
        nop()
        eidx(r1).mov(broadcast, r0)
        sub(null, r1, IDX0_J, cond='pushz')
        add(r5, r5, -16)
        sub(r0, r5, 1, cond='ifa')

        with loop as ln:

            # For I = rest_m - 16 * num_qpus - 1 and cur_m = qpu_num:
            # r4 = rest_m - cur_m = I + 16 * num_qpus + 1 - qpu_num
            if num_qpus == 1:
                mov(r5, 0).mov(r4, 1)
            elif num_qpus == 8:
                tidx(r5).mov(r4, 1)
                shr(r5, r5, 2)
                band(r5, r5, 0b1111)
            shl(r4, r4, ilog2(16) + ilog2(num_qpus))
            sub(r4, r4, r5).rotate(broadcast, r0, -IDX0_I)
            # If rest_m <= cur_m ∴ we have nothing to do, then exit the m-loop.
            add(r4, r4, r5, cond='pushn')
            b(R.exit_m, cond='a0')
            # K = Q - 1
            eidx(r1).rotate(broadcast, r0, -IDX0_Q)
            sub(null, r1, IDX0_K, cond='pushz')
            sub(r0, r5, 1, cond='ifa')

            # A_CUR = A
            # B_CUR = B
            sub(null, r1, IDX0_A_CUR, cond='pushz')
            rotate(broadcast, r0, -IDX0_A)
            mov(r0, r5, cond='ifa')
            sub(null, r1, IDX0_B_CUR, cond='pushz')
            rotate(broadcast, r0, -IDX0_B)
            mov(r0, r5, cond='ifa')

            with loop as lk:

                # Load *(A_CUR + 4 * num_qpus * eidx)
                # and *(B_CUR + 4 * LDB * eidx)
                eidx(r1)
                shl(r2, r1, ilog2(4) + ilog2(num_qpus))
                rotate(broadcast, r0, -IDX0_A_CUR)
                add(tmua, r5, r2).rotate(broadcast, r0, -IDX0_LDB)
                umul24(r2, r5, r1)
                rotate(broadcast, r0, -IDX0_B_CUR)
                add(tmua, r5, r2)

                # A_CUR += 4 * LDA
                # B_CUR += 4
                eidx(r1).rotate(broadcast, r0, -IDX0_LDA)
                sub(null, r1, IDX0_A_CUR, cond='pushz')
                add(r0, r0, r5, cond='ifa') \
                    .sub(null, r1, IDX0_B_CUR, cond='pushz')
                add(r0, r0, 4, cond='ifa')

                nop(sig=ldtmu(r1))
                nop(sig=ldtmu(r2))

                # For J = rest_n - 16 - 1:
                # If rest_n - eidx < 1 ∴ J + 16 - eidx < 0, then zero-clear B.
                rotate(broadcast, r0, -IDX0_J)
                eidx(r3).sub(r5, r5, -16)
                sub(null, r5, r3, cond='pushn')
                mov(r2, .0, cond='ifa')

                mov(broadcast, r1).rotate(null, r0, -IDX0_K, cond='pushz')
                for i in range(16):
                    fmul(r3, r5, r2)
                    fadd(rf[i], rf[i], r3).rotate(broadcast, r1, -i - 1)

                lk.b(cond='na0')
                eidx(r1)
                sub(null, r1, IDX0_K, cond='pushz')
                sub(r0, r0, 1, cond='ifa')

            # *(C + 4 * eidx + 4 * LDC * num_qpus * l) for l = 0, 1, ..., 15
            nop()
            eidx(r1).rotate(broadcast, r0, -IDX0_C)
            shl(r1, r1, ilog2(4))
            add(r1, r5, r1)

            eidx(r2).rotate(broadcast, r0, -IDX0_J)
            # For J = rest_n - 16 - 1:
            # r5 = -1 - J = 16 - rest_n ≡ -rest_n (mod 16)
            # r3 = J + 1 = rest_n - 16
            sub(r5, -1, r5).sub(r3, r5, -1)
            # If eidx + rest_n - 16 < 0 ...
            add(null, r2, r3, cond='pushn').mov(r3, r5)
            mov(broadcast, r1).rotate(r1, r1, r5)
            mov(r1, r5, cond='ifa').sub(null, r2, IDX0_ROTATE_N, cond='pushz')
            mov(r0, r3, cond='ifa')

            for i in range(16):
                mov(tmua, r1).rotate(broadcast, r0, -IDX0_ALPHA)
                fmul(r2, rf[i], r5)
                rotate(broadcast, r0, -IDX0_ROTATE_N)
                nop()
                rotate(rf[i], r2, r5)
                nop(sig=ldtmu(r2))
                rotate(broadcast, r0, -IDX0_BETA)
                # If rest_m <= cur_m, then skip the remaining rows of C.
                sub(r4, r4, num_qpus, cond='pushn').fmul(r2, r2, r5)
                b(R.exit_c, cond='a0')
                fadd(tmud, rf[i], r2).mov(rf[i], .0)
                mov(tmua, r1).rotate(broadcast, r0, -IDX0_LDC)
                add(r1, r1, r5)

            L.exit_c

            # B += 4 * 16
            # B += 4 * LDB * 16
            # C += 4 * 16
            eidx(r1).umul24(r2, 8, 8)
            sub(null, r1, IDX0_C, cond='pushz')
            add(r0, r0, r2, cond='ifa').rotate(broadcast, r0, -IDX0_LDB)
            shl(r5, r5, ilog2(16))
            sub(null, r1, IDX0_B, cond='pushz')
            add(r0, r0, r5, cond='ifa').rotate(null, r0, -IDX0_J, cond='pushn')
            ln.b(cond='na0')
            eidx(r1)
            sub(null, r1, IDX0_J, cond='pushz')
            add(r0, r0, -16, cond='ifa')

        # A += 4 * 16 * num_qpus
        # B -= 4 * LDB * 16 * ⌈ R / 16 ⌉
        # C += 4 * LDC * 16 * num_qpus - 4 * 16 * ⌈ R / 16 ⌉
        eidx(r1).mov(r5, 1)
        sub(null, r1, IDX0_A, cond='pushz')
        shl(r5, r5, ilog2(4) + ilog2(16) + ilog2(num_qpus))
        add(r0, r0, r5, cond='ifa')
        nop()

        rotate(broadcast, r0, -IDX0_LDC)
        sub(null, r1, IDX0_C, cond='pushz')
        shl(r5, r5, ilog2(16))
        add(r0, r0, r5, cond='ifa')
        nop()

        assert IDX0_R == 0
        mov(broadcast, r0)
        add(r5, r5, 15)
        shr(r5, r5, ilog2(16))
        shl(r5, r5, ilog2(4) + ilog2(16))
        sub(r0, r0, r5, cond='ifa').sub(null, r1, IDX0_B, cond='pushz')
        shr(r1, r5, ilog2(4))
        rotate(broadcast, r0, -IDX0_LDB)
        umul24(r5, r1, r5)
        sub(r0, r0, r5, cond='ifa')

        mov(r2, 1)
        eidx(r1).rotate(null, r0, -IDX0_I, cond='pushn')
        lm.b(cond='na0')
        sub(null, r1, IDX0_I, cond='pushz')
        shl(r2, r2, ilog2(16) + ilog2(num_qpus))
        sub(r0, r0, r2, cond='ifa')

    L.exit_m

    barrierid(syncb, sig=thrsw)
    nop()
    nop()

    nop(sig=thrsw)
    nop(sig=thrsw)
    nop()
    nop()
    nop(sig=thrsw)
    nop()
    nop()
    nop()
Beispiel #5
0
def qpu_fft4(asm, *, num_qpus, is_forward):

    g = globals()
    for i, name in enumerate([
            'j', 'k', 'l', 'm', 'n', 'x_orig', 'x', 'y_orig', 'y', 'buf',
            'omega_r', 'omega_c', 'c0r', 'c0c', 'c1r', 'c1c', 'c2r', 'c2c',
            'c3r', 'c3c'
    ]):
        g[f'reg_{name}'] = rf[i]

    nop(sig=ldunifrf(reg_n))
    nop(sig=ldunifrf(reg_x_orig))
    nop(sig=ldunifrf(reg_y_orig))
    nop(sig=ldunifrf(reg_buf))

    nop(sig=thrsw)
    nop()
    nop()

    tidx(r0)
    shr(r0, r0, 2)
    band(null, r0, 0b1111, cond='pushz')
    if num_qpus == 8:
        b(R.exit, cond='na0')
        nop()
        nop()
        nop()
    else:
        raise Exception('num_qpus must be 8')

    b(R.set_unif, cond='always').unif_addr(absolute=True)
    nop()
    nop()
    nop()
    L.set_unif

    shr(reg_k, reg_n, ilog2(4))
    mov(reg_j, 1)
    with loop as ljk:

        mov(reg_l, 0)
        with loop as ll:

            # m = j - 1
            # x = x_orig + (j * l + eidx) * 8
            # y = y_orig + (4 * j * l + eidx) * 8
            sub(reg_m, reg_j, 1)
            eidx(r0).umul24(r1, reg_j, reg_l)
            add(r1, r0, r1).umul24(r2, r1, 4)
            shl(r1, r1, ilog2(8)).add(r2, r2, r0)
            shl(r2, r2, ilog2(8)).add(reg_x, reg_x_orig, r1)
            add(reg_y, reg_y_orig, r2)

            with loop as lm:

                # r0 = n / 4 * 8 = 2 n, r1 = 2 n + 4
                mov(tmua, reg_x).add(r0, reg_n, reg_n)
                add(tmua, reg_x, 4).add(r1, r0, 4)

                # r5 = 4 n, r1 = 4 n + 8
                add(tmua, reg_x, r0).add(r5, r0, r0)
                add(tmua, reg_x, r1).add(r1, r1, r1)

                # r0 = r0 + r5 = 2 n + 4 n = 6 n, r1 = r1 - 4 = 4 n + 4
                add(tmua, reg_x, r5).sub(r1, r1, 4)
                add(tmua, reg_x, r1).add(r0, r0, r5)

                # r0 = r0 + 4 = 6 n + 4
                add(tmua, reg_x, r0).add(r0, r0, 4)
                add(tmua, reg_x, r0)

                # eidx < 16 - (m + 1) ∴ m + eidx - 15 < 0
                # r5 = -1 - (rest - 1) = -rest
                eidx(r0).add(r1, reg_m, -15)
                add(null, r0, r1, cond='pushn').sub(r5, -1, reg_m)
                mov(r0, reg_y)
                rotate(reg_y, r0, r5)
                mov(reg_y, r0, cond='ifa')

                nop(sig=ldtmu(reg_c0r))
                nop(sig=ldtmu(reg_c0c))
                nop(sig=ldtmu(reg_c1r))
                nop(sig=ldtmu(reg_c1c))
                nop(sig=ldtmu(reg_c2r))
                nop(sig=ldtmu(reg_c2c))
                nop(sig=ldtmu(reg_c3r))
                nop(sig=ldtmu(reg_c3c))

                # d0, d1, d2, d3 are stored to c0, c2, c1, c3, resp.
                fadd(r0, reg_c0r, reg_c2r)
                fadd(r1, reg_c0c, reg_c2c)
                fsub(reg_c2r, reg_c0r, reg_c2r).mov(reg_c0r, r0)
                fsub(reg_c2c, reg_c0c, reg_c2c).mov(reg_c0c, r1)
                fadd(r0, reg_c1r, reg_c3r)
                fadd(r1, reg_c1c, reg_c3c)
                if is_forward:
                    fsub(r1, reg_c1c, reg_c3c).mov(reg_c1c, r1)
                    fsub(r0, reg_c3r, reg_c1r).mov(reg_c1r, r0)
                else:
                    fsub(r1, reg_c3c, reg_c1c).mov(reg_c1c, r1)
                    fsub(r0, reg_c1r, reg_c3r).mov(reg_c1r, r0)
                mov(reg_c3r, r1)
                mov(reg_c3c, r0)

                reg_d0r = reg_c0r
                reg_d0c = reg_c0c
                reg_d1r = reg_c2r
                reg_d1c = reg_c2c
                reg_d2r = reg_c1r
                reg_d2c = reg_c1c
                reg_d3r = reg_c3r
                reg_d3c = reg_c3c

                fadd(r0, reg_d0r, reg_d2r)
                rotate(tmud, r0, r5)
                fadd(r0, reg_d0c, reg_d2c)
                mov(tmua, reg_y)
                rotate(tmud, r0, r5)
                add(tmua, reg_y, 4)
                shl(r0, reg_j, ilog2(8))
                add(reg_y, reg_y, r0)

                fadd(r0, reg_d1r, reg_d3r, sig=ldunifrf(reg_omega_r))
                fadd(r1, reg_d1c, reg_d3c, sig=ldunifrf(reg_omega_c))
                fmul(r2, r0, reg_omega_r)
                fmul(r3, r1, reg_omega_c)
                fsub(r2, r2, r3)
                rotate(tmud, r2, r5)
                mov(tmua, reg_y).fmul(r2, r0, reg_omega_c)
                fmul(r3, r1, reg_omega_r)
                fadd(r2, r2, r3)
                rotate(tmud, r2, r5)
                add(tmua, reg_y, 4)
                shl(r0, reg_j, ilog2(8))
                add(reg_y, reg_y, r0)

                fsub(r0, reg_d0r, reg_d2r, sig=ldunifrf(reg_omega_r))
                fsub(r1, reg_d0c, reg_d2c, sig=ldunifrf(reg_omega_c))
                fmul(r2, r0, reg_omega_r)
                fmul(r3, r1, reg_omega_c)
                fsub(r2, r2, r3)
                rotate(tmud, r2, r5)
                mov(tmua, reg_y).fmul(r2, r0, reg_omega_c)
                fmul(r3, r1, reg_omega_r)
                fadd(r2, r2, r3)
                rotate(tmud, r2, r5)
                add(tmua, reg_y, 4)
                shl(r0, reg_j, ilog2(8))
                add(reg_y, reg_y, r0)

                fsub(r0, reg_d1r, reg_d3r, sig=ldunifrf(reg_omega_r))
                fsub(r1, reg_d1c, reg_d3c, sig=ldunifrf(reg_omega_c))
                fmul(r2, r0, reg_omega_r)
                fmul(r3, r1, reg_omega_c)
                fsub(r2, r2, r3)
                rotate(tmud, r2, r5)
                mov(tmua, reg_y).fmul(r2, r0, reg_omega_c)
                fmul(r3, r1, reg_omega_r)
                fadd(r2, r2, r3).add(reg_m, reg_m, -16, cond='pushn')
                rotate(tmud, r2, r5)
                add(tmua, reg_y, 4)
                add(r0, 12, 12)  # 8 * 3
                mov(r1, 1).umul24(r0, reg_j, r0)

                lm.b(cond='na0').unif_addr(absolute=False)
                sub(reg_y, reg_y, r0)
                shl(r0, r1, ilog2(8 * 16))
                add(reg_x, reg_x, r0).add(reg_y, reg_y, r0)

            add(reg_l, reg_l, 1)
            sub(null, reg_l, reg_k, cond='pushn')
            ll.b(cond='a0')
            nop()
            nop()
            nop()

        shr(reg_k, reg_k, ilog2(4), cond='pushz')
        ljk.b(cond='na0')
        shl(reg_j, reg_j, ilog2(4)).sub(null, reg_j, 2, cond='pushn')
        mov(reg_y_orig, reg_buf, cond='ifa').mov(r0, reg_y_orig)
        mov(reg_y_orig, reg_x_orig, cond='ifna').mov(reg_x_orig, r0)

    L.exit

    barrierid(syncb, sig=thrsw)
    nop()
    nop()

    nop(sig=thrsw)
    nop(sig=thrsw)
    nop()
    nop()
    nop(sig=thrsw)
    nop()
    nop()
    nop()