def _TestRun(self, sess, batch_size, expect_engine_is_run): clear_test_values("") result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size}) self.assertAllEqual([[[4.0]]] * batch_size, result) execute_engine_test_value = ("done" if expect_engine_is_run else "") execute_native_segment_test_value = ("" if expect_engine_is_run else "done") self.assertEqual(execute_engine_test_value, get_test_value("TRTEngineOp_0:ExecuteTrtEngine")) self.assertEqual(execute_native_segment_test_value, get_test_value("TRTEngineOp_0:ExecuteNativeSegment"))
def _ExpectTestValue(self, engine_name, method, expected_value): label = "%s:%s" % (engine_name, method) actual_value = get_test_value(label) self.assertEqual( expected_value, actual_value, msg="Unexpected test value with label %s. Actual: %s; expected: %s" % (label, actual_value, expected_value))