def visit_BinOp(self, t): self.visit(t.left) self.visit(t.right) if util.only_scalars_involed(self.defined_symbols, t.left, t.right): return self.generic_visit(t) # Detect fused operations # MAD: These functions multiply the first two floating-point inputs and add the result to the third input. # MLA: These functions multiply the second and third floating-point inputs and add the result to the first input. # MSB: These functions multiply the first two floating-point inputs and subtract the result from the third input. # MLS: These functions multiply the second and third floating-point inputs and subtract the result from the first input. parent_op = t.op.__class__ left_op = None right_op = None if isinstance(t.left, ast.BinOp): left_op = t.left.op.__class__ if isinstance(t.right, ast.BinOp): right_op = t.right.op.__class__ args = [] name = None if parent_op == ast.Add: if left_op == ast.Mult: name = '__svmad_' args = [t.left.left, t.left.right, t.right] elif right_op == ast.Mult: name = '__svmla_' args = [t.left, t.right.left, t.right.right] elif parent_op == ast.Sub: if left_op == ast.Mult: name = '__svmsb_' args = [t.left.left, t.left.right, t.right] elif right_op == ast.Mult: name = '__svmls_' args = [t.left, t.right.left, t.right.right] # Fused ops need at least two of three arguments to be a vector if name: inferred = util.infer_ast(self.defined_symbols, *args) scalar_args = sum([util.is_scalar(tp) for tp in inferred]) if scalar_args > 1: return self.generic_visit(t) # Add the type suffix for internal representation name += util.TYPE_TO_SVE_SUFFIX[util.get_base_type( dace.dtypes.result_type_of(*inferred))] return ast.copy_location( ast.Call(func=ast.Name(name, ast.Load()), args=args, keywords=[]), t) return self.generic_visit(t)
def generate_out_register(self, sdfg: SDFG, state: SDFGState, edge: graph.MultiConnectorEdge[mm.Memlet], code: CodeIOStream, use_data_name: bool = False) -> bool: """ Responsible for generating temporary out registers in a Tasklet, given an outgoing edge. Returns `True` if a writeback of this register is needed. """ if edge.src_conn is None: return dst_node = state.memlet_path(edge)[-1].dst src_type = edge.src.out_connectors[edge.src_conn] src_name = edge.src_conn if use_data_name: src_name = edge.data.data if isinstance(dst_node, nodes.AccessNode) and isinstance( dst_node.desc(sdfg), data.Stream): # Streams don't need writeback and are treated differently self.stream_associations[edge.src_conn] = (edge.data.data, src_type.base_type) return False elif edge.data.wcr is not None: # WCR is addressed within the unparser to capture conditionals self.wcr_associations[edge.src_conn] = (dst_node, edge, src_type.base_type) return False # Create temporary registers ctype = None if util.is_vector(src_type): ctype = util.TYPE_TO_SVE[src_type.type] elif util.is_scalar(src_type): ctype = src_type.ctype else: raise util.NotSupportedError( 'Unsupported Code->Code edge (pointer)') self.dispatcher.defined_vars.add(src_name, DefinedType.Scalar, ctype) code.write(f'{ctype} {src_name};') return True
def _UnaryOp(self, t): inf_type = self.infer(t.operand)[0] if util.is_scalar(inf_type): return super()._UnaryOp(t) if isinstance(t.op, ast.UAdd): # A + in front is just ignored t.dispatch(t.operand) return if t.op.__class__ not in util.UN_OP_TO_SVE: raise NotImplementedError( f'Unary operation {t.op.__class__.__name__} not implemented') self.write('{}_x({}, '.format(util.UN_OP_TO_SVE[t.op.__class__], self.pred_name)) self.dispatch(t.operand) self.write(')')
def generate_writeback(self, sdfg: SDFG, state: SDFGState, map: nodes.Map, edge: graph.MultiConnectorEdge[mm.Memlet], code: CodeIOStream): """ Responsible for generating code for a writeback in a Tasklet, given the outgoing edge. This is mainly taking the temporary register and writing it back. """ if edge.src_conn is None: return dst_node = state.memlet_path(edge)[-1].dst src_type = edge.src.out_connectors[edge.src_conn] src_name = edge.src_conn if isinstance(dst_node, nodes.Tasklet): ################## # Code->Code edges dst_type = edge.dst.in_connectors[edge.dst_conn] if (util.is_vector(src_type) and util.is_vector(dst_type)) or (util.is_scalar(src_type) and util.is_scalar(dst_type)): # Simply write back to shared register code.write(f'{edge.data.data} = {src_name};') elif util.is_scalar(src_type) and util.is_vector(dst_type): # Scalar broadcast to shared vector register code.write(f'{edge.data.data} = svdup_{util.TYPE_TO_SVE_SUFFIX[dst_type.type]}({src_name});') else: raise util.NotSupportedError('Unsupported Code->Code edge') elif isinstance(dst_node, nodes.AccessNode): ################## # Write to AccessNode desc = dst_node.desc(sdfg) if isinstance(desc, data.Array): ################## # Write into Array if util.is_pointer(src_type): raise util.NotSupportedError('Unsupported writeback') elif util.is_vector(src_type): ################## # Scatter vector store into array stride = edge.data.get_stride(sdfg, map) # long long fix ptr_cast = '' if src_type.type == np.int64: ptr_cast = '(int64_t*) ' elif src_type.type == np.uint64: ptr_cast = '(uint64_t*) ' store_args = '{}, {}'.format( util.get_loop_predicate(sdfg, state, edge.src), ptr_cast + cpp.cpp_ptr_expr(sdfg, edge.data, DefinedType.Pointer, codegen=self.frame), ) if stride == 1: code.write(f'svst1({store_args}, {src_name});') else: code.write( f'svst1_scatter_index({store_args}, svindex_s{util.get_base_type(src_type).bytes * 8}(0, {sym2cpp(stride)}), {src_name});' ) else: ################## # Scalar write into array code.write(f'{cpp.cpp_array_expr(sdfg, edge.data, codegen=self.frame)} = {src_name};') elif isinstance(desc, data.Scalar): ################## # Write into Scalar if util.is_pointer(src_type): raise util.NotSupportedError('Unsupported writeback') elif util.is_vector(src_type): if util.is_vector(desc.dtype): ################## # Vector write into vector Scalar access node code.write(f'{edge.data.data} = {src_name};') else: raise util.NotSupportedError('Unsupported writeback') else: if util.is_vector(desc.dtype): ################## # Broadcast into scalar AccessNode code.write(f'{edge.data.data} = svdup_{util.TYPE_TO_SVE_SUFFIX[src_type]}({src_name});') else: ################## # Scalar write into scalar AccessNode code.write(f'{edge.data.data} = {src_name};') else: raise util.NotSupportedError('Only writeback to Tasklets and AccessNodes is supported')
def generate_read(self, sdfg: SDFG, state: SDFGState, map: nodes.Map, edge: graph.MultiConnectorEdge[mm.Memlet], code: CodeIOStream): """ Responsible for generating code for reads into a Tasklet, given the ingoing edge. """ if edge.dst_conn is None: return src_node = state.memlet_path(edge)[0].src dst_type = edge.dst.in_connectors[edge.dst_conn] dst_name = edge.dst_conn if isinstance(src_node, nodes.Tasklet): ################## # Code->Code edges src_type = edge.src.out_connectors[edge.src_conn] if util.is_vector(src_type) and util.is_vector(dst_type): # Directly read from shared vector register code.write(f'{util.TYPE_TO_SVE[dst_type.type]} {dst_name} = {edge.data.data};') elif util.is_scalar(src_type) and util.is_scalar(dst_type): # Directly read from shared scalar register code.write(f'{dst_type} {dst_name} = {edge.data.data};') elif util.is_scalar(src_type) and util.is_vector(dst_type): # Scalar broadcast from shared scalar register code.write( f'{util.TYPE_TO_SVE[dst_type.type]} {dst_name} = svdup_{util.TYPE_TO_SVE_SUFFIX[dst_type.type]}({edge.data.data});' ) else: raise util.NotSupportedError('Unsupported Code->Code edge') elif isinstance(src_node, nodes.AccessNode): ################## # Read from AccessNode desc = src_node.desc(sdfg) if isinstance(desc, data.Array): # Copy from array if util.is_pointer(dst_type): ################## # Pointer reference code.write( f'{dst_type} {dst_name} = {cpp.cpp_ptr_expr(sdfg, edge.data, None, codegen=self.frame)};') elif util.is_vector(dst_type): ################## # Vector load stride = edge.data.get_stride(sdfg, map) # First part of the declaration is `type name` load_lhs = '{} {}'.format(util.TYPE_TO_SVE[dst_type.type], dst_name) # long long issue casting ptr_cast = '' if dst_type.type == np.int64: ptr_cast = '(int64_t*) ' elif dst_type.type == np.uint64: ptr_cast = '(uint64_t*) ' # Regular load and gather share the first arguments load_args = '{}, {}'.format( util.get_loop_predicate(sdfg, state, edge.dst), ptr_cast + cpp.cpp_ptr_expr(sdfg, edge.data, DefinedType.Pointer, codegen=self.frame)) if stride == 1: code.write('{} = svld1({});'.format(load_lhs, load_args)) else: code.write('{} = svld1_gather_index({}, svindex_s{}(0, {}));'.format( load_lhs, load_args, util.get_base_type(dst_type).bytes * 8, sym2cpp(stride))) else: ################## # Scalar read from array code.write(f'{dst_type} {dst_name} = {cpp.cpp_array_expr(sdfg, edge.data, codegen=self.frame)};') elif isinstance(desc, data.Scalar): # Refer to shared variable src_type = desc.dtype if util.is_vector(src_type) and util.is_vector(dst_type): # Directly read from shared vector register code.write(f'{util.TYPE_TO_SVE[dst_type.type]} {dst_name} = {edge.data.data};') elif util.is_scalar(src_type) and util.is_scalar(dst_type): # Directly read from shared scalar register code.write(f'{dst_type} {dst_name} = {edge.data.data};') elif util.is_scalar(src_type) and util.is_vector(dst_type): # Scalar broadcast from shared scalar register code.write( f'{util.TYPE_TO_SVE[dst_type.type]} {dst_name} = svdup_{util.TYPE_TO_SVE_SUFFIX[dst_type.type]}({edge.data.data});' ) else: raise util.NotSupportedError('Unsupported Scalar->Code edge') else: raise util.NotSupportedError('Only copy from Tasklets and AccessNodes is supported')