def __init__(self, label_dimension=1, weight_column=None, loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, loss_fn=None, inverse_link_fn=None, name=None): if label_dimension < 1: raise ValueError( 'Invalid label_dimension {}.'.format(label_dimension)) base_head.validate_loss_reduction(loss_reduction) if loss_fn: base_head.validate_loss_fn_args(loss_fn) self._logits_dimension = label_dimension self._weight_column = weight_column self._loss_reduction = loss_reduction self._loss_fn = loss_fn self._inverse_link_fn = inverse_link_fn self._name = name # Metric keys. keys = metric_keys.MetricKeys self._loss_mean_key = self._summary_key(keys.LOSS_MEAN) self._prediction_mean_key = self._summary_key(keys.PREDICTION_MEAN) self._label_mean_key = self._summary_key(keys.LABEL_MEAN) self._loss_regularization_key = self._summary_key( keys.LOSS_REGULARIZATION)
def __init__(self, n_classes, weight_column=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, update_ops=None, name=None): if (n_classes is None) or (n_classes <= 2): raise ValueError('n_classes must be > 2: {}.'.format(n_classes)) if label_vocabulary is not None and not isinstance( label_vocabulary, (list, tuple)): raise ValueError( 'label_vocabulary should be a list or a tuple. Given type: {}'. format(type(label_vocabulary))) base_head.validate_loss_reduction(loss_reduction) if loss_fn: base_head.validate_loss_fn_args(loss_fn) base_head.validate_update_ops(update_ops) self._n_classes = n_classes self._weight_column = weight_column self._label_vocabulary = label_vocabulary self._loss_reduction = loss_reduction self._loss_fn = loss_fn self._update_ops = update_ops self._name = name # Metric keys. keys = metric_keys.MetricKeys self._loss_mean_key = self._summary_key(keys.LOSS_MEAN) self._accuracy_key = self._summary_key(keys.ACCURACY) self._loss_regularization_key = self._summary_key( keys.LOSS_REGULARIZATION)
def __init__(self, n_classes, weight_column=None, label_vocabulary=None, loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, loss_fn=None, name=None): if n_classes is None: raise ValueError('n_classes cannot be None') if label_vocabulary is not None and not isinstance( label_vocabulary, (list, tuple)): raise ValueError( 'label_vocabulary should be a list or a tuple. Given type: {}'. format(type(label_vocabulary))) if label_vocabulary is not None and len(label_vocabulary) != n_classes: raise ValueError( '"label_vocabulary" does not have "n_classes" items. ' 'len(label_vocabulary)={}, n_classes={}, label_vocabulary={}'. format(len(label_vocabulary), n_classes, label_vocabulary)) base_head.validate_loss_reduction(loss_reduction) if loss_fn: base_head.validate_loss_fn_args(loss_fn) self._n_classes = base_head.validate_n_classes(n_classes) self._weight_column = weight_column self._label_vocabulary = label_vocabulary self._loss_reduction = loss_reduction self._loss_fn = loss_fn self._name = name # Metric keys. keys = metric_keys.MetricKeys self._loss_mean_key = self._summary_key(keys.LOSS_MEAN) self._accuracy_key = self._summary_key(keys.ACCURACY) self._loss_regularization_key = self._summary_key( keys.LOSS_REGULARIZATION)
def __init__(self, weight_column=None, thresholds=None, label_vocabulary=None, loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, loss_fn=None, update_ops=None, name=None): if label_vocabulary is not None and not isinstance( label_vocabulary, (list, tuple)): raise ValueError( 'label_vocabulary should be a list or a tuple. Given type: {}'. format(type(label_vocabulary))) thresholds = tuple(thresholds) if thresholds else tuple() for threshold in thresholds: if (threshold <= 0.0) or (threshold >= 1.0): raise ValueError('thresholds not in (0, 1): {}.'.format( (thresholds, ))) base_head.validate_loss_reduction(loss_reduction) if loss_fn: base_head.validate_loss_fn_args(loss_fn) base_head.validate_update_ops(update_ops) self._weight_column = weight_column self._thresholds = thresholds self._label_vocabulary = label_vocabulary self._loss_reduction = loss_reduction self._loss_fn = loss_fn self._update_ops = update_ops self._name = name # Metric keys. keys = metric_keys.MetricKeys self._loss_mean_key = self._summary_key(keys.LOSS_MEAN) self._accuracy_key = self._summary_key(keys.ACCURACY) self._precision_key = self._summary_key(keys.PRECISION) self._recall_key = self._summary_key(keys.RECALL) self._prediction_mean_key = self._summary_key(keys.PREDICTION_MEAN) self._label_mean_key = self._summary_key(keys.LABEL_MEAN) self._accuracy_baseline_key = self._summary_key(keys.ACCURACY_BASELINE) self._auc_key = self._summary_key(keys.AUC) self._auc_pr_key = self._summary_key(keys.AUC_PR) self._loss_regularization_key = self._summary_key( keys.LOSS_REGULARIZATION) accuracy_keys = [] precision_keys = [] recall_keys = [] for threshold in self._thresholds: accuracy_keys.append( self._summary_key(keys.ACCURACY_AT_THRESHOLD % threshold)) precision_keys.append( self._summary_key(keys.PRECISION_AT_THRESHOLD % threshold)) recall_keys.append( self._summary_key(keys.RECALL_AT_THRESHOLD % threshold)) self._accuracy_keys = tuple(accuracy_keys) self._precision_keys = tuple(precision_keys) self._recall_keys = tuple(recall_keys)
def __init__(self, n_classes, weight_column=None, thresholds=None, label_vocabulary=None, loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, loss_fn=None, classes_for_class_based_metrics=None, name=None): if n_classes is None or n_classes < 2: raise ValueError( 'n_classes must be > 1 for multi-label classification. ' 'Given: {}'.format(n_classes)) thresholds = tuple(thresholds) if thresholds else tuple() for threshold in thresholds: if (threshold <= 0.0) or (threshold >= 1.0): raise ValueError( 'thresholds must be in (0, 1) range. Given: {}'.format( threshold)) if label_vocabulary is not None: if not isinstance(label_vocabulary, (list, tuple)): raise ValueError('label_vocabulary must be a list or tuple. ' 'Given type: {}'.format( type(label_vocabulary))) if len(label_vocabulary) != n_classes: raise ValueError( 'Length of label_vocabulary must be n_classes ({}). ' 'Given: {}'.format(n_classes, len(label_vocabulary))) if loss_fn: base_head.validate_loss_fn_args(loss_fn) base_head.validate_loss_reduction(loss_reduction) if classes_for_class_based_metrics: classes_for_class_based_metrics = tuple( classes_for_class_based_metrics) if isinstance(classes_for_class_based_metrics[0], six.string_types): if not label_vocabulary: raise ValueError( 'label_vocabulary must be provided when ' 'classes_for_class_based_metrics are sting.') class_ids = [] for class_string in classes_for_class_based_metrics: class_ids.append(label_vocabulary.index(class_string)) classes_for_class_based_metrics = tuple(class_ids) else: for class_id in classes_for_class_based_metrics: if (class_id < 0) or (class_id >= n_classes): raise ValueError( 'All classes_for_class_based_metrics must be in range [0, {}]. ' 'Given: {}'.format(n_classes - 1, class_id)) else: classes_for_class_based_metrics = tuple() self._n_classes = n_classes self._weight_column = weight_column self._thresholds = thresholds self._label_vocabulary = label_vocabulary self._loss_reduction = loss_reduction self._loss_fn = loss_fn self._classes_for_class_based_metrics = classes_for_class_based_metrics self._name = name # Metric keys. keys = metric_keys.MetricKeys self._loss_mean_key = self._summary_key(keys.LOSS_MEAN) self._auc_key = self._summary_key(keys.AUC) self._auc_pr_key = self._summary_key(keys.AUC_PR) self._loss_regularization_key = self._summary_key( keys.LOSS_REGULARIZATION) accuracy_keys = [] precision_keys = [] recall_keys = [] for threshold in self._thresholds: accuracy_keys.append( self._summary_key(keys.ACCURACY_AT_THRESHOLD % threshold)) precision_keys.append( self._summary_key(keys.PRECISION_AT_THRESHOLD % threshold)) recall_keys.append( self._summary_key(keys.RECALL_AT_THRESHOLD % threshold)) self._accuracy_keys = tuple(accuracy_keys) self._precision_keys = tuple(precision_keys) self._recall_keys = tuple(recall_keys) prob_keys = [] auc_keys = [] auc_pr_keys = [] for class_id in self._classes_for_class_based_metrics: if self._label_vocabulary is None: prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id auc_key = keys.AUC_AT_CLASS % class_id auc_pr_key = keys.AUC_PR_AT_CLASS % class_id else: prob_key = (keys.PROBABILITY_MEAN_AT_NAME % self._label_vocabulary[class_id]) auc_key = keys.AUC_AT_NAME % self._label_vocabulary[class_id] auc_pr_key = keys.AUC_PR_AT_NAME % self._label_vocabulary[ class_id] prob_keys.append(self._summary_key(prob_key)) auc_keys.append(self._summary_key(auc_key)) auc_pr_keys.append(self._summary_key(auc_pr_key)) self._prob_keys = tuple(prob_keys) self._auc_keys = tuple(auc_keys) self._auc_pr_keys = tuple(auc_pr_keys)