def predictions(self, logits, keys=None): """Return predictions based on keys. See `base_head.Head` for details. Args: logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`. For many applications, the shape is `[batch_size, logits_dimension]`. keys: a list of prediction keys. Key can be either the class variable of prediction_keys.PredictionKeys or its string value, such as: prediction_keys.PredictionKeys.LOGITS or 'logits'. Returns: A dict of predictions. """ pred_keys = prediction_keys.PredictionKeys valid_keys = [pred_keys.LOGITS, pred_keys.PROBABILITIES, pred_keys.CLASSES] if keys: base_head.check_prediction_keys(keys, valid_keys) else: keys = valid_keys logits = base_head.check_logits_final_dim(logits, self.logits_dimension) predictions = {} with ops.name_scope('predictions', values=(logits,)): if pred_keys.LOGITS in keys: predictions[pred_keys.LOGITS] = logits if pred_keys.PROBABILITIES in keys: probabilities = tf.math.sigmoid(logits, name=pred_keys.PROBABILITIES) predictions[pred_keys.PROBABILITIES] = probabilities if pred_keys.CLASSES in keys: predictions[pred_keys.CLASSES] = base_head.all_classes( logits, self._n_classes, self._label_vocabulary) return predictions
def predictions(self, logits, keys=None): """Return predictions based on keys. See `base_head.Head` for details. Args: logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`. For many applications, the shape is `[batch_size, logits_dimension]`. keys: a list or tuple of prediction keys. Each key can be either the class variable of prediction_keys.PredictionKeys or its string value, such as: prediction_keys.PredictionKeys.CLASSES or 'classes'. If not specified, it will return the predictions for all valid keys. Returns: A dict of predictions. """ pred_keys = prediction_keys.PredictionKeys valid_keys = [ pred_keys.LOGITS, pred_keys.PROBABILITIES, pred_keys.CLASS_IDS, pred_keys.CLASSES, pred_keys.ALL_CLASS_IDS, pred_keys.ALL_CLASSES ] if keys: base_head.check_prediction_keys(keys, valid_keys) else: keys = valid_keys logits = base_head.check_logits_final_dim(logits, self.logits_dimension) predictions = {} with ops.name_scope('predictions', values=(logits, )): if pred_keys.LOGITS in keys: predictions[pred_keys.LOGITS] = logits if pred_keys.PROBABILITIES in keys: probabilities = tf.compat.v1.nn.softmax( logits, name=pred_keys.PROBABILITIES) predictions[pred_keys.PROBABILITIES] = probabilities if pred_keys.CLASS_IDS in keys or pred_keys.CLASSES in keys: # class_ids's shape is [D0, D1, ... DN]. class_ids = tf.compat.v1.math.argmax(logits, axis=-1, name=pred_keys.CLASS_IDS) # Expand to [batch_size, 1]. class_ids = tf.compat.v1.expand_dims(class_ids, axis=-1) if pred_keys.CLASS_IDS in keys: predictions[pred_keys.CLASS_IDS] = class_ids if pred_keys.CLASSES in keys: if self._label_vocabulary: classes = self._class_string_table.lookup(class_ids) else: classes = tf.strings.as_string(class_ids, name='str_classes') predictions[pred_keys.CLASSES] = classes if pred_keys.ALL_CLASS_IDS in keys: predictions[pred_keys.ALL_CLASS_IDS] = base_head.all_class_ids( logits, n_classes=self._n_classes) if pred_keys.ALL_CLASSES in keys: predictions[pred_keys.ALL_CLASSES] = base_head.all_classes( logits, n_classes=self._n_classes, label_vocabulary=self._label_vocabulary) return predictions