def qpu_somatcopy_n(asm, *, num_qpus, unroll, code_offset, align_cond=lambda pos: True): assert unroll in [1, 2, 4, 8] g = globals() for i, v in enumerate([ 'rows', 'cols', 'alpha', 'a', 'lda', 'b', 'ldb', 'qpu_num', 'a_i', 'b_i', 'inc', 'rest_i', 'rest_j' ]): g[f'reg_{v}'] = rf[i] nop(sig=ldunifrf(reg_rows)) nop(sig=ldunifrf(reg_cols)) nop(sig=ldunifrf(reg_alpha)) nop(sig=ldunifrf(reg_a)) nop(sig=ldunifrf(reg_lda)) nop(sig=ldunifrf(reg_b)) nop(sig=ldunifrf(reg_ldb)) if num_qpus == 1: mov(reg_qpu_num, 0) elif num_qpus == 8: tidx(r0) shr(r0, r0, 2) band(reg_qpu_num, r0, 0b1111) else: raise Exception('num_qpus must be 1 or 8') # a += 4 * lda * qpu_num + 4 * eidx # b += 4 * ldb * qpu_num + 4 * eidx eidx(r1).umul24(r0, reg_lda, reg_qpu_num) shl(r1, r1, ilog2(4)).add(reg_a, reg_a, r0) add(reg_a, reg_a, r1).add(reg_a_i, reg_a, r1) umul24(r0, reg_ldb, reg_qpu_num) add(reg_b, reg_b, r0) add(reg_b, reg_b, r1).add(reg_b_i, reg_b, r1) # lda *= num_qpus # ldb *= num_qpus umul24(reg_lda, reg_lda, num_qpus) umul24(reg_ldb, reg_ldb, num_qpus) shl(reg_inc, 4, 4) # rest_i = rows - qpu_num - 1 sub(r0, reg_rows, reg_qpu_num) sub(reg_rest_i, r0, 1) nop(sig=thrsw) nop() nop() while not align_cond(code_offset + len(asm)): nop() with loop as li: shr(r0, reg_cols, ilog2(16) + ilog2(unroll), cond='pushz') b(R.rest_j, cond='a0') nop() nop() sub(reg_rest_j, r0, 2, cond='pushn') for i in range(unroll - 3): mov(tmua, reg_a).add(reg_a, reg_a, reg_inc) b(R.skip_j_unroll, cond='a0') mov(tmua, reg_a).add(reg_a, reg_a, reg_inc) mov(tmua, reg_a).add(reg_a, reg_a, reg_inc) if unroll >= 2 else nop() mov(tmua, reg_a).add(reg_a, reg_a, reg_inc) if unroll >= 3 else nop() with loop as lj: sub(reg_rest_j, reg_rest_j, 1, cond='pushn') for i in range(unroll - 1): nop(sig=ldtmu(r0)) fmul(tmud, r0, reg_alpha) mov(tmua, reg_b).add(reg_b, reg_b, reg_inc) mov(tmua, reg_a).add(reg_a, reg_a, reg_inc) nop(sig=ldtmu(r0)) lj.b(cond='na0') fmul(tmud, r0, reg_alpha) mov(tmua, reg_b).add(reg_b, reg_b, reg_inc) mov(tmua, reg_a).add(reg_a, reg_a, reg_inc) L.skip_j_unroll for i in range(unroll): nop(sig=ldtmu(r0)) fmul(tmud, r0, reg_alpha) mov(tmua, reg_b).add(reg_b, reg_b, reg_inc) L.rest_j mov(r0, 1) shl(r0, r0, ilog2(16) + ilog2(unroll)) sub(r0, r0, 1) band(r0, reg_cols, r0, cond='pushz') b(R.exit_j, cond='a0') sub(r0, r0, 1) add(r0, r0, -16, cond='pushn').mov(r1, r0) nop() with loop as lj0: lj0.b(cond='na0') mov(tmua, reg_a).add(reg_a, reg_a, reg_inc) add(r0, r0, -16, cond='pushn') nop() with loop as lj1: sub(r5, -1, r1) nop(sig=ldtmu(r0)) mov(r3, reg_b) rotate(r0, r0, r5) rotate(r3, r3, r5) fmul(tmud, r0, reg_alpha) mov(broadcast, reg_b).add(reg_b, reg_b, reg_inc) eidx(r0).add(r2, r1, -16) add(null, r0, r2, cond='pushn') mov(r3, r5, cond='ifa') add(r1, r1, -16, cond='pushn') lj1.b(cond='na0') mov(tmua, r3) nop() nop() L.exit_j sub(reg_rest_i, reg_rest_i, num_qpus, cond='pushn') li.b(cond='na0') add(reg_a_i, reg_a_i, reg_lda) add(reg_b_i, reg_b_i, reg_ldb) mov(reg_a, reg_a_i).mov(reg_b, reg_b_i) barrierid(syncb, sig=thrsw) nop() nop() nop(sig=thrsw) nop(sig=thrsw) nop() nop() nop(sig=thrsw) nop() nop() nop()
def qpu_comatcopy_t(asm, *, num_qpus, tile_rows, tile_cols, subtile_rows, subtile_cols, code_offset, align_cond=lambda pos: True): tile_size = tile_rows * tile_cols assert 1 <= tile_size <= 32 assert subtile_rows * subtile_cols == 16 IDX4_ROWS, IDX4_COLS, IDX4_ALPHA_R, IDX4_ALPHA_I, IDX4_A_I, IDX4_A_J, \ IDX4_B_I, IDX4_B_J, IDX4_LDA, IDX4_LDB, IDX4_I, IDX4_J = range(12) # r0 = eidx # r1 = qpu_num eidx(r0).mov(r4, 0) if num_qpus == 1: mov(r1, 0) elif num_qpus == 8: tidx(r1) shr(r1, r1, 2) band(r1, r1, 0b1111) else: raise Exception('num_qpus must be 1 or 8') for idx in [ IDX4_ROWS, IDX4_COLS, IDX4_ALPHA_R, IDX4_ALPHA_I, IDX4_A_I, IDX4_LDA, IDX4_B_I, IDX4_LDB ]: nop(sig=ldunif) sub(null, r0, idx, cond='pushz') mov(r4, r5, cond='ifa') # I = ROWS / subtile_rows / tile_rows - qpu_num - num_qpus - 1 sub(null, r0, IDX4_I, cond='pushz') rotate(broadcast, r4, -IDX4_ROWS) shr(r5, r5, ilog2(tile_rows * subtile_rows)) sub(r5, r5, r1) sub(r4, r5, num_qpus + 1, cond='ifa') # A_I += LDA * qpu_num * tile_rows * subtile_rows * 8 # B_I += qpu_num * tile_rows * subtile_rows * 8 shl(r1, r1, ilog2(tile_rows * subtile_rows)) rotate(broadcast, r4, -IDX4_LDA) sub(null, r0, IDX4_A_I, cond='pushz').umul24(r5, r1, r5) add(r4, r4, r5, cond='ifa').umul24(r5, r1, 8) sub(null, r0, IDX4_B_I, cond='pushz') add(r4, r4, r5, cond='ifa') nop(sig=thrsw) nop() nop() while not align_cond(code_offset + len(asm)): nop() with loop as li: # J = COLS / subtile_cols / tile_cols - 1 eidx(r0).rotate(broadcast, r4, -IDX4_COLS) sub(null, r0, IDX4_J, cond='pushz') shr(r5, r5, ilog2(tile_cols * subtile_cols)) sub(r4, r5, 1, cond='ifa') # A_J = A_I # B_J = B_I sub(null, r0, IDX4_A_J, cond='pushz') rotate(broadcast, r4, -IDX4_A_I) mov(r4, r5, cond='ifa') sub(null, r0, IDX4_B_J, cond='pushz') rotate(broadcast, r4, -IDX4_B_I) mov(r4, r5, cond='ifa') with loop as lj: # r0 = A_J + eidx / subtile_cols * LDA * 4 + eidx % subtile_cols * 8 # r1 = subtile_cols * 8 eidx(r1).rotate(broadcast, r4, -IDX4_A_J) band(r0, r1, subtile_cols - 1) shl(r0, r0, ilog2(8)) add(r0, r0, r5).rotate(broadcast, r4, -IDX4_LDA) shr(r1, r1, ilog2(subtile_cols)) mov(r5, 1).umul24(r1, r1, r5) shl(r1, r5, ilog2(subtile_cols * 8)).add(r0, r0, r1) tmu_unroll = min(tile_size, 4) for i in range(tile_size + tmu_unroll): if i > 0 and i % tile_cols == 0: # r0 = -tile_cols * subtile_cols * 8 + LDA * subtile_rows * 8 shl(r5, r1, ilog2(tile_cols)) sub(r0, r0, r5).rotate(broadcast, r4, -IDX4_LDA) shl(r5, r5, ilog2(subtile_rows)) add(r0, r0, r5) if i < tmu_unroll: mov(tmua, r0).mov(r5, 4) add(tmua, r0, r5).add(r0, r0, r1) if tmu_unroll <= i: j = i - tmu_unroll nop(sig=ldtmu(rf[0 + j * 2])) nop(sig=ldtmu(rf[1 + j * 2])) if tmu_unroll <= i < tile_size: mov(tmua, r0).mov(r5, 4) add(tmua, r0, r5).add(r0, r0, r1) # r0 = B_J + eidx % subtile_cols * LDB * 8 + eidx / subtile_cols * 8 # r1 = subtile_rows * 8 # r3 = ALPHA_R # r5 = ALPHA_I eidx(r1).rotate(broadcast, r4, -IDX4_B_J) shr(r0, r1, ilog2(subtile_cols)) shl(r0, r0, ilog2(8)) add(r0, r0, r5).rotate(broadcast, r4, -IDX4_LDB) band(r1, r1, subtile_cols - 1) mov(r5, 1).umul24(r1, r1, r5) shl(r1, r5, ilog2(subtile_rows * 8)).add(r0, r0, r1) rotate(broadcast, r4, -IDX4_ALPHA_R) mov(r3, r5).rotate(broadcast, r4, -IDX4_ALPHA_I) for i in range(tile_size): if i > 0 and i % tile_rows == 0: # r0 = -tile_rows * subtile_rows * 8 + LDB * subtile_cols * 8 shl(r5, r1, ilog2(tile_rows)) sub(r0, r0, r5).rotate(broadcast, r4, -IDX4_LDB) shl(r5, r5, ilog2(subtile_cols)) add(r0, r0, r5).rotate(broadcast, r4, -IDX4_ALPHA_R) mov(r3, r5).rotate(broadcast, r4, -IDX4_ALPHA_I) j = i % tile_rows * tile_cols + i // tile_rows fmul(r1, rf[0 + j * 2], r3) fmul(r2, rf[1 + j * 2], r5) fsub(tmud, r1, r2).fmul(r1, rf[0 + j * 2], r5) mov(tmua, r0).fmul(r2, rf[1 + j * 2], r3) fadd(tmud, r1, r2).mov(r1, 1) shl(r1, r1, ilog2(subtile_rows * 8)) add(tmua, r0, 4).add(r0, r0, r1) # A_J += tile_cols * subtile_cols * 8 # B_J += LDB * tile_cols * subtile_cols * 8 eidx(r0).mov(r5, 1) sub(null, r0, IDX4_A_J, cond='pushz') shl(r5, r5, ilog2(tile_cols * subtile_cols * 8)) add(r4, r4, r5, cond='ifa').rotate(broadcast, r4, -IDX4_LDB) sub(null, r0, IDX4_B_J, cond='pushz') shl(r5, r5, ilog2(tile_cols * subtile_cols)) add(r4, r4, r5, cond='ifa') rotate(null, r4, -IDX4_J, cond='pushz') lj.b(cond='na0') sub(null, r0, IDX4_J, cond='pushz') sub(r4, r4, 1, cond='ifa') nop() # A_I += LDA * num_qpus * tile_rows * subtile_rows * 8 # B_I += num_qpus * tile_rows * subtile_rows * 8 eidx(r0).rotate(broadcast, r4, -IDX4_LDA) mov(r1, 1) shl(r1, r1, ilog2(num_qpus * tile_rows * subtile_rows)) sub(null, r0, IDX4_A_I, cond='pushz').umul24(r5, r1, r5) add(r4, r4, r5, cond='ifa').sub(null, r0, IDX4_B_I, cond='pushz') shl(r5, r1, ilog2(8)) add(r4, r4, r5, cond='ifa') rotate(null, r4, -IDX4_I, cond='pushn') li.b(cond='na0') sub(null, r0, IDX4_I, cond='pushz') sub(r4, r4, num_qpus, cond='ifa') nop() barrierid(syncb, sig=thrsw) nop() nop() nop(sig=thrsw) nop(sig=thrsw) nop() nop() nop(sig=thrsw) nop() nop() nop()
def qpu_fft2(asm, *, num_qpus, do_unroll): g = globals() for i, name in enumerate(['j', 'k', 'l', 'm', 'n', 'x', 'x0', 'x1', 'y', 'y0', 'y1', 'buf', 'omega_addr', 'omega_r', 'omega_c']): g[f'reg_{name}'] = rf[i] nop(sig=ldunifrf(reg_n)) nop(sig=ldunifrf(reg_x)) nop(sig=ldunifrf(reg_y)) nop(sig=ldunifrf(reg_buf)) nop(sig=ldunifrf(reg_omega_addr)) nop(sig=thrsw) nop() nop() tidx(r0) shr(r0, r0, 2) band(null, r0, 0b1111, cond='pushz') if num_qpus == 8: b(R.exit, cond='na0') nop() nop() nop() else: raise Exception('num_qpus must be 8') b(R.set_unif, cond='always').unif_addr(reg_omega_addr) nop() nop() nop() L.set_unif shr(reg_k, reg_n, 1).mov(reg_j, 1) with loop as ljk: mov(reg_l, 0) with loop as ll: # m = j - 1 # x0 = x + (j * l + eidx) * 8 # x1 = x0 + (n / 2) * 8 # y0 = y + (2 * j * l + eidx) * 8 # y1 = y0 + j * 8 sub(reg_m, reg_j, 1) eidx(r0).umul24(r1, reg_j, reg_l) add(r0, r0, r1) shl(r0, r0, ilog2(8)).add(r1, r0, r1) shl(r1, r1, ilog2(8)).add(reg_x0, reg_x, r0) add(reg_y0, reg_y, r1) shl(r0, reg_n, ilog2(8 // 2)) add(reg_x1, reg_x0, r0, sig=ldunifrf(reg_omega_r)) shl(r0, reg_j, ilog2(8)) add(reg_y1, reg_y0, r0, sig=ldunifrf(reg_omega_c)) if do_unroll: mov(r0, 1).mov(tmua, reg_x0) shl(r0, r0, ilog2(8 * 16)).mov(tmua, reg_x1) add(reg_x0, reg_x0, r0).add(tmua, reg_x0, 4) add(reg_x1, reg_x1, r0).add(tmua, reg_x1, 4) with loop as lm: mov(tmua, reg_x0) mov(tmua, reg_x1).mov(r2, 4) # eidx < 16 - (m + 1) ∴ m + eidx - 15 < 0 # r5 = -1 - (rest - 1) = -rest eidx(r0).add(r1, reg_m, -15) add(null, r0, r1, cond='pushn').sub(r5, -1, reg_m) mov(r0, reg_y0).mov(r1, reg_y1) add(tmua, reg_x0, r2).rotate(reg_y0, r0, r5) add(tmua, reg_x1, r2).rotate(reg_y1, r1, r5) mov(reg_y0, r0, cond='ifa').mov(reg_y1, r1, cond='ifa') nop(sig=ldtmu(r0)) nop(sig=ldtmu(r1)) fadd(r2, r0, r1) fsub(r0, r0, r1).rotate(tmud, r2, r5) mov(tmua, reg_y0, sig=ldtmu(r1)) nop(sig=ldtmu(r2)) fadd(r3, r1, r2) fsub(r1, r1, r2).rotate(r3, r3, r5) mov(tmud, r3).fmul(r2, r0, reg_omega_r) add(tmua, reg_y0, 4) fmul(r3, r1, reg_omega_c) fsub(r2, r2, r3).add(reg_m, reg_m, -16, cond='pushn') rotate(tmud, r2, r5) mov(tmua, reg_y1).fmul(r2, r0, reg_omega_c) fmul(r3, r1, reg_omega_r) fadd(r2, r2, r3).mov(r0, 1) rotate(tmud, r2, r5) add(tmua, reg_y1, 4) lm.b(cond='na0') shl(r0, r0, ilog2(8 * 16)) add(reg_x0, reg_x0, r0).add(reg_x1, reg_x1, r0) add(reg_y0, reg_y0, r0).add(reg_y1, reg_y1, r0) nop(sig=ldtmu(null)) if do_unroll else nop() add(reg_l, reg_l, 1) sub(null, reg_l, reg_k, cond='pushn') ll.b(cond='a0') nop(sig=ldtmu(null)) if do_unroll else nop() nop(sig=ldtmu(null)) if do_unroll else nop() nop(sig=ldtmu(null)) if do_unroll else nop() shr(reg_k, reg_k, 1, cond='pushz') ljk.b(cond='na0') shl(reg_j, reg_j, 1).sub(null, reg_j, 1, cond='pushz') mov(reg_y, reg_buf, cond='ifa').mov(r0, reg_y) mov(reg_y, reg_x, cond='ifna').mov(reg_x, r0) L.exit barrierid(syncb, sig=thrsw) nop() nop() nop(sig=thrsw) nop(sig=thrsw) nop() nop() nop(sig=thrsw) nop() nop() nop()
def qpu_sgemm_rtt(asm, *, num_qpus, code_offset, align_cond=lambda pos: True): # α ⋅ (P × Q) ⋅ (Q × R) + β ⋅ (P × R) # α ⋅ (m × k) ⋅ (k × n) + β ⋅ (m × n) # α ⋅ (i × k) ⋅ (k × j) + β ⋅ (i × j) IDX0_R, IDX0_Q, IDX0_A, IDX0_B, IDX0_C, IDX0_A_CUR, IDX0_B_CUR, \ IDX0_LDA, IDX0_LDB, IDX0_LDC, IDX0_ALPHA, IDX0_BETA, \ IDX0_I, IDX0_J, IDX0_K, IDX0_ROTATE_N = range(16) # r1 = qpu_num if num_qpus == 1: mov(r1, 0) elif num_qpus == 8: tidx(r1) shr(r1, r1, 2) band(r1, r1, 0b1111) else: raise Exception('num_qpus must be 1 or 8') eidx(r2).mov(r0, 0) for idx in [ IDX0_I, IDX0_R, IDX0_Q, IDX0_A, IDX0_B, IDX0_C, IDX0_LDA, IDX0_LDB, IDX0_LDC, IDX0_ALPHA, IDX0_BETA ]: nop(sig=ldunifrf(r5)) sub(null, r2, idx, cond='pushz') mov(r0, r5, cond='ifa') # LDA *= 4 sub(null, r2, IDX0_LDA, cond='pushz') shl(r0, r0, ilog2(4), cond='ifa') # LDB *= 4 sub(null, r2, IDX0_LDB, cond='pushz') shl(r0, r0, ilog2(4), cond='ifa') # LDC *= 4 sub(null, r2, IDX0_LDC, cond='pushz') shl(r0, r0, ilog2(4), cond='ifa') # A += 4 * qpu_num sub(null, r2, IDX0_A, cond='pushz') shl(r3, r1, ilog2(4)) add(r0, r0, r3, cond='ifa') # C += 4 * LDC * qpu_num nop() rotate(broadcast, r0, -IDX0_LDC) sub(null, r2, IDX0_C, cond='pushz').umul24(r3, r5, r1) add(r0, r0, r3, cond='ifa') # LDC *= num_qpus sub(null, r2, IDX0_LDC, cond='pushz') shl(r0, r0, ilog2(num_qpus), cond='ifa') for i in range(8): mov(rf[i], .0).mov(rf[i + 8], .0) nop(sig=thrsw) nop() nop() # I = P - 16 * num_qpus - 1 eidx(r1).mov(r2, 1) sub(null, r1, IDX0_I, cond='pushz') shl(r2, r2, ilog2(16) + ilog2(num_qpus), cond='ifa') sub(r0, r0, r2, cond='ifa') sub(r0, r0, 1, cond='ifa') while not align_cond(code_offset + len(asm)): nop() with loop as lm: # J = R - 16 - 1 assert IDX0_R == 0 nop() eidx(r1).mov(broadcast, r0) sub(null, r1, IDX0_J, cond='pushz') add(r5, r5, -16) sub(r0, r5, 1, cond='ifa') with loop as ln: # For I = rest_m - 16 * num_qpus - 1 and cur_m = qpu_num: # r4 = rest_m - cur_m = I + 16 * num_qpus + 1 - qpu_num if num_qpus == 1: mov(r5, 0).mov(r4, 1) elif num_qpus == 8: tidx(r5).mov(r4, 1) shr(r5, r5, 2) band(r5, r5, 0b1111) shl(r4, r4, ilog2(16) + ilog2(num_qpus)) sub(r4, r4, r5).rotate(broadcast, r0, -IDX0_I) # If rest_m <= cur_m ∴ we have nothing to do, then exit the m-loop. add(r4, r4, r5, cond='pushn') b(R.exit_m, cond='a0') # K = Q - 1 eidx(r1).rotate(broadcast, r0, -IDX0_Q) sub(null, r1, IDX0_K, cond='pushz') sub(r0, r5, 1, cond='ifa') # A_CUR = A # B_CUR = B sub(null, r1, IDX0_A_CUR, cond='pushz') rotate(broadcast, r0, -IDX0_A) mov(r0, r5, cond='ifa') sub(null, r1, IDX0_B_CUR, cond='pushz') rotate(broadcast, r0, -IDX0_B) mov(r0, r5, cond='ifa') with loop as lk: # Load *(A_CUR + 4 * num_qpus * eidx) # and *(B_CUR + 4 * LDB * eidx) eidx(r1) shl(r2, r1, ilog2(4) + ilog2(num_qpus)) rotate(broadcast, r0, -IDX0_A_CUR) add(tmua, r5, r2).rotate(broadcast, r0, -IDX0_LDB) umul24(r2, r5, r1) rotate(broadcast, r0, -IDX0_B_CUR) add(tmua, r5, r2) # A_CUR += 4 * LDA # B_CUR += 4 eidx(r1).rotate(broadcast, r0, -IDX0_LDA) sub(null, r1, IDX0_A_CUR, cond='pushz') add(r0, r0, r5, cond='ifa') \ .sub(null, r1, IDX0_B_CUR, cond='pushz') add(r0, r0, 4, cond='ifa') nop(sig=ldtmu(r1)) nop(sig=ldtmu(r2)) # For J = rest_n - 16 - 1: # If rest_n - eidx < 1 ∴ J + 16 - eidx < 0, then zero-clear B. rotate(broadcast, r0, -IDX0_J) eidx(r3).sub(r5, r5, -16) sub(null, r5, r3, cond='pushn') mov(r2, .0, cond='ifa') mov(broadcast, r1).rotate(null, r0, -IDX0_K, cond='pushz') for i in range(16): fmul(r3, r5, r2) fadd(rf[i], rf[i], r3).rotate(broadcast, r1, -i - 1) lk.b(cond='na0') eidx(r1) sub(null, r1, IDX0_K, cond='pushz') sub(r0, r0, 1, cond='ifa') # *(C + 4 * eidx + 4 * LDC * num_qpus * l) for l = 0, 1, ..., 15 nop() eidx(r1).rotate(broadcast, r0, -IDX0_C) shl(r1, r1, ilog2(4)) add(r1, r5, r1) eidx(r2).rotate(broadcast, r0, -IDX0_J) # For J = rest_n - 16 - 1: # r5 = -1 - J = 16 - rest_n ≡ -rest_n (mod 16) # r3 = J + 1 = rest_n - 16 sub(r5, -1, r5).sub(r3, r5, -1) # If eidx + rest_n - 16 < 0 ... add(null, r2, r3, cond='pushn').mov(r3, r5) mov(broadcast, r1).rotate(r1, r1, r5) mov(r1, r5, cond='ifa').sub(null, r2, IDX0_ROTATE_N, cond='pushz') mov(r0, r3, cond='ifa') for i in range(16): mov(tmua, r1).rotate(broadcast, r0, -IDX0_ALPHA) fmul(r2, rf[i], r5) rotate(broadcast, r0, -IDX0_ROTATE_N) nop() rotate(rf[i], r2, r5) nop(sig=ldtmu(r2)) rotate(broadcast, r0, -IDX0_BETA) # If rest_m <= cur_m, then skip the remaining rows of C. sub(r4, r4, num_qpus, cond='pushn').fmul(r2, r2, r5) b(R.exit_c, cond='a0') fadd(tmud, rf[i], r2).mov(rf[i], .0) mov(tmua, r1).rotate(broadcast, r0, -IDX0_LDC) add(r1, r1, r5) L.exit_c # B += 4 * 16 # B += 4 * LDB * 16 # C += 4 * 16 eidx(r1).umul24(r2, 8, 8) sub(null, r1, IDX0_C, cond='pushz') add(r0, r0, r2, cond='ifa').rotate(broadcast, r0, -IDX0_LDB) shl(r5, r5, ilog2(16)) sub(null, r1, IDX0_B, cond='pushz') add(r0, r0, r5, cond='ifa').rotate(null, r0, -IDX0_J, cond='pushn') ln.b(cond='na0') eidx(r1) sub(null, r1, IDX0_J, cond='pushz') add(r0, r0, -16, cond='ifa') # A += 4 * 16 * num_qpus # B -= 4 * LDB * 16 * ⌈ R / 16 ⌉ # C += 4 * LDC * 16 * num_qpus - 4 * 16 * ⌈ R / 16 ⌉ eidx(r1).mov(r5, 1) sub(null, r1, IDX0_A, cond='pushz') shl(r5, r5, ilog2(4) + ilog2(16) + ilog2(num_qpus)) add(r0, r0, r5, cond='ifa') nop() rotate(broadcast, r0, -IDX0_LDC) sub(null, r1, IDX0_C, cond='pushz') shl(r5, r5, ilog2(16)) add(r0, r0, r5, cond='ifa') nop() assert IDX0_R == 0 mov(broadcast, r0) add(r5, r5, 15) shr(r5, r5, ilog2(16)) shl(r5, r5, ilog2(4) + ilog2(16)) sub(r0, r0, r5, cond='ifa').sub(null, r1, IDX0_B, cond='pushz') shr(r1, r5, ilog2(4)) rotate(broadcast, r0, -IDX0_LDB) umul24(r5, r1, r5) sub(r0, r0, r5, cond='ifa') mov(r2, 1) eidx(r1).rotate(null, r0, -IDX0_I, cond='pushn') lm.b(cond='na0') sub(null, r1, IDX0_I, cond='pushz') shl(r2, r2, ilog2(16) + ilog2(num_qpus)) sub(r0, r0, r2, cond='ifa') L.exit_m barrierid(syncb, sig=thrsw) nop() nop() nop(sig=thrsw) nop(sig=thrsw) nop() nop() nop(sig=thrsw) nop() nop() nop()
def qpu_fft4(asm, *, num_qpus, is_forward): g = globals() for i, name in enumerate([ 'j', 'k', 'l', 'm', 'n', 'x_orig', 'x', 'y_orig', 'y', 'buf', 'omega_r', 'omega_c', 'c0r', 'c0c', 'c1r', 'c1c', 'c2r', 'c2c', 'c3r', 'c3c' ]): g[f'reg_{name}'] = rf[i] nop(sig=ldunifrf(reg_n)) nop(sig=ldunifrf(reg_x_orig)) nop(sig=ldunifrf(reg_y_orig)) nop(sig=ldunifrf(reg_buf)) nop(sig=thrsw) nop() nop() tidx(r0) shr(r0, r0, 2) band(null, r0, 0b1111, cond='pushz') if num_qpus == 8: b(R.exit, cond='na0') nop() nop() nop() else: raise Exception('num_qpus must be 8') b(R.set_unif, cond='always').unif_addr(absolute=True) nop() nop() nop() L.set_unif shr(reg_k, reg_n, ilog2(4)) mov(reg_j, 1) with loop as ljk: mov(reg_l, 0) with loop as ll: # m = j - 1 # x = x_orig + (j * l + eidx) * 8 # y = y_orig + (4 * j * l + eidx) * 8 sub(reg_m, reg_j, 1) eidx(r0).umul24(r1, reg_j, reg_l) add(r1, r0, r1).umul24(r2, r1, 4) shl(r1, r1, ilog2(8)).add(r2, r2, r0) shl(r2, r2, ilog2(8)).add(reg_x, reg_x_orig, r1) add(reg_y, reg_y_orig, r2) with loop as lm: # r0 = n / 4 * 8 = 2 n, r1 = 2 n + 4 mov(tmua, reg_x).add(r0, reg_n, reg_n) add(tmua, reg_x, 4).add(r1, r0, 4) # r5 = 4 n, r1 = 4 n + 8 add(tmua, reg_x, r0).add(r5, r0, r0) add(tmua, reg_x, r1).add(r1, r1, r1) # r0 = r0 + r5 = 2 n + 4 n = 6 n, r1 = r1 - 4 = 4 n + 4 add(tmua, reg_x, r5).sub(r1, r1, 4) add(tmua, reg_x, r1).add(r0, r0, r5) # r0 = r0 + 4 = 6 n + 4 add(tmua, reg_x, r0).add(r0, r0, 4) add(tmua, reg_x, r0) # eidx < 16 - (m + 1) ∴ m + eidx - 15 < 0 # r5 = -1 - (rest - 1) = -rest eidx(r0).add(r1, reg_m, -15) add(null, r0, r1, cond='pushn').sub(r5, -1, reg_m) mov(r0, reg_y) rotate(reg_y, r0, r5) mov(reg_y, r0, cond='ifa') nop(sig=ldtmu(reg_c0r)) nop(sig=ldtmu(reg_c0c)) nop(sig=ldtmu(reg_c1r)) nop(sig=ldtmu(reg_c1c)) nop(sig=ldtmu(reg_c2r)) nop(sig=ldtmu(reg_c2c)) nop(sig=ldtmu(reg_c3r)) nop(sig=ldtmu(reg_c3c)) # d0, d1, d2, d3 are stored to c0, c2, c1, c3, resp. fadd(r0, reg_c0r, reg_c2r) fadd(r1, reg_c0c, reg_c2c) fsub(reg_c2r, reg_c0r, reg_c2r).mov(reg_c0r, r0) fsub(reg_c2c, reg_c0c, reg_c2c).mov(reg_c0c, r1) fadd(r0, reg_c1r, reg_c3r) fadd(r1, reg_c1c, reg_c3c) if is_forward: fsub(r1, reg_c1c, reg_c3c).mov(reg_c1c, r1) fsub(r0, reg_c3r, reg_c1r).mov(reg_c1r, r0) else: fsub(r1, reg_c3c, reg_c1c).mov(reg_c1c, r1) fsub(r0, reg_c1r, reg_c3r).mov(reg_c1r, r0) mov(reg_c3r, r1) mov(reg_c3c, r0) reg_d0r = reg_c0r reg_d0c = reg_c0c reg_d1r = reg_c2r reg_d1c = reg_c2c reg_d2r = reg_c1r reg_d2c = reg_c1c reg_d3r = reg_c3r reg_d3c = reg_c3c fadd(r0, reg_d0r, reg_d2r) rotate(tmud, r0, r5) fadd(r0, reg_d0c, reg_d2c) mov(tmua, reg_y) rotate(tmud, r0, r5) add(tmua, reg_y, 4) shl(r0, reg_j, ilog2(8)) add(reg_y, reg_y, r0) fadd(r0, reg_d1r, reg_d3r, sig=ldunifrf(reg_omega_r)) fadd(r1, reg_d1c, reg_d3c, sig=ldunifrf(reg_omega_c)) fmul(r2, r0, reg_omega_r) fmul(r3, r1, reg_omega_c) fsub(r2, r2, r3) rotate(tmud, r2, r5) mov(tmua, reg_y).fmul(r2, r0, reg_omega_c) fmul(r3, r1, reg_omega_r) fadd(r2, r2, r3) rotate(tmud, r2, r5) add(tmua, reg_y, 4) shl(r0, reg_j, ilog2(8)) add(reg_y, reg_y, r0) fsub(r0, reg_d0r, reg_d2r, sig=ldunifrf(reg_omega_r)) fsub(r1, reg_d0c, reg_d2c, sig=ldunifrf(reg_omega_c)) fmul(r2, r0, reg_omega_r) fmul(r3, r1, reg_omega_c) fsub(r2, r2, r3) rotate(tmud, r2, r5) mov(tmua, reg_y).fmul(r2, r0, reg_omega_c) fmul(r3, r1, reg_omega_r) fadd(r2, r2, r3) rotate(tmud, r2, r5) add(tmua, reg_y, 4) shl(r0, reg_j, ilog2(8)) add(reg_y, reg_y, r0) fsub(r0, reg_d1r, reg_d3r, sig=ldunifrf(reg_omega_r)) fsub(r1, reg_d1c, reg_d3c, sig=ldunifrf(reg_omega_c)) fmul(r2, r0, reg_omega_r) fmul(r3, r1, reg_omega_c) fsub(r2, r2, r3) rotate(tmud, r2, r5) mov(tmua, reg_y).fmul(r2, r0, reg_omega_c) fmul(r3, r1, reg_omega_r) fadd(r2, r2, r3).add(reg_m, reg_m, -16, cond='pushn') rotate(tmud, r2, r5) add(tmua, reg_y, 4) add(r0, 12, 12) # 8 * 3 mov(r1, 1).umul24(r0, reg_j, r0) lm.b(cond='na0').unif_addr(absolute=False) sub(reg_y, reg_y, r0) shl(r0, r1, ilog2(8 * 16)) add(reg_x, reg_x, r0).add(reg_y, reg_y, r0) add(reg_l, reg_l, 1) sub(null, reg_l, reg_k, cond='pushn') ll.b(cond='a0') nop() nop() nop() shr(reg_k, reg_k, ilog2(4), cond='pushz') ljk.b(cond='na0') shl(reg_j, reg_j, ilog2(4)).sub(null, reg_j, 2, cond='pushn') mov(reg_y_orig, reg_buf, cond='ifa').mov(r0, reg_y_orig) mov(reg_y_orig, reg_x_orig, cond='ifna').mov(reg_x_orig, r0) L.exit barrierid(syncb, sig=thrsw) nop() nop() nop(sig=thrsw) nop(sig=thrsw) nop() nop() nop(sig=thrsw) nop() nop() nop()