def val_d(self, x_real):
     z = random.normal((self.batch_size, 1, 1, self.z_dim))
     x_fake = self.G(z, training=False)
     fake_logits = self.D(x_fake, training=False)
     real_logits = self.D(x_real, training=False)
     cost = ops.d_loss_fn(fake_logits, real_logits)
     gp = self.gradient_penalty(partial(self.D, training=False), x_real, x_fake)
     cost += self.grad_penalty_weight * gp
     return cost
 def train_d(self, x_real):
     z = random.normal((self.batch_size, 1, 1, self.z_dim))
     with tf.GradientTape() as t:
         x_fake = self.G(z, training=True)
         fake_logits = self.D(DiffAugment(x_fake, policy=self.policy), training=True)
         real_logits = self.D(DiffAugment(x_real, policy=self.policy), training=True)
         cost = ops.d_loss_fn(fake_logits, real_logits)
         gp = self.gradient_penalty(partial(self.D, training=True), x_real, x_fake)
         cost += self.grad_penalty_weight * gp
     grad = t.gradient(cost, self.D.trainable_variables)
     self.d_opt.apply_gradients(zip(grad, self.D.trainable_variables))
     return cost
Example #3
0
 def train_d(self, x_real, image_scale=255.0):
     z = random.normal((self.batch_size, 1, 1, self.z_dim))
     with tf.GradientTape() as t:
         x_fake = self.G(z, training=True)
         #flist= None
         #x_fake, flist = self.Augment(images=x_fake, scale=image_scale, \
         #                        batch_shape=[self.batch_size, *self.image_shape])
         fake_logits = self.D(x_fake, training=True)
         x_real = self.Augment(images=x_real, scale=image_scale, \
                                  batch_shape=[self.batch_size, *self.image_shape])
         real_logits = self.D(x_real, training=True)
         cost = ops.d_loss_fn(fake_logits, real_logits)
         gp = self.gradient_penalty(partial(self.D, training=True), x_real,
                                    x_fake)
         cost += self.grad_penalty_weight * gp
     grad = t.gradient(cost, self.D.trainable_variables)
     self.d_opt.apply_gradients(zip(grad, self.D.trainable_variables))
     return cost