def op_repo_replacement(sdfg: SDFG, state: SDFGState, **kwargs): attrs = { name: value for name, value in kwargs.items() if name in dace_schema.attributes } onnx_node = cls(name=cls_name, **attrs) state.add_node(onnx_node) input_names = {p.name for p in dace_schema.inputs} output_names = {p.name for p in dace_schema.outputs} inputs = { name: arr_name for name, arr_name in kwargs.items() if name in input_names } outputs = { name: arr_name for name, arr_name in kwargs.items() if name in output_names } for inp, arr_name in inputs.items(): read = state.add_read(arr_name) state.add_edge(read, None, onnx_node, inp, sdfg.make_array_memlet(arr_name)) for outp, arr_name in outputs.items(): write = state.add_read(arr_name) state.add_edge(onnx_node, outp, write, None, sdfg.make_array_memlet(arr_name)) return []
def op_repo_replacement(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, **kwargs): attrs = { name: value for name, value in kwargs.items() if name in dace_schema.attributes } # remove used attrs kwargs = {k: v for k, v in kwargs.items() if k not in attrs} onnx_node = cls(name=cls_name, **attrs) state.add_node(onnx_node) input_names = dace_schema.non_variadic_inputs() variadic_inputs = dace_schema.variadic_inputs() output_names = dace_schema.non_variadic_outputs() variadic_outputs = dace_schema.variadic_outputs() inputs = { name: arr_name for name, arr_name in kwargs.items() if (name in input_names or # variadic params ("__" in name and parse_variadic_param(name)[0] in variadic_inputs)) } kwargs = {k: v for k, v in kwargs.items() if k not in inputs} outputs = { name: arr_name for name, arr_name in kwargs.items() if (name in output_names or # variadic params ("__" in name and parse_variadic_param(name)[0] in variadic_outputs)) } kwargs = {k: v for k, v in kwargs.items() if k not in outputs} if len(kwargs) > 0: raise TypeError(f"Unknown arguments {', '.join(kwargs)}") for inp, arr_name in inputs.items(): read = state.add_read(arr_name) state.add_edge(read, None, onnx_node, inp, sdfg.make_array_memlet(arr_name)) onnx_node.add_in_connector(inp) for outp, arr_name in outputs.items(): write = state.add_read(arr_name) state.add_edge(onnx_node, outp, write, None, sdfg.make_array_memlet(arr_name)) onnx_node.add_out_connector(outp) return []
def apply(self, sdfg: dace.SDFG): # Extract the subgraph, execute it and insert an AccessNode to the result # this method of execution is slow but simple. A better option would be to call the ORT # C API from a python object (like the OpChecker). parent: ONNXModel = sdfg._parent_onnx_model state = sdfg.nodes()[self.state_id] node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]] log.debug(f"Applying constant folding: {node} in {state}") if isinstance(node, donnx.ONNXShape): # if we have a shape node, replace it with a constant assert len(state.in_edges(node)) == 1 shape_in_edge = state.in_edges(node)[0] assert shape_in_edge.dst_conn == "data" shape_desc = sdfg.arrays[shape_in_edge.src.data] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ), dace.int64) assert constant_name not in parent.clean_weights parent.weights[constant_name] = torch.from_numpy( np.array(shape_desc.shape, np.int64)) assert len(state.out_edges(node)) == 1 output_edge = state.out_edges(node)[0] access_shape = state.add_access(clean_constant_name) state.add_edge(access_shape, None, output_edge.dst, output_edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) else: # otherwise compute the result of the op global UNIQUE_ID UNIQUE_ID += 1 sub_sdfg = dace.SDFG("sub_sdfg_" + str(UNIQUE_ID)) sub_state = sub_sdfg.add_state() node_copy = copy.deepcopy(node) sub_state.add_node(node_copy) inputs = {} for edge in state.in_edges(node): # we know from can_be_applied that all in edges are from AccessNodes assert (isinstance(edge.src, nd.AccessNode) and hasattr(sdfg, "_parent_onnx_model") and edge.src.data in sdfg._parent_onnx_model.clean_weights) desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc) input_value = sdfg._parent_onnx_model.clean_weights[ edge.src.data] if len(input_value.shape) == 0: inputs['array_' + edge.dst_conn] = input_value.cpu().numpy()[()] else: inputs['array_' + edge.dst_conn] = input_value.clone() access = sub_state.add_access('array_' + edge.dst_conn) sub_state.add_edge( access, None, node_copy, edge.dst_conn, sub_sdfg.make_array_memlet('array_' + edge.dst_conn)) outputs = {} for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) if isinstance(desc, dt.Scalar): # we need to copy to an array of size [1] so that we can "return" the output from the sdfg desc.transient = True sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn, desc) sub_sdfg.add_array('array_' + edge.src_conn, [1], desc.dtype, transient=False) access_scalar = sub_state.add_access('scalar_array_' + edge.src_conn) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access_scalar, None, sub_sdfg.make_array_memlet('scalar_array_' + edge.src_conn)) sub_state.add_edge( access_scalar, None, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) else: desc.transient = False sub_sdfg.add_datadesc('array_' + edge.src_conn, desc) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) if len(desc.shape) == 0: empty_array = np.empty((1, ), desc.dtype.as_numpy_dtype()) else: empty_array = np.empty(tuple(desc.shape), desc.dtype.as_numpy_dtype()) empty_array = torch.from_numpy(empty_array) if desc.storage is dtypes.StorageType.GPU_Global: empty_array = empty_array.cuda() outputs['array_' + edge.src_conn] = empty_array sub_sdfg(**outputs, **inputs) for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False output_value = outputs['array_' + edge.src_conn] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_datadesc(clean_constant_name, desc) assert constant_name not in parent.weights assert type(output_value) is torch.Tensor if not dtypes.can_access(dtypes.ScheduleType.CPU_Multicore, desc.storage): cpu_desc = copy.deepcopy(desc) cpu_desc.storage = dtypes.StorageType.CPU_Heap cpu_desc.transient = False desc.transient = True copy_in_name = sdfg.temp_data_name() clean_copy_in_name = clean_onnx_name(copy_in_name) sdfg.add_datadesc(clean_copy_in_name, cpu_desc) access_constant = state.add_access(clean_constant_name) state.add_edge(state.add_read(clean_copy_in_name), None, access_constant, None, sdfg.make_array_memlet(clean_copy_in_name)) name_to_add = copy_in_name else: access_constant = state.add_read(clean_constant_name) name_to_add = constant_name if isinstance(desc, dt.Scalar): parent.weights[name_to_add] = output_value.reshape(()) else: parent.weights[name_to_add] = output_value state.add_edge(access_constant, None, edge.dst, edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) # remove all now useless nodes with a reverse BFS remove_node_and_computation(sdfg, state, node)
def apply(self, sdfg: dace.SDFG): # Extract the subgraph, execute it and insert an AccessNode to the result parent: ONNXModel = sdfg._parent_onnx_model state = sdfg.nodes()[self.state_id] node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]] if isinstance(node, donnx.ONNXShape): # if we have a shape node, replace it with a constant assert len(state.in_edges(node)) == 1 shape_in_edge = state.in_edges(node)[0] assert shape_in_edge.dst_conn == "data" shape_desc = sdfg.arrays[shape_in_edge.src.data] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ), dace.int64) assert constant_name not in parent.clean_weights parent.weights[constant_name] = np.array(shape_desc.shape, np.int64) assert len(state.out_edges(node)) == 1 output_edge = state.out_edges(node)[0] access_shape = state.add_access(clean_constant_name) state.add_edge(access_shape, None, output_edge.dst, output_edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) else: # otherwise compute the result of the op sub_sdfg = dace.SDFG("sub_sdfg") sub_state = sub_sdfg.add_state() node_copy = copy.deepcopy(node) sub_state.add_node(node_copy) inputs = {} for edge in state.in_edges(node): # we know from can_be_applied that all in edges are from AccessNodes assert (isinstance(edge.src, nd.AccessNode) and hasattr(sdfg, "_parent_onnx_model") and edge.src.data in sdfg._parent_onnx_model.clean_weights) desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc) input_value = sdfg._parent_onnx_model.clean_weights[ edge.src.data] if len(input_value.shape) == 0: inputs['array_' + edge.dst_conn] = input_value[()] else: inputs['array_' + edge.dst_conn] = input_value.copy() access = sub_state.add_access('array_' + edge.dst_conn) sub_state.add_edge( access, None, node_copy, edge.dst_conn, sub_sdfg.make_array_memlet('array_' + edge.dst_conn)) outputs = {} for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) if isinstance(desc, dt.Scalar): # we need to copy to an array of size [1] so that we can "return" the output from the sdfg desc.transient = True sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn, desc) sub_sdfg.add_array('array_' + edge.src_conn, [1], desc.dtype, transient=False) access_scalar = sub_state.add_access('scalar_array_' + edge.src_conn) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access_scalar, None, sub_sdfg.make_array_memlet('scalar_array_' + edge.src_conn)) sub_state.add_edge( access_scalar, None, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) else: desc.transient = False sub_sdfg.add_datadesc('array_' + edge.src_conn, desc) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) if len(desc.shape) == 0: outputs['array_' + edge.src_conn] = np.empty( (1, ), desc.dtype.as_numpy_dtype()) else: outputs['array_' + edge.src_conn] = np.empty( tuple(desc.shape), desc.dtype.as_numpy_dtype()) sub_sdfg(**outputs, **inputs) for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False output_value = outputs['array_' + edge.src_conn] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_datadesc(clean_constant_name, desc) assert constant_name not in parent.weights if isinstance(desc, dt.Scalar): parent.weights[constant_name] = output_value.reshape(()) else: parent.weights[constant_name] = output_value access_constant = state.add_access(clean_constant_name) state.add_edge(access_constant, None, edge.dst, edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) # remove all now useless nodes with a reverse BFS queue = deque([node]) while len(queue) > 0: current_node = queue.popleft() edges = state.in_edges(current_node) state.remove_node(current_node) for e in edges: next_node = e.src if len(state.out_edges(next_node)) == 0: queue.append(next_node)