Exemple #1
0
    def visit_BinOp(self, t):
        self.visit(t.left)
        self.visit(t.right)

        if util.only_scalars_involed(self.defined_symbols, t.left, t.right):
            return self.generic_visit(t)

        # Detect fused operations

        # MAD: These functions multiply the first two floating-point inputs and add the result to the third input.
        # MLA: These functions multiply the second and third floating-point inputs and add the result to the first input.
        # MSB: These functions multiply the first two floating-point inputs and subtract the result from the third input.
        # MLS: These functions multiply the second and third floating-point inputs and subtract the result from the first input.

        parent_op = t.op.__class__
        left_op = None
        right_op = None

        if isinstance(t.left, ast.BinOp):
            left_op = t.left.op.__class__
        if isinstance(t.right, ast.BinOp):
            right_op = t.right.op.__class__

        args = []
        name = None

        if parent_op == ast.Add:
            if left_op == ast.Mult:
                name = '__svmad_'
                args = [t.left.left, t.left.right, t.right]
            elif right_op == ast.Mult:
                name = '__svmla_'
                args = [t.left, t.right.left, t.right.right]
        elif parent_op == ast.Sub:
            if left_op == ast.Mult:
                name = '__svmsb_'
                args = [t.left.left, t.left.right, t.right]
            elif right_op == ast.Mult:
                name = '__svmls_'
                args = [t.left, t.right.left, t.right.right]

        # Fused ops need at least two of three arguments to be a vector
        if name:
            inferred = util.infer_ast(self.defined_symbols, *args)
            scalar_args = sum([util.is_scalar(tp) for tp in inferred])
            if scalar_args > 1:
                return self.generic_visit(t)
            # Add the type suffix for internal representation
            name += util.TYPE_TO_SVE_SUFFIX[util.get_base_type(
                dace.dtypes.result_type_of(*inferred))]
            return ast.copy_location(
                ast.Call(func=ast.Name(name, ast.Load()),
                         args=args,
                         keywords=[]), t)

        return self.generic_visit(t)
Exemple #2
0
 def infer(self, *args) -> tuple:
     return util.infer_ast(self.get_defined_symbols(), *args)