def mkc(sdfg: dace.SDFG, state_before, src_name, dst_name, src_storage=None, dst_storage=None, src_shape=None, dst_shape=None, copy_expr=None, src_loc=None, dst_loc=None): """ Helper MaKe_Copy that creates and appends states performing exactly one copy. If a provided arrayname already exists it will use the old array, and ignore all newly passed values """ if copy_expr is None: copy_expr = src_name if (state_before == None): state = sdfg.add_state(is_start_state=True) else: state = sdfg.add_state_after(state_before) def mkarray(name, shape, storage, loc): if (name in sdfg.arrays): return sdfg.arrays[name] is_transient = False if (storage in _FPGA_STORAGE_TYPES): is_transient = True arr = sdfg.add_array(name, shape, dace.int32, storage, transient=is_transient) if loc is not None: arr[1].location["memorytype"] = loc[0] arr[1].location["bank"] = loc[1] return arr a = mkarray(src_name, src_shape, src_storage, src_loc) b = mkarray(dst_name, dst_shape, dst_storage, dst_loc) aAcc = state.add_access(src_name) bAcc = state.add_access(dst_name) edge = state.add_edge(aAcc, None, bAcc, None, mem.Memlet(copy_expr)) a_np_arr, b_np_arr = None, None if src_shape is not None: try: a_np_arr = np.zeros(src_shape, dtype=np.int32) except: pass if dst_shape is not None: try: b_np_arr = np.zeros(dst_shape, dtype=np.int32) except: pass return (state, a_np_arr, b_np_arr)
def add_backward_pass( sdfg: SDFG, state: SDFGState, outputs: typing.List[typing.Union[nd.AccessNode, str]], inputs: typing.List[typing.Union[nd.AccessNode, str]], ): """ Experimental: Add a backward pass to `state` using reverse-mode automatic differentiation. ``inputs``, ``outputs`` and ``grads`` can be provided either as ``AccessNode`` nodes, or as ``str``, in which case the graph will be searched for exactly one matching ``AccessNode`` with data matching the ``str``. The SDFG should not contain any inplace operations. It may contain the following nodes: * Maps * AccessNodes * Reductions (Sum, Min, Max) * ONNXOps * NestedSDFGs containing a single SDFGState (subject to the same constraints). NestedSDFGs may contain multiple states as long as all other states are only used for zero initialization. When differentiating an :class:`~daceml.onnx.nodes.onnx_op.ONNXOp`, the ONNXBackward registry will be checked for any matching backward pass implementations. If none are found, the ONNXForward registry will be checked for matching pure implementations. If one is found, symbolic differentiation of the pure implementation will be attempted. If this fails, or no pure forward implementation is found, the method will fail. :param sdfg: the parent SDFG of ``state``. :param state: the state to add the backward pass to. This is also the state of the forward pass. :param outputs: the forward pass outputs of the function to differentiate. :param inputs: the inputs w.r.t. which the gradient will be returned. """ sdfg.validate() backward_state = sdfg.add_state_after(state) gen = BackwardPassGenerator(sdfg=sdfg, state=state, given_gradients=outputs, required_gradients=inputs, backward_sdfg=sdfg, backward_state=backward_state) gen.backward()
def apply(self, sdfg: SDFG): input: nodes.AccessNode = self.input(sdfg) tasklet: nodes.Tasklet = self.tasklet(sdfg) output: nodes.AccessNode = self.output(sdfg) state: SDFGState = sdfg.node(self.state_id) # If state fission is necessary to keep semantics, do it first if (self.expr_index == 0 and state.in_degree(input) > 0 and state.out_degree(output) == 0): newstate = sdfg.add_state_after(state) newstate.add_node(tasklet) new_input, new_output = None, None # Keep old edges for after we remove tasklet from the original state in_edges = list(state.in_edges(tasklet)) out_edges = list(state.out_edges(tasklet)) for e in in_edges: r = newstate.add_read(e.src.data) newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data) if e.src is input: new_input = r for e in out_edges: w = newstate.add_write(e.dst.data) newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data) if e.dst is output: new_output = w # Remove tasklet and resulting isolated nodes state.remove_node(tasklet) for e in in_edges: if state.degree(e.src) == 0: state.remove_node(e.src) for e in out_edges: if state.degree(e.dst) == 0: state.remove_node(e.dst) # Reset state and nodes for rest of transformation input = new_input output = new_output state = newstate # End of state fission if self.expr_index == 0: inedges = state.edges_between(input, tasklet) outedge = state.edges_between(tasklet, output)[0] else: me = self.map_entry(sdfg) mx = self.map_exit(sdfg) inedges = state.edges_between(me, tasklet) outedge = state.edges_between(tasklet, mx)[0] # Get relevant output connector outconn = outedge.src_conn ops = '[%s]' % ''.join( re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) # Change tasklet code if tasklet.language is dtypes.Language.Python: raise NotImplementedError elif tasklet.language is dtypes.Language.CPP: cstr = tasklet.code.as_string.strip() for edge in inedges: inconn = edge.dst_conn match = re.match( r'^\s*%s\s*=\s*%s\s*(%s)(.*);$' % (re.escape(outconn), re.escape(inconn), ops), cstr) if match is None: # match = re.match( # r'^\s*%s\s*=\s*(.*)\s*(%s)\s*%s;$' % # (re.escape(outconn), ops, re.escape(inconn)), cstr) # if match is None: continue # op = match.group(2) # expr = match.group(1) else: op = match.group(1) expr = match.group(2) if edge.data.subset != outedge.data.subset: continue # Map asymmetric WCRs to symmetric ones if possible if op in AugAssignToWCR._EXPR_MAP: op, newexpr = AugAssignToWCR._EXPR_MAP[op] expr = newexpr.format(expr=expr) tasklet.code.code = '%s = %s;' % (outconn, expr) inedge = edge break else: raise NotImplementedError # Change output edge outedge.data.wcr = f'lambda a,b: a {op} b' if self.expr_index == 0: # Remove input node and connector state.remove_edge_and_connectors(inedge) if state.degree(input) == 0: state.remove_node(input) else: # Remove input edge and dst connector, but not necessarily src state.remove_memlet_path(inedge) # If outedge leads to non-transient, and this is a nested SDFG, # propagate outwards sd = sdfg while (not sd.arrays[outedge.data.data].transient and sd.parent_nsdfg_node is not None): nsdfg = sd.parent_nsdfg_node nstate = sd.parent sd = sd.parent_sdfg outedge = next( iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data))) for outedge in nstate.memlet_path(outedge): outedge.data.wcr = f'lambda a,b: a {op} b'