示例#1
0
def assert_global_step(global_step_tensor):
  training_util.assert_global_step(global_step_tensor)
示例#2
0
  def _model_fn_from_saved_model(self, features, labels, mode):
    """Load a SavedModel graph and return an EstimatorSpec."""
    # TODO(kathywu): Model function loads placeholders from the graph. Calling
    # export_all_saved_models creates another placeholder for the inputs, on top
    # of the original placeholders. There should be a way to avoid this.
    self._validate_mode(mode)

    g = ops.get_default_graph()
    if  training_util.get_global_step(g) is not None:
      raise RuntimeError(
          'Graph must not contain a global step tensor before the SavedModel is'
          ' loaded. Please make sure that the input function does not create a '
          'global step.')

    # Extract SignatureDef for information about the input and output tensors.
    signature_def = self._get_signature_def_for_mode(mode)

    # Generate input map for replacing the inputs in the SavedModel graph with
    # the provided features and labels.
    input_map = _generate_input_map(signature_def, features, labels)

    # Create a list of the names of output tensors. When the graph is loaded,
    # names of the output tensors may be remapped. This ensures that the correct
    # tensors are returned in the EstimatorSpec.
    output_tensor_names = [
        value.name for value in six.itervalues(signature_def.outputs)]

    # Load the graph. `output_tensors` contains output `Tensors` in the same
    # same order as the `output_tensor_names` list.
    tags = model_fn_lib.EXPORT_TAG_MAP[mode]
    _, output_tensors = self.saved_model_loader.load_graph(
        g, tags, input_map=input_map, return_elements=output_tensor_names)

    # Create a scaffold from the MetaGraphDef that contains ops to initialize
    # the graph. This should mirror the steps from _add_meta_graph_for_mode(),
    # which creates a MetaGraphDef from the EstimatorSpec's scaffold.
    scaffold = monitored_session.Scaffold(
        local_init_op=loader_impl._get_legacy_init_op_tensor(  # pylint: disable=protected-access
            self._get_meta_graph_def_for_mode(mode)))

    # Ensure that a global step tensor has been created.
    global_step_tensor = training_util.get_global_step(g)
    training_util.assert_global_step(global_step_tensor)

    # Extract values to return in the EstimatorSpec.
    output_map = dict(zip(output_tensor_names, output_tensors))
    outputs = {key: output_map[value.name]
               for key, value in six.iteritems(signature_def.outputs)}

    loss, predictions, metrics = _validate_and_extract_outputs(
        mode, outputs, signature_def.method_name)

    train_op = ops.get_collection(constants.TRAIN_OP_KEY)
    if len(train_op) > 1:
      raise RuntimeError('Multiple ops found in the train_op collection.')
    train_op = None if not train_op else train_op[0]

    _clear_saved_model_collections()
    return model_fn_lib.EstimatorSpec(
        scaffold=scaffold,
        mode=mode,
        loss=loss,
        train_op=train_op,
        predictions=predictions,
        eval_metric_ops=metrics)
  def _model_fn_from_saved_model(self, features, labels, mode):
    """Load a SavedModel graph and return an EstimatorSpec."""
    # TODO(kathywu): Model function loads placeholders from the graph. Calling
    # export_all_saved_models creates another placeholder for the inputs, on top
    # of the original placeholders. There should be a way to avoid this.
    self._validate_mode(mode)

    g = ops.get_default_graph()
    if  training_util.get_global_step(g) is not None:
      raise RuntimeError(
          'Graph must not contain a global step tensor before the SavedModel is'
          ' loaded. Please make sure that the input function does not create a '
          'global step.')

    # Extract SignatureDef for information about the input and output tensors.
    signature_def = self._get_signature_def_for_mode(mode)

    # Generate input map for replacing the inputs in the SavedModel graph with
    # the provided features and labels.
    input_map = _generate_input_map(signature_def, features, labels)

    # Create a list of the names of output tensors. When the graph is loaded,
    # names of the output tensors may be remapped. This ensures that the correct
    # tensors are returned in the EstimatorSpec.
    output_tensor_names = [
        value.name for value in six.itervalues(signature_def.outputs)]

    # Load the graph. `output_tensors` contains output `Tensors` in the same
    # same order as the `output_tensor_names` list.
    tags = model_fn_lib.EXPORT_TAG_MAP[mode]
    _, output_tensors = self.saved_model_loader.load_graph(
        g, tags, input_map=input_map, return_elements=output_tensor_names)

    # Create a scaffold from the MetaGraphDef that contains ops to initialize
    # the graph. This should mirror the steps from _add_meta_graph_for_mode(),
    # which creates a MetaGraphDef from the EstimatorSpec's scaffold.
    scaffold = monitored_session.Scaffold(
        local_init_op=loader_impl._get_main_op_tensor(  # pylint: disable=protected-access
            self._get_meta_graph_def_for_mode(mode)))

    # Ensure that a global step tensor has been created.
    global_step_tensor = training_util.get_global_step(g)
    training_util.assert_global_step(global_step_tensor)

    # Extract values to return in the EstimatorSpec.
    output_map = dict(zip(output_tensor_names, output_tensors))
    outputs = {key: output_map[value.name]
               for key, value in six.iteritems(signature_def.outputs)}

    loss, predictions, metrics = _validate_and_extract_outputs(
        mode, outputs, signature_def.method_name)

    train_op = ops.get_collection(constants.TRAIN_OP_KEY)
    if len(train_op) > 1:
      raise RuntimeError('Multiple ops found in the train_op collection.')
    train_op = None if not train_op else train_op[0]

    _clear_saved_model_collections()
    return model_fn_lib.EstimatorSpec(
        scaffold=scaffold,
        mode=mode,
        loss=loss,
        train_op=train_op,
        predictions=predictions,
        eval_metric_ops=metrics)