def mixup(batch_size, alpha, same_mix_weight_per_batch, use_truncated_beta, use_random_shuffling, images, labels): """Applies Mixup regularization to a batch of images and labels. [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz Mixup: Beyond Empirical Risk Minimization. ICLR'18, https://arxiv.org/abs/1710.09412 Arguments: batch_size: The input batch size for images and labels. alpha: Float that controls the strength of Mixup regularization. same_mix_weight_per_batch: whether to use the same mix coef over the batch. use_truncated_beta: whether to sample from Beta_[0,1](alpha, alpha) or from the truncated distribution Beta_[1/2, 1](alpha, alpha). use_random_shuffling: Whether to pair images by random shuffling (default is a deterministic pairing by reversing the batch). images: A batch of images of shape [batch_size, ...] labels: A batch of labels of shape [batch_size, num_classes] Returns: A tuple of (images, labels) with the same dimensions as the input with Mixup regularization applied. """ if same_mix_weight_per_batch: mix_weight = ed.Beta(alpha, alpha, sample_shape=[1, 1]) mix_weight = tf.tile(mix_weight, [batch_size, 1]) else: mix_weight = ed.Beta(alpha, alpha, sample_shape=[batch_size, 1]) if use_truncated_beta: mix_weight = tf.maximum(mix_weight, 1. - mix_weight) images_mix_weight = tf.reshape(mix_weight, [batch_size, 1, 1, 1]) images_mix_weight = tf.cast(images_mix_weight, images.dtype) if not use_random_shuffling: # Mixup on a single batch is implemented by taking a weighted sum with the # same batch in reverse. mixup_index = tf.reverse(tf.range(batch_size), axis=[0]) else: mixup_index = tf.random.shuffle(tf.range(batch_size)) images_mix = (images * images_mix_weight + tf.gather(images, mixup_index) * (1. - images_mix_weight)) mix_weight = tf.cast(mix_weight, labels.dtype) labels_mix = labels * mix_weight + tf.gather( labels, mixup_index) * (1. - mix_weight) return images_mix, labels_mix
def mixup(self, batch_size, alpha, images, labels): """Applies Mixup regularization to a batch of images and labels. [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz Mixup: Beyond Empirical Risk Minimization. ICLR'18, https://arxiv.org/abs/1710.09412 Arguments: batch_size: The input batch size for images and labels. alpha: Float that controls the strength of Mixup regularization. images: A batch of images of shape [batch_size, ...] labels: A batch of labels of shape [batch_size, num_classes] Returns: A tuple of (images, labels) with the same dimensions as the input with Mixup regularization applied. """ mix_weight = ed.Beta(alpha, alpha, sample_shape=[batch_size, 1]) mix_weight = tf.maximum(mix_weight, 1. - mix_weight) images_mix_weight = tf.reshape(mix_weight, [batch_size, 1, 1, 1]) images_mix_weight = tf.cast(images_mix_weight, images.dtype) # Mixup on a single batch is implemented by taking a weighted sum with the # same batch in reverse. images_mix = (images * images_mix_weight + images[::-1] * (1. - images_mix_weight)) mix_weight = tf.cast(mix_weight, labels.dtype) labels_mix = labels * mix_weight + labels[::-1] * (1. - mix_weight) return images_mix, labels_mix
def mixup(batch_size, aug_params, images, labels): """Applies Mixup regularization to a batch of images and labels. [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz Mixup: Beyond Empirical Risk Minimization. ICLR'18, https://arxiv.org/abs/1710.09412 `aug_params` can have the follwing fields: augmix: whether or not to run AugMix. mixup_alpha: the alpha to use in the Beta distribution. aug_count: the number of augmentations to use in AugMix. same_mix_weight_per_batch: whether to use the same mix coef over the batch. use_truncated_beta: whether to sample from Beta_[0,1](alpha, alpha) or from the truncated distribution Beta_[1/2, 1](alpha, alpha). use_random_shuffling: Whether to pair images by random shuffling (default is a deterministic pairing by reversing the batch). Arguments: batch_size: The input batch size for images and labels. aug_params: Dict of data augmentation hyper parameters. images: A batch of images of shape [batch_size, ...] labels: A batch of labels of shape [batch_size, num_classes] Returns: A tuple of (images, labels) with the same dimensions as the input with Mixup regularization applied. """ augmix = aug_params.get('augmix', False) alpha = aug_params.get('mixup_alpha', 0.) aug_count = aug_params.get('aug_count', 3) same_mix_weight_per_batch = aug_params.get('same_mix_weight_per_batch', False) use_truncated_beta = aug_params.get('use_truncated_beta', True) use_random_shuffling = aug_params.get('use_random_shuffling', False) if augmix and same_mix_weight_per_batch: raise ValueError( 'Can only set one of `augmix` or `same_mix_weight_per_batch`.') # 4 is hard-coding to aug_count=3. Fix this later! if augmix: mix_weight = ed.Beta(alpha, alpha, sample_shape=[batch_size, aug_count + 1, 1]) elif same_mix_weight_per_batch: mix_weight = ed.Beta(alpha, alpha, sample_shape=[1, 1]) mix_weight = tf.tile(mix_weight, [batch_size, 1]) else: mix_weight = ed.Beta(alpha, alpha, sample_shape=[batch_size, 1]) if use_truncated_beta: mix_weight = tf.maximum(mix_weight, 1. - mix_weight) if augmix: images_mix_weight = tf.reshape(mix_weight, [batch_size, aug_count + 1, 1, 1, 1]) else: images_mix_weight = tf.reshape(mix_weight, [batch_size, 1, 1, 1]) images_mix_weight = tf.cast(images_mix_weight, images.dtype) if use_random_shuffling: mixup_index = tf.random.shuffle(tf.range(batch_size)) else: # Mixup on a single batch is implemented by taking a weighted sum with the # same batch in reverse. mixup_index = tf.reverse(tf.range(batch_size), axis=[0]) images_mix = (images * images_mix_weight + tf.gather(images, mixup_index) * (1. - images_mix_weight)) mix_weight = tf.cast(mix_weight, labels.dtype) if augmix: labels = tf.reshape(tf.tile(labels, [1, aug_count + 1]), [batch_size, aug_count + 1, -1]) labels_mix = (labels * mix_weight + tf.gather(labels, mixup_index) * (1. - mix_weight)) labels_mix = tf.reshape(tf.transpose(labels_mix, [1, 0, 2]), [batch_size * (aug_count + 1), -1]) else: labels_mix = (labels * mix_weight + tf.gather(labels, mixup_index) * (1. - mix_weight)) return images_mix, labels_mix