Exemple #1
0
    def test_serialize_custom_class_with_default_name(self):
        @generic_utils.register_keras_serializable()
        class TestClass(object):
            def __init__(self, value):
                self._value = value

            def get_config(self):
                return {'value': self._value}

        serialized_name = 'Custom>TestClass'
        inst = TestClass(value=10)
        class_name = generic_utils._GLOBAL_CUSTOM_NAMES[TestClass]
        self.assertEqual(serialized_name, class_name)
        config = generic_utils.serialize_keras_object(inst)
        self.assertEqual(class_name, config['class_name'])
        new_inst = generic_utils.deserialize_keras_object(config)
        self.assertIsNot(inst, new_inst)
        self.assertIsInstance(new_inst, TestClass)
        self.assertEqual(10, new_inst._value)

        # Make sure registering a new class with same name will fail.
        with self.assertRaisesRegex(ValueError,
                                    '.*has already been registered.*'):

            @generic_utils.register_keras_serializable()  # pylint: disable=function-redefined
            class TestClass(object):
                def __init__(self, value):
                    self._value = value

                def get_config(self):
                    return {'value': self._value}
Exemple #2
0
def serialize(activation):
    """Returns name attribute (`__name__`) of function.

  Arguments:
      activation : Function

  Returns:
      String denoting the name attribute of the input function

  For example:

  >>> tf.keras.activations.serialize(tf.keras.activations.tanh)
  'tanh'
  >>> tf.keras.activations.serialize(tf.keras.activations.sigmoid)
  'sigmoid'
  >>> tf.keras.activations.serialize('abcd')
  Traceback (most recent call last):
  ...
  ValueError: ('Cannot serialize', 'abcd')

  Raises:
      ValueError: The input function is not a valid one.
  """
    if (hasattr(activation, '__name__')
            and activation.__name__ in _TF_ACTIVATIONS_V2):
        return _TF_ACTIVATIONS_V2[activation.__name__]
    return serialize_keras_object(activation)
Exemple #3
0
    def test_serialize_custom_class_with_custom_name(self):
        @generic_utils.register_keras_serializable('TestPackage', 'CustomName')
        class OtherTestClass(object):
            def __init__(self, val):
                self._val = val

            def get_config(self):
                return {'val': self._val}

        serialized_name = 'TestPackage>CustomName'
        inst = OtherTestClass(val=5)
        class_name = generic_utils._GLOBAL_CUSTOM_NAMES[OtherTestClass]
        self.assertEqual(serialized_name, class_name)
        fn_class_name = generic_utils.get_registered_name(OtherTestClass)
        self.assertEqual(fn_class_name, class_name)

        cls = generic_utils.get_registered_object(fn_class_name)
        self.assertEqual(OtherTestClass, cls)

        config = generic_utils.serialize_keras_object(inst)
        self.assertEqual(class_name, config['class_name'])
        new_inst = generic_utils.deserialize_keras_object(config)
        self.assertIsNot(inst, new_inst)
        self.assertIsInstance(new_inst, OtherTestClass)
        self.assertEqual(5, new_inst._val)
Exemple #4
0
    def test_serialize_custom_function(self):
        @generic_utils.register_keras_serializable()
        def my_fn():
            return 42

        serialized_name = 'Custom>my_fn'
        class_name = generic_utils._GLOBAL_CUSTOM_NAMES[my_fn]
        self.assertEqual(serialized_name, class_name)
        fn_class_name = generic_utils.get_registered_name(my_fn)
        self.assertEqual(fn_class_name, class_name)

        config = generic_utils.serialize_keras_object(my_fn)
        self.assertEqual(class_name, config)
        fn = generic_utils.deserialize_keras_object(config)
        self.assertEqual(42, fn())

        fn_2 = generic_utils.get_registered_object(fn_class_name)
        self.assertEqual(42, fn_2())
Exemple #5
0
def serialize(regularizer):
    return serialize_keras_object(regularizer)
Exemple #6
0
 def test_serialize_none(self):
     serialized = generic_utils.serialize_keras_object(None)
     self.assertEqual(serialized, None)
     deserialized = generic_utils.deserialize_keras_object(serialized)
     self.assertEqual(deserialized, None)
Exemple #7
0
def serialize(initializer):
    return serialize_keras_object(initializer)
Exemple #8
0
def serialize(constraint):
    return serialize_keras_object(constraint)