def _get_eval_ops(self, features, labels, metrics=None):
        """See base class."""
        features = self._get_feature_dict(features)
        features, labels = self._feature_engineering_fn(features, labels)
        logits = self._logits(features)

        eval_ops = self._head.head_ops(features,
                                       labels,
                                       model_fn.ModeKeys.EVAL,
                                       None,
                                       logits=logits)
        custom_metrics = {}
        if metrics:
            for name, metric in six.iteritems(metrics):
                if not isinstance(name, tuple):
                    # TODO(zakaria): remove once deprecation is finished (b/31229024)
                    custom_metrics[(name,
                                    self._default_prediction_key)] = metric
                else:
                    custom_metrics[name] = metric
        # TODO(zakaria): Remove this once we refactor this class to delegate
        #   to estimator.
        eval_ops.eval_metric_ops.update(
            estimator._make_metrics_ops(  # pylint: disable=protected-access
                custom_metrics, features, labels, eval_ops.predictions))
        return eval_ops
Exemplo n.º 2
0
def _eval_metric_ops(metrics, features, labels, predictions):
    with ops.name_scope(None, "metrics",
                        (tuple(six.itervalues(features)) +
                         (labels, ) + tuple(six.itervalues(predictions)))):
        # pylint: disable=protected-access
        return estimator._make_metrics_ops(metrics, features, labels,
                                           predictions)
Exemplo n.º 3
0
 def _dynamic_rnn_model_fn(features, labels, mode):
     """The model to be passed to an `Estimator`."""
     with ops.name_scope(name):
         initial_state = features.get(initial_state_key)
         sequence_length = features.get(sequence_length_key)
         sequence_input = build_sequence_input(features,
                                               sequence_feature_columns,
                                               context_feature_columns)
         if mode == model_fn.ModeKeys.TRAIN:
             cell_for_mode = apply_dropout(cell, input_keep_probability,
                                           output_keep_probability)
         else:
             cell_for_mode = cell
         rnn_activations, final_state = construct_rnn(
             initial_state,
             sequence_input,
             cell_for_mode,
             target_column.num_label_columns,
             dtype=dtype,
             parallel_iterations=parallel_iterations,
             swap_memory=swap_memory)
         if prediction_type == PredictionType.MULTIPLE_VALUE:
             prediction_dict = _multi_value_predictions(
                 rnn_activations, target_column, predict_probabilities)
             loss = _multi_value_loss(rnn_activations, labels,
                                      sequence_length, target_column,
                                      features)
         elif prediction_type == PredictionType.SINGLE_VALUE:
             prediction_dict = _single_value_predictions(
                 rnn_activations, sequence_length, target_column,
                 predict_probabilities)
             loss = _single_value_loss(rnn_activations, labels,
                                       sequence_length, target_column,
                                       features)
         # TODO(roumposg): Return eval_metric_ops here, instead of default_metrics.
         default_metrics = _get_default_metrics(problem_type,
                                                prediction_type,
                                                sequence_length)
         prediction_dict[RNNKeys.FINAL_STATE_KEY] = final_state
         eval_metric_ops = estimator._make_metrics_ops(  # pylint: disable=protected-access
             default_metrics, features, labels, prediction_dict)
         train_op = optimizers.optimize_loss(
             loss=loss,
             global_step=None,
             learning_rate=learning_rate,
             optimizer=optimizer,
             clip_gradients=gradient_clipping_norm,
             summaries=optimizers.OPTIMIZER_SUMMARIES)
     return model_fn.ModelFnOps(mode=mode,
                                predictions=prediction_dict,
                                loss=loss,
                                train_op=train_op,
                                eval_metric_ops=eval_metric_ops)
Exemplo n.º 4
0
 def _eval_metric_ops(self, features, labels, logits):
   """Returns a dict of metric ops keyed by name."""
   labels = _check_labels(labels, self._label_name)
   predictions = self._predictions(logits)
   return estimator._make_metrics_ops(  # pylint: disable=protected-access
       self._default_metrics(), features, labels, predictions)
Exemplo n.º 5
0
    def _rnn_model_fn(features, labels, mode):
        """The model to be passed to an `Estimator`."""
        with ops.name_scope(name):
            if mode == model_fn.ModeKeys.TRAIN:
                cell_for_mode = apply_dropout(cell, input_keep_probability,
                                              output_keep_probability)
            else:
                cell_for_mode = cell

            batch = _read_batch(
                cell=cell_for_mode,
                features=features,
                labels=labels,
                mode=mode,
                num_unroll=num_unroll,
                num_layers=num_layers,
                batch_size=batch_size,
                input_key_column_name=input_key_column_name,
                sequence_feature_columns=sequence_feature_columns,
                context_feature_columns=context_feature_columns,
                num_threads=num_threads,
                queue_capacity=queue_capacity)
            sequence_features = batch.sequences
            context_features = batch.context
            if mode != model_fn.ModeKeys.INFER:
                labels = sequence_features.pop(RNNKeys.LABELS_KEY)
            inputs = _prepare_inputs_for_rnn(sequence_features,
                                             context_features,
                                             sequence_feature_columns,
                                             num_unroll)
            state_name = _get_lstm_state_names(num_layers)
            rnn_activations, final_state = construct_state_saving_rnn(
                cell=cell_for_mode,
                inputs=inputs,
                num_label_columns=target_column.num_label_columns,
                state_saver=batch,
                state_name=state_name)

            loss = None  # Created below for modes TRAIN and EVAL.
            prediction_dict = _multi_value_predictions(rnn_activations,
                                                       target_column,
                                                       predict_probabilities)
            if mode != model_fn.ModeKeys.INFER:
                loss = _multi_value_loss(rnn_activations, labels, batch.length,
                                         target_column, features)

            eval_metric_ops = None
            if mode != model_fn.ModeKeys.INFER:
                default_metrics = _get_default_metrics(problem_type,
                                                       batch.length)
                eval_metric_ops = estimator._make_metrics_ops(  # pylint: disable=protected-access
                    default_metrics, features, labels, prediction_dict)
            state_dict = state_tuple_to_dict(final_state)
            prediction_dict.update(state_dict)

            train_op = None
            if mode == model_fn.ModeKeys.TRAIN:
                train_op = optimizers.optimize_loss(
                    loss=loss,
                    global_step=None,  # Get it internally.
                    learning_rate=learning_rate,
                    optimizer=optimizer,
                    clip_gradients=gradient_clipping_norm,
                    summaries=optimizers.OPTIMIZER_SUMMARIES)

        return model_fn.ModelFnOps(mode=mode,
                                   predictions=prediction_dict,
                                   loss=loss,
                                   train_op=train_op,
                                   eval_metric_ops=eval_metric_ops)