Exemple #1
0
def delete_input(graph: xpb2.GraphProto, name: str):
    """ Removes an existing input of a graph by name

    Args:
        graph: A graph, onnx.onnx_ml_pb2.GraphProto.
        name: String, the name of the input as used to determine the graph topology.

    Returns:
        The extended graph.

    """
    if type(graph) is not xpb2.GraphProto:
        return graph

    # Remove the named output
    found = False
    try:
        for elem in graph.input:
            if elem.name == name:
                graph.input.remove(elem)
                found = True
    except Exception as e:
        _print("Unable to iterate the inputs. " + str(e))
        return False
    if not found:
        _print("Unable to find the input by name.")
        return False

    return graph
Exemple #2
0
def add_input(graph: xpb2.GraphProto, name: str, data_type: str,
              dimensions: [], **kwargs):
    """ Add an input to a graph

    Args:
        graph: A graph, onnx.onnx_ml_pb2.GraphProto.
        name: String, the name of the input as used to determine the graph topology.
        data_type: String, the data type of the input. Run list_data_types() for an overview.
        dimensions: List[] specifying the dimensions of the input.
        **kwargs

    Returns:
        The extended graph.

    """
    if type(graph) is not xpb2.GraphProto:
        _print("graph is not a valid ONNX graph.")
        return False

    dtype = _data_type(data_type)
    if not dtype:
        return False

    try:
        graph.input.append(
            xhelp.make_tensor_value_info(name, dtype, dimensions, **kwargs),
            *kwargs)
    except Exception as e:
        _print("Unable to add the input: " + str(e))
        return False
    return graph
Exemple #3
0
def rename_input(graph, current_name, new_name):
    """ Rename an input to a graph

    Args:
        graph: A graph, onnx.onnx_ml_pb2.GraphProto.
        current_name: String, the current input name.
        new_name: String, the name desired input name.

    Returns:
        The changed graph.
    """
    if type(graph) is not xpb2.GraphProto:
        _print("graph is not a valid ONNX graph.")
        return False

    found = False
    for input in graph.input:
        if input.name == current_name:
            input.name = new_name
            found = True
    if not found:
        _print("Unable to find the input to rename.")
        return False

    # And rename it in every nodes that takes this as input:
    for node in graph.node:
        for index, name in enumerate(node.input):
            if name == current_name:
                node.input[index] = new_name

    return graph
Exemple #4
0
def rename_output(graph, current_name, new_name):
    """ Rename an output to a graph

    Args:
        graph: A graph, onnx.onnx_ml_pb2.GraphProto.
        current_name: String, the current output name.
        new_name: String, the name desired output name.

    Returns:
        The changed graph.
    """
    if type(graph) is not xpb2.GraphProto:
        _print("graph is not a valid ONNX graph.")
        return False

    found = False
    for output in graph.output:
        if output.name == current_name:
            output.name = new_name
            found = True
    if not found:
        _print("Unable to found the output by name.")
        return False

    for node in graph.node:
        for index, name in enumerate(node.output):
            if name == current_name:
                node.output[index] = new_name

    return graph
Exemple #5
0
def list_operators():
    """ List all available Scailable ONNX operators. """
    try:
        with open(glob.VERSION_INFO_LOCATION, "r") as f:
            glob.ONNX_VERSION_INFO = json.load(f)
    except FileNotFoundError:
        print("Unable to locate the ONNX_VERSION INFO.")
        return False
    _print(json.dumps(glob.ONNX_VERSION_INFO['operators'], indent=2), "MSG")
    return True
Exemple #6
0
def postfix_names(g: xpb2.GraphProto,
                  postfix: str = "_g1",
                  elem: str = "node"):
    """
    postfix_names is a utility function used by concat() to rename parts of an onnx graph.

    When merging (or otherwise manipulating) onnx graphs it is often useful to create unique names of the
    various elements of the graph. This function postfixes each name in supplied graph g of elements of type elem
    by the supplied postfix.

    Args:
        g: The graph
        postfix: (Optional) The postfix for the names of the elements. Default "_g1".
        elem: (Optional) The type of element. Options are "node", "init", "edge", "input", "output", "io", and "all". Default "node".
    """
    if elem == 'node':
        for item in g.node:
            item.name = item.name + postfix
        return g
    elif elem == 'init':
        for init in g.initializer:
            init.name = init.name + postfix
        return g
    elif elem == 'edge':
        for init in g.node:
            for index, name in enumerate(init.input):
                init.input[index] = init.input[index] + postfix
            for index, name in enumerate(init.output):
                init.output[index] = init.output[index] + postfix
        return g
    elif elem == 'input':
        for item in g.input:
            item.name = item.name + postfix
        return g
    elif elem == 'output':
        for item in g.output:
            item.name = item.name + postfix
        return g
    elif elem == 'io':
        cg = postfix_names(g, postfix, "input")
        cg = postfix_names(cg, postfix, "output")
        return cg
    elif elem == 'all':
        cg = postfix_names(g, postfix, "node")
        cg = postfix_names(cg, postfix, "init")
        cg = postfix_names(cg, postfix, "edge")
        cg = postfix_names(cg, postfix, "input")
        cg = postfix_names(cg, postfix, "output")
        return cg
    else:
        _print("No names have been changed; did you select the right element?",
               "MSG")

    return g
Exemple #7
0
def display(graph: xpb2.GraphProto, _tmpfile: str = '.tmp.onnx'):
    """ display a onnx graph using netron.

    Pass a graph to the display function to open it in Netron.
    Note: Due to the complexities of cross platform opening of source and the potential lack of
    a Netron installation this function might not always behave properly.
    Note2: This function might leave a file called .temp.onnx if it fails to remove the file.

    Args:
        graph: an ONNX graph
        _tmpfile: an optional string with the temporary file name. Default .tmp.onnx

    Returns:
        True if one of the 3 methods to open the file did not raise any warnings.

    Raises:
        SclblONNXError
    """
    if type(graph) is not xpb2.GraphProto:
        _print("graph is not a valid ONNX graph.")
        return False

    # store as tmpfile
    graph_to_file(graph, _tmpfile)

    file_open = False
    # try open on unix:
    if not file_open:
        try:
            subprocess.run(['xdg-open', _tmpfile])
            file_open = True
        except Exception:
            file_open = False

    # try open on mac:
    if not file_open:
        try:
            subprocess.run(['open', _tmpfile])
            file_open = True
        except Exception:
            file_open = False

    # try open on windows:
    if not file_open:
        try:
            os.startfile(_tmpfile)
            file_open = True
        except Exception:
            file_open = False

    # Result:
    return file_open
Exemple #8
0
def delete_node(graph: xpb2.GraphProto, node_name: str = "", **kwargs):
    """ Add node appends a node to graph g and returns the extended graph

    Prints a message and returns False if fails.

    Args:
        graph: A graph, onnx.onnx_ml_pb2.GraphProto.
        node_name: Name of the node to remove.
        **kwargs

    Returns:
        The extended graph.
    """
    if type(graph) is not xpb2.GraphProto:
        _print("The graph is not a valid ONNX graph.")
        return False

    if not node_name:
        _print("Please specify a node name.")
        return False

    found = False
    try:
        for elem in graph.node:
            if elem.name == node_name:
                graph.node.remove(elem)
                found = True
    except Exception as e:
        _print("Unable to iterate the nodes. " + str(e))
        return False
    if not found:
        _print("Unable to find the node by name.")
        return False

    return graph
Exemple #9
0
def replace_input(graph: xpb2.GraphProto, name: str, data_type: str,
                  dimensions: [], **kwargs):
    """ Changes an existing input in a graph

    Args:
        graph: A graph, onnx.onnx_ml_pb2.GraphProto.
        name: String, the name of the input as used to determine the graph topology.
        data_type: String, the data type of the input. Run list_data_types() for an overview.
        dimensions: List[] specifying the dimensions of the input.,
        **kwargs

    Returns:
        The extended graph.

    """
    if type(graph) is not xpb2.GraphProto:
        _print("graph is not a valid ONNX graph.")
        return graph

    # Remove the named input
    found = False
    try:
        for elem in graph.input:
            if elem.name == name:
                graph.input.remove(elem)
                found = True
    except Exception as e:
        _print("Unable to iterate the inputs. " + str(e))
        return False
    if not found:
        _print("Unable to find the input by name.")

    # Create the new value
    try:
        val = _value(name, data_type, dimensions, **kwargs)
    except Exception as e:
        _print("Unable to create value. " + str(e))
        return False

    # Add the value to the input
    try:
        graph.input.append(val, *kwargs)
    except Exception as e:
        _print("Unable to add the input: " + str(e))
        return False

    return graph
Exemple #10
0
def empty_graph(_default_name: str = "sclblgraph"):
    """ empty_graph returns an empty graph

    Note, an empty graph does not pass the check() as it does not contain input and output.

    Args:
        _default_name: Graph name, default sclblgraph

    Returns:
        An empty graph
    """
    try:
        graph = xpb2.GraphProto(name=_default_name)
    except Exception as e:
        _print("Unable to create graph: " + str(e))
        return False
    return graph
Exemple #11
0
def graph_to_file(graph: xpb2.GraphProto,
                  filename: str,
                  _producer: str = "sclblonnx",
                  onnx_opset_version=12,
                  **kwargs):
    """ graph_to_file stores an onnx graph to a .onnx file

    Stores a graph to a file

    Args:
        graph: An onnx graph
        filename: The filename of the resulting file
        _producer: Optional string with producer name. Default 'sclblonnx'
        onnx_opset_version: Optional version number for ONNX opset. Default 12
    Returns:
        True if successful, False otherwise.
    """
    if not filename:
        _print("Unable to save: Please specify a filename.")
        return False

    if type(graph) is not xpb2.GraphProto:
        _print("Unable to save: Graph is not an ONNX graph")

    try:
        if not 'opset_imports' in kwargs:
            op = onnx.OperatorSetIdProto()
            op.version = onnx_opset_version
            mod = xhelp.make_model(graph,
                                   producer_name=_producer,
                                   opset_imports=[op],
                                   **kwargs)
        else:
            mod = xhelp.make_model(graph, producer_name=_producer, **kwargs)
    except Exception as e:
        print("Unable to convert graph to model: " + str(e))
        return False

    try:
        xsave(mod, filename, **kwargs)
    except Exception as e:
        print("Unable to save the model: " + str(e))
        return False

    return True
Exemple #12
0
def node(op_type: str, inputs: [], outputs: [], name: str = "", **kwargs):
    """ Create a new node

    Args:
        op_type: Operator type, see https://github.com/onnx/onnx/blob/master/docs/Operators.md
        inputs: [] list of inputs (names to determine the graph topology)
        outputs: [] list of outputs (names to determine the graph topology)
        name: The name of this node (Optional)
        **kwargs
    """
    if not name:
        name = "sclbl-onnx-node" + str(glob.NODE_COUNT)
        glob.NODE_COUNT += 1

    try:
        node = xhelp.make_node(op_type, inputs, outputs, name, **kwargs)
    except Exception as e:
        _print("Unable to create node: " + str(e))
        return False
    return node
Exemple #13
0
def list_inputs(graph: xpb2.GraphProto):
    """ Tries to list the inputs of a given graph.

    Args:
        graph the ONNX graph
    """
    if type(graph) is not xpb2.GraphProto:
        _print("graph is not a valid ONNX graph.")
        return False

    i = 1
    for elem in graph.input:
        name, dtype, shape = _parse_element(elem)
        print("Input {}: Name: '{}', Type: {}, Dimension: {}".format(
            i, name, dtype, shape))
        i += 1

    if i == 1:
        print("No inputs found.")

    return True
Exemple #14
0
def add_node(graph: xpb2.GraphProto, node: xpb2.NodeProto, **kwargs):
    """ Add node appends a node to graph g and returns the extended graph

    Prints a message and returns False if fails.

    Args:
        graph: A graph, onnx.onnx_ml_pb2.GraphProto.
        node: A node, onnx.onnx_ml_pb2.NodeProto.
        **kwargs

    Returns:
        The extended graph.
    """
    if type(graph) is not xpb2.GraphProto:
        _print("The graph is not a valid ONNX graph.")
        return False

    if type(node) is not xpb2.NodeProto:
        _print("The node is not a valid ONNX node.")
        return False

    try:
        graph.node.append(node, **kwargs)
    except Exception as e:
        _print("Unable to extend graph: " + str(e))
        return False
    return graph
Exemple #15
0
def run(graph: xpb2.GraphProto,
        inputs: {},
        outputs: [],
        _tmpfile: str = ".tmp.onnx",
        onnx_opset_version=12,
        **kwargs):
    """ run executes a give graph with the given input and returns the output

    Args:
        graph: The onnx graph
        inputs: an object with the named inputs; please check the data types
        outputs: list of named outputs
        _tmpfile: String the temporary filename for the onnx file to run.
        onnx_opset_version: Optional version number for ONNX opset. Default 12
        
    Returns:
        The result (or False if it fails somewhere)
        """

    store = graph_to_file(graph,
                          _tmpfile,
                          onnx_opset_version=onnx_opset_version)
    if not store:
        _print("Unable to store model for evaluation.")
        return False

    try:
        sess = xrt.InferenceSession(_tmpfile, **kwargs)
        out = sess.run(outputs, inputs)
    except Exception as e:
        _print("Failed to run the model: " + str(e))
        return False

    try:
        os.remove(_tmpfile)
    except Exception as e:
        print("We were unable to delete the file " + _tmpfile, "MSG")

    return out
Exemple #16
0
def graph_from_file(filename: str):
    """ Retrieve a graph object from an onnx file

    Function attempts to open a .onnx file and returns its graph.

    Args:
        filename: String indicating the filename / relative location.

    Returns:
        An ONNX graph or False if unable to open.

    """
    mod_temp = xmp()
    try:
        with open(filename, 'rb') as fid:
            content = fid.read()
            mod_temp.ParseFromString(content)
        graph = mod_temp.graph
    except Exception as e:
        _print("Unable to open the file: " + str(e))
        return False
    return graph
Exemple #17
0
def constant(name: str, value: np.array, data_type: str, **kwargs):
    """ Create a constant node

    Args:
        name: Name of the (output value of the) constant node to determine the graph topology
        value: Values of the node (as a np.array)
        data_type: Data type of the node
        **kwargs

    Returns:
        A constant node.
    """
    if not name:
        _print("Unable to create unnamed constant.")
        return False

    dtype = _data_type(data_type)
    if not dtype:
        return False

    try:
        constant_node = xhelp.make_node('Constant',
                                        inputs=[],
                                        outputs=[name],
                                        name=name + "-constant",
                                        value=xhelp.make_tensor(
                                            name=name + "-values",
                                            data_type=dtype,
                                            dims=value.shape,
                                            vals=value.flatten()),
                                        **kwargs)
    except Exception as e:
        _print("Unable to create the constant node: " + str(e))
        return False

    return constant_node
Exemple #18
0
def add_constant(graph: xpb2.GraphProto, name: str, value: np.array,
                 data_type: str, **kwargs):
    """ Create and add a constant node to an existing graph.

    Note: use add_node() if you want to add an existing constant node to an existing graph

    Args:
        graph: A graph, onnx.onnx_ml_pb2.GraphProto.
        name: Name of the (output value of the) constant node to determine the graph topology
        value: Values of the node (as a np.array)
        data_type: Data type of the node

    Returns:
        The extended graph.
    """
    if type(graph) is not xpb2.GraphProto:
        print("graph is not a valid ONNX graph.")
        return False

    dtype = _data_type(data_type)
    if not dtype:
        return False

    try:
        constant_node = xhelp.make_node('Constant',
                                        inputs=[],
                                        outputs=[name],
                                        name=name + "-constant",
                                        value=xhelp.make_tensor(
                                            name=name + "-values",
                                            data_type=dtype,
                                            dims=value.shape,
                                            vals=value.flatten()),
                                        **kwargs)
    except Exception as e:
        _print("Unable to create the constant node: " + str(e))
        return False

    try:
        graph = add_node(graph, constant_node, **kwargs)
    except Exception as e:
        _print("Unable to add the constant node to the graph: " + str(e))
        return False

    if not graph:
        _print("Unable to add constant node to graph.")
        return False
    return graph
Exemple #19
0
def test__print():
    print("\n")
    _print("Red warning.")
    _print("Normal feedback", "MSG")
    _print("Green literal", "LIT")
    pass
Exemple #20
0
def check(graph: xpb2.GraphProto,
          _producer: str = "sclblonnx",
          _onnx_check: bool = True,
          _sclbl_check: bool = True,
          _verbose: bool = True,
          **kwargs):
    """ check whether or not an existing graph can be converted using the Scailable platform

    We assume that a user will use graph_to_file() in this package to store the model. This

     Args:
        graph: an ONNX graph
        _producer: String optional
        _onnx_check: Bool, default True. Run ONNX checker.check().
        _sclbl_check: Bool, default True.  Run Scailable checks.
        _verbose: Print user feedback; default True (note, errors are always printed).
        **kwargs

    Returns:
        True if the graph passes all the test. False otherwise.
    """
    # Check if this is a valid graph:
    if type(graph) is not xpb2.GraphProto:
        _print("Graph is not a valid ONNX graph.")
        return False

    # Convert to model:
    try:
        if not 'opset_imports' in kwargs:
            op = onnx.OperatorSetIdProto()
            op.version = 12
            mod = xhelp.make_model(graph,
                                   producer_name=_producer,
                                   opset_imports=[op],
                                   **kwargs)
        else:
            mod = xhelp.make_model(graph, producer_name=_producer, **kwargs)
    except Exception as e:
        _print("Unable to create the model: " + str(e))
        return False

    # Standard ONNX checking:
    if _onnx_check and False:
        try:
            checker.check_model(mod, **kwargs)
        except Exception as e:
            _print("Model fails on standard ONNX checker: " + str(e))
            return False

    if _sclbl_check:

        # User feedback
        _print(
            "Running Scailable specific checks for WASM conversion. \nUse _sclbl_check=False to turn off",
            "MSG", (not _verbose))

        # input / output checking:
        if not graph.input:
            _print("This graph does not contain any inputs.")
            return False

        if not graph.output:
            _print("This graph does not contain any outputs.")
            return False

        # Sclbl checking:
        if not glob.ONNX_VERSION_INFO:
            if not _load_version_info():
                _print("Unable to load the ONNX_VERSION INFO.")

        # Check general ONNX version:
        if version.parse(xversion) < version.parse(
                glob.ONNX_VERSION_INFO['onnx_version']['version_min']):
            _print(
                "Your current onnx version is lower then our support minimum. Please update your ONNX to {}"
                .format(glob.ONNX_VERSION_INFO['onnx_version']['version_min']))
            return False

        if version.parse(xversion) > version.parse(
                glob.ONNX_VERSION_INFO['onnx_version']['version_max']):
            _print(
                "Your current onnx version is higher then our support max. Please downgrade your ONNX version to {}"
                .format(glob.ONNX_VERSION_INFO['onnx_version']['version_max']))
            return False

        if mod.ir_version < glob.ONNX_VERSION_INFO['onnx_version'][
                'ir_version_min']:
            _print(
                "Your current IR version is lower then our support minimum. Please update to {}"
                .format(
                    glob.ONNX_VERSION_INFO['onnx_version']['ir_version_min']))
            return False

        if mod.ir_version > glob.ONNX_VERSION_INFO['onnx_version'][
                'ir_version_max']:
            _print(
                "Your current IR version is higher then our support max. Please downgrade to {}"
                .format(
                    glob.ONNX_VERSION_INFO['onnx_version']['ir_version_max']))
            return False

        # Interate through opset and check:
        for key in mod.opset_import:
            v = key.version
            if v < glob.ONNX_VERSION_INFO['onnx_version']['opset_min']:
                _print(
                    "One or more operators use an opset version that is too low. Please update to {}"
                    .format(
                        glob.ONNX_VERSION_INFO['onnx_version']['opset_min']))
                return False

            if v > glob.ONNX_VERSION_INFO['onnx_version']['opset_max']:
                _print(
                    "One or more operators use an opset version that is too high. Please downgrade to {}"
                    .format(
                        glob.ONNX_VERSION_INFO['onnx_version']['opset_max']))
                return False

        # Check individual nodes:
        not_supported = []
        for n in graph.node:
            op = n.op_type
            if op not in glob.ONNX_VERSION_INFO['operators']:
                not_supported.append(op)
        if not_supported:
            _print("The operator(s) {} are currently not supported.".format(
                not_supported))
            return False

        # Check dynamic
        for inputs in graph.input:
            if not inputs.type.tensor_type.shape.dim:
                _print(
                    "Your graph contains dynamically sized inputs, this is currently not supported."
                )
                return False
            for elem in inputs.type.tensor_type.shape.dim:
                if elem.dim_value == 0 or elem.dim_value == "":
                    _print(
                        "Your graph contains dynamically size inputs, this is currently not supported."
                    )

    if not _sclbl_check and not _onnx_check:
        _print("Set _sclbl_check or _onnx_check to True to run any checks.")

    _print("Your graph was successfully checked.", "MSG", (not _verbose))
    return True
Exemple #21
0
def clean(graph: xpb2.GraphProto,
          _optimize: bool = True,
          _simplify: bool = True,
          _remove_initializer: bool = True,
          _producer: str = "sclblonnx",
          _verbose: bool = True,
          **kwargs):
    """ clean cleans an ONNX graph using onnx tooling

    This method will attempt to clean the supplied graph by
    a. Removing initializers from input
    b. Optimizing it using onnxoptimizer.optimize
    c. Simplifying it using onnxsim.simplify

    If one of these fails the method will print an error message and return the unaltered graph.

    Args:
        graph: An ONNX graph
        _optimize: Boolean, default True. Optimize the model using onnxoptimizer.
        _simplify: Boolean, default True. Simplify the model using simplify.
        _remove_initializer: Boolean, default True. Remove initializers from input.
        _producer: Optional string with producer name. Default 'sclblonnx' (used for internal conversion)
        _verbose: Print user feedback; default True (note, errors are always printed).
        **kwargs

    Returns:
        The cleaned ONNX graph, or the old graph if an error occurs.
    """
    try:
        if not 'opset_imports' in kwargs:
            op = onnx.OperatorSetIdProto()
            op.version = 12
            mod = xhelp.make_model(graph,
                                   producer_name=_producer,
                                   opset_imports=[op],
                                   **kwargs)
        else:
            mod = xhelp.make_model(graph, producer_name=_producer, **kwargs)
    except Exception as e:
        _print("Unable to create the model: " + str(e))
        return graph

    if _optimize:
        try:
            mod = onnxoptimizer.optimize(mod, glob.OPTIMIZER_PASSES, **kwargs)
        except Exception as e:
            _print("Unable to optimize your model: " + str(e))
            return graph

    if _simplify:
        try:
            mod, _ = simplify(mod, **kwargs)
        except Exception as e:
            _print("Unable to simplify your model: " + str(e))
            return graph

    # From: onnxruntime/tools/python/remove_initializer_from_input.py
    graph = mod.graph
    if _remove_initializer:
        inputs = graph.input
        name_to_input = {}
        for input in inputs:
            name_to_input[input.name] = input
        for initializer in graph.initializer:
            if initializer.name in name_to_input:
                inputs.remove(name_to_input[initializer.name])

    _print("The graph was successfully cleaned.", "MSG", (not _verbose))
    return graph
Exemple #22
0
def sclbl_input(inputs: {}, example_type: str = "pb", _verbose: bool = True):
    """ input_str returns an example input for a Scailable runtime

    The method takes a valid input object to an onnx graph (i.e., one used for the "inputs" argument
    in the run() function, and returns and prints an example input to a Scailable runtime / REST endpoint

    Args:
        inputs: The input object as supplied to the run() function to test an ONNX grph
        example_type: The type of example string ("raw" for base64 encoded, or "pb" for protobuf, default pb)
        _verbose: Print user feedback; default True (note, errors are always printed).

    Returns:
        An example input to a Scailable runtime.
    """
    if not inputs:
        _print("No input provided.")

    if example_type == "raw":
        if len(inputs) == 1:
            for val in inputs.values():
                bytes = val.tobytes()
                encoded = base64.b64encode(bytes)
                value_str = '"' + encoded.decode('ascii') + '"'
        else:
            value_str = '["'
            for val in inputs.values():
                bytes = val.tobytes()
                encoded = base64.b64encode(bytes)
                value_str += (encoded.decode('ascii') + '","')
            value_str = value_str.rstrip(',"')
            value_str += '"]'

        input_json = '{"input": ' + value_str + ', "type":"raw"}'
        if _verbose:
            _print(
                "The following input string can be used for the Scailable runtime:",
                "MSG")
            _print(input_json, "LIT")
        return input_json

    elif example_type == "pb" or "protobuf":

        if len(inputs) == 1:
            for val in inputs.values():
                tensor = xnp.from_array(val)
                serialized = tensor.SerializeToString()
                encoded = base64.b64encode(serialized)
                value_str = '"' + encoded.decode('ascii') + '"'
        else:
            value_str = '["'
            for val in inputs.values():
                tensor = xnp.from_array(val)
                serialized = tensor.SerializeToString()
                encoded = base64.b64encode(serialized)
                value_str += (encoded.decode('ascii') + '","')
            value_str = value_str.rstrip(',"')
            value_str += '"]'

        input_json = '{"input": ' + value_str + ', "type":"pb"}'
        if _verbose:
            _print(
                "The following input string can be used for the Scailable runtime:",
                "MSG")
            _print(input_json, "LIT")
            _print(
                "The following input string can be used for the web front-end:",
                "MSG")
            _print(value_str, "LIT")
        return input_json
Exemple #23
0
def concat(sg1: xpb2.GraphProto,
           sg2: xpb2.GraphProto,
           complete: bool = False,
           rename_nodes: bool = True,
           io_match: [] = None,
           rename_io: bool = False,
           edge_match: [] = None,
           rename_edges: bool = False,
           rename_init: bool = False,
           _verbose: bool = True,
           **kwargs):
    """
    concat concatenates two graphs.

    Concat is the flexible (but also rather complex) workhorse for the merge, join, and split functions and
    can be used to quite flexibly paste together two (sub)graphs. Contrary to merge, join, and split, concat
    does not by default assume the resulting onnx graph to be complete (i.e., to contain inputs and outputs and to
    pass check()), and it can thus be used as an intermediate function when constructing larger graphs.

    Concat is flexible and versatile, but it takes time to master. See example_merge.py in the examples folder
    for a number of examples.

    Args:
        sg1: Subgraph 1, the parent.
        sg2: Subgraph 2, the child.
        complete: (Optional) Boolean indicating whether the resulting graph should be checked using so.check(). Default False.
        rename_nodes: (Optional) Boolean indicating whether the names of the nodes in the graph should be made unique. Default True.
        io_match: (Optional) Dict containing pairs of outputs of sg1 that should be matched to inputs of sg2. Default [].
        rename_io: (Optional) Boolean indicating whether the inputs and outputs of the graph should be renamed. Default False.
        edge_match: (Optional) Dict containing pairs edge names of sg1 (i.e., node outputs) that should be matched to edges of sg2 (i.e., node inputs). Default [].
        rename_edges: (Optional) Boolean indicating whether the edges should be renamed (default False)
        _verbose: (Optional) Boolean indicating whether verbose output should be printed (default False)
    Returns:
        The concatenated graph g, or False if something goes wrong along the way.
    """
    # immutable defaults:
    if io_match is None:
        io_match = []
    if edge_match is None:
        edge_match = []

    # prevent changes to original
    sg1 = copy.deepcopy(sg1)
    sg2 = copy.deepcopy(sg2)

    # Check input types:
    if type(sg1) is not xpb2.GraphProto:
        _print("Graph sg1 is not an ONNX graph. Abort.")
        return False
    if type(sg2) is not xpb2.GraphProto:
        _print("Graph sg2 is not an ONNX graph. Abort.")
        return False

    # Rename node names if requested (default True)
    if rename_nodes:
        _print("Renaming node names in graph.", "MSG", (not _verbose))
        sg1 = postfix_names(sg1, "_sg1", "node")
        sg2 = postfix_names(sg2, "_sg2", "node")

    if io_match:
        _print("Matching specified inputs and outputs..", "MSG",
               (not _verbose))
        for io_pair in io_match:
            for outputs in sg1.output:
                if outputs.name == io_pair[0]:
                    sg1 = delete_output(sg1, io_pair[0])
            for inputs in sg2.input:
                if inputs.name == io_pair[1]:
                    sg2 = delete_input(sg2, io_pair[1])
            for item in sg2.node:
                for index, name in enumerate(item.input):
                    if name == io_pair[1]:
                        item.input[index] = io_pair[0]

    if rename_io:
        _print("Renaming inputs and outputs.", "MSG", (not _verbose))
        sg1 = postfix_names(sg1, "_sg1", "io")
        sg2 = postfix_names(sg2, "_sg2", "io")

    if edge_match:
        _print("Matching edges.", "MSG", (not _verbose))
        for edge_pair in edge_match:
            for item in sg2.node:
                for index, name in enumerate(item.input):
                    if name == edge_pair[1]:
                        item.input[index] = edge_pair[0]

    if rename_edges:
        _print("Renaming edges.", "MSG", (not _verbose))
        sg1 = postfix_names(sg1, "_sg1", "edge")
        sg2 = postfix_names(sg2, "_sg2", "edge")

    if rename_init:
        _print("Renaming init.", "MSG", (not _verbose))
        sg1 = postfix_names(sg1, "_sg1", "init")
        sg2 = postfix_names(sg2, "_sg2", "init")

    # Paste graphs together:
    _print("Pasting graphs.", "MSG", (not _verbose))
    g = _paste_graphs(sg1, sg2)

    if complete:
        if not check(g, _verbose=_verbose, **kwargs):
            _print(
                "The end result does not pass check(). Are you sure you want a complete result? Set complete=False "
                "to continue concat without checking.")
            return False

    return g
Exemple #24
0
def split(pg: xpb2.GraphProto,
          cg1: xpb2.GraphProto,
          cg2: xpb2.GraphProto,
          cg1_match: [] = None,
          cg2_match: [] = None,
          complete: bool = True,
          _verbose: bool = True,
          **kwargs):
    """
    split takes takes a single parent and matches the outputs to the inputs of two childs (cg1 & cg2)

    Split matches the outputs of pg to the inputs of cg1 and cg2 as specified in cg1_match and cg2_match.
    Desired matches are specified in pairs: [("out1","in1"), ("out2","in2"),...].

    Split by default assumes the resulting joined graph to be complete. Split is merely a wrapper around concat (used
    twice). For more flexible combinations of graphs please see concat().

    Note: ONNX concat operations might give unexpected results if names of elements collide, please use postfix_names()
    to prevent this (and always critically inspect the resulting graph).

    Args:
        pg: The parent graph
        cg1: The left child.
        cg2: The right child.
        cg1_match: (Optional) List of pairs matching outputs of pg to inputs of cg1. Default [].
        cg2_match: (Optional) List of pairs matching outputs of pg to inputs of cg2. Default [].
        complete: (Optional) Boolean indicating whether the resulting graph should be complete (i.e., should pass check). Default True.
        _verbose: (Optional) Boolean indicating whether or not verbose user feedback should be provided. Default True.
    Returns:
        The joined graph g (of False is something fails along the way).
    """
    # immutable defaults:
    if cg1_match is None:
        cg1_match = []
    if cg2_match is None:
        cg2_match = []

    # prevent changes to original
    pg = copy.deepcopy(pg)
    cg1 = copy.deepcopy(cg1)
    cg2 = copy.deepcopy(cg2)

    if type(pg) is not xpb2.GraphProto:
        _print("Graph pg is not an ONNX graph.")
        return False
    if type(cg1) is not xpb2.GraphProto:
        _print("Graph cg1 is not an ONNX graph.")
        return False
    if type(cg2) is not xpb2.GraphProto:
        _print("Graph cg2 is not an ONNX graph.")
        return False

    # Create the split (using concat 2x)
    g1 = concat(pg,
                cg1,
                rename_nodes=True,
                io_match=cg1_match,
                complete=False,
                _verbose=False,
                **kwargs)
    g = concat(g1,
               cg2,
               rename_nodes=True,
               io_match=cg2_match,
               complete=complete,
               _verbose=False,
               **kwargs)
    if not g:
        _print(
            "Graph merge failed. Please checkout concat() for additional options.",
            "MSG", (not _verbose))

    return g
Exemple #25
0
def merge(sg1: xpb2.GraphProto,
          sg2: xpb2.GraphProto,
          outputs: [] = None,
          inputs: [] = None,
          io_match: [] = None,
          complete: bool = True,
          _verbose: bool = True,
          **kwargs):
    """
    merge merges two graphs.

    Given subgraph sg1 and subgraph sg2 merge attempts to link the identified outputs of sg1 to the
    inputs of sg2 resulting in a graph in which sg1 is the parent of sg2.

    Merge expects two complete graphs (i.e., it expects sg1 and sg2 to pass check(). If you would like more
    flexible merge options or partial merge please see the concat function (merge is merely a constrained wrapper
    around concat).

    Note: The args inputs and outputs are present for legacy reasons, we recommend using io_match directly.

    Args:
        sg1: Subgraph 1, the parent.
        sg2: Subgraph 2, the child.
        outputs: (Optional) A list of strings containing the names of the outputs of sg1 that are matched to inputs (in order of the desired match).
        inputs: (Optional) A list of strings containing the names of the inputs of sg2 to which the outputs of sg1 are matched.
        io_match: (Optional) A list of names pairs [("out1","in1"), ("out2","in2"),...]. This is an alternative for the inputs/outputs arguments.
        complete: (Optional) Boolean indicating whether the resulting graph should be complete (i.e., should pass check). Default True.
        _verbose: (Optional) Boolean indicating whether or not verbose user feedback should be provided. Default True.
    Returns:
        The merged graph g, or False (with a printed error message) if something is wrong.
    """
    # immutable defaults:
    if inputs is None:
        inputs = []
    if outputs is None:
        outputs = []
    if io_match is None:
        io_match = []

    # prevent changes to original
    sg1 = copy.deepcopy(sg1)
    sg2 = copy.deepcopy(sg2)

    # Check the inputs:
    if type(sg1) is not xpb2.GraphProto:
        _print("Graph sg1 is not an ONNX graph.")
        return False
    if type(sg2) is not xpb2.GraphProto:
        _print("Graph sg2 is not an ONNX graph.")
        return False
    if len(outputs) != len(inputs):
        _print("The number of outputs and inputs do not match.")
        return False
    if len(inputs) > 0 and len(io_match) > 0:
        _print(
            "Please use either the inputs/outputs arguments OR the io_match argument (not both)."
        )
        return False

    # Construct IO pairs
    if len(inputs) > 0:
        _print("Constructing the io_match list from your input and output.",
               "MSG", (not _verbose))
        io_match = []
        for idx, val in enumerate(outputs):
            io_match.append((val, inputs[idx]))

    # Use concat to do the merge
    g = concat(sg1, sg2, io_match=io_match, complete=complete, **kwargs)
    if not g:
        _print(
            "Graph merge failed. Please checkout concat for additional options.",
            "MSG", (not _verbose))

    return g
Exemple #26
0
def join(pg1: xpb2.GraphProto,
         pg2: xpb2.GraphProto,
         cg: xpb2.GraphProto,
         pg1_match: [] = None,
         pg2_match: [] = None,
         complete: bool = True,
         _verbose: bool = True,
         **kwargs):
    """
    join takes two parent graphs (pg1 & pg2) and merges them with a child graph cg.

    Join matches the outputs of pg1 to the inputs of cg specified in pg1_match, and similarly for pg2 and pg2_match.
    Desired matches are specified in pairs: [("out1","in1"), ("out2","in2"),...].

    Join by default assumes the resulting joined graph to be complete. Join is merely a wrapper around concat (used
    twice). For more flexible combinations of graphs please see concat().

    Note: ONNX concat operations might give unexpected results if names of elements collide, please use postfix_names()
    to prevent this (and always critically inspect the resulting graph).

    Args:
        pg1: Parent graph 1.
        pg2: Parent graph 2.
        cg: Child graph, the graph that will join together pg1 and pg2.
        pg1_match: (Optional) List of pairs matching outputs of pg1 to inputs of cg. Default [].
        pg2_match: (Optional) List of pairs matching outputs of pg2 to inputs of cg. Default [].
        complete: (Optional) Boolean indicating whether the resulting graph should be complete (i.e., should pass check). Default True.
        _verbose: (Optional) Boolean indicating whether or not verbose user feedback should be provided. Default True.
    Returns:
        The joined graph g (of False is something fails along the way).
    """
    # immutable defaults:
    if pg1_match is None:
        pg1_match = []
    if pg2_match is None:
        pg2_match = []

    # prevent changes to original
    pg1 = copy.deepcopy(pg1)
    pg2 = copy.deepcopy(pg2)
    cg = copy.deepcopy(cg)

    if type(pg1) is not xpb2.GraphProto:
        _print("Graph pg1 is not an ONNX graph.")
        return False
    if type(pg2) is not xpb2.GraphProto:
        _print("Graph pg2 is not an ONNX graph.")
        return False
    if type(cg) is not xpb2.GraphProto:
        _print("Graph cg is not an ONNX graph.")
        return False

    # Construct the match list
    io_match = pg1_match
    io_match.extend(pg2_match)

    # Do the joint (2x concat)
    g1 = concat(pg1,
                pg2,
                rename_nodes=True,
                complete=False,
                _verbose=False,
                **kwargs)
    g = concat(g1,
               cg,
               rename_nodes=True,
               io_match=io_match,
               complete=complete,
               _verbose=False,
               **kwargs)
    if not g:
        _print(
            "Graph merge failed. Please checkout concat for additional options.",
            "MSG", (not _verbose))

    return g
Exemple #27
0
def list_data_types():
    """ List all available data types. """
    _print(json.dumps(glob.DATA_TYPES, indent=2), "MSG")
    _print("Note: STRINGS are not supported at this time.", "LIT")
    return True