def expansion(node, state, sdfg): a, b, c = _get_matmul_operands(node, state, sdfg) size_a = a[2] size_b = b[2] if len(size_a) == 2 and len(size_b) == 2: # Matrix and matrix -> GEMM from dace.libraries.blas.nodes.gemm import Gemm beta = 0.0 if c[0].data.wcr: from dace.frontend import operations redtype = operations.detect_reduction_type(c[0].data.wcr) if redtype == dace.dtypes.ReductionType.Sum: beta = 1.0 else: warnings.warn("Unsupported WCR in output of MatMul " "library node: {}".format(c[0].data.wcr)) gemm = Gemm(node.name + 'gemm', location=node.location, alpha=1.0, beta=beta) return gemm elif len(size_b) == 3 and (len(size_a) in [2, 3]): # Batched matrix and matrix -> batched matrix multiplication from dace.libraries.blas.nodes.batched_matmul import BatchedMatMul batched = BatchedMatMul(node.name + 'bmm', location=node.location) return batched elif len(size_a) == 2 and len(size_b) == 1: # Matrix and vector -> GEMV from dace.libraries.blas.nodes.gemv import Gemv # Rename inputs to match dot naming a[0].dst_conn = "_A" b[0].dst_conn = "_x" c[0].src_conn = "_y" gemv = Gemv(node.name + 'gemv', location=node.location) return gemv elif len(size_a) == 1 and len(size_b) == 2: # Vector and matrix -> GEMV with transposed matrix from dace.libraries.blas.nodes.gemv import Gemv # Rename inputs to match dot naming a[0].dst_conn = "_x" b[0].dst_conn = "_A" c[0].src_conn = "_y" gemv = Gemv(node.name + 'gemvt', location=node.location, transA=True) return gemv elif len(size_a) == 1 and len(size_b) == 1: # Vector and vector -> dot product from dace.libraries.blas.nodes.dot import Dot # Rename inputs to match dot naming a[0].dst_conn = "_x" b[0].dst_conn = "_y" c[0].src_conn = "_result" dot = Dot(node.name + 'dot', location=node.location) return dot else: raise NotImplementedError("Matrix multiplication not implemented " "for shapes: {} and {}".format( size_a, size_b))
def expansion(node, state, sdfg): a, b, c = _get_matmul_operands(node, state, sdfg) size_a = a[2] size_b = b[2] if len(size_a) == 2 and len(size_b) == 2: # Matrix and matrix -> GEMM from dace.libraries.blas.nodes.gemm import Gemm gemm = Gemm(node.name + 'gemm', location=node.location, alpha=1.0, beta=0.0) return gemm elif len(size_b) == 3 and (len(size_a) in [2, 3]): # Batched matrix and matrix -> batched matrix multiplication from dace.libraries.blas.nodes.batched_matmul import BatchedMatMul batched = BatchedMatMul(node.name + 'bmm', location=node.location) return batched elif len(size_a) == 2 and len(size_b) == 1: # Matrix and vector -> GEMV from dace.libraries.blas.nodes.gemv import Gemv # Rename inputs to match dot naming a[0].dst_conn = "_A" b[0].dst_conn = "_x" c[0].src_conn = "_y" gemv = Gemv(node.name + 'gemv', location=node.location) return gemv elif len(size_a) == 1 and len(size_b) == 2: # Vector and matrix -> GEMV with transposed matrix from dace.libraries.blas.nodes.gemv import Gemv # Rename inputs to match dot naming a[0].dst_conn = "_x" b[0].dst_conn = "_A" c[0].src_conn = "_y" gemv = Gemv(node.name + 'gemvt', location=node.location, transA=True) return gemv elif len(size_a) == 1 and len(size_b) == 1: # Vector and vector -> dot product from dace.libraries.blas.nodes.dot import Dot # Rename inputs to match dot naming a[0].dst_conn = "_x" b[0].dst_conn = "_y" c[0].src_conn = "_result" dot = Dot(node.name + 'dot', location=node.location) return dot else: raise NotImplementedError("Matrix multiplication not implemented " "for shapes: {} and {}".format( size_a, size_b))