def test_compile_with_all_params_with_shape(self): with TemporaryDirectory() as model_dir: _save_saved_model_file(model_dir) config = Config.from_json({ 'input_names': ['x', 'y'], 'input_shapes': [[1, 2, 3, 4], [1, 2, 3, 4]], 'output_names': ['z'], 'enable_nhwc_to_nchw': False }) compiled = compiler.compile_source( SavedModelFile(model_path=model_dir), config) self.assertEqual( compiled.get_inputs(), [ ModelInput(name='x', data_type=tf.float32.as_datatype_enum, format=ModelInput.FORMAT_NONE, dims=[2, 3, 4]), # pylint: disable=no-member ModelInput(name='y', data_type=tf.float32.as_datatype_enum, format=ModelInput.FORMAT_NONE, dims=[2, 3, 4]) ]) # pylint: disable=no-member self.assertEqual(compiled.get_outputs(), [ ModelOutput(name='z', data_type=tf.float32.as_datatype_enum, dims=[2, 3, 4]) ])
def test_compile_with_two_saved_model_tags(self): with TemporaryDirectory() as model_dir: _save_saved_model_file_with_two_tags(model_dir) config = Config.from_json({ 'input_names': ['x', 'y'], 'output_names': ['z'], 'enable_nhwc_to_nchw': True, 'max_batch_size': 1, 'saved_model_tags': ['serve', 'graph2'] }) compiled = compiler.compile_source( SavedModelFile(model_path=model_dir), config) self.assertEqual( compiled.get_inputs(), [ ModelInput(name='x', data_type=tf.float32.as_datatype_enum, format=ModelInput.FORMAT_NONE, dims=[4, 2, 3]), # pylint: disable=no-member ModelInput(name='y', data_type=tf.float32.as_datatype_enum, format=ModelInput.FORMAT_NONE, dims=[4, 2, 3]) ]) # pylint: disable=no-member self.assertEqual(compiled.get_outputs(), [ ModelOutput(name='z', data_type=tf.float32.as_datatype_enum, dims=[4, 2, 3]) ])
def test_compile_with_fp16(self): with TemporaryDirectory() as model_dir: _save_saved_model_file(model_dir) config = Config.from_json({ 'max_batch_size': 1, 'data_type': 'FP16' }) compiled = compiler.compile_source( SavedModelFile(model_path=model_dir), config) self.assertEqual( compiled.get_inputs(), [ ModelInput(name='x', data_type=tf.float16.as_datatype_enum, format=ModelInput.FORMAT_NONE, dims=[2, 3, 4]), # pylint: disable=no-member ModelInput(name='y', data_type=tf.float16.as_datatype_enum, format=ModelInput.FORMAT_NONE, dims=[2, 3, 4]) ]) # pylint: disable=no-member self.assertEqual(compiled.get_outputs(), [ ModelOutput(name='z', data_type=tf.float16.as_datatype_enum, dims=[2, 3, 4]) ])