Esempio n. 1
0
def deserialize_keras_object(identifier,
                             module_objects=None,
                             custom_objects=None,
                             printable_module_name='object'):
    """Turns the serialized form of a Keras object back into an actual object."""
    if identifier is None:
        return None

    if isinstance(identifier, dict):
        # In this case we are dealing with a Keras config dictionary.
        config = identifier
        (cls, cls_config) = class_and_config_for_serialized_keras_object(
            config, module_objects, custom_objects, printable_module_name)

        if hasattr(cls, 'from_config'):
            arg_spec = tf_inspect.getfullargspec(cls.from_config)
            custom_objects = custom_objects or {}

            if 'custom_objects' in arg_spec.args:
                return cls.from_config(
                    cls_config,
                    custom_objects=dict(
                        list(_GLOBAL_CUSTOM_OBJECTS.items()) +
                        list(custom_objects.items())))
            with CustomObjectScope(custom_objects):
                return cls.from_config(cls_config)
        else:
            # Then `cls` may be a function returning a class.
            # in this case by convention `config` holds
            # the kwargs of the function.
            custom_objects = custom_objects or {}
            with CustomObjectScope(custom_objects):
                return cls(**cls_config)
    elif isinstance(identifier, six.string_types):
        object_name = identifier
        if custom_objects and object_name in custom_objects:
            obj = custom_objects.get(object_name)
        elif object_name in _GLOBAL_CUSTOM_OBJECTS:
            obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
        else:
            obj = module_objects.get(object_name)
            if obj is None:
                raise ValueError(
                    'Unknown {}: {}. Please ensure this object is '
                    'passed to the `custom_objects` argument. See '
                    'https://www.tensorflow.org/guide/keras/save_and_serialize'
                    '#registering_the_custom_object for details.'.format(
                        printable_module_name, object_name))

        # Classes passed by name are instantiated with no args, functions are
        # returned as-is.
        if tf_inspect.isclass(obj):
            return obj()
        return obj
    elif tf_inspect.isfunction(identifier):
        # If a function has already been deserialized, return as is.
        return identifier
    else:
        raise ValueError('Could not interpret serialized %s: %s' %
                         (printable_module_name, identifier))
Esempio n. 2
0
def class_and_config_for_serialized_keras_object(
        config,
        module_objects=None,
        custom_objects=None,
        printable_module_name='object'):
    """Returns the class name and config for a serialized keras object."""
    if (not isinstance(config, dict) or 'class_name' not in config
            or 'config' not in config):
        raise ValueError('Improper config format: ' + str(config))

    class_name = config['class_name']
    cls = get_registered_object(class_name, custom_objects, module_objects)
    if cls is None:
        raise ValueError(
            'Unknown {}: {}. Please ensure this object is '
            'passed to the `custom_objects` argument. See '
            'https://www.tensorflow.org/guide/keras/save_and_serialize'
            '#registering_the_custom_object for details.'.format(
                printable_module_name, class_name))

    cls_config = config['config']
    # Check if `cls_config` is a list. If it is a list, return the class and the
    # associated class configs for recursively deserialization. This case will
    # happen on the old version of sequential model (e.g. `keras_version` ==
    # "2.0.6"), which is serialized in a different structure, for example
    # "{'class_name': 'Sequential',
    #   'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}".
    if isinstance(cls_config, list):
        return (cls, cls_config)

    deserialized_objects = {}
    for key, item in cls_config.items():
        if isinstance(item, dict) and '__passive_serialization__' in item:
            deserialized_objects[key] = deserialize_keras_object(
                item,
                module_objects=module_objects,
                custom_objects=custom_objects,
                printable_module_name='config_item')
        # TODO(momernick): Should this also have 'module_objects'?
        elif (isinstance(item, six.string_types) and tf_inspect.isfunction(
                get_registered_object(item, custom_objects))):
            # Handle custom functions here. When saving functions, we only save the
            # function's name as a string. If we find a matching string in the custom
            # objects during deserialization, we convert the string back to the
            # original function.
            # Note that a potential issue is that a string field could have a naming
            # conflict with a custom function name, but this should be a rare case.
            # This issue does not occur if a string field has a naming conflict with
            # a custom object, since the config of an object will always be a dict.
            deserialized_objects[key] = get_registered_object(
                item, custom_objects)
    for key, item in deserialized_objects.items():
        cls_config[key] = deserialized_objects[key]

    return (cls, cls_config)
Esempio n. 3
0
def deserialize_keras_object(identifier,
                             module_objects=None,
                             custom_objects=None,
                             printable_module_name='object'):
    """Turns the serialized form of a Keras object back into an actual object.

  This function is for mid-level library implementers rather than end users.

  Importantly, this utility requires you to provide the dict of `module_objects`
  to use for looking up the object config; this is not populated by default.
  If you need a deserialization utility that has preexisting knowledge of
  built-in Keras objects, use e.g. `keras.layers.deserialize(config)`,
  `keras.metrics.deserialize(config)`, etc.

  Calling `deserialize_keras_object` while underneath the
  `SharedObjectLoadingScope` context manager will cause any already-seen shared
  objects to be returned as-is rather than creating a new object.

  Args:
    identifier: the serialized form of the object.
    module_objects: A dictionary of built-in objects to look the name up in.
      Generally, `module_objects` is provided by midlevel library implementers.
    custom_objects: A dictionary of custom objects to look the name up in.
      Generally, `custom_objects` is provided by the end user.
    printable_module_name: A human-readable string representing the type of the
      object. Printed in case of exception.

  Returns:
    The deserialized object.

  Example:

  A mid-level library implementer might want to implement a utility for
  retrieving an object from its config, as such:

  ```python
  def deserialize(config, custom_objects=None):
     return deserialize_keras_object(
       identifier,
       module_objects=globals(),
       custom_objects=custom_objects,
       name="MyObjectType",
     )
  ```

  This is how e.g. `keras.layers.deserialize()` is implemented.
  """
    if identifier is None:
        return None

    if isinstance(identifier, dict):
        # In this case we are dealing with a Keras config dictionary.
        config = identifier
        (cls, cls_config) = class_and_config_for_serialized_keras_object(
            config, module_objects, custom_objects, printable_module_name)

        # If this object has already been loaded (i.e. it's shared between multiple
        # objects), return the already-loaded object.
        shared_object_id = config.get(SHARED_OBJECT_KEY)
        shared_object = _shared_object_loading_scope().get(shared_object_id)  # pylint: disable=assignment-from-none
        if shared_object is not None:
            return shared_object

        if hasattr(cls, 'from_config'):
            arg_spec = tf_inspect.getfullargspec(cls.from_config)
            custom_objects = custom_objects or {}

            if 'custom_objects' in arg_spec.args:
                deserialized_obj = cls.from_config(
                    cls_config,
                    custom_objects=dict(
                        list(_GLOBAL_CUSTOM_OBJECTS.items()) +
                        list(custom_objects.items())))
            else:
                with CustomObjectScope(custom_objects):
                    deserialized_obj = cls.from_config(cls_config)
        else:
            # Then `cls` may be a function returning a class.
            # in this case by convention `config` holds
            # the kwargs of the function.
            custom_objects = custom_objects or {}
            with CustomObjectScope(custom_objects):
                deserialized_obj = cls(**cls_config)

        # Add object to shared objects, in case we find it referenced again.
        _shared_object_loading_scope().set(shared_object_id, deserialized_obj)

        return deserialized_obj

    elif isinstance(identifier, str):
        object_name = identifier
        if custom_objects and object_name in custom_objects:
            obj = custom_objects.get(object_name)
        elif object_name in _GLOBAL_CUSTOM_OBJECTS:
            obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
        else:
            obj = module_objects.get(object_name)
            if obj is None:
                raise ValueError(
                    'Unknown {}: {}. Please ensure this object is '
                    'passed to the `custom_objects` argument. See '
                    'https://www.tensorflow.org/guide/keras/save_and_serialize'
                    '#registering_the_custom_object for details.'.format(
                        printable_module_name, object_name))

        # Classes passed by name are instantiated with no args, functions are
        # returned as-is.
        if tf_inspect.isclass(obj):
            return obj()
        return obj
    elif tf_inspect.isfunction(identifier):
        # If a function has already been deserialized, return as is.
        return identifier
    else:
        raise ValueError('Could not interpret serialized %s: %s' %
                         (printable_module_name, identifier))
Esempio n. 4
0
def class_and_config_for_serialized_keras_object(
    config,
    module_objects=None,
    custom_objects=None,
    printable_module_name="object",
):
    """Returns the class name and config for a serialized keras object."""
    if (not isinstance(config, dict) or "class_name" not in config
            or "config" not in config):
        raise ValueError(
            f"Improper config format for {config}. "
            "Expecting python dict contains `class_name` and `config` as keys")

    class_name = config["class_name"]
    cls = get_registered_object(class_name, custom_objects, module_objects)
    if cls is None:
        raise ValueError(
            f"Unknown {printable_module_name}: {class_name}. "
            "Please ensure this "
            "object is passed to the `custom_objects` argument. See "
            "https://www.tensorflow.org/guide/keras/save_and_serialize"
            "#registering_the_custom_object for details.")

    cls_config = config["config"]
    # Check if `cls_config` is a list. If it is a list, return the class and the
    # associated class configs for recursively deserialization. This case will
    # happen on the old version of sequential model (e.g. `keras_version` ==
    # "2.0.6"), which is serialized in a different structure, for example
    # "{'class_name': 'Sequential',
    #   'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}".
    if isinstance(cls_config, list):
        return (cls, cls_config)

    deserialized_objects = {}
    for key, item in cls_config.items():
        if key == "name":
            # Assume that the value of 'name' is a string that should not be
            # deserialized as a function. This avoids the corner case where
            # cls_config['name'] has an identical name to a custom function and
            # gets converted into that function.
            deserialized_objects[key] = item
        elif isinstance(item, dict) and "__passive_serialization__" in item:
            deserialized_objects[key] = deserialize_keras_object(
                item,
                module_objects=module_objects,
                custom_objects=custom_objects,
                printable_module_name="config_item",
            )
        # TODO(momernick): Should this also have 'module_objects'?
        elif isinstance(item, str) and tf_inspect.isfunction(
                get_registered_object(item, custom_objects)):
            # Handle custom functions here. When saving functions, we only save
            # the function's name as a string. If we find a matching string in
            # the custom objects during deserialization, we convert the string
            # back to the original function.
            # Note that a potential issue is that a string field could have a
            # naming conflict with a custom function name, but this should be a
            # rare case.  This issue does not occur if a string field has a
            # naming conflict with a custom object, since the config of an
            # object will always be a dict.
            deserialized_objects[key] = get_registered_object(
                item, custom_objects)
    for key, item in deserialized_objects.items():
        cls_config[key] = deserialized_objects[key]

    return (cls, cls_config)