def discriminator_trainstep(self, x_real, y, x_fake0):
        toggle_grad_D(self.discriminator, True, self.D_fix_layer)
        self.generator.train()
        self.discriminator.train()
        self.d_optimizer.zero_grad()

        # On real data
        x_real.requires_grad_()

        d_real = self.discriminator(x_real, y)
        dloss_real = self.compute_loss(d_real, 1)

        if self.reg_type == 'real' or self.reg_type == 'real_fake':
            dloss_real.backward(retain_graph=True)
            reg = self.reg_param * compute_grad2(d_real, x_real).mean()
            reg.backward()
        else:
            dloss_real.backward()

        # # On fake data
        # with torch.no_grad():
        #     x_fake = self.generator(z, y)
        #
        # x_fake0 = x_fake.detach() * 1.0
        x_fake0.requires_grad_()
        d_fake = self.discriminator(x_fake0, y)
        dloss_fake = self.compute_loss(d_fake, 0)

        if self.reg_type == 'fake' or self.reg_type == 'real_fake':
            dloss_fake.backward(retain_graph=True)
            reg = self.reg_param * compute_grad2(d_fake, x_fake0).mean()
            reg.backward()
        else:
            dloss_fake.backward()

        if self.reg_type == 'wgangp':
            reg = self.reg_param * self.wgan_gp_reg(x_real, x_fake0, y)
            reg.backward()
        elif self.reg_type == 'wgangp0':
            reg = self.reg_param * self.wgan_gp_reg(x_real, x_fake0, y, center=0.)
            reg.backward()

        self.d_optimizer.step()

        toggle_grad_D(self.discriminator, False, self.D_fix_layer)

        # Output
        dloss = (dloss_real + dloss_fake)

        if self.reg_type == 'none':
            reg = torch.tensor(0.)

        return dloss.item(), reg.item()
    def generator_trainstep(self, y, z, FLAG=500):
        assert(y.size(0) == z.size(0))
        # toggle_grad(self.generator, True)
        toggle_grad_D(self.discriminator, False, self.D_fix_layer)
        self.generator.train()
        self.discriminator.train()
        self.g_optimizer.zero_grad()

        x_fake, loss_w = self.generator(z, y)
        d_fake = self.discriminator(x_fake, y)
        gloss = self.compute_loss(d_fake, 1)
        gloss.backward()

        self.g_optimizer.step()
        # print('loss_w:---', loss_w)
        return gloss.item(), x_fake.detach()
Exemple #3
0
        discriminator = model_equal_part(discriminator, dict_D)

        for name, param in generator.named_parameters():
            if name.find('small') >= 0:
                param.requires_grad = True
            else:
                param.requires_grad = False
            if name.find('small_adafm_') >= 0:
                param.requires_grad = False
        get_parameter_number(generator)

        for param in discriminator.parameters():
            param.requires_grad = False

        #toggle_grad_G(generator, True, G_Layer_FIX)
        toggle_grad_D(discriminator, True, D_Layer_FIX)

        # Put models on gpu if needed
        generator, discriminator = generator.to(device), discriminator.to(
            device)
        g_optimizer, d_optimizer = build_optimizers(generator, discriminator,
                                                    config)

        # summary(generator, input_size=[(256,), (1,)])
        # summary(discriminator, input_size=[(3, 128, 128), (1,)])

        # Register modules to checkpoint
        checkpoint_io.register_modules(
            generator=generator,
            discriminator=discriminator,
            g_optimizer=g_optimizer,