コード例 #1
0
 def test_compile_with_all_params_with_shape(self):
     with TemporaryDirectory() as model_dir:
         _save_saved_model_file(model_dir)
         config = Config.from_json({
             'input_names': ['x', 'y'],
             'input_shapes': [[1, 2, 3, 4], [1, 2, 3, 4]],
             'output_names': ['z'],
             'enable_nhwc_to_nchw':
             False
         })
         compiled = compiler.compile_source(
             SavedModelFile(model_path=model_dir), config)
         self.assertEqual(
             compiled.get_inputs(),
             [
                 ModelInput(name='x',
                            data_type=tf.float32.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[2, 3, 4]),  # pylint: disable=no-member
                 ModelInput(name='y',
                            data_type=tf.float32.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[2, 3, 4])
             ])  # pylint: disable=no-member
         self.assertEqual(compiled.get_outputs(), [
             ModelOutput(name='z',
                         data_type=tf.float32.as_datatype_enum,
                         dims=[2, 3, 4])
         ])
コード例 #2
0
 def test_compile_with_two_saved_model_tags(self):
     with TemporaryDirectory() as model_dir:
         _save_saved_model_file_with_two_tags(model_dir)
         config = Config.from_json({
             'input_names': ['x', 'y'],
             'output_names': ['z'],
             'enable_nhwc_to_nchw': True,
             'max_batch_size': 1,
             'saved_model_tags': ['serve', 'graph2']
         })
         compiled = compiler.compile_source(
             SavedModelFile(model_path=model_dir), config)
         self.assertEqual(
             compiled.get_inputs(),
             [
                 ModelInput(name='x',
                            data_type=tf.float32.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[4, 2, 3]),  # pylint: disable=no-member
                 ModelInput(name='y',
                            data_type=tf.float32.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[4, 2, 3])
             ])  # pylint: disable=no-member
         self.assertEqual(compiled.get_outputs(), [
             ModelOutput(name='z',
                         data_type=tf.float32.as_datatype_enum,
                         dims=[4, 2, 3])
         ])
コード例 #3
0
 def test_compile_with_fp16(self):
     with TemporaryDirectory() as model_dir:
         _save_saved_model_file(model_dir)
         config = Config.from_json({
             'max_batch_size': 1,
             'data_type': 'FP16'
         })
         compiled = compiler.compile_source(
             SavedModelFile(model_path=model_dir), config)
         self.assertEqual(
             compiled.get_inputs(),
             [
                 ModelInput(name='x',
                            data_type=tf.float16.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[2, 3, 4]),  # pylint: disable=no-member
                 ModelInput(name='y',
                            data_type=tf.float16.as_datatype_enum,
                            format=ModelInput.FORMAT_NONE,
                            dims=[2, 3, 4])
             ])  # pylint: disable=no-member
         self.assertEqual(compiled.get_outputs(), [
             ModelOutput(name='z',
                         data_type=tf.float16.as_datatype_enum,
                         dims=[2, 3, 4])
         ])