def merge_with_rotation_data(self, real, fake, real_labels, fake_labels, num_rot_examples): """Returns the original data concatenated with the rotated version.""" # Put all rotation angles in a single batch, the first batch_size are # the original up-right images, followed by rotated_batch_size * 3 # rotated images with 3 different angles. For NUM_ROTATIONS=4 and # num_rot_examples=2 we have labels_rotated [0, 0, 1, 1, 2, 2, 3, 3]. real_to_rot, fake_to_rot = (real[-num_rot_examples:], fake[-num_rot_examples:]) real_rotated = utils.rotate_images(real_to_rot, rot90_scalars=(1, 2, 3)) fake_rotated = utils.rotate_images(fake_to_rot, rot90_scalars=(1, 2, 3)) all_features = tf.concat([real, real_rotated, fake, fake_rotated], 0) all_labels = None if self.conditional: real_rotated_labels = tf.tile(real_labels[-num_rot_examples:], [3, 1]) fake_rotated_labels = tf.tile(fake_labels[-num_rot_examples:], [3, 1]) all_labels = tf.concat([ real_labels, real_rotated_labels, fake_labels, fake_rotated_labels ], 0) return all_features, all_labels
def create_loss(self, features, labels, params, is_training=True): """Build the loss tensors for discriminator and generator. This method will set self.d_loss and self.g_loss. Args: features: Optional dictionary with inputs to the model ("images" should contain the real images and "z" the noise for the generator). labels: Tensor will labels. These are class indices. Use self._get_one_hot_labels(labels) to get a one hot encoded tensor. params: Dictionary with hyperparameters passed to TPUEstimator. Additional TPUEstimator will set 3 keys: `batch_size`, `use_tpu`, `tpu_context`. `batch_size` is the batch size for this core. is_training: If True build the model in training mode. If False build the model for inference mode (e.g. use trained averages for batch norm). Raises: ValueError: If set of meta/hyper parameters is not supported. """ images = features["images"] # Input images. generated = features["generated"] # Fake images. if self.conditional: y = self._get_one_hot_labels(labels) sampled_y = self._get_one_hot_labels(features["sampled_labels"]) else: y = None sampled_y = None all_y = None # Batch size per core. bs = images.shape[0].value num_replicas = params[ "context"].num_replicas if "context" in params else 1 assert self._rotated_batch_size % num_replicas == 0 # Rotated batch size per core. rotated_bs = self._rotated_batch_size // num_replicas assert rotated_bs % 4 == 0 # Number of images to rotate. Each images gets rotated 3 times. num_rotated_examples = rotated_bs // 4 logging.info( "num_replicas=%s, bs=%s, rotated_bs=%s, " "num_rotated_examples=%s, params=%s", num_replicas, bs, rotated_bs, num_rotated_examples, params) # Augment the images with rotation. if "rotation" in self._self_supervision: # Put all rotation angles in a single batch, the first batch_size are # the original up-right images, followed by rotated_batch_size * 3 # rotated images with 3 different angles. assert num_rotated_examples <= bs, (num_rotated_examples, bs) images_rotated = utils.rotate_images( images[-num_rotated_examples:], rot90_scalars=(1, 2, 3)) generated_rotated = utils.rotate_images( generated[-num_rotated_examples:], rot90_scalars=(1, 2, 3)) # Labels for rotation loss (unrotated and 3 rotated versions). For # NUM_ROTATIONS=4 and num_rotated_examples=2 this is: # [0, 0, 1, 1, 2, 2, 3, 3] rotate_labels = tf.constant( np.repeat(np.arange(NUM_ROTATIONS, dtype=np.int32), num_rotated_examples)) rotate_labels_onehot = tf.one_hot(rotate_labels, NUM_ROTATIONS) all_images = tf.concat( [images, images_rotated, generated, generated_rotated], 0) if self.conditional: y_rotated = tf.tile(y[-num_rotated_examples:], [3, 1]) sampled_y_rotated = tf.tile(y[-num_rotated_examples:], [3, 1]) all_y = tf.concat([y, y_rotated, sampled_y, sampled_y_rotated], 0) else: all_images = tf.concat([images, generated], 0) if self.conditional: all_y = tf.concat([y, sampled_y], axis=0) # Compute discriminator output for real and fake images in one batch. d_all, d_all_logits, c_all_logits = self.discriminator_with_rotation_head( all_images, y=all_y, is_training=is_training) d_real, d_fake = tf.split(d_all, 2) d_real_logits, d_fake_logits = tf.split(d_all_logits, 2) c_real_logits, c_fake_logits = tf.split(c_all_logits, 2) # Separate the true/fake scores from whole rotation batch. d_real_logits = d_real_logits[:bs] d_fake_logits = d_fake_logits[:bs] d_real = d_real[:bs] d_fake = d_fake[:bs] self.d_loss, _, _, self.g_loss = loss_lib.get_losses( d_real=d_real, d_fake=d_fake, d_real_logits=d_real_logits, d_fake_logits=d_fake_logits) penalty_loss = penalty_lib.get_penalty_loss( x=images, x_fake=generated, y=y, is_training=is_training, discriminator=self.discriminator, architecture=self._architecture) self.d_loss += self._lambda * penalty_loss # Add rotation augmented loss. if "rotation" in self._self_supervision: # We take an even pieces for all rotation angles assert len(c_real_logits.shape.as_list()) == 2, c_real_logits.shape assert len(c_fake_logits.shape.as_list()) == 2, c_fake_logits.shape c_real_logits = c_real_logits[-rotated_bs:] c_fake_logits = c_fake_logits[-rotated_bs:] preds_onreal = tf.cast(tf.argmax(c_real_logits, -1), rotate_labels.dtype) accuracy = tf.reduce_mean( tf.cast(tf.equal(rotate_labels, preds_onreal), tf.float32)) c_real_probs = tf.nn.softmax(c_real_logits) c_fake_probs = tf.nn.softmax(c_fake_logits) c_real_loss = -tf.reduce_mean( tf.reduce_sum( rotate_labels_onehot * tf.log(c_real_probs + 1e-10), 1)) c_fake_loss = -tf.reduce_mean( tf.reduce_sum( rotate_labels_onehot * tf.log(c_fake_probs + 1e-10), 1)) if self._self_supervision == "rotation_only": self.d_loss *= 0.0 self.g_loss *= 0.0 self.d_loss += c_real_loss * self._weight_rotation_loss_d self.g_loss += c_fake_loss * self._weight_rotation_loss_g else: c_real_loss = 0.0 c_fake_loss = 0.0 accuracy = tf.zeros([]) self._tpu_summary.scalar("loss/c_real_loss", c_real_loss) self._tpu_summary.scalar("loss/c_fake_loss", c_fake_loss) self._tpu_summary.scalar("accuracy/d_rotation", accuracy) self._tpu_summary.scalar("loss/penalty", penalty_loss)