def _TestRun(self, sess, batch_size, expect_engine_is_run): trt_convert.clear_test_values("") result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size}) self.assertAllEqual([[[4.0]]] * batch_size, result) execute_engine_test_value = ("done" if expect_engine_is_run else "") execute_native_segment_test_value = ("" if expect_engine_is_run else "done") self.assertEqual( execute_engine_test_value, trt_convert.get_test_value("TRTEngineOp_0:ExecuteTrtEngine")) self.assertEqual( execute_native_segment_test_value, trt_convert.get_test_value("TRTEngineOp_0:ExecuteNativeSegment"))
def _PrepareRun(self, graph_state): """Set up necessary testing environment before calling sess.run().""" # Clear test values added by TRTEngineOp. trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteTrtEngine") trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteCalibration") trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteNativeSegment")
def setUp(self): """Setup method.""" super(TfTrtIntegrationTestBase, self).setUp() warnings.simplefilter("always") trt_convert.clear_test_values("")
def setUp(self): """Setup method.""" super(PartiallyConvertedTestB, self).setUp() # Let it fail to build the first engine. trt_convert.clear_test_values("") trt_convert.add_test_value("TRTEngineOp_0:CreateTRTNode", "fail")