def update_metrics(self,
                       eval_metrics,
                       features,
                       logits,
                       labels,
                       regularization_losses=None):
        """Updates eval metrics. See `base_head.Head` for details."""
        # Compute predictions.
        predictions = self.predictions(logits)
        predicted_value = predictions[
            prediction_keys.PredictionKeys.PREDICTIONS]
        logits = base_head.check_logits_final_dim(logits,
                                                  self.logits_dimension)
        label_ids = self._processed_labels(logits, labels)
        unweighted_loss, weights = self._unweighted_loss_and_weights(
            logits, label_ids, features)

        # Update metrics.
        eval_metrics[self._loss_mean_key].update_state(values=unweighted_loss,
                                                       sample_weight=weights)
        eval_metrics[self._label_mean_key].update_state(values=labels,
                                                        sample_weight=weights)
        base_head.update_metric_with_broadcast_weights(
            eval_metrics[self._prediction_mean_key], predicted_value, weights)
        if regularization_losses is not None:
            regularization_loss = math_ops.add_n(regularization_losses)
            eval_metrics[self._loss_regularization_key].update_state(
                values=regularization_loss)
        return eval_metrics
 def update_metrics(self,
                    eval_metrics,
                    features,
                    logits,
                    labels,
                    regularization_losses=None):
     """Updates eval metrics. See `base_head.Head` for details."""
     preds = self.predictions(logits)
     class_ids = preds[prediction_keys.PredictionKeys.CLASS_IDS]
     logits = base_head.check_logits_final_dim(logits,
                                               self.logits_dimension)
     labels = self._processed_labels(logits, labels)
     unweighted_loss, weights = self._unweighted_loss_and_weights(
         logits, labels, features)
     # Update metrics.
     eval_metrics[self._loss_mean_key].update_state(values=unweighted_loss,
                                                    sample_weight=weights)
     eval_metrics[self._accuracy_key].update_state(y_true=labels,
                                                   y_pred=class_ids,
                                                   sample_weight=weights)
     eval_metrics[self._precision_key].update_state(y_true=labels,
                                                    y_pred=class_ids,
                                                    sample_weight=weights)
     eval_metrics[self._recall_key].update_state(y_true=labels,
                                                 y_pred=class_ids,
                                                 sample_weight=weights)
     logistic_key = prediction_keys.PredictionKeys.LOGISTIC
     predictions = self.predictions(logits, [logistic_key])
     logistic = predictions[logistic_key]
     base_head.update_metric_with_broadcast_weights(
         eval_metrics[self._prediction_mean_key], logistic, weights)
     base_head.update_metric_with_broadcast_weights(
         eval_metrics[self._label_mean_key], labels, weights)
     self._update_accuracy_baseline(eval_metrics)
     self._update_auc(auc_metric=eval_metrics[self._auc_key],
                      labels=labels,
                      predictions=logistic,
                      weights=weights)
     self._update_auc(auc_metric=eval_metrics[self._auc_pr_key],
                      labels=labels,
                      predictions=logistic,
                      weights=weights)
     if regularization_losses is not None:
         regularization_loss = tf.math.add_n(regularization_losses)
         eval_metrics[self._loss_regularization_key].update_state(
             values=regularization_loss)
     for i in range(len(self._thresholds)):
         eval_metrics[self._accuracy_keys[i]].update_state(
             y_true=labels, y_pred=logistic, sample_weight=weights)
         eval_metrics[self._precision_keys[i]].update_state(
             y_true=labels, y_pred=logistic, sample_weight=weights)
         eval_metrics[self._recall_keys[i]].update_state(
             y_true=labels, y_pred=logistic, sample_weight=weights)
     return eval_metrics
Example #3
0
    def update_metrics(self,
                       eval_metrics,
                       features,
                       logits,
                       labels,
                       regularization_losses=None):
        """Updates eval metrics. See `base_head.Head` for details."""
        logits = base_head.check_logits_final_dim(logits,
                                                  self.logits_dimension)
        processed_labels = self._processed_labels(logits, labels)
        unweighted_loss, weights = self._unweighted_loss_and_weights(
            logits, processed_labels, features)
        prob_key = prediction_keys.PredictionKeys.PROBABILITIES
        predictions = self.predictions(logits, [prob_key])
        probabilities = predictions[prob_key]

        # Update metrics.
        eval_metrics[self._loss_mean_key].update_state(values=unweighted_loss,
                                                       sample_weight=weights)
        # TODO(b/118843532): update Keras metrics
        # eval_metrics[self._auc_key].update_state(...)
        # eval_metrics[self._auc_pr_key].update_state(...)
        if regularization_losses is not None:
            regularization_loss = math_ops.add_n(regularization_losses)
            eval_metrics[self._loss_regularization_key].update_state(
                values=regularization_loss)
        for i in range(len(self._thresholds)):
            eval_metrics[self._accuracy_keys[i]].update_state(
                y_true=labels, y_pred=probabilities, sample_weight=weights)
            # TODO(b/118843532): update Keras metrics
            # eval_metrics[self._precision_keys[i]].update_state(
            #     ...)
            # eval_metrics[self._recall_keys[i]].update_state(
            #     ...)
        for i, class_id in enumerate(self._classes_for_class_based_metrics):
            batch_rank = array_ops.rank(probabilities) - 1
            begin = array_ops.concat([
                array_ops.zeros([batch_rank], dtype=dtypes.int32), [class_id]
            ],
                                     axis=0)
            size = array_ops.concat(
                [-1 * array_ops.ones([batch_rank], dtype=dtypes.int32), [1]],
                axis=0)
            class_probabilities = array_ops.slice(probabilities,
                                                  begin=begin,
                                                  size=size)
            # class_labels = array_ops.slice(labels, begin=begin, size=size)
            # TODO(b/118843532): update Keras metrics
            base_head.update_metric_with_broadcast_weights(
                eval_metrics[self._prob_keys[i]], class_probabilities, weights)
            # eval_metrics[self._auc_keys[i]].update_state(...)
            # eval_metrics[self._auc_pr_key[i]].update_state(...)
        return eval_metrics
Example #4
0
 def update_metrics(self,
                    eval_metrics,
                    features,
                    logits,
                    labels,
                    regularization_losses=None):
     """Updates eval metrics. See `base_head.Head` for details."""
     logits = base_head.check_logits_final_dim(logits,
                                               self.logits_dimension)
     two_class_logits = array_ops.concat(
         (array_ops.zeros_like(logits), logits),
         axis=-1,
         name='two_class_logits')
     labels = self._processed_labels(logits, labels)
     unweighted_loss, weights = self._unweighted_loss_and_weights(
         logits, labels, features)
     # Update metrics.
     eval_metrics[self._loss_mean_key].update_state(values=unweighted_loss,
                                                    sample_weight=weights)
     eval_metrics[self._accuracy_key].update_state(y_true=labels,
                                                   y_pred=two_class_logits,
                                                   sample_weight=weights)
     # TODO(b/118843532): update Keras metrics
     # eval_metrics[self._precision_key].update(...)
     # eval_metrics[self._recall_key].update(...)
     logistic_key = prediction_keys.PredictionKeys.LOGISTIC
     predictions = self.predictions(logits, [logistic_key])
     logistic = predictions[logistic_key]
     base_head.update_metric_with_broadcast_weights(
         eval_metrics[self._prediction_mean_key], logistic, weights)
     base_head.update_metric_with_broadcast_weights(
         eval_metrics[self._label_mean_key], labels, weights)
     # TODO(b/118843532): update Keras metrics
     # eval_metrics[self._accuracy_baseline_key].update_state(...)
     # eval_metrics[self._auc_key].update_state(...)
     # eval_metrics[self._auc_pr_key].update_state(...)
     if regularization_losses is not None:
         regularization_loss = math_ops.add_n(regularization_losses)
         eval_metrics[self._loss_regularization_key].update_state(
             values=regularization_loss)
     for i in range(len(self._thresholds)):
         eval_metrics[self._accuracy_keys[i]].update_state(
             y_true=labels, y_pred=logistic, sample_weight=weights)
         # TODO(b/118843532): update Keras metrics
         # eval_metrics[self._precision_keys[i]].update_state(
         #     ...)
         # eval_metrics[self._recall_keys[i]].update_state(
         #     ...)
     return eval_metrics
    def update_metrics(self,
                       eval_metrics,
                       features,
                       logits,
                       labels,
                       regularization_losses=None):
        """Updates eval metrics.

    Args:
      eval_metrics: See `base_head.Head` for details.
      features: dict with feature name strings as key, tensor as value.
      logits: estimated obs. value, [batch, time_len, num_obs] tensor.
      labels: ground truth observation, feature dict with obs. and interv. codes
        as keys, values tensor with shape [batch_size, context_window_size].
      regularization_losses: See `base_head.Head` for details.

    Returns:
      eval_metrics: See `base_head.Head` for details.
    """
        processed_labels = self._processed_labels(logits, labels, features)
        unweighted_loss, weights = self._unweighted_loss_and_weights(
            logits, processed_labels, features, mode=model_fn.ModeKeys.EVAL)
        # Update metrics.
        eval_metrics[self._loss_mean_key].update_state(values=unweighted_loss,
                                                       sample_weight=weights)
        value_key = prediction_keys.PredictionKeys.LOGITS
        predictions = self.predictions(logits, [value_key])
        value_predictions = predictions[value_key]

        base_head.update_metric_with_broadcast_weights(
            eval_metrics[self._prediction_mean_key], value_predictions,
            weights)
        eval_metrics[self._mean_abs_error].update_state(
            processed_labels[:, (-1 - self._model_hparams.last_obs_len):-1, :],
            value_predictions)

        # label_mean represents the percentage of censored events. In case of
        # mortality, it is the percentage of survived patients.
        base_head.update_metric_with_broadcast_weights(
            eval_metrics[self._label_mean_key], processed_labels, weights)

        if regularization_losses is not None:
            regularization_loss = math_ops.add_n(regularization_losses)
            eval_metrics[self._loss_regularization_key].update_state(
                values=regularization_loss)
        return eval_metrics
Example #6
0
 def update_metrics(self, eval_metrics, features, logits, labels,
                    regularization_losses=None):
   """Updates and returns the eval metrics. See `Head` for more details."""
   logits = base_head.check_logits_final_dim(logits, self.logits_dimension)
   two_class_logits = array_ops.concat((array_ops.zeros_like(logits), logits),
                                       axis=-1, name='two_class_logits')
   labels = self._processed_labels(logits, labels)
   unweighted_loss, weights = self._unweighted_loss_and_weights(
       logits, labels, features)
   # Update metrics.
   eval_metrics[self._loss_mean_key].update_state(
       values=unweighted_loss, sample_weight=weights)
   eval_metrics[self._accuracy_key].update_state(
       y_true=labels, y_pred=two_class_logits, sample_weight=weights)
   # TODO(b/118843532): update Keras metrics
   # eval_metrics[self._precision_key].update(...)
   # eval_metrics[self._recall_key].update(...)
   logistic_key = prediction_keys.PredictionKeys.LOGISTIC
   predictions = self.predictions(logits, [logistic_key])
   logistic = predictions[logistic_key]
   base_head.update_metric_with_broadcast_weights(
       eval_metrics[self._prediction_mean_key], logistic, weights)
   base_head.update_metric_with_broadcast_weights(
       eval_metrics[self._label_mean_key], labels, weights)
   # TODO(b/118843532): update Keras metrics
   # eval_metrics[self._accuracy_baseline_key].update_state(...)
   # eval_metrics[self._auc_key].update_state(...)
   # eval_metrics[self._auc_pr_key].update_state(...)
   if regularization_losses is not None:
     regularization_loss = math_ops.add_n(regularization_losses)
     eval_metrics[self._loss_regularization_key].update_state(
         values=regularization_loss)
   keys = metric_keys.MetricKeys
   for threshold in self._thresholds:
     accuracy_key = self._summary_key(keys.ACCURACY_AT_THRESHOLD % threshold)
     eval_metrics[accuracy_key].update_state(
         y_true=labels, y_pred=logistic, sample_weight=weights)
     # TODO(b/118843532): update Keras metrics
     # precision_key = keys.PRECISION_AT_THRESHOLD % threshold
     # eval_metrics[precision_key].update_state(
     #     ...)
     # recall_key = keys.RECALL_AT_THRESHOLD % threshold
     # eval_metrics[recall_key].update_state(
     #     ...)
   return eval_metrics
  def update_metrics(self,
                     eval_metrics,
                     features,
                     logits,
                     labels,
                     regularization_losses=None):
    """Updates eval metrics.

    Args:
      eval_metrics: See `base_head.Head` for details.
      features: See `base_head.Head` for details.
      logits: for single event, indepdent event, logits is a tensor of shape
        [batch_size, 1], for correlated event, a dict with event_name as key,
        value as tensor of shape [batch_size, 1].
      labels: dict keyed by 'event_name' and 'event_name.time_of_event' with
        value as tensors of shape [batch_size] or [batch_size, 1]. For
        correlated events, labels for all events are provided. Otherwise, only
        the event associated with this head is provided.
      regularization_losses: See `base_head.Head` for details.

    Returns:
      eval_metrics: See `base_head.Head` for details.
    """
    processed_logits = self._processed_logits(logits)
    processed_labels = self._processed_labels(logits, labels)
    time_to_event, censored = processed_labels

    unweighted_loss, weights = self._unweighted_loss_and_weights(
        processed_logits, processed_labels, features)

    # Update metrics.
    eval_metrics[self._loss_mean_key].update_state(
        values=unweighted_loss, sample_weight=weights)
    prob_key = prediction_keys.PredictionKeys.PROBABILITIES
    predictions = self.predictions(logits, [prob_key])
    probabilities = predictions[prob_key]

    base_head.update_metric_with_broadcast_weights(
        eval_metrics[self._prediction_mean_key], probabilities, weights)

    if self._model_hparams.da_tlen > 0 and not self._model_hparams.event_relation:
      y_true_list = []
      y_pred_list = []
      for i in range(int(self._model_hparams.da_tlen / SLOT_TO_WINDOW) + 1):
        model = self._survival_model(
            params=processed_logits,
            labels=processed_labels,
            event_index=self._event_index,
            model_hparams=self._model_hparams)
        window_start_t = i * UNITS_IN_HR * self._model_hparams.da_sslot * SLOT_TO_WINDOW  # pylint: disable=line-too-long
        window_end_t = (
            i + 1) * UNITS_IN_HR * self._model_hparams.da_sslot * SLOT_TO_WINDOW
        # probabilities_at_window shape [batch_size, 1].
        probabilities_at_window = model.probability_within_window(
            window_start_t=window_start_t, window_end_t=window_end_t)
        base_head.update_metric_with_broadcast_weights(
            eval_metrics[self._probablity_within_window_list[i]],
            probabilities_at_window, weights)
        y_true, y_pred = self._true_and_predict_within_window(
            time_to_event, censored, probabilities_at_window, window_start_t,
            window_end_t)
        # y_true, y_pred shape [batch_size]
        y_true_list.append(y_true)
        y_pred_list.append(y_pred)
        tf.logging.info(y_true)
        tf.logging.info(y_pred)

      eval_metrics[self._auc_pr].update_state(
          tf.concat(y_true_list, axis=0), tf.concat(y_pred_list, axis=0))
      eval_metrics[self._auc_roc].update_state(
          tf.concat(y_true_list, axis=0), tf.concat(y_pred_list, axis=0))

      # 24hr and 48hr window AUC.
      probabilities_at_24 = model.probability_within_window(
          window_start_t=0, window_end_t=UNITS_IN_HR * 24)
      y_true_24, y_pred_24 = self._true_and_predict_within_window(
          time_to_event, censored, probabilities_at_24, 0, UNITS_IN_HR * 24)
      eval_metrics[self._auc_roc_24].update_state(y_true_24, y_pred_24)

      probabilities_at_48 = model.probability_within_window(
          window_start_t=0, window_end_t=UNITS_IN_HR * 48)
      y_true_48, y_pred_48 = self._true_and_predict_within_window(
          time_to_event, censored, probabilities_at_48, 0, UNITS_IN_HR * 48)
      eval_metrics[self._auc_roc_48].update_state(y_true_48, y_pred_48)

      observed_time = tf.boolean_mask(
          time_to_event[:, self._event_index] / UNITS_IN_HR,
          tf.logical_not(censored[:, self._event_index]))

      predicted_time = model.predicted_time()
      predicted_time = tf.boolean_mask(
          predicted_time, tf.logical_not(censored[:, self._event_index]))
      eval_metrics[self._mean_abs_error].update_state(observed_time,
                                                      predicted_time)

    # label_mean represents the percentage of censored events. In case of
    # mortality, it is the percentage of survived patients.
    base_head.update_metric_with_broadcast_weights(
        eval_metrics[self._label_mean_key], censored, weights)

    if regularization_losses is not None:
      regularization_loss = math_ops.add_n(regularization_losses)
      eval_metrics[self._loss_regularization_key].update_state(
          values=regularization_loss)
    return eval_metrics