Esempio n. 1
0
  def head_ops(self, features, labels, mode, train_op_fn, logits=None,
               logits_input=None):
    """Returns ops for a model_fn.

    Args:
      features: input dict.
      labels: labels dict or tensor.
      mode: estimator's ModeKeys
      train_op_fn: function that takes a scalar loss and returns an op to
          optimize with the loss.
      logits: logits to be used for the head.
      logits_input: tensor to build logits from.

    Returns:
      `estimator.ModelFnOps`

    Raises:
      ValueError: if mode is not recognized.
    """
    _check_logits_input_not_supported(logits, logits_input)
    if mode == estimator.ModeKeys.TRAIN:
      loss, additional_train_op = self._training_loss(features, labels,
                                                      logits, logits_input)

      train_op = train_op_fn(loss)

      if additional_train_op:
        if train_op:
          train_op = control_flow_ops.group(train_op, *additional_train_op)
        else:
          train_op = control_flow_ops.group(*additional_train_op)

      return estimator.ModelFnOps(
          mode=estimator.ModeKeys.TRAIN,
          loss=loss,
          training_op=train_op,
          default_metrics=self._default_metric(),
          signature_fn=self._create_signature_fn())

    if mode == estimator.ModeKeys.INFER:
      return estimator.ModelFnOps(
          mode=estimator.ModeKeys.INFER,
          predictions=self._infer_op(logits, logits_input),
          default_metrics=self._default_metric(),
          signature_fn=self._create_signature_fn())

    if mode == estimator.ModeKeys.EVAL:
      predictions, loss = self._eval_op(features, labels, logits, logits_input)
      return estimator.ModelFnOps(
          mode=estimator.ModeKeys.EVAL,
          predictions=predictions,
          loss=loss,
          default_metrics=self._default_metric(),
          signature_fn=self._create_signature_fn())

    raise ValueError("mode=%s unrecognized." % str(mode))
Esempio n. 2
0
 def _dynamic_rnn_model_fn(features, labels, mode):
     """The model to be passed to an `Estimator`."""
     with ops.name_scope(name):
         initial_state = features.get(initial_state_key)
         sequence_length = features.get(sequence_length_key)
         sequence_input = build_sequence_input(features,
                                               sequence_feature_columns,
                                               context_feature_columns)
         if mode == estimator.ModeKeys.TRAIN:
             cell_for_mode = apply_dropout(cell, input_keep_probability,
                                           output_keep_probability)
         else:
             cell_for_mode = cell
         rnn_activations, final_state = construct_rnn(
             initial_state,
             sequence_input,
             cell_for_mode,
             target_column.num_label_columns,
             dtype=dtype,
             parallel_iterations=parallel_iterations,
             swap_memory=swap_memory)
         if prediction_type == PredictionType.MULTIPLE_VALUE:
             prediction_dict = _multi_value_predictions(
                 rnn_activations, target_column, predict_probabilities)
             loss = _multi_value_loss(rnn_activations, labels,
                                      sequence_length, target_column,
                                      features)
         elif prediction_type == PredictionType.SINGLE_VALUE:
             prediction_dict = _single_value_predictions(
                 rnn_activations, sequence_length, target_column,
                 predict_probabilities)
             loss = _single_value_loss(rnn_activations, labels,
                                       sequence_length, target_column,
                                       features)
         # TODO(roumposg): Return eval_metric_ops here, instead of default_metrics.
         default_metrics = _get_default_metrics(problem_type,
                                                prediction_type,
                                                sequence_length)
         prediction_dict[RNNKeys.FINAL_STATE_KEY] = final_state
         eval_metric_ops = estimator._make_metrics_ops(  # pylint: disable=protected-access
             default_metrics, features, labels, prediction_dict)
         train_op = optimizers.optimize_loss(
             loss=loss,
             global_step=None,
             learning_rate=learning_rate,
             optimizer=optimizer,
             clip_gradients=gradient_clipping_norm,
             summaries=optimizers.OPTIMIZER_SUMMARIES)
     return estimator.ModelFnOps(mode=mode,
                                 predictions=prediction_dict,
                                 loss=loss,
                                 train_op=train_op,
                                 eval_metric_ops=eval_metric_ops)
 def _dynamic_rnn_model_fn(features, labels, mode):
     """The model to be passed to an `Estimator`."""
     with ops.name_scope(name):
         initial_state = features.get(initial_state_key)
         sequence_length = features.get(sequence_length_key)
         sequence_input = build_sequence_input(features,
                                               sequence_feature_columns,
                                               context_feature_columns)
         rnn_activations, final_state = construct_rnn(
             initial_state,
             sequence_input,
             cell,
             target_column.num_label_columns,
             dtype=dtype,
             parallel_iterations=parallel_iterations,
             swap_memory=swap_memory)
         if prediction_type == PredictionType.MULTIPLE_VALUE:
             prediction_dict = _multi_value_predictions(
                 rnn_activations, target_column, predict_probabilities)
             loss = _multi_value_loss(rnn_activations, labels,
                                      sequence_length, target_column,
                                      features)
         elif prediction_type == PredictionType.SINGLE_VALUE:
             prediction_dict = _single_value_predictions(
                 rnn_activations, sequence_length, target_column,
                 predict_probabilities)
             loss = _single_value_loss(rnn_activations, labels,
                                       sequence_length, target_column,
                                       features)
         default_metrics = _get_default_metrics(problem_type,
                                                prediction_type,
                                                sequence_length)
         prediction_dict[RNNKeys.FINAL_STATE_KEY] = final_state
         training_op = optimizers.optimize_loss(
             loss=loss,
             global_step=None,
             learning_rate=learning_rate,
             optimizer=optimizer,
             clip_gradients=gradient_clipping_norm,
             summaries=optimizers.OPTIMIZER_SUMMARIES)
     return estimator.ModelFnOps(mode=mode,
                                 predictions=prediction_dict,
                                 loss=loss,
                                 training_op=training_op,
                                 default_metrics=default_metrics)