Пример #1
0
  def predictions(self, logits, keys=None):
    """Return predictions based on keys.

    See `base_head.Head` for details.

    Args:
      logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.
        For many applications, the shape is `[batch_size, logits_dimension]`.
      keys: a list of prediction keys. Key can be either the class variable
        of prediction_keys.PredictionKeys or its string value, such as:
          prediction_keys.PredictionKeys.LOGITS or 'logits'.

    Returns:
      A dict of predictions.
    """
    pred_keys = prediction_keys.PredictionKeys
    valid_keys = [pred_keys.LOGITS, pred_keys.PROBABILITIES, pred_keys.CLASSES]
    if keys:
      base_head.check_prediction_keys(keys, valid_keys)
    else:
      keys = valid_keys
    logits = base_head.check_logits_final_dim(logits, self.logits_dimension)
    predictions = {}
    with ops.name_scope('predictions', values=(logits,)):
      if pred_keys.LOGITS in keys:
        predictions[pred_keys.LOGITS] = logits
      if pred_keys.PROBABILITIES in keys:
        probabilities = tf.math.sigmoid(logits, name=pred_keys.PROBABILITIES)
        predictions[pred_keys.PROBABILITIES] = probabilities
      if pred_keys.CLASSES in keys:
        predictions[pred_keys.CLASSES] = base_head.all_classes(
            logits, self._n_classes, self._label_vocabulary)

      return predictions
Пример #2
0
    def predictions(self, logits, keys=None):
        """Return predictions based on keys.

    See `base_head.Head` for details.

    Args:
      logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.
        For many applications, the shape is `[batch_size, logits_dimension]`.
      keys: a list or tuple of prediction keys. Each key can be either the class
        variable of prediction_keys.PredictionKeys or its string value, such as:
          prediction_keys.PredictionKeys.CLASSES or 'classes'. If not specified,
          it will return the predictions for all valid keys.

    Returns:
      A dict of predictions.
    """
        pred_keys = prediction_keys.PredictionKeys
        valid_keys = [
            pred_keys.LOGITS, pred_keys.PROBABILITIES, pred_keys.CLASS_IDS,
            pred_keys.CLASSES, pred_keys.ALL_CLASS_IDS, pred_keys.ALL_CLASSES
        ]
        if keys:
            base_head.check_prediction_keys(keys, valid_keys)
        else:
            keys = valid_keys
        logits = base_head.check_logits_final_dim(logits,
                                                  self.logits_dimension)
        predictions = {}
        with ops.name_scope('predictions', values=(logits, )):
            if pred_keys.LOGITS in keys:
                predictions[pred_keys.LOGITS] = logits
            if pred_keys.PROBABILITIES in keys:
                probabilities = tf.compat.v1.nn.softmax(
                    logits, name=pred_keys.PROBABILITIES)
                predictions[pred_keys.PROBABILITIES] = probabilities
            if pred_keys.CLASS_IDS in keys or pred_keys.CLASSES in keys:
                # class_ids's shape is [D0, D1, ... DN].
                class_ids = tf.compat.v1.math.argmax(logits,
                                                     axis=-1,
                                                     name=pred_keys.CLASS_IDS)
                # Expand to [batch_size, 1].
                class_ids = tf.compat.v1.expand_dims(class_ids, axis=-1)
                if pred_keys.CLASS_IDS in keys:
                    predictions[pred_keys.CLASS_IDS] = class_ids
                if pred_keys.CLASSES in keys:
                    if self._label_vocabulary:
                        classes = self._class_string_table.lookup(class_ids)
                    else:
                        classes = tf.strings.as_string(class_ids,
                                                       name='str_classes')
                    predictions[pred_keys.CLASSES] = classes
            if pred_keys.ALL_CLASS_IDS in keys:
                predictions[pred_keys.ALL_CLASS_IDS] = base_head.all_class_ids(
                    logits, n_classes=self._n_classes)
            if pred_keys.ALL_CLASSES in keys:
                predictions[pred_keys.ALL_CLASSES] = base_head.all_classes(
                    logits,
                    n_classes=self._n_classes,
                    label_vocabulary=self._label_vocabulary)
            return predictions