def test_encode_decode_type_spec(self): spec = tf.TensorSpec((1, 5), tf.float32) string = json_utils.Encoder().encode(spec) loaded = json_utils.decode(string) self.assertEqual(spec, loaded) invalid_type_spec = { "class_name": "TypeSpec", "type_spec": "Invalid Type", "serialized": None, } string = json_utils.Encoder().encode(invalid_type_spec) with self.assertRaisesRegexp(ValueError, "No TypeSpec has been registered"): loaded = json_utils.decode(string)
def test_encode_decode_type_spec(self): spec = tf.TensorSpec((1, 5), tf.float32) string = json_utils.Encoder().encode(spec) loaded = json_utils.decode(string) self.assertEqual(spec, loaded) invalid_type_spec = { 'class_name': 'TypeSpec', 'type_spec': 'Invalid Type', 'serialized': None } string = json_utils.Encoder().encode(invalid_type_spec) with self.assertRaisesRegexp(ValueError, 'No TypeSpec has been registered'): loaded = json_utils.decode(string)
def test_saved_module_paths_and_class_names(self): temp_dir = os.path.join(self.get_temp_dir(), "my_model") subclassed_model = self._get_subclassed_model() x = np.random.random((100, 32)) y = np.random.random((100, 1)) subclassed_model.fit(x, y, epochs=1) subclassed_model._save_new(temp_dir) file_path = os.path.join(temp_dir, saving_lib._CONFIG_FILE) with tf.io.gfile.GFile(file_path, "r") as f: config_json = f.read() config_dict = json_utils.decode(config_json) self.assertEqual( config_dict["registered_name"], "my_custom_package>CustomModelX" ) self.assertIsNone(config_dict["config"]["optimizer"]["module"]) self.assertEqual( config_dict["config"]["optimizer"]["class_name"], "keras.optimizers.Adam", ) self.assertEqual( config_dict["config"]["loss"]["module"], "keras.engine.compile_utils", ) self.assertEqual( config_dict["config"]["loss"]["class_name"], "LossesContainer" )
def load(dirpath): """Load a saved python model.""" file_path = os.path.join(dirpath, _CONFIG_FILE) with tf.io.gfile.GFile(file_path, "r") as f: config_json = f.read() config_dict = json_utils.decode(config_json) return deserialize_keras_object(config_dict)
def _load_layer(self, node_id, identifier, metadata): """Load a single layer from a SavedUserObject proto.""" metadata = json_utils.decode(metadata) # If node was already created if node_id in self.loaded_nodes: node, setter = self.loaded_nodes[node_id] # Revive setter requires the object to have a `_serialized_attributes` # property. Add it here. _maybe_add_serialized_attributes(node, metadata) config = metadata.get('config') if _is_graph_network(node) and generic_utils.validate_config(config): child_nodes = self._get_child_layer_node_ids(node_id) self.model_layer_dependencies[node_id] = (node, child_nodes) if not child_nodes: self._models_to_reconstruct.append(node_id) return node, setter # Detect whether this object can be revived from the config. If not, then # revive from the SavedModel instead. obj, setter = self._revive_from_config(identifier, metadata, node_id) if obj is None: obj, setter = revive_custom_object(identifier, metadata) # Add an attribute that stores the extra functions/objects saved in the # SavedModel. Most of these functions/objects are ignored, but some are # used later in the loading process (e.g. the list of regularization # losses, or the training config of compiled models). _maybe_add_serialized_attributes(obj, metadata) return obj, setter
def test_encode_decode_tuple(self): metadata = {'key1': (3, 5), 'key2': [(1, (3, 4)), (1, )]} string = json_utils.Encoder().encode(metadata) loaded = json_utils.decode(string) self.assertEqual(set(loaded.keys()), {'key1', 'key2'}) self.assertAllEqual(loaded['key1'], (3, 5)) self.assertAllEqual(loaded['key2'], [(1, (3, 4)), (1, )])
def test_encode_decode_tuple(self): metadata = {"key1": (3, 5), "key2": [(1, (3, 4)), (1, )]} string = json_utils.Encoder().encode(metadata) loaded = json_utils.decode(string) self.assertEqual(set(loaded.keys()), {"key1", "key2"}) self.assertAllEqual(loaded["key1"], (3, 5)) self.assertAllEqual(loaded["key2"], [(1, (3, 4)), (1, )])
def test_encode_decode_enum(self): class Enum(enum.Enum): CLASS_A = "a" CLASS_B = "b" config = {"key": Enum.CLASS_A, "key2": Enum.CLASS_B} string = json_utils.Encoder().encode(config) loaded = json_utils.decode(string) self.assertAllEqual({"key": "a", "key2": "b"}, loaded)
def test_encode_decode_enum(self): class Enum(enum.Enum): CLASS_A = 'a' CLASS_B = 'b' config = {'key': Enum.CLASS_A, 'key2': Enum.CLASS_B} string = json_utils.Encoder().encode(config) loaded = json_utils.decode(string) self.assertAllEqual({'key': 'a', 'key2': 'b'}, loaded)
def test_encode_decode_extension_type_tensor(self): class MaskedTensor(tf.experimental.ExtensionType): __name__ = 'MaskedTensor' values: tf.Tensor mask: tf.Tensor x = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]], mask=[[True, True, False], [True, False, True]]) string = json_utils.Encoder().encode(x) loaded = json_utils.decode(string) self.assertAllEqual(loaded, x)
def test_encode_decode_tensor_shape(self): metadata = { "key1": tf.TensorShape(None), "key2": [tf.TensorShape([None]), tf.TensorShape([3, None, 5])], } string = json_utils.Encoder().encode(metadata) loaded = json_utils.decode(string) self.assertEqual(set(loaded.keys()), {"key1", "key2"}) self.assertAllEqual(loaded["key1"].rank, None) self.assertAllEqual(loaded["key2"][0].as_list(), [None]) self.assertAllEqual(loaded["key2"][1].as_list(), [3, None, 5])
def _update_to_current_version(metadata): """Applies version updates to the metadata proto for backwards compat.""" for node in metadata.nodes: if node.version.producer == 1 and node.identifier in [ constants.MODEL_IDENTIFIER, constants.SEQUENTIAL_IDENTIFIER, constants.NETWORK_IDENTIFIER]: node_metadata = json_utils.decode(node.metadata) save_spec = node_metadata.get('save_spec') if save_spec is not None: node_metadata['full_save_spec'] = ([save_spec], {}) node.metadata = json_utils.Encoder().encode(node_metadata) return metadata
def test_encode_decode_tensor_shape(self): metadata = { 'key1': tf.TensorShape(None), 'key2': [tf.TensorShape([None]), tf.TensorShape([3, None, 5])] } string = json_utils.Encoder().encode(metadata) loaded = json_utils.decode(string) self.assertEqual(set(loaded.keys()), {'key1', 'key2'}) self.assertAllEqual(loaded['key1'].rank, None) self.assertAllEqual(loaded['key2'][0].as_list(), [None]) self.assertAllEqual(loaded['key2'][1].as_list(), [3, None, 5])
def testAddFullSaveSpec(self): save_spec = tf.TensorSpec([3, 5], dtype=tf.int32) node_metadata = json_utils.Encoder().encode({'save_spec': save_spec}) metadata = saved_metadata_pb2.SavedMetadata() metadata.nodes.add( version=versions_pb2.VersionDef( producer=1, min_consumer=1, bad_consumers=[]), identifier='_tf_keras_model', metadata=node_metadata) # pylint: disable=protected-access new_metadata = keras_load._update_to_current_version(metadata) node_metadata = json_utils.decode(new_metadata.nodes[0].metadata) expected_full_spec = ([tf.TensorSpec(shape=(3, 5), dtype=tf.int32)], {}) self.assertAllEqual(expected_full_spec, node_metadata.get('full_save_spec'))
def _reconstruct_model(self, model_id, model, layers): """Reconstructs the network structure.""" config = json_utils.decode( self._proto.nodes[model_id].user_object.metadata)['config'] # Set up model inputs if model.inputs: # Inputs may already be created if the model is instantiated in another # object's __init__. pass elif isinstance(model, models_lib.Sequential): if not layers or not isinstance(layers[0], input_layer.InputLayer): if config['layers'][0]['class_name'] == 'InputLayer': layers.insert(0, input_layer.InputLayer.from_config( config['layers'][0]['config'])) elif 'batch_input_shape' in config['layers'][0]['config']: batch_input_shape = config['layers'][0]['config']['batch_input_shape'] layers.insert(0, input_layer.InputLayer( input_shape=batch_input_shape[1:], batch_size=batch_input_shape[0], dtype=layers[0].dtype, name=layers[0].name + '_input')) model.__init__(layers, name=config['name']) if not model.inputs: first_layer = self._get_child_layer_node_ids(model_id)[0] input_specs = self._infer_inputs(first_layer) input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True) model._set_inputs(input_specs) # pylint: disable=protected-access if not model.built and not isinstance(input_specs, dict): model.build(input_shapes) else: # Reconstruct functional model (inputs, outputs, created_layers) = functional_lib.reconstruct_from_config( config, created_layers={layer.name: layer for layer in layers}) model.__init__(inputs, outputs, name=config['name']) functional_lib.connect_ancillary_layers(model, created_layers) # Set model dtype. _set_network_attributes_from_metadata(model) # Unblock models that are dependent on this model. self._unblock_model_reconstruction(model_id, model)
def test_saved_module_paths_and_class_names(self): temp_dir = os.path.join(self.get_temp_dir(), 'my_model') subclassed_model = self._get_subclassed_model() x = np.random.random((100, 32)) y = np.random.random((100, 1)) subclassed_model.fit(x, y, epochs=1) subclassed_model._save_new(temp_dir) file_path = os.path.join(temp_dir, saving_lib._CONFIG_FILE) with tf.io.gfile.GFile(file_path, 'r') as f: config_json = f.read() config_dict = json_utils.decode(config_json) self.assertEqual(config_dict['registered_name'], 'my_custom_package>CustomModelX') self.assertIsNone(config_dict['config']['optimizer']['module']) self.assertEqual(config_dict['config']['optimizer']['class_name'], 'keras.optimizers.Adam') self.assertEqual(config_dict['config']['loss']['module'], 'keras.engine.compile_utils') self.assertEqual(config_dict['config']['loss']['class_name'], 'LossesContainer')
def model_from_json(json_string, custom_objects=None): """Parses a JSON model configuration string and returns a model instance. Usage: >>> model = tf.keras.Sequential([ ... tf.keras.layers.Dense(5, input_shape=(3,)), ... tf.keras.layers.Softmax()]) >>> config = model.to_json() >>> loaded_model = tf.keras.models.model_from_json(config) Args: json_string: JSON string encoding a model configuration. custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization. Returns: A Keras model instance (uncompiled). """ config = json_utils.decode(json_string) from keras.layers import deserialize # pylint: disable=g-import-not-at-top return deserialize(config, custom_objects=custom_objects)
def _add_children_recreated_from_config(self, obj, proto, node_id): """Recursively records objects recreated from config.""" # pylint: disable=protected-access if node_id in self._traversed_nodes_from_config: return parent_path = self._node_paths[node_id] self._traversed_nodes_from_config.add(node_id) obj._maybe_initialize_trackable() if isinstance(obj, base_layer.Layer) and not obj.built: metadata = json_utils.decode(proto.user_object.metadata) self._try_build_layer(obj, node_id, metadata.get('build_input_shape')) # Create list of all possible children children = [] # Look for direct children for reference in proto.children: obj_child = obj._lookup_dependency(reference.local_name) children.append((obj_child, reference.node_id, reference.local_name)) # Add metrics that may have been added to the layer._metrics list. # This is stored in the SavedModel as layer.keras_api.layer_metrics in # SavedModels created after Tf 2.2. metric_list_node_id = self._search_for_child_node( node_id, [constants.KERAS_ATTR, 'layer_metrics']) if metric_list_node_id is not None and hasattr(obj, '_metrics'): obj_metrics = {m.name: m for m in obj._metrics} for reference in self._proto.nodes[metric_list_node_id].children: metric = obj_metrics.get(reference.local_name) if metric is not None: metric_path = '{}.layer_metrics.{}'.format(constants.KERAS_ATTR, reference.local_name) children.append((metric, reference.node_id, metric_path)) for (obj_child, child_id, child_name) in children: child_proto = self._proto.nodes[child_id] if not isinstance(obj_child, tf.__internal__.tracking.Trackable): continue if (child_proto.user_object.identifier in revived_types.registered_identifiers()): setter = revived_types.get_setter(child_proto.user_object) elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS: setter = _revive_setter else: setter = setattr # pylint: enable=protected-access if child_id in self.loaded_nodes: if self.loaded_nodes[child_id][0] is not obj_child: # This means that the same trackable object is referenced by two # different objects that were recreated from the config. logging.warn('Looks like there is an object (perhaps variable or ' 'layer) that is shared between different layers/models. ' 'This may cause issues when restoring the variable ' 'values. Object: {}'.format(obj_child)) continue # Overwrite variable names with the ones saved in the SavedModel. if (child_proto.WhichOneof('kind') == 'variable' and child_proto.variable.name): obj_child._handle_name = child_proto.variable.name + ':0' # pylint: disable=protected-access if isinstance(obj_child, tf.__internal__.tracking.TrackableDataStructure): setter = lambda *args: None child_path = '{}.{}'.format(parent_path, child_name) self._node_paths[child_id] = child_path self._add_children_recreated_from_config( obj_child, child_proto, child_id) self.loaded_nodes[child_id] = obj_child, setter
def test_encode_decode_ragged_tensor(self): x = tf.ragged.constant([[1.0, 2.0], [3.0]]) string = json_utils.Encoder().encode(x) loaded = json_utils.decode(string) self.assertAllEqual(loaded, x)
def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin """Loads a model saved via `save_model_to_hdf5`. Args: filepath: One of the following: - String, path to the saved model - `h5py.File` object from which to load the model custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization. compile: Boolean, whether to compile the model after loading. Returns: A Keras model instance. If an optimizer was found as part of the saved model, the model is already compiled. Otherwise, the model is uncompiled and a warning will be displayed. When `compile` is set to False, the compilation is omitted without any warning. Raises: ImportError: if h5py is not available. ValueError: In case of an invalid savefile. """ if h5py is None: raise ImportError('`load_model()` using h5 format requires h5py. Could not ' 'import h5py.') if not custom_objects: custom_objects = {} opened_new_file = not isinstance(filepath, h5py.File) if opened_new_file: f = h5py.File(filepath, mode='r') else: f = filepath model = None try: # instantiate model model_config = f.attrs.get('model_config') if model_config is None: raise ValueError(f'No model config found in the file at {filepath}.') if hasattr(model_config, 'decode'): model_config = model_config.decode('utf-8') model_config = json_utils.decode(model_config) model = model_config_lib.model_from_config(model_config, custom_objects=custom_objects) # set weights load_weights_from_hdf5_group(f['model_weights'], model) if compile: # instantiate optimizer training_config = f.attrs.get('training_config') if hasattr(training_config, 'decode'): training_config = training_config.decode('utf-8') if training_config is None: logging.warning('No training configuration found in the save file, so ' 'the model was *not* compiled. Compile it manually.') return model training_config = json_utils.decode(training_config) # Compile model. model.compile(**saving_utils.compile_args_from_training_config( training_config, custom_objects), from_serialized=True) saving_utils.try_build_compiled_arguments(model) # Set optimizer weights. if 'optimizer_weights' in f: try: model.optimizer._create_all_weights(model.trainable_variables) except (NotImplementedError, AttributeError): logging.warning( 'Error when creating the weights of optimizer {}, making it ' 'impossible to restore the saved optimizer state. As a result, ' 'your model is starting with a freshly initialized optimizer.') optimizer_weight_values = load_optimizer_weights_from_hdf5_group(f) try: model.optimizer.set_weights(optimizer_weight_values) except ValueError: logging.warning('Error in loading the saved optimizer ' 'state. As a result, your model is ' 'starting with a freshly initialized ' 'optimizer.') finally: if opened_new_file: f.close() return model