def test_invalid_init(self): with self.assertRaisesRegexp(ValueError, 'Multiple keys/values found'): _ = mode_keys.ModeKeyMap( **{ mode_keys.KerasModeKeys.PREDICT: 3, mode_keys.EstimatorModeKeys.PREDICT: 1 })
def test_map(self): mode_map = mode_keys.ModeKeyMap(**{ mode_keys.KerasModeKeys.PREDICT: 3, mode_keys.KerasModeKeys.TEST: 1 }) # Test dictionary __getitem__ self.assertEqual(3, mode_map[mode_keys.KerasModeKeys.PREDICT]) self.assertEqual(3, mode_map[mode_keys.EstimatorModeKeys.PREDICT]) self.assertEqual(1, mode_map[mode_keys.KerasModeKeys.TEST]) self.assertEqual(1, mode_map[mode_keys.EstimatorModeKeys.EVAL]) with self.assertRaises(KeyError): _ = mode_map[mode_keys.KerasModeKeys.TRAIN] with self.assertRaises(KeyError): _ = mode_map[mode_keys.EstimatorModeKeys.TRAIN] with self.assertRaisesRegexp(ValueError, 'Invalid mode'): _ = mode_map['serve'] # Test common dictionary methods self.assertLen(mode_map, 2) self.assertEqual({1, 3}, set(mode_map.values())) self.assertEqual( {mode_keys.KerasModeKeys.TEST, mode_keys.KerasModeKeys.PREDICT}, set(mode_map.keys())) # Map is immutable with self.assertRaises(TypeError): mode_map[mode_keys.KerasModeKeys.TEST] = 1
import six from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model.model_utils import export_output as export_output_lib from tensorflow.python.saved_model.model_utils import mode_keys from tensorflow.python.saved_model.model_utils.mode_keys import KerasModeKeys as ModeKeys from tensorflow.python.util import compat # Mapping of the modes to appropriate MetaGraph tags in the SavedModel. EXPORT_TAG_MAP = mode_keys.ModeKeyMap( **{ ModeKeys.PREDICT: [tag_constants.SERVING], ModeKeys.TRAIN: [tag_constants.TRAINING], ModeKeys.TEST: [tag_constants.EVAL] }) # For every exported mode, a SignatureDef map should be created using the # functions `export_outputs_for_mode` and `build_all_signature_defs`. By # default, this map will contain a single Signature that defines the input # tensors and output predictions, losses, and/or metrics (depending on the mode) # The default keys used in the SignatureDef map are defined below. SIGNATURE_KEY_MAP = mode_keys.ModeKeyMap( **{ ModeKeys.PREDICT: signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, ModeKeys.TRAIN: signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY, ModeKeys.TEST: signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY })