Ejemplo n.º 1
0
 def GetConversionParams(self, run_params):
     """Return a ConversionParams for test."""
     return super(
         SimpleMultiEnginesTest, self
     ).GetConversionParams(run_params)._replace(
         # Disable layout optimizer, since it'll add Transpose(Const, Const) to
         # the graph and breaks the conversion check.
         rewriter_config=trt_test.OptimizerDisabledRewriterConfig())
 def GetConversionParams(self, run_params):
   """Return a ConversionParams for test."""
   conversion_params = super(DynamicInputShapesTest,
                             self).GetConversionParams(run_params)
   return conversion_params._replace(
       maximum_cached_engines=10,
       # Disable layout optimizer, since it will convert BiasAdd with NHWC
       # format to NCHW format under four dimentional input.
       rewriter_config=trt_test.OptimizerDisabledRewriterConfig())
 def GetConversionParams(self, run_params):
     """Return a ConversionParams for test."""
     conversion_params = super(ExcludeUnsupportedInt32Test,
                               self).GetConversionParams(run_params)
     return conversion_params._replace(
         max_batch_size=100,
         maximum_cached_engines=1,
         # Disable layout optimizer, since it will convert BiasAdd with NHWC
         # format to NCHW format under four dimentional input.
         rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig(
         ))