def expansion(node, parent_state: SDFGState, parent_sdfg: SDFG): in_edge = parent_state.in_edges(node)[0] out_edge = parent_state.out_edges(node)[0] sdfg = dace.SDFG("nested") sdfg.add_datadesc("_a", copy.deepcopy(parent_sdfg.arrays[in_edge.data.data])) sdfg.add_datadesc( "_b", copy.deepcopy(parent_sdfg.arrays[out_edge.data.data])) sdfg.arrays["_a"].transient = False sdfg.arrays["_b"].transient = False state = sdfg.add_state() inp = state.add_access("_a") outp = state.add_access("_b") me, mx = state.add_map("useless_map", {"i": "0"}) tasklet = state.add_tasklet("add", {"inp"}, {"outp"}, "outp = inp + 1") state.add_edge(inp, None, me, None, sdfg.make_array_memlet("_a")) state.add_edge(me, None, tasklet, "inp", sdfg.make_array_memlet("_a")) state.add_edge(tasklet, "outp", mx, None, dace.Memlet("_b[0]")) state.add_edge(mx, None, outp, None, dace.Memlet("_b[0]")) sdfg.fill_scope_connectors() return sdfg
def expressions(): # Matching # o o g = SDFGState() g.add_node(MergeSourceSinkArrays._array1) return [g]
def can_be_applied(graph: dace.SDFGState, candidate: Dict[pm.PatternNode, int], expr_index: int, sdfg: dace.SDFG, permissive: bool = False): map_entry = graph.node(candidate[Reduction1Operation.map_entry]) map_exit = graph.exit_node(map_entry) params = [dace.symbol(p) for p in map_entry.map.params] outputs = dict() for _, _, _, _, m in graph.out_edges(map_exit): if not m.wcr: return False desc = sdfg.arrays[m.data] if desc not in outputs.keys(): outputs[desc] = [] outputs[desc].append(m.subset) for desc, accesses in outputs.items(): if isinstance(desc, dace.data.Scalar): continue elif isinstance(desc, (dace.data.Array, dace.data.View)): for a in accesses: if a.num_elements() != 1: return False else: return False return True
def fusion(sdfg: dace.SDFG, graph: dace.SDFGState, subgraph: Union[SubgraphView, List[SubgraphView]] = None, **kwargs): subgraph = graph if not subgraph else subgraph if not isinstance(subgraph, list): subgraph = [subgraph] map_fusion = SubgraphFusion(subgraph[0]) for (property, val) in kwargs.items(): setattr(map_fusion, property, val) for sg in subgraph: map_entries = helpers.get_outermost_scope_maps(sdfg, graph, sg) # remove map_entries and their corresponding exits from the subgraph # already before applying transformation if isinstance(sg, SubgraphView): for map_entry in map_entries: sg.nodes().remove(map_entry) if graph.exit_node(map_entry) in sg.nodes(): sg.nodes().remove(graph.exit_node(map_entry)) print(f"Subgraph Fusion on map entries {map_entries}") map_fusion.fuse(sdfg, graph, map_entries) if isinstance(sg, SubgraphView): sg.nodes().append(map_fusion._global_map_entry)
def get_node_name_mapping(state: dace.SDFGState, node: dace.nodes.LibraryNode): name_mapping = dict() for edge in state.in_edges(node): if edge.dst_conn is not None: assert edge.dst_conn.startswith("IN_") internal_name = edge.dst_conn[len("IN_"):] outer_name = edge.data.data if internal_name not in name_mapping: name_mapping[internal_name] = outer_name else: msg = ( f"input and output of field '{internal_name}' to node'{node.name}' refer to " + "different arrays") assert name_mapping[internal_name] == outer_name, msg for edge in state.out_edges(node): if edge.src_conn is not None: assert edge.src_conn.startswith("OUT_") internal_name = edge.src_conn[len("OUT_"):] outer_name = edge.data.data if internal_name not in name_mapping: name_mapping[internal_name] = outer_name else: msg = ( f"input and output of field '{internal_name}' to node'{node.name}' refer to" + "different arrays") assert name_mapping[internal_name] == outer_name, msg return name_mapping
def __init__(self, sdfg: SDFG, state: SDFGState, map_entry: nodes.MapEntry, vec_len, initial_constraints: Dict[Union[Tuple[nodes.Tasklet, str, bool], nodes.AccessNode], int] = None, flags: VectorInferenceFlags = None): """ Builds a vector inference graph for a Map to infer vectorizable Tasklet connectors and AccessNodes in polynomial time. :param sdfg: The SDFG where the Map resides. :param state: The state where the Map resides. :param map_entry: The entry node of the Map. :param vec_len: The vector length that should be used when creating a `dtypes.vector`. :param initial_constraints: A dictionary mapping from a connector specified using `(node, name, is_input)` or an `AccessNode` to either `InferenceNode.Scalar` or `InferenceNode.Vector`. :param flags: Additional flags to limit the vectorization. """ super().__init__() self.sdfg = sdfg self.state = state self.subgraph = state.scope_subgraph(map_entry, include_entry=False, include_exit=False) self.subgraph_with_scope = state.scope_subgraph(map_entry) self.map = map_entry.map # Infer connectors on the entire subgraph (including the entry and exit) self.inf: infer_types.TypeInferenceDict = infer_types.infer_connector_types( sdfg, state, self.subgraph_with_scope) # Use the innermost loop param self.param = self.map.params[-1] self.vec_len = vec_len # Stores a mapping from SDFG nodes/connectors to InferenceNode's # Used when constructing the internal inference graph self.conn_to_node = DefaultDict[Union[Tuple[nodes.Tasklet, str, bool], nodes.AccessNode], InferenceNode](lambda: None) self.flags = flags self._build() self._detect_constraints() if initial_constraints is not None: for n, t in initial_constraints.items(): self.set_constraint(n, t)
def unwire_access_node( state: SDFGState, left: HorizontalExecutionLibraryNode, access: dace.nodes.AccessNode, right: HorizontalExecutionLibraryNode, ) -> None: out_removable = set(state.edges_between(access, right)) for removable_edge in out_removable: state.remove_edge_and_connectors(removable_edge)
def can_be_applied(self, graph: dace.SDFGState, expr_index: int, sdfg: dace.SDFG, permissive: bool = False): map_entry = self.map_entry map_exit = graph.exit_node(map_entry) params = [dace.symbol(p) for p in map_entry.map.params] inputs = dict() for _, _, _, _, m in graph.out_edges(map_entry): if not m.data: continue desc = sdfg.arrays[m.data] if desc not in inputs.keys(): inputs[desc] = [] inputs[desc].append(m.subset) outputs = dict() for _, _, _, _, m in graph.in_edges(map_exit): if m.is_empty(): continue desc = sdfg.arrays[m.data] if not m.wcr: if desc not in inputs.keys(): return False access_found = False for a in inputs[desc]: if a == m.subset: access_found = True break if not access_found: return False if desc not in outputs.keys(): outputs[desc] = [] outputs[desc].append(m.subset) for desc, accesses in outputs.items(): if isinstance(desc, (dace.data.Array, dace.data.View)): for a in accesses: if a.num_elements() != 1: return False indices = a.min_element() unmatched_indices = set(params) for idx in indices: if idx in unmatched_indices: unmatched_indices.remove(idx) if len(unmatched_indices) == len(params): return False else: return False return True
def validate(self, sdfg: dace.SDFG, state: dace.SDFGState): try: size = dace.symbolic.evaluate(self.size, sdfg.constants) if size < 1: raise ValueError(f"Invalid size parameter for {self}: {size}") except TypeError: pass # Not a constant in_edge = state.in_edges(self) if len(in_edge) != 1: raise ValueError( f"Expected only one input edge, found {len(in_edge)} edges.") out_edge = state.out_edges(self) if len(out_edge) != 1: raise ValueError( f"Expected only one input edge, found {len(out_edge)} edges.") in_edge = in_edge[0] in_desc = sdfg.arrays[in_edge.data.data] if not isinstance(in_desc, dace.data.Stream): raise TypeError( f"Expected input to be a stream, got {type(in_desc)}.") out_edge = out_edge[0] out_desc = sdfg.arrays[out_edge.data.data] if not isinstance(out_desc, dace.data.Stream): raise TypeError( f"Expected input to be a stream, got {type(out_desc)}.") # The type of one side must be a vector of the other, or a vector of the # same type with a vector size that is a multiple of the other if (isinstance(in_desc.dtype, dace.vector) and in_desc.dtype.base_type == out_desc.dtype): is_pack = False # Is unpack gear_factor = in_desc.dtype.veclen elif (isinstance(out_desc.dtype, dace.vector) and out_desc.dtype.base_type == in_desc.dtype): is_pack = True gear_factor = out_desc.dtype.veclen elif (isinstance(in_desc.dtype, dace.vector) and isinstance(out_desc.dtype, dace.vector) and in_desc.veclen // out_desc.veclen > 1 and in_desc.veclen % out_desc.veclen == 0): is_pack = False # Is unpack gear_factor = in_desc.veclen // out_desc.veclen elif (isinstance(in_desc.dtype, dace.vector) and isinstance(out_desc.dtype, dace.vector) and out_desc.veclen // in_desc.veclen > 1 and out_desc.veclen % in_desc.veclen == 0): is_pack = True gear_factor = out_desc.veclen // in_desc.veclen else: raise TypeError( f"Cannot gearbox between {in_desc.dtype} for {in_edge.dst_conn}" f" and {out_desc.dtype} for {out_edge.src_conn}.") return (in_edge, in_desc, out_edge, out_desc, is_pack, gear_factor)
def apply(self, state: SDFGState, sdfg: SDFG): nsdfg = self.nsdfg candidates, candidate_nodes = self._candidates(nsdfg) for outer_edge in state.out_edges(nsdfg): if outer_edge.src_conn in candidates: state.remove_memlet_path(outer_edge) sdfg.remove_data(outer_edge.data.data, validate=False) for nstate, node in candidate_nodes: for ie in nstate.in_edges(node): nstate.remove_memlet_path(ie) for cand in candidates: nsdfg.sdfg.remove_data(cand, validate=False)
def gemv_libnode(sdfg: SDFG, state: SDFGState, A, B, C, alpha, beta, trans_a=False, trans_b=False): # Add nodes A_in, B_in = (state.add_read(name) for name in (A, B)) C_out = state.add_write(C) libnode = Gemm('gemm', transA=trans_a, transB=trans_b, alpha=alpha, beta=beta) state.add_node(libnode) # Connect nodes state.add_edge(A_in, None, libnode, '_a', mm.Memlet(A)) state.add_edge(B_in, None, libnode, '_b', mm.Memlet(B)) state.add_edge(libnode, '_c', C_out, None, mm.Memlet(C)) # TODO: Bring back C as input connector if beta is not 0 # if beta != 0: # C_in = state.add_read(C) # state.add_edge(C_in, None, libnode, '_c', mm.Memlet(C)) return []
def op_repo_replacement(sdfg: SDFG, state: SDFGState, **kwargs): attrs = { name: value for name, value in kwargs.items() if name in dace_schema.attributes } onnx_node = cls(name=cls_name, **attrs) state.add_node(onnx_node) input_names = {p.name for p in dace_schema.inputs} output_names = {p.name for p in dace_schema.outputs} inputs = { name: arr_name for name, arr_name in kwargs.items() if name in input_names } outputs = { name: arr_name for name, arr_name in kwargs.items() if name in output_names } for inp, arr_name in inputs.items(): read = state.add_read(arr_name) state.add_edge(read, None, onnx_node, inp, sdfg.make_array_memlet(arr_name)) for outp, arr_name in outputs.items(): write = state.add_read(arr_name) state.add_edge(onnx_node, outp, write, None, sdfg.make_array_memlet(arr_name)) return []
def count_matmul(node: MatMul, symbols: Dict[str, Any], state: dace.SDFGState) -> int: A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b') C_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_c') result = 2 # Multiply, add # Batch if len(C_memlet.data.subset) == 3: result *= symeval(C_memlet.data.subset.size()[0], symbols) # M*N result *= symeval(C_memlet.data.subset.size()[-2], symbols) result *= symeval(C_memlet.data.subset.size()[-1], symbols) # K result *= symeval(A_memlet.data.subset.size()[-1], symbols) return result
def expand_reduce(sdfg: dace.SDFG, graph: dace.SDFGState, subgraph: Union[SubgraphView, List[SubgraphView]] = None, **kwargs): subgraph = graph if not subgraph else subgraph if not isinstance(subgraph, list): subgraph = [subgraph] for sg in subgraph: reduce_nodes = [] for node in sg.nodes(): if isinstance(node, stdlib.Reduce): rexp = ReduceExpansion(sdfg, sdfg.sdfg_id, sdfg.node_id(graph), {ReduceExpansion.reduce: graph.node_id(node)}, 0) if not rexp.can_be_applied(graph, 0, sdfg): print(f"WARNING: Cannot expand reduce node {node}:" "can_be_applied() failed.") continue reduce_nodes.append(node) trafo_reduce = ReduceExpansion(sdfg, sdfg.sdfg_id, sdfg.node_id(graph), {}, 0) for (property, val) in kwargs.items(): setattr(trafo_reduce, property, val) for reduce_node in reduce_nodes: trafo_reduce.expand(sdfg, graph, reduce_node) if isinstance(sg, SubgraphView): sg.nodes().remove(reduce_node) sg.nodes().append(trafo_reduce._reduce) sg.nodes().append(trafo_reduce._outer_entry)
def can_be_applied(graph: dace.SDFGState, candidate: Dict[Any, int], expr_index: int, sdfg: dace.SDFG, strict=False): map_entry: nodes.MapEntry = graph.node(candidate[NestK._map_entry]) stencil: Stencil = graph.node(candidate[NestK._stencil]) if len(map_entry.map.params) != 1: return False if sd.has_dynamic_map_inputs(graph, map_entry): return False pname = map_entry.map.params[0] # Usually "k" dim_index = None for edge in graph.out_edges(map_entry): if edge.dst != stencil: return False for edge in graph.all_edges(stencil): if edge.data.data is None: # Empty memlet continue # TODO: Use bitmap to verify lower-dimensional arrays if len(edge.data.subset) == 3: for i, rng in enumerate(edge.data.subset.ndrange()): for r in rng: if pname in map(str, r.free_symbols): if dim_index is not None and dim_index != i: # k dimension must match in all memlets return False if str(r) != pname: if symbolic.issymbolic( r - symbolic.symbol(pname), sdfg.constants): warnings.warn('k expression is nontrivial') dim_index = i # No nesting dimension found if dim_index is None: return False # Ensure the stencil shape is 1 for the found dimension if stencil.shape[dim_index] != 1: return False return True
def can_be_applied(graph: dace.SDFGState, candidate: Dict[Any, int], expr_index: int, sdfg: dace.SDFG, strict=False): stencil_a: Stencil = graph.node(candidate[StencilFusion._stencil_a]) stencil_b: Stencil = graph.node(candidate[StencilFusion._stencil_b]) array: nodes.AccessNode = graph.node( candidate[StencilFusion._tmp_array]) # Ensure the stencil shapes match if len(stencil_a.shape) != len(stencil_b.shape): return False if any(sa != sb for sa, sb in zip(stencil_a.shape, stencil_b.shape)): return False # Ensure that the transient is not used anywhere else and can be # removed if len(graph.all_edges(array)) != 2: return False if not sdfg.arrays[array.data].transient: return False if (len([ n for state in sdfg.nodes() for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data == array.data ]) > 1): return False # Ensure that second stencil only has one input access of the # candidate transient to remove edge = graph.out_edges(array)[0] if len(stencil_b.accesses[edge.dst_conn][1]) > 1: return False # TODO: Remove check once stencils can be offset if any(a != 0 for a in stencil_b.accesses[edge.dst_conn][1][0]): return False # Code languages must match if stencil_a.code.language != stencil_b.code.language: return False # TODO: Boundary condition matching checks return True
def op_repo_replacement(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, **kwargs): attrs = { name: value for name, value in kwargs.items() if name in dace_schema.attributes } # remove used attrs kwargs = {k: v for k, v in kwargs.items() if k not in attrs} onnx_node = cls(name=cls_name, **attrs) state.add_node(onnx_node) input_names = dace_schema.non_variadic_inputs() variadic_inputs = dace_schema.variadic_inputs() output_names = dace_schema.non_variadic_outputs() variadic_outputs = dace_schema.variadic_outputs() inputs = { name: arr_name for name, arr_name in kwargs.items() if (name in input_names or # variadic params ("__" in name and parse_variadic_param(name)[0] in variadic_inputs)) } kwargs = {k: v for k, v in kwargs.items() if k not in inputs} outputs = { name: arr_name for name, arr_name in kwargs.items() if (name in output_names or # variadic params ("__" in name and parse_variadic_param(name)[0] in variadic_outputs)) } kwargs = {k: v for k, v in kwargs.items() if k not in outputs} if len(kwargs) > 0: raise TypeError(f"Unknown arguments {', '.join(kwargs)}") for inp, arr_name in inputs.items(): read = state.add_read(arr_name) state.add_edge(read, None, onnx_node, inp, sdfg.make_array_memlet(arr_name)) onnx_node.add_in_connector(inp) for outp, arr_name in outputs.items(): write = state.add_read(arr_name) state.add_edge(onnx_node, outp, write, None, sdfg.make_array_memlet(arr_name)) onnx_node.add_out_connector(outp) return []
def iter_edges( self, state: SDFGState) -> Iterator[Tuple[MultiConnectorEdge, bool]]: """ Returns an iterator over tuples of an edge and a boolean that indicates whether that edge is an input, ordered by the order required by the schema. This method assumes that this node has been validated. :param state: the state containing this node. """ in_edges: List[MultiConnectorEdge] = state.in_edges(self) out_edges: List[MultiConnectorEdge] = state.out_edges(self) def get_idx(parameters, name): full_name = name if '__' in name: name, number = parse_variadic_param(name) else: number = 0 matched = [ i for i, param in enumerate(parameters) if param.name == name ] # since validation passed, we know there will only be one if len(matched) != 1: raise ValueError( "Found {} connectors with name '{}', expected to find exactly one" .format(len(matched), name)) parameter_idx = matched[0] # add on the variadic parameter index parameter_idx += number return parameter_idx sorted_in = sorted( in_edges, key=lambda edge: get_idx(self.schema.inputs, edge.dst_conn)) sorted_out = sorted( out_edges, key=lambda edge: get_idx(self.schema.outputs, edge.src_conn)) return itertools.chain(zip(sorted_in, itertools.repeat(True)), zip(sorted_out, itertools.repeat(False)))
def create_zero_initialization(init_state: dace.SDFGState, array_name): sdfg = init_state.parent array_shape = sdfg.arrays[array_name].shape array_access_node = init_state.add_write(array_name) indices = ["i" + str(k) for k, _ in enumerate(array_shape)] init_state.add_mapped_tasklet( output_nodes={array_name: array_access_node}, name=(array_name + "_init_tasklet"), map_ranges={k: "0:" + str(v) for k, v in zip(indices, array_shape)}, inputs={}, code='val = 0', outputs=dict( val=dace.Memlet.simple(array_access_node.data, ",".join(indices))), external_edges=True)
def can_be_applied(graph: dace.SDFGState, candidate: Dict[pm.PatternNode, int], expr_index: int, sdfg: dace.SDFG, strict: bool = False): # A candidate subgraph matches the map-expansion pattern when it # includes an N-dimensional map, with N greater than one. map_entry = graph.node(candidate[MapExpansion.map_entry]) return map_entry.map.get_param_num() > 1
def get_connector_edges(dfg: dace.SDFGState, node: nodes.Node, conn: str, is_in_conn: bool) -> graph.Edge: edges = [] for e in dfg.all_edges(node): if (is_in_conn and e.dst == node and e.dst_conn == conn) or (not is_in_conn and e.src == node and e.src_conn == conn): edges.append(e) return edges
def is_constant(sdfg: dace.SDFG, state: dace.SDFGState, node) -> bool: if len(state.in_edges(node)) > 0: return False # the ONNX importer adds a _parent_onnx_model attribute to the sdfg if isinstance(node, nd.AccessNode ) and node.data in sdfg._parent_onnx_model.clean_weights: return True return False
def axpy_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, a, x, y, result): # Add nodes x_in, y_in = (state.add_read(name) for name in (x, y)) res = state.add_write(result) libnode = Axpy('axpy', a=a) state.add_node(libnode) # Connect nodes state.add_edge(x_in, None, libnode, '_x', mm.Memlet(x)) state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y)) state.add_edge(libnode, '_res', res, None, mm.Memlet(result)) return []
def dot_libnode(sdfg: SDFG, state: SDFGState, x, y, result): # Add nodes x_in, y_in = (state.add_read(name) for name in (x, y)) res = state.add_write(result) libnode = Dot('dot', n=sdfg.arrays[x].shape[0]) state.add_node(libnode) # Connect nodes state.add_edge(x_in, None, libnode, '_x', mm.Memlet(x)) state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y)) state.add_edge(libnode, '_result', res, None, mm.Memlet(result)) return []
def count_moved_data_state(state: dace.SDFGState, symbols: Dict[str, Any]) -> int: stree_root = state.scope_tree()[None] sdict = state.scope_dict(node_to_children=True) result = 0 edges_counted = set() for node in sdict[None]: node_result = 0 if isinstance(node, (CodeNode, LibraryNode, Reduce)): inputs = sum(e.data.num_accesses for e in state.in_edges(node) if e not in edges_counted) outputs = sum(e.data.num_accesses for e in state.out_edges(node) if e not in edges_counted) # Do not count edges twice edges_counted |= set(state.all_edges(node)) iprint( type(node).__name__, node, 'inputs:', inputs, 'outputs:', outputs) node_result += inputs + outputs elif isinstance(node, dace.nodes.EntryNode): # Gather inputs from entry node inputs = sum(e.data.num_accesses for e in state.in_edges(node) if e not in edges_counted) # Do not count edges twice edges_counted |= set(state.in_edges(node)) # Gather outputs from exit node exit_node = state.exit_nodes(node)[0] outputs = sum(e.data.num_accesses for e in state.out_edges(exit_node) if e not in edges_counted) edges_counted |= set(state.out_edges(exit_node)) iprint('Scope', type(node).__name__, node, 'inputs:', inputs, 'outputs:', outputs) node_result += inputs + outputs result += node_result return result
def create_einsum(state: dace.SDFGState, map_ranges, code, inputs, outputs=None, wcr_outputs=None): outputs = outputs or [] wcr_outputs = wcr_outputs or [] inpdict = {access_node.data: access_node for access_node, _ in inputs} outdict = { access_node.data: access_node for access_node, _ in (outputs + wcr_outputs) } input_memlets = {(access_node.data + "_inp"): dace.Memlet.simple(access_node.data, access_range) for access_node, access_range in inputs} output_memlets = {(access_node.data + "_out"): dace.Memlet.simple(access_node.data, access_range) for access_node, access_range in outputs} wcr_output_memlets = {(access_node.data + "_out"): dace.Memlet.simple(access_node.data, access_range, wcr_str='lambda x, y: x + y') for access_node, access_range in wcr_outputs} state.add_mapped_tasklet(name="einsum_tasklet", input_nodes=inpdict, output_nodes=outdict, map_ranges=map_ranges, inputs=input_memlets, code=code, outputs={ **output_memlets, **wcr_output_memlets }, external_edges=True)
def _iter_params_in_onnx_order( self, state: SDFGState, inputs: bool = False) -> List[MultiConnectorEdge]: parameters = list( self.schema.inputs if inputs else self.schema.outputs) if parameters[-1].param_type == ONNXParameterType.Variadic: name = parameters[-1].name parameters = itertools.chain( [param.name for param in parameters[:-1]], (name + "__" + str(i) for i in itertools.count())) else: parameters = [param.name for param in parameters] edges = state.in_edges(self) if inputs else state.out_edges(self) parameters = list(itertools.islice(parameters, len(edges))) conn_to_edge = { edge.dst_conn if inputs else edge.src_conn: edge for edge in edges } return [conn_to_edge[name] for name in parameters]
def dot_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, x, y, result, acctype=None): # Add nodes x_in, y_in = (state.add_read(name) for name in (x, y)) res = state.add_write(result) libnode = Dot('dot', n=sdfg.arrays[x].shape[0], accumulator_type=acctype) state.add_node(libnode) # Connect nodes state.add_edge(x_in, None, libnode, '_x', mm.Memlet(x)) state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y)) state.add_edge(libnode, '_result', res, None, mm.Memlet(result)) return []
def map_descriptor(state: dace.SDFGState, map_entry: dace.nodes.MapEntry) -> str: tasklets = filter( lambda node: isinstance(node, dace.nodes.Tasklet), map(lambda edge: edge.dst, state.out_edges(map_entry))) tasklets = set(tasklets) desc = [] for tasklet in tasklets: label = tasklet.label.split("_")[:-2] label = "_".join(label) desc.append(label) return ":".join(desc)
def frag_fill(pv: ProgramVisitor, sdfg: dace.SDFG, state: dace.SDFGState, frag: str, fill: Any) -> List[str]: # Replacement functions receive the SDFG and the current state as the first # two arguments, followed by all the other arguments. Here we treat them as # two strings representing the array name to fill and what to fill it with. # NOTE: If a slice is used in the `frag` argument, the Python frontend # automatically creates a new array for it, and uses the correct string as # the argument. wnode = state.add_write(frag) tasklet = state.add_tasklet('fill', set(), {'out'}, ''' wmma::fill_fragment(out, %s);''' % fill, language=dace.Language.CPP) state.add_edge(tasklet, 'out', wnode, None, dace.Memlet.from_array(frag, wnode.desc(sdfg))) _include_mma(sdfg) # Function has no return value return []