def _GetConfigProto(self, run_params, graph_state):
        """Get config proto based on specific settings."""
        conversion_params = self.GetConversionParams(run_params)
        if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
            rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
                conversion_params.rewriter_config,
                conversion_params.max_batch_size,
                conversion_params.max_workspace_size_bytes,
                conversion_params.precision_mode,
                conversion_params.minimum_segment_size,
                conversion_params.is_dynamic_op,
                conversion_params.maximum_cached_engines,
                conversion_params.cached_engine_batches,
                conversion_params.use_calibration)

            graph_options = config_pb2.GraphOptions(
                rewrite_options=rewriter_cfg)
        else:
            graph_options = config_pb2.GraphOptions()
            if conversion_params.rewriter_config is not None:
                graph_options.rewrite_options.CopyFrom(
                    conversion_params.rewriter_config)

        config = config_pb2.ConfigProto(gpu_options=self._GetGPUOptions(),
                                        graph_options=graph_options)
        return config
  def _GetConfigProto(self, run_params, graph_state):
    """Get config proto based on specific settings."""
    if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
      conversion_params = self.GetConversionParams(run_params)
      rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
          conversion_params.rewriter_config, conversion_params.max_batch_size,
          conversion_params.max_workspace_size_bytes,
          conversion_params.precision_mode,
          conversion_params.minimum_segment_size,
          conversion_params.is_dynamic_op,
          conversion_params.maximum_cached_engines,
          conversion_params.cached_engine_batch_sizes)

      graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
    else:
      graph_options = config_pb2.GraphOptions()

    config = config_pb2.ConfigProto(
        gpu_options=self._GetGPUOptions(), graph_options=graph_options)
    return config
 def testGetTensorrtRewriterConfig(self):
     """Test case for trt_convert.get_tensorrt_rewriter_config()."""
     rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
         rewriter_config=None,
         max_batch_size=128,
         max_workspace_size_bytes=1234,
         precision_mode="INT8",
         minimum_segment_size=10,
         is_dynamic_op=True,
         maximum_cached_engines=2,
         cached_engine_batches=[1, 128])
     self.assertEqual(["constfold", "layout", "constfold"],
                      rewriter_cfg.optimizers)
     self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE,
                      rewriter_cfg.meta_optimizer_iterations)
     trt_optimizer = None
     for optimizer in rewriter_cfg.custom_optimizers:
         if optimizer.name == "TensorRTOptimizer":
             self.assertTrue(trt_optimizer is None)
             trt_optimizer = optimizer
     self.assertTrue(trt_optimizer is not None)
     for key in [
             "minimum_segment_size", "max_batch_size", "is_dynamic_op",
             "max_workspace_size_bytes", "precision_mode",
             "maximum_cached_engines", "cached_engine_batches"
     ]:
         self.assertTrue(key in trt_optimizer.parameter_map)
     self.assertEqual(10,
                      trt_optimizer.parameter_map["minimum_segment_size"].i)
     self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i)
     self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b)
     self.assertEqual(
         1234, trt_optimizer.parameter_map["max_workspace_size_bytes"].i)
     self.assertEqual(trt_convert._to_bytes("INT8"),
                      trt_optimizer.parameter_map["precision_mode"].s)
     self.assertEqual(
         2, trt_optimizer.parameter_map["maximum_cached_engines"].i)
     self.assertEqual(
         [1, 128],
         trt_optimizer.parameter_map["cached_engine_batches"].list.i)
 def testGetTensorrtRewriterConfig(self):
   """Test case for trt_convert.get_tensorrt_rewriter_config()."""
   rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
       rewriter_config=None,
       max_batch_size=128,
       max_workspace_size_bytes=1234,
       precision_mode="INT8",
       minimum_segment_size=10,
       is_dynamic_op=True,
       maximum_cached_engines=2,
       cached_engine_batch_sizes=[1, 128])
   self.assertEqual(["constfold", "layout", "constfold"],
                    rewriter_cfg.optimizers)
   self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE,
                    rewriter_cfg.meta_optimizer_iterations)
   trt_optimizer = None
   for optimizer in rewriter_cfg.custom_optimizers:
     if optimizer.name == "TensorRTOptimizer":
       self.assertTrue(trt_optimizer is None)
       trt_optimizer = optimizer
   self.assertTrue(trt_optimizer is not None)
   for key in [
       "minimum_segment_size", "max_batch_size", "is_dynamic_op",
       "max_workspace_size_bytes", "precision_mode", "maximum_cached_engines",
       "cached_engine_batches"
   ]:
     self.assertTrue(key in trt_optimizer.parameter_map)
   self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i)
   self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i)
   self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b)
   self.assertEqual(1234,
                    trt_optimizer.parameter_map["max_workspace_size_bytes"].i)
   self.assertEqual(
       trt_convert._to_bytes("INT8"),
       trt_optimizer.parameter_map["precision_mode"].s)
   self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i)
   self.assertEqual(
       [1, 128],
       trt_optimizer.parameter_map["cached_engine_batches"].list.i)