Ejemplo n.º 1
0
 def test_graph_models_to_saved_model_accepts_signatures(self):
     """graph_models_to_saved_model should accept signatures"""
     model_dir_1 = testutils.get_path_to(testutils.SIMPLE_MODEL_PATH_NAME)
     model_dir_2 = testutils.get_path_to(testutils.MULTI_HEAD_PATH)
     tags_1 = [tf.saved_model.SERVING, 'model_1']
     tags_2 = [tf.saved_model.SERVING, 'model_2']
     export_dir = tempfile.mkdtemp(suffix='.saved_model')
     try:
         model_list = [(model_dir_1, tags_1), (model_dir_2, tags_2)]
         signatures = {
             'ignore_this': {'': {api.SIGNATURE_OUTPUTS: ['y']}},
             model_dir_2: {'': {api.SIGNATURE_OUTPUTS: ['Identity']}}
         }
         api.graph_models_to_saved_model(model_list, export_dir, signatures)
         self.assertTrue(os.path.exists(export_dir))
         # try to load model 2
         meta_graph_def = load_meta_graph(export_dir, tags_2)
         # we want a signature to be present
         self.assertEqual(len(meta_graph_def.signature_def), 1)
         signature = list(meta_graph_def.signature_def.values())[0]
         # the signature should be valid
         self.assertTrue(is_valid_signature(signature))
         # the signature should have a single output
         self.assertEqual(len(signature.outputs), 1)
         self.assertIn('Identity', signature.outputs.keys())
     finally:
         if os.path.exists(export_dir):
             shutil.rmtree(export_dir)
Ejemplo n.º 2
0
 def test_graph_models_to_saved_model(self):
     """graph_models_to_saved_model should accept model list"""
     model_dir_1 = testutils.get_path_to(testutils.SIMPLE_MODEL_PATH_NAME)
     model_dir_2 = testutils.get_path_to(testutils.PRELU_MODEL_PATH)
     tags_1 = [tf.saved_model.SERVING, 'model_1']
     tags_2 = [tf.saved_model.SERVING, 'model_2']
     export_dir = tempfile.mkdtemp(suffix='.saved_model')
     try:
         api.graph_models_to_saved_model([(model_dir_1, tags_1),
                                          (model_dir_2, tags_2)],
                                         export_dir)
         self.assertTrue(os.path.exists(export_dir))
         # try to load model 1
         meta_graph_def = load_meta_graph(export_dir, tags_1)
         self.assertIsNotNone(meta_graph_def)
         # we also want a signature to be present
         self.assertEqual(len(meta_graph_def.signature_def), 1)
         # the signature should be valid
         self.assertTrue(
             is_valid_signature(
                 list(meta_graph_def.signature_def.values())[0]))
         # try to load model 2
         meta_graph_def = load_meta_graph(export_dir, tags_2)
         self.assertIsNotNone(meta_graph_def)
         # we also want a signature to be present
         self.assertEqual(len(meta_graph_def.signature_def), 1)
         # the signature should be valid
         self.assertTrue(
             is_valid_signature(
                 list(meta_graph_def.signature_def.values())[0]))
     finally:
         if os.path.exists(export_dir):
             shutil.rmtree(export_dir)
Ejemplo n.º 3
0
 def test_graph_models_to_saved_model_accepts_signature_keys(self):
     """graph_models_to_saved_model should accept signature keys"""
     model_dir_1 = testutils.get_path_to(testutils.SIMPLE_MODEL_PATH_NAME)
     model_dir_2 = testutils.get_path_to(testutils.MULTI_HEAD_PATH)
     tags_1 = [tf.saved_model.SERVING, 'model_1']
     tags_2 = [tf.saved_model.SERVING, 'model_2']
     export_dir = tempfile.mkdtemp(suffix='.saved_model')
     try:
         model_list = [(model_dir_1, tags_1), (model_dir_2, tags_2)]
         signatures = {
             'ignore_this': {
                 '': {
                     api.SIGNATURE_OUTPUTS: ['y']
                 }
             },
             model_dir_2: {
                 '': {
                     api.SIGNATURE_OUTPUTS: ['Identity']
                 }
             }
         }
         signature_keys = {
             model_dir_1: api.RenameMap({
                 'x': 'input',
                 'Identity': 'output'
             }),
             model_dir_2: api.RenameMap({'Identity': 'scores'})
         }
         api.graph_models_to_saved_model(model_list, export_dir, signatures,
                                         signature_keys)
         self.assertTrue(os.path.exists(export_dir))
         # check the signatures of model 1
         meta_graph_def = load_meta_graph(export_dir, tags_1)
         signature = list(meta_graph_def.signature_def.values())[0]
         self.assertIn('input', signature.inputs)
         self.assertEqual(signature.inputs['input'].name, 'x:0')
         self.assertIn('output', signature.outputs)
         self.assertEqual(signature.outputs['output'].name, 'Identity:0')
         # check the signatures of model 2
         meta_graph_def = load_meta_graph(export_dir, tags_2)
         signature = list(meta_graph_def.signature_def.values())[0]
         self.assertIn('scores', signature.outputs)
         self.assertEqual(signature.outputs['scores'].name, 'Identity:0')
     finally:
         if os.path.exists(export_dir):
             shutil.rmtree(export_dir)
Ejemplo n.º 4
0
 def test_graph_models_to_saved_model(self):
     """graph_models_to_saved_model should accept model list"""
     model_dir_1 = testutils.get_path_to(testutils.SIMPLE_MODEL_PATH_NAME)
     model_dir_2 = testutils.get_path_to(testutils.PRELU_MODEL_PATH)
     tags_1 = ['serving_default', 'model_1']
     tags_2 = ['serving_default', 'model_2']
     export_dir = tempfile.mkdtemp(suffix='.saved_model')
     try:
         api.graph_models_to_saved_model([(model_dir_1, tags_1),
                                          (model_dir_2, tags_2)],
                                         export_dir)
         self.assertTrue(os.path.exists(export_dir))
         # try to load model 1
         model = tf.saved_model.load(export_dir, tags=tags_1)
         self.assertIsNotNone(model.graph)
         # try toload model 2
         model = tf.saved_model.load(export_dir, tags=tags_2)
         self.assertIsNotNone(model.graph)
     finally:
         if os.path.exists(export_dir):
             shutil.rmtree(export_dir)