예제 #1
0
def mixup(batch_size, alpha, same_mix_weight_per_batch, use_truncated_beta,
          use_random_shuffling, images, labels):
    """Applies Mixup regularization to a batch of images and labels.

  [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz
    Mixup: Beyond Empirical Risk Minimization.
    ICLR'18, https://arxiv.org/abs/1710.09412

  Arguments:
    batch_size: The input batch size for images and labels.
    alpha: Float that controls the strength of Mixup regularization.
    same_mix_weight_per_batch: whether to use the same mix coef over the batch.
    use_truncated_beta: whether to sample from Beta_[0,1](alpha, alpha) or from
       the truncated distribution Beta_[1/2, 1](alpha, alpha).
    use_random_shuffling: Whether to pair images by random shuffling
      (default is a deterministic pairing by reversing the batch).
    images: A batch of images of shape [batch_size, ...]
    labels: A batch of labels of shape [batch_size, num_classes]

  Returns:
    A tuple of (images, labels) with the same dimensions as the input with
    Mixup regularization applied.
  """
    if same_mix_weight_per_batch:
        mix_weight = ed.Beta(alpha, alpha, sample_shape=[1, 1])
        mix_weight = tf.tile(mix_weight, [batch_size, 1])
    else:
        mix_weight = ed.Beta(alpha, alpha, sample_shape=[batch_size, 1])

    if use_truncated_beta:
        mix_weight = tf.maximum(mix_weight, 1. - mix_weight)

    images_mix_weight = tf.reshape(mix_weight, [batch_size, 1, 1, 1])
    images_mix_weight = tf.cast(images_mix_weight, images.dtype)

    if not use_random_shuffling:
        # Mixup on a single batch is implemented by taking a weighted sum with the
        # same batch in reverse.
        mixup_index = tf.reverse(tf.range(batch_size), axis=[0])
    else:
        mixup_index = tf.random.shuffle(tf.range(batch_size))

    images_mix = (images * images_mix_weight + tf.gather(images, mixup_index) *
                  (1. - images_mix_weight))
    mix_weight = tf.cast(mix_weight, labels.dtype)
    labels_mix = labels * mix_weight + tf.gather(
        labels, mixup_index) * (1. - mix_weight)
    return images_mix, labels_mix
    def mixup(self, batch_size, alpha, images, labels):
        """Applies Mixup regularization to a batch of images and labels.

    [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz
      Mixup: Beyond Empirical Risk Minimization.
      ICLR'18, https://arxiv.org/abs/1710.09412

    Arguments:
      batch_size: The input batch size for images and labels.
      alpha: Float that controls the strength of Mixup regularization.
      images: A batch of images of shape [batch_size, ...]
      labels: A batch of labels of shape [batch_size, num_classes]

    Returns:
      A tuple of (images, labels) with the same dimensions as the input with
      Mixup regularization applied.
    """
        mix_weight = ed.Beta(alpha, alpha, sample_shape=[batch_size, 1])
        mix_weight = tf.maximum(mix_weight, 1. - mix_weight)
        images_mix_weight = tf.reshape(mix_weight, [batch_size, 1, 1, 1])
        images_mix_weight = tf.cast(images_mix_weight, images.dtype)
        # Mixup on a single batch is implemented by taking a weighted sum with the
        # same batch in reverse.
        images_mix = (images * images_mix_weight + images[::-1] *
                      (1. - images_mix_weight))
        mix_weight = tf.cast(mix_weight, labels.dtype)
        labels_mix = labels * mix_weight + labels[::-1] * (1. - mix_weight)
        return images_mix, labels_mix
예제 #3
0
def mixup(batch_size, aug_params, images, labels):
    """Applies Mixup regularization to a batch of images and labels.

  [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz
    Mixup: Beyond Empirical Risk Minimization.
    ICLR'18, https://arxiv.org/abs/1710.09412

  `aug_params` can have the follwing fields:
    augmix: whether or not to run AugMix.
    mixup_alpha: the alpha to use in the Beta distribution.
    aug_count: the number of augmentations to use in AugMix.
    same_mix_weight_per_batch: whether to use the same mix coef over the batch.
    use_truncated_beta: whether to sample from Beta_[0,1](alpha, alpha) or from
       the truncated distribution Beta_[1/2, 1](alpha, alpha).
    use_random_shuffling: Whether to pair images by random shuffling
      (default is a deterministic pairing by reversing the batch).

  Arguments:
    batch_size: The input batch size for images and labels.
    aug_params: Dict of data augmentation hyper parameters.
    images: A batch of images of shape [batch_size, ...]
    labels: A batch of labels of shape [batch_size, num_classes]

  Returns:
    A tuple of (images, labels) with the same dimensions as the input with
    Mixup regularization applied.
  """
    augmix = aug_params.get('augmix', False)
    alpha = aug_params.get('mixup_alpha', 0.)
    aug_count = aug_params.get('aug_count', 3)
    same_mix_weight_per_batch = aug_params.get('same_mix_weight_per_batch',
                                               False)
    use_truncated_beta = aug_params.get('use_truncated_beta', True)
    use_random_shuffling = aug_params.get('use_random_shuffling', False)

    if augmix and same_mix_weight_per_batch:
        raise ValueError(
            'Can only set one of `augmix` or `same_mix_weight_per_batch`.')

    # 4 is hard-coding to aug_count=3. Fix this later!
    if augmix:
        mix_weight = ed.Beta(alpha,
                             alpha,
                             sample_shape=[batch_size, aug_count + 1, 1])
    elif same_mix_weight_per_batch:
        mix_weight = ed.Beta(alpha, alpha, sample_shape=[1, 1])
        mix_weight = tf.tile(mix_weight, [batch_size, 1])
    else:
        mix_weight = ed.Beta(alpha, alpha, sample_shape=[batch_size, 1])

    if use_truncated_beta:
        mix_weight = tf.maximum(mix_weight, 1. - mix_weight)

    if augmix:
        images_mix_weight = tf.reshape(mix_weight,
                                       [batch_size, aug_count + 1, 1, 1, 1])
    else:
        images_mix_weight = tf.reshape(mix_weight, [batch_size, 1, 1, 1])
    images_mix_weight = tf.cast(images_mix_weight, images.dtype)

    if use_random_shuffling:
        mixup_index = tf.random.shuffle(tf.range(batch_size))
    else:
        # Mixup on a single batch is implemented by taking a weighted sum with the
        # same batch in reverse.
        mixup_index = tf.reverse(tf.range(batch_size), axis=[0])

    images_mix = (images * images_mix_weight + tf.gather(images, mixup_index) *
                  (1. - images_mix_weight))

    mix_weight = tf.cast(mix_weight, labels.dtype)
    if augmix:
        labels = tf.reshape(tf.tile(labels, [1, aug_count + 1]),
                            [batch_size, aug_count + 1, -1])
        labels_mix = (labels * mix_weight + tf.gather(labels, mixup_index) *
                      (1. - mix_weight))
        labels_mix = tf.reshape(tf.transpose(labels_mix, [1, 0, 2]),
                                [batch_size * (aug_count + 1), -1])
    else:
        labels_mix = (labels * mix_weight + tf.gather(labels, mixup_index) *
                      (1. - mix_weight))
    return images_mix, labels_mix