def test_rename_map_ctor_verifies_input_format(self): """RenameMap constructor verifies argument types""" # pass non-dictionary self.assertRaises(ValueError, lambda: api.RenameMap(5)) # key is not a str self.assertRaises( ValueError, lambda: api.RenameMap({ 'valid': 'ok', True: 'that cannot work' })) # value is not a str self.assertRaises( ValueError, lambda: api.RenameMap({ 'valid': 'ok', 'no-go': ['invalid'] })) # empty key self.assertRaises( ValueError, lambda: api.RenameMap({ 'valid': 'ok', '': 'invalid' })) # empty value self.assertRaises( ValueError, lambda: api.RenameMap({ 'valid': 'ok', 'invalid': '\n \r\t' }))
def test_rename_map_ctor_accepts_str_to_str_dict(self): """RenameMap constructor accepts Dict[str, str] and iterable""" rename_map = api.RenameMap({'Identity': 'stylised_image'}) self.assertEqual(rename_map.mapping['Identity'], 'stylised_image') rename_map = api.RenameMap([('Identity', 'stylised_image'), ('input', 'original_image')]) self.assertEqual(rename_map.mapping['Identity'], 'stylised_image') self.assertEqual(rename_map.mapping['input'], 'original_image')
def test_graph_model_to_saved_model_accepts_signature_key_map(self): """graph_model_to_saved_model should accept signature key map""" model_dir = testutils.get_path_to(testutils.MULTI_HEAD_PATH) export_dir = tempfile.mkdtemp(suffix='.saved_model') try: tags = [tf.saved_model.SERVING] signature_map = { '': {api.SIGNATURE_OUTPUTS: ['Identity']}, 'debug': {api.SIGNATURE_OUTPUTS: ['Identity', 'Identity_1']}} signature_key = api.RenameMap([ ('Identity', 'output'), ('Identity_1', 'autoencoder_output') ]) api.graph_model_to_saved_model(model_dir, export_dir, tags=tags, signature_def_map=signature_map, signature_key_map=signature_key) # try and load the model meta_graph_def = load_meta_graph(export_dir, tags) # the signatures should contain the renamed keys for signature in meta_graph_def.signature_def.values(): self.assertIn('output', signature.outputs) self.assertEqual(signature.outputs['output'].name, 'Identity:0') signature = meta_graph_def.signature_def['debug'] self.assertIn('autoencoder_output', signature.outputs) self.assertEqual(signature.outputs['autoencoder_output'].name, 'Identity_1:0') finally: if os.path.exists(export_dir): shutil.rmtree(export_dir)
def test_graph_models_to_saved_model_accepts_signature_keys(self): """graph_models_to_saved_model should accept signature keys""" model_dir_1 = testutils.get_path_to(testutils.SIMPLE_MODEL_PATH_NAME) model_dir_2 = testutils.get_path_to(testutils.MULTI_HEAD_PATH) tags_1 = [tf.saved_model.SERVING, 'model_1'] tags_2 = [tf.saved_model.SERVING, 'model_2'] export_dir = tempfile.mkdtemp(suffix='.saved_model') try: model_list = [(model_dir_1, tags_1), (model_dir_2, tags_2)] signatures = { 'ignore_this': { '': { api.SIGNATURE_OUTPUTS: ['y'] } }, model_dir_2: { '': { api.SIGNATURE_OUTPUTS: ['Identity'] } } } signature_keys = { model_dir_1: api.RenameMap({ 'x': 'input', 'Identity': 'output' }), model_dir_2: api.RenameMap({'Identity': 'scores'}) } api.graph_models_to_saved_model(model_list, export_dir, signatures, signature_keys) self.assertTrue(os.path.exists(export_dir)) # check the signatures of model 1 meta_graph_def = load_meta_graph(export_dir, tags_1) signature = list(meta_graph_def.signature_def.values())[0] self.assertIn('input', signature.inputs) self.assertEqual(signature.inputs['input'].name, 'x:0') self.assertIn('output', signature.outputs) self.assertEqual(signature.outputs['output'].name, 'Identity:0') # check the signatures of model 2 meta_graph_def = load_meta_graph(export_dir, tags_2) signature = list(meta_graph_def.signature_def.values())[0] self.assertIn('scores', signature.outputs) self.assertEqual(signature.outputs['scores'].name, 'Identity:0') finally: if os.path.exists(export_dir): shutil.rmtree(export_dir)
def test_rename_map_apply(self): """RenameMap.apply maps old names to new names""" _, signature_def = api.load_graph_model_and_signature( testutils.get_path_to(testutils.SIMPLE_MODEL_PATH_NAME)) mapping = api.RenameMap({'x': 'input', 'Identity': 'output'}) updated = mapping.apply(signature_def) self.assertNotIn('x', updated.inputs) self.assertIn('input', updated.inputs) # keep the tensor name! self.assertEqual(updated.inputs['input'].name, 'x:0') self.assertNotIn('Identity', updated.outputs) self.assertIn('output', updated.outputs) # keep the tensor name! self.assertEqual(updated.outputs['output'].name, 'Identity:0')
def test_rename_map_apply_requires_signature_def(self): """RenameMap.apply accepts SignatureDef only""" self.assertRaises(ValueError, lambda: api.RenameMap({}).apply(5))
def test_rename_map_ctor_empty(self): """RenameMap constructor accepts empty arguments""" rename_map = api.RenameMap({}) self.assertTrue(not any(rename_map.mapping)) rename_map = api.RenameMap(dict()) self.assertTrue(not any(rename_map.mapping))
def _get_signature_keys( namespace: argparse.Namespace) -> Optional[api.RenameMap]: if namespace.rename is not None: return api.RenameMap(namespace.rename) else: return None