def sddmm_dsl( A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K), B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N), S=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N), C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)): C[dsl.D.m, dsl.D.n] += S[dsl.D.m, dsl.D.n] * A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
def matmul_dsl( A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K), B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N), C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True) ): """Helper function for mlir sparse matrix multiplication benchmark.""" C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
def matmul_dsl( A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K), B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N), C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)): C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]