def head_ops(self, features, labels, mode, train_op_fn, logits=None, logits_input=None): """Returns ops for a model_fn. Args: features: input dict. labels: labels dict or tensor. mode: estimator's ModeKeys train_op_fn: function that takes a scalar loss and returns an op to optimize with the loss. logits: logits to be used for the head. logits_input: tensor to build logits from. Returns: `estimator.ModelFnOps` Raises: ValueError: if mode is not recognized. """ _check_logits_input_not_supported(logits, logits_input) if mode == estimator.ModeKeys.TRAIN: loss, additional_train_op = self._training_loss(features, labels, logits, logits_input) train_op = train_op_fn(loss) if additional_train_op: if train_op: train_op = control_flow_ops.group(train_op, *additional_train_op) else: train_op = control_flow_ops.group(*additional_train_op) return estimator.ModelFnOps( mode=estimator.ModeKeys.TRAIN, loss=loss, training_op=train_op, default_metrics=self._default_metric(), signature_fn=self._create_signature_fn()) if mode == estimator.ModeKeys.INFER: return estimator.ModelFnOps( mode=estimator.ModeKeys.INFER, predictions=self._infer_op(logits, logits_input), default_metrics=self._default_metric(), signature_fn=self._create_signature_fn()) if mode == estimator.ModeKeys.EVAL: predictions, loss = self._eval_op(features, labels, logits, logits_input) return estimator.ModelFnOps( mode=estimator.ModeKeys.EVAL, predictions=predictions, loss=loss, default_metrics=self._default_metric(), signature_fn=self._create_signature_fn()) raise ValueError("mode=%s unrecognized." % str(mode))
def _dynamic_rnn_model_fn(features, labels, mode): """The model to be passed to an `Estimator`.""" with ops.name_scope(name): initial_state = features.get(initial_state_key) sequence_length = features.get(sequence_length_key) sequence_input = build_sequence_input(features, sequence_feature_columns, context_feature_columns) if mode == estimator.ModeKeys.TRAIN: cell_for_mode = apply_dropout(cell, input_keep_probability, output_keep_probability) else: cell_for_mode = cell rnn_activations, final_state = construct_rnn( initial_state, sequence_input, cell_for_mode, target_column.num_label_columns, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory) if prediction_type == PredictionType.MULTIPLE_VALUE: prediction_dict = _multi_value_predictions( rnn_activations, target_column, predict_probabilities) loss = _multi_value_loss(rnn_activations, labels, sequence_length, target_column, features) elif prediction_type == PredictionType.SINGLE_VALUE: prediction_dict = _single_value_predictions( rnn_activations, sequence_length, target_column, predict_probabilities) loss = _single_value_loss(rnn_activations, labels, sequence_length, target_column, features) # TODO(roumposg): Return eval_metric_ops here, instead of default_metrics. default_metrics = _get_default_metrics(problem_type, prediction_type, sequence_length) prediction_dict[RNNKeys.FINAL_STATE_KEY] = final_state eval_metric_ops = estimator._make_metrics_ops( # pylint: disable=protected-access default_metrics, features, labels, prediction_dict) train_op = optimizers.optimize_loss( loss=loss, global_step=None, learning_rate=learning_rate, optimizer=optimizer, clip_gradients=gradient_clipping_norm, summaries=optimizers.OPTIMIZER_SUMMARIES) return estimator.ModelFnOps(mode=mode, predictions=prediction_dict, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)
def _dynamic_rnn_model_fn(features, labels, mode): """The model to be passed to an `Estimator`.""" with ops.name_scope(name): initial_state = features.get(initial_state_key) sequence_length = features.get(sequence_length_key) sequence_input = build_sequence_input(features, sequence_feature_columns, context_feature_columns) rnn_activations, final_state = construct_rnn( initial_state, sequence_input, cell, target_column.num_label_columns, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory) if prediction_type == PredictionType.MULTIPLE_VALUE: prediction_dict = _multi_value_predictions( rnn_activations, target_column, predict_probabilities) loss = _multi_value_loss(rnn_activations, labels, sequence_length, target_column, features) elif prediction_type == PredictionType.SINGLE_VALUE: prediction_dict = _single_value_predictions( rnn_activations, sequence_length, target_column, predict_probabilities) loss = _single_value_loss(rnn_activations, labels, sequence_length, target_column, features) default_metrics = _get_default_metrics(problem_type, prediction_type, sequence_length) prediction_dict[RNNKeys.FINAL_STATE_KEY] = final_state training_op = optimizers.optimize_loss( loss=loss, global_step=None, learning_rate=learning_rate, optimizer=optimizer, clip_gradients=gradient_clipping_norm, summaries=optimizers.OPTIMIZER_SUMMARIES) return estimator.ModelFnOps(mode=mode, predictions=prediction_dict, loss=loss, training_op=training_op, default_metrics=default_metrics)