def test_compile_with_fp16(self): with NamedTemporaryFile(suffix='.pb') as model_file: _save_frozen_graph_model(model_file) config = Config.from_json({ 'max_batch_size': 1, 'data_type': 'FP16' }) compiled = compiler.compile_source( FrozenGraphFile(model_path=model_file.name), config) self.assertEqual( compiled.get_inputs(), [ ModelInput(name='x', data_type=tf.float16.as_datatype_enum, format=ModelInput.FORMAT_NONE, dims=[2, 3, 4]), # pylint: disable=no-member ModelInput(name='y', data_type=tf.float16.as_datatype_enum, format=ModelInput.FORMAT_NONE, dims=[2, 3, 4]) ]) # pylint: disable=no-member self.assertEqual(compiled.get_outputs(), [ ModelOutput(name='z', data_type=tf.float16.as_datatype_enum, dims=[2, 3, 4]) ])
def test_compile_with_all_params_with_enable_nhwc_to_nchw_true(self): with NamedTemporaryFile(suffix='.pb') as model_file: _save_frozen_graph_model(model_file) config = Config.from_json({ 'input_names': ['x', 'y'], 'output_names': ['z'], 'enable_nhwc_to_nchw': True, 'max_batch_size': 1 }) compiled = compiler.compile_source( FrozenGraphFile(model_path=model_file.name), config) self.assertEqual( compiled.get_inputs(), [ ModelInput(name='x', data_type=tf.float32.as_datatype_enum, format=ModelInput.FORMAT_NONE, dims=[4, 2, 3]), # pylint: disable=no-member ModelInput(name='y', data_type=tf.float32.as_datatype_enum, format=ModelInput.FORMAT_NONE, dims=[4, 2, 3]) ]) # pylint: disable=no-member self.assertEqual(compiled.get_outputs(), [ ModelOutput(name='z', data_type=tf.float32.as_datatype_enum, dims=[4, 2, 3]) ])
def test_compile_with_no_input_name(self): with NamedTemporaryFile(suffix='.pb') as model_file: _save_frozen_graph_model(model_file) config = Config.from_json({'input_names': None, 'output_names': ['z'], 'max_batch_size': 1}) compiled = compiler.compile_source(FrozenGraphFile(model_path=model_file.name), config) self.assertEqual(compiled.inputs, None)
def test_compile_with_no_input_formats(self): with NamedTemporaryFile(suffix='.pb') as model_file: _save_frozen_graph_model(model_file) config = Config.from_json({'input_names': ['x', 'y'], 'output_names': ['z'], 'max_batch_size': 4}) compiled = compiler.compile_source(FrozenGraphFile(model_path=model_file.name), config) self.assertEqual([model_input.format for model_input in compiled.inputs], [ModelInput.FORMAT_NONE, ModelInput.FORMAT_NONE]) # pylint: disable=no-member self.assertEqual([model_input.name for model_input in compiled.inputs], ['x', 'y'])
def test_compile_with_no_output_name(self): with NamedTemporaryFile(suffix='.pb') as model_file: _save_frozen_graph_model(model_file) config = Config.from_json({ 'input_names': ['x:0', 'y:0'], 'input_formats': ['channels_first', 'channels_last'] }) compiled = compiler.compile_source( FrozenGraphFile(model_path=model_file.name), config) self.assertEqual( [model_output.name for model_output in compiled.outputs], ['z:0'])
def test_compile_with_no_input_formats(self): with NamedTemporaryFile(suffix='.pb') as model_file: _save_frozen_graph_model(model_file) config = Config.from_json({ 'input_names': ['x:0'], 'output_names': ['z:0'] }) compiled = compiler.compile_source( FrozenGraphFile(model_path=model_file.name), config) self.assertEqual( [model_input.data_format for model_input in compiled.inputs], [None])
def test_compile_with_variables(self): with NamedTemporaryFile(suffix='.pb') as model_file: _save_frozen_graph_model(model_file) config = Config.from_json({ 'input_names': ['x:0', 'y:0'], 'output_names': ['z:0'], 'input_formats': ['channels_first', 'channels_last'] }) compiled = compiler.compile_source( FrozenGraphFile(model_path=model_file.name), config) self.assertIsInstance(compiled.session, tf.compat.v1.Session) self.assertEqual( [model_input.tensor.name for model_input in compiled.inputs], ['x:0', 'y:0']) self.assertEqual( [model_input.data_format for model_input in compiled.inputs], [DataFormat.CHANNELS_FIRST, DataFormat.CHANNELS_LAST]) self.assertEqual( [model_output.name for model_output in compiled.outputs], ['z:0'])
def test_from_json(self): config = FrozenGraphFile.from_json({'input_model': 'foo'}) self.assertEqual(config.model_path, 'foo')
def test_from_env(self): config = FrozenGraphFile.from_env({'FROZEN_GRAPH_PATH': 'model'}) self.assertEqual(config.model_path, 'model')