def multiple_fusions(A: dace.float32[10, 20], B: dace.float32[10, 20], C: dace.float32[10, 20], out: dace.float32[1]): A_prime = dace.define_local([10, 20], dtype=A.dtype) A_prime_copy = dace.define_local([10, 20], dtype=A.dtype) for i, j in dace.map[0:10, 0:20]: with dace.tasklet: inp << A[i, j] out1 >> out(1, lambda a, b: a + b)[0] out2 >> A_prime[i, j] out3 >> A_prime_copy[i, j] out1 = inp out2 = inp * inp out3 = inp * inp for i, j in dace.map[0:10, 0:20]: with dace.tasklet: inp << A_prime[i, j] out1 >> B[i, j] out1 = inp + 1 for i, j in dace.map[0:10, 0:20]: with dace.tasklet: inp << A_prime_copy[i, j] out2 >> C[i, j] out2 = inp + 2
def fusion(A: dace.float32[10, 20], B: dace.float32[10, 20], out: dace.float32[1]): tmp = dace.define_local([10, 20], dtype=A.dtype) tmp_2 = dace.define_local([10, 20], dtype=A.dtype) for i, j in dace.map[0:10, 0:20]: with dace.tasklet: a << A[i, j] b >> tmp[i, j] b = a * a for i, j in dace.map[0:20, 0:10]: with dace.tasklet: a << tmp[j, i] b << B[j, i] c >> tmp_2[j, i] c = a + b for m, n in dace.map[0:10, 0:20]: with dace.tasklet: a << tmp_2[m, n] b >> out(1, lambda a, b: a + b)[0] b = a
def outer_sqrt_with_intermediate(Y: dace.float32[3, 3]): intermediate = dace.define_local([3, 3], dace.float32) W = dace.define_local([3, 3], dace.float32) intermediate[:] = dace.elementwise(lambda x: sqrt(x), Y) W[:] = middle_sqrt_no_sum(intermediate) Z = np.sum(W) return Z
def attn_fwd( q: dace.float32[batchSize, Qsize, seqLenQ], k: dace.float32[batchSize, Qsize, seqLenK], v: dace.float32[batchSize, Qsize, seqLenK], wq: dace.float32[numHeads, projQsize, Qsize], wk: dace.float32[numHeads, projQsize, Qsize], wv: dace.float32[numHeads, projQsize, Qsize], wo: dace.float32[numHeads, Qsize, projQsize], out: dace.float32[batchSize, Qsize, seqLenQ], ): for b in dace.map[0:batchSize]: outs = dace.define_local([numHeads, Qsize, seqLenQ], dace.float32) for h in dace.map[0:numHeads]: q_bar = wq[h] @ q[b] # projQsize x seqLenQ k_bar = wk[h] @ k[b] # projQsize x seqLenK v_bar = wv[h] @ v[b] # projQsize x seqLenK k_bar_t = dace.define_local([seqLenK, projQsize], dace.float32) sdfg_transpose(k_bar, k_bar_t) beta = k_bar_t @ q_bar # seqLenK x seqLenQ alpha = dace.define_local([seqLenK, seqLenQ], dace.float32) for j in dace.map[0:seqLenK]: dace_softmax(beta[j], alpha[j]) h_bar = v_bar @ alpha # projQsize x seqLenQ outs[h] = wo[h] @ h_bar # Qsize x seqLenQ out[b] = dace.reduce(lambda a, b: a + b, outs, axis=0, identity=0)
def middle_sqrt_with_intermediate(Y: dace.float32[3, 3]): intermediate = dace.define_local([3, 3], dace.float32) W = dace.define_local([3, 3], dace.float32) intermediate[:] = dace.elementwise(lambda x: sqrt(x), Y) inner_sdfg_with_intermediate(intermediate, W) Z = np.sum(W) return Z
def DFT(X, Y): # Generate DFT matrix dft_mat = dace.define_local([N, N], dtype=dace.complex128) @dace.mapscope(_[0:N]) def out_map_gen(i): @dace.map(_[i:N]) def dft_mat_gen(j): omega1 >> dft_mat[i, j] omega2 >> dft_mat[j, i] omega = exp(-dace.complex128(0, 2 * 3.14159265359 * i * j) / dace.complex128(N)) omega1 = omega omega2 = omega # Matrix multiply input vector with DFT matrix tmp = dace.define_local([N, N], dtype=dace.complex128) @dace.map(_[0:N, 0:N]) def dft_tasklet(k, n): x << X[n] omega << dft_mat[k, n] out >> tmp[k, n] out = x * omega dace.reduce(lambda a, b: a + b, tmp, Y, axis=1, identity=0)
def prog(A, B): no = dace.define_local([number], dace.float32) number = dace.define_local([W], dace.float32) f(A, number) @dace.map(_[0:W]) def bla2(i): inp << number[i] out >> B[i] out = 2 * inp
def duplicate_naming(A, B): no = dace.define_local([number], dace.float32) number = dace.define_local([W], dace.float32) duplicate_naming_inner(A, number) @dace.map(_[0:W]) def bla2(i): inp << number[i] out >> B[i] out = 2 * inp
def softmax_backward(output, output_grad, input_grad): prod = dace.define_local(output_shape, output_dtype) sums = dace.define_local(sums_shape, output_dtype) donnx.ONNXMul(A=output, B=output_grad, C=prod) donnx.ONNXReduceSum(data=prod, reduced=sums, keepdims=1, axes=[dim]) donnx.ONNXMul(A=output, B=sums, C=input_grad) # let's not use ONNXSub here; not sure how this inplace op is handled by ORT... input_grad[:] = prod - input_grad
def logsoftmax_backward(output, output_grad, input_grad): exp_output = dace.define_local(output_shape, output_dtype) donnx.ONNXExp(input=output, output=exp_output) grad_output_sum = dace.define_local(sums_shape, output_dtype) donnx.ONNXReduceSum(data=output_grad, reduced=grad_output_sum, keepdims=1, axes=[dim]) # let's not use ONNXMul here; not sure how this inplace op is handled by ORT... exp_output[:] = exp_output * grad_output_sum donnx.ONNXSub(A=output_grad, B=exp_output, C=input_grad)
def sftw(A: dace.float64[20]): B = dace.define_local([20], dace.float64) C = dace.define_local([20], dace.float64) D = dace.define_local([20], dace.float64) E = dace.define_local([20], dace.float64) dup = dace.define_local([20], dace.float64) for i in dace.map[0:20]: with dace.tasklet: a << A[i] b >> B[i] b = a for i in dace.map[0:20]: with dace.tasklet: a << B[i] b >> dup[i] b = a for i in dace.map[0:20]: with dace.tasklet: a << dup[i] b >> D[i] b = a + 2 for i in dace.map[0:20]: with dace.tasklet: a << A[i] b >> C[i] b = a + 1 for i in dace.map[0:20]: with dace.tasklet: a << C[i] b >> dup[i] b = a + 1 for i in dace.map[0:20]: with dace.tasklet: a << dup[i] b >> E[i] b = a + 3 for i in dace.map[0:20]: with dace.tasklet: d << D[i] e << E[i] a >> A[i] a = d + e
def operation(A: dace.float64[M, M], B: dace.float64[M, M], C: dace.float64[M, N], D: dace.float64[M, N]): tmp = dace.define_local([M, M, M], dtype=A.dtype) E = dace.define_local([M,M], dtype=A.dtype) @dace.map(_[0:M, 0:M, 0:M]) def multiplication(i, j, k): in_A << A[i, k] in_B << B[k, j] out >> tmp[i, j, k] out = in_A * in_B dace.reduce(lambda a, b: a + b, tmp, E, axis=2, identity=0) C[:] = A @ E @ (A @ B) @ (B @ D)
def k2mm(A, B, C, D, alpha, beta): tmp = dace.define_local([NI, NJ], dtype=datatype) @dace.map def zerotmp(i: _[0:NI], j: _[0:NJ]): out >> tmp[i, j] out = 0.0 @dace.map def mult_tmp(i: _[0:NI], j: _[0:NJ], k: _[0:NK]): in_a << A[i, k] in_b << B[k, j] in_alpha << alpha out >> tmp(1, lambda x, y: x + y)[i, j] out = in_alpha * in_a * in_b @dace.map def mult_d(i: _[0:NI], j: _[0:NL]): inp << D[i, j] in_beta << beta out >> D[i, j] out = inp * in_beta @dace.map def comp_d(i: _[0:NI], j: _[0:NL], k: _[0:NJ]): in_a << tmp[i, k] in_b << C[k, j] out >> D(1, lambda x, y: x + y)[i, j] out = in_a * in_b
def test_inline_reshape_views_work(A: dace.float64[3, 3], B: dace.float64[9]): result = dace.define_local([9], dace.float64) result[:] = nested_add2(A, B) result_reshaped = reshape_node(result) return np.transpose(result_reshaped)
def fusion_recomputation(A: dace.float64[20, 20], B: dace.float64[16, 16]): tmp = dace.define_local([18, 18], dtype=A.dtype) for i, j in dace.map[1:19, 1:19]: with dace.tasklet: a0 << A[i - 1, j - 1] a1 << A[i - 1, j] a2 << A[i - 1, j + 1] a3 << A[i, j - 1] a4 << A[i, j] a5 << A[i, j + 1] a6 << A[i + 1, j - 1] a7 << A[i + 1, j] a8 << A[i + 1, j + 1] b >> tmp[i - 1, j - 1] b = (a0 + a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8) / 9.0 for i, j in dace.map[1:17, 1:17]: with dace.tasklet: a0 << tmp[i - 1, j - 1] a1 << tmp[i - 1, j] a2 << tmp[i - 1, j + 1] a3 << tmp[i, j - 1] a4 << tmp[i, j] a5 << tmp[i, j + 1] a6 << tmp[i + 1, j - 1] a7 << tmp[i + 1, j] a8 << tmp[i + 1, j + 1] b >> B[i - 1, j - 1] b = (a0 + a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8) / 9.0
def single_state_reshape_same_state(inp: dace.float64[9], target_shape: dace.int64[2]): reshaped = dace.define_local([3, 3], dace.float64) donnx.ONNXReshape(data=inp, shape=target_shape, reshaped=reshaped) Zl = dace.elementwise(lambda x: log(x + 1), reshaped) S = np.sum(Zl) return S
def persistent_transient(A: dace.float32[3, 3]): persistent_transient = dace.define_local( [3, 5], dace.float32, lifetime=dace.AllocationLifetime.Persistent, storage=dace.StorageType.GPU_Global) return A @ persistent_transient
def test_einsum(A: dace.float64[5, 4, 3], B: dace.float64[3, 2]): Y = dace.define_local([5, 4, 2], dace.float64) donnx.ONNXEinsum(Inputs__0=A, Inputs__1=B, Output=Y, equation="bij, jk -> bik") return Y
def jacobi(A, iterations): # Transient variable tmp = dace.define_local([H, W], dtype=A.dtype) @dace.map(_[0:H, 0:W]) def reset_tmp(y, x): out >> tmp[y, x] out = 0.0 @dace.iterate(_[0:iterations]) def step(t): @dace.map(_[1:H - 1, 1:W - 1]) def a2b(y, x): in_N << A[y - 1, x] in_S << A[y + 1, x] in_W << A[y, x - 1] in_E << A[y, x + 1] in_C << A[y, x] out >> tmp[y, x] out = 0.2 * (in_C + in_N + in_S + in_W + in_E) # Double buffering @dace.map(_[1:H - 1, 1:W - 1]) def b2a(y, x): in_N << tmp[y - 1, x] in_S << tmp[y + 1, x] in_W << tmp[y, x - 1] in_E << tmp[y, x + 1] in_C << tmp[y, x] out >> A[y, x] out = 0.2 * (in_C + in_N + in_S + in_W + in_E)
def nested(A: dace.float64[64]): # Create local array with the same name as an outer array gpu_A = dace.define_local([64], np.float64, storage=dace.StorageType.GPU_Global) gpu_A[:] = 0 gpu_A[:] = 1 A[:] = gpu_A
def implicit_line_joining(A: dace.float32[N], B: dace.float32[N]): # The DaCe programs sets B equal to A tmp = dace.define_local( (N,), # shape dtype=dace.float32 # type ) tmp[:] = A[:] # for i in 0 .. N-1; tmp[i] = A[i] B[:] = tmp[:] # for i in 0 .. N-1; B[i] = tmp[i]
def transient(A: dace.float64[128, 64]): for i in dace.map[0:128]: # Create local array with the same name as an outer array gpu_A = dace.define_local([64], np.float64, storage=dace.StorageType.GPU_Global) gpu_A[:] = 0 gpu_A[:] = 1 A[i, :] = gpu_A
def prog(input: dace.float32[2, 2], output: dace.float32[2, 2]): tmp_max = np.max(input, axis=axis) out_tmp = dace.define_local(out_tmp_shape, out_tmp_dtype) exp_minus_max(tmp_max=tmp_max, original_input=input, output=out_tmp) tmp_sum = np.sum(out_tmp, axis=axis) out_tmp_div_sum(out_tmp=out_tmp, tmp_sum=tmp_sum, output=output)
def add_reshape_grad_test_nested(inp: dace.float64[9], bias: dace.float64[3], target_shape: dace.int64[2], result: dace.float64): reshaped = dace.define_local([3, 3], dace.float64) added = inp + 1 donnx.ONNXReshape(data=added, shape=target_shape, reshaped=reshaped) Z = reshaped * bias Zl = dace.elementwise(lambda x: log(x + 1), Z) result[:] = np.sum(Zl)
def program(A: dace.float32[N], B: dace.float32[N]): for i in dace.map[0:N]: arr = dace.define_local(N, dace.float32) with dace.tasklet: a << A[0] x_out >> arr[0] # now write into an Array access node x_out = a with dace.tasklet: x_in << arr[0] b >> B[i] b = x_in
def keyword_return(A: dace.float32[N]): i = dace.define_local_scalar(dtype=dace.int32) i = 0 B = dace.define_local((N, ), dtype=dace.float32) while True: B[i] = A[i] + i - i i += 1 if i < N: continue else: break return B
def vector_reduce(x: dace.float32[N], s: dace.scalar(dace.float32)): #transient tmp = dace.define_local([N], dtype=x.dtype) @dace.map def sum(i: _[0:N]): in_x << x[i] out >> tmp[i] out = in_x dace.reduce(lambda a, b: a + b, tmp, s, axis=(0), identity=0)
def program(A: dace.float32[N], B: dace.float32[N]): for i in dace.map[0:N]: arr = dace.define_local(N, dace.float32) with dace.tasklet: a << A[i] # Reading a vector but storing a scalar (must fail) x_out >> arr[0] x_out = a with dace.tasklet: x_in << arr[i] b >> B[i] b = x_in
def transients(A: dace.float32[10]): ostream = dace.define_stream(dace.float32, 10) oscalar = dace.define_local_scalar(dace.int32) oarray = dace.define_local([10], dace.float32) oarray[:] = 0 oscalar = 0 for i in dace.map[0:10]: if A[i] >= 0.5: A[i] >> ostream(-1) oscalar += 1 ostream >> oarray return oscalar, oarray
def gramschmidt(A, R, Q): nrm = dace.define_local([1], datatype) for k in range(0, N, 1): @dace.tasklet def set_nrm(): out_nrm >> nrm out_nrm = datatype(0) @dace.map def set_sum(i: _[0:M]): in_A << A[i, k] out_nrm >> nrm(1, lambda x, y: x + y) out_nrm = in_A * in_A @dace.tasklet def set_rkk(): in_nrm << nrm out_R >> R[k, k] out_R = math.sqrt(in_nrm) @dace.map def set_q(i: _[0:M]): in_A << A[i, k] in_R << R[k, k] out_Q >> Q[i, k] out_Q = in_A / in_R @dace.mapscope def set_rna(j: _[k + 1:N]): # for j in range(k+1, N, 1): @dace.tasklet def init_r(): out_R >> R[k, j] out_R = datatype(0) @dace.map def set_r(i: _[0:M]): in_A << A[i, j] in_Q << Q[i, k] out_R >> R(1, lambda x, y: x + y)[k, j] out_R = in_A * in_Q @dace.map def set_a(i: _[0:M]): in_R << R[k, j] in_Q << Q[i, k] out_A >> A(1, lambda x, y: x + y)[i, j] out_A = -in_R * in_Q