Example #1
0
    def update_state(self, labels, probabilities, **kwargs):
        """Updates this metric.

    This will flatten the labels and probabilities, and then compute the ECE
    over all predictions.

    Args:
      labels: Tensor of shape [..., ] of class labels in [0, k-1].
      probabilities: Tensor of shape [..., ], [..., 1] or [..., k] of normalized
        probabilities associated with the True class in the binary case, or with
        each of k classes in the multiclass case.
      **kwargs: Other potential keywords, which will be ignored by this method.
    """
        del kwargs  # unused
        labels = tf.convert_to_tensor(labels)
        probabilities = tf.cast(probabilities, self.dtype)

        # Flatten labels to [N, ] and probabilities to [N, 1] or [N, k].
        if tf.rank(labels) != 1:
            labels = tf.reshape(labels, [-1])
        if tf.rank(probabilities) != 2 or (tf.shape(probabilities)[0] !=
                                           tf.shape(labels)[0]):
            probabilities = tf.reshape(probabilities,
                                       [tf.shape(labels)[0], -1])
        # Extend any probabilities of shape [N, 1] to shape [N, 2].
        # NOTE: XLA does not allow for different shapes in the branches of a
        # conditional statement. Therefore, explicit indexing is used.
        given_k = tf.shape(probabilities)[-1]
        k = tf.math.maximum(2, given_k)
        probabilities = tf.cond(
            given_k < 2, lambda: tf.concat([1. - probabilities, probabilities],
                                           axis=-1)[:, -k:],
            lambda: probabilities)

        pred_labels = tf.math.argmax(probabilities, axis=-1)
        pred_probs = tf.math.reduce_max(probabilities, axis=-1)
        correct_preds = tf.math.equal(pred_labels,
                                      tf.cast(labels, pred_labels.dtype))
        correct_preds = tf.cast(correct_preds, self.dtype)

        bin_indices = tf.histogram_fixed_width_bins(pred_probs,
                                                    tf.constant([0., 1.],
                                                                self.dtype),
                                                    nbins=self.num_bins)
        batch_correct_sums = tf.math.unsorted_segment_sum(
            data=tf.cast(correct_preds, self.dtype),
            segment_ids=bin_indices,
            num_segments=self.num_bins)
        batch_prob_sums = tf.math.unsorted_segment_sum(
            data=pred_probs,
            segment_ids=bin_indices,
            num_segments=self.num_bins)
        batch_counts = tf.math.unsorted_segment_sum(
            data=tf.ones_like(bin_indices),
            segment_ids=bin_indices,
            num_segments=self.num_bins)
        batch_counts = tf.cast(batch_counts, self.dtype)
        self.correct_sums.assign_add(batch_correct_sums)
        self.prob_sums.assign_add(batch_prob_sums)
        self.counts.assign_add(batch_counts)
Example #2
0
def _compute_calibration_bin_statistics(num_bins,
                                        logits=None,
                                        labels_true=None,
                                        labels_predicted=None):
    """Compute binning statistics required for calibration measures.

  Args:
    num_bins: int, number of probability bins, e.g. 10.
    logits: Tensor, (n,nlabels), with logits for n instances and nlabels.
    labels_true: Tensor, (n,), with tf.int32 or tf.int64 elements containing
      ground truth class labels in the range [0,nlabels].
    labels_predicted: Tensor, (n,), with tf.int32 or tf.int64 elements
      containing decisions of the predictive system.  If `None`, we will use
      the argmax decision using the `logits`.

  Returns:
    bz: Tensor, shape (2,num_bins), tf.int32, counts of incorrect (row 0) and
      correct (row 1) predictions in each of the `num_bins` probability bins.
    pmean_observed: Tensor, shape (num_bins,), tf.float32, the mean predictive
      probabilities in each probability bin.
  """

    if labels_predicted is None:
        # If no labels are provided, we take the label with the maximum probability
        # decision.  This corresponds to the optimal expected minimum loss decision
        # under 0/1 loss.
        pred_y = tf.argmax(logits, axis=1, output_type=labels_true.dtype)
    else:
        pred_y = labels_predicted

    correct = tf.cast(tf.equal(pred_y, labels_true), tf.int32)

    # Collect predicted probabilities of decisions
    pred = tf.nn.softmax(logits, axis=1)
    prob_y = tf.gather(pred, pred_y[:, tf.newaxis],
                       batch_dims=1)  # p(pred_y | x)
    prob_y = tf.reshape(prob_y, (ps.size(prob_y), ))

    # Compute b/z histogram statistics:
    # bz[0,bin] contains counts of incorrect predictions in the probability bin.
    # bz[1,bin] contains counts of correct predictions in the probability bin.
    bins = tf.histogram_fixed_width_bins(prob_y, [0.0, 1.0], nbins=num_bins)
    event_bin_counts = tf.math.bincount(correct * num_bins + bins,
                                        minlength=2 * num_bins,
                                        maxlength=2 * num_bins)
    event_bin_counts = tf.reshape(event_bin_counts, (2, num_bins))

    # Compute mean predicted probability value in each of the `num_bins` bins
    pmean_observed = tf.math.unsorted_segment_sum(prob_y, bins, num_bins)
    tiny = np.finfo(dtype_util.as_numpy_dtype(logits.dtype)).tiny
    pmean_observed = pmean_observed / (
        tf.cast(tf.reduce_sum(event_bin_counts, axis=0), logits.dtype) + tiny)

    return event_bin_counts, pmean_observed
Example #3
0
    def update_state(self, labels, probabilities, **kwargs):
        """Updates this metric.

    Args:
      labels: Tensor of shape (N,) of class labels, one per example.
      probabilities: Tensor of shape (N,) or (N, k) of normalized probabilities
        associated with the True class in the binary case or with each of k
        classes in the multiclass case.
      **kwargs: Other potential keywords, which will be ignored by this method.
    """
        del kwargs  # unused
        labels = tf.squeeze(tf.convert_to_tensor(labels))
        probabilities = tf.convert_to_tensor(probabilities, self.dtype)

        if self.num_classes == 2:
            # Explicitly ensure probs have shape [n, 2] instead of [n, 1] or [n,].
            n = tf.shape(probabilities)[0]
            k = tf.size(probabilities) // n
            probabilities = tf.reshape(probabilities, [n, k])
            probabilities = tf.cond(
                k < 2,
                lambda: tf.concat([1. - probabilities, probabilities], axis=1),
                lambda: probabilities)

        pred_labels = tf.argmax(probabilities, axis=1)
        pred_probs = tf.reduce_max(probabilities, axis=1)
        correct_preds = tf.equal(pred_labels, tf.cast(labels,
                                                      pred_labels.dtype))
        correct_preds = tf.cast(correct_preds, self.dtype)

        bin_indices = tf.histogram_fixed_width_bins(pred_probs,
                                                    tf.constant([0., 1.],
                                                                self.dtype),
                                                    nbins=self.num_bins)
        batch_correct_sums = tf.math.unsorted_segment_sum(
            data=tf.cast(correct_preds, self.dtype),
            segment_ids=bin_indices,
            num_segments=self.num_bins)
        batch_prob_sums = tf.math.unsorted_segment_sum(
            data=pred_probs,
            segment_ids=bin_indices,
            num_segments=self.num_bins)
        batch_counts = tf.math.unsorted_segment_sum(
            data=tf.ones_like(bin_indices),
            segment_ids=bin_indices,
            num_segments=self.num_bins)
        batch_counts = tf.cast(batch_counts, self.dtype)
        self.correct_sums.assign_add(batch_correct_sums)
        self.prob_sums.assign_add(batch_prob_sums)
        self.counts.assign_add(batch_counts)
def _compute_calibration_bin_statistics(
    num_bins, logits=None, probabilities=None,
    labels_true=None, labels_predicted=None):
  """Compute binning statistics required for calibration measures.

  Args:
    num_bins: int, number of probability bins, e.g. 10.
    logits: Tensor, (n,nlabels), with logits for n instances and nlabels.
    probabilities: Tensor, (n,nlabels), with probs for n instances and nlabels.
    labels_true: Tensor, (n,), with tf.int32 or tf.int64 elements containing
      ground truth class labels in the range [0,nlabels].
    labels_predicted: Tensor, (n,), with tf.int32 or tf.int64 elements
      containing decisions of the predictive system.  If `None`, we will use
      the argmax decision using the `logits`.

  Returns:
    bz: Tensor, shape (2,num_bins), tf.int32, counts of incorrect (row 0) and
      correct (row 1) predictions in each of the `num_bins` probability bins.
    pmean_observed: Tensor, shape (num_bins,), tf.float32, the mean predictive
      probabilities in each probability bin.
  """
  if (logits is None) == (probabilities is None):
    raise ValueError(
        "_compute_calibration_bin_statistics expects exactly one of logits or "
        "probabilities.")
  if probabilities is None:
    if logits.get_shape().as_list()[-1] == 1:
      raise ValueError(
          "_compute_calibration_bin_statistics expects logits for binary"
          " classification of shape (n, 2) for nlabels=2 but got ",
          logits.get_shape())
    probabilities = tf.math.softmax(logits, axis=1)
  if (probabilities.get_shape().as_list()[-1] == 1 or
      len(probabilities.get_shape().as_list()) == 1):
    raise ValueError(
        "_compute_calibration_bin_statistics expects probabilities for binary"
        " classification of shape (n, 2) for nlabels=2 but got ",
        probabilities.get_shape())

  if labels_predicted is None:
    # If no labels are provided, we take the label with the maximum probability
    # decision.  This corresponds to the optimal expected minimum loss decision
    # under 0/1 loss.
    pred_y = tf.cast(tf.argmax(probabilities, axis=1), tf.int32)
  else:
    pred_y = labels_predicted

  correct = tf.cast(tf.equal(pred_y, labels_true), tf.int32)

  # Collect predicted probabilities of decisions
  prob_y = tf.compat.v1.batch_gather(probabilities,
                                     tf.expand_dims(pred_y, 1))  # p(pred_y | x)
  prob_y = tf.reshape(prob_y, (tf.size(prob_y),))

  # Compute b/z histogram statistics:
  # bz[0,bin] contains counts of incorrect predictions in the probability bin.
  # bz[1,bin] contains counts of correct predictions in the probability bin.
  bins = tf.histogram_fixed_width_bins(prob_y, [0.0, 1.0], nbins=num_bins)
  event_bin_counts = tf.math.bincount(
      correct*num_bins + bins,
      minlength=2*num_bins,
      maxlength=2*num_bins)
  event_bin_counts = tf.reshape(event_bin_counts, (2, num_bins))

  # Compute mean predicted probability value in each of the `num_bins` bins
  pmean_observed = tf.math.unsorted_segment_sum(prob_y, bins, num_bins)
  tiny = np.finfo(np.float32).tiny
  pmean_observed = pmean_observed / (
      tf.cast(tf.reduce_sum(event_bin_counts, axis=0), tf.float32) + tiny)

  return event_bin_counts, pmean_observed