def delete_input(graph: xpb2.GraphProto, name: str): """ Removes an existing input of a graph by name Args: graph: A graph, onnx.onnx_ml_pb2.GraphProto. name: String, the name of the input as used to determine the graph topology. Returns: The extended graph. """ if type(graph) is not xpb2.GraphProto: return graph # Remove the named output found = False try: for elem in graph.input: if elem.name == name: graph.input.remove(elem) found = True except Exception as e: _print("Unable to iterate the inputs. " + str(e)) return False if not found: _print("Unable to find the input by name.") return False return graph
def add_input(graph: xpb2.GraphProto, name: str, data_type: str, dimensions: [], **kwargs): """ Add an input to a graph Args: graph: A graph, onnx.onnx_ml_pb2.GraphProto. name: String, the name of the input as used to determine the graph topology. data_type: String, the data type of the input. Run list_data_types() for an overview. dimensions: List[] specifying the dimensions of the input. **kwargs Returns: The extended graph. """ if type(graph) is not xpb2.GraphProto: _print("graph is not a valid ONNX graph.") return False dtype = _data_type(data_type) if not dtype: return False try: graph.input.append( xhelp.make_tensor_value_info(name, dtype, dimensions, **kwargs), *kwargs) except Exception as e: _print("Unable to add the input: " + str(e)) return False return graph
def rename_input(graph, current_name, new_name): """ Rename an input to a graph Args: graph: A graph, onnx.onnx_ml_pb2.GraphProto. current_name: String, the current input name. new_name: String, the name desired input name. Returns: The changed graph. """ if type(graph) is not xpb2.GraphProto: _print("graph is not a valid ONNX graph.") return False found = False for input in graph.input: if input.name == current_name: input.name = new_name found = True if not found: _print("Unable to find the input to rename.") return False # And rename it in every nodes that takes this as input: for node in graph.node: for index, name in enumerate(node.input): if name == current_name: node.input[index] = new_name return graph
def rename_output(graph, current_name, new_name): """ Rename an output to a graph Args: graph: A graph, onnx.onnx_ml_pb2.GraphProto. current_name: String, the current output name. new_name: String, the name desired output name. Returns: The changed graph. """ if type(graph) is not xpb2.GraphProto: _print("graph is not a valid ONNX graph.") return False found = False for output in graph.output: if output.name == current_name: output.name = new_name found = True if not found: _print("Unable to found the output by name.") return False for node in graph.node: for index, name in enumerate(node.output): if name == current_name: node.output[index] = new_name return graph
def list_operators(): """ List all available Scailable ONNX operators. """ try: with open(glob.VERSION_INFO_LOCATION, "r") as f: glob.ONNX_VERSION_INFO = json.load(f) except FileNotFoundError: print("Unable to locate the ONNX_VERSION INFO.") return False _print(json.dumps(glob.ONNX_VERSION_INFO['operators'], indent=2), "MSG") return True
def postfix_names(g: xpb2.GraphProto, postfix: str = "_g1", elem: str = "node"): """ postfix_names is a utility function used by concat() to rename parts of an onnx graph. When merging (or otherwise manipulating) onnx graphs it is often useful to create unique names of the various elements of the graph. This function postfixes each name in supplied graph g of elements of type elem by the supplied postfix. Args: g: The graph postfix: (Optional) The postfix for the names of the elements. Default "_g1". elem: (Optional) The type of element. Options are "node", "init", "edge", "input", "output", "io", and "all". Default "node". """ if elem == 'node': for item in g.node: item.name = item.name + postfix return g elif elem == 'init': for init in g.initializer: init.name = init.name + postfix return g elif elem == 'edge': for init in g.node: for index, name in enumerate(init.input): init.input[index] = init.input[index] + postfix for index, name in enumerate(init.output): init.output[index] = init.output[index] + postfix return g elif elem == 'input': for item in g.input: item.name = item.name + postfix return g elif elem == 'output': for item in g.output: item.name = item.name + postfix return g elif elem == 'io': cg = postfix_names(g, postfix, "input") cg = postfix_names(cg, postfix, "output") return cg elif elem == 'all': cg = postfix_names(g, postfix, "node") cg = postfix_names(cg, postfix, "init") cg = postfix_names(cg, postfix, "edge") cg = postfix_names(cg, postfix, "input") cg = postfix_names(cg, postfix, "output") return cg else: _print("No names have been changed; did you select the right element?", "MSG") return g
def display(graph: xpb2.GraphProto, _tmpfile: str = '.tmp.onnx'): """ display a onnx graph using netron. Pass a graph to the display function to open it in Netron. Note: Due to the complexities of cross platform opening of source and the potential lack of a Netron installation this function might not always behave properly. Note2: This function might leave a file called .temp.onnx if it fails to remove the file. Args: graph: an ONNX graph _tmpfile: an optional string with the temporary file name. Default .tmp.onnx Returns: True if one of the 3 methods to open the file did not raise any warnings. Raises: SclblONNXError """ if type(graph) is not xpb2.GraphProto: _print("graph is not a valid ONNX graph.") return False # store as tmpfile graph_to_file(graph, _tmpfile) file_open = False # try open on unix: if not file_open: try: subprocess.run(['xdg-open', _tmpfile]) file_open = True except Exception: file_open = False # try open on mac: if not file_open: try: subprocess.run(['open', _tmpfile]) file_open = True except Exception: file_open = False # try open on windows: if not file_open: try: os.startfile(_tmpfile) file_open = True except Exception: file_open = False # Result: return file_open
def delete_node(graph: xpb2.GraphProto, node_name: str = "", **kwargs): """ Add node appends a node to graph g and returns the extended graph Prints a message and returns False if fails. Args: graph: A graph, onnx.onnx_ml_pb2.GraphProto. node_name: Name of the node to remove. **kwargs Returns: The extended graph. """ if type(graph) is not xpb2.GraphProto: _print("The graph is not a valid ONNX graph.") return False if not node_name: _print("Please specify a node name.") return False found = False try: for elem in graph.node: if elem.name == node_name: graph.node.remove(elem) found = True except Exception as e: _print("Unable to iterate the nodes. " + str(e)) return False if not found: _print("Unable to find the node by name.") return False return graph
def replace_input(graph: xpb2.GraphProto, name: str, data_type: str, dimensions: [], **kwargs): """ Changes an existing input in a graph Args: graph: A graph, onnx.onnx_ml_pb2.GraphProto. name: String, the name of the input as used to determine the graph topology. data_type: String, the data type of the input. Run list_data_types() for an overview. dimensions: List[] specifying the dimensions of the input., **kwargs Returns: The extended graph. """ if type(graph) is not xpb2.GraphProto: _print("graph is not a valid ONNX graph.") return graph # Remove the named input found = False try: for elem in graph.input: if elem.name == name: graph.input.remove(elem) found = True except Exception as e: _print("Unable to iterate the inputs. " + str(e)) return False if not found: _print("Unable to find the input by name.") # Create the new value try: val = _value(name, data_type, dimensions, **kwargs) except Exception as e: _print("Unable to create value. " + str(e)) return False # Add the value to the input try: graph.input.append(val, *kwargs) except Exception as e: _print("Unable to add the input: " + str(e)) return False return graph
def empty_graph(_default_name: str = "sclblgraph"): """ empty_graph returns an empty graph Note, an empty graph does not pass the check() as it does not contain input and output. Args: _default_name: Graph name, default sclblgraph Returns: An empty graph """ try: graph = xpb2.GraphProto(name=_default_name) except Exception as e: _print("Unable to create graph: " + str(e)) return False return graph
def graph_to_file(graph: xpb2.GraphProto, filename: str, _producer: str = "sclblonnx", onnx_opset_version=12, **kwargs): """ graph_to_file stores an onnx graph to a .onnx file Stores a graph to a file Args: graph: An onnx graph filename: The filename of the resulting file _producer: Optional string with producer name. Default 'sclblonnx' onnx_opset_version: Optional version number for ONNX opset. Default 12 Returns: True if successful, False otherwise. """ if not filename: _print("Unable to save: Please specify a filename.") return False if type(graph) is not xpb2.GraphProto: _print("Unable to save: Graph is not an ONNX graph") try: if not 'opset_imports' in kwargs: op = onnx.OperatorSetIdProto() op.version = onnx_opset_version mod = xhelp.make_model(graph, producer_name=_producer, opset_imports=[op], **kwargs) else: mod = xhelp.make_model(graph, producer_name=_producer, **kwargs) except Exception as e: print("Unable to convert graph to model: " + str(e)) return False try: xsave(mod, filename, **kwargs) except Exception as e: print("Unable to save the model: " + str(e)) return False return True
def node(op_type: str, inputs: [], outputs: [], name: str = "", **kwargs): """ Create a new node Args: op_type: Operator type, see https://github.com/onnx/onnx/blob/master/docs/Operators.md inputs: [] list of inputs (names to determine the graph topology) outputs: [] list of outputs (names to determine the graph topology) name: The name of this node (Optional) **kwargs """ if not name: name = "sclbl-onnx-node" + str(glob.NODE_COUNT) glob.NODE_COUNT += 1 try: node = xhelp.make_node(op_type, inputs, outputs, name, **kwargs) except Exception as e: _print("Unable to create node: " + str(e)) return False return node
def list_inputs(graph: xpb2.GraphProto): """ Tries to list the inputs of a given graph. Args: graph the ONNX graph """ if type(graph) is not xpb2.GraphProto: _print("graph is not a valid ONNX graph.") return False i = 1 for elem in graph.input: name, dtype, shape = _parse_element(elem) print("Input {}: Name: '{}', Type: {}, Dimension: {}".format( i, name, dtype, shape)) i += 1 if i == 1: print("No inputs found.") return True
def add_node(graph: xpb2.GraphProto, node: xpb2.NodeProto, **kwargs): """ Add node appends a node to graph g and returns the extended graph Prints a message and returns False if fails. Args: graph: A graph, onnx.onnx_ml_pb2.GraphProto. node: A node, onnx.onnx_ml_pb2.NodeProto. **kwargs Returns: The extended graph. """ if type(graph) is not xpb2.GraphProto: _print("The graph is not a valid ONNX graph.") return False if type(node) is not xpb2.NodeProto: _print("The node is not a valid ONNX node.") return False try: graph.node.append(node, **kwargs) except Exception as e: _print("Unable to extend graph: " + str(e)) return False return graph
def run(graph: xpb2.GraphProto, inputs: {}, outputs: [], _tmpfile: str = ".tmp.onnx", onnx_opset_version=12, **kwargs): """ run executes a give graph with the given input and returns the output Args: graph: The onnx graph inputs: an object with the named inputs; please check the data types outputs: list of named outputs _tmpfile: String the temporary filename for the onnx file to run. onnx_opset_version: Optional version number for ONNX opset. Default 12 Returns: The result (or False if it fails somewhere) """ store = graph_to_file(graph, _tmpfile, onnx_opset_version=onnx_opset_version) if not store: _print("Unable to store model for evaluation.") return False try: sess = xrt.InferenceSession(_tmpfile, **kwargs) out = sess.run(outputs, inputs) except Exception as e: _print("Failed to run the model: " + str(e)) return False try: os.remove(_tmpfile) except Exception as e: print("We were unable to delete the file " + _tmpfile, "MSG") return out
def graph_from_file(filename: str): """ Retrieve a graph object from an onnx file Function attempts to open a .onnx file and returns its graph. Args: filename: String indicating the filename / relative location. Returns: An ONNX graph or False if unable to open. """ mod_temp = xmp() try: with open(filename, 'rb') as fid: content = fid.read() mod_temp.ParseFromString(content) graph = mod_temp.graph except Exception as e: _print("Unable to open the file: " + str(e)) return False return graph
def constant(name: str, value: np.array, data_type: str, **kwargs): """ Create a constant node Args: name: Name of the (output value of the) constant node to determine the graph topology value: Values of the node (as a np.array) data_type: Data type of the node **kwargs Returns: A constant node. """ if not name: _print("Unable to create unnamed constant.") return False dtype = _data_type(data_type) if not dtype: return False try: constant_node = xhelp.make_node('Constant', inputs=[], outputs=[name], name=name + "-constant", value=xhelp.make_tensor( name=name + "-values", data_type=dtype, dims=value.shape, vals=value.flatten()), **kwargs) except Exception as e: _print("Unable to create the constant node: " + str(e)) return False return constant_node
def add_constant(graph: xpb2.GraphProto, name: str, value: np.array, data_type: str, **kwargs): """ Create and add a constant node to an existing graph. Note: use add_node() if you want to add an existing constant node to an existing graph Args: graph: A graph, onnx.onnx_ml_pb2.GraphProto. name: Name of the (output value of the) constant node to determine the graph topology value: Values of the node (as a np.array) data_type: Data type of the node Returns: The extended graph. """ if type(graph) is not xpb2.GraphProto: print("graph is not a valid ONNX graph.") return False dtype = _data_type(data_type) if not dtype: return False try: constant_node = xhelp.make_node('Constant', inputs=[], outputs=[name], name=name + "-constant", value=xhelp.make_tensor( name=name + "-values", data_type=dtype, dims=value.shape, vals=value.flatten()), **kwargs) except Exception as e: _print("Unable to create the constant node: " + str(e)) return False try: graph = add_node(graph, constant_node, **kwargs) except Exception as e: _print("Unable to add the constant node to the graph: " + str(e)) return False if not graph: _print("Unable to add constant node to graph.") return False return graph
def test__print(): print("\n") _print("Red warning.") _print("Normal feedback", "MSG") _print("Green literal", "LIT") pass
def check(graph: xpb2.GraphProto, _producer: str = "sclblonnx", _onnx_check: bool = True, _sclbl_check: bool = True, _verbose: bool = True, **kwargs): """ check whether or not an existing graph can be converted using the Scailable platform We assume that a user will use graph_to_file() in this package to store the model. This Args: graph: an ONNX graph _producer: String optional _onnx_check: Bool, default True. Run ONNX checker.check(). _sclbl_check: Bool, default True. Run Scailable checks. _verbose: Print user feedback; default True (note, errors are always printed). **kwargs Returns: True if the graph passes all the test. False otherwise. """ # Check if this is a valid graph: if type(graph) is not xpb2.GraphProto: _print("Graph is not a valid ONNX graph.") return False # Convert to model: try: if not 'opset_imports' in kwargs: op = onnx.OperatorSetIdProto() op.version = 12 mod = xhelp.make_model(graph, producer_name=_producer, opset_imports=[op], **kwargs) else: mod = xhelp.make_model(graph, producer_name=_producer, **kwargs) except Exception as e: _print("Unable to create the model: " + str(e)) return False # Standard ONNX checking: if _onnx_check and False: try: checker.check_model(mod, **kwargs) except Exception as e: _print("Model fails on standard ONNX checker: " + str(e)) return False if _sclbl_check: # User feedback _print( "Running Scailable specific checks for WASM conversion. \nUse _sclbl_check=False to turn off", "MSG", (not _verbose)) # input / output checking: if not graph.input: _print("This graph does not contain any inputs.") return False if not graph.output: _print("This graph does not contain any outputs.") return False # Sclbl checking: if not glob.ONNX_VERSION_INFO: if not _load_version_info(): _print("Unable to load the ONNX_VERSION INFO.") # Check general ONNX version: if version.parse(xversion) < version.parse( glob.ONNX_VERSION_INFO['onnx_version']['version_min']): _print( "Your current onnx version is lower then our support minimum. Please update your ONNX to {}" .format(glob.ONNX_VERSION_INFO['onnx_version']['version_min'])) return False if version.parse(xversion) > version.parse( glob.ONNX_VERSION_INFO['onnx_version']['version_max']): _print( "Your current onnx version is higher then our support max. Please downgrade your ONNX version to {}" .format(glob.ONNX_VERSION_INFO['onnx_version']['version_max'])) return False if mod.ir_version < glob.ONNX_VERSION_INFO['onnx_version'][ 'ir_version_min']: _print( "Your current IR version is lower then our support minimum. Please update to {}" .format( glob.ONNX_VERSION_INFO['onnx_version']['ir_version_min'])) return False if mod.ir_version > glob.ONNX_VERSION_INFO['onnx_version'][ 'ir_version_max']: _print( "Your current IR version is higher then our support max. Please downgrade to {}" .format( glob.ONNX_VERSION_INFO['onnx_version']['ir_version_max'])) return False # Interate through opset and check: for key in mod.opset_import: v = key.version if v < glob.ONNX_VERSION_INFO['onnx_version']['opset_min']: _print( "One or more operators use an opset version that is too low. Please update to {}" .format( glob.ONNX_VERSION_INFO['onnx_version']['opset_min'])) return False if v > glob.ONNX_VERSION_INFO['onnx_version']['opset_max']: _print( "One or more operators use an opset version that is too high. Please downgrade to {}" .format( glob.ONNX_VERSION_INFO['onnx_version']['opset_max'])) return False # Check individual nodes: not_supported = [] for n in graph.node: op = n.op_type if op not in glob.ONNX_VERSION_INFO['operators']: not_supported.append(op) if not_supported: _print("The operator(s) {} are currently not supported.".format( not_supported)) return False # Check dynamic for inputs in graph.input: if not inputs.type.tensor_type.shape.dim: _print( "Your graph contains dynamically sized inputs, this is currently not supported." ) return False for elem in inputs.type.tensor_type.shape.dim: if elem.dim_value == 0 or elem.dim_value == "": _print( "Your graph contains dynamically size inputs, this is currently not supported." ) if not _sclbl_check and not _onnx_check: _print("Set _sclbl_check or _onnx_check to True to run any checks.") _print("Your graph was successfully checked.", "MSG", (not _verbose)) return True
def clean(graph: xpb2.GraphProto, _optimize: bool = True, _simplify: bool = True, _remove_initializer: bool = True, _producer: str = "sclblonnx", _verbose: bool = True, **kwargs): """ clean cleans an ONNX graph using onnx tooling This method will attempt to clean the supplied graph by a. Removing initializers from input b. Optimizing it using onnxoptimizer.optimize c. Simplifying it using onnxsim.simplify If one of these fails the method will print an error message and return the unaltered graph. Args: graph: An ONNX graph _optimize: Boolean, default True. Optimize the model using onnxoptimizer. _simplify: Boolean, default True. Simplify the model using simplify. _remove_initializer: Boolean, default True. Remove initializers from input. _producer: Optional string with producer name. Default 'sclblonnx' (used for internal conversion) _verbose: Print user feedback; default True (note, errors are always printed). **kwargs Returns: The cleaned ONNX graph, or the old graph if an error occurs. """ try: if not 'opset_imports' in kwargs: op = onnx.OperatorSetIdProto() op.version = 12 mod = xhelp.make_model(graph, producer_name=_producer, opset_imports=[op], **kwargs) else: mod = xhelp.make_model(graph, producer_name=_producer, **kwargs) except Exception as e: _print("Unable to create the model: " + str(e)) return graph if _optimize: try: mod = onnxoptimizer.optimize(mod, glob.OPTIMIZER_PASSES, **kwargs) except Exception as e: _print("Unable to optimize your model: " + str(e)) return graph if _simplify: try: mod, _ = simplify(mod, **kwargs) except Exception as e: _print("Unable to simplify your model: " + str(e)) return graph # From: onnxruntime/tools/python/remove_initializer_from_input.py graph = mod.graph if _remove_initializer: inputs = graph.input name_to_input = {} for input in inputs: name_to_input[input.name] = input for initializer in graph.initializer: if initializer.name in name_to_input: inputs.remove(name_to_input[initializer.name]) _print("The graph was successfully cleaned.", "MSG", (not _verbose)) return graph
def sclbl_input(inputs: {}, example_type: str = "pb", _verbose: bool = True): """ input_str returns an example input for a Scailable runtime The method takes a valid input object to an onnx graph (i.e., one used for the "inputs" argument in the run() function, and returns and prints an example input to a Scailable runtime / REST endpoint Args: inputs: The input object as supplied to the run() function to test an ONNX grph example_type: The type of example string ("raw" for base64 encoded, or "pb" for protobuf, default pb) _verbose: Print user feedback; default True (note, errors are always printed). Returns: An example input to a Scailable runtime. """ if not inputs: _print("No input provided.") if example_type == "raw": if len(inputs) == 1: for val in inputs.values(): bytes = val.tobytes() encoded = base64.b64encode(bytes) value_str = '"' + encoded.decode('ascii') + '"' else: value_str = '["' for val in inputs.values(): bytes = val.tobytes() encoded = base64.b64encode(bytes) value_str += (encoded.decode('ascii') + '","') value_str = value_str.rstrip(',"') value_str += '"]' input_json = '{"input": ' + value_str + ', "type":"raw"}' if _verbose: _print( "The following input string can be used for the Scailable runtime:", "MSG") _print(input_json, "LIT") return input_json elif example_type == "pb" or "protobuf": if len(inputs) == 1: for val in inputs.values(): tensor = xnp.from_array(val) serialized = tensor.SerializeToString() encoded = base64.b64encode(serialized) value_str = '"' + encoded.decode('ascii') + '"' else: value_str = '["' for val in inputs.values(): tensor = xnp.from_array(val) serialized = tensor.SerializeToString() encoded = base64.b64encode(serialized) value_str += (encoded.decode('ascii') + '","') value_str = value_str.rstrip(',"') value_str += '"]' input_json = '{"input": ' + value_str + ', "type":"pb"}' if _verbose: _print( "The following input string can be used for the Scailable runtime:", "MSG") _print(input_json, "LIT") _print( "The following input string can be used for the web front-end:", "MSG") _print(value_str, "LIT") return input_json
def concat(sg1: xpb2.GraphProto, sg2: xpb2.GraphProto, complete: bool = False, rename_nodes: bool = True, io_match: [] = None, rename_io: bool = False, edge_match: [] = None, rename_edges: bool = False, rename_init: bool = False, _verbose: bool = True, **kwargs): """ concat concatenates two graphs. Concat is the flexible (but also rather complex) workhorse for the merge, join, and split functions and can be used to quite flexibly paste together two (sub)graphs. Contrary to merge, join, and split, concat does not by default assume the resulting onnx graph to be complete (i.e., to contain inputs and outputs and to pass check()), and it can thus be used as an intermediate function when constructing larger graphs. Concat is flexible and versatile, but it takes time to master. See example_merge.py in the examples folder for a number of examples. Args: sg1: Subgraph 1, the parent. sg2: Subgraph 2, the child. complete: (Optional) Boolean indicating whether the resulting graph should be checked using so.check(). Default False. rename_nodes: (Optional) Boolean indicating whether the names of the nodes in the graph should be made unique. Default True. io_match: (Optional) Dict containing pairs of outputs of sg1 that should be matched to inputs of sg2. Default []. rename_io: (Optional) Boolean indicating whether the inputs and outputs of the graph should be renamed. Default False. edge_match: (Optional) Dict containing pairs edge names of sg1 (i.e., node outputs) that should be matched to edges of sg2 (i.e., node inputs). Default []. rename_edges: (Optional) Boolean indicating whether the edges should be renamed (default False) _verbose: (Optional) Boolean indicating whether verbose output should be printed (default False) Returns: The concatenated graph g, or False if something goes wrong along the way. """ # immutable defaults: if io_match is None: io_match = [] if edge_match is None: edge_match = [] # prevent changes to original sg1 = copy.deepcopy(sg1) sg2 = copy.deepcopy(sg2) # Check input types: if type(sg1) is not xpb2.GraphProto: _print("Graph sg1 is not an ONNX graph. Abort.") return False if type(sg2) is not xpb2.GraphProto: _print("Graph sg2 is not an ONNX graph. Abort.") return False # Rename node names if requested (default True) if rename_nodes: _print("Renaming node names in graph.", "MSG", (not _verbose)) sg1 = postfix_names(sg1, "_sg1", "node") sg2 = postfix_names(sg2, "_sg2", "node") if io_match: _print("Matching specified inputs and outputs..", "MSG", (not _verbose)) for io_pair in io_match: for outputs in sg1.output: if outputs.name == io_pair[0]: sg1 = delete_output(sg1, io_pair[0]) for inputs in sg2.input: if inputs.name == io_pair[1]: sg2 = delete_input(sg2, io_pair[1]) for item in sg2.node: for index, name in enumerate(item.input): if name == io_pair[1]: item.input[index] = io_pair[0] if rename_io: _print("Renaming inputs and outputs.", "MSG", (not _verbose)) sg1 = postfix_names(sg1, "_sg1", "io") sg2 = postfix_names(sg2, "_sg2", "io") if edge_match: _print("Matching edges.", "MSG", (not _verbose)) for edge_pair in edge_match: for item in sg2.node: for index, name in enumerate(item.input): if name == edge_pair[1]: item.input[index] = edge_pair[0] if rename_edges: _print("Renaming edges.", "MSG", (not _verbose)) sg1 = postfix_names(sg1, "_sg1", "edge") sg2 = postfix_names(sg2, "_sg2", "edge") if rename_init: _print("Renaming init.", "MSG", (not _verbose)) sg1 = postfix_names(sg1, "_sg1", "init") sg2 = postfix_names(sg2, "_sg2", "init") # Paste graphs together: _print("Pasting graphs.", "MSG", (not _verbose)) g = _paste_graphs(sg1, sg2) if complete: if not check(g, _verbose=_verbose, **kwargs): _print( "The end result does not pass check(). Are you sure you want a complete result? Set complete=False " "to continue concat without checking.") return False return g
def split(pg: xpb2.GraphProto, cg1: xpb2.GraphProto, cg2: xpb2.GraphProto, cg1_match: [] = None, cg2_match: [] = None, complete: bool = True, _verbose: bool = True, **kwargs): """ split takes takes a single parent and matches the outputs to the inputs of two childs (cg1 & cg2) Split matches the outputs of pg to the inputs of cg1 and cg2 as specified in cg1_match and cg2_match. Desired matches are specified in pairs: [("out1","in1"), ("out2","in2"),...]. Split by default assumes the resulting joined graph to be complete. Split is merely a wrapper around concat (used twice). For more flexible combinations of graphs please see concat(). Note: ONNX concat operations might give unexpected results if names of elements collide, please use postfix_names() to prevent this (and always critically inspect the resulting graph). Args: pg: The parent graph cg1: The left child. cg2: The right child. cg1_match: (Optional) List of pairs matching outputs of pg to inputs of cg1. Default []. cg2_match: (Optional) List of pairs matching outputs of pg to inputs of cg2. Default []. complete: (Optional) Boolean indicating whether the resulting graph should be complete (i.e., should pass check). Default True. _verbose: (Optional) Boolean indicating whether or not verbose user feedback should be provided. Default True. Returns: The joined graph g (of False is something fails along the way). """ # immutable defaults: if cg1_match is None: cg1_match = [] if cg2_match is None: cg2_match = [] # prevent changes to original pg = copy.deepcopy(pg) cg1 = copy.deepcopy(cg1) cg2 = copy.deepcopy(cg2) if type(pg) is not xpb2.GraphProto: _print("Graph pg is not an ONNX graph.") return False if type(cg1) is not xpb2.GraphProto: _print("Graph cg1 is not an ONNX graph.") return False if type(cg2) is not xpb2.GraphProto: _print("Graph cg2 is not an ONNX graph.") return False # Create the split (using concat 2x) g1 = concat(pg, cg1, rename_nodes=True, io_match=cg1_match, complete=False, _verbose=False, **kwargs) g = concat(g1, cg2, rename_nodes=True, io_match=cg2_match, complete=complete, _verbose=False, **kwargs) if not g: _print( "Graph merge failed. Please checkout concat() for additional options.", "MSG", (not _verbose)) return g
def merge(sg1: xpb2.GraphProto, sg2: xpb2.GraphProto, outputs: [] = None, inputs: [] = None, io_match: [] = None, complete: bool = True, _verbose: bool = True, **kwargs): """ merge merges two graphs. Given subgraph sg1 and subgraph sg2 merge attempts to link the identified outputs of sg1 to the inputs of sg2 resulting in a graph in which sg1 is the parent of sg2. Merge expects two complete graphs (i.e., it expects sg1 and sg2 to pass check(). If you would like more flexible merge options or partial merge please see the concat function (merge is merely a constrained wrapper around concat). Note: The args inputs and outputs are present for legacy reasons, we recommend using io_match directly. Args: sg1: Subgraph 1, the parent. sg2: Subgraph 2, the child. outputs: (Optional) A list of strings containing the names of the outputs of sg1 that are matched to inputs (in order of the desired match). inputs: (Optional) A list of strings containing the names of the inputs of sg2 to which the outputs of sg1 are matched. io_match: (Optional) A list of names pairs [("out1","in1"), ("out2","in2"),...]. This is an alternative for the inputs/outputs arguments. complete: (Optional) Boolean indicating whether the resulting graph should be complete (i.e., should pass check). Default True. _verbose: (Optional) Boolean indicating whether or not verbose user feedback should be provided. Default True. Returns: The merged graph g, or False (with a printed error message) if something is wrong. """ # immutable defaults: if inputs is None: inputs = [] if outputs is None: outputs = [] if io_match is None: io_match = [] # prevent changes to original sg1 = copy.deepcopy(sg1) sg2 = copy.deepcopy(sg2) # Check the inputs: if type(sg1) is not xpb2.GraphProto: _print("Graph sg1 is not an ONNX graph.") return False if type(sg2) is not xpb2.GraphProto: _print("Graph sg2 is not an ONNX graph.") return False if len(outputs) != len(inputs): _print("The number of outputs and inputs do not match.") return False if len(inputs) > 0 and len(io_match) > 0: _print( "Please use either the inputs/outputs arguments OR the io_match argument (not both)." ) return False # Construct IO pairs if len(inputs) > 0: _print("Constructing the io_match list from your input and output.", "MSG", (not _verbose)) io_match = [] for idx, val in enumerate(outputs): io_match.append((val, inputs[idx])) # Use concat to do the merge g = concat(sg1, sg2, io_match=io_match, complete=complete, **kwargs) if not g: _print( "Graph merge failed. Please checkout concat for additional options.", "MSG", (not _verbose)) return g
def join(pg1: xpb2.GraphProto, pg2: xpb2.GraphProto, cg: xpb2.GraphProto, pg1_match: [] = None, pg2_match: [] = None, complete: bool = True, _verbose: bool = True, **kwargs): """ join takes two parent graphs (pg1 & pg2) and merges them with a child graph cg. Join matches the outputs of pg1 to the inputs of cg specified in pg1_match, and similarly for pg2 and pg2_match. Desired matches are specified in pairs: [("out1","in1"), ("out2","in2"),...]. Join by default assumes the resulting joined graph to be complete. Join is merely a wrapper around concat (used twice). For more flexible combinations of graphs please see concat(). Note: ONNX concat operations might give unexpected results if names of elements collide, please use postfix_names() to prevent this (and always critically inspect the resulting graph). Args: pg1: Parent graph 1. pg2: Parent graph 2. cg: Child graph, the graph that will join together pg1 and pg2. pg1_match: (Optional) List of pairs matching outputs of pg1 to inputs of cg. Default []. pg2_match: (Optional) List of pairs matching outputs of pg2 to inputs of cg. Default []. complete: (Optional) Boolean indicating whether the resulting graph should be complete (i.e., should pass check). Default True. _verbose: (Optional) Boolean indicating whether or not verbose user feedback should be provided. Default True. Returns: The joined graph g (of False is something fails along the way). """ # immutable defaults: if pg1_match is None: pg1_match = [] if pg2_match is None: pg2_match = [] # prevent changes to original pg1 = copy.deepcopy(pg1) pg2 = copy.deepcopy(pg2) cg = copy.deepcopy(cg) if type(pg1) is not xpb2.GraphProto: _print("Graph pg1 is not an ONNX graph.") return False if type(pg2) is not xpb2.GraphProto: _print("Graph pg2 is not an ONNX graph.") return False if type(cg) is not xpb2.GraphProto: _print("Graph cg is not an ONNX graph.") return False # Construct the match list io_match = pg1_match io_match.extend(pg2_match) # Do the joint (2x concat) g1 = concat(pg1, pg2, rename_nodes=True, complete=False, _verbose=False, **kwargs) g = concat(g1, cg, rename_nodes=True, io_match=io_match, complete=complete, _verbose=False, **kwargs) if not g: _print( "Graph merge failed. Please checkout concat for additional options.", "MSG", (not _verbose)) return g
def list_data_types(): """ List all available data types. """ _print(json.dumps(glob.DATA_TYPES, indent=2), "MSG") _print("Note: STRINGS are not supported at this time.", "LIT") return True