def test_graph_models_to_saved_model(self):
     """graph_models_to_saved_model should accept model list"""
     model_dir_1 = testutils.get_path_to(testutils.SIMPLE_MODEL_PATH_NAME)
     model_dir_2 = testutils.get_path_to(testutils.PRELU_MODEL_PATH)
     tags_1 = [tf.saved_model.SERVING, 'model_1']
     tags_2 = [tf.saved_model.SERVING, 'model_2']
     export_dir = tempfile.mkdtemp(suffix='.saved_model')
     try:
         api.graph_models_to_saved_model([(model_dir_1, tags_1),
                                          (model_dir_2, tags_2)],
                                         export_dir)
         self.assertTrue(os.path.exists(export_dir))
         # try to load model 1
         meta_graph_def = load_meta_graph(export_dir, tags_1)
         self.assertIsNotNone(meta_graph_def)
         # we also want a signature to be present
         self.assertEqual(len(meta_graph_def.signature_def), 1)
         # the signature should be valid
         self.assertTrue(
             is_valid_signature(
                 list(meta_graph_def.signature_def.values())[0]))
         # try to load model 2
         meta_graph_def = load_meta_graph(export_dir, tags_2)
         self.assertIsNotNone(meta_graph_def)
         # we also want a signature to be present
         self.assertEqual(len(meta_graph_def.signature_def), 1)
         # the signature should be valid
         self.assertTrue(
             is_valid_signature(
                 list(meta_graph_def.signature_def.values())[0]))
     finally:
         if os.path.exists(export_dir):
             shutil.rmtree(export_dir)
Exemple #2
0
 def test_graph_models_to_saved_model_accepts_signatures(self):
     """graph_models_to_saved_model should accept signatures"""
     model_dir_1 = testutils.get_path_to(testutils.SIMPLE_MODEL_PATH_NAME)
     model_dir_2 = testutils.get_path_to(testutils.MULTI_HEAD_PATH)
     tags_1 = [tf.saved_model.SERVING, 'model_1']
     tags_2 = [tf.saved_model.SERVING, 'model_2']
     export_dir = tempfile.mkdtemp(suffix='.saved_model')
     try:
         model_list = [(model_dir_1, tags_1), (model_dir_2, tags_2)]
         signatures = {
             'ignore_this': {'': {api.SIGNATURE_OUTPUTS: ['y']}},
             model_dir_2: {'': {api.SIGNATURE_OUTPUTS: ['Identity']}}
         }
         api.graph_models_to_saved_model(model_list, export_dir, signatures)
         self.assertTrue(os.path.exists(export_dir))
         # try to load model 2
         meta_graph_def = load_meta_graph(export_dir, tags_2)
         # we want a signature to be present
         self.assertEqual(len(meta_graph_def.signature_def), 1)
         signature = list(meta_graph_def.signature_def.values())[0]
         # the signature should be valid
         self.assertTrue(is_valid_signature(signature))
         # the signature should have a single output
         self.assertEqual(len(signature.outputs), 1)
         self.assertIn('Identity', signature.outputs.keys())
     finally:
         if os.path.exists(export_dir):
             shutil.rmtree(export_dir)
Exemple #3
0
 def test_graph_model_to_saved_model_accepts_signature_map(self):
     """graph_model_to_saved_model should accept signature map"""
     model_dir = testutils.get_path_to(testutils.MULTI_HEAD_PATH)
     export_dir = tempfile.mkdtemp(suffix='.saved_model')
     try:
         tags = [tf.saved_model.SERVING]
         signature_map = {
             '': {api.SIGNATURE_OUTPUTS: ['Identity']},
             'debug': {api.SIGNATURE_OUTPUTS: ['Identity', 'Identity_1']}}
         api.graph_model_to_saved_model(model_dir, export_dir,
                                        tags=tags,
                                        signature_def_map=signature_map)
         # try and load the model
         meta_graph_def = load_meta_graph(export_dir, tags)
         self.assertIsNotNone(meta_graph_def)
         # we want both signatures to be present
         self.assertEqual(len(meta_graph_def.signature_def), 2)
         # the signatures should be valid
         for signature in meta_graph_def.signature_def.values():
             self.assertTrue(is_valid_signature(signature))
         # the default signature should have one output
         default_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
         self.assertEqual(
             len(meta_graph_def.signature_def[default_key].outputs), 1)
         # debug signature should be present and contain two outputs
         self.assertIn('debug', meta_graph_def.signature_def.keys())
         self.assertEqual(
             len(meta_graph_def.signature_def['debug'].outputs), 2)
     finally:
         if os.path.exists(export_dir):
             shutil.rmtree(export_dir)
Exemple #4
0
    def test_build_signatures(self):
        """_build_signatures should apply given key and include inputs"""
        graph = testutils.get_sample_graph(
            testutils.get_path_to(testutils.MULTI_HEAD_FILE))
        default_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
        debug_key = 'debug_model'
        signature_map = {
            '': {api.SIGNATURE_OUTPUTS: ['Identity']},
            debug_key: {api.SIGNATURE_OUTPUTS: ['Identity', 'Identity_1']}}

        signature_def_map = api._build_signatures(graph, signature_map)
        self.assertIn(default_key, signature_def_map)
        self.assertIn(debug_key, signature_def_map)
        for signature_def in signature_def_map.values():
            self.assertTrue(is_valid_signature(signature_def))
 def test_load_graph_model_and_signature_from_meta_data(self):
     """load_graph_model_and_signature should extract signature def"""
     _, signature_def = api.load_graph_model_and_signature(
         testutils.get_path_to(testutils.PRELU_MODEL_PATH))
     self.assertIsInstance(signature_def, util.SignatureDef)
     self.assertTrue(is_valid_signature(signature_def))
     self.assertEqual(len(signature_def.inputs), 1)
     key, value = list(signature_def.inputs.items())[0]
     self.assertEqual(key, 'input_vector')
     self.assertEqual(value.name, 'input_vector:0')
     self.assertEqual(value.dtype, tf.dtypes.float32)
     self.assertEqual(_shape_of(value), (-1, 7))
     self.assertEqual(len(signature_def.outputs), 1)
     key, value = list(signature_def.outputs.items())[0]
     self.assertEqual(key, 'Identity')
     self.assertEqual(value.name, 'Identity:0')
     self.assertEqual(value.dtype, tf.dtypes.float32)
     self.assertEqual(_shape_of(value), (-1, 1))
 def test_load_graph_model_and_signature_from_tree(self):
     """load_graph_model_and_signature should infer signature def
        from graph if signature def is incomplete
     """
     _, signature_def = api.load_graph_model_and_signature(
         testutils.get_path_to(testutils.SIMPLE_MODEL_PATH_NAME))
     # simple model is missing inputs in signature - defer from graph
     self.assertIsInstance(signature_def, util.SignatureDef)
     self.assertTrue(is_valid_signature(signature_def))
     self.assertEqual(len(signature_def.inputs), 1)
     key, value = list(signature_def.inputs.items())[0]
     self.assertEqual(key, 'x')
     self.assertEqual(value.name, 'x:0')
     self.assertEqual(value.dtype, tf.dtypes.float32)
     self.assertEqual(_shape_of(value), (-1, 1))
     self.assertEqual(len(signature_def.outputs), 1)
     key, value = list(signature_def.outputs.items())[0]
     self.assertEqual(key, 'Identity')
     self.assertEqual(value.name, 'Identity:0')
     self.assertEqual(value.dtype, tf.dtypes.float32)
     self.assertEqual(_shape_of(value), (-1, 1))
 def test_graph_model_to_saved_model(self):
     """graph_model_to_saved_model should save valid SavedModel"""
     model_dir = testutils.get_path_to(testutils.PRELU_MODEL_PATH)
     export_dir = tempfile.mkdtemp(suffix='.saved_model')
     try:
         tags = [tf.saved_model.SERVING]
         api.graph_model_to_saved_model(model_dir, export_dir, tags=tags)
         self.assertTrue(os.path.exists(export_dir))
         self.assertTrue(tf.saved_model.contains_saved_model(export_dir))
         # try and load the model
         meta_graph_def = load_meta_graph(export_dir, tags)
         self.assertIsNotNone(meta_graph_def)
         # we also want a signature to be present
         self.assertEqual(len(meta_graph_def.signature_def), 1)
         # the signatures should be valid
         self.assertTrue(
             is_valid_signature(
                 list(meta_graph_def.signature_def.values())[0]))
     finally:
         if os.path.exists(export_dir):
             shutil.rmtree(export_dir)