def multi_label_head(n_classes, weight_column=None, thresholds=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, classes_for_class_based_metrics=None, name=None): """Creates a `_Head` for multi-label classification. Multi-label classification handles the case where each example may have zero or more associated labels, from a discrete set. This is distinct from `multi_class_head` which has exactly one label per example. Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over the batch. Namely, if the input logits have shape `[batch_size, n_classes]`, the loss is the average over `n_classes` and the weighted sum over `batch_size`. The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many applications, the shape is `[batch_size, n_classes]`. Labels can be: * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` * An integer `SparseTensor` of class indices. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`. If `weight_column` is specified, weights must be of shape `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or `(labels, logits, features)` as arguments and returns unreduced loss with shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. The head can be used with a canned estimator. Example: ```python my_head = tf.contrib.estimator.multi_label_head(n_classes=3) my_estimator = tf.contrib.estimator.DNNEstimator( head=my_head, hidden_units=..., feature_columns=...) ``` It can also be used with a custom `model_fn`. Example: ```python def _my_model_fn(features, labels, mode): my_head = tf.contrib.estimator.multi_label_head(n_classes=3) logits = tf.keras.Model(...)(features) return my_head.create_estimator_spec( features=features, mode=mode, labels=labels, optimizer=tf.AdagradOptimizer(learning_rate=0.1), logits=logits) my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) ``` Args: n_classes: Number of classes, must be greater than 1 (for 1 class, use `binary_classification_head`). weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing weights. It is used to down weight or boost examples during training. It will be multiplied by the loss of the example. Per-class weighting is not supported. thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision and recall metrics are evaluated for each threshold value. The threshold is applied to the predicted probabilities, i.e. above the threshold is `true`, below is `false`. label_vocabulary: A list of strings represents possible label values. If it is not given, that means labels are already encoded as integer within [0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor string type and have any value in `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by batch size. See `tf.losses.Reduction`. loss_fn: Optional loss function. classes_for_class_based_metrics: List of integer class IDs or string class names for which per-class metrics are evaluated. If integers, all must be in the range `[0, n_classes - 1]`. If strings, all must be in `label_vocabulary`. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for multi-label classification. Raises: ValueError: if `n_classes`, `thresholds`, `loss_reduction`, `loss_fn` or `metric_class_ids` is invalid. """ thresholds = tuple(thresholds) if thresholds else tuple() if n_classes is None or n_classes < 2: raise ValueError( 'n_classes must be > 1 for multi-class classification. ' 'Given: {}'.format(n_classes)) for threshold in thresholds: if (threshold <= 0.0) or (threshold >= 1.0): raise ValueError( 'thresholds must be in (0, 1) range. Given: {}'.format(threshold)) if label_vocabulary is not None: if not isinstance(label_vocabulary, (list, tuple)): raise ValueError( 'label_vocabulary must be a list or tuple. ' 'Given type: {}'.format(type(label_vocabulary))) if len(label_vocabulary) != n_classes: raise ValueError( 'Length of label_vocabulary must be n_classes ({}). ' 'Given: {}'.format(n_classes, len(label_vocabulary))) if loss_fn: head_lib._validate_loss_fn_args(loss_fn) # pylint:disable=protected-access if (loss_reduction not in losses.Reduction.all() or loss_reduction == losses.Reduction.NONE): raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) classes_for_class_based_metrics = tuple( [] if classes_for_class_based_metrics is None else classes_for_class_based_metrics) if classes_for_class_based_metrics: if isinstance(classes_for_class_based_metrics[0], six.string_types): if not label_vocabulary: raise ValueError( 'label_vocabulary must be provided when ' 'classes_for_class_based_metrics are sting.') class_ids = [] for class_string in classes_for_class_based_metrics: class_ids.append(label_vocabulary.index(class_string)) classes_for_class_based_metrics = tuple(class_ids) else: for class_id in classes_for_class_based_metrics: if (class_id < 0) or (class_id >= n_classes): raise ValueError( 'All classes_for_class_based_metrics must be in range [0, {}]. ' 'Given: {}'.format(n_classes - 1, class_id)) return _MultiLabelHead( n_classes=n_classes, weight_column=weight_column, thresholds=thresholds, label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, loss_fn=loss_fn, classes_for_class_based_metrics=classes_for_class_based_metrics, name=name)
def multi_label_head(n_classes, weight_column=None, thresholds=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM, loss_fn=None, name=None): """Creates a `_Head` for multi-label classification. Multi-label classification handles the case where each example may have zero or more associated labels, from a discrete set. This is distinct from `multi_class_head` which has exactly one label per example. Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over the batch. Namely, if the input logits have shape `[batch_size, n_classes]`, the loss is the average over `n_classes` and the weighted sum over `batch_size`. The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many applications, the shape is `[batch_size, n_classes]`. Labels can be: * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` * An integer `SparseTensor` of class indices. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`. If `weight_column` is specified, weights must be of shape `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or `(labels, logits, features)` as arguments and returns unreduced loss with shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. Args: n_classes: Number of classes, must be greater than 1 (for 1 class, use `binary_classification_head`). weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing weights. It is used to down weight or boost examples during training. It will be multiplied by the loss of the example. Per-class weighting is not supported. thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision and recall metrics are evaluated for each threshold value. The threshold is applied to the predicted probabilities, i.e. above the threshold is `true`, below is `false`. label_vocabulary: A list of strings represents possible label values. If it is not given, that means labels are already encoded as integer within [0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor string type and have any value in `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for multi-label classification. Raises: ValueError: if `n_classes`, `thresholds`, `loss_reduction` or `loss_fn` is invalid. """ thresholds = tuple(thresholds) if thresholds else tuple() if n_classes is None or n_classes < 2: raise ValueError( 'n_classes must be > 1 for multi-class classification. ' 'Given: {}'.format(n_classes)) for threshold in thresholds: if (threshold <= 0.0) or (threshold >= 1.0): raise ValueError( 'thresholds must be in (0, 1) range. Given: {}'.format(threshold)) if label_vocabulary is not None: if not isinstance(label_vocabulary, (list, tuple)): raise ValueError( 'label_vocabulary must be a list or tuple. ' 'Given type: {}'.format(type(label_vocabulary))) if len(label_vocabulary) != n_classes: raise ValueError( 'Length of label_vocabulary must be n_classes ({}). ' 'Given: {}'.format(n_classes, len(label_vocabulary))) if loss_fn: head_lib._validate_loss_fn_args(loss_fn) # pylint:disable=protected-access if (loss_reduction not in losses.Reduction.all() or loss_reduction == losses.Reduction.NONE): raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) return _MultiLabelHead( n_classes=n_classes, weight_column=weight_column, thresholds=thresholds, label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, loss_fn=loss_fn, name=name)
def multi_label_head(n_classes, weight_column=None, thresholds=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM, loss_fn=None, name=None): """Creates a `_Head` for multi-label classification. Multi-label classification handles the case where each example may have zero or more associated labels, from a discrete set. This is distinct from `multi_class_head` which has exactly one label per example. Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over the batch. Namely, if the input logits have shape `[batch_size, n_classes]`, the loss is the average over `n_classes` and the weighted sum over `batch_size`. The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many applications, the shape is `[batch_size, n_classes]`. Labels can be: * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` * An integer `SparseTensor` of class indices. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`. If `weight_column` is specified, weights must be of shape `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or `(labels, logits, features)` as arguments and returns unreduced loss with shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. Args: n_classes: Number of classes, must be greater than 1 (for 1 class, use `binary_classification_head`). weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing weights. It is used to down weight or boost examples during training. It will be multiplied by the loss of the example. Per-class weighting is not supported. thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision and recall metrics are evaluated for each threshold value. The threshold is applied to the predicted probabilities, i.e. above the threshold is `true`, below is `false`. label_vocabulary: A list of strings represents possible label values. If it is not given, that means labels are already encoded as integer within [0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor string type and have any value in `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for multi-label classification. Raises: ValueError: if `n_classes`, `thresholds`, `loss_reduction` or `loss_fn` is invalid. """ thresholds = tuple(thresholds) if thresholds else tuple() if n_classes is None or n_classes < 2: raise ValueError( 'n_classes must be > 1 for multi-class classification. ' 'Given: {}'.format(n_classes)) for threshold in thresholds: if (threshold <= 0.0) or (threshold >= 1.0): raise ValueError( 'thresholds must be in (0, 1) range. Given: {}'.format( threshold)) if label_vocabulary is not None: if not isinstance(label_vocabulary, (list, tuple)): raise ValueError('label_vocabulary must be a list or tuple. ' 'Given type: {}'.format(type(label_vocabulary))) if len(label_vocabulary) != n_classes: raise ValueError( 'Length of label_vocabulary must be n_classes ({}). ' 'Given: {}'.format(n_classes, len(label_vocabulary))) if loss_fn: head_lib._validate_loss_fn_args(loss_fn) # pylint:disable=protected-access if (loss_reduction not in losses.Reduction.all() or loss_reduction == losses.Reduction.NONE): raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) return _MultiLabelHead(n_classes=n_classes, weight_column=weight_column, thresholds=thresholds, label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, loss_fn=loss_fn, name=name)
def multi_label_head(n_classes, weight_column=None, thresholds=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, classes_for_class_based_metrics=None, name=None): """Creates a `_Head` for multi-label classification. Multi-label classification handles the case where each example may have zero or more associated labels, from a discrete set. This is distinct from `multi_class_head` which has exactly one label per example. Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over the batch. Namely, if the input logits have shape `[batch_size, n_classes]`, the loss is the average over `n_classes` and the weighted sum over `batch_size`. The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many applications, the shape is `[batch_size, n_classes]`. Labels can be: * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` * An integer `SparseTensor` of class indices. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`. If `weight_column` is specified, weights must be of shape `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or `(labels, logits, features)` as arguments and returns unreduced loss with shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. The head can be used with a canned estimator. Example: ```python my_head = tf.contrib.estimator.multi_label_head(n_classes=3) my_estimator = tf.contrib.estimator.DNNEstimator( head=my_head, hidden_units=..., feature_columns=...) ``` It can also be used with a custom `model_fn`. Example: ```python def _my_model_fn(features, labels, mode): my_head = tf.contrib.estimator.multi_label_head(n_classes=3) logits = tf.keras.Model(...)(features) return my_head.create_estimator_spec( features=features, mode=mode, labels=labels, optimizer=tf.AdagradOptimizer(learning_rate=0.1), logits=logits) my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) ``` Args: n_classes: Number of classes, must be greater than 1 (for 1 class, use `binary_classification_head`). weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing weights. It is used to down weight or boost examples during training. It will be multiplied by the loss of the example. Per-class weighting is not supported. thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision and recall metrics are evaluated for each threshold value. The threshold is applied to the predicted probabilities, i.e. above the threshold is `true`, below is `false`. label_vocabulary: A list of strings represents possible label values. If it is not given, that means labels are already encoded as integer within [0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor string type and have any value in `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by batch size. See `tf.losses.Reduction`. loss_fn: Optional loss function. classes_for_class_based_metrics: List of integer class IDs or string class names for which per-class metrics are evaluated. If integers, all must be in the range `[0, n_classes - 1]`. If strings, all must be in `label_vocabulary`. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for multi-label classification. Raises: ValueError: if `n_classes`, `thresholds`, `loss_reduction`, `loss_fn` or `metric_class_ids` is invalid. """ thresholds = tuple(thresholds) if thresholds else tuple() if n_classes is None or n_classes < 2: raise ValueError( 'n_classes must be > 1 for multi-class classification. ' 'Given: {}'.format(n_classes)) for threshold in thresholds: if (threshold <= 0.0) or (threshold >= 1.0): raise ValueError( 'thresholds must be in (0, 1) range. Given: {}'.format( threshold)) if label_vocabulary is not None: if not isinstance(label_vocabulary, (list, tuple)): raise ValueError('label_vocabulary must be a list or tuple. ' 'Given type: {}'.format(type(label_vocabulary))) if len(label_vocabulary) != n_classes: raise ValueError( 'Length of label_vocabulary must be n_classes ({}). ' 'Given: {}'.format(n_classes, len(label_vocabulary))) if loss_fn: head_lib._validate_loss_fn_args(loss_fn) # pylint:disable=protected-access if (loss_reduction not in losses.Reduction.all() or loss_reduction == losses.Reduction.NONE): raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) classes_for_class_based_metrics = tuple( [] if classes_for_class_based_metrics is None else classes_for_class_based_metrics) if classes_for_class_based_metrics: if isinstance(classes_for_class_based_metrics[0], six.string_types): if not label_vocabulary: raise ValueError('label_vocabulary must be provided when ' 'classes_for_class_based_metrics are sting.') class_ids = [] for class_string in classes_for_class_based_metrics: class_ids.append(label_vocabulary.index(class_string)) classes_for_class_based_metrics = tuple(class_ids) else: for class_id in classes_for_class_based_metrics: if (class_id < 0) or (class_id >= n_classes): raise ValueError( 'All classes_for_class_based_metrics must be in range [0, {}]. ' 'Given: {}'.format(n_classes - 1, class_id)) return _MultiLabelHead( n_classes=n_classes, weight_column=weight_column, thresholds=thresholds, label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, loss_fn=loss_fn, classes_for_class_based_metrics=classes_for_class_based_metrics, name=name)