def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: node.validate(sdfg, state) def prog(input, output): output[:] = dace.elementwise(lambda x: tanh(x), input) return program_for_node(prog, sdfg, state, node)
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: node.validate(sdfg, state) def prog(input, output): output[:] = input return program_for_node(prog, sdfg, state, node)
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: node.validate(sdfg, state) def prog(A, B, C): C[:] = A / B return program_for_node(prog, sdfg, state, node)
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: node.validate(sdfg, state) def prog(X, Y, Z): Z[:] = X**Y return program_for_node(prog, sdfg, state, node)
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: node.validate(sdfg, state) def prog(X, Y): Y[:] = dace.elementwise(lambda x: sqrt(x), X) return program_for_node(prog, sdfg, state, node)
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: node.validate(sdfg, state) perm = node.perm def prog(data, transposed): transposed[:] = np.transpose(data, axes=perm) return program_for_node(prog, sdfg, state, node)
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: new_shape = out_desc_with_name(node, state, sdfg, "reshaped").shape node.remove_in_connector("shape") shape_node = in_edge_with_name(node, state, "shape").src constant_folding.remove_node_and_computation(sdfg, state, shape_node) def prog(data, reshaped): reshaped[:] = np.reshape(data, new_shape) return program_for_node(prog, sdfg, state, node)
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: node.validate(sdfg, state) axes = node.axes # when keepdims is true, this works but there is a useless copy. We just leave this for now; this can be fixed # with a reshape node when those exist. def prog(data, reduced): reduced[:] = np.min(data, axis=axes) return program_for_node(prog, sdfg, state, node)
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: node.validate(sdfg, state) dtype = in_desc_with_name(node, state, sdfg, 'X').dtype tanh_lambda = "lambda x: dace.{}(1) / x".format(dtype.to_string()) def prog(X, Y): Y[:] = dace.elementwise(tanh_lambda, X) return program_for_node(prog, sdfg, state, node)
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: node.validate(sdfg, state) assert node.alpha == 1.0 and node.beta == 1.0 and node.transA == 0 and node.transB == 1 # the gemm libnode is broken for now, so we just do it manually if "C" in node.in_connectors: def prog(A, B, C, Y): Y[:] = A @ np.transpose(B) + C else: def prog(A, B, Y): Y[:] = A @ np.transpose(B) sdfg = program_for_node(prog, sdfg, state, node) sdfg.apply_strict_transformations() return sdfg
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: node.validate(sdfg, state) nsdfg = dace.SDFG(node.label + "_expansion") nstate = nsdfg.add_state() for e in node.iter_inputs_in_onnx_order(state): nsdfg.add_datadesc( e.dst_conn, in_desc_with_name(node, state, sdfg, e.dst_conn)) for e in node.iter_outputs_in_onnx_order(state): nsdfg.add_datadesc( e.src_conn, out_desc_with_name(node, state, sdfg, e.src_conn)) create_einsum_sdfg(None, nsdfg, nstate, node.equation.replace(" ", ""), *(e.dst_conn for e in node.iter_inputs_in_onnx_order(state)), output="Output") return nsdfg
class ConstantFolding(transformation.Transformation): """ Remove nodes where all inputs are known and replace them with constant nodes by precomputing the output. """ # pattern matching only checks that the type of the node matches, _onnx_node = ONNXOp("_") @staticmethod def expressions(): return [sdutil.node_path_graph(ConstantFolding._onnx_node)] @staticmethod def is_constant(sdfg: dace.SDFG, state: dace.SDFGState, node) -> bool: if len(state.in_edges(node)) > 0: return False # the ONNX importer adds a _parent_onnx_model attribute to the sdfg if isinstance(node, nd.AccessNode ) and node.data in sdfg._parent_onnx_model.clean_weights: return True return False @staticmethod def can_be_applied(graph: dace.sdfg.graph.OrderedMultiDiConnectorGraph, candidate: Dict[nd.Node, int], expr_index: int, sdfg, strict: bool = False): node: ONNXOp = graph.nodes()[candidate[ConstantFolding._onnx_node]] # SDFG must be imported from an ONNXModel if not hasattr(sdfg, "_parent_onnx_model"): return False if not 'ONNX' + node.schema.name not in NONDETERMINISTIC_OPS: return False if isinstance(node, donnx.ONNXShape): return True # all inputs are constant for edge in graph.in_edges(node): if not ConstantFolding.is_constant(sdfg, graph, edge.src): return False return True @staticmethod def match_to_str(graph, candidate): node: ONNXOp = graph.nodes()[candidate[ConstantFolding._onnx_node]] return "Precompute outputs of {}".format(node) def apply(self, sdfg: dace.SDFG): # Extract the subgraph, execute it and insert an AccessNode to the result # this method of execution is slow but simple. A better option would be to call the ORT # C API from a python object (like the OpChecker). parent: ONNXModel = sdfg._parent_onnx_model state = sdfg.nodes()[self.state_id] node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]] log.debug(f"Applying constant folding: {node} in {state}") if isinstance(node, donnx.ONNXShape): # if we have a shape node, replace it with a constant assert len(state.in_edges(node)) == 1 shape_in_edge = state.in_edges(node)[0] assert shape_in_edge.dst_conn == "data" shape_desc = sdfg.arrays[shape_in_edge.src.data] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ), dace.int64) assert constant_name not in parent.clean_weights parent.weights[constant_name] = torch.from_numpy( np.array(shape_desc.shape, np.int64)) assert len(state.out_edges(node)) == 1 output_edge = state.out_edges(node)[0] access_shape = state.add_access(clean_constant_name) state.add_edge(access_shape, None, output_edge.dst, output_edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) else: # otherwise compute the result of the op global UNIQUE_ID UNIQUE_ID += 1 sub_sdfg = dace.SDFG("sub_sdfg_" + str(UNIQUE_ID)) sub_state = sub_sdfg.add_state() node_copy = copy.deepcopy(node) sub_state.add_node(node_copy) inputs = {} for edge in state.in_edges(node): # we know from can_be_applied that all in edges are from AccessNodes assert (isinstance(edge.src, nd.AccessNode) and hasattr(sdfg, "_parent_onnx_model") and edge.src.data in sdfg._parent_onnx_model.clean_weights) desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc) input_value = sdfg._parent_onnx_model.clean_weights[ edge.src.data] if len(input_value.shape) == 0: inputs['array_' + edge.dst_conn] = input_value.cpu().numpy()[()] else: inputs['array_' + edge.dst_conn] = input_value.clone() access = sub_state.add_access('array_' + edge.dst_conn) sub_state.add_edge( access, None, node_copy, edge.dst_conn, sub_sdfg.make_array_memlet('array_' + edge.dst_conn)) outputs = {} for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) if isinstance(desc, dt.Scalar): # we need to copy to an array of size [1] so that we can "return" the output from the sdfg desc.transient = True sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn, desc) sub_sdfg.add_array('array_' + edge.src_conn, [1], desc.dtype, transient=False) access_scalar = sub_state.add_access('scalar_array_' + edge.src_conn) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access_scalar, None, sub_sdfg.make_array_memlet('scalar_array_' + edge.src_conn)) sub_state.add_edge( access_scalar, None, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) else: desc.transient = False sub_sdfg.add_datadesc('array_' + edge.src_conn, desc) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) if len(desc.shape) == 0: empty_array = np.empty((1, ), desc.dtype.as_numpy_dtype()) else: empty_array = np.empty(tuple(desc.shape), desc.dtype.as_numpy_dtype()) empty_array = torch.from_numpy(empty_array) if desc.storage is dtypes.StorageType.GPU_Global: empty_array = empty_array.cuda() outputs['array_' + edge.src_conn] = empty_array sub_sdfg(**outputs, **inputs) for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False output_value = outputs['array_' + edge.src_conn] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_datadesc(clean_constant_name, desc) assert constant_name not in parent.weights assert type(output_value) is torch.Tensor if not dtypes.can_access(dtypes.ScheduleType.CPU_Multicore, desc.storage): cpu_desc = copy.deepcopy(desc) cpu_desc.storage = dtypes.StorageType.CPU_Heap cpu_desc.transient = False desc.transient = True copy_in_name = sdfg.temp_data_name() clean_copy_in_name = clean_onnx_name(copy_in_name) sdfg.add_datadesc(clean_copy_in_name, cpu_desc) access_constant = state.add_access(clean_constant_name) state.add_edge(state.add_read(clean_copy_in_name), None, access_constant, None, sdfg.make_array_memlet(clean_copy_in_name)) name_to_add = copy_in_name else: access_constant = state.add_read(clean_constant_name) name_to_add = constant_name if isinstance(desc, dt.Scalar): parent.weights[name_to_add] = output_value.reshape(()) else: parent.weights[name_to_add] = output_value state.add_edge(access_constant, None, edge.dst, edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) # remove all now useless nodes with a reverse BFS remove_node_and_computation(sdfg, state, node)
class ConstantFolding(transformation.Transformation): """ Remove nodes where all inputs are known and replace them with constant nodes by precomputing the output. """ # pattern matching only checks that the type of the node matches, _onnx_node = ONNXOp("_") @staticmethod def expressions(): return [sdutil.node_path_graph(ConstantFolding._onnx_node)] @staticmethod def is_constant(sdfg: dace.SDFG, state: dace.SDFGState, node) -> bool: if len(state.in_edges(node)) > 0: return False # the ONNX importer adds a _parent_onnx_model attribute to the sdfg if isinstance(node, nd.AccessNode ) and node.data in sdfg._parent_onnx_model.clean_weights: return True return False @staticmethod def can_be_applied(graph: dace.sdfg.graph.OrderedMultiDiConnectorGraph, candidate: Dict[nd.Node, int], expr_index: int, sdfg, strict: bool = False): node: ONNXOp = graph.nodes()[candidate[ConstantFolding._onnx_node]] # SDFG must be imported from an ONNXModel if not hasattr(sdfg, "_parent_onnx_model"): return False if not 'ONNX' + node.schema.name not in NONDETERMINISTIC_OPS: return False if isinstance(node, donnx.ONNXShape): return True # all inputs are constant for edge in graph.in_edges(node): if not ConstantFolding.is_constant(sdfg, graph, edge.src): return False return True @staticmethod def match_to_str(graph, candidate): node: ONNXOp = graph.nodes()[candidate[ConstantFolding._onnx_node]] return "Precompute outputs of {}".format(node) def apply(self, sdfg: dace.SDFG): # Extract the subgraph, execute it and insert an AccessNode to the result parent: ONNXModel = sdfg._parent_onnx_model state = sdfg.nodes()[self.state_id] node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]] if isinstance(node, donnx.ONNXShape): # if we have a shape node, replace it with a constant assert len(state.in_edges(node)) == 1 shape_in_edge = state.in_edges(node)[0] assert shape_in_edge.dst_conn == "data" shape_desc = sdfg.arrays[shape_in_edge.src.data] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ), dace.int64) assert constant_name not in parent.clean_weights parent.weights[constant_name] = np.array(shape_desc.shape, np.int64) assert len(state.out_edges(node)) == 1 output_edge = state.out_edges(node)[0] access_shape = state.add_access(clean_constant_name) state.add_edge(access_shape, None, output_edge.dst, output_edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) else: # otherwise compute the result of the op sub_sdfg = dace.SDFG("sub_sdfg") sub_state = sub_sdfg.add_state() node_copy = copy.deepcopy(node) sub_state.add_node(node_copy) inputs = {} for edge in state.in_edges(node): # we know from can_be_applied that all in edges are from AccessNodes assert (isinstance(edge.src, nd.AccessNode) and hasattr(sdfg, "_parent_onnx_model") and edge.src.data in sdfg._parent_onnx_model.clean_weights) desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc) input_value = sdfg._parent_onnx_model.clean_weights[ edge.src.data] if len(input_value.shape) == 0: inputs['array_' + edge.dst_conn] = input_value[()] else: inputs['array_' + edge.dst_conn] = input_value.copy() access = sub_state.add_access('array_' + edge.dst_conn) sub_state.add_edge( access, None, node_copy, edge.dst_conn, sub_sdfg.make_array_memlet('array_' + edge.dst_conn)) outputs = {} for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) if isinstance(desc, dt.Scalar): # we need to copy to an array of size [1] so that we can "return" the output from the sdfg desc.transient = True sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn, desc) sub_sdfg.add_array('array_' + edge.src_conn, [1], desc.dtype, transient=False) access_scalar = sub_state.add_access('scalar_array_' + edge.src_conn) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access_scalar, None, sub_sdfg.make_array_memlet('scalar_array_' + edge.src_conn)) sub_state.add_edge( access_scalar, None, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) else: desc.transient = False sub_sdfg.add_datadesc('array_' + edge.src_conn, desc) access = sub_state.add_access('array_' + edge.src_conn) sub_state.add_edge( node_copy, edge.src_conn, access, None, sub_sdfg.make_array_memlet('array_' + edge.src_conn)) if len(desc.shape) == 0: outputs['array_' + edge.src_conn] = np.empty( (1, ), desc.dtype.as_numpy_dtype()) else: outputs['array_' + edge.src_conn] = np.empty( tuple(desc.shape), desc.dtype.as_numpy_dtype()) sub_sdfg(**outputs, **inputs) for edge in state.out_edges(node): desc = copy.deepcopy(sdfg.arrays[edge.data.data]) desc.transient = False output_value = outputs['array_' + edge.src_conn] constant_name = sdfg.temp_data_name() clean_constant_name = clean_onnx_name(constant_name) sdfg.add_datadesc(clean_constant_name, desc) assert constant_name not in parent.weights if isinstance(desc, dt.Scalar): parent.weights[constant_name] = output_value.reshape(()) else: parent.weights[constant_name] = output_value access_constant = state.add_access(clean_constant_name) state.add_edge(access_constant, None, edge.dst, edge.dst_conn, sdfg.make_array_memlet(clean_constant_name)) # remove all now useless nodes with a reverse BFS queue = deque([node]) while len(queue) > 0: current_node = queue.popleft() edges = state.in_edges(current_node) state.remove_node(current_node) for e in edges: next_node = e.src if len(state.out_edges(next_node)) == 0: queue.append(next_node)
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: node.validate(sdfg, state) A_desc = in_desc_with_name(node, state, sdfg, "A") B_desc = in_desc_with_name(node, state, sdfg, "B") Y_desc = out_desc_with_name(node, state, sdfg, "Y") input0_dim = A_desc.shape input1_dim = B_desc.shape # list containing letters from z-a letters = [chr(ord('z') - i) for i in range(26)] # i j k are used for the last dimensions letters = [l for l in letters if l not in ['i', 'j', 'k']] if len(input0_dim) == 1: if len(input1_dim) != 2: raise ValueError("invalid dimensions") arg1 = 'k' arg2 = 'kj' result = 'j' elif len(input1_dim) == 1: if len(input0_dim) != 2: raise ValueError("invalid dimensions") arg1 = 'ik' arg2 = 'k' result = 'i' else: # build the einsum. The last two dimensions are always just the matrix multiply einsum # dace will later specialize to a batched matmul if possible arg1 = 'ik' arg2 = 'kj' result = 'ij' if input0_dim[-2] != input0_dim[-1]: if dace.symbolic.issymbolic(input0_dim[-2]): log.warning( f"overriding symbol {input0_dim[-2]} with value {input1_dim[-1]} in descriptor of input A of node {node}" ) new_shape = list(A_desc.shape) new_shape[-1] = input1_dim[-2] A_desc.shape = new_shape elif dace.symbolic.issymbolic(input1_dim[-1]): log.warning( f"overriding symbol {input0_dim[-1]} with value {input0_dim[-2]} in descriptor of input B of node {node}" ) new_shape = list(B_desc.shape) new_shape[-2] = input0_dim[-1] B_desc.shape = new_shape input0_dim = input0_dim[:-2] input1_dim = input1_dim[:-2] for dim0, dim1 in itertools.zip_longest(reversed(input0_dim), reversed(input1_dim)): if dim0 is None: # only dim0 exists letter = letters.pop() arg2 = letter + arg2 result = letter + result elif dim1 is None: # only dim1 exists letter = letters.pop() arg1 = letter + arg1 result = letter + result else: # both exist letter = letters.pop() arg1 = letter + arg1 arg2 = letter + arg2 result = letter + result einsum_str = '{},{}->{}'.format(arg1, arg2, result) # we lower to an ONNXEinsum node instead straight to the dace einsum to make the autodiff simpler nsdfg = dace.SDFG(node.label + "_expansion") nstate = nsdfg.add_state() einsum_node: nodes.LibraryNode = onnx_op.ONNXEinsum( node.label + "_einsum_expansion", equation=einsum_str) nstate.add_node(einsum_node) einsum_node.add_in_connector("Inputs__0") einsum_node.add_in_connector("Inputs__1") nsdfg.add_datadesc("A", copy.deepcopy(A_desc)) nsdfg.add_datadesc("B", copy.deepcopy(B_desc)) nsdfg.add_datadesc("Y", copy.deepcopy(Y_desc)) nsdfg.arrays["A"].transient = False nsdfg.arrays["B"].transient = False nsdfg.arrays["Y"].transient = False nstate.add_edge(nstate.add_read("A"), None, einsum_node, "Inputs__0", nsdfg.make_array_memlet("A")) nstate.add_edge(nstate.add_read("B"), None, einsum_node, "Inputs__1", nsdfg.make_array_memlet("B")) nstate.add_edge(einsum_node, "Output", nstate.add_write("Y"), None, nsdfg.make_array_memlet("Y")) return nsdfg