示例#1
0
def main():
    args = get_args()

    opset = tf2onnx.utils.find_opset(args.opset)
    print("using tensorflow={}, onnx={}, opset={}, tfonnx={}/{}".format(
        tf.__version__, onnx.__version__, opset, tf2onnx.__version__,
        tf2onnx.version.git_version[:6]))

    # override unknown dimensions from -1 to 1 (aka batchsize 1) since not every runtime does
    # support unknown dimensions.
    tf2onnx.utils.ONNX_UNKNOWN_DIMENSION = args.unknown_dim

    if args.custom_ops:
        # default custom ops for tensorflow-onnx are in the "tf" namespace
        custom_ops = {
            op: default_custom_op_handler
            for op in args.custom_ops.split(",")
        }
        extra_opset = [helper.make_opsetid(_TENSORFLOW_DOMAIN, 1)]
    else:
        custom_ops = {}
        extra_opset = None

    graph_def = tf.GraphDef()
    with tf.gfile.GFile(args.input, 'rb') as f:
        graph_def.ParseFromString(f.read())

    # todo: consider to enable const folding by default?
    graph_def = tf_optimize(args.inputs, args.outputs, graph_def,
                            args.fold_const)
    with tf.Graph().as_default() as tf_graph:
        tf.import_graph_def(graph_def, name='')
    with tf.Session(graph=tf_graph):
        g = process_tf_graph(tf_graph,
                             continue_on_error=args.continue_on_error,
                             verbose=args.verbose,
                             target=args.target,
                             opset=args.opset,
                             custom_op_handlers=custom_ops,
                             extra_opset=extra_opset,
                             shape_override=args.shape_override,
                             input_names=args.inputs,
                             output_names=args.outputs,
                             inputs_as_nchw=args.inputs_as_nchw)

    new_model_proto = GraphUtil.opt_transposes_with_graph(
        g,
        "converted from {}".format(args.input),
        optimize=not args.continue_on_error)
    if new_model_proto:
        model_proto = new_model_proto
    else:
        print("NON-CRITICAL, optimizers are not applied successfully")

    # write onnx graph
    if args.output:
        with open(args.output, "wb") as f:
            f.write(model_proto.SerializeToString())
            print("\nComplete successfully, the onnx model is generated at " +
                  args.output)
示例#2
0
    def _import_from_tf_pb(self, graph_def):
        inputs, outputs = _find_out_terminal_node(graph_def, postfix=True)
        print("inputs:{}".format(inputs))
        print("outputs:{}".format(outputs))

        # FIXME: folding const = False
        graph_def = tf2onnx.tfonnx.tf_optimize(inputs, outputs, graph_def,
                                               False)
        with tf.Graph().as_default() as tf_graph:
            tf.import_graph_def(graph_def, name='')
        with tf.Session(graph=tf_graph):
            onnx_graph = tf2onnx.tfonnx.process_tf_graph(
                tf_graph,
                continue_on_error=False,
                verbose=False,
                target=",".join(tf2onnx.tfonnx.DEFAULT_TARGET),
                opset=6,
                input_names=inputs,
                output_names=outputs,
                inputs_as_nchw=None)
        model_proto = onnx_graph.make_model("tf_model")
        new_model_proto = GraphUtil.opt_transposes_with_graph(onnx_graph,
                                                              'tf_model',
                                                              optimize=True)
        if new_model_proto:
            model_proto = new_model_proto
        return model_proto
示例#3
0
    def run_test(self,
                 name,
                 backend="caffe2",
                 debug=False,
                 onnx_file=None,
                 opset=None,
                 perf=None,
                 fold_const=None):
        """Run complete test against backend."""
        print(name)
        self.perf = perf

        # get the model
        if self.url:
            _, dir_name = self.download_file()
            model_path = os.path.join(dir_name, self.local)
        else:
            model_path = self.local
            dir_name = os.path.dirname(self.local)
        print("\tdownloaded", model_path)

        if self.model_type in ["checkpoint"]:
            #
            # if the input model is a checkpoint, convert it to a frozen model
            saver = tf.train.import_meta_graph(model_path)
            with tf.Session() as sess:
                saver.restore(sess, model_path[:-5])
                frozen_graph = freeze_session(sess,
                                              output_names=self.output_names)
                tf.train.write_graph(frozen_graph,
                                     dir_name,
                                     "frozen.pb",
                                     as_text=False)
            model_path = os.path.join(dir_name, "frozen.pb")
        elif self.model_type in ["saved_model"]:
            try:
                from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
                get_signature_def = lambda meta_graph_def, k: \
                    signature_def_utils.get_signature_def_by_key(meta_graph_def, k)
            except ImportError:
                # TF1.12 changed the api
                get_signature_def = lambda meta_graph_def, k: meta_graph_def.signature_def[
                    k]

            # saved_model format - convert to checkpoint
            with tf.Session() as sess:
                meta_graph_def = tf.saved_model.loader.load(
                    sess, [tf.saved_model.tag_constants.SERVING], model_path)
                inputs = {}
                outputs = {}
                for k in meta_graph_def.signature_def.keys():
                    inputs_tensor_info = get_signature_def(meta_graph_def,
                                                           k).inputs
                    for _, input_tensor in sorted(inputs_tensor_info.items()):
                        inputs[
                            input_tensor.name] = sess.graph.get_tensor_by_name(
                                input_tensor.name)
                    outputs_tensor_info = get_signature_def(meta_graph_def,
                                                            k).outputs
                    for _, output_tensor in sorted(
                            outputs_tensor_info.items()):
                        outputs[output_tensor.
                                name] = sess.graph.get_tensor_by_name(
                                    output_tensor.name)
                # freeze uses the node name derived from output:0 so only pass in output:0;
                # it will provide all outputs of that node.
                for o in list(outputs.keys()):
                    if not o.endswith(":0"):
                        del outputs[o]
                frozen_graph = freeze_session(sess,
                                              output_names=list(
                                                  outputs.keys()))
                tf.train.write_graph(frozen_graph,
                                     dir_name,
                                     "frozen.pb",
                                     as_text=False)
            model_path = os.path.join(dir_name, "frozen.pb")

        # create the input data
        inputs = {}
        for k, v in self.input_names.items():
            if isinstance(v, six.text_type) and v.startswith("np."):
                inputs[k] = eval(v)  # pylint: disable=eval-used
            else:
                inputs[k] = self.make_input(v)
        if self.more_inputs:
            for k, v in self.more_inputs.items():
                inputs[k] = v
        tf.reset_default_graph()
        graph_def = graph_pb2.GraphDef()
        with open(model_path, "rb") as f:
            graph_def.ParseFromString(f.read())

        graph_def = tf2onnx.tfonnx.tf_optimize(inputs, self.output_names,
                                               graph_def, fold_const)
        shape_override = {}
        g = tf.import_graph_def(graph_def, name='')
        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True),
                        graph=g) as sess:

            # fix inputs if needed
            for k in inputs.keys():  # pylint: disable=consider-iterating-dictionary
                t = sess.graph.get_tensor_by_name(k)
                dtype = tf.as_dtype(t.dtype).name
                if type != "float32":
                    v = inputs[k]
                    inputs[k] = v.astype(dtype)
            if self.force_input_shape:
                shape_override = self.input_names

            # run the model with tensorflow
            if self.skip_tensorflow:
                print("\ttensorflow", "SKIPPED")
            else:
                tf_results = self.run_tensorflow(sess, inputs)
                print("\ttensorflow", "OK")
            model_proto = None
            try:
                # convert model to onnx
                onnx_graph = self.to_onnx(sess.graph,
                                          opset=opset,
                                          shape_override=shape_override)
                new_model_proto = GraphUtil.opt_transposes_with_graph(
                    onnx_graph, "test", debug=debug)
                if new_model_proto:
                    model_proto = new_model_proto
                else:
                    print(
                        "\tNON-CRITICAL, optimizers are not applied successfully"
                    )
                print("\tto_onnx", "OK")
                if debug:
                    onnx_graph.dump_graph()
                if onnx_file:
                    self.create_onnx_file(name, model_proto, inputs, onnx_file)
            except Exception as ex:
                tb = traceback.format_exc()
                print("\tto_onnx", "FAIL", ex, tb)

        try:
            onnx_results = None
            if backend == "caffe2":
                onnx_results = self.run_caffe2(name, model_proto, inputs)
            elif backend == "onnxmsrtnext":
                onnx_results = self.run_onnxmsrtnext(name, model_proto, inputs)
            elif backend == "onnxruntime":
                onnx_results = self.run_onnxruntime(name, model_proto, inputs)
            else:
                raise ValueError("unknown backend")
            print("\trun_onnx OK")

            try:
                if self.skip_tensorflow:
                    print("\tResults: skipped tensorflow")
                else:
                    if self.check_only_shape:
                        for tf_res, onnx_res in zip(tf_results, onnx_results):
                            np.testing.assert_array_equal(
                                tf_res.shape, onnx_res.shape)
                    else:
                        for tf_res, onnx_res in zip(tf_results, onnx_results):
                            np.testing.assert_allclose(tf_res,
                                                       onnx_res,
                                                       rtol=self.rtol,
                                                       atol=self.atol)
                    print("\tResults: OK")
                return True
            except Exception as ex:
                print("\tResults: ", ex)

        except Exception as ex:
            print("\trun_onnx", "FAIL", ex)

        return False