def add_backward_desc_for_connector(backward_sdfg: dace.SDFG, forward_node: nd.Node, context: BackwardContext, connector: str, input: bool) -> str: """ Adds the backward array for the connector of ``forward_node``. :param backward_sdfg: the sdfg to add to. :param forward_node: the forward node with the connector that we want to add a descriptor for :param connector: the connector on the forward node that we want to add the descriptor for :param input: ``True`` if the connector is an input, ``False`` otherwise :return: the name of the newly added array in ``backward_sdfg``. """ if input: edge = utils.in_edge_with_name(forward_node, context.forward_state, connector) else: edge = utils.out_edge_with_name(forward_node, context.forward_state, connector) arr_name = edge.data.data forward_desc = context.forward_sdfg.arrays[arr_name] new_desc = copy.deepcopy(forward_desc) new_desc.transient = False return backward_sdfg.add_datadesc(arr_name + "_grad", new_desc, find_new_name=True)
def add_backward_desc(backward_sdfg: dace.SDFG, forward_sdfg: dace.SDFG, forward_desc: dt.Data, forward_name: str) -> str: """ Adds the backward array for the given descriptor. :param backward_sdfg: the sdfg to add to. :param forward_sdfg: the forward sdfg. :param forward_desc: the data descriptor of the forward array from ``forward_sdfg``. :param forward_name: a name for the forward array (does not have to match it's actual name). :return: the name of the newly added array in ``backward_sdfg``. """ backward_name = utils.find_str_not_in_set(forward_sdfg.arrays, forward_name + "_grad") new_desc = copy.deepcopy(forward_desc) new_desc.transient = False return backward_sdfg.add_datadesc(backward_name, new_desc)
def apply(self, sdfg: dace.SDFG): # Extract the subgraph, execute it and insert an AccessNode to the result # this method of execution is slow but simple. A better option would be to call the ORT # C API from a python object (like the OpChecker). parent: ONNXModel = sdfg._parent_onnx_model state = sdfg.nodes()[self.state_id] node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]] log.debug(f"Applying constant folding: {node} in {state}") if isinstance(node, donnx.ONNXShape): # if we have a shape node, replace it with a constant assert len(state.in_edges(node)) == 1 shape_in_edge = state.in_edges(node)[0] assert shape_in_edge.dst_conn == "data" shape_desc = sdfg.arrays[shape_in_edge.src.data] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ), dace.int64) assert constant_name not in parent.clean_weights parent.weights[constant_name] = torch.from_numpy( np.array(shape_desc.shape, np.int64)) assert len(state.out_edges(node)) == 1 output_edge = state.out_edges(node)[0] access_shape = state.add_access(clean_constant_name) state.add_edge(access_shape, None, output_edge.dst, output_edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) else: # otherwise compute the result of the op global UNIQUE_ID UNIQUE_ID += 1 sub_sdfg = dace.SDFG("sub_sdfg_" + str(UNIQUE_ID)) sub_state = sub_sdfg.add_state() node_copy = copy.deepcopy(node) sub_state.add_node(node_copy) inputs = {} for edge in state.in_edges(node): # we know from can_be_applied that all in edges are from AccessNodes assert (isinstance(edge.src, nd.AccessNode) and hasattr(sdfg, "_parent_onnx_model") and edge.src.data in sdfg._parent_onnx_model.clean_weights) desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc) input_value = sdfg._parent_onnx_model.clean_weights[ edge.src.data] if len(input_value.shape) == 0: inputs['array_' + edge.dst_conn] = input_value.cpu().numpy()[()] else: inputs['array_' + edge.dst_conn] = input_value.clone() access = sub_state.add_access('array_' + edge.dst_conn) sub_state.add_edge( access, None, node_copy, edge.dst_conn, sub_sdfg.make_array_memlet('array_' + edge.dst_conn)) outputs = {} for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) if isinstance(desc, dt.Scalar): # we need to copy to an array of size [1] so that we can "return" the output from the sdfg desc.transient = True sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn, desc) sub_sdfg.add_array('array_' + edge.src_conn, [1], desc.dtype, transient=False) access_scalar = sub_state.add_access('scalar_array_' + edge.src_conn) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access_scalar, None, sub_sdfg.make_array_memlet('scalar_array_' + edge.src_conn)) sub_state.add_edge( access_scalar, None, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) else: desc.transient = False sub_sdfg.add_datadesc('array_' + edge.src_conn, desc) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) if len(desc.shape) == 0: empty_array = np.empty((1, ), desc.dtype.as_numpy_dtype()) else: empty_array = np.empty(tuple(desc.shape), desc.dtype.as_numpy_dtype()) empty_array = torch.from_numpy(empty_array) if desc.storage is dtypes.StorageType.GPU_Global: empty_array = empty_array.cuda() outputs['array_' + edge.src_conn] = empty_array sub_sdfg(**outputs, **inputs) for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False output_value = outputs['array_' + edge.src_conn] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_datadesc(clean_constant_name, desc) assert constant_name not in parent.weights assert type(output_value) is torch.Tensor if not dtypes.can_access(dtypes.ScheduleType.CPU_Multicore, desc.storage): cpu_desc = copy.deepcopy(desc) cpu_desc.storage = dtypes.StorageType.CPU_Heap cpu_desc.transient = False desc.transient = True copy_in_name = sdfg.temp_data_name() clean_copy_in_name = clean_onnx_name(copy_in_name) sdfg.add_datadesc(clean_copy_in_name, cpu_desc) access_constant = state.add_access(clean_constant_name) state.add_edge(state.add_read(clean_copy_in_name), None, access_constant, None, sdfg.make_array_memlet(clean_copy_in_name)) name_to_add = copy_in_name else: access_constant = state.add_read(clean_constant_name) name_to_add = constant_name if isinstance(desc, dt.Scalar): parent.weights[name_to_add] = output_value.reshape(()) else: parent.weights[name_to_add] = output_value state.add_edge(access_constant, None, edge.dst, edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) # remove all now useless nodes with a reverse BFS remove_node_and_computation(sdfg, state, node)
def apply(self, sdfg: dace.SDFG): # Extract the subgraph, execute it and insert an AccessNode to the result parent: ONNXModel = sdfg._parent_onnx_model state = sdfg.nodes()[self.state_id] node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]] if isinstance(node, donnx.ONNXShape): # if we have a shape node, replace it with a constant assert len(state.in_edges(node)) == 1 shape_in_edge = state.in_edges(node)[0] assert shape_in_edge.dst_conn == "data" shape_desc = sdfg.arrays[shape_in_edge.src.data] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ), dace.int64) assert constant_name not in parent.clean_weights parent.weights[constant_name] = np.array(shape_desc.shape, np.int64) assert len(state.out_edges(node)) == 1 output_edge = state.out_edges(node)[0] access_shape = state.add_access(clean_constant_name) state.add_edge(access_shape, None, output_edge.dst, output_edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) else: # otherwise compute the result of the op sub_sdfg = dace.SDFG("sub_sdfg") sub_state = sub_sdfg.add_state() node_copy = copy.deepcopy(node) sub_state.add_node(node_copy) inputs = {} for edge in state.in_edges(node): # we know from can_be_applied that all in edges are from AccessNodes assert (isinstance(edge.src, nd.AccessNode) and hasattr(sdfg, "_parent_onnx_model") and edge.src.data in sdfg._parent_onnx_model.clean_weights) desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc) input_value = sdfg._parent_onnx_model.clean_weights[ edge.src.data] if len(input_value.shape) == 0: inputs['array_' + edge.dst_conn] = input_value[()] else: inputs['array_' + edge.dst_conn] = input_value.copy() access = sub_state.add_access('array_' + edge.dst_conn) sub_state.add_edge( access, None, node_copy, edge.dst_conn, sub_sdfg.make_array_memlet('array_' + edge.dst_conn)) outputs = {} for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) if isinstance(desc, dt.Scalar): # we need to copy to an array of size [1] so that we can "return" the output from the sdfg desc.transient = True sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn, desc) sub_sdfg.add_array('array_' + edge.src_conn, [1], desc.dtype, transient=False) access_scalar = sub_state.add_access('scalar_array_' + edge.src_conn) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access_scalar, None, sub_sdfg.make_array_memlet('scalar_array_' + edge.src_conn)) sub_state.add_edge( access_scalar, None, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) else: desc.transient = False sub_sdfg.add_datadesc('array_' + edge.src_conn, desc) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) if len(desc.shape) == 0: outputs['array_' + edge.src_conn] = np.empty( (1, ), desc.dtype.as_numpy_dtype()) else: outputs['array_' + edge.src_conn] = np.empty( tuple(desc.shape), desc.dtype.as_numpy_dtype()) sub_sdfg(**outputs, **inputs) for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False output_value = outputs['array_' + edge.src_conn] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_datadesc(clean_constant_name, desc) assert constant_name not in parent.weights if isinstance(desc, dt.Scalar): parent.weights[constant_name] = output_value.reshape(()) else: parent.weights[constant_name] = output_value access_constant = state.add_access(clean_constant_name) state.add_edge(access_constant, None, edge.dst, edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) # remove all now useless nodes with a reverse BFS queue = deque([node]) while len(queue) > 0: current_node = queue.popleft() edges = state.in_edges(current_node) state.remove_node(current_node) for e in edges: next_node = e.src if len(state.out_edges(next_node)) == 0: queue.append(next_node)