Beispiel #1
0
 def test_compile_with_fp16(self):
     with NamedTemporaryFile(suffix='.pb') as model_file:
         _save_frozen_graph_model(model_file)
         config = Config.from_json({
             'max_batch_size': 1,
             'data_type': 'FP16'
         })
         compiled = compiler.compile_source(
             FrozenGraphFile(model_path=model_file.name), config)
         self.assertEqual(
             compiled.get_inputs(),
             [
                 ModelInput(name='x',
                            data_type=tf.float16.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[2, 3, 4]),  # pylint: disable=no-member
                 ModelInput(name='y',
                            data_type=tf.float16.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[2, 3, 4])
             ])  # pylint: disable=no-member
         self.assertEqual(compiled.get_outputs(), [
             ModelOutput(name='z',
                         data_type=tf.float16.as_datatype_enum,
                         dims=[2, 3, 4])
         ])
Beispiel #2
0
 def test_compile_with_all_params_with_enable_nhwc_to_nchw_true(self):
     with NamedTemporaryFile(suffix='.pb') as model_file:
         _save_frozen_graph_model(model_file)
         config = Config.from_json({
             'input_names': ['x', 'y'],
             'output_names': ['z'],
             'enable_nhwc_to_nchw': True,
             'max_batch_size': 1
         })
         compiled = compiler.compile_source(
             FrozenGraphFile(model_path=model_file.name), config)
         self.assertEqual(
             compiled.get_inputs(),
             [
                 ModelInput(name='x',
                            data_type=tf.float32.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[4, 2, 3]),  # pylint: disable=no-member
                 ModelInput(name='y',
                            data_type=tf.float32.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[4, 2, 3])
             ])  # pylint: disable=no-member
         self.assertEqual(compiled.get_outputs(), [
             ModelOutput(name='z',
                         data_type=tf.float32.as_datatype_enum,
                         dims=[4, 2, 3])
         ])
Beispiel #3
0
 def test_compile_with_no_input_name(self):
     with NamedTemporaryFile(suffix='.pb') as model_file:
         _save_frozen_graph_model(model_file)
         config = Config.from_json({'input_names': None,
                                    'output_names': ['z'],
                                    'max_batch_size': 1})
         compiled = compiler.compile_source(FrozenGraphFile(model_path=model_file.name), config)
     self.assertEqual(compiled.inputs, None)
Beispiel #4
0
 def test_compile_with_no_input_formats(self):
     with NamedTemporaryFile(suffix='.pb') as model_file:
         _save_frozen_graph_model(model_file)
         config = Config.from_json({'input_names': ['x', 'y'],
                                    'output_names': ['z'],
                                    'max_batch_size': 4})
         compiled = compiler.compile_source(FrozenGraphFile(model_path=model_file.name), config)
     self.assertEqual([model_input.format for model_input in compiled.inputs],
                      [ModelInput.FORMAT_NONE, ModelInput.FORMAT_NONE])  # pylint: disable=no-member
     self.assertEqual([model_input.name for model_input in compiled.inputs], ['x', 'y'])
    def test_compile_with_no_output_name(self):
        with NamedTemporaryFile(suffix='.pb') as model_file:
            _save_frozen_graph_model(model_file)

            config = Config.from_json({
                'input_names': ['x:0', 'y:0'],
                'input_formats': ['channels_first', 'channels_last']
            })
            compiled = compiler.compile_source(
                FrozenGraphFile(model_path=model_file.name), config)

        self.assertEqual(
            [model_output.name for model_output in compiled.outputs], ['z:0'])
    def test_compile_with_no_input_formats(self):
        with NamedTemporaryFile(suffix='.pb') as model_file:
            _save_frozen_graph_model(model_file)

            config = Config.from_json({
                'input_names': ['x:0'],
                'output_names': ['z:0']
            })
            compiled = compiler.compile_source(
                FrozenGraphFile(model_path=model_file.name), config)

        self.assertEqual(
            [model_input.data_format for model_input in compiled.inputs],
            [None])
    def test_compile_with_variables(self):
        with NamedTemporaryFile(suffix='.pb') as model_file:
            _save_frozen_graph_model(model_file)
            config = Config.from_json({
                'input_names': ['x:0', 'y:0'],
                'output_names': ['z:0'],
                'input_formats': ['channels_first', 'channels_last']
            })
            compiled = compiler.compile_source(
                FrozenGraphFile(model_path=model_file.name), config)

        self.assertIsInstance(compiled.session, tf.compat.v1.Session)

        self.assertEqual(
            [model_input.tensor.name for model_input in compiled.inputs],
            ['x:0', 'y:0'])
        self.assertEqual(
            [model_input.data_format for model_input in compiled.inputs],
            [DataFormat.CHANNELS_FIRST, DataFormat.CHANNELS_LAST])

        self.assertEqual(
            [model_output.name for model_output in compiled.outputs], ['z:0'])
Beispiel #8
0
 def test_from_json(self):
     config = FrozenGraphFile.from_json({'input_model': 'foo'})
     self.assertEqual(config.model_path, 'foo')
Beispiel #9
0
 def test_from_env(self):
     config = FrozenGraphFile.from_env({'FROZEN_GRAPH_PATH': 'model'})
     self.assertEqual(config.model_path, 'model')