def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto, op_type, remaining_op_num, debug=False, rtol=1e-07): utils.make_sure(op_type is not None, "op_type should be specified") utils.make_sure(remaining_op_num is not None, "remaining_op_num should be specified") origin_model_path = self.save_onnx_model(origin_proto, onnx_feed_dict, postfix="_origin") new_proto = GraphUtil.optimize_graph_with_model_proto(origin_proto) self.assertTrue(new_proto, msg="model proto after optimizer should not be None") new_model_path = self.save_onnx_model(new_proto, onnx_feed_dict, postfix="_opt") current = GraphUtil.get_node_count_from_onnx_graph(new_proto.graph) self.assertTrue(current[op_type] == remaining_op_num, msg="Expect " + str(remaining_op_num) + " " + op_type + " ops left, but actually " + str(current[op_type]) + " left") if self.config.is_onnxruntime_backend: expected = self.run_onnxruntime(origin_model_path, onnx_feed_dict, output_names_with_port) actual = self.run_onnxruntime(new_model_path, onnx_feed_dict, output_names_with_port) else: raise ValueError("only onnxruntime is supported to test transpose optimizer") for expected_val, actual_val in zip(expected, actual): self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=0.) self.assertEqual(expected_val.dtype, actual_val.dtype) self.assertEqual(expected_val.shape, actual_val.shape)
def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto, debug=False, rtol=1e-07): origin_model_path = self.save_onnx_model(origin_proto, onnx_feed_dict, postfix="_origin") new_proto = GraphUtil.opt_transposes_with_model_proto(origin_proto) self.assertTrue(new_proto, msg="model proto after optimizer should not be None") new_model_path = self.save_onnx_model(new_proto, onnx_feed_dict, postfix="_opt") previous = GraphUtil.get_node_count_from_onnx_graph(origin_proto.graph) current = GraphUtil.get_node_count_from_onnx_graph(new_proto.graph) self.assertTrue(current["Transpose"] < previous["Transpose"], msg="transpose ops count not changed") if type(self).BACKEND == "onnxruntime": expected = self.run_onnxruntime(origin_model_path, onnx_feed_dict, output_names_with_port) actual = self.run_onnxruntime(new_model_path, onnx_feed_dict, output_names_with_port) else: raise ValueError( "only onnxruntime is supported to test transpose optimizer") for expected_val, actual_val in zip(expected, actual): self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=0.) self.assertEqual(expected_val.dtype, actual_val.dtype) self.assertEqual(expected_val.shape, actual_val.shape)