def testTensorrtRewriterConfig(self): """Test case for trt_convert.tensorrt_rewriter_config().""" rewriter_cfg = trt_convert.tensorrt_rewriter_config( max_batch_size=128, max_workspace_size_bytes=1234, precision_mode="INT8", minimum_segment_size=10, is_dynamic_op=True, maximum_cached_engines=2, cached_engine_batch_sizes=[1, 128]) trt_optimizer = None for optimizer in rewriter_cfg.custom_optimizers: if optimizer.name == "TensorRTOptimizer": self.assertTrue(trt_optimizer is None) trt_optimizer = optimizer self.assertTrue(trt_optimizer is not None) for key in [ "minimum_segment_size", "max_batch_size", "is_dynamic_op", "max_workspace_size_bytes", "precision_mode", "maximum_cached_engines", "cached_engine_batches" ]: self.assertTrue(key in trt_optimizer.parameter_map) self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i) self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i) self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b) self.assertEqual(1234, trt_optimizer.parameter_map["max_workspace_size_bytes"].i) self.assertEqual( trt_convert._to_bytes("INT8"), trt_optimizer.parameter_map["precision_mode"].s) self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i) self.assertEqual( [1, 128], trt_optimizer.parameter_map["cached_engine_batches"].list.i)
def testTensorrtRewriterConfig(self): """Test case for trt_convert.tensorrt_rewriter_config().""" rewriter_cfg = trt_convert.tensorrt_rewriter_config( max_batch_size=128, max_workspace_size_bytes=1234, precision_mode="INT8", minimum_segment_size=10, is_dynamic_op=True, maximum_cached_engines=2, cached_engine_batch_sizes=[1, 128]) trt_optimizer = None for optimizer in rewriter_cfg.custom_optimizers: if optimizer.name == "TensorRTOptimizer": self.assertTrue(trt_optimizer is None) trt_optimizer = optimizer self.assertTrue(trt_optimizer is not None) for key in [ "minimum_segment_size", "max_batch_size", "is_dynamic_op", "max_workspace_size_bytes", "precision_mode", "maximum_cached_engines", "cached_engine_batches" ]: self.assertTrue(key in trt_optimizer.parameter_map) self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i) self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i) self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b) self.assertEqual( 1234, trt_optimizer.parameter_map["max_workspace_size_bytes"].i) self.assertEqual(trt_convert._to_bytes("INT8"), trt_optimizer.parameter_map["precision_mode"].s) self.assertEqual( 2, trt_optimizer.parameter_map["maximum_cached_engines"].i) self.assertEqual( [1, 128], trt_optimizer.parameter_map["cached_engine_batches"].list.i)