def train_critic(self, labels, images): """Train the critic on one batch of data Args: labels (tf.Tensor): One-hot encoded image labels images (tf.Tensor): Images Returns: tf.Tensor: Critic loss for the step """ noise = tf.random.uniform((tf.shape(labels)[0], self.noise_dims), 0, 1, tf.float32) concat_real = data_utils.concatenate_images_labels(images, labels) with tf.GradientTape(persistent=True) as tape: generated_images = self.generator([labels, noise], training=False) concat_fake = data_utils.concatenate_images_labels( generated_images, labels) real_output = self.critic(concat_real, training=True) fake_output = self.critic(concat_fake, training=True) critic_loss_val = self.critic_loss(real_output, fake_output) critic_grads = tape.gradient(critic_loss_val, self.critic.trainable_variables) self.critic_optimizer.apply_gradients( zip(critic_grads, self.critic.trainable_variables)) return critic_loss_val
def take_generator_step(self): """Override function from parent class in order to handel the image - label concatenation necessary for the critic. """ labels, images = self.sample_batch_of_data() concat_real = data_utils.concatenate_images_labels(images, labels) generator_loss = self.model.train_generator(labels) predicted_images = self.model.make_generator_predictions(labels) concat_fake = data_utils.concatenate_images_labels( predicted_images, labels) self.generator_losses.append(generator_loss) real_output = self.model.critic(concat_real, training=False) fake_output = self.model.critic(concat_fake, training=False) wass_estimate = -self.model.critic_loss(real_output, fake_output) self.wass_estimates.append(wass_estimate)
def train_generator(self, labels): """Train the generator on one batch of data Args: labels (tf.Tensor): One-hot encoded image labels Returns: tf.Tensor: Generator loss for the step """ noise = tf.random.uniform((tf.shape(labels)[0], self.noise_dims), 0, 1, tf.float32) with tf.GradientTape() as tape: generated_images = self.generator([labels, noise], training=True) concat_fake = data_utils.concatenate_images_labels( generated_images, labels) fake_output = self.critic(concat_fake, training=False) generator_loss_val = self.generator_loss(fake_output) generator_grads = tape.gradient(generator_loss_val, self.generator.trainable_variables) self.generator_optimizer.apply_gradients( zip(generator_grads, self.generator.trainable_variables)) return generator_loss_val