Example #1
0
 def test_compile_with_fp16(self):
     with NamedTemporaryFile(suffix='.pb') as model_file:
         _save_frozen_graph_model(model_file)
         config = Config.from_json({
             'max_batch_size': 1,
             'data_type': 'FP16'
         })
         compiled = compiler.compile_source(
             FrozenGraphFile(model_path=model_file.name), config)
         self.assertEqual(
             compiled.get_inputs(),
             [
                 ModelInput(name='x',
                            data_type=tf.float16.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[2, 3, 4]),  # pylint: disable=no-member
                 ModelInput(name='y',
                            data_type=tf.float16.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[2, 3, 4])
             ])  # pylint: disable=no-member
         self.assertEqual(compiled.get_outputs(), [
             ModelOutput(name='z',
                         data_type=tf.float16.as_datatype_enum,
                         dims=[2, 3, 4])
         ])
Example #2
0
 def test_compile_with_all_params_with_enable_nhwc_to_nchw_true(self):
     with NamedTemporaryFile(suffix='.pb') as model_file:
         _save_frozen_graph_model(model_file)
         config = Config.from_json({
             'input_names': ['x', 'y'],
             'output_names': ['z'],
             'enable_nhwc_to_nchw': True,
             'max_batch_size': 1
         })
         compiled = compiler.compile_source(
             FrozenGraphFile(model_path=model_file.name), config)
         self.assertEqual(
             compiled.get_inputs(),
             [
                 ModelInput(name='x',
                            data_type=tf.float32.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[4, 2, 3]),  # pylint: disable=no-member
                 ModelInput(name='y',
                            data_type=tf.float32.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[4, 2, 3])
             ])  # pylint: disable=no-member
         self.assertEqual(compiled.get_outputs(), [
             ModelOutput(name='z',
                         data_type=tf.float32.as_datatype_enum,
                         dims=[4, 2, 3])
         ])
Example #3
0
 def test_compile_with_no_input_name(self):
     with NamedTemporaryFile(suffix='.pb') as model_file:
         _save_frozen_graph_model(model_file)
         config = Config.from_json({'input_names': None,
                                    'output_names': ['z'],
                                    'max_batch_size': 1})
         compiled = compiler.compile_source(FrozenGraphFile(model_path=model_file.name), config)
     self.assertEqual(compiled.inputs, None)
Example #4
0
 def test_compile_with_no_input_formats(self):
     with NamedTemporaryFile(suffix='.pb') as model_file:
         _save_frozen_graph_model(model_file)
         config = Config.from_json({'input_names': ['x', 'y'],
                                    'output_names': ['z'],
                                    'max_batch_size': 4})
         compiled = compiler.compile_source(FrozenGraphFile(model_path=model_file.name), config)
     self.assertEqual([model_input.format for model_input in compiled.inputs],
                      [ModelInput.FORMAT_NONE, ModelInput.FORMAT_NONE])  # pylint: disable=no-member
     self.assertEqual([model_input.name for model_input in compiled.inputs], ['x', 'y'])