def test_build_signatures_applies_defaults(self): """_build_signatures should return defaults given None or empty""" graph = testutils.get_sample_graph() signature_map = {None: {api.SIGNATURE_OUTPUTS: ['Identity']}} signature_def_map = api._build_signatures(graph, signature_map) default_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY self.assertIn(default_key, signature_def_map) default_name = tf.saved_model.PREDICT_METHOD_NAME self.assertEqual(signature_def_map[default_key].method_name, default_name) # empty method names map to default, too signature_map = {None: {api.SIGNATURE_OUTPUTS: ['Identity'], api.SIGNATURE_METHOD: ''}} signature_def_map = api._build_signatures(graph, signature_map) self.assertEqual(signature_def_map[default_key].method_name, default_name)
def test_build_signatures(self): """_build_signatures should apply given key and include inputs""" graph = testutils.get_sample_graph( testutils.get_path_to(testutils.MULTI_HEAD_FILE)) default_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY debug_key = 'debug_model' signature_map = { '': {api.SIGNATURE_OUTPUTS: ['Identity']}, debug_key: {api.SIGNATURE_OUTPUTS: ['Identity', 'Identity_1']}} signature_def_map = api._build_signatures(graph, signature_map) self.assertIn(default_key, signature_def_map) self.assertIn(debug_key, signature_def_map) for signature_def in signature_def_map.values(): self.assertTrue(is_valid_signature(signature_def))
def test_build_signatures_verifies_outputs(self): """_build_signatures should not accept invalid output names""" graph = testutils.get_sample_graph() signature_map = {'': {api.SIGNATURE_OUTPUTS: ['Not_A_Tensor']}} self.assertRaises(ValueError, lambda: api._build_signatures(graph, signature_map))