示例#1
0
    def test_parse_tflite_graph(self):

        def func(a, b, c):
            alpha = tf.constant(1.1, dtype=tf.float32)
            beta = tf.constant(2.3, dtype=tf.float32)
            mul1 = tf.multiply(alpha, tf.matmul(a, b))
            mul2 = tf.multiply(beta, c)
            x_ = mul1 + mul2
            return tf.identity(x_, name="output")

        inp_shapes = [[2, 3], [3, 1], [2, 1]]
        inp_dtypes = [tf.float32, tf.float32, tf.float32]
        names = ['a', 'b', 'c']
        names_with_port = ['a:0', 'b:0', 'c:0']
        output_names = ['output']
        output_names_with_port = ['output:0']

        input_tensors = [tf.TensorSpec(shape=s, dtype=d, name=n) for s, d, n in zip(inp_shapes, inp_dtypes, names)]

        concrete_func = tf.function(func, input_signature=tuple(input_tensors))
        concrete_func = concrete_func.get_concrete_function()
        graph_def = from_function(concrete_func,
                                  input_names=names_with_port,
                                  output_names=output_names_with_port)
        with tf_session() as sess:
            tf.import_graph_def(graph_def, name='')
            sess_inputs = [sess.graph.get_tensor_by_name(k) for k in names_with_port]
            sess_outputs = [sess.graph.get_tensor_by_name(n) for n in output_names_with_port]
            converter = tf.compat.v1.lite.TFLiteConverter.from_session(sess, sess_inputs, sess_outputs)

        tflite_model = converter.convert()
        tflite_path = os.path.join(self.test_data_directory, self._testMethodName + ".tflite")
        dir_name = os.path.dirname(tflite_path)
        tflite_model = converter.convert()
        os.makedirs(dir_name, exist_ok=True)
        with open(tflite_path, 'wb') as f:
            f.write(tflite_model)

        tflite_graphs, opcodes_map, model, tensor_shapes = read_tflite_model(tflite_path)
        self.assertEqual(1, len(tflite_graphs))
        onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, inputs, outputs, _ = \
            parse_tflite_graph(tflite_graphs[0], opcodes_map, model, tensor_shapes_override=tensor_shapes)
        self.assertEqual(2, op_cnt['MUL'])
        self.assertEqual(1, op_cnt['ADD'])
        self.assertEqual(1, op_cnt['FULLY_CONNECTED'])

        self.assertEqual(1, attr_cnt['WeightsFormat'])
        self.assertEqual(names, inputs)
        self.assertEqual(output_names, outputs)

        for name, shape, dtype in zip(names, inp_shapes, inp_dtypes):
            self.assertEqual(shape, output_shapes[name])
            self.assertEqual(dtype, dtypes[name])

        self.assertTrue(len(onnx_nodes) >= 4)
示例#2
0
def process_tf_graph(tf_graph,
                     continue_on_error=False,
                     verbose=False,
                     target=None,
                     opset=None,
                     custom_op_handlers=None,
                     custom_rewriter=None,
                     extra_opset=None,
                     shape_override=None,
                     inputs_as_nchw=None,
                     input_names=None,
                     output_names=None,
                     ignore_default=None,
                     use_default=None,
                     is_subgraph=False,
                     const_node_values=None,
                     tensors_to_rename=None,
                     initialized_tables=None,
                     tflite_path=None,
                     dequantize=False):
    """Convert tensorflow graph to onnx graph.
        Args:
            tf_graph: tensorflow graph
            continue_on_error: if an op can't be processed (aka there is no mapping), continue
            verbose: print summary stats (deprecated)
            target: list of workarounds applied to help certain platforms
            opset: the opset to be used (int, default is latest)
            custom_op_handlers: dictionary of custom ops handlers
            custom_rewriter: list of custom graph rewriters
            extra_opset: list of extra opset's, for example the opset's used by custom ops
            shape_override: dict with inputs that override the shapes given by tensorflow
            inputs_as_nchw: transpose inputs in list from nchw to nhwc
            input_names: list of input node names in graph, input name format as node_name:port_id. Optional.
            output_names: list of output node names in graph, format is node_name:port_id. Optional for tflite.
            ignore_default: list of node names of PlaceholderWithDefault ops to change into Placeholder ops
            use_default: list of node names of PlaceholderWithDefault ops to change into Identity ops using the default
            const_node_values: a dict returned by compress_graph_def mapping node names to tensor values
            tensors_to_rename: an optional dict (string->string) mapping tensor names to new names
            initialized_tables: mapping from table shared_names to tuple of keys and values of table
            tflite_path: Path to a tflite file to convert. If used, pass None to tf_graph
        Return:
            onnx graph
    """
    # NOTE: process_parsed_graph and Graph are always given tensors post-rename.
    # process_tf_graph (this function) gets tensors pre-rename.
    if verbose:
        logger.warning(
            "Argument verbose for process_tf_graph is deprecated. Please use --verbose option instead."
        )
    del verbose

    opset = utils.find_opset(opset)
    if not is_subgraph:
        logger.info("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s",
                    get_tf_version(), utils.get_onnx_version(),
                    tf2onnx.__version__, tf2onnx.version.git_version[:6])
        logger.info("Using opset <onnx, %s>", opset)
        if opset > schemas.get_max_supported_opset_version():
            logger.warning(
                "Currently installed onnx package %s is too low to support opset %s, "
                "please upgrade onnx package to avoid potential conversion issue.",
                utils.get_onnx_version(), opset)

    if shape_override is None:
        shape_override = {}
    if inputs_as_nchw is None:
        inputs_as_nchw = []
    if target is None:
        target = constants.DEFAULT_TARGET

    def check_io(input_names, output_names, output_shapes):
        io_to_check = []
        if input_names:
            io_to_check.extend(input_names)
        if output_names:
            io_to_check.extend(output_names)
        if io_to_check:
            # check output existence in case user passed in wrong output ids
            non_exists = set(io_to_check) - set(output_shapes.keys())
            if non_exists:
                logger.error(
                    "\nFailed to convert: inputs/outputs specified do not exist, make sure your passed"
                    "in format: input/output_node_name:port_id. Problematic inputs/outputs are: %s \n",
                    non_exists)
                raise ValueError("Inputs/Outputs Not Found")

    def rename_tensors_in_dict(d):
        if tensors_to_rename is None:
            return d
        return {tensors_to_rename.get(k, k): v for k, v in d.items()}

    def rename_tensors_in_list(tensors):
        if tensors_to_rename is None or tensors is None:
            return tensors
        return [tensors_to_rename.get(t, t) for t in tensors]

    def rename_tensors_in_nodes(onnx_nodes):
        if tensors_to_rename is None:
            return
        for n in onnx_nodes:
            n.input[:] = rename_tensors_in_list(n.input)
            n.output[:] = rename_tensors_in_list(n.output)

    if tflite_path is not None:
        tflite_graphs, opcodes, model, tensor_shapes = read_tflite_model(
            tflite_path)
        main_g = None
        inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw)
        for i, tfl_graph in enumerate(tflite_graphs):
            is_main_g = i == len(tflite_graphs) - 1
            prefix = '' if is_main_g else tfl_graph.Name().decode() + '_'
            tensor_shapes_from_interpreter = None
            if is_main_g:
                tensor_shapes_from_interpreter = tensor_shapes
            onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \
                parse_tflite_graph(tfl_graph, opcodes, model, prefix, tensor_shapes_from_interpreter)
            g_inputs = f_inputs
            g_outputs = f_outputs
            if is_main_g:
                # Override IO in main graph
                check_io(input_names, output_names, output_shapes)
                if input_names is not None:
                    g_inputs = input_names
                if output_names is not None:
                    g_outputs = output_names
            rename_tensors_in_nodes(onnx_nodes)
            g_inputs = rename_tensors_in_list(g_inputs)
            g_outputs = rename_tensors_in_list(g_outputs)
            output_shapes = rename_tensors_in_dict(output_shapes)
            dtypes = rename_tensors_in_dict(dtypes)
            g = Graph(onnx_nodes, output_shapes, dtypes, target, opset,
                      extra_opset, g_inputs, g_outputs, is_subgraph)
            fg = process_parsed_graph(g,
                                      custom_op_handlers,
                                      inputs_as_nchw,
                                      continue_on_error,
                                      custom_rewriter,
                                      target,
                                      g_outputs, {}, {}, {},
                                      op_cnt,
                                      attr_cnt,
                                      is_tflite=True,
                                      dequantize=dequantize)
            fg.graph_name = graph_name
            if is_main_g:
                main_g = fg
            else:
                set_function(graph_name, fg)

        return main_g

    is_func = is_function(tf_graph)
    if not is_func:
        tf_graph = infer_shape(tf_graph, shape_override)

    outputs_to_values, outputs_to_dtypes = compute_const_folding_using_tf(
        tf_graph, const_node_values, output_names)

    onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \
        tensorflow_to_onnx(tf_graph, shape_override, const_node_values, ignore_default, use_default)
    if not is_subgraph:
        # make tf2onnx internal subgraphs from the tensorflow subgraphs
        ordered_func = resolve_functions(tf_graph)
        for func in ordered_func:
            f_inputs_names = [t.name for t in func.inputs]
            f_output_names = [t.name for t in func.outputs]
            fg = process_tf_graph(func,
                                  continue_on_error,
                                  False,
                                  target,
                                  opset,
                                  custom_op_handlers,
                                  custom_rewriter,
                                  extra_opset,
                                  shape_override,
                                  inputs_as_nchw,
                                  f_inputs_names,
                                  f_output_names,
                                  is_subgraph=True,
                                  const_node_values=const_node_values,
                                  tensors_to_rename=tensors_to_rename,
                                  initialized_tables=initialized_tables)
            fg.graph_name = func.name
            set_function(func.name, fg)

    check_io(input_names, output_names, output_shapes)

    if not is_subgraph:
        rename_tensors_in_nodes(onnx_nodes)
        input_names = rename_tensors_in_list(input_names)
        output_names = rename_tensors_in_list(output_names)
        output_shapes = rename_tensors_in_dict(output_shapes)
        dtypes = rename_tensors_in_dict(dtypes)
        inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw)
    g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset,
              input_names, output_names, is_subgraph)
    g = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw,
                             continue_on_error, custom_rewriter, target,
                             output_names, initialized_tables,
                             outputs_to_values, outputs_to_dtypes, op_cnt,
                             attr_cnt)
    return g