def test_rename_map_ctor_verifies_input_format(self):
     """RenameMap constructor verifies argument types"""
     # pass non-dictionary
     self.assertRaises(ValueError, lambda: api.RenameMap(5))
     # key is not a str
     self.assertRaises(
         ValueError, lambda: api.RenameMap({
             'valid': 'ok',
             True: 'that cannot work'
         }))
     # value is not a str
     self.assertRaises(
         ValueError, lambda: api.RenameMap({
             'valid': 'ok',
             'no-go': ['invalid']
         }))
     # empty key
     self.assertRaises(
         ValueError, lambda: api.RenameMap({
             'valid': 'ok',
             '': 'invalid'
         }))
     # empty value
     self.assertRaises(
         ValueError, lambda: api.RenameMap({
             'valid': 'ok',
             'invalid': '\n  \r\t'
         }))
 def test_rename_map_ctor_accepts_str_to_str_dict(self):
     """RenameMap constructor accepts Dict[str, str] and iterable"""
     rename_map = api.RenameMap({'Identity': 'stylised_image'})
     self.assertEqual(rename_map.mapping['Identity'], 'stylised_image')
     rename_map = api.RenameMap([('Identity', 'stylised_image'),
                                 ('input', 'original_image')])
     self.assertEqual(rename_map.mapping['Identity'], 'stylised_image')
     self.assertEqual(rename_map.mapping['input'], 'original_image')
Exemple #3
0
 def test_graph_model_to_saved_model_accepts_signature_key_map(self):
     """graph_model_to_saved_model should accept signature key map"""
     model_dir = testutils.get_path_to(testutils.MULTI_HEAD_PATH)
     export_dir = tempfile.mkdtemp(suffix='.saved_model')
     try:
         tags = [tf.saved_model.SERVING]
         signature_map = {
             '': {api.SIGNATURE_OUTPUTS: ['Identity']},
             'debug': {api.SIGNATURE_OUTPUTS: ['Identity', 'Identity_1']}}
         signature_key = api.RenameMap([
             ('Identity', 'output'), ('Identity_1', 'autoencoder_output')
         ])
         api.graph_model_to_saved_model(model_dir, export_dir,
                                        tags=tags,
                                        signature_def_map=signature_map,
                                        signature_key_map=signature_key)
         # try and load the model
         meta_graph_def = load_meta_graph(export_dir, tags)
         # the signatures should contain the renamed keys
         for signature in meta_graph_def.signature_def.values():
             self.assertIn('output', signature.outputs)
             self.assertEqual(signature.outputs['output'].name,
                              'Identity:0')
         signature = meta_graph_def.signature_def['debug']
         self.assertIn('autoencoder_output', signature.outputs)
         self.assertEqual(signature.outputs['autoencoder_output'].name,
                          'Identity_1:0')
     finally:
         if os.path.exists(export_dir):
             shutil.rmtree(export_dir)
 def test_graph_models_to_saved_model_accepts_signature_keys(self):
     """graph_models_to_saved_model should accept signature keys"""
     model_dir_1 = testutils.get_path_to(testutils.SIMPLE_MODEL_PATH_NAME)
     model_dir_2 = testutils.get_path_to(testutils.MULTI_HEAD_PATH)
     tags_1 = [tf.saved_model.SERVING, 'model_1']
     tags_2 = [tf.saved_model.SERVING, 'model_2']
     export_dir = tempfile.mkdtemp(suffix='.saved_model')
     try:
         model_list = [(model_dir_1, tags_1), (model_dir_2, tags_2)]
         signatures = {
             'ignore_this': {
                 '': {
                     api.SIGNATURE_OUTPUTS: ['y']
                 }
             },
             model_dir_2: {
                 '': {
                     api.SIGNATURE_OUTPUTS: ['Identity']
                 }
             }
         }
         signature_keys = {
             model_dir_1: api.RenameMap({
                 'x': 'input',
                 'Identity': 'output'
             }),
             model_dir_2: api.RenameMap({'Identity': 'scores'})
         }
         api.graph_models_to_saved_model(model_list, export_dir, signatures,
                                         signature_keys)
         self.assertTrue(os.path.exists(export_dir))
         # check the signatures of model 1
         meta_graph_def = load_meta_graph(export_dir, tags_1)
         signature = list(meta_graph_def.signature_def.values())[0]
         self.assertIn('input', signature.inputs)
         self.assertEqual(signature.inputs['input'].name, 'x:0')
         self.assertIn('output', signature.outputs)
         self.assertEqual(signature.outputs['output'].name, 'Identity:0')
         # check the signatures of model 2
         meta_graph_def = load_meta_graph(export_dir, tags_2)
         signature = list(meta_graph_def.signature_def.values())[0]
         self.assertIn('scores', signature.outputs)
         self.assertEqual(signature.outputs['scores'].name, 'Identity:0')
     finally:
         if os.path.exists(export_dir):
             shutil.rmtree(export_dir)
 def test_rename_map_apply(self):
     """RenameMap.apply maps old names to new names"""
     _, signature_def = api.load_graph_model_and_signature(
         testutils.get_path_to(testutils.SIMPLE_MODEL_PATH_NAME))
     mapping = api.RenameMap({'x': 'input', 'Identity': 'output'})
     updated = mapping.apply(signature_def)
     self.assertNotIn('x', updated.inputs)
     self.assertIn('input', updated.inputs)
     # keep the tensor name!
     self.assertEqual(updated.inputs['input'].name, 'x:0')
     self.assertNotIn('Identity', updated.outputs)
     self.assertIn('output', updated.outputs)
     # keep the tensor name!
     self.assertEqual(updated.outputs['output'].name, 'Identity:0')
 def test_rename_map_apply_requires_signature_def(self):
     """RenameMap.apply accepts SignatureDef only"""
     self.assertRaises(ValueError, lambda: api.RenameMap({}).apply(5))
 def test_rename_map_ctor_empty(self):
     """RenameMap constructor accepts empty arguments"""
     rename_map = api.RenameMap({})
     self.assertTrue(not any(rename_map.mapping))
     rename_map = api.RenameMap(dict())
     self.assertTrue(not any(rename_map.mapping))
def _get_signature_keys(
        namespace: argparse.Namespace) -> Optional[api.RenameMap]:
    if namespace.rename is not None:
        return api.RenameMap(namespace.rename)
    else:
        return None