def generate_out_register(self, sdfg: SDFG, state: SDFGState, edge: graph.MultiConnectorEdge[mm.Memlet], code: CodeIOStream, use_data_name: bool = False) -> bool: """ Responsible for generating temporary out registers in a Tasklet, given an outgoing edge. Returns `True` if a writeback of this register is needed. """ if edge.src_conn is None: return dst_node = state.memlet_path(edge)[-1].dst src_type = edge.src.out_connectors[edge.src_conn] src_name = edge.src_conn if use_data_name: src_name = edge.data.data if isinstance(dst_node, nodes.AccessNode) and isinstance( dst_node.desc(sdfg), data.Stream): # Streams don't need writeback and are treated differently self.stream_associations[edge.src_conn] = (edge.data.data, src_type.base_type) return False elif edge.data.wcr is not None: # WCR is addressed within the unparser to capture conditionals self.wcr_associations[edge.src_conn] = (dst_node, edge, src_type.base_type) return False # Create temporary registers ctype = None if util.is_vector(src_type): ctype = util.TYPE_TO_SVE[src_type.type] elif util.is_scalar(src_type): ctype = src_type.ctype else: raise util.NotSupportedError( 'Unsupported Code->Code edge (pointer)') self.dispatcher.defined_vars.add(src_name, DefinedType.Scalar, ctype) code.write(f'{ctype} {src_name};') return True
def generate_writeback(self, sdfg: SDFG, state: SDFGState, map: nodes.Map, edge: graph.MultiConnectorEdge[mm.Memlet], code: CodeIOStream): """ Responsible for generating code for a writeback in a Tasklet, given the outgoing edge. This is mainly taking the temporary register and writing it back. """ if edge.src_conn is None: return dst_node = state.memlet_path(edge)[-1].dst src_type = edge.src.out_connectors[edge.src_conn] src_name = edge.src_conn if isinstance(dst_node, nodes.Tasklet): ################## # Code->Code edges dst_type = edge.dst.in_connectors[edge.dst_conn] if (util.is_vector(src_type) and util.is_vector(dst_type)) or (util.is_scalar(src_type) and util.is_scalar(dst_type)): # Simply write back to shared register code.write(f'{edge.data.data} = {src_name};') elif util.is_scalar(src_type) and util.is_vector(dst_type): # Scalar broadcast to shared vector register code.write(f'{edge.data.data} = svdup_{util.TYPE_TO_SVE_SUFFIX[dst_type.type]}({src_name});') else: raise util.NotSupportedError('Unsupported Code->Code edge') elif isinstance(dst_node, nodes.AccessNode): ################## # Write to AccessNode desc = dst_node.desc(sdfg) if isinstance(desc, data.Array): ################## # Write into Array if util.is_pointer(src_type): raise util.NotSupportedError('Unsupported writeback') elif util.is_vector(src_type): ################## # Scatter vector store into array stride = edge.data.get_stride(sdfg, map) # long long fix ptr_cast = '' if src_type.type == np.int64: ptr_cast = '(int64_t*) ' elif src_type.type == np.uint64: ptr_cast = '(uint64_t*) ' store_args = '{}, {}'.format( util.get_loop_predicate(sdfg, state, edge.src), ptr_cast + cpp.cpp_ptr_expr(sdfg, edge.data, DefinedType.Pointer, codegen=self.frame), ) if stride == 1: code.write(f'svst1({store_args}, {src_name});') else: code.write( f'svst1_scatter_index({store_args}, svindex_s{util.get_base_type(src_type).bytes * 8}(0, {sym2cpp(stride)}), {src_name});' ) else: ################## # Scalar write into array code.write(f'{cpp.cpp_array_expr(sdfg, edge.data, codegen=self.frame)} = {src_name};') elif isinstance(desc, data.Scalar): ################## # Write into Scalar if util.is_pointer(src_type): raise util.NotSupportedError('Unsupported writeback') elif util.is_vector(src_type): if util.is_vector(desc.dtype): ################## # Vector write into vector Scalar access node code.write(f'{edge.data.data} = {src_name};') else: raise util.NotSupportedError('Unsupported writeback') else: if util.is_vector(desc.dtype): ################## # Broadcast into scalar AccessNode code.write(f'{edge.data.data} = svdup_{util.TYPE_TO_SVE_SUFFIX[src_type]}({src_name});') else: ################## # Scalar write into scalar AccessNode code.write(f'{edge.data.data} = {src_name};') else: raise util.NotSupportedError('Only writeback to Tasklets and AccessNodes is supported')
def generate_read(self, sdfg: SDFG, state: SDFGState, map: nodes.Map, edge: graph.MultiConnectorEdge[mm.Memlet], code: CodeIOStream): """ Responsible for generating code for reads into a Tasklet, given the ingoing edge. """ if edge.dst_conn is None: return src_node = state.memlet_path(edge)[0].src dst_type = edge.dst.in_connectors[edge.dst_conn] dst_name = edge.dst_conn if isinstance(src_node, nodes.Tasklet): ################## # Code->Code edges src_type = edge.src.out_connectors[edge.src_conn] if util.is_vector(src_type) and util.is_vector(dst_type): # Directly read from shared vector register code.write(f'{util.TYPE_TO_SVE[dst_type.type]} {dst_name} = {edge.data.data};') elif util.is_scalar(src_type) and util.is_scalar(dst_type): # Directly read from shared scalar register code.write(f'{dst_type} {dst_name} = {edge.data.data};') elif util.is_scalar(src_type) and util.is_vector(dst_type): # Scalar broadcast from shared scalar register code.write( f'{util.TYPE_TO_SVE[dst_type.type]} {dst_name} = svdup_{util.TYPE_TO_SVE_SUFFIX[dst_type.type]}({edge.data.data});' ) else: raise util.NotSupportedError('Unsupported Code->Code edge') elif isinstance(src_node, nodes.AccessNode): ################## # Read from AccessNode desc = src_node.desc(sdfg) if isinstance(desc, data.Array): # Copy from array if util.is_pointer(dst_type): ################## # Pointer reference code.write( f'{dst_type} {dst_name} = {cpp.cpp_ptr_expr(sdfg, edge.data, None, codegen=self.frame)};') elif util.is_vector(dst_type): ################## # Vector load stride = edge.data.get_stride(sdfg, map) # First part of the declaration is `type name` load_lhs = '{} {}'.format(util.TYPE_TO_SVE[dst_type.type], dst_name) # long long issue casting ptr_cast = '' if dst_type.type == np.int64: ptr_cast = '(int64_t*) ' elif dst_type.type == np.uint64: ptr_cast = '(uint64_t*) ' # Regular load and gather share the first arguments load_args = '{}, {}'.format( util.get_loop_predicate(sdfg, state, edge.dst), ptr_cast + cpp.cpp_ptr_expr(sdfg, edge.data, DefinedType.Pointer, codegen=self.frame)) if stride == 1: code.write('{} = svld1({});'.format(load_lhs, load_args)) else: code.write('{} = svld1_gather_index({}, svindex_s{}(0, {}));'.format( load_lhs, load_args, util.get_base_type(dst_type).bytes * 8, sym2cpp(stride))) else: ################## # Scalar read from array code.write(f'{dst_type} {dst_name} = {cpp.cpp_array_expr(sdfg, edge.data, codegen=self.frame)};') elif isinstance(desc, data.Scalar): # Refer to shared variable src_type = desc.dtype if util.is_vector(src_type) and util.is_vector(dst_type): # Directly read from shared vector register code.write(f'{util.TYPE_TO_SVE[dst_type.type]} {dst_name} = {edge.data.data};') elif util.is_scalar(src_type) and util.is_scalar(dst_type): # Directly read from shared scalar register code.write(f'{dst_type} {dst_name} = {edge.data.data};') elif util.is_scalar(src_type) and util.is_vector(dst_type): # Scalar broadcast from shared scalar register code.write( f'{util.TYPE_TO_SVE[dst_type.type]} {dst_name} = svdup_{util.TYPE_TO_SVE_SUFFIX[dst_type.type]}({edge.data.data});' ) else: raise util.NotSupportedError('Unsupported Scalar->Code edge') else: raise util.NotSupportedError('Only copy from Tasklets and AccessNodes is supported')