Пример #1
0
    def test_export_keras_estimator(self, checkpoint_format):
        keras_model, (x_train, y_train), (
            _, _), train_input_fn, _ = get_resource_for_simple_model(
                model_type='sequential', is_evaluate=False)

        keras_model.compile(loss='categorical_crossentropy',
                            optimizer='adam',
                            metrics=['accuracy'])
        keras_model.fit(x_train, y_train, epochs=1)
        bias_value = keras.backend.get_value(keras_model.layers[0].bias)

        est_keras = keras_lib.model_to_estimator(
            keras_model=keras_model,
            model_dir=tempfile.mkdtemp(dir=self._base_dir),
            checkpoint_format=checkpoint_format)

        def serving_input_receiver_fn():
            feature_spec = {
                'dense_input':
                parsing_ops.FixedLenFeature([1], dtype=dtypes.float32)
            }
            return export_lib.build_parsing_serving_input_receiver_fn(
                feature_spec)

        # Try immediately exporting, testing that (1) exported values are the same,
        # and (2) estimator can be exported without saving a checkpoint into the
        # model directory.
        saved_model_dir = est_keras.export_saved_model(
            tempfile.mkdtemp(dir=self._base_dir), serving_input_receiver_fn())
        variables_path = saved_model_utils.get_variables_path(saved_model_dir)

        variable_name = 'dense/bias'
        if checkpoint_format == 'checkpoint':
            names_to_keys = saver_lib.object_graph_key_mapping(variables_path)
            variable_name = names_to_keys[variable_name]

        self.assertAllClose(
            bias_value, training.load_variable(variables_path, variable_name))

        # Export the estimator after training a bit.
        est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
        saved_model_dir = est_keras.export_saved_model(
            tempfile.mkdtemp(dir=self._base_dir), serving_input_receiver_fn())
        variables_path = saved_model_utils.get_variables_path(saved_model_dir)
        self.assertNotAllClose(
            bias_value, training.load_variable(variables_path, variable_name))
Пример #2
0
def _get_object_checkpoint_renames(path, variable_names):
  """Returns a dictionary mapping variable names to checkpoint keys.

  The warm-starting utility expects variable names to match with the variable
  names in the checkpoint. For object-based checkpoints, the variable names
  and names in the checkpoint are different. Thus, for object-based checkpoints,
  this function is used to obtain the map from variable names to checkpoint
  keys.

  Args:
    path: path to checkpoint directory or file.
    variable_names: list of variable names to load from the checkpoint.

  Returns:
    If the checkpoint is object-based, this function returns a map from variable
    names to their corresponding checkpoint keys.
    If the checkpoint is name-based, this returns an empty dict.

  Raises:
    ValueError: If the object-based checkpoint is missing variables.
  """
  fname = checkpoint_utils._get_checkpoint_filename(path)  # pylint: disable=protected-access
  try:
    names_to_keys = saver_lib.object_graph_key_mapping(fname)
  except errors.NotFoundError:
    # If an error is raised from `object_graph_key_mapping`, then the
    # checkpoint is name-based. There are no renames, so return an empty dict.
    return {}

  missing_names = set(variable_names) - set(names_to_keys.keys())
  if missing_names:
    raise ValueError(
        "Attempting to warm-start from an object-based checkpoint, but found "
        "that the checkpoint did not contain values for all variables. The "
        "following variables were missing: {}"
        .format(missing_names))
  return {name: names_to_keys[name] for name in variable_names}
def patched_restore(self, sess, save_path, options=None):  # type: ignore
    """
    Restores previously saved variables.

    This method runs the ops added by the constructor for restoring variables.
    It requires a session in which the graph was launched.  The variables to
    restore do not have to have been initialized, as restoring is itself a way
    to initialize variables.

    The `save_path` argument is typically a value previously returned from a
    `save()` call, or a call to `latest_checkpoint()`.

    Args:
      sess: A `Session` to use to restore the parameters. None in eager mode.
      save_path: Path where parameters were previously saved.

    Raises:
      ValueError: If save_path is None or not a valid checkpoint.
    """
    if self._is_empty:
        return
    if save_path is None:
        raise ValueError("Can't load save_path when it is None.")

    checkpoint_prefix = compat.as_text(save_path)
    if not checkpoint_management.checkpoint_exists(checkpoint_prefix):
        raise ValueError("The passed save_path is not a valid checkpoint: " +
                         checkpoint_prefix)

    logging.info("Restoring parameters from %s", checkpoint_prefix)
    try:
        if context.executing_eagerly():
            self._build_eager(save_path, build_save=False, build_restore=True)
        else:
            sess.run(
                self.saver_def.restore_op_name,
                {self.saver_def.filename_tensor_name: save_path},
                options=options,
            )
    except errors.NotFoundError as err:
        # There are three common conditions that might cause this error:
        # 0. The file is missing. We ignore here, as this is checked above.
        # 1. This is an object-based checkpoint trying name-based loading.
        # 2. The graph has been altered and a variable or other name is missing.

        # 1. The checkpoint would not be loaded successfully as is. Try to parse
        # it as an object-based checkpoint.
        try:
            names_to_keys = object_graph_key_mapping(save_path)
        except errors.NotFoundError:
            # 2. This is not an object-based checkpoint, which likely means there
            # is a graph mismatch. Re-raise the original error with
            # a helpful message (b/110263146)
            raise _wrap_restore_error_with_msg(
                err, "a Variable name or other graph key that is missing")

        # This is an object-based checkpoint. We'll print a warning and then do
        # the restore.
        logging.warning(
            "Restoring an object-based checkpoint using a name-based saver. This "
            "may be somewhat fragile, and will re-build the Saver. Instead, "
            "consider loading object-based checkpoints using "
            "tf.train.Checkpoint().")
        self._object_restore_saver = saver_from_object_based_checkpoint(
            checkpoint_path=save_path,
            var_list=self._var_list,
            builder=self._builder,
            names_to_keys=names_to_keys,
            cached_saver=self._object_restore_saver,
        )
        self._object_restore_saver.restore(sess=sess,
                                           save_path=save_path,
                                           options=options)
    except errors.InvalidArgumentError as err:
        # There is a mismatch between the graph and the checkpoint being loaded.
        # We add a more reasonable error message here to help users (b/110263146)
        raise _wrap_restore_error_with_msg(
            err, "a mismatch between the current graph and the graph")