示例#1
0
def deserialize(name, custom_objects=None):
    """Returns activation function denoted by input string.

  Arguments:
      x : String

  Returns:
      TensorFlow Activation function denoted by input string.

  For example:

  >>> tf.keras.activations.deserialize('linear')
   <function linear at 0x1239596a8>
  >>> tf.keras.activations.deserialize('sigmoid')
   <function sigmoid at 0x123959510>
  >>> tf.keras.activations.deserialize('abcd')
  Traceback (most recent call last):
  ...
  ValueError: Unknown activation function:abcd

  Args:
    name: The name of the activation function.
    custom_objects: A {name:value} dictionary for activations not build into
      keras.

  Raises:
      ValueError: `Unknown activation function` if the input string does not
      denote any defined Tensorflow activation function.
  """
    return deserialize_keras_object(
        name,
        module_objects=globals(),
        custom_objects=custom_objects,
        printable_module_name='activation function')
示例#2
0
    def test_serialize_custom_class_with_default_name(self):
        @generic_utils.register_keras_serializable()
        class TestClass(object):
            def __init__(self, value):
                self._value = value

            def get_config(self):
                return {'value': self._value}

        serialized_name = 'Custom>TestClass'
        inst = TestClass(value=10)
        class_name = generic_utils._GLOBAL_CUSTOM_NAMES[TestClass]
        self.assertEqual(serialized_name, class_name)
        config = generic_utils.serialize_keras_object(inst)
        self.assertEqual(class_name, config['class_name'])
        new_inst = generic_utils.deserialize_keras_object(config)
        self.assertIsNot(inst, new_inst)
        self.assertIsInstance(new_inst, TestClass)
        self.assertEqual(10, new_inst._value)

        # Make sure registering a new class with same name will fail.
        with self.assertRaisesRegex(ValueError,
                                    '.*has already been registered.*'):

            @generic_utils.register_keras_serializable()  # pylint: disable=function-redefined
            class TestClass(object):
                def __init__(self, value):
                    self._value = value

                def get_config(self):
                    return {'value': self._value}
示例#3
0
    def test_serialize_custom_class_with_custom_name(self):
        @generic_utils.register_keras_serializable('TestPackage', 'CustomName')
        class OtherTestClass(object):
            def __init__(self, val):
                self._val = val

            def get_config(self):
                return {'val': self._val}

        serialized_name = 'TestPackage>CustomName'
        inst = OtherTestClass(val=5)
        class_name = generic_utils._GLOBAL_CUSTOM_NAMES[OtherTestClass]
        self.assertEqual(serialized_name, class_name)
        fn_class_name = generic_utils.get_registered_name(OtherTestClass)
        self.assertEqual(fn_class_name, class_name)

        cls = generic_utils.get_registered_object(fn_class_name)
        self.assertEqual(OtherTestClass, cls)

        config = generic_utils.serialize_keras_object(inst)
        self.assertEqual(class_name, config['class_name'])
        new_inst = generic_utils.deserialize_keras_object(config)
        self.assertIsNot(inst, new_inst)
        self.assertIsInstance(new_inst, OtherTestClass)
        self.assertEqual(5, new_inst._val)
示例#4
0
def deserialize(config, custom_objects=None):
    """Return an `Initializer` object from its config."""
    if tf2.enabled():
        # Class names are the same for V1 and V2 but the V2 classes
        # are aliased in this file so we need to grab them directly
        # from `init_ops_v2`.
        module_objects = {
            obj_name: getattr(init_ops_v2, obj_name)
            for obj_name in dir(init_ops_v2)
        }
    else:
        module_objects = globals()
    return deserialize_keras_object(config,
                                    module_objects=module_objects,
                                    custom_objects=custom_objects,
                                    printable_module_name='initializer')
示例#5
0
    def test_serialize_custom_function(self):
        @generic_utils.register_keras_serializable()
        def my_fn():
            return 42

        serialized_name = 'Custom>my_fn'
        class_name = generic_utils._GLOBAL_CUSTOM_NAMES[my_fn]
        self.assertEqual(serialized_name, class_name)
        fn_class_name = generic_utils.get_registered_name(my_fn)
        self.assertEqual(fn_class_name, class_name)

        config = generic_utils.serialize_keras_object(my_fn)
        self.assertEqual(class_name, config)
        fn = generic_utils.deserialize_keras_object(config)
        self.assertEqual(42, fn())

        fn_2 = generic_utils.get_registered_object(fn_class_name)
        self.assertEqual(42, fn_2())
示例#6
0
def get(identifier):
    """Returns function.

  Arguments:
      identifier: Function or string

  Returns:
      Activation function denoted by input:
      - `Linear activation function` if input is `None`.
      - Function corresponding to the input string or input function.

  For example:

  >>> tf.keras.activations.get('softmax')
   <function softmax at 0x1222a3d90>
  >>> tf.keras.activations.get(tf.keras.activations.softmax)
   <function softmax at 0x1222a3d90>
  >>> tf.keras.activations.get(None)
   <function linear at 0x1239596a8>
  >>> tf.keras.activations.get(abs)
   <built-in function abs>
  >>> tf.keras.activations.get('abcd')
  Traceback (most recent call last):
  ...
  ValueError: Unknown activation function:abcd

  Raises:
      ValueError: Input is an unknown function or string, i.e., the input does
      not denote any defined function.
  """
    if identifier is None:
        return linear
    if isinstance(identifier, six.string_types):
        identifier = str(identifier)
        return deserialize(identifier)
    elif callable(identifier):
        return identifier
    elif isinstance(identifier, dict):
        return deserialize_keras_object(identifier,
                                        printable_module_name='activation')
    else:
        raise TypeError(
            'Could not interpret activation function identifier: {}'.format(
                repr(identifier)))
示例#7
0
def deserialize(config, custom_objects=None):
    return deserialize_keras_object(config,
                                    module_objects=globals(),
                                    custom_objects=custom_objects,
                                    printable_module_name='regularizer')
示例#8
0
 def test_serialize_none(self):
     serialized = generic_utils.serialize_keras_object(None)
     self.assertEqual(serialized, None)
     deserialized = generic_utils.deserialize_keras_object(serialized)
     self.assertEqual(deserialized, None)