def update_metrics(self, eval_metrics, features, logits, labels, regularization_losses=None): """Updates eval metrics. See `base_head.Head` for details.""" # Compute predictions. predictions = self.predictions(logits) predicted_value = predictions[ prediction_keys.PredictionKeys.PREDICTIONS] logits = base_head.check_logits_final_dim(logits, self.logits_dimension) label_ids = self._processed_labels(logits, labels) unweighted_loss, weights = self._unweighted_loss_and_weights( logits, label_ids, features) # Update metrics. eval_metrics[self._loss_mean_key].update_state(values=unweighted_loss, sample_weight=weights) eval_metrics[self._label_mean_key].update_state(values=labels, sample_weight=weights) base_head.update_metric_with_broadcast_weights( eval_metrics[self._prediction_mean_key], predicted_value, weights) if regularization_losses is not None: regularization_loss = math_ops.add_n(regularization_losses) eval_metrics[self._loss_regularization_key].update_state( values=regularization_loss) return eval_metrics
def update_metrics(self, eval_metrics, features, logits, labels, regularization_losses=None): """Updates eval metrics. See `base_head.Head` for details.""" preds = self.predictions(logits) class_ids = preds[prediction_keys.PredictionKeys.CLASS_IDS] logits = base_head.check_logits_final_dim(logits, self.logits_dimension) labels = self._processed_labels(logits, labels) unweighted_loss, weights = self._unweighted_loss_and_weights( logits, labels, features) # Update metrics. eval_metrics[self._loss_mean_key].update_state(values=unweighted_loss, sample_weight=weights) eval_metrics[self._accuracy_key].update_state(y_true=labels, y_pred=class_ids, sample_weight=weights) eval_metrics[self._precision_key].update_state(y_true=labels, y_pred=class_ids, sample_weight=weights) eval_metrics[self._recall_key].update_state(y_true=labels, y_pred=class_ids, sample_weight=weights) logistic_key = prediction_keys.PredictionKeys.LOGISTIC predictions = self.predictions(logits, [logistic_key]) logistic = predictions[logistic_key] base_head.update_metric_with_broadcast_weights( eval_metrics[self._prediction_mean_key], logistic, weights) base_head.update_metric_with_broadcast_weights( eval_metrics[self._label_mean_key], labels, weights) self._update_accuracy_baseline(eval_metrics) self._update_auc(auc_metric=eval_metrics[self._auc_key], labels=labels, predictions=logistic, weights=weights) self._update_auc(auc_metric=eval_metrics[self._auc_pr_key], labels=labels, predictions=logistic, weights=weights) if regularization_losses is not None: regularization_loss = tf.math.add_n(regularization_losses) eval_metrics[self._loss_regularization_key].update_state( values=regularization_loss) for i in range(len(self._thresholds)): eval_metrics[self._accuracy_keys[i]].update_state( y_true=labels, y_pred=logistic, sample_weight=weights) eval_metrics[self._precision_keys[i]].update_state( y_true=labels, y_pred=logistic, sample_weight=weights) eval_metrics[self._recall_keys[i]].update_state( y_true=labels, y_pred=logistic, sample_weight=weights) return eval_metrics
def update_metrics(self, eval_metrics, features, logits, labels, regularization_losses=None): """Updates eval metrics. See `base_head.Head` for details.""" logits = base_head.check_logits_final_dim(logits, self.logits_dimension) processed_labels = self._processed_labels(logits, labels) unweighted_loss, weights = self._unweighted_loss_and_weights( logits, processed_labels, features) prob_key = prediction_keys.PredictionKeys.PROBABILITIES predictions = self.predictions(logits, [prob_key]) probabilities = predictions[prob_key] # Update metrics. eval_metrics[self._loss_mean_key].update_state(values=unweighted_loss, sample_weight=weights) # TODO(b/118843532): update Keras metrics # eval_metrics[self._auc_key].update_state(...) # eval_metrics[self._auc_pr_key].update_state(...) if regularization_losses is not None: regularization_loss = math_ops.add_n(regularization_losses) eval_metrics[self._loss_regularization_key].update_state( values=regularization_loss) for i in range(len(self._thresholds)): eval_metrics[self._accuracy_keys[i]].update_state( y_true=labels, y_pred=probabilities, sample_weight=weights) # TODO(b/118843532): update Keras metrics # eval_metrics[self._precision_keys[i]].update_state( # ...) # eval_metrics[self._recall_keys[i]].update_state( # ...) for i, class_id in enumerate(self._classes_for_class_based_metrics): batch_rank = array_ops.rank(probabilities) - 1 begin = array_ops.concat([ array_ops.zeros([batch_rank], dtype=dtypes.int32), [class_id] ], axis=0) size = array_ops.concat( [-1 * array_ops.ones([batch_rank], dtype=dtypes.int32), [1]], axis=0) class_probabilities = array_ops.slice(probabilities, begin=begin, size=size) # class_labels = array_ops.slice(labels, begin=begin, size=size) # TODO(b/118843532): update Keras metrics base_head.update_metric_with_broadcast_weights( eval_metrics[self._prob_keys[i]], class_probabilities, weights) # eval_metrics[self._auc_keys[i]].update_state(...) # eval_metrics[self._auc_pr_key[i]].update_state(...) return eval_metrics
def update_metrics(self, eval_metrics, features, logits, labels, regularization_losses=None): """Updates eval metrics. See `base_head.Head` for details.""" logits = base_head.check_logits_final_dim(logits, self.logits_dimension) two_class_logits = array_ops.concat( (array_ops.zeros_like(logits), logits), axis=-1, name='two_class_logits') labels = self._processed_labels(logits, labels) unweighted_loss, weights = self._unweighted_loss_and_weights( logits, labels, features) # Update metrics. eval_metrics[self._loss_mean_key].update_state(values=unweighted_loss, sample_weight=weights) eval_metrics[self._accuracy_key].update_state(y_true=labels, y_pred=two_class_logits, sample_weight=weights) # TODO(b/118843532): update Keras metrics # eval_metrics[self._precision_key].update(...) # eval_metrics[self._recall_key].update(...) logistic_key = prediction_keys.PredictionKeys.LOGISTIC predictions = self.predictions(logits, [logistic_key]) logistic = predictions[logistic_key] base_head.update_metric_with_broadcast_weights( eval_metrics[self._prediction_mean_key], logistic, weights) base_head.update_metric_with_broadcast_weights( eval_metrics[self._label_mean_key], labels, weights) # TODO(b/118843532): update Keras metrics # eval_metrics[self._accuracy_baseline_key].update_state(...) # eval_metrics[self._auc_key].update_state(...) # eval_metrics[self._auc_pr_key].update_state(...) if regularization_losses is not None: regularization_loss = math_ops.add_n(regularization_losses) eval_metrics[self._loss_regularization_key].update_state( values=regularization_loss) for i in range(len(self._thresholds)): eval_metrics[self._accuracy_keys[i]].update_state( y_true=labels, y_pred=logistic, sample_weight=weights) # TODO(b/118843532): update Keras metrics # eval_metrics[self._precision_keys[i]].update_state( # ...) # eval_metrics[self._recall_keys[i]].update_state( # ...) return eval_metrics
def update_metrics(self, eval_metrics, features, logits, labels, regularization_losses=None): """Updates eval metrics. Args: eval_metrics: See `base_head.Head` for details. features: dict with feature name strings as key, tensor as value. logits: estimated obs. value, [batch, time_len, num_obs] tensor. labels: ground truth observation, feature dict with obs. and interv. codes as keys, values tensor with shape [batch_size, context_window_size]. regularization_losses: See `base_head.Head` for details. Returns: eval_metrics: See `base_head.Head` for details. """ processed_labels = self._processed_labels(logits, labels, features) unweighted_loss, weights = self._unweighted_loss_and_weights( logits, processed_labels, features, mode=model_fn.ModeKeys.EVAL) # Update metrics. eval_metrics[self._loss_mean_key].update_state(values=unweighted_loss, sample_weight=weights) value_key = prediction_keys.PredictionKeys.LOGITS predictions = self.predictions(logits, [value_key]) value_predictions = predictions[value_key] base_head.update_metric_with_broadcast_weights( eval_metrics[self._prediction_mean_key], value_predictions, weights) eval_metrics[self._mean_abs_error].update_state( processed_labels[:, (-1 - self._model_hparams.last_obs_len):-1, :], value_predictions) # label_mean represents the percentage of censored events. In case of # mortality, it is the percentage of survived patients. base_head.update_metric_with_broadcast_weights( eval_metrics[self._label_mean_key], processed_labels, weights) if regularization_losses is not None: regularization_loss = math_ops.add_n(regularization_losses) eval_metrics[self._loss_regularization_key].update_state( values=regularization_loss) return eval_metrics
def update_metrics(self, eval_metrics, features, logits, labels, regularization_losses=None): """Updates and returns the eval metrics. See `Head` for more details.""" logits = base_head.check_logits_final_dim(logits, self.logits_dimension) two_class_logits = array_ops.concat((array_ops.zeros_like(logits), logits), axis=-1, name='two_class_logits') labels = self._processed_labels(logits, labels) unweighted_loss, weights = self._unweighted_loss_and_weights( logits, labels, features) # Update metrics. eval_metrics[self._loss_mean_key].update_state( values=unweighted_loss, sample_weight=weights) eval_metrics[self._accuracy_key].update_state( y_true=labels, y_pred=two_class_logits, sample_weight=weights) # TODO(b/118843532): update Keras metrics # eval_metrics[self._precision_key].update(...) # eval_metrics[self._recall_key].update(...) logistic_key = prediction_keys.PredictionKeys.LOGISTIC predictions = self.predictions(logits, [logistic_key]) logistic = predictions[logistic_key] base_head.update_metric_with_broadcast_weights( eval_metrics[self._prediction_mean_key], logistic, weights) base_head.update_metric_with_broadcast_weights( eval_metrics[self._label_mean_key], labels, weights) # TODO(b/118843532): update Keras metrics # eval_metrics[self._accuracy_baseline_key].update_state(...) # eval_metrics[self._auc_key].update_state(...) # eval_metrics[self._auc_pr_key].update_state(...) if regularization_losses is not None: regularization_loss = math_ops.add_n(regularization_losses) eval_metrics[self._loss_regularization_key].update_state( values=regularization_loss) keys = metric_keys.MetricKeys for threshold in self._thresholds: accuracy_key = self._summary_key(keys.ACCURACY_AT_THRESHOLD % threshold) eval_metrics[accuracy_key].update_state( y_true=labels, y_pred=logistic, sample_weight=weights) # TODO(b/118843532): update Keras metrics # precision_key = keys.PRECISION_AT_THRESHOLD % threshold # eval_metrics[precision_key].update_state( # ...) # recall_key = keys.RECALL_AT_THRESHOLD % threshold # eval_metrics[recall_key].update_state( # ...) return eval_metrics
def update_metrics(self, eval_metrics, features, logits, labels, regularization_losses=None): """Updates eval metrics. Args: eval_metrics: See `base_head.Head` for details. features: See `base_head.Head` for details. logits: for single event, indepdent event, logits is a tensor of shape [batch_size, 1], for correlated event, a dict with event_name as key, value as tensor of shape [batch_size, 1]. labels: dict keyed by 'event_name' and 'event_name.time_of_event' with value as tensors of shape [batch_size] or [batch_size, 1]. For correlated events, labels for all events are provided. Otherwise, only the event associated with this head is provided. regularization_losses: See `base_head.Head` for details. Returns: eval_metrics: See `base_head.Head` for details. """ processed_logits = self._processed_logits(logits) processed_labels = self._processed_labels(logits, labels) time_to_event, censored = processed_labels unweighted_loss, weights = self._unweighted_loss_and_weights( processed_logits, processed_labels, features) # Update metrics. eval_metrics[self._loss_mean_key].update_state( values=unweighted_loss, sample_weight=weights) prob_key = prediction_keys.PredictionKeys.PROBABILITIES predictions = self.predictions(logits, [prob_key]) probabilities = predictions[prob_key] base_head.update_metric_with_broadcast_weights( eval_metrics[self._prediction_mean_key], probabilities, weights) if self._model_hparams.da_tlen > 0 and not self._model_hparams.event_relation: y_true_list = [] y_pred_list = [] for i in range(int(self._model_hparams.da_tlen / SLOT_TO_WINDOW) + 1): model = self._survival_model( params=processed_logits, labels=processed_labels, event_index=self._event_index, model_hparams=self._model_hparams) window_start_t = i * UNITS_IN_HR * self._model_hparams.da_sslot * SLOT_TO_WINDOW # pylint: disable=line-too-long window_end_t = ( i + 1) * UNITS_IN_HR * self._model_hparams.da_sslot * SLOT_TO_WINDOW # probabilities_at_window shape [batch_size, 1]. probabilities_at_window = model.probability_within_window( window_start_t=window_start_t, window_end_t=window_end_t) base_head.update_metric_with_broadcast_weights( eval_metrics[self._probablity_within_window_list[i]], probabilities_at_window, weights) y_true, y_pred = self._true_and_predict_within_window( time_to_event, censored, probabilities_at_window, window_start_t, window_end_t) # y_true, y_pred shape [batch_size] y_true_list.append(y_true) y_pred_list.append(y_pred) tf.logging.info(y_true) tf.logging.info(y_pred) eval_metrics[self._auc_pr].update_state( tf.concat(y_true_list, axis=0), tf.concat(y_pred_list, axis=0)) eval_metrics[self._auc_roc].update_state( tf.concat(y_true_list, axis=0), tf.concat(y_pred_list, axis=0)) # 24hr and 48hr window AUC. probabilities_at_24 = model.probability_within_window( window_start_t=0, window_end_t=UNITS_IN_HR * 24) y_true_24, y_pred_24 = self._true_and_predict_within_window( time_to_event, censored, probabilities_at_24, 0, UNITS_IN_HR * 24) eval_metrics[self._auc_roc_24].update_state(y_true_24, y_pred_24) probabilities_at_48 = model.probability_within_window( window_start_t=0, window_end_t=UNITS_IN_HR * 48) y_true_48, y_pred_48 = self._true_and_predict_within_window( time_to_event, censored, probabilities_at_48, 0, UNITS_IN_HR * 48) eval_metrics[self._auc_roc_48].update_state(y_true_48, y_pred_48) observed_time = tf.boolean_mask( time_to_event[:, self._event_index] / UNITS_IN_HR, tf.logical_not(censored[:, self._event_index])) predicted_time = model.predicted_time() predicted_time = tf.boolean_mask( predicted_time, tf.logical_not(censored[:, self._event_index])) eval_metrics[self._mean_abs_error].update_state(observed_time, predicted_time) # label_mean represents the percentage of censored events. In case of # mortality, it is the percentage of survived patients. base_head.update_metric_with_broadcast_weights( eval_metrics[self._label_mean_key], censored, weights) if regularization_losses is not None: regularization_loss = math_ops.add_n(regularization_losses) eval_metrics[self._loss_regularization_key].update_state( values=regularization_loss) return eval_metrics