Пример #1
0
 def __init__(self,
              label_dimension=1,
              weight_column=None,
              loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
              loss_fn=None,
              inverse_link_fn=None,
              name=None):
     if label_dimension < 1:
         raise ValueError(
             'Invalid label_dimension {}.'.format(label_dimension))
     base_head.validate_loss_reduction(loss_reduction)
     if loss_fn:
         base_head.validate_loss_fn_args(loss_fn)
     self._logits_dimension = label_dimension
     self._weight_column = weight_column
     self._loss_reduction = loss_reduction
     self._loss_fn = loss_fn
     self._inverse_link_fn = inverse_link_fn
     self._name = name
     # Metric keys.
     keys = metric_keys.MetricKeys
     self._loss_mean_key = self._summary_key(keys.LOSS_MEAN)
     self._prediction_mean_key = self._summary_key(keys.PREDICTION_MEAN)
     self._label_mean_key = self._summary_key(keys.LABEL_MEAN)
     self._loss_regularization_key = self._summary_key(
         keys.LOSS_REGULARIZATION)
Пример #2
0
 def __init__(self,
              n_classes,
              weight_column=None,
              label_vocabulary=None,
              loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
              loss_fn=None,
              name=None):
     if (n_classes is None) or (n_classes <= 2):
         raise ValueError('n_classes must be > 2: {}.'.format(n_classes))
     if label_vocabulary is not None and not isinstance(
             label_vocabulary, (list, tuple)):
         raise ValueError(
             'label_vocabulary should be a list or a tuple. Given type: {}'.
             format(type(label_vocabulary)))
     if (loss_reduction not in losses.Reduction.all()
             or loss_reduction == losses.Reduction.NONE):
         raise ValueError(
             'Invalid loss_reduction: {}'.format(loss_reduction))
     if loss_fn:
         base_head.validate_loss_fn_args(loss_fn)
     self._n_classes = n_classes
     self._weight_column = weight_column
     self._label_vocabulary = label_vocabulary
     self._loss_reduction = loss_reduction
     self._loss_fn = loss_fn
     self._name = name
     # Metric keys.
     keys = metric_keys.MetricKeys
     self._loss_mean_key = self._summary_key(keys.LOSS_MEAN)
     self._accuracy_key = self._summary_key(keys.ACCURACY)
     self._loss_regularization_key = self._summary_key(
         keys.LOSS_REGULARIZATION)
Пример #3
0
 def __init__(self,
              n_classes,
              weight_column=None,
              label_vocabulary=None,
              loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
              loss_fn=None,
              name=None):
     if n_classes is None:
         raise ValueError('n_classes cannot be None')
     if label_vocabulary is not None and not isinstance(
             label_vocabulary, (list, tuple)):
         raise ValueError(
             'label_vocabulary should be a list or a tuple. Given type: {}'.
             format(type(label_vocabulary)))
     if label_vocabulary is not None and len(label_vocabulary) != n_classes:
         raise ValueError(
             '"label_vocabulary" does not have "n_classes" items. '
             'len(label_vocabulary)={}, n_classes={}, label_vocabulary={}'.
             format(len(label_vocabulary), n_classes, label_vocabulary))
     base_head.validate_loss_reduction(loss_reduction)
     if loss_fn:
         base_head.validate_loss_fn_args(loss_fn)
     self._n_classes = base_head.validate_n_classes(n_classes)
     self._weight_column = weight_column
     self._label_vocabulary = label_vocabulary
     self._loss_reduction = loss_reduction
     self._loss_fn = loss_fn
     self._name = name
     # Metric keys.
     keys = metric_keys.MetricKeys
     self._loss_mean_key = self._summary_key(keys.LOSS_MEAN)
     self._accuracy_key = self._summary_key(keys.ACCURACY)
     self._loss_regularization_key = self._summary_key(
         keys.LOSS_REGULARIZATION)
Пример #4
0
    def __init__(self,
                 weight_column=None,
                 thresholds=None,
                 label_vocabulary=None,
                 loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
                 loss_fn=None,
                 name=None):
        if label_vocabulary is not None and not isinstance(
                label_vocabulary, (list, tuple)):
            raise ValueError(
                'label_vocabulary should be a list or a tuple. Given type: {}'.
                format(type(label_vocabulary)))
        thresholds = tuple(thresholds) if thresholds else tuple()
        for threshold in thresholds:
            if (threshold <= 0.0) or (threshold >= 1.0):
                raise ValueError('thresholds not in (0, 1): {}.'.format(
                    (thresholds, )))
        if (loss_reduction not in losses.Reduction.all()
                or loss_reduction == losses.Reduction.NONE):
            raise ValueError(
                'Invalid loss_reduction: {}. See `tf.losses.Reduction` for valid '
                'options.'.format(loss_reduction))
        if loss_fn:
            base_head.validate_loss_fn_args(loss_fn)

        self._weight_column = weight_column
        self._thresholds = thresholds
        self._label_vocabulary = label_vocabulary
        self._loss_reduction = loss_reduction
        self._loss_fn = loss_fn
        self._name = name
        # Metric keys.
        keys = metric_keys.MetricKeys
        self._loss_mean_key = self._summary_key(keys.LOSS_MEAN)
        self._accuracy_key = self._summary_key(keys.ACCURACY)
        self._precision_key = self._summary_key(keys.PRECISION)
        self._recall_key = self._summary_key(keys.RECALL)
        self._prediction_mean_key = self._summary_key(keys.PREDICTION_MEAN)
        self._label_mean_key = self._summary_key(keys.LABEL_MEAN)
        self._accuracy_baseline_key = self._summary_key(keys.ACCURACY_BASELINE)
        self._auc_key = self._summary_key(keys.AUC)
        self._auc_pr_key = self._summary_key(keys.AUC_PR)
        self._loss_regularization_key = self._summary_key(
            keys.LOSS_REGULARIZATION)
        accuracy_keys = []
        precision_keys = []
        recall_keys = []
        for threshold in self._thresholds:
            accuracy_keys.append(
                self._summary_key(keys.ACCURACY_AT_THRESHOLD % threshold))
            precision_keys.append(
                self._summary_key(keys.PRECISION_AT_THRESHOLD % threshold))
            recall_keys.append(
                self._summary_key(keys.RECALL_AT_THRESHOLD % threshold))
        self._accuracy_keys = tuple(accuracy_keys)
        self._precision_keys = tuple(precision_keys)
        self._recall_keys = tuple(recall_keys)
Пример #5
0
    def __init__(self,
                 n_classes,
                 weight_column=None,
                 thresholds=None,
                 label_vocabulary=None,
                 loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
                 loss_fn=None,
                 classes_for_class_based_metrics=None,
                 name=None):
        if n_classes is None or n_classes < 2:
            raise ValueError(
                'n_classes must be > 1 for multi-label classification. '
                'Given: {}'.format(n_classes))
        thresholds = tuple(thresholds) if thresholds else tuple()
        for threshold in thresholds:
            if (threshold <= 0.0) or (threshold >= 1.0):
                raise ValueError(
                    'thresholds must be in (0, 1) range. Given: {}'.format(
                        threshold))
        if label_vocabulary is not None:
            if not isinstance(label_vocabulary, (list, tuple)):
                raise ValueError('label_vocabulary must be a list or tuple. '
                                 'Given type: {}'.format(
                                     type(label_vocabulary)))
            if len(label_vocabulary) != n_classes:
                raise ValueError(
                    'Length of label_vocabulary must be n_classes ({}). '
                    'Given: {}'.format(n_classes, len(label_vocabulary)))

        if loss_fn:
            base_head.validate_loss_fn_args(loss_fn)
        base_head.validate_loss_reduction(loss_reduction)
        if classes_for_class_based_metrics:
            classes_for_class_based_metrics = tuple(
                classes_for_class_based_metrics)
            if isinstance(classes_for_class_based_metrics[0],
                          six.string_types):
                if not label_vocabulary:
                    raise ValueError(
                        'label_vocabulary must be provided when '
                        'classes_for_class_based_metrics are sting.')
                class_ids = []
                for class_string in classes_for_class_based_metrics:
                    class_ids.append(label_vocabulary.index(class_string))
                classes_for_class_based_metrics = tuple(class_ids)
            else:
                for class_id in classes_for_class_based_metrics:
                    if (class_id < 0) or (class_id >= n_classes):
                        raise ValueError(
                            'All classes_for_class_based_metrics must be in range [0, {}]. '
                            'Given: {}'.format(n_classes - 1, class_id))
        else:
            classes_for_class_based_metrics = tuple()
        self._n_classes = n_classes
        self._weight_column = weight_column
        self._thresholds = thresholds
        self._label_vocabulary = label_vocabulary
        self._loss_reduction = loss_reduction
        self._loss_fn = loss_fn
        self._classes_for_class_based_metrics = classes_for_class_based_metrics
        self._name = name
        # Metric keys.
        keys = metric_keys.MetricKeys
        self._loss_mean_key = self._summary_key(keys.LOSS_MEAN)
        self._auc_key = self._summary_key(keys.AUC)
        self._auc_pr_key = self._summary_key(keys.AUC_PR)
        self._loss_regularization_key = self._summary_key(
            keys.LOSS_REGULARIZATION)
        accuracy_keys = []
        precision_keys = []
        recall_keys = []
        for threshold in self._thresholds:
            accuracy_keys.append(
                self._summary_key(keys.ACCURACY_AT_THRESHOLD % threshold))
            precision_keys.append(
                self._summary_key(keys.PRECISION_AT_THRESHOLD % threshold))
            recall_keys.append(
                self._summary_key(keys.RECALL_AT_THRESHOLD % threshold))
        self._accuracy_keys = tuple(accuracy_keys)
        self._precision_keys = tuple(precision_keys)
        self._recall_keys = tuple(recall_keys)
        prob_keys = []
        auc_keys = []
        auc_pr_keys = []
        for class_id in self._classes_for_class_based_metrics:
            if self._label_vocabulary is None:
                prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id
                auc_key = keys.AUC_AT_CLASS % class_id
                auc_pr_key = keys.AUC_PR_AT_CLASS % class_id
            else:
                prob_key = (keys.PROBABILITY_MEAN_AT_NAME %
                            self._label_vocabulary[class_id])
                auc_key = keys.AUC_AT_NAME % self._label_vocabulary[class_id]
                auc_pr_key = keys.AUC_PR_AT_NAME % self._label_vocabulary[
                    class_id]
            prob_keys.append(self._summary_key(prob_key))
            auc_keys.append(self._summary_key(auc_key))
            auc_pr_keys.append(self._summary_key(auc_pr_key))
        self._prob_keys = tuple(prob_keys)
        self._auc_keys = tuple(auc_keys)
        self._auc_pr_keys = tuple(auc_pr_keys)