def _create_tpu_estimator_spec(self, features, mode, logits, labels=None, optimizer=None, trainable_variables=None, train_op_fn=None, update_ops=None, regularization_losses=None): """Returns an `model_fn._TPUEstimatorSpec`. Args: features: Input `dict` of `Tensor` or `SparseTensor` objects. mode: Estimator's `ModeKeys`. logits: logits `Tensor` with shape `[D0, D1, ... DN, n_classes]`. For many applications, the shape is `[batch_size, n_classes]`. labels: Labels with shape matching `logits`. Can be multi-hot `Tensor` with shape `[D0, D1, ... DN, n_classes]` or `SparseTensor` with `dense_shape` `[D0, D1, ... DN, ?]`. `labels` is required argument when `mode` equals `TRAIN` or `EVAL`. optimizer: An `tf.keras.optimizers.Optimizer` instance to optimize the loss in TRAIN mode. Namely, sets `train_op = optimizer.get_updates(loss, trainable_variables)`, which updates variables to minimize `loss`.able_variables)`, which updates variables to minimize `loss`. trainable_variables: A list or tuple of `Variable` objects to update to minimize `loss`. In Tensorflow 1.x, by default these are the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. As Tensorflow 2.x doesn't have collections and GraphKeys, trainable_variables need to be passed explicitly here. train_op_fn: Function that takes a scalar loss `Tensor` and returns `train_op`. Used if `optimizer` is `None`. update_ops: A list or tuple of update ops to be run at training time. For example, layers such as BatchNormalization create mean and variance update ops that need to be run at training time. In Tensorflow 1.x, these are thrown into an UPDATE_OPS collection. As Tensorflow 2.x doesn't have collections, update_ops need to be passed explicitly here. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to set `loss_reduction=SUM_OVER_BATCH_SIZE` when creating the head to avoid scaling errors. Returns: `model_fn._TPUEstimatorSpec`. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode, or if both are set. """ with ops.name_scope(self._name, 'head'): # Predict. pred_keys = prediction_keys.PredictionKeys predictions = self.predictions(logits) if mode == ModeKeys.PREDICT: probabilities = predictions[pred_keys.PROBABILITIES] classifier_output = base_head.classification_output( scores=probabilities, n_classes=self._n_classes, label_vocabulary=self._label_vocabulary) return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=ModeKeys.PREDICT, predictions=predictions, export_outputs={ base_head.DEFAULT_SERVING_KEY: classifier_output, base_head.CLASSIFY_SERVING_KEY: classifier_output, base_head.PREDICT_SERVING_KEY: ( export_output.PredictOutput(predictions)) }) regularized_training_loss = self.loss( logits=logits, labels=labels, features=features, mode=mode, regularization_losses=regularization_losses) # Eval. if mode == ModeKeys.EVAL: eval_metrics = self.metrics( regularization_losses=regularization_losses) return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=ModeKeys.EVAL, predictions=predictions, loss=regularized_training_loss, eval_metrics=base_head.create_eval_metrics_tuple( self.update_metrics, { 'eval_metrics': eval_metrics, 'features': features, 'logits': logits, 'labels': labels, 'regularization_losses': regularization_losses })) # Train. train_op = base_head.create_estimator_spec_train_op( head_name=self._name, optimizer=optimizer, train_op_fn=train_op_fn, update_ops=update_ops, trainable_variables=trainable_variables, regularized_training_loss=regularized_training_loss, loss_reduction=self._loss_reduction) # Create summary. base_head.create_estimator_spec_summary( regularized_training_loss=regularized_training_loss, regularization_losses=regularization_losses, summary_key_fn=self._summary_key) return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, train_op=train_op)
def _create_tpu_estimator_spec(self, features, mode, logits, labels=None, optimizer=None, trainable_variables=None, train_op_fn=None, update_ops=None, regularization_losses=None): """See superclass for description.""" with tf.compat.v1.name_scope(self._name, 'head'): # Predict. pred_keys = prediction_keys.PredictionKeys predictions = self.predictions(logits) if mode == ModeKeys.PREDICT: probabilities = predictions[pred_keys.PROBABILITIES] logistic = predictions[pred_keys.LOGISTIC] classifier_output = base_head.classification_output( scores=probabilities, n_classes=2, label_vocabulary=self._label_vocabulary) return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=ModeKeys.PREDICT, predictions=predictions, export_outputs={ base_head.DEFAULT_SERVING_KEY: classifier_output, base_head.CLASSIFY_SERVING_KEY: classifier_output, base_head.REGRESS_SERVING_KEY: export_output.RegressionOutput(value=logistic), base_head.PREDICT_SERVING_KEY: export_output.PredictOutput(predictions) }) regularized_training_loss = self.loss( logits=logits, labels=labels, features=features, mode=mode, regularization_losses=regularization_losses) scalar_loss = tf.reduce_mean(regularized_training_loss) # Eval. if mode == ModeKeys.EVAL: eval_metrics = self.metrics( regularization_losses=regularization_losses) return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=ModeKeys.EVAL, predictions=predictions, loss=scalar_loss, eval_metrics=base_head.create_eval_metrics_tuple( self.update_metrics, { 'eval_metrics': eval_metrics, 'features': features, 'logits': logits, 'labels': labels, 'regularization_losses': regularization_losses })) # Train. train_op = base_head.create_estimator_spec_train_op( head_name=self._name, optimizer=optimizer, train_op_fn=train_op_fn, update_ops=update_ops, trainable_variables=trainable_variables, regularized_training_loss=regularized_training_loss, loss_reduction=self._loss_reduction) # Create summary. base_head.create_estimator_spec_summary( regularized_training_loss=scalar_loss, regularization_losses=regularization_losses, summary_key_fn=self._summary_key) return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=ModeKeys.TRAIN, predictions=predictions, loss=scalar_loss, train_op=train_op)
def create_estimator_spec(self, features, mode, logits, labels=None, optimizer=None, trainable_variables=None, train_op_fn=None, update_ops=None, regularization_losses=None): """Returns a `model_fn.EstimatorSpec`. Args: features: Input `dict` of `Tensor` or `SparseTensor` objects. mode: Estimator's `ModeKeys`. logits: Input `dict` keyed by head name, or logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`. For many applications, the `Tensor` shape is `[batch_size, logits_dimension]`. If logits is a `Tensor`, it will split the `Tensor` along the last dimension and distribute it appropriately among the heads. Check `MultiHead` for examples. labels: Input `dict` keyed by head name. For each head, the label value can be integer or string `Tensor` with shape matching its corresponding `logits`.`labels` is a required argument when `mode` equals `TRAIN` or `EVAL`. optimizer: An `tf.keras.optimizers.Optimizer` instance to optimize the loss in TRAIN mode. Namely, sets `train_op = optimizer.get_updates(loss, trainable_variables)`, which updates variables to minimize `loss`. trainable_variables: A list or tuple of `Variable` objects to update to minimize `loss`. In Tensorflow 1.x, by default these are the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. As Tensorflow 2.x doesn't have collections and GraphKeys, trainable_variables need to be passed explicitly here. train_op_fn: Function that takes a scalar loss `Tensor` and returns `train_op`. Used if `optimizer` is `None`. update_ops: A list or tuple of update ops to be run at training time. For example, layers such as BatchNormalization create mean and variance update ops that need to be run at training time. In Tensorflow 1.x, these are thrown into an UPDATE_OPS collection. As Tensorflow 2.x doesn't have collections, update_ops need to be passed explicitly here. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results, in each head, users need to use the default `loss_reduction=SUM_OVER_BATCH_SIZE` or set `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` to avoid scaling errors. Compared to the regularization losses for each head, this loss is to regularize the merged loss of all heads in multi head, and will be added to the overall training loss of multi head. Returns: A `model_fn.EstimatorSpec` instance. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode, or if both are set. If `mode` is not in Estimator's `ModeKeys`. """ per_head_label_map = _get_per_head_label(labels) if labels else None with ops.name_scope(self.name, 'multi_survival_head'): logits_dict = self._check_logits_and_labels(logits, labels) # Get all estimator spec. all_estimator_spec = [] for head in self._heads: tf.logging.info(head.name) all_estimator_spec.append( head.create_estimator_spec( features=features, mode=mode, logits=logits_dict[head.name], labels=per_head_label_map[head.name] if labels else None, train_op_fn=_no_op_train_fn)) # Predict. predictions = self.predictions(logits) if mode == model_fn.ModeKeys.PREDICT: export_outputs = self._merge_predict_export_outputs( all_estimator_spec) return model_fn.EstimatorSpec(mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs=export_outputs) loss = self.loss(logits, labels, features, mode, regularization_losses) # Eval. if mode == model_fn.ModeKeys.EVAL: eval_metrics = self.metrics( regularization_losses=regularization_losses) updated_metrics = self.update_metrics( eval_metrics, features, logits, labels, regularization_losses=regularization_losses) return model_fn.EstimatorSpec(mode=model_fn.ModeKeys.EVAL, predictions=predictions, loss=loss, eval_metric_ops=updated_metrics) # Train. if mode == model_fn.ModeKeys.TRAIN: # train_op. if optimizer is not None: if train_op_fn is not None: raise ValueError( 'train_op_fn and optimizer cannot both be set.') if isinstance(optimizer, optimizer_v2.OptimizerV2): base_head.validate_trainable_variables( trainable_variables) train_op = optimizer.get_updates( loss, trainable_variables) else: train_op = optimizer.minimize( loss, global_step=training_util.get_global_step()) elif train_op_fn is not None: train_op = train_op_fn(loss) else: raise ValueError( 'train_op_fn and optimizer cannot both be None.') # Create summary. base_head.create_estimator_spec_summary( loss, regularization_losses) # eval_metrics. eval_metrics = {} for spec in all_estimator_spec: eval_metrics.update(spec.eval_metric_ops or {}) # predictions can be used to access the logits in `TRAIN` mode return model_fn.EstimatorSpec(mode=model_fn.ModeKeys.TRAIN, loss=loss, train_op=train_op, predictions=predictions, eval_metric_ops=eval_metrics) raise ValueError('mode={} unrecognized'.format(mode))
def _create_tpu_estimator_spec(self, features, mode, logits, labels=None, optimizer=None, train_op_fn=None, regularization_losses=None): """Returns an `model_fn._TPUEstimatorSpec`. Args: features: Input `dict` of `Tensor` or `SparseTensor` objects. mode: Estimator's `ModeKeys`. logits: logits `Tensor` with shape `[D0, D1, ... DN, n_classes]`. For many applications, the shape is `[batch_size, n_classes]`. labels: Labels with shape matching `logits`. Can be multi-hot `Tensor` with shape `[D0, D1, ... DN, n_classes]` or `SparseTensor` with `dense_shape` `[D0, D1, ... DN, ?]`. `labels` is required argument when `mode` equals `TRAIN` or `EVAL`. optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns `train_op`. Used if `optimizer` is `None`. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to set `loss_reduction=SUM_OVER_BATCH_SIZE` or `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to avoid scaling errors. Returns: `model_fn._TPUEstimatorSpec`. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode, or if both are set. """ with ops.name_scope(self._name, 'head'): # Predict. pred_keys = prediction_keys.PredictionKeys predictions = self.predictions(logits) if mode == model_fn.ModeKeys.PREDICT: probabilities = predictions[pred_keys.PROBABILITIES] classifier_output = base_head.classification_output( scores=probabilities, n_classes=self._n_classes, label_vocabulary=self._label_vocabulary) return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ base_head.DEFAULT_SERVING_KEY: classifier_output, base_head.CLASSIFY_SERVING_KEY: classifier_output, base_head.PREDICT_SERVING_KEY: ( export_output.PredictOutput(predictions)) }) regularized_training_loss = self.loss( logits=logits, labels=labels, features=features, mode=mode, regularization_losses=regularization_losses) # Eval. if mode == model_fn.ModeKeys.EVAL: eval_metrics = self.metrics( regularization_losses=regularization_losses) return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.EVAL, predictions=predictions, loss=regularized_training_loss, eval_metrics=base_head.create_eval_metrics_tuple( self.update_metrics, { 'eval_metrics': eval_metrics, 'features': features, 'logits': logits, 'labels': labels, 'regularization_losses': regularization_losses })) # Train. train_op = base_head.create_estimator_spec_train_op( self._name, optimizer, train_op_fn, regularized_training_loss) # Create summary. base_head.create_estimator_spec_summary(regularized_training_loss, regularization_losses, self._summary_key) return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, train_op=train_op)
def _create_tpu_estimator_spec(self, features, mode, logits, labels=None, optimizer=None, trainable_variables=None, train_op_fn=None, update_ops=None, regularization_losses=None): """Returns an `model_fn._TPUEstimatorSpec`. Args: features: Input `dict` of `Tensor` or `SparseTensor` objects. mode: Estimator's `ModeKeys`. logits: estimated obs. value, [batch, time_len, num_obs] tensor. labels: ground truth observation, feature dict with obs. and interv. codes as keys, values tensor with shape [batch_size, context_window_size]. optimizer: An `tf.keras.optimizers.Optimizer` instance to optimize the loss in TRAIN mode. Namely, sets `train_op = optimizer.get_updates(loss, trainable_variables)`, which updates variables to minimize `loss`. trainable_variables: A list or tuple of `Variable` objects to update to minimize `loss`. In Tensorflow 1.x, by default these are the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. As Tensorflow 2.x doesn't have collections and GraphKeys, trainable_variables need to be passed explicitly here. train_op_fn: Function that takes a scalar loss `Tensor` and returns `train_op`. Used if `optimizer` is `None`. update_ops: A list or tuple of update ops to be run at training time. For example, layers such as BatchNormalization create mean and variance update ops that need to be run at training time. In Tensorflow 1.x, these are thrown into an UPDATE_OPS collection. As Tensorflow 2.x doesn't have collections, update_ops need to be passed explicitly here. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to set `loss_reduction=SUM_OVER_BATCH_SIZE` or `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to avoid scaling errors. Returns: `model_fn._TPUEstimatorSpec`. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode, or if both are set. """ with ops.name_scope(self._name, 'sequence_head'): # Predict. predictions = self.predictions(logits) if mode == model_fn.ModeKeys.PREDICT: return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ base_head.DEFAULT_SERVING_KEY: export_output.PredictOutput(predictions), base_head.PREDICT_SERVING_KEY: ( export_output.PredictOutput(predictions)) }) regularized_training_loss = self.loss( logits=logits, labels=labels, features=features, mode=mode, regularization_losses=regularization_losses) # Eval. if mode == model_fn.ModeKeys.EVAL: eval_metrics = self.metrics( regularization_losses=regularization_losses) return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.EVAL, predictions=predictions, loss=regularized_training_loss, eval_metrics=base_head.create_eval_metrics_tuple( self.update_metrics, { 'eval_metrics': eval_metrics, 'features': features, 'logits': logits, 'labels': labels, 'regularization_losses': regularization_losses })) # Train. train_op = base_head.create_estimator_spec_train_op( self._name, optimizer=optimizer, trainable_variables=trainable_variables, train_op_fn=train_op_fn, update_ops=update_ops, regularized_training_loss=regularized_training_loss) # Create summary. base_head.create_estimator_spec_summary(regularized_training_loss, regularization_losses, self._summary_key) return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, train_op=train_op)
def _create_tpu_estimator_spec(self, features, mode, logits, labels=None, optimizer=None, train_op_fn=None, regularization_losses=None): """Returns an `EstimatorSpec`. Args: features: Input `dict` mapping string feature names to `Tensor` or `SparseTensor` objects containing the values for that feature in a minibatch. Often to be used to fetch example-weight tensor. mode: Estimator's `ModeKeys`. logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`. For many applications, the shape is `[batch_size, logits_dimension]`. labels: Labels `Tensor` with shape matching `logits`, namely `[D0, D1, ... DN, logits_dimension]`. When `logits_dimension=1`, shape `[D0, D1, ... DN]` is also supported. `labels` is a required argument when `mode` equals `TRAIN` or `EVAL`. optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns `train_op`. Used if `optimizer` is `None`. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to set `loss_reduction=SUM_OVER_BATCH_SIZE` when creating the head to avoid scaling errors. Returns: A `model_fn._TPUEstimatorSpec` instance. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode, or if both are set. """ with ops.name_scope(self._name, 'head'): # Predict. predictions = self.predictions(logits) if mode == model_fn.ModeKeys.PREDICT: keys = prediction_keys.PredictionKeys regression_output = export_output.RegressionOutput( value=predictions[keys.PREDICTIONS]) return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ base_head.DEFAULT_SERVING_KEY: regression_output, base_head.REGRESS_SERVING_KEY: regression_output, base_head.PREDICT_SERVING_KEY: export_output.PredictOutput( predictions) }) regularized_training_loss = self.loss( logits=logits, labels=labels, features=features, mode=mode, regularization_losses=regularization_losses) # Eval. if mode == model_fn.ModeKeys.EVAL: eval_metrics = self.metrics( regularization_losses=regularization_losses) return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.EVAL, predictions=predictions, loss=regularized_training_loss, eval_metrics=base_head.create_eval_metrics_tuple( self.update_metrics, { 'eval_metrics': eval_metrics, 'features': features, 'logits': logits, 'labels': labels, 'regularization_losses': regularization_losses })) # Train. train_op = base_head.create_estimator_spec_train_op( head_name=self._name, optimizer=optimizer, train_op_fn=train_op_fn, update_ops=self._update_ops, regularized_training_loss=regularized_training_loss) # Create summary. base_head.create_estimator_spec_summary( regularized_training_loss=regularized_training_loss, regularization_losses=regularization_losses, summary_key_fn=self._summary_key) return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, train_op=train_op)
def create_estimator_spec( self, features, mode, logits, labels=None, optimizer=None, train_op_fn=None, regularization_losses=None): """Returns a `model_fn.EstimatorSpec`. Args: features: Input `dict` of `Tensor` or `SparseTensor` objects. mode: Estimator's `ModeKeys`. logits: Input `dict` keyed by head name, or logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`. For many applications, the `Tensor` shape is `[batch_size, logits_dimension]`. If logits is a `Tensor`, it will split the `Tensor` along the last dimension and distribute it appropriately among the heads. Check `MultiHead` for examples. labels: Input `dict` keyed by head name. For each head, the label value can be integer or string `Tensor` with shape matching its corresponding `logits`.`labels` is a required argument when `mode` equals `TRAIN` or `EVAL`. optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns `train_op`. Used if `optimizer` is `None`. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results, in each head, users need to use the default `loss_reduction=SUM_OVER_BATCH_SIZE` to avoid scaling errors. Compared to the regularization losses for each head, this loss is to regularize the merged loss of all heads in multi head, and will be added to the overall training loss of multi head. Returns: A `model_fn.EstimatorSpec` instance. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode, or if both are set. If `mode` is not in Estimator's `ModeKeys`. """ with ops.name_scope(self.name, 'multi_head'): logits_dict = self._check_logits_and_labels(logits, labels) # Get all estimator spec. all_estimator_spec = [] for head in self._heads: all_estimator_spec.append( head.create_estimator_spec( features=features, mode=mode, logits=logits_dict[head.name], labels=labels[head.name] if labels else None, train_op_fn=_no_op_train_fn)) # Predict. predictions = self.predictions(logits) if mode == ModeKeys.PREDICT: export_outputs = self._merge_predict_export_outputs(all_estimator_spec) return model_fn.EstimatorSpec( mode=ModeKeys.PREDICT, predictions=predictions, export_outputs=export_outputs) loss = self.loss(logits, labels, features, mode, regularization_losses) # Eval. if mode == ModeKeys.EVAL: eval_metrics = self.metrics(regularization_losses=regularization_losses) updated_metrics = self.update_metrics( eval_metrics, features, logits, labels, regularization_losses=regularization_losses) return model_fn.EstimatorSpec( mode=ModeKeys.EVAL, predictions=predictions, loss=loss, eval_metric_ops=updated_metrics) # Train. if mode == ModeKeys.TRAIN: # train_op. if optimizer is not None: if train_op_fn is not None: raise ValueError('train_op_fn and optimizer cannot both be set.') train_op = optimizer.minimize( loss, global_step=training_util.get_global_step()) elif train_op_fn is not None: train_op = train_op_fn(loss) else: raise ValueError('train_op_fn and optimizer cannot both be None.') # Create summary. base_head.create_estimator_spec_summary(loss, regularization_losses) # eval_metrics. eval_metrics = {} for spec in all_estimator_spec: eval_metrics.update(spec.eval_metric_ops or {}) # predictions can be used to access the logits in `TRAIN` mode return model_fn.EstimatorSpec( mode=ModeKeys.TRAIN, loss=loss, train_op=train_op, predictions=predictions, eval_metric_ops=eval_metrics) raise ValueError('mode={} unrecognized'.format(mode))
def _create_tpu_estimator_spec(self, features, mode, logits, labels=None, optimizer=None, trainable_variables=None, train_op_fn=None, update_ops=None, regularization_losses=None): """Returns an `model_fn._TPUEstimatorSpec`. Args: features: Input `dict` of `Tensor` or `SparseTensor` objects. mode: Estimator's `ModeKeys`. logits: for single event, indepdent event, logits is a tensor of shape [batch_size, 1], for correlated event, a dict with event_name as key, value as tensor of shape [batch_size, 1]. labels: dict keyed by 'event_name' and 'event_name.time_of_event' with value as tensors of shape [batch_size] or [batch_size, 1]. For correlated events, labels for all events are provided. Otherwise, only the event associated with this head is provided. Here is one example label: {u'respiration_failure.time_to_event': <tf.Tensor 'Cast:0' shape=(32,) dtype=float32>, u'respiration_failure': <tf.Tensor 'Batch/batch:110' shape=(32,) dtype=int64>} `labels` is required argument when `mode` equals `TRAIN` or `EVAL`. optimizer: An `tf.keras.optimizers.Optimizer` instance to optimize the loss in TRAIN mode. Namely, sets `train_op = optimizer.get_updates(loss, trainable_variables)`, which updates variables to minimize `loss`. trainable_variables: A list or tuple of `Variable` objects to update to minimize `loss`. In Tensorflow 1.x, by default these are the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. As Tensorflow 2.x doesn't have collections and GraphKeys, trainable_variables need to be passed explicitly here. train_op_fn: Function that takes a scalar loss `Tensor` and returns `train_op`. Used if `optimizer` is `None`. update_ops: A list or tuple of update ops to be run at training time. For example, layers such as BatchNormalization create mean and variance update ops that need to be run at training time. In Tensorflow 1.x, these are thrown into an UPDATE_OPS collection. As Tensorflow 2.x doesn't have collections, update_ops need to be passed explicitly here. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to set `loss_reduction=SUM_OVER_BATCH_SIZE` or `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to avoid scaling errors. Returns: `model_fn._TPUEstimatorSpec`. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode, or if both are set. """ tf.logging.info(mode) with ops.name_scope(self._name, 'survival_head'): # Predict. predictions = self.predictions(logits) # hazard_rates = self.hazard_rates(logits) if mode == model_fn.ModeKeys.PREDICT: # survival_output = SurvivalOutput(value=hazard_rates) return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ base_head.DEFAULT_SERVING_KEY: ( export_output.PredictOutput(predictions)), base_head.PREDICT_SERVING_KEY: ( export_output.PredictOutput(predictions)) }) regularized_training_loss = self.loss( logits=logits, labels=labels, features=features, mode=mode, regularization_losses=regularization_losses) # Eval. if mode == model_fn.ModeKeys.EVAL: eval_metrics = self.metrics(regularization_losses=regularization_losses) return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.EVAL, predictions=predictions, loss=regularized_training_loss, eval_metrics=base_head.create_eval_metrics_tuple( self.update_metrics, { 'eval_metrics': eval_metrics, 'features': features, 'logits': logits, 'labels': labels, 'regularization_losses': regularization_losses })) # Train. train_op = base_head.create_estimator_spec_train_op( self._name, optimizer=optimizer, trainable_variables=trainable_variables, train_op_fn=train_op_fn, update_ops=update_ops, regularized_training_loss=regularized_training_loss) # Create summary. base_head.create_estimator_spec_summary( regularized_training_loss, regularization_losses, self._summary_key) return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, train_op=train_op)