Exemplo n.º 1
0
    def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto, op_type,
                        remaining_op_num, debug=False, rtol=1e-07):
        utils.make_sure(op_type is not None, "op_type should be specified")
        utils.make_sure(remaining_op_num is not None, "remaining_op_num should be specified")

        origin_model_path = self.save_onnx_model(origin_proto, onnx_feed_dict, postfix="_origin")

        new_proto = GraphUtil.optimize_graph_with_model_proto(origin_proto)

        self.assertTrue(new_proto, msg="model proto after optimizer should not be None")

        new_model_path = self.save_onnx_model(new_proto, onnx_feed_dict, postfix="_opt")
        current = GraphUtil.get_node_count_from_onnx_graph(new_proto.graph)

        self.assertTrue(current[op_type] == remaining_op_num,
                        msg="Expect " + str(remaining_op_num) + " " + op_type + " ops left, but actually " +
                        str(current[op_type]) + " left")

        if self.config.is_onnxruntime_backend:
            expected = self.run_onnxruntime(origin_model_path, onnx_feed_dict, output_names_with_port)
            actual = self.run_onnxruntime(new_model_path, onnx_feed_dict, output_names_with_port)
        else:
            raise ValueError("only onnxruntime is supported to test transpose optimizer")

        for expected_val, actual_val in zip(expected, actual):
            self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=0.)
            self.assertEqual(expected_val.dtype, actual_val.dtype)
            self.assertEqual(expected_val.shape, actual_val.shape)
    def run_and_compare(self,
                        output_names_with_port,
                        onnx_feed_dict,
                        origin_proto,
                        debug=False,
                        rtol=1e-07):
        origin_model_path = self.save_onnx_model(origin_proto,
                                                 onnx_feed_dict,
                                                 postfix="_origin")

        new_proto = GraphUtil.opt_transposes_with_model_proto(origin_proto)

        self.assertTrue(new_proto,
                        msg="model proto after optimizer should not be None")

        new_model_path = self.save_onnx_model(new_proto,
                                              onnx_feed_dict,
                                              postfix="_opt")

        previous = GraphUtil.get_node_count_from_onnx_graph(origin_proto.graph)
        current = GraphUtil.get_node_count_from_onnx_graph(new_proto.graph)

        self.assertTrue(current["Transpose"] < previous["Transpose"],
                        msg="transpose ops count not changed")

        if type(self).BACKEND == "onnxruntime":
            expected = self.run_onnxruntime(origin_model_path, onnx_feed_dict,
                                            output_names_with_port)
            actual = self.run_onnxruntime(new_model_path, onnx_feed_dict,
                                          output_names_with_port)
        else:
            raise ValueError(
                "only onnxruntime is supported to test transpose optimizer")

        for expected_val, actual_val in zip(expected, actual):
            self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=0.)
            self.assertEqual(expected_val.dtype, actual_val.dtype)
            self.assertEqual(expected_val.shape, actual_val.shape)