class BaseOirSDFGBuilder(ABC): has_transients = True def __init__(self, name, stencil: Stencil, nodes): self._stencil = stencil self._sdfg = SDFG(name) self._state = self._sdfg.add_state(name + "_state") self._extents = nodes_extent_calculation(nodes) self._dtypes = { decl.name: decl.dtype for decl in stencil.declarations + stencil.params } self._axes = { decl.name: decl.dimensions for decl in stencil.declarations + stencil.params if isinstance(decl, FieldDecl) } self._recent_write_acc: Dict[str, dace.nodes.AccessNode] = dict() self._recent_read_acc: Dict[str, dace.nodes.AccessNode] = dict() self._access_nodes: Dict[str, dace.nodes.AccessNode] = dict() self._access_collection_cache: Dict[ int, AccessCollector.CartesianAccessCollection] = dict() self._source_nodes: Dict[str, dace.nodes.AccessNode] = dict() self._delete_candidates: List[MultiConnectorEdge] = list() def _access_space_to_subset(self, name, access_space): extent = self._extents[name] origin = (extent[0][0], extent[1][0]) subsets = [] if self._axes[name][0]: subsets.append("{start}:__I{end:+d}".format( start=origin[0] + access_space[0][0], end=origin[0] + access_space[0][1])) if self._axes[name][1]: subsets.append("{start}:__J{end:+d}".format( start=origin[1] + access_space[1][0], end=origin[1] + access_space[1][1])) return subsets def _are_nodes_ordered(self, name, node1, node2): assert name in self._access_nodes assert node1.data == name assert node2.data == name return self._access_nodes[name].index( node1) < self._access_nodes[name].index(node2) def _get_source(self, name): if name not in self._source_nodes: self._source_nodes[name] = self._state.add_read(name) if name not in self._access_nodes: self._access_nodes[name] = [] self._access_nodes[name].insert(0, self._source_nodes[name]) return self._source_nodes[name] def _get_new_sink(self, name): res = self._state.add_access(name) if name not in self._access_nodes: self._access_nodes[name] = [] self._access_nodes[name].append(res) return res def _get_current_sink(self, name): if name in self._access_nodes: return self._access_nodes[name][-1] return None def _get_access_collection( self, node: "Union[HorizontalExecutionLibraryNode, VerticalLoopLibraryNode, SDFG]" ) -> AccessCollector.CartesianAccessCollection: if isinstance(node, SDFG): res = AccessCollector.CartesianAccessCollection([]) for node in node.states()[0].nodes(): if isinstance( node, (HorizontalExecutionLibraryNode, VerticalLoopLibraryNode)): collection = self._get_access_collection(node) res._ordered_accesses.extend(collection._ordered_accesses) return res elif isinstance(node, HorizontalExecutionLibraryNode): if id(node.oir_node) not in self._access_collection_cache: self._access_collection_cache[id( node.oir_node)] = AccessCollector.apply( node.oir_node).cartesian_accesses() return self._access_collection_cache[id(node.oir_node)] else: assert isinstance(node, VerticalLoopLibraryNode) res = AccessCollector.CartesianAccessCollection([]) for _, sdfg in node.sections: collection = self._get_access_collection(sdfg) res._ordered_accesses.extend(collection._ordered_accesses) return res def _get_recent_reads(self, name, interval): if name not in self._recent_read_acc: self._recent_read_acc[name] = IntervalMapping() if not self._axes[name][2]: interval = Interval.full() return self._recent_read_acc[name][interval] def _get_recent_writes(self, name, interval): if name not in self._recent_write_acc: self._recent_write_acc[name] = IntervalMapping() if not self._axes[name][2]: interval = Interval.full() return self._recent_write_acc[name][interval] def _set_read(self, name, interval, node): if name not in self._recent_read_acc: self._recent_read_acc[name] = IntervalMapping() if not self._axes[name][2]: interval = Interval.full() self._recent_read_acc[name][interval] = node def _set_write(self, name, interval, node): if name not in self._recent_write_acc: self._recent_write_acc[name] = IntervalMapping() if not self._axes[name][2]: interval = Interval.full() self._recent_write_acc[name][interval] = node def _reset_writes(self): self._recent_write_acc = dict() def _add_read_edges( self, node, collections: List[Tuple[Interval, AccessCollector.CartesianAccessCollection]]): read_accesses: Dict[str, dace.nodes.AccessNode] = dict() for interval, access_collection in collections: for name in access_collection.read_fields(): for offset in access_collection.read_offsets()[name]: read_interval = interval.shifted(offset[2]) for candidate_access in self._get_recent_writes( name, read_interval): if name not in read_accesses or self._are_nodes_ordered( name, read_accesses[name], candidate_access): # candidate_access is downstream from recent_access, therefore candidate is more recent read_accesses[name] = candidate_access for interval, access_collection in collections: for name in access_collection.read_fields(): for offset in access_collection.read_offsets()[name]: read_interval = interval.shifted(offset[2]) if name not in read_accesses: read_accesses[name] = self._get_source(name) self._set_read(name, read_interval, read_accesses[name]) for name, recent_access in read_accesses.items(): node.add_in_connector("IN_" + name) self._state.add_edge(recent_access, None, node, "IN_" + name, dace.Memlet()) def _add_write_edges( self, node, collections: List[Tuple[Interval, AccessCollector.CartesianAccessCollection]]): write_accesses = dict() for interval, access_collection in collections: for name in access_collection.write_fields(): access_node = self._get_current_sink(name) if access_node is None or ( (name not in write_accesses) and (access_node in self._get_recent_reads(name, interval) or access_node in self._get_recent_writes(name, interval) or nx.has_path(self._state.nx, access_node, node))): write_accesses[name] = self._get_new_sink(name) else: write_accesses[name] = access_node self._set_write(name, interval, write_accesses[name]) for name, access_node in write_accesses.items(): node.add_out_connector("OUT_" + name) self._state.add_edge(node, "OUT_" + name, access_node, None, dace.Memlet()) def _add_write_after_write_edges( self, node, collections: List[Tuple[Interval, AccessCollector.CartesianAccessCollection]]): for interval, collection in collections: for name in collection.write_fields(): for src in self._get_recent_writes(name, interval): edge = self._state.add_edge(src, None, node, None, dace.Memlet()) self._delete_candidates.append(edge) def _add_write_after_read_edges( self, node, collections: List[Tuple[Interval, AccessCollector.CartesianAccessCollection]]): for interval, collection in collections: for name in collection.read_fields(): for offset in collection.read_offsets()[name]: read_interval = interval.shifted(offset[2]) for dst in self._get_recent_writes(name, read_interval): edge = self._state.add_edge(node, None, dst, None, dace.Memlet()) self._delete_candidates.append(edge) for interval, collection in collections: for name in collection.write_fields(): self._set_write(name, interval, node) def add_node(self, node): self._state.add_node(node) def finalize(self): for edge in self._delete_candidates: assert edge.src_conn is None assert edge.dst_conn is None self._state.remove_edge(edge) if not nx.has_path(self._state.nx, edge.src, edge.dst): self._state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn, edge.data) self.add_subsets() self.add_arrays() for acc in (n for n in self._state.nodes() if isinstance(n, dace.nodes.AccessNode)): is_write = len(self._state.in_edges(acc)) > 0 and all( edge.data.data is not None for edge in self._state.in_edges(acc)) is_read = len(self._state.out_edges(acc)) > 0 and all( edge.data.data is not None for edge in self._state.out_edges(acc)) if is_read and is_write: acc.access = dace.AccessType.ReadWrite elif is_read: acc.access = dace.AccessType.ReadOnly else: assert is_write acc.access = dace.AccessType.WriteOnly def _get_sdfg(self): self.finalize() return self._sdfg def add_arrays(self): shapes = self.get_shapes() for decl in self._stencil.params + self._stencil.declarations: name = decl.name dtype = dace.dtypes.typeclass( np.dtype(data_type_to_typestr(self._dtypes[name])).name) if isinstance(decl, ScalarDecl): self._sdfg.add_symbol(name, stype=dtype) else: if name not in self._get_access_collection( self._sdfg).offsets(): continue assert name in self._dtypes strides = tuple( dace.symbolic.pystr_to_symbolic(f"__{name}_{var}_stride") for is_axis, var in zip(self._axes[name], "IJK") if is_axis ) + tuple( dace.symbolic.pystr_to_symbolic(f"__{name}_d{dim}_stride") for dim, _ in enumerate(decl.data_dims)) self._sdfg.add_array( name, dtype=dtype, shape=shapes[name], strides=strides, transient=isinstance(decl, Temporary) and self.has_transients, lifetime=dace.AllocationLifetime.Persistent, ) def add_subsets(self): decls = { decl.name: decl for decl in self._stencil.params + self._stencil.declarations } for node in self._state.nodes(): if isinstance(node, dace.nodes.LibraryNode): access_spaces_input, access_spaces_output = self.get_access_spaces( node) k_subset_strs_input, k_subset_strs_output = self.get_k_subsets( node) for edge in self._state.in_edges(node) + self._state.out_edges( node): if edge.dst_conn is not None: name = edge.src.data access_space = access_spaces_input[name] subset_str_k = k_subset_strs_input.get(name, None) dynamic = isinstance( node, HorizontalExecutionLibraryNode) and any( isinstance(stmt, oir.MaskStmt) for stmt in node.oir_node.body) elif edge.src_conn is not None: name = edge.dst.data access_space = access_spaces_output[name] subset_str_k = k_subset_strs_output.get(name, None) dynamic = False else: continue subset_strs = self._access_space_to_subset( name, access_space) if subset_str_k is not None: subset_strs.append(subset_str_k) for dim in decls[name].data_dims: subset_strs.append(f"0:{dim}") edge.data = dace.Memlet.simple( data=name, subset_str=",".join(subset_strs), dynamic=dynamic) @abstractmethod def get_k_size(self, name): pass @abstractmethod def add_read_edges(self, node): pass @abstractmethod def add_write_edges(self, node): pass @abstractmethod def add_write_after_read_edges(self, node): pass @abstractmethod def add_write_after_write_edges(self, node): pass @abstractmethod def get_k_subsets(self, node): pass @abstractmethod def get_access_spaces(self, node): pass @abstractmethod def get_shapes(self): pass @classmethod def build(cls, name, stencil: Stencil, nodes: List[dace.nodes.LibraryNode]): builder = cls(name, stencil, nodes) for n in nodes: builder.add_node(n) builder.add_write_after_write_edges(n) builder.add_read_edges(n) builder.add_write_edges(n) builder._reset_writes() for n in reversed(nodes): builder.add_write_after_read_edges(n) res = builder._get_sdfg() res.validate() return res
def insert_sdfg_element(sdfg_str, type, parent_uuid, edge_a_uuid): sdfg_answer = load_sdfg_from_json(sdfg_str) sdfg = sdfg_answer['sdfg'] uuid = 'error' ret = find_graph_element_by_uuid(sdfg, parent_uuid) parent = ret['element'] libname = None if type is not None and isinstance(type, str): split_type = type.split('|') if len(split_type) == 2: type = split_type[0] libname = split_type[1] if type == 'SDFGState': if parent is None: parent = sdfg elif isinstance(parent, nodes.NestedSDFG): parent = parent.sdfg state = parent.add_state() uuid = [get_uuid(state)] elif type == 'AccessNode': arrays = list(parent.parent.arrays.keys()) if len(arrays) == 0: parent.parent.add_array('tmp', [1], dtype=dtypes.float64) arrays = list(parent.parent.arrays.keys()) node = parent.add_access(arrays[0]) uuid = [get_uuid(node, parent)] elif type == 'Map': map_entry, map_exit = parent.add_map('map', dict(i='0:1')) uuid = [get_uuid(map_entry, parent), get_uuid(map_exit, parent)] elif type == 'Consume': consume_entry, consume_exit = parent.add_consume('consume', ('i', '1')) uuid = [get_uuid(consume_entry, parent), get_uuid(consume_exit, parent)] elif type == 'Tasklet': tasklet = parent.add_tasklet( name='placeholder', inputs={'in'}, outputs={'out'}, code='') uuid = [get_uuid(tasklet, parent)] elif type == 'NestedSDFG': sub_sdfg = SDFG('nested_sdfg') sub_sdfg.add_array('in', [1], dtypes.float32) sub_sdfg.add_array('out', [1], dtypes.float32) nsdfg = parent.add_nested_sdfg(sub_sdfg, sdfg, {'in'}, {'out'}) uuid = [get_uuid(nsdfg, parent)] elif type == 'LibraryNode': if libname is None: return { 'error': { 'message': 'Failed to add library node', 'details': 'Must provide a valid library node type', }, } libnode_class = pydoc.locate(libname) libnode = libnode_class() parent.add_node(libnode) uuid = [get_uuid(libnode, parent)] elif type == 'Edge': edge_start_ret = find_graph_element_by_uuid(sdfg, edge_a_uuid) edge_start = edge_start_ret['element'] edge_parent = edge_start_ret['parent'] if edge_start is not None: if edge_parent is None: edge_parent = sdfg if isinstance(edge_parent, SDFGState): if not (isinstance(edge_start, nodes.Node) and isinstance(parent, nodes.Node)): return { 'error': { 'message': 'Failed to add edge', 'details': 'Must connect two nodes or two states', }, } memlet = Memlet() edge_parent.add_edge(edge_start, None, parent, None, memlet) elif isinstance(edge_parent, SDFG): if not (isinstance(edge_start, SDFGState) and isinstance(parent, SDFGState)): return { 'error': { 'message': 'Failed to add edge', 'details': 'Must connect two nodes or two states', }, } isedge = InterstateEdge() edge_parent.add_edge(edge_start, parent, isedge) uuid = ['NONE'] else: raise ValueError('No edge starting point provided') old_meta = disable_save_metadata() new_sdfg_str = sdfg.to_json() restore_save_metadata(old_meta) return { 'sdfg': new_sdfg_str, 'uuid': uuid, }
def apply(self, sdfg: dace.SDFG): # Extract the subgraph, execute it and insert an AccessNode to the result parent: ONNXModel = sdfg._parent_onnx_model state = sdfg.nodes()[self.state_id] node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]] if isinstance(node, donnx.ONNXShape): # if we have a shape node, replace it with a constant assert len(state.in_edges(node)) == 1 shape_in_edge = state.in_edges(node)[0] assert shape_in_edge.dst_conn == "data" shape_desc = sdfg.arrays[shape_in_edge.src.data] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ), dace.int64) assert constant_name not in parent.clean_weights parent.weights[constant_name] = np.array(shape_desc.shape, np.int64) assert len(state.out_edges(node)) == 1 output_edge = state.out_edges(node)[0] access_shape = state.add_access(clean_constant_name) state.add_edge(access_shape, None, output_edge.dst, output_edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) else: # otherwise compute the result of the op sub_sdfg = dace.SDFG("sub_sdfg") sub_state = sub_sdfg.add_state() node_copy = copy.deepcopy(node) sub_state.add_node(node_copy) inputs = {} for edge in state.in_edges(node): # we know from can_be_applied that all in edges are from AccessNodes assert (isinstance(edge.src, nd.AccessNode) and hasattr(sdfg, "_parent_onnx_model") and edge.src.data in sdfg._parent_onnx_model.clean_weights) desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc) input_value = sdfg._parent_onnx_model.clean_weights[ edge.src.data] if len(input_value.shape) == 0: inputs['array_' + edge.dst_conn] = input_value[()] else: inputs['array_' + edge.dst_conn] = input_value.copy() access = sub_state.add_access('array_' + edge.dst_conn) sub_state.add_edge( access, None, node_copy, edge.dst_conn, sub_sdfg.make_array_memlet('array_' + edge.dst_conn)) outputs = {} for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) if isinstance(desc, dt.Scalar): # we need to copy to an array of size [1] so that we can "return" the output from the sdfg desc.transient = True sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn, desc) sub_sdfg.add_array('array_' + edge.src_conn, [1], desc.dtype, transient=False) access_scalar = sub_state.add_access('scalar_array_' + edge.src_conn) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access_scalar, None, sub_sdfg.make_array_memlet('scalar_array_' + edge.src_conn)) sub_state.add_edge( access_scalar, None, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) else: desc.transient = False sub_sdfg.add_datadesc('array_' + edge.src_conn, desc) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) if len(desc.shape) == 0: outputs['array_' + edge.src_conn] = np.empty( (1, ), desc.dtype.as_numpy_dtype()) else: outputs['array_' + edge.src_conn] = np.empty( tuple(desc.shape), desc.dtype.as_numpy_dtype()) sub_sdfg(**outputs, **inputs) for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False output_value = outputs['array_' + edge.src_conn] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_datadesc(clean_constant_name, desc) assert constant_name not in parent.weights if isinstance(desc, dt.Scalar): parent.weights[constant_name] = output_value.reshape(()) else: parent.weights[constant_name] = output_value access_constant = state.add_access(clean_constant_name) state.add_edge(access_constant, None, edge.dst, edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) # remove all now useless nodes with a reverse BFS queue = deque([node]) while len(queue) > 0: current_node = queue.popleft() edges = state.in_edges(current_node) state.remove_node(current_node) for e in edges: next_node = e.src if len(state.out_edges(next_node)) == 0: queue.append(next_node)
def apply(self, sdfg: dace.SDFG): # Extract the subgraph, execute it and insert an AccessNode to the result # this method of execution is slow but simple. A better option would be to call the ORT # C API from a python object (like the OpChecker). parent: ONNXModel = sdfg._parent_onnx_model state = sdfg.nodes()[self.state_id] node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]] log.debug(f"Applying constant folding: {node} in {state}") if isinstance(node, donnx.ONNXShape): # if we have a shape node, replace it with a constant assert len(state.in_edges(node)) == 1 shape_in_edge = state.in_edges(node)[0] assert shape_in_edge.dst_conn == "data" shape_desc = sdfg.arrays[shape_in_edge.src.data] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ), dace.int64) assert constant_name not in parent.clean_weights parent.weights[constant_name] = torch.from_numpy( np.array(shape_desc.shape, np.int64)) assert len(state.out_edges(node)) == 1 output_edge = state.out_edges(node)[0] access_shape = state.add_access(clean_constant_name) state.add_edge(access_shape, None, output_edge.dst, output_edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) else: # otherwise compute the result of the op global UNIQUE_ID UNIQUE_ID += 1 sub_sdfg = dace.SDFG("sub_sdfg_" + str(UNIQUE_ID)) sub_state = sub_sdfg.add_state() node_copy = copy.deepcopy(node) sub_state.add_node(node_copy) inputs = {} for edge in state.in_edges(node): # we know from can_be_applied that all in edges are from AccessNodes assert (isinstance(edge.src, nd.AccessNode) and hasattr(sdfg, "_parent_onnx_model") and edge.src.data in sdfg._parent_onnx_model.clean_weights) desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc) input_value = sdfg._parent_onnx_model.clean_weights[ edge.src.data] if len(input_value.shape) == 0: inputs['array_' + edge.dst_conn] = input_value.cpu().numpy()[()] else: inputs['array_' + edge.dst_conn] = input_value.clone() access = sub_state.add_access('array_' + edge.dst_conn) sub_state.add_edge( access, None, node_copy, edge.dst_conn, sub_sdfg.make_array_memlet('array_' + edge.dst_conn)) outputs = {} for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) if isinstance(desc, dt.Scalar): # we need to copy to an array of size [1] so that we can "return" the output from the sdfg desc.transient = True sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn, desc) sub_sdfg.add_array('array_' + edge.src_conn, [1], desc.dtype, transient=False) access_scalar = sub_state.add_access('scalar_array_' + edge.src_conn) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access_scalar, None, sub_sdfg.make_array_memlet('scalar_array_' + edge.src_conn)) sub_state.add_edge( access_scalar, None, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) else: desc.transient = False sub_sdfg.add_datadesc('array_' + edge.src_conn, desc) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) if len(desc.shape) == 0: empty_array = np.empty((1, ), desc.dtype.as_numpy_dtype()) else: empty_array = np.empty(tuple(desc.shape), desc.dtype.as_numpy_dtype()) empty_array = torch.from_numpy(empty_array) if desc.storage is dtypes.StorageType.GPU_Global: empty_array = empty_array.cuda() outputs['array_' + edge.src_conn] = empty_array sub_sdfg(**outputs, **inputs) for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False output_value = outputs['array_' + edge.src_conn] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_datadesc(clean_constant_name, desc) assert constant_name not in parent.weights assert type(output_value) is torch.Tensor if not dtypes.can_access(dtypes.ScheduleType.CPU_Multicore, desc.storage): cpu_desc = copy.deepcopy(desc) cpu_desc.storage = dtypes.StorageType.CPU_Heap cpu_desc.transient = False desc.transient = True copy_in_name = sdfg.temp_data_name() clean_copy_in_name = clean_onnx_name(copy_in_name) sdfg.add_datadesc(clean_copy_in_name, cpu_desc) access_constant = state.add_access(clean_constant_name) state.add_edge(state.add_read(clean_copy_in_name), None, access_constant, None, sdfg.make_array_memlet(clean_copy_in_name)) name_to_add = copy_in_name else: access_constant = state.add_read(clean_constant_name) name_to_add = constant_name if isinstance(desc, dt.Scalar): parent.weights[name_to_add] = output_value.reshape(()) else: parent.weights[name_to_add] = output_value state.add_edge(access_constant, None, edge.dst, edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) # remove all now useless nodes with a reverse BFS remove_node_and_computation(sdfg, state, node)
def backward( forward_node: Node, context: BackwardContext, given_gradients: typing.List[typing.Optional[str]], required_gradients: typing.List[typing.Optional[str]] ) -> typing.Tuple[Node, BackwardResult]: reduction_type = detect_reduction_type(forward_node.wcr) if len(given_gradients) != 1: raise AutoDiffException( "recieved invalid SDFG: reduce node {} should have exactly one output edge" .format(forward_node)) if len(required_gradients) != 1: raise AutoDiffException( "recieved invalid SDFG: reduce node {} should have exactly one input edge" .format(forward_node)) input_name = next(iter(required_gradients)) in_desc = in_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, input_name) output_name = next(iter(given_gradients)) out_desc = out_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, output_name) all_axes: typing.List[int] = list(range(len(in_desc.shape))) reduce_axes: typing.List[ int] = all_axes if forward_node.axes is None else forward_node.axes non_reduce_axes: typing.List[int] = [ i for i in all_axes if i not in reduce_axes ] result = BackwardResult.empty() if reduction_type is dtypes.ReductionType.Sum: # in this case, we need to simply scatter the grad across the axes that were reduced sdfg = SDFG("_reverse_" + str(reduction_type).replace(".", "_") + "_") state = sdfg.add_state() rev_input_conn_name = "input_gradient" rev_output_conn_name = "output_gradient" result.required_grad_names[output_name] = rev_output_conn_name result.given_grad_names[input_name] = rev_input_conn_name _, rev_input_arr = sdfg.add_array(rev_input_conn_name, shape=out_desc.shape, dtype=out_desc.dtype) _, rev_output_arr = sdfg.add_array(rev_output_conn_name, shape=in_desc.shape, dtype=in_desc.dtype) state.add_mapped_tasklet( "_distribute_grad_" + str(reduction_type).replace(".", "_") + "_", { "i" + str(i): "0:{}".format(shape) for i, shape in enumerate(in_desc.shape) }, { "__in": Memlet.simple( rev_input_conn_name, "0" if forward_node.axes is None else ",".join( "i" + str(i) for i in non_reduce_axes)) }, "__out = __in", { "__out": Memlet.simple(rev_output_conn_name, ",".join("i" + str(i) for i in all_axes), wcr_str="lambda x, y: x + y") }, external_edges=True) return context.backward_state.add_nested_sdfg( sdfg, None, {rev_input_conn_name}, {rev_output_conn_name}), result else: raise AutoDiffException( "Unsupported reduction type '{}'".format(reduction_type))