示例#1
0
 def test_remove_input(self):
     model_proto = self.sample_net()
     nodes = model_proto.node
     g = Graph(nodes, output_shapes={}, dtypes={})
     n4 = g.get_node_by_name("n4")
     g.remove_input(n4, n4.input[1])
     result = onnx_to_graphviz(g)
     expected = 'digraph { n1 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] n4 [op_type=Add] ' \
                'n5 [op_type=Abs] n6 [op_type=Identity] input -> n1 n1:0 -> n2 n1:0 -> n3 n2:0 -> n4 ' \
                'n4:0 -> n5 n5:0 -> n6 }'
     self.assertEqual(expected, result)
示例#2
0
def graphs_from_tf(tf_graph,
                   input_names,
                   output_names,
                   shape_override=None,
                   const_node_values=None,
                   ignore_default=None,
                   use_default=None):
    """make tf2onnx internal subgraphs from the tensorflow subgraphs"""
    if shape_override is None:
        shape_override = {}
    ordered_func = resolve_functions(tf_graph)
    subgraphs = []
    for func in ordered_func:
        f_inputs_names = [t.name for t in func.inputs]
        f_output_names = [t.name for t in func.outputs]

        outputs_to_values, _ = compute_const_folding_using_tf(
            func, const_node_values, output_names)

        onnx_nodes, _, _, output_shapes, dtypes, _ = \
            tensorflow_to_onnx(func, shape_override, const_node_values, ignore_default, use_default)

        fg = Graph(onnx_nodes,
                   output_shapes,
                   dtypes,
                   input_names=f_inputs_names,
                   output_names=f_output_names,
                   is_subgraph=True,
                   graph_name=func.name)
        fold_constants_using_tf(fg, outputs_to_values)
        subgraphs.append(fg)

    is_func = is_function(tf_graph)
    if not is_func:
        tf_graph = infer_shape(tf_graph, shape_override)

    outputs_to_values, _ = compute_const_folding_using_tf(
        tf_graph, const_node_values, output_names)

    onnx_nodes, _, _, output_shapes, dtypes, _ = \
        tensorflow_to_onnx(tf_graph, shape_override, const_node_values, ignore_default, use_default)

    utils.check_io(input_names, output_names, output_shapes.keys())
    main_g = Graph(onnx_nodes,
                   output_shapes,
                   dtypes,
                   input_names=input_names,
                   output_names=output_names)
    fold_constants_using_tf(main_g, outputs_to_values)
    return main_g, subgraphs
示例#3
0
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None,
                     opset=None, custom_op_handlers=None, custom_rewriter=None):
    """Convert tensorflow graph to onnx graph.
        Args:
            tf_graph: tensorflow graph
            continue_on_error: if an op can't be processed (aka there is no mapping), continue
            verbose: print summary stats
            target: list of workarounds applied to help certain platforms
            opset: the opset to be used (int, default is latest)
            custom_op_handlers: dictionary of custom ops handlers
            custom_rewriter: list of custom graph rewriters
        Return:
            onnx graph
    """
    def topological_sort(ops):
        if not continue_on_error:
            g.topological_sort(ops)
        else:
            try:
                g.topological_sort(ops)
            except:
                # if we continue on error, ignore graph cycles so we can report all missing ops
                pass

    if target is None:
        target = DEFAULT_TARGET

    onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tensorflow_to_onnx(tf_graph)

    g = Graph(onnx_nodes, output_shapes, dtypes, target, opset)
    ops = g.get_nodes()

    # rewrite graph
    rewriters = [rewrite_transpose, rewrite_flatten, rewrite_random_uniform,
                 rewrite_random_normal, rewrite_dropout]
    if custom_rewriter is not None:
        rewriters.extend(custom_rewriter)
    for rewrite in rewriters:
        ops = rewrite(g, ops)
        g.set_nodes(ops)
    topological_sort(g.get_nodes())

    if custom_op_handlers is None:
        custom_op_handlers = {}
    mapped_op, unmapped_op = tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers)
    topological_sort(g.get_nodes())
    g.update_proto()
    if verbose:
        print("tensorflow ops: {}".format(op_cnt))
        print("tensorflow attr: {}".format(attr_cnt))
        print("onnx mapped: {}".format(mapped_op))
        print("onnx unmapped: {}".format(unmapped_op))
    return g
示例#4
0
 def test_insert_node1(self):
     model_proto = self.sample_net()
     nodes = model_proto.node
     g = Graph(nodes, output_shapes={}, dtypes={})
     n2 = g.get_node_by_name("n2")
     n7 = g.insert_new_node_on_input(n2, "Abs", "n1:0", name="n7")
     ops = g.get_nodes()
     ops.append(n7)
     g.topological_sort(ops)
     result = onnx_to_graphviz(g)
     expected = 'digraph { n1 [op_type=Abs] n7 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] ' \
                'n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \
                'input -> n1 n1:0 -> n7 n7:0 -> n2 n1:0 -> n3 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5:0 -> n6 }'
     self.assertEqual(expected, result)
示例#5
0
def graphs_from_tflite(tflite_path, input_names=None, output_names=None):
    """
    Given the path to a tflite model, returns a tuple (main_graph, subgraphs) of graph.py Graph objects
    inputs/outputs will be taken from main graph in model if not overridden
    """
    tflite_graphs, opcodes, model, tensor_shapes = read_tflite_model(
        tflite_path)
    main_g = None
    subgraphs = []
    for i, tfl_graph in enumerate(tflite_graphs):
        is_main_g = i == len(tflite_graphs) - 1
        prefix = '' if is_main_g else tfl_graph.Name().decode() + '_'
        tensor_shapes_from_interpreter = None
        if is_main_g:
            tensor_shapes_from_interpreter = tensor_shapes
        onnx_nodes, _, _, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \
            parse_tflite_graph(tfl_graph, opcodes, model, prefix, tensor_shapes_from_interpreter)
        g_inputs = f_inputs
        g_outputs = f_outputs
        if is_main_g:
            # Override IO in main graph
            utils.check_io(input_names, output_names, output_shapes.keys())
            if input_names is not None:
                g_inputs = input_names
            if output_names is not None:
                g_outputs = output_names
        g = Graph(onnx_nodes,
                  output_shapes,
                  dtypes,
                  input_names=g_inputs,
                  output_names=g_outputs,
                  is_subgraph=not is_main_g,
                  graph_name=graph_name)
        if is_main_g:
            main_g = g
        else:
            subgraphs.append(g)
    return main_g, subgraphs
示例#6
0
def load_graph(fname):
    with open(fname, "rb") as f:
        data = f.read()
        model_proto = onnx.ModelProto()
        model_proto.ParseFromString(data)
        onnx_nodes = model_proto.graph.node
        output_names = []

        # some pytorch model had empty names - make one up
        for node in onnx_nodes:
            if not node.name:
                node.name = tf2onnx.utils.make_name("was_empty")

        g = Graph(onnx_nodes, output_shapes={}, dtypes={}, output_names=output_names)
        for i in model_proto.graph.initializer:
            v = numpy_helper.to_array(i)
            name = i.name
            g.initializers[name] = i
            dtype = i.data_type
            g.set_dtype(name, dtype)
            g.set_shape(name, v.shape)
        for i in model_proto.graph.input:
            name = i.name
            if name in g.initializers:
                # ignore if it is not a model input
                continue
            shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1
                     for j in i.type.tensor_type.shape.dim]
            dtype = i.type.tensor_type.elem_type
            g.set_dtype(name, dtype)
            g.set_shape(name, shape)
            g.add_graph_input(name, dtype, shape)
        for i in model_proto.graph.output:
            name = i.name
            shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1
                     for j in i.type.tensor_type.shape.dim]
            dtype = i.type.tensor_type.elem_type
            g.set_dtype(name, dtype)
            g.set_shape(name, shape)
            output_names.append(name)

        # TODO: this is a hack in case a output name does not follow tensorflow convention
        for node in g.get_nodes():
            for name in node.output:
                g._nodes_by_name[name] = node  # pylint: disable=protected-access

    return g, model_proto.producer_name
示例#7
0
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None,
                     opset=None, custom_op_handlers=None, custom_rewriter=None,
                     extra_opset=None, shape_override=None, inputs_as_nchw=None,
                     input_names=None, output_names=None, is_subgraph=False):
    """Convert tensorflow graph to onnx graph.
        Args:
            tf_graph: tensorflow graph
            continue_on_error: if an op can't be processed (aka there is no mapping), continue
            verbose: print summary stats (deprecated)
            target: list of workarounds applied to help certain platforms
            opset: the opset to be used (int, default is latest)
            custom_op_handlers: dictionary of custom ops handlers
            custom_rewriter: list of custom graph rewriters
            extra_opset: list of extra opset's, for example the opset's used by custom ops
            shape_override: dict with inputs that override the shapes given by tensorflow
            inputs_as_nchw: transpose inputs in list from nchw to nchw
            input_names: list of input node names in graph, input name format as node_name:port_id
            output_names: list of output node names in graph, output name format as node_name:port_id
        Return:
            onnx graph
    """
    if verbose:
        logger.warning("Argument verbose for process_tf_graph is deprecated. Please use --verbose option instead.")
    del verbose

    opset = utils.find_opset(opset)
    if not is_subgraph:
        logger.info("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s",
                    get_tf_version(), utils.get_onnx_version(), tf2onnx.__version__, tf2onnx.version.git_version[:6])
        logger.info("Using opset <onnx, %s>", opset)
        if opset > schemas.get_max_supported_opset_version():
            logger.warning("Currently installed onnx package %s is too low to support opset %s, "
                           "please upgrade onnx package to avoid potential conversion issue.",
                           utils.get_onnx_version(), opset)

    is_func = is_function(tf_graph)
    if not is_func:
        tf_graph = infer_shape(tf_graph, shape_override)

    if shape_override is None:
        shape_override = {}
    if inputs_as_nchw is None:
        inputs_as_nchw = []
    if target is None:
        target = constants.DEFAULT_TARGET

    onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = tensorflow_to_onnx(tf_graph, shape_override)
    if not is_subgraph:
        # make tf2onnx internal subgraphs from the tensorflow subgraphs
        ordered_func = resolve_functions(tf_graph)
        for func in ordered_func:
            f_inputs_names = [t.name for t in func.inputs]
            f_output_names = [t.name for t in func.outputs]
            fg = process_tf_graph(func, continue_on_error, False, target, opset,
                                  custom_op_handlers, custom_rewriter,
                                  extra_opset, shape_override, inputs_as_nchw,
                                  f_inputs_names, f_output_names, is_subgraph=True)
            fg.graph_name = func.name
            fg.func_inputs = f_inputs_names
            set_function(func.name, fg)

    io_to_check = []
    if input_names:
        io_to_check.extend(input_names)
    if output_names:
        io_to_check.extend(output_names)

    if io_to_check:
        # check output existence in case user passed in wrong output ids
        non_exists = set(io_to_check) - set(output_shapes.keys())
        if non_exists:
            logger.error("\nFailed to convert: inputs/outputs specified do not exist, make sure your passed"
                         "in format: input/output_node_name:port_id. Problematical inputs/outputs are: %s \n",
                         non_exists)
            raise ValueError("Inputs/Outputs Not Found")

    g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, output_names, is_subgraph=is_subgraph)

    # create ops mapping for the desired opsets
    ops_mapping = handler.tf_op.create_mapping(g.opset, g.extra_opset)

    # apply custom ops on top of the assembled opset. We can either complement the opset
    # or override existing ops with a custom op.
    if custom_op_handlers is not None:
        # below is a bit tricky since there are a few api's:
        # 1. the future way we want custom ops to be registered with the @tf_op decorator. THose handlers will be
        #     registered via the decorator on load of the module ... nothing is required here.
        # 2. the old custom op api: a dictionary of {name: (func, args[])
        #     We deal with this by using a compat_handler that wraps to old handler with a new style handler.
        #     This is tempoary to give people give to move to the new api and after tf2onnx-1.5 we want to remove this
        custom_opset = {}
        for k, v in custom_op_handlers.items():
            # FIXME: remove this after tf2onnx-1.5
            def compat_handler(ctx, node, **kwargs):
                # wrap old handler
                name = node.name
                args = kwargs["args"]
                func = kwargs["func"]
                return func(ctx, node, name, args)

            args = v[1]
            kwargs = {"func": v[0]}
            if args:
                onnx_op = args[0]
                kwargs["onnx_op"] = onnx_op
                args = args[1:]
            kwargs["args"] = args
            new_handler = handler.tf_op(k,
                                        domain=constants.TENSORFLOW_OPSET.domain,
                                        kwargs=kwargs)
            new_handler.register_compat_handler(compat_handler, 1)
            custom_opset[k] = (compat_handler, kwargs)
        ops_mapping.update(custom_opset)

    if inputs_as_nchw:
        transpose_inputs(g, inputs_as_nchw)

    # pre-processing graph rewrites
    # bi-directional re-writer should be placed after single directional re-writer
    rewriters = [rewrite_transpose, rewrite_flatten, rewrite_gemm,
                 rewrite_random_uniform, rewrite_random_uniform_fold_const,
                 rewrite_random_normal, rewrite_dropout, rewrite_eye,
                 rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad,
                 rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
                 rewrite_single_direction_gru, rewrite_bi_direction_gru,
                 rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond,
                 rewrite_biasadd_with_conv2d,
                 ]

    if custom_rewriter is not None:
        rewriters.extend(custom_rewriter)

    run_rewriters(g, rewriters, continue_on_error)

    # some nodes may already copied into inner Graph, so remove them from main Graph.
    g.delete_unused_nodes(output_names)
    topological_sort(g, continue_on_error)

    mapped_op, unmapped_op, exceptions = tensorflow_onnx_mapping(g, ops_mapping)
    if unmapped_op:
        logger.error("Unsupported ops: %s", unmapped_op)
    if exceptions and not continue_on_error:
        raise exceptions[0]

    # post-processing rewriters
    late_rewriters = []
    if constants.TARGET_RS5 in target:
        late_rewriters.append(rewrite_incomplete_type_support_rs5)
    if constants.TARGET_RS6 in target:
        late_rewriters.append(rewrite_incomplete_type_support_rs6)
    if late_rewriters:
        run_rewriters(g, late_rewriters, continue_on_error)

    # onnx requires topological sorting
    topological_sort(g, continue_on_error)

    g.update_proto()

    logger.verbose(
        "Summay Stats:\n"
        "\ttensorflow ops: {}\n"
        "\ttensorflow attr: {}\n"
        "\tonnx mapped: {}\n"
        "\tonnx unmapped: {}".format(op_cnt, attr_cnt, mapped_op, unmapped_op))

    return g
示例#8
0
def read_tfjs_graph(nodes,
                    weights,
                    func=None,
                    graph_inputs=None,
                    graph_outputs=None,
                    shape_override=None,
                    ignore_default=None,
                    use_default=None):
    """Creates an onnx graph from the provided tfjs nodes"""
    if shape_override is None:
        shape_override = {}
    onnx_nodes = []
    output_shapes = {}
    tf_dtypes = {}
    op_info = {}
    graph_name = 'tfjs_model'
    func_name = None

    def update_shapes(new_shapes):
        if isinstance(new_shapes, dict):
            new_shapes = new_shapes.items()
        for k, v in new_shapes:
            output_shapes[k] = shape_override.get(k, v)

    if func is not None:
        tf_dtypes, fn_input_shapes, graph_inputs, graph_outputs, func_name = read_tfjs_function(
            func)
        update_shapes(fn_input_shapes)
        graph_name = func_name
        for inp in graph_inputs:
            onnx_nodes.append(
                helper.make_node("Placeholder", [], outputs=[inp], name=inp))

    if graph_inputs is None:
        placeholder_ops = [
            "Placeholder", "PlaceholderWithDefault", "PlaceholderV2"
        ]
        graph_inputs = [
            n['name'] + ':0' for n in nodes if n['op'] in placeholder_ops
        ]

    for node in nodes:
        if node['op'] == "NextIteration":
            # NextIteration nodes can violate the topological sort with cyclic dependencies, so we do them first.
            node_name = node['name']
            output_name = node_name + ':0'
            output_shapes[output_name] = None
            tf_dtypes[output_name] = read_tfjs_attr(node['attr']['T'],
                                                    tf_dtypes=True)
            op_info[node_name] = (node['op'], {
                'dtype': tf_dtypes[output_name]
            }, [tf_dtypes[output_name]])

    for node in nodes:
        op_type = node['op']
        node_name = node['name']
        if op_type == "Const":
            np_arr = weights[node_name]
            out_name = node_name + ':0'
            tf_dtype = read_tfjs_attr(node['attr']['dtype'], tf_dtypes=True)
            onnx_dtype = tf_utils.map_tf_dtype(tf_dtype)
            # The dtype of a Const in tfjs can differ from that of the weight used to get its value
            np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype)
            onnx_tensor = numpy_helper.from_array(np_arr.astype(np_dtype),
                                                  out_name)
            onnx_node = helper.make_node("Const", [],
                                         outputs=[out_name],
                                         name=node_name,
                                         value=onnx_tensor)
            onnx_nodes.append(onnx_node)
            output_shapes[out_name] = shape_override.get(
                out_name, list(np_arr.shape))
            tf_dtypes[out_name] = tf_dtype
            op_info[node_name] = (op_type, {'dtype': tf_dtypes[out_name]}, [])
            continue
        tf_attr = {}
        onnx_attr = {}
        fix_string_attr(node)
        node_def = tfjs_node_to_tf_node_def(node)
        for k, v in node.get('attr', {}).items():
            tf_attr[k] = read_tfjs_attr(v, tf_dtypes=True)
            if k in tf_utils.TF_IGNORED_NODE_ATTRS:
                continue
            if k == 'DstT':
                k = 'to'
            onnx_attr[k] = read_tfjs_attr(v)
        if op_type == "FusedDepthwiseConv2dNative":
            # This op isn't in tensorflow but can be converted to a TF op
            op_type = "_FusedDepthwiseConv2dNative"
            err_msg = "explicit_paddings for supported for _FusedDepthwiseConv2dNative"
            utils.make_sure(len(tf_attr['explicit_paddings']) == 0, err_msg)
            del tf_attr['explicit_paddings']
            del onnx_attr['explicit_paddings']
            del node_def.attr['explicit_paddings']
            node_def.op = op_type

        input_names = [
            inp for inp in node.get('input', []) if not inp.startswith('^')
        ]
        input_names = [
            resolve_output(inp, op_info, func_name) for inp in input_names
        ]
        inp_dtypes = [tf_dtypes[inp] for inp in input_names]
        inp_shapes = [output_shapes[inp] for inp in input_names]
        inp_consts = [weights.get(inp.split(':')[0]) for inp in input_names]
        out_dtypes = get_output_dtypes(op_type, tf_attr, inp_dtypes)
        out_shapes = get_output_shapes(node_def, inp_dtypes, inp_shapes,
                                       inp_consts)
        op_info[node_name] = (op_type, tf_attr, inp_dtypes)

        output_names = [
            node_name + ":" + str(i) for i in range(len(out_dtypes))
        ]
        tf_dtypes.update(zip(output_names, out_dtypes))
        update_shapes(zip(output_names, out_shapes))

        if op_type == "PlaceholderWithDefault":
            remove = False
            if ignore_default and node_name in ignore_default:
                op_type = 'Placeholder'
                input_names = []
            elif use_default and node_name in use_default:
                remove = True
            elif node_name.endswith('keras_learning_phase'):
                logger.warning(
                    "Removing optional input %s that appears to be a keras learning phase parameter. "
                    "Use --ignore_default to force this into an input.",
                    node_name)
                remove = True
            if remove:
                op_type = 'Identity'
                graph_inputs = [
                    inp for inp in graph_inputs if inp != node_name + ":0"
                ]

        onnx_node = helper.make_node(op_type,
                                     input_names,
                                     output_names,
                                     name=node_name,
                                     **onnx_attr)
        onnx_nodes.append(onnx_node)

    dtypes = {k: tf_utils.map_tf_dtype(v) for k, v in tf_dtypes.items()}
    if graph_outputs is None:
        output_to_node = {
            out: node.name
            for node in onnx_nodes for out in node.output
        }
        node_to_outputs = {node.name: list(node.output) for node in onnx_nodes}
        used_nodes = set(output_to_node[out] for node in onnx_nodes
                         for out in node.input)
        unused_nodes = [
            node for node in onnx_nodes if node.name not in used_nodes
        ]
        graph_outputs = [
            out for node in unused_nodes for out in node_to_outputs[node.name]
        ]
    graph_outputs_mapped = [
        resolve_output(out, op_info, func_name) for out in graph_outputs
    ]

    g = Graph(onnx_nodes,
              output_shapes,
              dtypes,
              input_names=graph_inputs,
              output_names=graph_outputs_mapped,
              is_subgraph=func is not None,
              graph_name=graph_name)
    g.rename_tensors(dict(zip(graph_outputs_mapped, graph_outputs)))
    return g
示例#9
0
def process_tf_graph(graph, continue_on_error=False, verbose=False, target=None, opset=0):
    """Convert tensorflow graph to onnx graph."""

    if target is None:
        target = DEFAULT_TARGET

    onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tensorflow_to_onnx(graph)

    g = Graph(onnx_nodes, output_shapes, dtypes, target, opset)
    ops = g.get_nodes()

    # rewrites
    for rewrite in [rewrite_flatten,
                    rewrite_random_uniform,
                    rewrite_random_normal,
                    rewrite_dropout,
                    rewrite_transpose]:
        ops = rewrite(g, ops)
        g.set_nodes(ops)

    g.topological_sort(g.get_nodes())
    mapped_op, unmapped_op = tensorflow_onnx_mapping(g, continue_on_error)
    g.topological_sort(g.get_nodes())

    g.update_proto()
    if verbose:
        print("tensorflow ops: {}".format(op_cnt))
        print("tensorflow attr: {}".format(attr_cnt))
        print("onnx mapped: {}".format(mapped_op))
        print("onnx unmapped: {}".format(unmapped_op))
    return g
    def rewrite(self):
        log.debug("enter custom rnn late rewriter")
        nodes = self.g.get_nodes()
        nodes_to_remove = []
        for scan_node in nodes:
            if scan_node.type != "Scan":
                continue
            log.debug("late write for scan node %s", scan_node.name)
            num_scan_inputs = scan_node.get_attr("num_scan_inputs").i
            if not BodyGraphDict.has_body_graph_info(scan_node.name):
                continue

            body_graph_meta = BodyGraphDict.pop_body_graph_info(scan_node.name)
            onnx_nodes, _ = LoopRewriterBase.find_subgraph(
                body_graph_meta, self.g)
            nodes_to_remove.extend(onnx_nodes)

            log.debug("start creating body graph for scan node %s ",
                      scan_node.name)
            body_graph_initializers = {}
            const_nodes = [
                n for n in onnx_nodes if n.type in ("Const", "ConstV2")
            ]
            for n in const_nodes:
                # when set nodes, Const should be removed, they need be replaced as initializers.
                body_graph_initializers[n.output[0]] = self.g.initializers[
                    n.output[0]]
                onnx_nodes.remove(n)

            onnx_nodes = set(onnx_nodes)

            ops = []
            for op in onnx_nodes:
                onnx_op = op.op
                ops.append(onnx_op)

            body_g = Graph(ops,
                           output_shapes=self.g._output_shapes,
                           dtypes=self.g._dtypes)
            body_g._initializers = body_graph_initializers

            log.debug("start preparing body graph inputs nodes")
            temp_nodes = body_g.get_nodes()
            i = 0
            input_count = len(body_graph_meta.input_ids)
            for input_name, init_input_id in zip(
                    body_graph_meta.input_ids,
                    body_graph_meta.initial_input_ids):
                shape = body_g.get_shape(input_name)
                dtype = body_g.get_dtype(input_name)
                if shape is None:
                    shape = self.g.get_shape(init_input_id)
                    if i >= input_count - num_scan_inputs:
                        loop_input_shape = list(shape)[2:]  # delete [1, time,]
                    else:
                        loop_input_shape = list(shape)
                else:
                    loop_input_shape = list(shape)

                onnx_input_shape = utils.make_onnx_shape(loop_input_shape)
                val = helper.make_tensor_value_info(input_name, dtype,
                                                    onnx_input_shape)
                body_g.add_model_input(input_name, val)
                i += 1

            log.debug("start preparing body graph outputs nodes")
            new_output_names = []
            for o in body_graph_meta.output_ids:
                # insert identity node, since sometimes we need output same output_id as state_output
                # and scan_out, but ONNX don't allow the same output_id appeared more than once as
                # output node.
                identity_name = utils.make_name("Identity")
                identity_output = utils.port_name(identity_name)
                node = Node(
                    helper.make_node("Identity", [o], [identity_output],
                                     name=identity_name), body_g)
                body_g.set_dtype(identity_output, body_g.get_dtype(o))
                body_g.copy_shape(o, identity_output)
                new_output_names.append(identity_output)
                temp_nodes.append(node)

            body_g.set_nodes(temp_nodes)
            body_g.topological_sort(body_g.get_nodes())

            log.debug("start make graph based on body graph nodes")
            body_g.output_names = new_output_names
            graph = body_g.make_graph("scan body graph")
            scan_node.set_attr("body", graph)

        # remove nodes in body graph from g
        for n in set(nodes_to_remove):
            if n in nodes:
                nodes.remove(n)
            elif self.g.is_initializer(n.output[0]):
                del self.g.initializers[n.output[0]]
            else:
                raise ValueError("error when removing nodes")

        return nodes
示例#11
0
 def _create_empty_graph(self, inputs, shapes, dtypes):
     graph = Graph([], target=self.config.target, opset=self.config.opset)
     for inp, shape, dtype in zip(inputs, shapes, dtypes):
         graph.add_graph_input(inp, dtype, shape)
     return graph
示例#12
0
def process_tf_graph(tf_graph,
                     continue_on_error=False,
                     verbose=False,
                     target=None,
                     opset=None,
                     custom_op_handlers=None,
                     custom_rewriter=None,
                     extra_opset=None,
                     shape_override=None,
                     inputs_as_nchw=None,
                     input_names=None,
                     output_names=None,
                     ignore_default=None,
                     use_default=None,
                     is_subgraph=False,
                     const_node_values=None,
                     tensors_to_rename=None,
                     initialized_tables=None,
                     tflite_path=None,
                     dequantize=False):
    """Convert tensorflow graph to onnx graph.
        Args:
            tf_graph: tensorflow graph
            continue_on_error: if an op can't be processed (aka there is no mapping), continue
            verbose: print summary stats (deprecated)
            target: list of workarounds applied to help certain platforms
            opset: the opset to be used (int, default is latest)
            custom_op_handlers: dictionary of custom ops handlers
            custom_rewriter: list of custom graph rewriters
            extra_opset: list of extra opset's, for example the opset's used by custom ops
            shape_override: dict with inputs that override the shapes given by tensorflow
            inputs_as_nchw: transpose inputs in list from nchw to nhwc
            input_names: list of input node names in graph, input name format as node_name:port_id. Optional.
            output_names: list of output node names in graph, format is node_name:port_id. Optional for tflite.
            ignore_default: list of node names of PlaceholderWithDefault ops to change into Placeholder ops
            use_default: list of node names of PlaceholderWithDefault ops to change into Identity ops using the default
            const_node_values: a dict returned by compress_graph_def mapping node names to tensor values
            tensors_to_rename: an optional dict (string->string) mapping tensor names to new names
            initialized_tables: mapping from table shared_names to tuple of keys and values of table
            tflite_path: Path to a tflite file to convert. If used, pass None to tf_graph
        Return:
            onnx graph
    """
    # NOTE: process_parsed_graph and Graph are always given tensors post-rename.
    # process_tf_graph (this function) gets tensors pre-rename.
    if verbose:
        logger.warning(
            "Argument verbose for process_tf_graph is deprecated. Please use --verbose option instead."
        )
    del verbose

    opset = utils.find_opset(opset)
    if not is_subgraph:
        logger.info("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s",
                    get_tf_version(), utils.get_onnx_version(),
                    tf2onnx.__version__, tf2onnx.version.git_version[:6])
        logger.info("Using opset <onnx, %s>", opset)
        if opset > schemas.get_max_supported_opset_version():
            logger.warning(
                "Currently installed onnx package %s is too low to support opset %s, "
                "please upgrade onnx package to avoid potential conversion issue.",
                utils.get_onnx_version(), opset)

    if shape_override is None:
        shape_override = {}
    if inputs_as_nchw is None:
        inputs_as_nchw = []
    if target is None:
        target = constants.DEFAULT_TARGET

    def check_io(input_names, output_names, output_shapes):
        io_to_check = []
        if input_names:
            io_to_check.extend(input_names)
        if output_names:
            io_to_check.extend(output_names)
        if io_to_check:
            # check output existence in case user passed in wrong output ids
            non_exists = set(io_to_check) - set(output_shapes.keys())
            if non_exists:
                logger.error(
                    "\nFailed to convert: inputs/outputs specified do not exist, make sure your passed"
                    "in format: input/output_node_name:port_id. Problematic inputs/outputs are: %s \n",
                    non_exists)
                raise ValueError("Inputs/Outputs Not Found")

    def rename_tensors_in_dict(d):
        if tensors_to_rename is None:
            return d
        return {tensors_to_rename.get(k, k): v for k, v in d.items()}

    def rename_tensors_in_list(tensors):
        if tensors_to_rename is None or tensors is None:
            return tensors
        return [tensors_to_rename.get(t, t) for t in tensors]

    def rename_tensors_in_nodes(onnx_nodes):
        if tensors_to_rename is None:
            return
        for n in onnx_nodes:
            n.input[:] = rename_tensors_in_list(n.input)
            n.output[:] = rename_tensors_in_list(n.output)

    if tflite_path is not None:
        tflite_graphs, opcodes, model, tensor_shapes = read_tflite_model(
            tflite_path)
        main_g = None
        inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw)
        for i, tfl_graph in enumerate(tflite_graphs):
            is_main_g = i == len(tflite_graphs) - 1
            prefix = '' if is_main_g else tfl_graph.Name().decode() + '_'
            tensor_shapes_from_interpreter = None
            if is_main_g:
                tensor_shapes_from_interpreter = tensor_shapes
            onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \
                parse_tflite_graph(tfl_graph, opcodes, model, prefix, tensor_shapes_from_interpreter)
            g_inputs = f_inputs
            g_outputs = f_outputs
            if is_main_g:
                # Override IO in main graph
                check_io(input_names, output_names, output_shapes)
                if input_names is not None:
                    g_inputs = input_names
                if output_names is not None:
                    g_outputs = output_names
            rename_tensors_in_nodes(onnx_nodes)
            g_inputs = rename_tensors_in_list(g_inputs)
            g_outputs = rename_tensors_in_list(g_outputs)
            output_shapes = rename_tensors_in_dict(output_shapes)
            dtypes = rename_tensors_in_dict(dtypes)
            g = Graph(onnx_nodes, output_shapes, dtypes, target, opset,
                      extra_opset, g_inputs, g_outputs, is_subgraph)
            fg = process_parsed_graph(g,
                                      custom_op_handlers,
                                      inputs_as_nchw,
                                      continue_on_error,
                                      custom_rewriter,
                                      target,
                                      g_outputs, {}, {}, {},
                                      op_cnt,
                                      attr_cnt,
                                      is_tflite=True,
                                      dequantize=dequantize)
            fg.graph_name = graph_name
            if is_main_g:
                main_g = fg
            else:
                set_function(graph_name, fg)

        return main_g

    is_func = is_function(tf_graph)
    if not is_func:
        tf_graph = infer_shape(tf_graph, shape_override)

    outputs_to_values, outputs_to_dtypes = compute_const_folding_using_tf(
        tf_graph, const_node_values, output_names)

    onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \
        tensorflow_to_onnx(tf_graph, shape_override, const_node_values, ignore_default, use_default)
    if not is_subgraph:
        # make tf2onnx internal subgraphs from the tensorflow subgraphs
        ordered_func = resolve_functions(tf_graph)
        for func in ordered_func:
            f_inputs_names = [t.name for t in func.inputs]
            f_output_names = [t.name for t in func.outputs]
            fg = process_tf_graph(func,
                                  continue_on_error,
                                  False,
                                  target,
                                  opset,
                                  custom_op_handlers,
                                  custom_rewriter,
                                  extra_opset,
                                  shape_override,
                                  inputs_as_nchw,
                                  f_inputs_names,
                                  f_output_names,
                                  is_subgraph=True,
                                  const_node_values=const_node_values,
                                  tensors_to_rename=tensors_to_rename,
                                  initialized_tables=initialized_tables)
            fg.graph_name = func.name
            set_function(func.name, fg)

    check_io(input_names, output_names, output_shapes)

    if not is_subgraph:
        rename_tensors_in_nodes(onnx_nodes)
        input_names = rename_tensors_in_list(input_names)
        output_names = rename_tensors_in_list(output_names)
        output_shapes = rename_tensors_in_dict(output_shapes)
        dtypes = rename_tensors_in_dict(dtypes)
        inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw)
    g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset,
              input_names, output_names, is_subgraph)
    g = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw,
                             continue_on_error, custom_rewriter, target,
                             output_names, initialized_tables,
                             outputs_to_values, outputs_to_dtypes, op_cnt,
                             attr_cnt)
    return g