def deserialize(name, custom_objects=None): """Returns activation function denoted by input string. Arguments: x : String Returns: TensorFlow Activation function denoted by input string. For example: >>> tf.keras.activations.deserialize('linear') <function linear at 0x1239596a8> >>> tf.keras.activations.deserialize('sigmoid') <function sigmoid at 0x123959510> >>> tf.keras.activations.deserialize('abcd') Traceback (most recent call last): ... ValueError: Unknown activation function:abcd Args: name: The name of the activation function. custom_objects: A {name:value} dictionary for activations not build into keras. Raises: ValueError: `Unknown activation function` if the input string does not denote any defined Tensorflow activation function. """ return deserialize_keras_object( name, module_objects=globals(), custom_objects=custom_objects, printable_module_name='activation function')
def test_serialize_custom_class_with_default_name(self): @generic_utils.register_keras_serializable() class TestClass(object): def __init__(self, value): self._value = value def get_config(self): return {'value': self._value} serialized_name = 'Custom>TestClass' inst = TestClass(value=10) class_name = generic_utils._GLOBAL_CUSTOM_NAMES[TestClass] self.assertEqual(serialized_name, class_name) config = generic_utils.serialize_keras_object(inst) self.assertEqual(class_name, config['class_name']) new_inst = generic_utils.deserialize_keras_object(config) self.assertIsNot(inst, new_inst) self.assertIsInstance(new_inst, TestClass) self.assertEqual(10, new_inst._value) # Make sure registering a new class with same name will fail. with self.assertRaisesRegex(ValueError, '.*has already been registered.*'): @generic_utils.register_keras_serializable() # pylint: disable=function-redefined class TestClass(object): def __init__(self, value): self._value = value def get_config(self): return {'value': self._value}
def test_serialize_custom_class_with_custom_name(self): @generic_utils.register_keras_serializable('TestPackage', 'CustomName') class OtherTestClass(object): def __init__(self, val): self._val = val def get_config(self): return {'val': self._val} serialized_name = 'TestPackage>CustomName' inst = OtherTestClass(val=5) class_name = generic_utils._GLOBAL_CUSTOM_NAMES[OtherTestClass] self.assertEqual(serialized_name, class_name) fn_class_name = generic_utils.get_registered_name(OtherTestClass) self.assertEqual(fn_class_name, class_name) cls = generic_utils.get_registered_object(fn_class_name) self.assertEqual(OtherTestClass, cls) config = generic_utils.serialize_keras_object(inst) self.assertEqual(class_name, config['class_name']) new_inst = generic_utils.deserialize_keras_object(config) self.assertIsNot(inst, new_inst) self.assertIsInstance(new_inst, OtherTestClass) self.assertEqual(5, new_inst._val)
def deserialize(config, custom_objects=None): """Return an `Initializer` object from its config.""" if tf2.enabled(): # Class names are the same for V1 and V2 but the V2 classes # are aliased in this file so we need to grab them directly # from `init_ops_v2`. module_objects = { obj_name: getattr(init_ops_v2, obj_name) for obj_name in dir(init_ops_v2) } else: module_objects = globals() return deserialize_keras_object(config, module_objects=module_objects, custom_objects=custom_objects, printable_module_name='initializer')
def test_serialize_custom_function(self): @generic_utils.register_keras_serializable() def my_fn(): return 42 serialized_name = 'Custom>my_fn' class_name = generic_utils._GLOBAL_CUSTOM_NAMES[my_fn] self.assertEqual(serialized_name, class_name) fn_class_name = generic_utils.get_registered_name(my_fn) self.assertEqual(fn_class_name, class_name) config = generic_utils.serialize_keras_object(my_fn) self.assertEqual(class_name, config) fn = generic_utils.deserialize_keras_object(config) self.assertEqual(42, fn()) fn_2 = generic_utils.get_registered_object(fn_class_name) self.assertEqual(42, fn_2())
def get(identifier): """Returns function. Arguments: identifier: Function or string Returns: Activation function denoted by input: - `Linear activation function` if input is `None`. - Function corresponding to the input string or input function. For example: >>> tf.keras.activations.get('softmax') <function softmax at 0x1222a3d90> >>> tf.keras.activations.get(tf.keras.activations.softmax) <function softmax at 0x1222a3d90> >>> tf.keras.activations.get(None) <function linear at 0x1239596a8> >>> tf.keras.activations.get(abs) <built-in function abs> >>> tf.keras.activations.get('abcd') Traceback (most recent call last): ... ValueError: Unknown activation function:abcd Raises: ValueError: Input is an unknown function or string, i.e., the input does not denote any defined function. """ if identifier is None: return linear if isinstance(identifier, six.string_types): identifier = str(identifier) return deserialize(identifier) elif callable(identifier): return identifier elif isinstance(identifier, dict): return deserialize_keras_object(identifier, printable_module_name='activation') else: raise TypeError( 'Could not interpret activation function identifier: {}'.format( repr(identifier)))
def deserialize(config, custom_objects=None): return deserialize_keras_object(config, module_objects=globals(), custom_objects=custom_objects, printable_module_name='regularizer')
def test_serialize_none(self): serialized = generic_utils.serialize_keras_object(None) self.assertEqual(serialized, None) deserialized = generic_utils.deserialize_keras_object(serialized) self.assertEqual(deserialized, None)