Exemple #1
0
    def testEval(self):
        if not is_tensorrt_enabled():
            return

        # TODO(b/162447069): Enable the test for TRT 7.1.3.
        if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
            return

        model_dir = test.test_src_dir_path(
            'python/compiler/tensorrt/test/testdata/mnist')

        accuracy_tf_native = self._Run(is_training=False,
                                       use_trt=False,
                                       batch_size=128,
                                       num_epochs=None,
                                       model_dir=model_dir)['accuracy']
        logging.info('accuracy_tf_native: %f', accuracy_tf_native)
        self.assertAllClose(0.9662, accuracy_tf_native, rtol=3e-3, atol=3e-3)

        if not trt_test.IsTensorRTVersionGreaterEqual(5):
            return

        accuracy_tf_trt = self._Run(is_training=False,
                                    use_trt=True,
                                    batch_size=128,
                                    num_epochs=None,
                                    model_dir=model_dir)['accuracy']
        logging.info('accuracy_tf_trt: %f', accuracy_tf_trt)
        self.assertAllClose(0.9675, accuracy_tf_trt, rtol=1e-3, atol=1e-3)
 def ShouldRunTest(self, run_params):
   # TODO(b/162447069): Enable the test for TRT 7.1.3.
   if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
     return (False, 'Skip test due to b/162447069')
   # There is no CombinedNonMaxSuppression op for GPU at the moment, so
   # calibration will fail.
   # TODO(laigd): fix this.
   # Only run for TRT 5.1 and above.
   return trt_test.IsTensorRTVersionGreaterEqual(
       5, 1) and not trt_test.IsQuantizationMode(
           run_params.precision_mode), 'test >=TRT5.1 and non-INT8'
Exemple #3
0
 def ShouldRunTest(self, run_params):
     should_run, reason = super().ShouldRunTest(run_params)
     should_run = should_run and \
         not trt_test.IsQuantizationMode(run_params.precision_mode)
     reason += ' and precision != INT8'
     # Only run for TRT 7.1.3 and above.
     return should_run and trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3), \
         reason + ' and >= TRT 7.1.3'
Exemple #4
0
 def ShouldRunTest(self, run_params):
     # There is no CombinedNonMaxSuppression op for GPU at the moment, so
     # calibration will fail.
     # TODO(laigd): fix this.
     # Only run for TRT 5.1 and above.
     return trt_test.IsTensorRTVersionGreaterEqual(
         5, 1) and not trt_test.IsQuantizationMode(
             run_params.precision_mode), 'test >=TRT5.1 and non-INT8'
Exemple #5
0
 def ShouldRunTest(self, run_params):
     # TODO(b/162448349): Enable the test for TRT 7.1.3.
     if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
         return (False, "Skip test due to b/162448349")
     return super().ShouldRunTest(run_params)
 def ShouldRunTest(self, run_params):
   # Only run for TRT 6 and above.
   return run_params.is_v2 and trt_test.IsTensorRTVersionGreaterEqual(6) and (
       not run_params.use_calibration), "test v2 >=TRT6 and non-calibration"
Exemple #7
0
 def ShouldRunTest(self, run_params):
   # Test static/dynamic engine with/without calibration.
   return (trt_test.IsTensorRTVersionGreaterEqual(5) and
           trt_test.IsQuantizationMode(run_params.precision_mode) and
           not run_params.convert_online), "test offline conversion and INT8"
Exemple #8
0
 def setUp(self):
     super(trt_test.TfTrtIntegrationTestBase, self).setUp()
     if trt_test.IsTensorRTVersionGreaterEqual(7):
         os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True"
Exemple #9
0
 def ShouldRunTest(self, run_params):
     # TODO(b/162448349): Enable the test for TRT 7.1.3.
     if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
         return (False, "Skip test due to b/162448349")
     return (run_params.dynamic_engine and not trt_test.IsQuantizationMode(
         run_params.precision_mode)), "test dynamic engine and non-INT8"
Exemple #10
0
 def ShouldRunTest(self, run_params):
     should_run, reason = super().ShouldRunTest(run_params)
     # Only run for TRT 5.1 and above.
     return should_run and trt_test.IsTensorRTVersionGreaterEqual(
         5, 1), reason + ' and >=TRT5.1'