Exemple #1
0
    def _BoolOp(self, t):
        if util.only_scalars_involed(self.get_defined_symbols(), *t.values):
            return super()._BoolOp(t)

        types = self.infer(*t.values)

        # Bool ops are nested SVE instructions, so we must make sure they all act on vectors
        for type in types:
            if not isinstance(type, dtypes.vector):
                raise util.NotSupportedError(
                    'Non-vectorizable boolean operation')

        # There can be many t.values, e.g. if
        # x or y or z
        for val in t.values:
            # The last entry doesn't need more nesting
            if val == t.values[-1]:
                self.dispatch(t.values[-1])
                break

            # Binary nesting
            self.write('{}_z({}, '.format(util.BOOL_OP_TO_SVE[t.op.__class__],
                                          self.pred_name))
            self.dispatch(val)
            self.write(', ')

        # Close all except the last bracket (because the last entry isn't nested)
        self.write(')' * (len(t.values) - 1))
Exemple #2
0
    def visit_BinOp(self, t):
        self.visit(t.left)
        self.visit(t.right)

        if util.only_scalars_involed(self.defined_symbols, t.left, t.right):
            return self.generic_visit(t)

        # Detect fused operations

        # MAD: These functions multiply the first two floating-point inputs and add the result to the third input.
        # MLA: These functions multiply the second and third floating-point inputs and add the result to the first input.
        # MSB: These functions multiply the first two floating-point inputs and subtract the result from the third input.
        # MLS: These functions multiply the second and third floating-point inputs and subtract the result from the first input.

        parent_op = t.op.__class__
        left_op = None
        right_op = None

        if isinstance(t.left, ast.BinOp):
            left_op = t.left.op.__class__
        if isinstance(t.right, ast.BinOp):
            right_op = t.right.op.__class__

        args = []
        name = None

        if parent_op == ast.Add:
            if left_op == ast.Mult:
                name = '__svmad_'
                args = [t.left.left, t.left.right, t.right]
            elif right_op == ast.Mult:
                name = '__svmla_'
                args = [t.left, t.right.left, t.right.right]
        elif parent_op == ast.Sub:
            if left_op == ast.Mult:
                name = '__svmsb_'
                args = [t.left.left, t.left.right, t.right]
            elif right_op == ast.Mult:
                name = '__svmls_'
                args = [t.left, t.right.left, t.right.right]

        # Fused ops need at least two of three arguments to be a vector
        if name:
            inferred = util.infer_ast(self.defined_symbols, *args)
            scalar_args = sum([util.is_scalar(tp) for tp in inferred])
            if scalar_args > 1:
                return self.generic_visit(t)
            # Add the type suffix for internal representation
            name += util.TYPE_TO_SVE_SUFFIX[util.get_base_type(
                dace.dtypes.result_type_of(*inferred))]
            return ast.copy_location(
                ast.Call(func=ast.Name(name, ast.Load()),
                         args=args,
                         keywords=[]), t)

        return self.generic_visit(t)
Exemple #3
0
    def _IfExp(self, t):
        if util.only_scalars_involed(self.get_defined_symbols(), t.test,
                                     t.body, t.orelse):
            return super()._IfExp(t)

        if_type, else_type = self.infer(t.body, t.orelse)
        res_type = dtypes.result_type_of(if_type, else_type)
        if not isinstance(res_type, dtypes.vector):
            res_type = dtypes.vector(res_type, -1)

        self.write('svsel(')
        self.dispatch_expect(t.test, dtypes.vector(dace.bool, -1))
        self.write(', ')
        self.dispatch_expect(t.body, res_type)
        self.write(', ')
        self.dispatch_expect(t.orelse, res_type)
        self.write(')')