コード例 #1
0
 def test_from_json_no_names(self):
     self.assertEqual(Config.from_json({'input_names': None,
                                        'input_formats': None,
                                        'output_names': ['output'],
                                        'max_batch_size': 1}),
                      Config(input_info=None,
                             output_names=['output'],
                             max_batch_size=1))
コード例 #2
0
 def test_from_json(self):
     self.assertEqual(Config.from_json({'input_names': ['input'],
                                        'input_formats': ['channels_first'],
                                        'output_names': ['output'],
                                        'max_batch_size': 1}),
                      Config(input_info=[('input', DataFormat.CHANNELS_FIRST)],
                             output_names=['output'],
                             max_batch_size=1))
コード例 #3
0
 def test_from_env(self):
     self.assertEqual(Config.from_env({'INPUT_NAMES': 'input1,input2:0',
                                       'OUTPUT_NAMES': 'output',
                                       'INPUT_FORMATS': 'channels_first,channels_first',
                                       'MAX_BATCH_SIZE': '1'}),
                      Config(input_info=[('input1', DataFormat.CHANNELS_FIRST),
                                         ('input2:0', DataFormat.CHANNELS_FIRST)],
                             output_names=['output'],
                             max_batch_size=1))
コード例 #4
0
 def test_from_env_all_names(self):
     self.assertEqual(Config.from_env({'INPUT_NAMES': 'input',
                                       'INPUT_SHAPES': '[1, 2, 3, 4]',
                                       'OUTPUT_NAMES': 'output',
                                       'MAX_BATCH_SIZE': '1',
                                       'ENABLE_NHWC_TO_NCHW': '0'}),
                      Config(input_names=['input'],
                             input_shapes=[[1, 2, 3, 4]],
                             output_names=['output'],
                             max_batch_size=1,
                             enable_nhwc_to_nchw=False))
コード例 #5
0
 def test_from_json_all_names(self):
     self.assertEqual(Config.from_json({'input_names': ['input'],
                                        'input_shapes': [[1, 2, 3, 4]],
                                        'output_names': ['output'],
                                        'max_batch_size': 1,
                                        'enable_nhwc_to_nchw': False}),
                      Config(input_names=['input'],
                             input_shapes=[[1, 2, 3, 4]],
                             output_names=['output'],
                             max_batch_size=1,
                             enable_nhwc_to_nchw=False))
コード例 #6
0
 def test_compile_with_fp16(self):
     with NamedTemporaryFile(suffix='.pb') as model_file:
         _save_frozen_graph_model(model_file)
         config = Config.from_json({
             'max_batch_size': 1,
             'data_type': 'FP16'
         })
         compiled = compiler.compile_source(
             FrozenGraphFile(model_path=model_file.name), config)
         self.assertEqual(
             compiled.get_inputs(),
             [
                 ModelInput(name='x',
                            data_type=tf.float16.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[2, 3, 4]),  # pylint: disable=no-member
                 ModelInput(name='y',
                            data_type=tf.float16.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[2, 3, 4])
             ])  # pylint: disable=no-member
         self.assertEqual(compiled.get_outputs(), [
             ModelOutput(name='z',
                         data_type=tf.float16.as_datatype_enum,
                         dims=[2, 3, 4])
         ])
コード例 #7
0
 def test_compile_with_all_params_with_enable_nhwc_to_nchw_true(self):
     with NamedTemporaryFile(suffix='.pb') as model_file:
         _save_frozen_graph_model(model_file)
         config = Config.from_json({
             'input_names': ['x', 'y'],
             'output_names': ['z'],
             'enable_nhwc_to_nchw': True,
             'max_batch_size': 1
         })
         compiled = compiler.compile_source(
             FrozenGraphFile(model_path=model_file.name), config)
         self.assertEqual(
             compiled.get_inputs(),
             [
                 ModelInput(name='x',
                            data_type=tf.float32.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[4, 2, 3]),  # pylint: disable=no-member
                 ModelInput(name='y',
                            data_type=tf.float32.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[4, 2, 3])
             ])  # pylint: disable=no-member
         self.assertEqual(compiled.get_outputs(), [
             ModelOutput(name='z',
                         data_type=tf.float32.as_datatype_enum,
                         dims=[4, 2, 3])
         ])
コード例 #8
0
 def test_compile_with_no_input_name(self):
     with NamedTemporaryFile(suffix='.pb') as model_file:
         _save_frozen_graph_model(model_file)
         config = Config.from_json({'input_names': None,
                                    'output_names': ['z'],
                                    'max_batch_size': 1})
         compiled = compiler.compile_source(FrozenGraphFile(model_path=model_file.name), config)
     self.assertEqual(compiled.inputs, None)
コード例 #9
0
 def test_compile_with_no_input_formats(self):
     with NamedTemporaryFile(suffix='.pb') as model_file:
         _save_frozen_graph_model(model_file)
         config = Config.from_json({'input_names': ['x', 'y'],
                                    'output_names': ['z'],
                                    'max_batch_size': 4})
         compiled = compiler.compile_source(FrozenGraphFile(model_path=model_file.name), config)
     self.assertEqual([model_input.format for model_input in compiled.inputs],
                      [ModelInput.FORMAT_NONE, ModelInput.FORMAT_NONE])  # pylint: disable=no-member
     self.assertEqual([model_input.name for model_input in compiled.inputs], ['x', 'y'])
コード例 #10
0
 def test_from_env_no_names(self):
     self.assertEqual(Config.from_env({'OUTPUT_NAMES': 'output',
                                       'MAX_BATCH_SIZE': '1'}),
                      Config(input_info=None,
                             output_names=['output'],
                             max_batch_size=1))
コード例 #11
0
 def test_from_env_no_names(self):
     self.assertEqual(Config.from_env({}), Config())
コード例 #12
0
 def test_from_json_no_names(self):
     self.assertEqual(Config.from_json({}), Config())