Exemple #1
0
def load_optimizer_weights(model: Model, load_path: pathlib.Path) -> None:
    """
    Load the optimizer states from a tf.keras model saved with
    tf.keras.models.save_model(). Ignores and prints a warning message when
    encountering a graph network. This implementation is lifted from
    tf.keras.models.load_model().
    """
    f = h5py.File(str(load_path), mode="r")
    if "optimizer_weights" in f:
        # Build train function (to get weight updates).  Models that aren't
        # graph networks must wait until they are called with data to
        # _make_train_function() and so can't load optimizer weights.
        if model._is_graph_network:  # pylint: disable=protected-access
            model._make_train_function()
            optimizer_weight_values = load_optimizer_weights_from_hdf5_group(f)
            try:
                model.optimizer.set_weights(optimizer_weight_values)
            except ValueError:
                logging.warning("Error in loading the saved optimizer "
                                "state. As a result, your model is "
                                "starting with a freshly initialized "
                                "optimizer.")
        else:
            logging.warning("Sequential models without an `input_shape` "
                            "passed to the first layer cannot reload their "
                            "optimizer state. As a result, your model is "
                            "starting with a freshly initialized optimizer.")
Exemple #2
0
    def _load_weights(self, checkpoint, gen_or_disc):
        if gen_or_disc == 'gen':
            network = self.generator
            step_fn = self.gen_step
        elif gen_or_disc == 'disc':
            network = self.discriminator
            step_fn = self.disc_step
        else:
            raise ValueError(gen_or_disc)

        model_file = h5py.File(checkpoint, 'r')
        if len(network.optimizer.weights
               ) == 0 and 'optimizer_weights' in model_file:
            # perform single optimization step to init optimizer weights
            features_shape = self.discriminator.inputs[0].shape.as_list()
            targets_shape = self.discriminator.inputs[1].shape.as_list()
            features_shape[0], targets_shape[0] = 1, 1
            step_fn(tf.zeros(features_shape), tf.zeros(targets_shape))

        print(f'Loading {gen_or_disc} weights from {str(checkpoint)}')
        network.load_weights(str(checkpoint))

        if 'optimizer_weights' in model_file:
            print('Also recovering the optimizer state')
            opt_weight_values = hdf5_format.load_optimizer_weights_from_hdf5_group(
                model_file)
            network.optimizer.set_weights(opt_weight_values)
Exemple #3
0
def load_optimizer_weights(model, filepath):
    """Loads optimizer weights to compiled model from hdf5 file.
        Arguments:
            model: Compiled model
    """
    opened_new_file = not isinstance(filepath, h5py.File)
    if opened_new_file:
        f = h5py.File(filepath, mode='r')
    else:
        f = filepath

    try:
        if model.optimizer and 'optimizer_weights' in f:
            try:
                model.optimizer._create_all_weights(model.trainable_variables)
            except (NotImplementedError, AttributeError):
                logging.warning(
                    'Error when creating the weights of optimizer {}, making it '
                    'impossible to restore the saved optimizer state. As a result, '
                    'your model is starting with a freshly initialized optimizer.')
            optimizer_weight_values = hdf5_format.load_optimizer_weights_from_hdf5_group(f)
            try:
                model.optimizer.set_weights(optimizer_weight_values)
            except ValueError:
                logging.warning('Error in loading the saved optimizer '
                                'state. As a result, your model is '
                                'starting with a freshly initialized '
                                'optimizer.')
    finally:
        if opened_new_file:
            f.close()
    return model
Exemple #4
0
 def restore_optimizer(self, checkpoint=LATEST):
     """Restore weights to optimizer."""
     # pylint: disable=import-error
     from tensorflow.python.keras.saving.hdf5_format import \
         load_optimizer_weights_from_hdf5_group
     # pylint: enable=import-error
     import h5py
     checkpoint = self.checkpoint(checkpoint)
     if checkpoint is None:
         return
     with h5py.File(checkpoint) as f:
         optimizer_weight_values = load_optimizer_weights_from_hdf5_group(f)
         self.model.optimizer.set_weights(optimizer_weight_values)
Exemple #5
0
    def load_optimizer_from_hdf5(filepath, model):
        with h5py.File(filepath, mode='r') as f:
            if 'optimizer_weights' in f:
                try:
                    model.optimizer._create_all_weights(
                        model.trainable_variables)
                except (NotImplementedError, AttributeError):
                    logging.warning(
                        'Error when creating the weights of optimizer {}')

                optimizer_weight_values = load_optimizer_weights_from_hdf5_group(
                    f)
                try:
                    model.optimizer.set_weights(optimizer_weight_values)
                except ValueError:
                    logging.warning(
                        'Error in loading the saved optimizer state.')
def load_optimizer_weights(
    model: Model, h5group: Any, optimizer: tf.keras.optimizers.Optimizer
) -> None:
    """
    Load the optimizer states from a tf.keras model saved with
    tf.keras.models.save_model(). Ignores and prints a warning message when
    encountering a graph network. This implementation is lifted from
    tf.keras.models.load_model().
    """
    tf2_2_or_newer = version.parse(tf.__version__) >= version.parse("2.2.0")
    if model._is_graph_network or tf2_2_or_newer:  # pylint: disable=protected-access
        if tf2_2_or_newer:
            try:
                optimizer._create_all_weights(model.trainable_variables)
            except (NotImplementedError, AttributeError):
                logging.warning(
                    "Error when creating the weights of optimizer, making it "
                    "impossible to restore the saved optimizer state. As a result, "
                    "your model is starting with a freshly initialized optimizer."
                )
        else:
            # Build train function (to get weight updates).  Models that aren't
            # graph networks must wait until they are called with data to
            # _make_train_function() and so can't load optimizer weights.
            model._make_train_function()

        optimizer_weight_values = load_optimizer_weights_from_hdf5_group(h5group)
        try:
            optimizer.set_weights(optimizer_weight_values)
        except ValueError:
            logging.warning(
                "Error in loading the saved optimizer "
                "state. As a result, your model is "
                "starting with a freshly initialized "
                "optimizer."
            )
    else:
        logging.warning(
            "Sequential models without an `input_shape` "
            "passed to the first layer cannot reload their "
            "optimizer state. As a result, your model is "
            "starting with a freshly initialized optimizer."
        )
Exemple #7
0
 def _set_optimizer_state(self, optimizer, state, optimizer_name):
     with h5py.File(state, 'r') as f:
         weights = hdf5_format.load_optimizer_weights_from_hdf5_group(f)
         optimizer.set_weights(weights)
Exemple #8
0
def load_model(filepath, custom_objects=None, optimizer_name=None):
    """Loads and compiles a Keras model saved as an HDF5 file.

  Same as tf.keras.model.load_model, except it will always compile the model
  and instantiate the Kfac optimizer correctly. If you do not want the model to
  be compiled, or saved without the optimizer, use tf.keras.models.load_model
  instead.

  Example:
  ```python:
  import tensorflow as tf
  import kfac

  model = tf.keras.Model(...)
  loss = tf.keras.losses.MSE()  # could be a serialized loss function
  optimizer = kfac.keras.optimizers.Kfac(0.001, 0.01, model=model, loss=loss)
  model.compile(optimizer, loss)
  model.fit(...)
  model.save('saved_model.hdf5')  # or use tf.keras.models.save_model
  ...
  loaded_model = kfac.keras.saving_utils.load_model('saved_model.hdf5')
  loaded_model.fit(...)
  ```

  Args:
    filepath: One of the following:
        - String, path to the saved model
        - `h5py.File` object from which to load the model
    custom_objects: Optional dictionary mapping names (strings) to custom
      classes or functions to be considered during deserialization. Kfac will
      be added to this dictionary automatically.
    optimizer_name: Optional string that specifies what variable scope you want
      the KFAC variables to be created in. Useful if you have multiple KFAC
      optimizers on one graph.

  Raises:
    ImportError: If h5py was not imported.

  Returns:
    A compiled Keras model with the Kfac optimizer correctly initialized.
  """
    if h5py is None:
        raise ImportError('`load_model` requires h5py.')
    if not custom_objects:
        custom_objects = {}
    custom_objects['Kfac'] = optimizers.Kfac

    should_open_file = not isinstance(filepath, h5py.File)
    model_file = h5py.File(filepath,
                           mode='r') if should_open_file else filepath

    model = tf.keras.models.load_model(model_file,
                                       custom_objects=custom_objects,
                                       compile=False)

    # Code below is current as of 2019-06-20 and may break due to future changes.
    # github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/saving/hdf5_format.py
    try:
        training_config = model_file.attrs.get('training_config')
        if hasattr(training_config, 'decode'):
            training_config = training_config.decode('utf-8')
        if training_config is None:
            raise ValueError(
                'No training configuration found in save file, meaning '
                'the model was not compiled. Please use '
                'tf.keras.models.load_model instead.')
        training_config = json.loads(training_config)

        model.compile(**_compile_args_from_training_config(
            training_config, custom_objects))
        model.optimizer.register_layers(model)
        if optimizer_name:
            model.optimizer.name = optimizer_name

        if 'optimizer_weights' in model_file:
            # Build train function (to get weight updates).
            # Models that aren't graph networks must wait until they are called
            # with data to _make_train_function() and so can't load optimizer
            # weights.
            model._make_train_function()  # pylint: disable=protected-access
            opt_weight_vals = hdf5_format.load_optimizer_weights_from_hdf5_group(
                model_file)
            try:
                model.optimizer.set_weights(opt_weight_vals)
            except ValueError:
                logging.warn(
                    'Error in loading the saved optimizer state. As a '
                    'result, your model is starting with a freshly '
                    'initialized optimizer.')
    finally:
        if should_open_file:
            model_file.close()

    return model