def test_from_env(self): self.assertEqual( Config.from_env({ 'INPUT_NAMES': 'input1,input2:0', 'OUTPUT_NAMES': 'output', 'INPUT_FORMATS': 'channels_last' }), Config(input_names=['input1', 'input2:0'], data_formats=[DataFormat.CHANNELS_LAST], output_names=['output']))
def test_from_json(self): self.assertEqual( Config.from_json({ 'input_names': ['input'], 'output_names': ['output'], 'input_formats': ['channels_first'] }), Config(input_names=['input'], data_formats=[DataFormat.CHANNELS_FIRST], output_names=['output']))
def test_compile_with_no_output_name(self): with NamedTemporaryFile(suffix='.pb') as model_file: _save_frozen_graph_model(model_file) config = Config.from_json({ 'input_names': ['x:0', 'y:0'], 'input_formats': ['channels_first', 'channels_last'] }) compiled = compiler.compile_source( FrozenGraphFile(model_path=model_file.name), config) self.assertEqual( [model_output.name for model_output in compiled.outputs], ['z:0'])
def test_compile_with_no_input_formats(self): with NamedTemporaryFile(suffix='.pb') as model_file: _save_frozen_graph_model(model_file) config = Config.from_json({ 'input_names': ['x:0'], 'output_names': ['z:0'] }) compiled = compiler.compile_source( FrozenGraphFile(model_path=model_file.name), config) self.assertEqual( [model_input.data_format for model_input in compiled.inputs], [None])
def test_compile_with_variables(self): with NamedTemporaryFile(suffix='.pb') as model_file: _save_frozen_graph_model(model_file) config = Config.from_json({ 'input_names': ['x:0', 'y:0'], 'output_names': ['z:0'], 'input_formats': ['channels_first', 'channels_last'] }) compiled = compiler.compile_source( FrozenGraphFile(model_path=model_file.name), config) self.assertIsInstance(compiled.session, tf.compat.v1.Session) self.assertEqual( [model_input.tensor.name for model_input in compiled.inputs], ['x:0', 'y:0']) self.assertEqual( [model_input.data_format for model_input in compiled.inputs], [DataFormat.CHANNELS_FIRST, DataFormat.CHANNELS_LAST]) self.assertEqual( [model_output.name for model_output in compiled.outputs], ['z:0'])
def test_from_env_no_names(self): self.assertEqual( Config.from_env({}), Config(input_names=None, data_formats=[], output_names=None))