def create_loss(self, features, labels, params, is_training=True): """Build the loss tensors for discriminator and generator. This method will set self.d_loss and self.g_loss. Args: features: Optional dictionary with inputs to the model ("images" should contain the real images and "z" the noise for the generator). labels: Tensor will labels. Use self._get_one_hot_labels(labels) to get a one hot encoded tensor. params: Dictionary with hyperparameters passed to TPUEstimator. Additional TPUEstimator will set 3 keys: `batch_size`, `use_tpu`, `tpu_context`. `batch_size` is the batch size for this core. is_training: If True build the model in training mode. If False build the model for inference mode (e.g. use trained averages for batch norm). Raises: ValueError: If set of meta/hyper parameters is not supported. """ images = features["images"] # Real images. generated = features["generated"] # Fake images. if self.conditional: y = self._get_one_hot_labels(labels) sampled_y = self._get_one_hot_labels(features["sampled_labels"]) all_y = tf.concat([y, sampled_y], axis=0) else: y = None sampled_y = None all_y = None if self._deprecated_split_disc_calls: with tf.name_scope("disc_for_real"): d_real, d_real_logits, _ = self.discriminator( images, y=y, is_training=is_training) with tf.name_scope("disc_for_fake"): d_fake, d_fake_logits, _ = self.discriminator( generated, y=sampled_y, is_training=is_training) else: # Compute discriminator output for real and fake images in one batch. all_images = tf.concat([images, generated], axis=0) d_all, d_all_logits, _ = self.discriminator( all_images, y=all_y, is_training=is_training) d_real, d_fake = tf.split(d_all, 2) d_real_logits, d_fake_logits = tf.split(d_all_logits, 2) self.d_loss, _, _, self.g_loss = loss_lib.get_losses( d_real=d_real, d_fake=d_fake, d_real_logits=d_real_logits, d_fake_logits=d_fake_logits) penalty_loss = penalty_lib.get_penalty_loss( x=images, x_fake=generated, y=y, is_training=is_training, discriminator=self.discriminator) self.d_loss += self._lambda * penalty_loss
def create_loss(self, features, labels, params, is_training=True): """Build the loss tensors for discriminator and generator. This method will set self.d_loss and self.g_loss. Args: features: Optional dictionary with inputs to the model ("images" should contain the real images and "z" the noise for the generator). labels: Tensor will labels. These are class indices. Use self._get_one_hot_labels(labels) to get a one hot encoded tensor. params: Dictionary with hyperparameters passed to TPUEstimator. Additional TPUEstimator will set 3 keys: `batch_size`, `use_tpu`, `tpu_context`. `batch_size` is the batch size for this core. is_training: If True build the model in training mode. If False build the model for inference mode (e.g. use trained averages for batch norm). Raises: ValueError: If set of meta/hyper parameters is not supported. """ images = features["images"] # Input images. generated = features["generated"] # Fake images. if self.conditional: y = self._get_one_hot_labels(labels) sampled_y = self._get_one_hot_labels(features["sampled_labels"]) else: y = None sampled_y = None all_y = None # Batch size per core. bs = images.shape[0].value def augment(imgs): imgs = random_crop_and_resize(imgs) imgs = random_apply(color_distortion, imgs, self._aug_color_jitter_prob) imgs = random_apply(color_drop, imgs, self._aug_color_drop_prob) return tf.stop_gradient(imgs) aug_images, aug_generated = augment(images), augment(generated) # concat all images all_images = tf.concat([images, generated, aug_images, aug_generated], 0) if self.conditional: all_y = tf.concat([y, sampled_y, y, sampled_y], axis=0) # Compute discriminator output for real and fake images in one batch. d_all, d_all_logits, d_latents = self.discriminator( x=all_images, y=all_y, is_training=is_training) z_projs = self._latent_projections(d_latents) d_real, d_fake, _, _ = tf.split(d_all, 4) d_real_logits, d_fake_logits, _, _ = tf.split(d_all_logits, 4) z_projs_real, z_projs_fake, z_aug_projs_real, z_aug_projs_fake = tf.split( z_projs, 4) self.d_loss, _, _, self.g_loss = loss_lib.get_losses( d_real=d_real, d_fake=d_fake, d_real_logits=d_real_logits, d_fake_logits=d_fake_logits) penalty_loss = penalty_lib.get_penalty_loss( x=images, x_fake=generated, y=y, is_training=is_training, discriminator=self.discriminator, architecture=self._architecture) self.d_loss += self._lambda * penalty_loss z_projs = tf.concat([z_projs_real, z_projs_fake], 0) z_aug_projs = tf.concat([z_aug_projs_real, z_aug_projs_fake], 0) sims_logits = tf.matmul(z_projs, z_aug_projs, transpose_b=True) logits_max = tf.reduce_max(sims_logits, 1) sims_logits = sims_logits - tf.reshape(logits_max, [-1, 1]) sims_probs = tf.nn.softmax(sims_logits) sim_labels = tf.constant(np.arange(bs * 2, dtype=np.int32)) sims_onehot = tf.one_hot(sim_labels, bs * 2) c_real_loss = -tf.reduce_mean( tf.reduce_sum(sims_onehot * tf.log(sims_probs + 1e-10), 1)) self.d_loss += c_real_loss * self._weight_contrastive_loss_d self._tpu_summary.scalar("loss/c_real_loss", c_real_loss) self._tpu_summary.scalar("loss/penalty", penalty_loss)
def create_loss(self, features, labels, params, is_training=True): """Build the loss tensors for discriminator and generator. This method will set self.d_loss and self.g_loss. Args: features: Optional dictionary with inputs to the model ("images" should contain the real images and "z" the noise for the generator). labels: Tensor will labels. These are class indices. Use self._get_one_hot_labels(labels) to get a one hot encoded tensor. params: Dictionary with hyperparameters passed to TPUEstimator. Additional TPUEstimator will set 3 keys: `batch_size`, `use_tpu`, `tpu_context`. `batch_size` is the batch size for this core. is_training: If True build the model in training mode. If False build the model for inference mode (e.g. use trained averages for batch norm). Raises: ValueError: If set of meta/hyper parameters is not supported. """ images = features["images"] # Input images. generated = features["generated"] # Fake images. if self.conditional: y = self._get_one_hot_labels(labels) sampled_y = self._get_one_hot_labels(features["sampled_labels"]) else: y = None sampled_y = None all_y = None # Batch size per core. bs = images.shape[0].value num_replicas = params[ "context"].num_replicas if "context" in params else 1 assert self._rotated_batch_size % num_replicas == 0 # Rotated batch size per core. rotated_bs = self._rotated_batch_size // num_replicas assert rotated_bs % 4 == 0 # Number of images to rotate. Each images gets rotated 3 times. num_rotated_examples = rotated_bs // 4 logging.info( "num_replicas=%s, bs=%s, rotated_bs=%s, " "num_rotated_examples=%s, params=%s", num_replicas, bs, rotated_bs, num_rotated_examples, params) # Augment the images with rotation. if "rotation" in self._self_supervision: # Put all rotation angles in a single batch, the first batch_size are # the original up-right images, followed by rotated_batch_size * 3 # rotated images with 3 different angles. assert num_rotated_examples <= bs, (num_rotated_examples, bs) images_rotated = utils.rotate_images( images[-num_rotated_examples:], rot90_scalars=(1, 2, 3)) generated_rotated = utils.rotate_images( generated[-num_rotated_examples:], rot90_scalars=(1, 2, 3)) # Labels for rotation loss (unrotated and 3 rotated versions). For # NUM_ROTATIONS=4 and num_rotated_examples=2 this is: # [0, 0, 1, 1, 2, 2, 3, 3] rotate_labels = tf.constant( np.repeat(np.arange(NUM_ROTATIONS, dtype=np.int32), num_rotated_examples)) rotate_labels_onehot = tf.one_hot(rotate_labels, NUM_ROTATIONS) all_images = tf.concat( [images, images_rotated, generated, generated_rotated], 0) if self.conditional: y_rotated = tf.tile(y[-num_rotated_examples:], [3, 1]) sampled_y_rotated = tf.tile(y[-num_rotated_examples:], [3, 1]) all_y = tf.concat([y, y_rotated, sampled_y, sampled_y_rotated], 0) else: all_images = tf.concat([images, generated], 0) if self.conditional: all_y = tf.concat([y, sampled_y], axis=0) # Compute discriminator output for real and fake images in one batch. d_all, d_all_logits, c_all_logits = self.discriminator_with_rotation_head( all_images, y=all_y, is_training=is_training) d_real, d_fake = tf.split(d_all, 2) d_real_logits, d_fake_logits = tf.split(d_all_logits, 2) c_real_logits, c_fake_logits = tf.split(c_all_logits, 2) # Separate the true/fake scores from whole rotation batch. d_real_logits = d_real_logits[:bs] d_fake_logits = d_fake_logits[:bs] d_real = d_real[:bs] d_fake = d_fake[:bs] self.d_loss, _, _, self.g_loss = loss_lib.get_losses( d_real=d_real, d_fake=d_fake, d_real_logits=d_real_logits, d_fake_logits=d_fake_logits) penalty_loss = penalty_lib.get_penalty_loss( x=images, x_fake=generated, y=y, is_training=is_training, discriminator=self.discriminator, architecture=self._architecture) self.d_loss += self._lambda * penalty_loss # Add rotation augmented loss. if "rotation" in self._self_supervision: # We take an even pieces for all rotation angles assert len(c_real_logits.shape.as_list()) == 2, c_real_logits.shape assert len(c_fake_logits.shape.as_list()) == 2, c_fake_logits.shape c_real_logits = c_real_logits[-rotated_bs:] c_fake_logits = c_fake_logits[-rotated_bs:] preds_onreal = tf.cast(tf.argmax(c_real_logits, -1), rotate_labels.dtype) accuracy = tf.reduce_mean( tf.cast(tf.equal(rotate_labels, preds_onreal), tf.float32)) c_real_probs = tf.nn.softmax(c_real_logits) c_fake_probs = tf.nn.softmax(c_fake_logits) c_real_loss = -tf.reduce_mean( tf.reduce_sum( rotate_labels_onehot * tf.log(c_real_probs + 1e-10), 1)) c_fake_loss = -tf.reduce_mean( tf.reduce_sum( rotate_labels_onehot * tf.log(c_fake_probs + 1e-10), 1)) if self._self_supervision == "rotation_only": self.d_loss *= 0.0 self.g_loss *= 0.0 self.d_loss += c_real_loss * self._weight_rotation_loss_d self.g_loss += c_fake_loss * self._weight_rotation_loss_g else: c_real_loss = 0.0 c_fake_loss = 0.0 accuracy = tf.zeros([]) self._tpu_summary.scalar("loss/c_real_loss", c_real_loss) self._tpu_summary.scalar("loss/c_fake_loss", c_fake_loss) self._tpu_summary.scalar("accuracy/d_rotation", accuracy) self._tpu_summary.scalar("loss/penalty", penalty_loss)
def create_loss(self, features, labels, params, is_training=True): """Build the loss tensors for discriminator and generator. This method will set self.d_loss and self.g_loss. Args: features: Optional dictionary with inputs to the model ("images" should contain the real images and "z" the noise for the generator). labels: Tensor will labels. Use self._get_one_hot_labels(labels) to get a one hot encoded tensor. params: Dictionary with hyperparameters passed to TPUEstimator. Additional TPUEstimator will set 3 keys: `batch_size`, `use_tpu`, `tpu_context`. `batch_size` is the batch size for this core. is_training: If True build the model in training mode. If False build the model for inference mode (e.g. use trained averages for batch norm). Raises: ValueError: If set of meta/hyper parameters is not supported. """ images = features["images"] # Real images. generated = features["generated"] # Fake images. eps_generated = features["eps_generated"] #fake eps eps_bs = generated.get_shape().as_list()[0] half_bs = eps_bs // 2 if self.conditional: y = self._get_one_hot_labels(labels) sampled_y = self._get_one_hot_labels(features["sampled_labels"]) all_y = tf.concat([y, sampled_y, sampled_y], axis=0) else: y = None sampled_y = None all_y = None if self._deprecated_split_disc_calls: with tf.name_scope("disc_for_real"): d_real, d_real_logits, _ = self.discriminator( images, y=y, is_training=is_training) with tf.name_scope("disc_for_fake"): d_fake, d_fake_logits, D_in_op_g = self.discriminator( generated, y=sampled_y, is_training=is_training) with tf.name_scope("disc_eps_for_fake"): d_fake_eps, d_fake_logits_eps, D_in_op_eg = self.discriminator( eps_generated, y=sampled_y, is_training=is_training) else: # Compute discriminator output for real and fake images in one batch. all_images = tf.concat([images, generated, eps_generated], axis=0) d_all, d_all_logits, D_in_op_all = self.discriminator( all_images, y=all_y, is_training=is_training) d_real, d_fake, d_fake_eps = tf.split(d_all, 3) d_real_logits, d_fake_logits, d_fake_logits_eps = tf.split( d_all_logits, 3) _, D_in_op_g, D_in_op_eg = tf.split(D_in_op_all, 3) self.d_loss, _, _, _ = loss_lib.get_losses( d_real=d_real, d_fake=tf.concat([d_fake[:half_bs], d_fake_eps[:half_bs]], axis=0), d_real_logits=d_real_logits, d_fake_logits=tf.concat( [d_fake_logits[:half_bs], d_fake_logits_eps[:half_bs]], axis=0)) _, _, _, g_loss = loss_lib.get_losses(d_real=None, d_fake=d_fake, d_real_logits=None, d_fake_logits=d_fake_logits) _, _, _, g_loss_eps = loss_lib.get_losses( d_real=None, d_fake=d_fake_eps, d_real_logits=None, d_fake_logits=d_fake_logits_eps) self.g_loss = g_loss + g_loss_eps penalty_loss = penalty_lib.get_penalty_loss( x=images, x_fake=generated, y=y, is_training=is_training, discriminator=self.discriminator) self.d_loss += self._lambda * penalty_loss D_in_op_g = tf.reshape(D_in_op_g, [eps_bs, -1], name="reshape_g") D_in_op_eg = tf.reshape(D_in_op_eg, [eps_bs, -1], name="reshape_eg") if self._choice_of_f == 'subtract': f = tf.subtract(D_in_op_g, D_in_op_eg) elif self._choice_of_f == 'concat': f = tf.concat([D_in_op_g, D_in_op_eg], axis=1) pred_eps = self.aux_network(f) features["labels_aux"] = tf.cast(features["random_mask"], tf.float32, name="labels") self.aux_loss = tf.losses.sigmoid_cross_entropy( features["labels_aux"], pred_eps ) # tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred_eps, labels=features["labels"], name="aux_cross_entropy")) self.aux_acc = tf.reduce_mean( tf.cast(tf.equal(features["labels_aux"], tf.round(tf.sigmoid(pred_eps))), dtype=tf.float32)) self.g_loss += self._lambda_bce_loss * self.aux_loss
def create_loss(self, features, labels, params, is_training=True, reuse=False): """Build the loss tensors for discriminator and generator. This method will set self.d_loss and self.g_loss. Args: features: Optional dictionary with inputs to the model ("images" should contain the real images and "z" the noise for the generator). labels: Tensor will labels. Use self._get_one_hot_labels(labels) to get a one hot encoded tensor. params: Dictionary with hyperparameters passed to TPUEstimator. Additional TPUEstimator will set 3 keys: `batch_size`, `use_tpu`, `tpu_context`. `batch_size` is the batch size for this core. is_training: If True build the model in training mode. If False build the model for inference mode (e.g. use trained averages for batch norm). reuse: Bool, whether to reuse existing variables for the models. This is only used for unrolling discriminator iterations when training on TPU. Raises: ValueError: If set of meta/hyper parameters is not supported. """ images = features["images"] # Input images. if self.conditional: y = self._get_one_hot_labels(labels) sampled_y = self._get_one_hot_labels(features["sampled_labels"]) all_y = tf.concat([y, sampled_y], axis=0) else: y = None sampled_y = None all_y = None if self._experimental_joint_gen_for_disc: assert "generated" in features generated = features["generated"] else: logging.warning("Computing fake images for sub step separately.") z = features["z"] # Noise vector. generated = self.generator(z, y=sampled_y, is_training=is_training, reuse=reuse) if self._deprecated_split_disc_calls: with tf.name_scope("disc_for_real"): d_real, d_real_logits, _ = self.discriminator( images, y=y, is_training=is_training, reuse=reuse) with tf.name_scope("disc_for_fake"): d_fake, d_fake_logits, _ = self.discriminator( generated, y=sampled_y, is_training=is_training, reuse=True) else: # Compute discriminator output for real and fake images in one batch. all_images = tf.concat([images, generated], axis=0) d_all, d_all_logits, _ = self.discriminator( all_images, y=all_y, is_training=is_training, reuse=reuse) d_real, d_fake = tf.split(d_all, 2) d_real_logits, d_fake_logits = tf.split(d_all_logits, 2) self.d_loss, _, _, self.g_loss = loss_lib.get_losses( d_real=d_real, d_fake=d_fake, d_real_logits=d_real_logits, d_fake_logits=d_fake_logits) discriminator = functools.partial(self.discriminator, y=y) penalty_loss = penalty_lib.get_penalty_loss( x=images, x_fake=generated, is_training=is_training, discriminator=discriminator, architecture=self._architecture) self.d_loss += self._lambda * penalty_loss self._tpu_summary.scalar("loss/penalty", penalty_loss)