示例#1
0
 def testExportSingleFunction(self):
   export_decorator = tf_export.tf_export('nameA', 'nameB')
   decorated_function = export_decorator(_test_function)
   self.assertEquals(decorated_function, _test_function)
   self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names)
   self.assertEquals(['nameA', 'nameB'],
                     tf_export.get_v1_names(decorated_function))
   self.assertEquals(['nameA', 'nameB'],
                     tf_export.get_v2_names(decorated_function))
   self.assertEqual(tf_export.get_symbol_from_name('nameA'),
                    decorated_function)
   self.assertEqual(tf_export.get_symbol_from_name('nameB'),
                    decorated_function)
   self.assertEqual(
       tf_export.get_symbol_from_name(
           tf_export.get_canonical_name_for_symbol(decorated_function)),
       decorated_function)
示例#2
0
 def testExportSingleFunctionV1Only(self):
   export_decorator = tf_export.tf_export(v1=['nameA', 'nameB'])
   decorated_function = export_decorator(_test_function)
   self.assertEqual(decorated_function, _test_function)
   self.assertAllEqual(('nameA', 'nameB'), decorated_function._tf_api_names_v1)
   self.assertAllEqual(['nameA', 'nameB'],
                       tf_export.get_v1_names(decorated_function))
   self.assertEqual([],
                    tf_export.get_v2_names(decorated_function))
   self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameA'),
                    decorated_function)
   self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameB'),
                    decorated_function)
   self.assertEqual(
       tf_export.get_symbol_from_name(
           tf_export.get_canonical_name_for_symbol(
               decorated_function, add_prefix_to_v1_names=True)),
       decorated_function)
示例#3
0
  def from_config(cls, config, custom_objects=None):
    config = config.copy()
    symbol_name = config.pop('cls_symbol')
    cls_ref = get_symbol_from_name(symbol_name)
    if not cls_ref:
      raise ValueError(f'TensorFlow symbol `{symbol_name}` could not be found.')

    config['cls_ref'] = cls_ref

    return cls(**config)
示例#4
0
  def from_config(cls, config, custom_objects=None):
    config = config.copy()
    symbol_name = config['function']
    function = get_symbol_from_name(symbol_name)
    if not function:
      raise ValueError(f'TF symbol `{symbol_name}` could not be found.')

    config['function'] = function

    return cls(**config)
 def testExportMultipleFunctions(self):
     export_decorator1 = tf_export.tf_export('nameA', 'nameB')
     export_decorator2 = tf_export.tf_export('nameC', 'nameD')
     decorated_function1 = export_decorator1(_test_function)
     decorated_function2 = export_decorator2(_test_function2)
     self.assertEqual(decorated_function1, _test_function)
     self.assertEqual(decorated_function2, _test_function2)
     self.assertEqual(('nameA', 'nameB'), decorated_function1._tf_api_names)
     self.assertEqual(('nameC', 'nameD'), decorated_function2._tf_api_names)
     self.assertEqual(tf_export.get_symbol_from_name('nameB'),
                      decorated_function1)
     self.assertEqual(tf_export.get_symbol_from_name('nameD'),
                      decorated_function2)
     self.assertEqual(
         tf_export.get_symbol_from_name(
             tf_export.get_canonical_name_for_symbol(decorated_function1)),
         decorated_function1)
     self.assertEqual(
         tf_export.get_symbol_from_name(
             tf_export.get_canonical_name_for_symbol(decorated_function2)),
         decorated_function2)
示例#6
0
def deserialize_keras_object(config_dict):
    """Retrieve the object by deserializing the config dict.

    The config dict is a python dictionary that consists of a set of key-value
    pairs, and represents a Keras object, such as an `Optimizer`, `Layer`,
    `Metrics`, etc. The saving and loading library uses the following keys to
    record information of a Keras object:

    - `class_name`: String. For classes that have an exported Keras namespace,
      this is the full path that starts with "keras", such as
      "keras.optimizers.Adam". For classes that do not have an exported Keras
      namespace, this is the name of the class, as exactly defined in the source
      code, such as "LossesContainer".
    - `config`: Dict. Library-defined or user-defined key-value pairs that store
      the configuration of the object, as obtained by `object.get_config()`.
    - `module`: String. The path of the python module, such as
      "keras.engine.compile_utils". Built-in Keras classes
      expect to have prefix `keras`. For classes that have an exported Keras
      namespace, this is `None` since the class can be fully identified by the
      full Keras path.
    - `registered_name`: String. The key the class is registered under via
      `keras.utils.register_keras_serializable(package, name)` API. The key has
      the format of '{package}>{name}', where `package` and `name` are the
      arguments passed to `register_keras_serializable()`. If `name` is not
      provided, it defaults to the class name. If `registered_name` successfully
      resolves to a class (that was registered), `class_name` and `config` values
      in the dict will not be used. `registered_name` is only used for
      non-built-in classes.

    For example, the following dictionary represents the built-in Adam optimizer
    with the relevant config. Note that for built-in (exported symbols that have
    an exported Keras namespace) classes, the library tracks the class by the
    the import location of the built-in object in the Keras namespace, e.g.
    `"keras.optimizers.Adam"`, and this information is stored in `class_name`:

    ```
    dict_structure = {
        "class_name": "keras.optimizers.Adam",
        "config": {
            "amsgrad": false,
            "beta_1": 0.8999999761581421,
            "beta_2": 0.9990000128746033,
            "decay": 0.0,
            "epsilon": 1e-07,
            "learning_rate": 0.0010000000474974513,
            "name": "Adam"
        },
        "module": null,
        "registered_name": "Adam"
    }
    # Returns an `Adam` instance identical to the original one.
    deserialize_keras_object(dict_structure)
    ```

    If the class does not have an exported Keras namespace, the library tracks it
    by its `module` and `class_name`. For example:

    ```
    dict_structure = {
      "class_name": "LossesContainer",
      "config": {
          "losses": [...],
          "total_loss_mean": {...},
      },
      "module": "keras.engine.compile_utils",
      "registered_name": "LossesContainer"
    }

    # Returns a `LossesContainer` instance identical to the original one.
    deserialize_keras_object(dict_structure)
    ```

    And the following dictionary represents a user-customized `MeanSquaredError`
    loss:

    ```
    @keras.utils.generic_utils.register_keras_serializable(package='my_package')
    class ModifiedMeanSquaredError(keras.losses.MeanSquaredError):
      ...

    dict_structure = {
        "class_name": "ModifiedMeanSquaredError",
        "config": {
            "fn": "mean_squared_error",
            "name": "mean_squared_error",
            "reduction": "auto"
        },
        "registered_name": "my_package>ModifiedMeanSquaredError"
    }
    # Gives `ModifiedMeanSquaredError` object
    deserialize_keras_object(dict_structure)
    ```

    Args:
      config_dict: the python dict structure to deserialize the Keras object from.

    Returns:
      The Keras object that is deserialized from `config_dict`.

    """
    # TODO(rchao): Design a 'version' key for `config_dict` for defining versions
    # for classes.
    class_name = config_dict["class_name"]
    config = config_dict["config"]
    module = config_dict["module"]
    registered_name = config_dict["registered_name"]

    # Strings and functions will have `builtins` as its module.
    if module == "builtins":
        if class_name == "str":
            if not isinstance(config, str):
                raise TypeError("Config of string is supposed to be a string. "
                                f"Received: {config}.")
            return config

        elif class_name == "function":
            custom_function = generic_utils.get_custom_objects_by_name(
                registered_name)
            if custom_function is not None:
                # If there is a custom function registered (via
                # `register_keras_serializable` API), that takes precedence.
                return custom_function

            # Otherwise, attempt to import the tracked module, and find the function.
            function_module = config.get("module", None)
            try:
                function_module = importlib.import_module(function_module)
            except ImportError as e:
                raise ImportError(
                    f"The function module {function_module} is not available. The "
                    f"config dictionary provided is {config_dict}.") from e
            return vars(function_module).get(config["function_name"])

        raise TypeError(f"Unrecognized type: {class_name}")

    custom_class = generic_utils.get_custom_objects_by_name(registered_name)
    if custom_class is not None:
        # For others (classes), see if there is a custom class registered (via
        # `register_keras_serializable` API). If so, that takes precedence.
        return custom_class.from_config(config)
    else:
        # Otherwise, attempt to retrieve the class object given the `module`, and
        # `class_name`.
        if module is None:
            # In the case where `module` is not recorded, the `class_name` represents
            # the full exported Keras namespace (used by `keras_export`) such as
            # "keras.optimizers.Adam".
            cls = tf_export.get_symbol_from_name(class_name)
        else:
            # In the case where `module` is available, the class does not have an
            # Keras namespace (which is the case when the symbol is not exported via
            # `keras_export`). Import the tracked module (that is used for the
            # internal path), find the class, and use its config.
            mod = importlib.import_module(module)
            cls = vars(mod).get(class_name, None)
        if not hasattr(cls, "from_config"):
            raise TypeError(f"Unable to reconstruct an instance of {cls}.")
        return cls.from_config(config)