Beispiel #1
0
def _export_mode(mode, has_saved_vars, builder, model, custom_objects,
                 checkpoint_path, input_signature):
    """Exports a model, and optionally saves new vars from the clone model.

  Args:
    mode: A `tf.estimator.ModeKeys` string.
    has_saved_vars: A `boolean` indicating whether the SavedModel has already
      exported variables.
    builder: A `SavedModelBuilder` object.
    model: A `tf.keras.Model` object.
    custom_objects: A dictionary mapping string names to custom classes
      or functions.
    checkpoint_path: String path to checkpoint.
    input_signature: Nested TensorSpec containing the expected inputs. Can be
      `None`, in which case the signature will be inferred from the model.

  Raises:
    ValueError: If the train/eval mode is being exported, but the model does
      not have an optimizer.
  """
    from tensorflow.python.keras import models as models_lib  # pylint: disable=g-import-not-at-top
    compile_clone = (mode != mode_keys.ModeKeys.PREDICT)
    if compile_clone and not model.optimizer:
        raise ValueError(
            'Model does not have an optimizer. Cannot export mode %s' % mode)

    model_graph = ops.get_default_graph()
    with ops.Graph().as_default() as g, K.learning_phase_scope(
            mode == mode_keys.ModeKeys.TRAIN):

        if input_signature is None:
            input_tensors = None
        else:
            input_tensors = nest.map_structure(create_placeholder,
                                               input_signature)

        # Clone the model into blank graph. This will create placeholders for inputs
        # and targets.
        clone = models_lib.clone_and_build_model(model,
                                                 input_tensors=input_tensors,
                                                 custom_objects=custom_objects,
                                                 compile_clone=compile_clone)

        # Make sure that iterations variable is added to the global step collection,
        # to ensure that, when the SavedModel graph is loaded, the iterations
        # variable is returned by `tf.train.get_global_step()`. This is required for
        # compatibility with the SavedModelEstimator.
        if compile_clone:
            g.add_to_collection(ops.GraphKeys.GLOBAL_STEP,
                                clone.optimizer.iterations)

        # Extract update and train ops from train/test/predict functions.
        train_op = None
        if mode == mode_keys.ModeKeys.TRAIN:
            clone._make_train_function()
            train_op = clone.train_function.updates_op
        elif mode == mode_keys.ModeKeys.TEST:
            clone._make_test_function()
        else:
            clone._make_predict_function()
        g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(
            clone.state_updates)

        clone_var_list = checkpointable_utils.named_saveables(clone)

        with session.Session().as_default():
            if has_saved_vars:
                # Confirm all variables in the clone have an entry in the checkpoint.
                status = clone.load_weights(checkpoint_path)
                status.assert_existing_objects_matched()
            else:
                # Confirm that variables between the clone and model match up exactly,
                # not counting optimizer objects. Optimizer objects are ignored because
                # if the model has not trained, the slot variables will not have been
                # created yet.
                # TODO(b/113179535): Replace with checkpointable equivalence.
                _assert_same_non_optimizer_objects(model, model_graph, clone,
                                                   g)

                # TODO(b/113178242): Use value transfer for checkpointable objects.
                clone.load_weights(checkpoint_path)

                # Add graph and variables to SavedModel.
                # TODO(b/113134168): Switch to add_meta_graph_and_variables.
                clone.save_weights(checkpoint_path,
                                   save_format='tf',
                                   overwrite=True)
                builder._has_saved_variables = True

        # Add graph to the SavedModel builder.
        builder.add_meta_graph(model_utils.EXPORT_TAG_MAP[mode],
                               signature_def_map=_create_signature_def_map(
                                   clone, mode),
                               saver=saver_lib.Saver(clone_var_list),
                               init_op=variables.local_variables_initializer(),
                               train_op=train_op)
        return None
def _export_mode(
    mode, has_saved_vars, builder, model, custom_objects, checkpoint_path):
  """Export a model, and optionally save new vars from the clone model.

  Args:
    mode: A `tf.estimator.ModeKeys` string.
    has_saved_vars: A `boolean` indicating whether the SavedModel has already
      exported variables.
    builder: A `SavedModelBuilder` object.
    model: A `tf.keras.Model` object.
    custom_objects: A dictionary mapping string names to custom classes
      or functions.
    checkpoint_path: String path to checkpoint.

  Raises:
    ValueError: If the train/eval mode is being exported, but the model does
      not have an optimizer.
  """
  compile_clone = (mode != model_fn_lib.ModeKeys.PREDICT)
  if compile_clone and not model.optimizer:
    raise ValueError(
        'Model does not have an optimizer. Cannot export mode %s' % mode)

  model_graph = ops.get_default_graph()
  with ops.Graph().as_default() as g:

    K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)

    # Clone the model into blank graph. This will create placeholders for inputs
    # and targets.
    clone = models_lib.clone_and_build_model(
        model, custom_objects=custom_objects, compile_clone=compile_clone)

    # Make sure that iterations variable is added to the global step collection,
    # to ensure that, when the SavedModel graph is loaded, the iterations
    # variable is returned by `tf.train.get_global_step()`. This is required for
    # compatibility with the SavedModelEstimator.
    if compile_clone:
      g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations)

    # Extract update and train ops from train/test/predict functions.
    train_op = None
    if mode == model_fn_lib.ModeKeys.TRAIN:
      clone._make_train_function()
      train_op = clone.train_function.updates_op
    elif mode == model_fn_lib.ModeKeys.EVAL:
      clone._make_test_function()
    else:
      clone._make_predict_function()
    g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)

    clone_var_list = checkpointable_utils.named_saveables(clone)

    with session.Session().as_default():
      if has_saved_vars:
        # Confirm all variables in the clone have an entry in the checkpoint.
        status = clone.load_weights(checkpoint_path)
        status.assert_existing_objects_matched()
      else:
        # Confirm that variables between the clone and model match up exactly,
        # not counting optimizer objects. Optimizer objects are ignored because
        # if the model has not trained, the slot variables will not have been
        # created yet.
        # TODO(b/113179535): Replace with checkpointable equivalence.
        _assert_same_non_optimizer_objects(model, model_graph, clone, g)

        # TODO(b/113178242): Use value transfer for checkpointable objects.
        clone.load_weights(checkpoint_path)

        # Add graph and variables to SavedModel.
        # TODO(b/113134168): Switch to add_meta_graph_and_variables.
        clone.save_weights(checkpoint_path, save_format='tf', overwrite=True)
        builder._has_saved_variables = True

    # Add graph to the SavedModel builder.
    builder.add_meta_graph(
        model_fn_lib.EXPORT_TAG_MAP[mode],
        signature_def_map=_create_signature_def_map(clone, mode),
        saver=saver_lib.Saver(clone_var_list),
        init_op=variables.local_variables_initializer(),
        train_op=train_op)
    return None
Beispiel #3
0
def _get_var_list(model):
    """Returns list of all checkpointed saveable objects in the model."""
    return checkpointable_utils.named_saveables(model)
def _get_var_list(model):
  """Return list of all checkpointed saveable objects in the model."""
  return checkpointable_utils.named_saveables(model)