def _TestRun(self, sess, batch_size, expect_engine_is_run):
   trt_convert.clear_test_values("")
   result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
   self.assertAllEqual([[[4.0]]] * batch_size, result)
   execute_engine_test_value = ("done" if expect_engine_is_run else "")
   execute_native_segment_test_value = ("" if expect_engine_is_run else "done")
   self.assertEqual(execute_engine_test_value,
                    trt_convert.get_test_value("my_trt_op_0:ExecuteTrtEngine"))
   self.assertEqual(
       execute_native_segment_test_value,
       trt_convert.get_test_value("my_trt_op_0:ExecuteNativeSegment"))
Exemple #2
0
 def _TestRun(self, sess, batch_size, expect_engine_is_run):
   trt_convert.clear_test_values("")
   result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
   self.assertAllEqual([[[4.0]]] * batch_size, result)
   execute_engine_test_value = ("done" if expect_engine_is_run else "")
   execute_native_segment_test_value = ("" if expect_engine_is_run else "done")
   self.assertEqual(
       execute_engine_test_value,
       trt_convert.get_test_value("TRTEngineOp_0:ExecuteTrtEngine"))
   self.assertEqual(
       execute_native_segment_test_value,
       trt_convert.get_test_value("TRTEngineOp_0:ExecuteNativeSegment"))
 def _ExpectTestValue(self, engine_name, method, expected_value):
     label = "%s:%s" % (engine_name, method)
     actual_value = trt_convert.get_test_value(label)
     self.assertEqual(
         expected_value,
         actual_value,
         msg="Unexpected test value with label %s. Actual: %s; expected: %s"
         % (label, actual_value, expected_value))
 def _ExpectTestValue(self, engine_name, method, expected_value):
   label = "%s:%s" % (engine_name, method)
   actual_value = trt_convert.get_test_value(label)
   self.assertEqual(
       expected_value,
       actual_value,
       msg="Unexpected test value with label %s. Actual: %s; expected: %s" %
       (label, actual_value, expected_value))