예제 #1
0
파일: head.py 프로젝트: didukhle/tensorflow
def multi_label_head(n_classes,
                     weight_column=None,
                     thresholds=None,
                     label_vocabulary=None,
                     loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
                     loss_fn=None,
                     classes_for_class_based_metrics=None,
                     name=None):
  """Creates a `_Head` for multi-label classification.

  Multi-label classification handles the case where each example may have zero
  or more associated labels, from a discrete set. This is distinct from
  `multi_class_head` which has exactly one label per example.

  Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over
  the batch. Namely, if the input logits have shape `[batch_size, n_classes]`,
  the loss is the average over `n_classes` and the weighted sum over
  `batch_size`.

  The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many
  applications, the shape is `[batch_size, n_classes]`.

  Labels can be:
  * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`
  * An integer `SparseTensor` of class indices. The `dense_shape` must be
    `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`.
  * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape`
    must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`.

  If `weight_column` is specified, weights must be of shape
  `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.

  Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
  `(labels, logits, features)` as arguments and returns unreduced loss with
  shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with
  shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies
  `label_vocabulary` to the input labels before passing them to `loss_fn`.

  The head can be used with a canned estimator. Example:

  ```python
  my_head = tf.contrib.estimator.multi_label_head(n_classes=3)
  my_estimator = tf.contrib.estimator.DNNEstimator(
      head=my_head,
      hidden_units=...,
      feature_columns=...)
  ```

  It can also be used with a custom `model_fn`. Example:

  ```python
  def _my_model_fn(features, labels, mode):
    my_head = tf.contrib.estimator.multi_label_head(n_classes=3)
    logits = tf.keras.Model(...)(features)

    return my_head.create_estimator_spec(
        features=features,
        mode=mode,
        labels=labels,
        optimizer=tf.AdagradOptimizer(learning_rate=0.1),
        logits=logits)

  my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)
  ```

  Args:
    n_classes: Number of classes, must be greater than 1 (for 1 class, use
      `binary_classification_head`).
    weight_column: A string or a `_NumericColumn` created by
      `tf.feature_column.numeric_column` defining feature column representing
      weights. It is used to down weight or boost examples during training. It
      will be multiplied by the loss of the example.  Per-class weighting is
      not supported.
    thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision
      and recall metrics are evaluated for each threshold value. The threshold
      is applied to the predicted probabilities, i.e. above the threshold is
      `true`, below is `false`.
    label_vocabulary: A list of strings represents possible label values. If it
      is not given, that means labels are already encoded as integer within
      [0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor
      string type and have any value in `label_vocabulary`. Also there will be
      errors if vocabulary is not provided and labels are string.
    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
      reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely
      weighted sum of losses divided by batch size. See `tf.losses.Reduction`.
    loss_fn: Optional loss function.
    classes_for_class_based_metrics: List of integer class IDs or string class
      names for which per-class metrics are evaluated. If integers, all must be
      in the range `[0, n_classes - 1]`. If strings, all must be in
      `label_vocabulary`.
    name: name of the head. If provided, summary and metrics keys will be
      suffixed by `"/" + name`. Also used as `name_scope` when creating ops.

  Returns:
    An instance of `_Head` for multi-label classification.

  Raises:
    ValueError: if `n_classes`, `thresholds`, `loss_reduction`, `loss_fn` or
    `metric_class_ids` is invalid.
  """
  thresholds = tuple(thresholds) if thresholds else tuple()
  if n_classes is None or n_classes < 2:
    raise ValueError(
        'n_classes must be > 1 for multi-class classification. '
        'Given: {}'.format(n_classes))
  for threshold in thresholds:
    if (threshold <= 0.0) or (threshold >= 1.0):
      raise ValueError(
          'thresholds must be in (0, 1) range. Given: {}'.format(threshold))
  if label_vocabulary is not None:
    if not isinstance(label_vocabulary, (list, tuple)):
      raise ValueError(
          'label_vocabulary must be a list or tuple. '
          'Given type: {}'.format(type(label_vocabulary)))
    if len(label_vocabulary) != n_classes:
      raise ValueError(
          'Length of label_vocabulary must be n_classes ({}). '
          'Given: {}'.format(n_classes, len(label_vocabulary)))
  if loss_fn:
    head_lib._validate_loss_fn_args(loss_fn)  # pylint:disable=protected-access
  if (loss_reduction not in losses.Reduction.all() or
      loss_reduction == losses.Reduction.NONE):
    raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
  classes_for_class_based_metrics = tuple(
      [] if classes_for_class_based_metrics is None
      else classes_for_class_based_metrics)
  if classes_for_class_based_metrics:
    if isinstance(classes_for_class_based_metrics[0], six.string_types):
      if not label_vocabulary:
        raise ValueError(
            'label_vocabulary must be provided when '
            'classes_for_class_based_metrics are sting.')
      class_ids = []
      for class_string in classes_for_class_based_metrics:
        class_ids.append(label_vocabulary.index(class_string))
      classes_for_class_based_metrics = tuple(class_ids)
    else:
      for class_id in classes_for_class_based_metrics:
        if (class_id < 0) or (class_id >= n_classes):
          raise ValueError(
              'All classes_for_class_based_metrics must be in range [0, {}]. '
              'Given: {}'.format(n_classes - 1, class_id))
  return _MultiLabelHead(
      n_classes=n_classes, weight_column=weight_column, thresholds=thresholds,
      label_vocabulary=label_vocabulary, loss_reduction=loss_reduction,
      loss_fn=loss_fn,
      classes_for_class_based_metrics=classes_for_class_based_metrics,
      name=name)
예제 #2
0
def multi_label_head(n_classes,
                     weight_column=None,
                     thresholds=None,
                     label_vocabulary=None,
                     loss_reduction=losses.Reduction.SUM,
                     loss_fn=None,
                     name=None):
  """Creates a `_Head` for multi-label classification.

  Multi-label classification handles the case where each example may have zero
  or more associated labels, from a discrete set. This is distinct from
  `multi_class_head` which has exactly one label per example.

  Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over
  the batch. Namely, if the input logits have shape `[batch_size, n_classes]`,
  the loss is the average over `n_classes` and the weighted sum over
  `batch_size`.

  The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many
  applications, the shape is `[batch_size, n_classes]`.

  Labels can be:
  * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`
  * An integer `SparseTensor` of class indices. The `dense_shape` must be
    `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`.
  * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape`
    must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`.

  If `weight_column` is specified, weights must be of shape
  `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.

  Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
  `(labels, logits, features)` as arguments and returns unreduced loss with
  shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with
  shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies
  `label_vocabulary` to the input labels before passing them to `loss_fn`.

  Args:
    n_classes: Number of classes, must be greater than 1 (for 1 class, use
      `binary_classification_head`).
    weight_column: A string or a `_NumericColumn` created by
      `tf.feature_column.numeric_column` defining feature column representing
      weights. It is used to down weight or boost examples during training. It
      will be multiplied by the loss of the example.  Per-class weighting is
      not supported.
    thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision
      and recall metrics are evaluated for each threshold value. The threshold
      is applied to the predicted probabilities, i.e. above the threshold is
      `true`, below is `false`.
    label_vocabulary: A list of strings represents possible label values. If it
      is not given, that means labels are already encoded as integer within
      [0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor
      string type and have any value in `label_vocabulary`. Also there will be
      errors if vocabulary is not provided and labels are string.
    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
      reduce training loss over batch. Defaults to `SUM`.
    loss_fn: Optional loss function.
    name: name of the head. If provided, summary and metrics keys will be
      suffixed by `"/" + name`. Also used as `name_scope` when creating ops.

  Returns:
    An instance of `_Head` for multi-label classification.

  Raises:
    ValueError: if `n_classes`, `thresholds`, `loss_reduction` or `loss_fn` is
    invalid.
  """
  thresholds = tuple(thresholds) if thresholds else tuple()
  if n_classes is None or n_classes < 2:
    raise ValueError(
        'n_classes must be > 1 for multi-class classification. '
        'Given: {}'.format(n_classes))
  for threshold in thresholds:
    if (threshold <= 0.0) or (threshold >= 1.0):
      raise ValueError(
          'thresholds must be in (0, 1) range. Given: {}'.format(threshold))
  if label_vocabulary is not None:
    if not isinstance(label_vocabulary, (list, tuple)):
      raise ValueError(
          'label_vocabulary must be a list or tuple. '
          'Given type: {}'.format(type(label_vocabulary)))
    if len(label_vocabulary) != n_classes:
      raise ValueError(
          'Length of label_vocabulary must be n_classes ({}). '
          'Given: {}'.format(n_classes, len(label_vocabulary)))
  if loss_fn:
    head_lib._validate_loss_fn_args(loss_fn)  # pylint:disable=protected-access
  if (loss_reduction not in losses.Reduction.all() or
      loss_reduction == losses.Reduction.NONE):
    raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
  return _MultiLabelHead(
      n_classes=n_classes, weight_column=weight_column, thresholds=thresholds,
      label_vocabulary=label_vocabulary, loss_reduction=loss_reduction,
      loss_fn=loss_fn, name=name)
예제 #3
0
def multi_label_head(n_classes,
                     weight_column=None,
                     thresholds=None,
                     label_vocabulary=None,
                     loss_reduction=losses.Reduction.SUM,
                     loss_fn=None,
                     name=None):
    """Creates a `_Head` for multi-label classification.

  Multi-label classification handles the case where each example may have zero
  or more associated labels, from a discrete set. This is distinct from
  `multi_class_head` which has exactly one label per example.

  Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over
  the batch. Namely, if the input logits have shape `[batch_size, n_classes]`,
  the loss is the average over `n_classes` and the weighted sum over
  `batch_size`.

  The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many
  applications, the shape is `[batch_size, n_classes]`.

  Labels can be:
  * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`
  * An integer `SparseTensor` of class indices. The `dense_shape` must be
    `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`.
  * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape`
    must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`.

  If `weight_column` is specified, weights must be of shape
  `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.

  Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
  `(labels, logits, features)` as arguments and returns unreduced loss with
  shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with
  shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies
  `label_vocabulary` to the input labels before passing them to `loss_fn`.

  Args:
    n_classes: Number of classes, must be greater than 1 (for 1 class, use
      `binary_classification_head`).
    weight_column: A string or a `_NumericColumn` created by
      `tf.feature_column.numeric_column` defining feature column representing
      weights. It is used to down weight or boost examples during training. It
      will be multiplied by the loss of the example.  Per-class weighting is
      not supported.
    thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision
      and recall metrics are evaluated for each threshold value. The threshold
      is applied to the predicted probabilities, i.e. above the threshold is
      `true`, below is `false`.
    label_vocabulary: A list of strings represents possible label values. If it
      is not given, that means labels are already encoded as integer within
      [0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor
      string type and have any value in `label_vocabulary`. Also there will be
      errors if vocabulary is not provided and labels are string.
    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
      reduce training loss over batch. Defaults to `SUM`.
    loss_fn: Optional loss function.
    name: name of the head. If provided, summary and metrics keys will be
      suffixed by `"/" + name`. Also used as `name_scope` when creating ops.

  Returns:
    An instance of `_Head` for multi-label classification.

  Raises:
    ValueError: if `n_classes`, `thresholds`, `loss_reduction` or `loss_fn` is
    invalid.
  """
    thresholds = tuple(thresholds) if thresholds else tuple()
    if n_classes is None or n_classes < 2:
        raise ValueError(
            'n_classes must be > 1 for multi-class classification. '
            'Given: {}'.format(n_classes))
    for threshold in thresholds:
        if (threshold <= 0.0) or (threshold >= 1.0):
            raise ValueError(
                'thresholds must be in (0, 1) range. Given: {}'.format(
                    threshold))
    if label_vocabulary is not None:
        if not isinstance(label_vocabulary, (list, tuple)):
            raise ValueError('label_vocabulary must be a list or tuple. '
                             'Given type: {}'.format(type(label_vocabulary)))
        if len(label_vocabulary) != n_classes:
            raise ValueError(
                'Length of label_vocabulary must be n_classes ({}). '
                'Given: {}'.format(n_classes, len(label_vocabulary)))
    if loss_fn:
        head_lib._validate_loss_fn_args(loss_fn)  # pylint:disable=protected-access
    if (loss_reduction not in losses.Reduction.all()
            or loss_reduction == losses.Reduction.NONE):
        raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
    return _MultiLabelHead(n_classes=n_classes,
                           weight_column=weight_column,
                           thresholds=thresholds,
                           label_vocabulary=label_vocabulary,
                           loss_reduction=loss_reduction,
                           loss_fn=loss_fn,
                           name=name)
예제 #4
0
def multi_label_head(n_classes,
                     weight_column=None,
                     thresholds=None,
                     label_vocabulary=None,
                     loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
                     loss_fn=None,
                     classes_for_class_based_metrics=None,
                     name=None):
    """Creates a `_Head` for multi-label classification.

  Multi-label classification handles the case where each example may have zero
  or more associated labels, from a discrete set. This is distinct from
  `multi_class_head` which has exactly one label per example.

  Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over
  the batch. Namely, if the input logits have shape `[batch_size, n_classes]`,
  the loss is the average over `n_classes` and the weighted sum over
  `batch_size`.

  The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many
  applications, the shape is `[batch_size, n_classes]`.

  Labels can be:

  * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`
  * An integer `SparseTensor` of class indices. The `dense_shape` must be
    `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`.
  * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape`
    must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`.

  If `weight_column` is specified, weights must be of shape
  `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.

  Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
  `(labels, logits, features)` as arguments and returns unreduced loss with
  shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with
  shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies
  `label_vocabulary` to the input labels before passing them to `loss_fn`.

  The head can be used with a canned estimator. Example:

  ```python
  my_head = tf.contrib.estimator.multi_label_head(n_classes=3)
  my_estimator = tf.contrib.estimator.DNNEstimator(
      head=my_head,
      hidden_units=...,
      feature_columns=...)
  ```

  It can also be used with a custom `model_fn`. Example:

  ```python
  def _my_model_fn(features, labels, mode):
    my_head = tf.contrib.estimator.multi_label_head(n_classes=3)
    logits = tf.keras.Model(...)(features)

    return my_head.create_estimator_spec(
        features=features,
        mode=mode,
        labels=labels,
        optimizer=tf.AdagradOptimizer(learning_rate=0.1),
        logits=logits)

  my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)
  ```

  Args:
    n_classes: Number of classes, must be greater than 1 (for 1 class, use
      `binary_classification_head`).
    weight_column: A string or a `_NumericColumn` created by
      `tf.feature_column.numeric_column` defining feature column representing
      weights. It is used to down weight or boost examples during training. It
      will be multiplied by the loss of the example.  Per-class weighting is
      not supported.
    thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision
      and recall metrics are evaluated for each threshold value. The threshold
      is applied to the predicted probabilities, i.e. above the threshold is
      `true`, below is `false`.
    label_vocabulary: A list of strings represents possible label values. If it
      is not given, that means labels are already encoded as integer within
      [0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor
      string type and have any value in `label_vocabulary`. Also there will be
      errors if vocabulary is not provided and labels are string.
    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
      reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely
      weighted sum of losses divided by batch size. See `tf.losses.Reduction`.
    loss_fn: Optional loss function.
    classes_for_class_based_metrics: List of integer class IDs or string class
      names for which per-class metrics are evaluated. If integers, all must be
      in the range `[0, n_classes - 1]`. If strings, all must be in
      `label_vocabulary`.
    name: name of the head. If provided, summary and metrics keys will be
      suffixed by `"/" + name`. Also used as `name_scope` when creating ops.

  Returns:
    An instance of `_Head` for multi-label classification.

  Raises:
    ValueError: if `n_classes`, `thresholds`, `loss_reduction`, `loss_fn` or
    `metric_class_ids` is invalid.
  """
    thresholds = tuple(thresholds) if thresholds else tuple()
    if n_classes is None or n_classes < 2:
        raise ValueError(
            'n_classes must be > 1 for multi-class classification. '
            'Given: {}'.format(n_classes))
    for threshold in thresholds:
        if (threshold <= 0.0) or (threshold >= 1.0):
            raise ValueError(
                'thresholds must be in (0, 1) range. Given: {}'.format(
                    threshold))
    if label_vocabulary is not None:
        if not isinstance(label_vocabulary, (list, tuple)):
            raise ValueError('label_vocabulary must be a list or tuple. '
                             'Given type: {}'.format(type(label_vocabulary)))
        if len(label_vocabulary) != n_classes:
            raise ValueError(
                'Length of label_vocabulary must be n_classes ({}). '
                'Given: {}'.format(n_classes, len(label_vocabulary)))
    if loss_fn:
        head_lib._validate_loss_fn_args(loss_fn)  # pylint:disable=protected-access
    if (loss_reduction not in losses.Reduction.all()
            or loss_reduction == losses.Reduction.NONE):
        raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
    classes_for_class_based_metrics = tuple(
        [] if classes_for_class_based_metrics is None else
        classes_for_class_based_metrics)
    if classes_for_class_based_metrics:
        if isinstance(classes_for_class_based_metrics[0], six.string_types):
            if not label_vocabulary:
                raise ValueError('label_vocabulary must be provided when '
                                 'classes_for_class_based_metrics are sting.')
            class_ids = []
            for class_string in classes_for_class_based_metrics:
                class_ids.append(label_vocabulary.index(class_string))
            classes_for_class_based_metrics = tuple(class_ids)
        else:
            for class_id in classes_for_class_based_metrics:
                if (class_id < 0) or (class_id >= n_classes):
                    raise ValueError(
                        'All classes_for_class_based_metrics must be in range [0, {}]. '
                        'Given: {}'.format(n_classes - 1, class_id))
    return _MultiLabelHead(
        n_classes=n_classes,
        weight_column=weight_column,
        thresholds=thresholds,
        label_vocabulary=label_vocabulary,
        loss_reduction=loss_reduction,
        loss_fn=loss_fn,
        classes_for_class_based_metrics=classes_for_class_based_metrics,
        name=name)