def testExportSingleFunction(self): export_decorator = tf_export.tf_export('nameA', 'nameB') decorated_function = export_decorator(_test_function) self.assertEquals(decorated_function, _test_function) self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names) self.assertEquals(['nameA', 'nameB'], tf_export.get_v1_names(decorated_function)) self.assertEquals(['nameA', 'nameB'], tf_export.get_v2_names(decorated_function)) self.assertEqual(tf_export.get_symbol_from_name('nameA'), decorated_function) self.assertEqual(tf_export.get_symbol_from_name('nameB'), decorated_function) self.assertEqual( tf_export.get_symbol_from_name( tf_export.get_canonical_name_for_symbol(decorated_function)), decorated_function)
def testExportSingleFunctionV1Only(self): export_decorator = tf_export.tf_export(v1=['nameA', 'nameB']) decorated_function = export_decorator(_test_function) self.assertEqual(decorated_function, _test_function) self.assertAllEqual(('nameA', 'nameB'), decorated_function._tf_api_names_v1) self.assertAllEqual(['nameA', 'nameB'], tf_export.get_v1_names(decorated_function)) self.assertEqual([], tf_export.get_v2_names(decorated_function)) self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameA'), decorated_function) self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameB'), decorated_function) self.assertEqual( tf_export.get_symbol_from_name( tf_export.get_canonical_name_for_symbol( decorated_function, add_prefix_to_v1_names=True)), decorated_function)
def from_config(cls, config, custom_objects=None): config = config.copy() symbol_name = config.pop('cls_symbol') cls_ref = get_symbol_from_name(symbol_name) if not cls_ref: raise ValueError(f'TensorFlow symbol `{symbol_name}` could not be found.') config['cls_ref'] = cls_ref return cls(**config)
def from_config(cls, config, custom_objects=None): config = config.copy() symbol_name = config['function'] function = get_symbol_from_name(symbol_name) if not function: raise ValueError(f'TF symbol `{symbol_name}` could not be found.') config['function'] = function return cls(**config)
def testExportMultipleFunctions(self): export_decorator1 = tf_export.tf_export('nameA', 'nameB') export_decorator2 = tf_export.tf_export('nameC', 'nameD') decorated_function1 = export_decorator1(_test_function) decorated_function2 = export_decorator2(_test_function2) self.assertEqual(decorated_function1, _test_function) self.assertEqual(decorated_function2, _test_function2) self.assertEqual(('nameA', 'nameB'), decorated_function1._tf_api_names) self.assertEqual(('nameC', 'nameD'), decorated_function2._tf_api_names) self.assertEqual(tf_export.get_symbol_from_name('nameB'), decorated_function1) self.assertEqual(tf_export.get_symbol_from_name('nameD'), decorated_function2) self.assertEqual( tf_export.get_symbol_from_name( tf_export.get_canonical_name_for_symbol(decorated_function1)), decorated_function1) self.assertEqual( tf_export.get_symbol_from_name( tf_export.get_canonical_name_for_symbol(decorated_function2)), decorated_function2)
def deserialize_keras_object(config_dict): """Retrieve the object by deserializing the config dict. The config dict is a python dictionary that consists of a set of key-value pairs, and represents a Keras object, such as an `Optimizer`, `Layer`, `Metrics`, etc. The saving and loading library uses the following keys to record information of a Keras object: - `class_name`: String. For classes that have an exported Keras namespace, this is the full path that starts with "keras", such as "keras.optimizers.Adam". For classes that do not have an exported Keras namespace, this is the name of the class, as exactly defined in the source code, such as "LossesContainer". - `config`: Dict. Library-defined or user-defined key-value pairs that store the configuration of the object, as obtained by `object.get_config()`. - `module`: String. The path of the python module, such as "keras.engine.compile_utils". Built-in Keras classes expect to have prefix `keras`. For classes that have an exported Keras namespace, this is `None` since the class can be fully identified by the full Keras path. - `registered_name`: String. The key the class is registered under via `keras.utils.register_keras_serializable(package, name)` API. The key has the format of '{package}>{name}', where `package` and `name` are the arguments passed to `register_keras_serializable()`. If `name` is not provided, it defaults to the class name. If `registered_name` successfully resolves to a class (that was registered), `class_name` and `config` values in the dict will not be used. `registered_name` is only used for non-built-in classes. For example, the following dictionary represents the built-in Adam optimizer with the relevant config. Note that for built-in (exported symbols that have an exported Keras namespace) classes, the library tracks the class by the the import location of the built-in object in the Keras namespace, e.g. `"keras.optimizers.Adam"`, and this information is stored in `class_name`: ``` dict_structure = { "class_name": "keras.optimizers.Adam", "config": { "amsgrad": false, "beta_1": 0.8999999761581421, "beta_2": 0.9990000128746033, "decay": 0.0, "epsilon": 1e-07, "learning_rate": 0.0010000000474974513, "name": "Adam" }, "module": null, "registered_name": "Adam" } # Returns an `Adam` instance identical to the original one. deserialize_keras_object(dict_structure) ``` If the class does not have an exported Keras namespace, the library tracks it by its `module` and `class_name`. For example: ``` dict_structure = { "class_name": "LossesContainer", "config": { "losses": [...], "total_loss_mean": {...}, }, "module": "keras.engine.compile_utils", "registered_name": "LossesContainer" } # Returns a `LossesContainer` instance identical to the original one. deserialize_keras_object(dict_structure) ``` And the following dictionary represents a user-customized `MeanSquaredError` loss: ``` @keras.utils.generic_utils.register_keras_serializable(package='my_package') class ModifiedMeanSquaredError(keras.losses.MeanSquaredError): ... dict_structure = { "class_name": "ModifiedMeanSquaredError", "config": { "fn": "mean_squared_error", "name": "mean_squared_error", "reduction": "auto" }, "registered_name": "my_package>ModifiedMeanSquaredError" } # Gives `ModifiedMeanSquaredError` object deserialize_keras_object(dict_structure) ``` Args: config_dict: the python dict structure to deserialize the Keras object from. Returns: The Keras object that is deserialized from `config_dict`. """ # TODO(rchao): Design a 'version' key for `config_dict` for defining versions # for classes. class_name = config_dict["class_name"] config = config_dict["config"] module = config_dict["module"] registered_name = config_dict["registered_name"] # Strings and functions will have `builtins` as its module. if module == "builtins": if class_name == "str": if not isinstance(config, str): raise TypeError("Config of string is supposed to be a string. " f"Received: {config}.") return config elif class_name == "function": custom_function = generic_utils.get_custom_objects_by_name( registered_name) if custom_function is not None: # If there is a custom function registered (via # `register_keras_serializable` API), that takes precedence. return custom_function # Otherwise, attempt to import the tracked module, and find the function. function_module = config.get("module", None) try: function_module = importlib.import_module(function_module) except ImportError as e: raise ImportError( f"The function module {function_module} is not available. The " f"config dictionary provided is {config_dict}.") from e return vars(function_module).get(config["function_name"]) raise TypeError(f"Unrecognized type: {class_name}") custom_class = generic_utils.get_custom_objects_by_name(registered_name) if custom_class is not None: # For others (classes), see if there is a custom class registered (via # `register_keras_serializable` API). If so, that takes precedence. return custom_class.from_config(config) else: # Otherwise, attempt to retrieve the class object given the `module`, and # `class_name`. if module is None: # In the case where `module` is not recorded, the `class_name` represents # the full exported Keras namespace (used by `keras_export`) such as # "keras.optimizers.Adam". cls = tf_export.get_symbol_from_name(class_name) else: # In the case where `module` is available, the class does not have an # Keras namespace (which is the case when the symbol is not exported via # `keras_export`). Import the tracked module (that is used for the # internal path), find the class, and use its config. mod = importlib.import_module(module) cls = vars(mod).get(class_name, None) if not hasattr(cls, "from_config"): raise TypeError(f"Unable to reconstruct an instance of {cls}.") return cls.from_config(config)