def _BoolOp(self, t): if util.only_scalars_involed(self.get_defined_symbols(), *t.values): return super()._BoolOp(t) types = self.infer(*t.values) # Bool ops are nested SVE instructions, so we must make sure they all act on vectors for type in types: if not isinstance(type, dtypes.vector): raise util.NotSupportedError( 'Non-vectorizable boolean operation') # There can be many t.values, e.g. if # x or y or z for val in t.values: # The last entry doesn't need more nesting if val == t.values[-1]: self.dispatch(t.values[-1]) break # Binary nesting self.write('{}_z({}, '.format(util.BOOL_OP_TO_SVE[t.op.__class__], self.pred_name)) self.dispatch(val) self.write(', ') # Close all except the last bracket (because the last entry isn't nested) self.write(')' * (len(t.values) - 1))
def visit_BinOp(self, t): self.visit(t.left) self.visit(t.right) if util.only_scalars_involed(self.defined_symbols, t.left, t.right): return self.generic_visit(t) # Detect fused operations # MAD: These functions multiply the first two floating-point inputs and add the result to the third input. # MLA: These functions multiply the second and third floating-point inputs and add the result to the first input. # MSB: These functions multiply the first two floating-point inputs and subtract the result from the third input. # MLS: These functions multiply the second and third floating-point inputs and subtract the result from the first input. parent_op = t.op.__class__ left_op = None right_op = None if isinstance(t.left, ast.BinOp): left_op = t.left.op.__class__ if isinstance(t.right, ast.BinOp): right_op = t.right.op.__class__ args = [] name = None if parent_op == ast.Add: if left_op == ast.Mult: name = '__svmad_' args = [t.left.left, t.left.right, t.right] elif right_op == ast.Mult: name = '__svmla_' args = [t.left, t.right.left, t.right.right] elif parent_op == ast.Sub: if left_op == ast.Mult: name = '__svmsb_' args = [t.left.left, t.left.right, t.right] elif right_op == ast.Mult: name = '__svmls_' args = [t.left, t.right.left, t.right.right] # Fused ops need at least two of three arguments to be a vector if name: inferred = util.infer_ast(self.defined_symbols, *args) scalar_args = sum([util.is_scalar(tp) for tp in inferred]) if scalar_args > 1: return self.generic_visit(t) # Add the type suffix for internal representation name += util.TYPE_TO_SVE_SUFFIX[util.get_base_type( dace.dtypes.result_type_of(*inferred))] return ast.copy_location( ast.Call(func=ast.Name(name, ast.Load()), args=args, keywords=[]), t) return self.generic_visit(t)
def _IfExp(self, t): if util.only_scalars_involed(self.get_defined_symbols(), t.test, t.body, t.orelse): return super()._IfExp(t) if_type, else_type = self.infer(t.body, t.orelse) res_type = dtypes.result_type_of(if_type, else_type) if not isinstance(res_type, dtypes.vector): res_type = dtypes.vector(res_type, -1) self.write('svsel(') self.dispatch_expect(t.test, dtypes.vector(dace.bool, -1)) self.write(', ') self.dispatch_expect(t.body, res_type) self.write(', ') self.dispatch_expect(t.orelse, res_type) self.write(')')