Пример #1
0
 def expansion(node, state, sdfg):
     a, b, c = _get_matmul_operands(node, state, sdfg)
     size_a = a[2]
     size_b = b[2]
     if len(size_a) == 2 and len(size_b) == 2:
         # Matrix and matrix -> GEMM
         from dace.libraries.blas.nodes.gemm import Gemm
         beta = 0.0
         if c[0].data.wcr:
             from dace.frontend import operations
             redtype = operations.detect_reduction_type(c[0].data.wcr)
             if redtype == dace.dtypes.ReductionType.Sum:
                 beta = 1.0
             else:
                 warnings.warn("Unsupported WCR in output of MatMul "
                               "library node: {}".format(c[0].data.wcr))
         gemm = Gemm(node.name + 'gemm',
                     location=node.location,
                     alpha=1.0,
                     beta=beta)
         return gemm
     elif len(size_b) == 3 and (len(size_a) in [2, 3]):
         # Batched matrix and matrix -> batched matrix multiplication
         from dace.libraries.blas.nodes.batched_matmul import BatchedMatMul
         batched = BatchedMatMul(node.name + 'bmm',
                                 location=node.location)
         return batched
     elif len(size_a) == 2 and len(size_b) == 1:
         # Matrix and vector -> GEMV
         from dace.libraries.blas.nodes.gemv import Gemv
         # Rename inputs to match dot naming
         a[0].dst_conn = "_A"
         b[0].dst_conn = "_x"
         c[0].src_conn = "_y"
         gemv = Gemv(node.name + 'gemv', location=node.location)
         return gemv
     elif len(size_a) == 1 and len(size_b) == 2:
         # Vector and matrix -> GEMV with transposed matrix
         from dace.libraries.blas.nodes.gemv import Gemv
         # Rename inputs to match dot naming
         a[0].dst_conn = "_x"
         b[0].dst_conn = "_A"
         c[0].src_conn = "_y"
         gemv = Gemv(node.name + 'gemvt',
                     location=node.location,
                     transA=True)
         return gemv
     elif len(size_a) == 1 and len(size_b) == 1:
         # Vector and vector -> dot product
         from dace.libraries.blas.nodes.dot import Dot
         # Rename inputs to match dot naming
         a[0].dst_conn = "_x"
         b[0].dst_conn = "_y"
         c[0].src_conn = "_result"
         dot = Dot(node.name + 'dot', location=node.location)
         return dot
     else:
         raise NotImplementedError("Matrix multiplication not implemented "
                                   "for shapes: {} and {}".format(
                                       size_a, size_b))
Пример #2
0
 def expansion(node, state, sdfg):
     a, b, c = _get_matmul_operands(node, state, sdfg)
     size_a = a[2]
     size_b = b[2]
     if len(size_a) == 2 and len(size_b) == 2:
         # Matrix and matrix -> GEMM
         from dace.libraries.blas.nodes.gemm import Gemm
         gemm = Gemm(node.name + 'gemm',
                     location=node.location,
                     alpha=1.0,
                     beta=0.0)
         return gemm
     elif len(size_b) == 3 and (len(size_a) in [2, 3]):
         # Batched matrix and matrix -> batched matrix multiplication
         from dace.libraries.blas.nodes.batched_matmul import BatchedMatMul
         batched = BatchedMatMul(node.name + 'bmm',
                                 location=node.location)
         return batched
     elif len(size_a) == 2 and len(size_b) == 1:
         # Matrix and vector -> GEMV
         from dace.libraries.blas.nodes.gemv import Gemv
         # Rename inputs to match dot naming
         a[0].dst_conn = "_A"
         b[0].dst_conn = "_x"
         c[0].src_conn = "_y"
         gemv = Gemv(node.name + 'gemv', location=node.location)
         return gemv
     elif len(size_a) == 1 and len(size_b) == 2:
         # Vector and matrix -> GEMV with transposed matrix
         from dace.libraries.blas.nodes.gemv import Gemv
         # Rename inputs to match dot naming
         a[0].dst_conn = "_x"
         b[0].dst_conn = "_A"
         c[0].src_conn = "_y"
         gemv = Gemv(node.name + 'gemvt',
                     location=node.location,
                     transA=True)
         return gemv
     elif len(size_a) == 1 and len(size_b) == 1:
         # Vector and vector -> dot product
         from dace.libraries.blas.nodes.dot import Dot
         # Rename inputs to match dot naming
         a[0].dst_conn = "_x"
         b[0].dst_conn = "_y"
         c[0].src_conn = "_result"
         dot = Dot(node.name + 'dot', location=node.location)
         return dot
     else:
         raise NotImplementedError("Matrix multiplication not implemented "
                                   "for shapes: {} and {}".format(
                                       size_a, size_b))