def _optimize_generator(self,
                            fake_images,
                            prev_image,
                            condition,
                            objects,
                            aux_reg,
                            mask,
                            mu,
                            logvar,
                            seg_reg=0,
                            seg_fake=None,
                            seg_gt=None):
        self.generator.zero_grad()
        d_fake, aux_fake, _ = self.discriminator(fake_images, condition,
                                                 prev_image)
        g_loss = self._generator_masked_loss(d_fake, aux_fake, objects,
                                             aux_reg, mu, logvar, mask,
                                             seg_reg, seg_fake, seg_gt)

        g_loss.backward(retain_graph=True)
        gen_grad_norm = _recurrent_gan.get_grad_norm(
            self.generator.parameters())

        self.generator_optimizer.step()

        g_loss_scalar = g_loss.item()
        gen_grad_norm_scalar = gen_grad_norm.item()

        del g_loss
        del gen_grad_norm
        gc.collect()

        return g_loss_scalar, gen_grad_norm_scalar
Exemple #2
0
    def _optimize_rnn(self):
        torch.nn.utils.clip_grad_norm_(self.rnn.parameters(), self.cfg.grad_clip)
        rnn_grad_norm = _recurrent_gan.get_grad_norm(self.rnn.parameters())
        self.rnn_optimizer.step()
        self.rnn.zero_grad()

        gru_grad_norm = None
        torch.nn.utils.clip_grad_norm_(self.sentence_encoder.parameters(), self.cfg.grad_clip)
        gru_grad_norm = _recurrent_gan.get_grad_norm(self.sentence_encoder.parameters())
        self.sentence_encoder_optimizer.step()
        self.sentence_encoder.zero_grad()

        ce_grad_norm = _recurrent_gan.get_grad_norm(self.condition_encoder.parameters())
        ie_grad_norm = _recurrent_gan.get_grad_norm(self.image_encoder.parameters())
        self.feature_encoders_optimizer.step()
        self.condition_encoder.zero_grad()
        self.image_encoder.zero_grad()
        return rnn_grad_norm, gru_grad_norm, ce_grad_norm, ie_grad_norm
    def _optimize_discriminator(self,
                                real_images,
                                fake_images,
                                prev_image,
                                condition,
                                mask,
                                objects,
                                gp_reg=0,
                                aux_reg=0):
        """Discriminator is updated every step independent of batch_size
        RNN and the generator
        """
        wrong_images = torch.cat((real_images[1:], real_images[0:1]), dim=0)
        wrong_prev = torch.cat((prev_image[1:], prev_image[0:1]), dim=0)

        self.discriminator.zero_grad()
        real_images.requires_grad_()

        d_real, aux_real, _ = self.discriminator(real_images, condition,
                                                 prev_image)
        d_fake, aux_fake, _ = self.discriminator(fake_images, condition,
                                                 prev_image)
        d_wrong, _, _ = self.discriminator(wrong_images, condition, wrong_prev)

        d_loss, aux_loss = self._discriminator_masked_loss(
            d_real, d_fake, d_wrong, aux_real, aux_fake, objects, aux_reg,
            mask)

        d_loss.backward(retain_graph=True)
        if gp_reg:
            reg = gp_reg * self._masked_gradient_penalty(
                d_real, real_images, mask)
            reg.backward(retain_graph=True)

        grad_norm = _recurrent_gan.get_grad_norm(
            self.discriminator.parameters())
        self.discriminator_optimizer.step()

        d_loss_scalar = d_loss.item()
        d_real_np = d_real.cpu().data.numpy()
        d_fake_np = d_fake.cpu().data.numpy()
        aux_loss_scalar = aux_loss.item() if isinstance(
            aux_loss, torch.Tensor) else aux_loss
        grad_norm_scalar = grad_norm.item()
        del d_loss
        del d_real
        del d_fake
        del aux_loss
        del grad_norm
        gc.collect()

        return d_loss_scalar, d_real_np, d_fake_np, aux_loss_scalar, grad_norm_scalar