Exemplo n.º 1
0
    def test_trans_output_as_graph_outputs(self):
        """
        If transpose's output is graph's output, don't optimize it.
        """
        trans = helper.make_node("Transpose", ["X"], ["Y"], name="trans", perm=[0, 2, 3, 1])
        graph_proto = helper.make_graph(
            [trans],
            "trans-to-graph-output",
            [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
            [helper.make_tensor_value_info("Y", TensorProto.FLOAT, (2, 4, 5, 3))],
        )

        graph = GraphUtil.create_graph_from_onnx_graph(graph_proto)
        # remove identity to graph output
        identity_op = graph.get_node_by_output(graph.outputs[0])
        graph.outputs = [identity_op.input[0]]
        graph.remove_node(identity_op.name)

        optimized_graph = GraphUtil.optimize_graph(graph)

        self.assertTrue(optimized_graph, msg="graph after optimizer should not be None")

        trans_cnt = len(group_nodes_by_type(optimized_graph)["Transpose"])

        self.assertTrue(trans_cnt == 1, msg="Expect 1 Transpose ops left, but actually " + str(trans_cnt) + " left")
    def run_test(self,
                 name,
                 backend="caffe2",
                 debug=False,
                 onnx_file=None,
                 opset=None,
                 perf=None,
                 fold_const=None):
        """Run complete test against backend."""
        print(name)
        self.perf = perf

        # get the model
        if self.url:
            _, dir_name = self.download_file()
            model_path = os.path.join(dir_name, self.local)
        else:
            model_path = self.local
            dir_name = os.path.dirname(self.local)
        print("\tdownloaded", model_path)

        inputs = list(self.input_names.keys())
        outputs = self.output_names
        if self.model_type in ["checkpoint"]:
            graph_def, inputs, outputs = loader.from_checkpoint(
                model_path, inputs, outputs)
        elif self.model_type in ["saved_model"]:
            graph_def, inputs, outputs = loader.from_saved_model(
                model_path, inputs, outputs)
        else:
            graph_def, inputs, outputs = loader.from_graphdef(
                model_path, inputs, outputs)

        # create the input data
        inputs = {}
        for k, v in self.input_names.items():
            if isinstance(v, six.text_type) and v.startswith("np."):
                inputs[k] = eval(v)  # pylint: disable=eval-used
            else:
                inputs[k] = self.make_input(v)
        if self.more_inputs:
            for k, v in self.more_inputs.items():
                inputs[k] = v

        graph_def = tf2onnx.tfonnx.tf_optimize(inputs.keys(),
                                               self.output_names, graph_def,
                                               fold_const)
        shape_override = {}
        g = tf.import_graph_def(graph_def, name='')
        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True),
                        graph=g) as sess:

            # fix inputs if needed
            for k in inputs.keys():  # pylint: disable=consider-iterating-dictionary
                t = sess.graph.get_tensor_by_name(k)
                dtype = tf.as_dtype(t.dtype).name
                if type != "float32":
                    v = inputs[k]
                    inputs[k] = v.astype(dtype)
            if self.force_input_shape:
                for k, v in inputs.items():
                    shape_override[k] = list(v.shape)

            # run the model with tensorflow
            if self.skip_tensorflow:
                print("\ttensorflow", "SKIPPED")
            else:
                tf_results = self.run_tensorflow(sess, inputs)
                print("\ttensorflow", "OK")
            model_proto = None
            try:
                # convert model to onnx
                onnx_graph = self.to_onnx(sess.graph,
                                          opset=opset,
                                          shape_override=shape_override,
                                          input_names=inputs.keys())
                model_proto = onnx_graph.make_model("converted from tf2onnx")
                new_model_proto = GraphUtil.optimize_graph(onnx_graph,
                                                           "test",
                                                           debug=debug)
                if new_model_proto:
                    model_proto = new_model_proto
                else:
                    print(
                        "\tNON-CRITICAL, optimizers are not applied successfully"
                    )
                print("\tto_onnx", "OK")
                if debug:
                    onnx_graph.dump_graph()
                if onnx_file:
                    self.create_onnx_file(name, model_proto, inputs, onnx_file)
            except Exception as ex:
                tb = traceback.format_exc()
                print("\tto_onnx", "FAIL", ex, tb)

        try:
            onnx_results = None
            if backend == "caffe2":
                onnx_results = self.run_caffe2(name, model_proto, inputs)
            elif backend == "onnxmsrtnext":
                onnx_results = self.run_onnxmsrtnext(name, model_proto, inputs)
            elif backend == "onnxruntime":
                onnx_results = self.run_onnxruntime(name, model_proto, inputs)
            else:
                raise ValueError("unknown backend")
            print("\trun_onnx OK")

            try:
                if self.skip_tensorflow:
                    print("\tResults: skipped tensorflow")
                else:
                    if self.check_only_shape:
                        for tf_res, onnx_res in zip(tf_results, onnx_results):
                            np.testing.assert_array_equal(
                                tf_res.shape, onnx_res.shape)
                    else:
                        for tf_res, onnx_res in zip(tf_results, onnx_results):
                            np.testing.assert_allclose(tf_res,
                                                       onnx_res,
                                                       rtol=self.rtol,
                                                       atol=self.atol)
                    print("\tResults: OK")
                return True
            except Exception as ex:
                print("\tResults: ", ex)

        except Exception as ex:
            print("\trun_onnx", "FAIL", ex)

        return False