Beispiel #1
0
    def _VerifyGraphDefV1(self, run_params, original_gdef, gdef_to_verify,
                          graph_state):
        expected_engines = self.ExpectedEnginesToBuild(run_params)
        num_engines = 0
        functions = [f.signature.name for f in gdef_to_verify.library.function]
        all_nodes = list(self._ChainAllNodes(gdef_to_verify))
        all_nodes.sort(key=lambda x: x.name)

        for node in all_nodes:
            if node.op == "TRTEngineOp":
                logging.info("Found TRTEngineOp: " + node.name)
                num_engines += 1
                segment_funcdef_name = node.attr["segment_func"].func.name
                function_name = node.name + "_native_segment"
                is_dynamic_engine = not node.attr["static_engine"].b
                self.assertNotEmpty(segment_funcdef_name, node.name)
                self.assertIn(function_name, functions)
                if (not IsQuantizationWithCalibration(run_params)
                        and not is_dynamic_engine):
                    self.assertTrue(len(node.attr["serialized_segment"].s),
                                    node.name)
                self.assertIn(self._RemoveGraphSequenceNumber(node.name),
                              expected_engines)
                self.assertEqual(self._ToBytes(run_params.precision_mode),
                                 node.attr["precision_mode"].s, node.name)

                self.assertEqual(run_params.dynamic_engine, is_dynamic_engine,
                                 node.name)
                self.assertEqual(node.attr["use_calibration"].b,
                                 run_params.use_calibration, node.name)

                has_calibration_data = len(node.attr["calibration_data"].s)
                if (IsQuantizationWithCalibration(run_params)
                        and graph_state == GraphState.INFERENCE):
                    self.assertTrue(has_calibration_data, node.name)
                else:
                    self.assertFalse(has_calibration_data, node.name)
        if graph_state == GraphState.ORIGINAL:
            self.assertEqual(0, num_engines)
            self._VerifyTestAttrs(
                function_protos=gdef_to_verify.library.function)
        else:
            self.assertEqual(num_engines, len(expected_engines))
            expected_connections = self.ExpectedConnections(run_params)
            if expected_connections:
                self._VerifyConnections(expected_engines, expected_connections,
                                        original_gdef, gdef_to_verify)
            self._VerifyMaxBatchSizeAnnotations(
                expected_engines=expected_engines,
                original_gdef=original_gdef,
                converted_gdef=gdef_to_verify,
                expected_max_batch_sizes=self.ExpectedMaxBatchSizes(
                    run_params),
                default_max_batch_size=self.GetMaxBatchSize(run_params))
            self._VerifyTestAttrs(
                function_protos=gdef_to_verify.library.function)
Beispiel #2
0
 def ShouldRunTest(self, run_params):
   should_run, reason_for_skipping = (
       trt_test.TfTrtIntegrationTestBase.ShouldRunTest(self, run_params))
   if not should_run:
     return should_run, reason_for_skipping
   else:
     # TODO(kyungtaek): Calibration currently does not run for nodes
     # nested within functions. If this gets fixed, this method should not
     # override the parent method.
     return (not IsQuantizationWithCalibration(run_params),
             "calibration is not supported for tf.functions")
 def ShouldRunTest(self, run_params):
     should_run, reason_for_skipping = (
         trt_test.TfTrtIntegrationTestBase.ShouldRunTest(self, run_params))
     if not should_run:
         return should_run, reason_for_skipping
     # TODO(kyungtaek): The TRTGraphConverterV2 is incapable of reflecting the
     # _noinline attribute from the SavedModel V1 probouf. We should skip tests
     # in this case for now. When we fix the code to reflect attributes, we
     # should remove this elif clause.
     elif run_params.is_v2:
         return (False,
                 "Disabling function inlining is not supported for V2")
     else:
         # TODO(kyungtaek): Calibration currently does not run for nodes
         # nested within functions. If this gets fixed, this method should not
         # override the parent method.
         return (not IsQuantizationWithCalibration(run_params),
                 "calibration is not supported for tf.functions")