def test_from_json_no_names(self): self.assertEqual(Config.from_json({'input_names': None, 'input_formats': None, 'output_names': ['output'], 'max_batch_size': 1}), Config(input_info=None, output_names=['output'], max_batch_size=1))
def test_from_json(self): self.assertEqual(Config.from_json({'input_names': ['input'], 'input_formats': ['channels_first'], 'output_names': ['output'], 'max_batch_size': 1}), Config(input_info=[('input', DataFormat.CHANNELS_FIRST)], output_names=['output'], max_batch_size=1))
def test_from_env(self): self.assertEqual(Config.from_env({'INPUT_NAMES': 'input1,input2:0', 'OUTPUT_NAMES': 'output', 'INPUT_FORMATS': 'channels_first,channels_first', 'MAX_BATCH_SIZE': '1'}), Config(input_info=[('input1', DataFormat.CHANNELS_FIRST), ('input2:0', DataFormat.CHANNELS_FIRST)], output_names=['output'], max_batch_size=1))
def test_from_env_all_names(self): self.assertEqual(Config.from_env({'INPUT_NAMES': 'input', 'INPUT_SHAPES': '[1, 2, 3, 4]', 'OUTPUT_NAMES': 'output', 'MAX_BATCH_SIZE': '1', 'ENABLE_NHWC_TO_NCHW': '0'}), Config(input_names=['input'], input_shapes=[[1, 2, 3, 4]], output_names=['output'], max_batch_size=1, enable_nhwc_to_nchw=False))
def test_from_json_all_names(self): self.assertEqual(Config.from_json({'input_names': ['input'], 'input_shapes': [[1, 2, 3, 4]], 'output_names': ['output'], 'max_batch_size': 1, 'enable_nhwc_to_nchw': False}), Config(input_names=['input'], input_shapes=[[1, 2, 3, 4]], output_names=['output'], max_batch_size=1, enable_nhwc_to_nchw=False))
def test_compile_with_fp16(self): with NamedTemporaryFile(suffix='.pb') as model_file: _save_frozen_graph_model(model_file) config = Config.from_json({ 'max_batch_size': 1, 'data_type': 'FP16' }) compiled = compiler.compile_source( FrozenGraphFile(model_path=model_file.name), config) self.assertEqual( compiled.get_inputs(), [ ModelInput(name='x', data_type=tf.float16.as_datatype_enum, format=ModelInput.FORMAT_NONE, dims=[2, 3, 4]), # pylint: disable=no-member ModelInput(name='y', data_type=tf.float16.as_datatype_enum, format=ModelInput.FORMAT_NONE, dims=[2, 3, 4]) ]) # pylint: disable=no-member self.assertEqual(compiled.get_outputs(), [ ModelOutput(name='z', data_type=tf.float16.as_datatype_enum, dims=[2, 3, 4]) ])
def test_compile_with_all_params_with_enable_nhwc_to_nchw_true(self): with NamedTemporaryFile(suffix='.pb') as model_file: _save_frozen_graph_model(model_file) config = Config.from_json({ 'input_names': ['x', 'y'], 'output_names': ['z'], 'enable_nhwc_to_nchw': True, 'max_batch_size': 1 }) compiled = compiler.compile_source( FrozenGraphFile(model_path=model_file.name), config) self.assertEqual( compiled.get_inputs(), [ ModelInput(name='x', data_type=tf.float32.as_datatype_enum, format=ModelInput.FORMAT_NONE, dims=[4, 2, 3]), # pylint: disable=no-member ModelInput(name='y', data_type=tf.float32.as_datatype_enum, format=ModelInput.FORMAT_NONE, dims=[4, 2, 3]) ]) # pylint: disable=no-member self.assertEqual(compiled.get_outputs(), [ ModelOutput(name='z', data_type=tf.float32.as_datatype_enum, dims=[4, 2, 3]) ])
def test_compile_with_no_input_name(self): with NamedTemporaryFile(suffix='.pb') as model_file: _save_frozen_graph_model(model_file) config = Config.from_json({'input_names': None, 'output_names': ['z'], 'max_batch_size': 1}) compiled = compiler.compile_source(FrozenGraphFile(model_path=model_file.name), config) self.assertEqual(compiled.inputs, None)
def test_compile_with_no_input_formats(self): with NamedTemporaryFile(suffix='.pb') as model_file: _save_frozen_graph_model(model_file) config = Config.from_json({'input_names': ['x', 'y'], 'output_names': ['z'], 'max_batch_size': 4}) compiled = compiler.compile_source(FrozenGraphFile(model_path=model_file.name), config) self.assertEqual([model_input.format for model_input in compiled.inputs], [ModelInput.FORMAT_NONE, ModelInput.FORMAT_NONE]) # pylint: disable=no-member self.assertEqual([model_input.name for model_input in compiled.inputs], ['x', 'y'])
def test_from_env_no_names(self): self.assertEqual(Config.from_env({'OUTPUT_NAMES': 'output', 'MAX_BATCH_SIZE': '1'}), Config(input_info=None, output_names=['output'], max_batch_size=1))
def test_from_env_no_names(self): self.assertEqual(Config.from_env({}), Config())
def test_from_json_no_names(self): self.assertEqual(Config.from_json({}), Config())