def load_optimizer_weights(model: Model, load_path: pathlib.Path) -> None: """ Load the optimizer states from a tf.keras model saved with tf.keras.models.save_model(). Ignores and prints a warning message when encountering a graph network. This implementation is lifted from tf.keras.models.load_model(). """ f = h5py.File(str(load_path), mode="r") if "optimizer_weights" in f: # Build train function (to get weight updates). Models that aren't # graph networks must wait until they are called with data to # _make_train_function() and so can't load optimizer weights. if model._is_graph_network: # pylint: disable=protected-access model._make_train_function() optimizer_weight_values = load_optimizer_weights_from_hdf5_group(f) try: model.optimizer.set_weights(optimizer_weight_values) except ValueError: logging.warning("Error in loading the saved optimizer " "state. As a result, your model is " "starting with a freshly initialized " "optimizer.") else: logging.warning("Sequential models without an `input_shape` " "passed to the first layer cannot reload their " "optimizer state. As a result, your model is " "starting with a freshly initialized optimizer.")
def _load_weights(self, checkpoint, gen_or_disc): if gen_or_disc == 'gen': network = self.generator step_fn = self.gen_step elif gen_or_disc == 'disc': network = self.discriminator step_fn = self.disc_step else: raise ValueError(gen_or_disc) model_file = h5py.File(checkpoint, 'r') if len(network.optimizer.weights ) == 0 and 'optimizer_weights' in model_file: # perform single optimization step to init optimizer weights features_shape = self.discriminator.inputs[0].shape.as_list() targets_shape = self.discriminator.inputs[1].shape.as_list() features_shape[0], targets_shape[0] = 1, 1 step_fn(tf.zeros(features_shape), tf.zeros(targets_shape)) print(f'Loading {gen_or_disc} weights from {str(checkpoint)}') network.load_weights(str(checkpoint)) if 'optimizer_weights' in model_file: print('Also recovering the optimizer state') opt_weight_values = hdf5_format.load_optimizer_weights_from_hdf5_group( model_file) network.optimizer.set_weights(opt_weight_values)
def load_optimizer_weights(model, filepath): """Loads optimizer weights to compiled model from hdf5 file. Arguments: model: Compiled model """ opened_new_file = not isinstance(filepath, h5py.File) if opened_new_file: f = h5py.File(filepath, mode='r') else: f = filepath try: if model.optimizer and 'optimizer_weights' in f: try: model.optimizer._create_all_weights(model.trainable_variables) except (NotImplementedError, AttributeError): logging.warning( 'Error when creating the weights of optimizer {}, making it ' 'impossible to restore the saved optimizer state. As a result, ' 'your model is starting with a freshly initialized optimizer.') optimizer_weight_values = hdf5_format.load_optimizer_weights_from_hdf5_group(f) try: model.optimizer.set_weights(optimizer_weight_values) except ValueError: logging.warning('Error in loading the saved optimizer ' 'state. As a result, your model is ' 'starting with a freshly initialized ' 'optimizer.') finally: if opened_new_file: f.close() return model
def restore_optimizer(self, checkpoint=LATEST): """Restore weights to optimizer.""" # pylint: disable=import-error from tensorflow.python.keras.saving.hdf5_format import \ load_optimizer_weights_from_hdf5_group # pylint: enable=import-error import h5py checkpoint = self.checkpoint(checkpoint) if checkpoint is None: return with h5py.File(checkpoint) as f: optimizer_weight_values = load_optimizer_weights_from_hdf5_group(f) self.model.optimizer.set_weights(optimizer_weight_values)
def load_optimizer_from_hdf5(filepath, model): with h5py.File(filepath, mode='r') as f: if 'optimizer_weights' in f: try: model.optimizer._create_all_weights( model.trainable_variables) except (NotImplementedError, AttributeError): logging.warning( 'Error when creating the weights of optimizer {}') optimizer_weight_values = load_optimizer_weights_from_hdf5_group( f) try: model.optimizer.set_weights(optimizer_weight_values) except ValueError: logging.warning( 'Error in loading the saved optimizer state.')
def load_optimizer_weights( model: Model, h5group: Any, optimizer: tf.keras.optimizers.Optimizer ) -> None: """ Load the optimizer states from a tf.keras model saved with tf.keras.models.save_model(). Ignores and prints a warning message when encountering a graph network. This implementation is lifted from tf.keras.models.load_model(). """ tf2_2_or_newer = version.parse(tf.__version__) >= version.parse("2.2.0") if model._is_graph_network or tf2_2_or_newer: # pylint: disable=protected-access if tf2_2_or_newer: try: optimizer._create_all_weights(model.trainable_variables) except (NotImplementedError, AttributeError): logging.warning( "Error when creating the weights of optimizer, making it " "impossible to restore the saved optimizer state. As a result, " "your model is starting with a freshly initialized optimizer." ) else: # Build train function (to get weight updates). Models that aren't # graph networks must wait until they are called with data to # _make_train_function() and so can't load optimizer weights. model._make_train_function() optimizer_weight_values = load_optimizer_weights_from_hdf5_group(h5group) try: optimizer.set_weights(optimizer_weight_values) except ValueError: logging.warning( "Error in loading the saved optimizer " "state. As a result, your model is " "starting with a freshly initialized " "optimizer." ) else: logging.warning( "Sequential models without an `input_shape` " "passed to the first layer cannot reload their " "optimizer state. As a result, your model is " "starting with a freshly initialized optimizer." )
def _set_optimizer_state(self, optimizer, state, optimizer_name): with h5py.File(state, 'r') as f: weights = hdf5_format.load_optimizer_weights_from_hdf5_group(f) optimizer.set_weights(weights)
def load_model(filepath, custom_objects=None, optimizer_name=None): """Loads and compiles a Keras model saved as an HDF5 file. Same as tf.keras.model.load_model, except it will always compile the model and instantiate the Kfac optimizer correctly. If you do not want the model to be compiled, or saved without the optimizer, use tf.keras.models.load_model instead. Example: ```python: import tensorflow as tf import kfac model = tf.keras.Model(...) loss = tf.keras.losses.MSE() # could be a serialized loss function optimizer = kfac.keras.optimizers.Kfac(0.001, 0.01, model=model, loss=loss) model.compile(optimizer, loss) model.fit(...) model.save('saved_model.hdf5') # or use tf.keras.models.save_model ... loaded_model = kfac.keras.saving_utils.load_model('saved_model.hdf5') loaded_model.fit(...) ``` Args: filepath: One of the following: - String, path to the saved model - `h5py.File` object from which to load the model custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization. Kfac will be added to this dictionary automatically. optimizer_name: Optional string that specifies what variable scope you want the KFAC variables to be created in. Useful if you have multiple KFAC optimizers on one graph. Raises: ImportError: If h5py was not imported. Returns: A compiled Keras model with the Kfac optimizer correctly initialized. """ if h5py is None: raise ImportError('`load_model` requires h5py.') if not custom_objects: custom_objects = {} custom_objects['Kfac'] = optimizers.Kfac should_open_file = not isinstance(filepath, h5py.File) model_file = h5py.File(filepath, mode='r') if should_open_file else filepath model = tf.keras.models.load_model(model_file, custom_objects=custom_objects, compile=False) # Code below is current as of 2019-06-20 and may break due to future changes. # github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/saving/hdf5_format.py try: training_config = model_file.attrs.get('training_config') if hasattr(training_config, 'decode'): training_config = training_config.decode('utf-8') if training_config is None: raise ValueError( 'No training configuration found in save file, meaning ' 'the model was not compiled. Please use ' 'tf.keras.models.load_model instead.') training_config = json.loads(training_config) model.compile(**_compile_args_from_training_config( training_config, custom_objects)) model.optimizer.register_layers(model) if optimizer_name: model.optimizer.name = optimizer_name if 'optimizer_weights' in model_file: # Build train function (to get weight updates). # Models that aren't graph networks must wait until they are called # with data to _make_train_function() and so can't load optimizer # weights. model._make_train_function() # pylint: disable=protected-access opt_weight_vals = hdf5_format.load_optimizer_weights_from_hdf5_group( model_file) try: model.optimizer.set_weights(opt_weight_vals) except ValueError: logging.warn( 'Error in loading the saved optimizer state. As a ' 'result, your model is starting with a freshly ' 'initialized optimizer.') finally: if should_open_file: model_file.close() return model