Exemple #1
0
class Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        # auto-encoder for domain a
        self.trait_dim = hyperparameters['gen']['trait_dim']

        self.gen_a = VAEGen(hyperparameters['input_dim'],
                            hyperparameters['basis_encoder_dims'],
                            hyperparameters['trait_encoder_dims'],
                            hyperparameters['decoder_dims'], self.trait_dim)
        # auto-encoder for domain b
        self.gen_b = VAEGen(hyperparameters['input_dim'],
                            hyperparameters['basis_encoder_dims'],
                            hyperparameters['trait_encoder_dims'],
                            hyperparameters['decoder_dims'], self.trait_dim)
        # discriminator for domain a
        self.dis_a = Discriminator(hyperparameters['input_dim'],
                                   hyperparameters['dis_dims'], 1)
        # discriminator for domain b
        self.dis_b = Discriminator(hyperparameters['input_dim'],
                                   hyperparameters['dis_dims'], 1)

        # fix the noise used in sampling
        self.trait_a = torch.randn(8, self.trait_dim, 1, 1)
        self.trait_b = torch.randn(8, self.trait_dim, 1, 1)

        # Setup the optimizers
        dis_params = list(self.dis_a.parameters()) + \
            list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + \
            list(self.gen_b.parameters())
        for _p in gen_params:
            print(_p.data.shape)
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.gen_a.apply(weights_init('gaussian'))
        self.gen_b.apply(weights_init('gaussian'))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        trait_a = Variable(self.trait_a)
        trait_b = Variable(self.trait_b)
        basis_a, trait_a_fake = self.gen_a.encode(x_a)
        basis_b, trait_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(basis_b, trait_a)
        x_ab = self.gen_b.decode(basis_a, trait_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        trait_a = Variable(torch.randn(x_a.size(0), self.trait_dim))
        trait_b = Variable(torch.randn(x_b.size(0), self.trait_dim))
        # encode
        basis_a, trait_a_prime = self.gen_a.encode(x_a)
        basis_b, trait_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(basis_a, trait_a_prime)
        x_b_recon = self.gen_b.decode(basis_b, trait_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(basis_b, trait_a)
        x_ab = self.gen_b.decode(basis_a, trait_b)
        # encode again
        basis_b_recon, trait_a_recon = self.gen_a.encode(x_ba)
        basis_a_recon, trait_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(
            basis_a_recon,
            trait_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            basis_b_recon,
            trait_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_trait_a = self.recon_criterion(
            trait_a_recon, trait_a)
        self.loss_gen_recon_trait_b = self.recon_criterion(
            trait_b_recon, trait_b)
        self.loss_gen_recon_basis_a = self.recon_criterion(
            basis_a_recon, basis_a)
        self.loss_gen_recon_basis_b = self.recon_criterion(
            basis_b_recon, basis_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)

        # total loss
        self.loss_gen_total = hyperparameters[
            'gan_w'] * self.loss_gen_adv_a + \
            hyperparameters['gan_w'] * self.loss_gen_adv_b + \
            hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
            hyperparameters['recon_trait_w'] * self.loss_gen_recon_trait_a + \
            hyperparameters['recon_basis_w'] * self.loss_gen_recon_basis_a + \
            hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
            hyperparameters['recon_trait_w'] * self.loss_gen_recon_trait_b + \
            hyperparameters['recon_basis_w'] * self.loss_gen_recon_basis_b + \
            hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
            hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b

        self.loss_gen_total.backward()
        self.gen_opt.step()

    # def sample(self, x_a, x_b):
    #     self.eval()
    #     s_a1 = Variable(self.s_a)
    #     s_b1 = Variable(self.s_b)
    #     s_a2 = Variable(torch.randn(x_a.size(0), self.trait_dim, 1, 1))
    #     s_b2 = Variable(torch.randn(x_b.size(0), self.trait_dim, 1, 1))
    #     x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
    #     for i in range(x_a.size(0)):
    #         c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
    #         c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
    #         x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
    #         x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
    #         x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
    #         x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
    #         x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
    #         x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
    #     x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
    #     x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
    #     x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
    #     self.train()
    #     return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        trait_a = Variable(torch.randn(x_a.size(0), self.trait_dim))
        trait_b = Variable(torch.randn(x_b.size(0), self.trait_dim))
        # encode
        basis_a, _ = self.gen_a.encode(x_a)
        basis_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(basis_b, trait_a)
        x_ab = self.gen_b.decode(basis_a, trait_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba, x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab, x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * \
            self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Exemple #2
0
def main(args, dataloader):

    # define the G and D
    netG = DCGenerator(nz=args.nz, ngf=args.ngf, nc=args.nc).cuda()
    netG.apply(weight_init)
    print(netG)

    netD = Discriminator(nc=args.nc, ndf=args.ndf).cuda()
    netD.apply(weight_init)
    print(netD)

    # define the loss criterion
    criterion = nn.BCELoss()

    # sample a fixed noise vector that will be used to visualize the training
    # progress
    fixed_noise = torch.randn(64, args.nz, 1, 1).cuda()

    # define the ground truth labels.
    real_labels = 1  # for the real images
    fake_labels = 0  # for the fake images

    # define the optimizers, one for each network
    netD_optimizer = optim.Adam(params=netD.parameters(), lr=args.lr, betas=(0.5, 0.999))
    netG_optimizer = optim.Adam(params=netG.parameters(), lr=args.lr, betas=(0.5, 0.999))

    # sample two fixed noise vectors and do a linear interpolation between them
    # to get the intermediate noise vectors. We will generate samples for the interpolated
    # noise vectors to see effect of interpolation in the latent space. (See later!)
    z_1 = torch.randn(1, args.nz, 1, 1)
    z_2 = torch.randn(1, args.nz, 1, 1)
    fixed_interpolate = []
    for i in range(64):
        lambda_interp = i / 63
        z_interp = z_1 * (1 - lambda_interp) + lambda_interp * z_2
        fixed_interpolate.append(z_interp)
    fixed_interpolate = torch.cat(fixed_interpolate, dim=0).cuda()


    # Training loop
    iters = 0

    # for each epoch
    for epoch in range(args.num_epochs):
        # iterate through the data loader
        for i, data in enumerate(dataloader, 0):

            ## Discriminator training ##
            # maximize log(D(x)) + log(1 - D(G(x)))

            # The discriminator will be updated once with the real images
            # and once with the fake images. This is achieved by first computing
            # the gradients with the real images (the first term in the D loss function),
            # and then with the fake images generated by the G (second loss term).
            # Only after that the optimizer.step() will be done, which will update the
            # weights of the D.
            # IMPORTANT to note that when the D is updated, the G is kept frozen.
            # Gradients are calculated with loss.backward().

            # train D with real images
            netD.train()
            netD.zero_grad()
            real_images = data[0].cuda()
            bs = real_images.shape[0]
            label = torch.full((bs,), real_labels).cuda()
            noise_1 = torch.Tensor(real_images.shape).normal_(0, 0.1 * (args.num_epochs - epoch) / args.num_epochs).cuda()
            output = netD(real_images + noise_1).view(-1)
            # calculate loss on real images. It pushes the D's output for real images
            # close to 1
            errD_real = criterion(output, label)
            # calculate gradients for D
            errD_real.backward()
            # track D outputs for real images
            D_x = output.mean().item()

            # train D with fake images
            # sample a batch of noise vectors
            noise = torch.randn(bs, args.nz, 1, 1).cuda()
            # generate fake data
            fake_images = netG(noise)
            label.fill_(fake_labels)
            # run the fake images through the discriminator.
            # IMPORTANT to detach the fake_images because we do not need gradients
            # of the G activations wrt to the G weights.
            noise_2 = torch.Tensor(real_images.shape).normal_(0, 0.1 * (args.num_epochs - epoch) / args.num_epochs).cuda()
            output = netD(fake_images.detach() + noise_2).view(-1)
            # calculate loss on the fake images. It pushes the D's output for fake
            # images close to 0
            errD_fake = criterion(output, label)
            # calculate the gradients for D
            errD_fake.backward()
            errD = (errD_real + errD_fake)
            # track D outputs for fake images
            D_G_x_1 = output.mean().item()

            # update the D weights with the gradients accumulated
            netD_optimizer.step()

            ## Generator training ##
            # minimize log(1 - D(G(x)))
            # But such a formulation provides no gradient during the early stages of
            # training and hence its is reformulated as:
            # maximize log(D(G(x)))

            # during the G training the D is kept fixed
            netG.train()
            netG.zero_grad()
            # real_labels because the G wants to make the fake images look as real as
            # possible
            label.fill_(real_labels)
            output = netD(fake_images + noise_2).view(-1)
            # calculate loss for G based on the fake images. It pushes the D's output
            # for fake images close to 1
            errG = criterion(output, label)
            # calculate the gradients for G
            errG.backward()
            # track the outputs for fake images
            D_G_x_2 = output.mean().item()

            # update the G weights with the gradients accumulated
            netG_optimizer.step()

            # print the training losses
            if iters % 50 == 0:
                print('[%3d/%d][%3d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, args.num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_x_1, D_G_x_2))

            # visualize the samples generated by the G.
            if (iters % 1000 == 0):
                out_dir = os.path.join(args.log_dir, args.run_name, 'out/')
                os.makedirs(out_dir, exist_ok=True)
                interp_dir = os.path.join(args.log_dir, args.run_name, 'interpolate/')
                os.makedirs(interp_dir, exist_ok=True)
                netG.eval()
                with torch.no_grad():
                    fake_fixed = netG(fixed_noise).cpu()
                    save_image(fake_fixed, os.path.join(out_dir, str(iters).zfill(7) + '.png'),
                               normalize=True)

                    interp_fixed = netG(fixed_interpolate).cpu()
                    save_image(interp_fixed, os.path.join(interp_dir, str(iters).zfill(7) + '.png'),
                               normalize=True)

            iters += 1
Exemple #3
0
class BiGAN(object):
    def __init__(self, args):

        self.z_dim = args.z_dim
        self.decay_rate = args.decay_rate
        self.learning_rate = args.learning_rate
        self.model_name = args.model_name
        self.batch_size = args.batch_size

        #initialize networks
        self.Generator = Generator(self.z_dim).cuda()
        self.Encoder = Encoder(self.z_dim).cuda()
        self.Discriminator = Discriminator().cuda()

        #set optimizers for all networks
        self.optimizer_G_E = torch.optim.Adam(
            list(self.Generator.parameters()) +
            list(self.Encoder.parameters()),
            lr=self.learning_rate,
            betas=(0.5, 0.999))

        self.optimizer_D = torch.optim.Adam(self.Discriminator.parameters(),
                                            lr=self.learning_rate,
                                            betas=(0.5, 0.999))

        #initialize network weights
        self.Generator.apply(weights_init)
        self.Encoder.apply(weights_init)
        self.Discriminator.apply(weights_init)

    def train(self, data):

        self.Generator.train()
        self.Encoder.train()
        self.Discriminator.train()

        self.optimizer_G_E.zero_grad()
        self.optimizer_D.zero_grad()

        #get fake z_data for generator
        self.z_fake = torch.randn((self.batch_size, self.z_dim))

        #send fake z_data through generator to get fake x_data
        self.x_fake = self.Generator(self.z_fake.detach())

        #send real data through encoder to get real z_data
        self.z_real = self.Encoder(data)

        #send real x and z data into discriminator
        self.out_real = self.Discriminator(data, z_real.detach())

        #send fake x and z data into discriminator
        self.out_fake = self.Discriminator(x_fake.detach(), z_fake.detach())

        #compute discriminator loss
        self.D_loss = nn.BCELoss()

        #compute generator/encoder loss
        self.G_E_loss = nn.BCELoss()

        #compute discriminator gradiants and backpropogate
        self.D_loss.backward()
        self.optimizer_D.step()

        #compute generator/encoder gradiants and backpropogate
        self.G_E_loss.backward()
        self.optimizer_G_E.step()
class LSGANs_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(LSGANs_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.encoder = Encoder(hyperparameters['input_dim_a'],
                               hyperparameters['gen'])
        self.decoder = Decoder(hyperparameters['input_dim_a'],
                               hyperparameters['gen'])
        self.dis_a = Discriminator()
        self.dis_b = Discriminator()
        self.interp_net_ab = Interpolator()
        self.interp_net_ba = Interpolator()
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        enc_params = list(self.encoder.parameters())
        dec_params = list(self.decoder.parameters())
        dis_a_params = list(self.dis_a.parameters())
        dis_b_params = list(self.dis_b.parameters())
        interperlator_ab_params = list(self.interp_net_ab.parameters())
        interperlator_ba_params = list(self.interp_net_ba.parameters())

        self.enc_opt = torch.optim.Adam(
            [p for p in enc_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dec_opt = torch.optim.Adam(
            [p for p in dec_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_a_opt = torch.optim.Adam(
            [p for p in dis_a_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_b_opt = torch.optim.Adam(
            [p for p in dis_b_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.interp_ab_opt = torch.optim.Adam(
            [p for p in interperlator_ab_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.interp_ba_opt = torch.optim.Adam(
            [p for p in interperlator_ba_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters)
        self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters)
        self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters)
        self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters)
        self.interp_ab_scheduler = get_scheduler(self.interp_ab_opt,
                                                 hyperparameters)
        self.interp_ba_scheduler = get_scheduler(self.interp_ba_opt,
                                                 hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        self.total_loss = 0
        self.best_iter = 0
        self.perceptural_loss = Perceptural_loss()

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()

        c_a, s_a_fake = self.encoder(x_a)
        c_b, s_b_fake = self.encoder(x_b)

        # decode (cross domain)
        s_ab_interp = self.interp_net_ab(s_a_fake, s_b_fake, self.v)
        s_ba_interp = self.interp_net_ba(s_b_fake, s_a_fake, self.v)
        x_ba = self.decoder(c_b, s_a_interp)
        x_ab = selfdecoder(c_a, s_b_interp)
        self.train()
        return x_ab, x_ba

    def zero_grad(self):
        self.dis_a_opt.zero_grad()
        self.dis_b_opt.zero_grad()
        self.dec_opt.zero_grad()
        self.enc_opt.zero_grad()
        self.interp_ab_opt.zero_grad()
        self.interp_ba_opt.zero_grad()

    def dis_update(self, x_a, x_b, hyperparameters):
        self.zero_grad()

        # encode
        c_a, s_a = self.encoder(x_a)
        c_b, s_b = self.encoder(x_b)

        # decode (cross domain)
        self.v = torch.ones(s_a.size())
        s_a_interp = self.interp_net_ba(s_b, s_a, self.v)
        s_b_interp = self.interp_net_ab(s_a, s_b, self.v)
        x_ba = self.decoder(c_b, s_a_interp)
        x_ab = self.decoder(c_a, s_b_interp)

        x_a_feature = self.dis_a(x_a)
        x_ba_feature = self.dis_a(x_ba)
        x_b_feature = self.dis_b(x_b)
        x_ab_feature = self.dis_b(x_ab)
        self.loss_dis_a = (x_ba_feature - x_a_feature).mean()
        self.loss_dis_b = (x_ab_feature - x_b_feature).mean()

        # gradient penality
        self.loss_dis_a_gp = self.dis_a.calculate_gradient_penalty(x_ba, x_a)
        self.loss_dis_b_gp = self.dis_b.calculate_gradient_penalty(x_ab, x_b)


        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \
                              hyperparameters['gan_w'] * self.loss_dis_b + \
                              hyperparameters['gan_w'] * self.loss_dis_a_gp + \
                              hyperparameters['gan_w'] * self.loss_dis_b_gp

        self.loss_dis_total.backward()
        self.total_loss += self.loss_dis_total.item()
        self.dis_a_opt.step()
        self.dis_b_opt.step()

    def gen_update(self, x_a, x_b, hyperparameters):
        self.zero_grad()

        # encode
        c_a, s_a = self.encoder(x_a)
        c_b, s_b = self.encoder(x_b)

        # decode (within domain)
        x_a_recon = self.decoder(c_a, s_a)
        x_b_recon = self.decoder(c_b, s_b)

        # decode (cross domain)
        self.v = torch.ones(s_a.size())
        s_a_interp = self.interp_net_ba(s_b, s_a, self.v)
        s_b_interp = self.interp_net_ab(s_a, s_b, self.v)
        x_ba = self.decoder(c_b, s_a_interp)
        x_ab = self.decoder(c_a, s_b_interp)

        # encode again
        c_b_recon, s_a_recon = self.encoder(x_ba)
        c_a_recon, s_b_recon = self.encoder(x_ab)

        # decode again
        x_aa = self.decoder(
            c_a_recon, s_a) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bb = self.decoder(
            c_b_recon, s_b) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aa, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bb, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0

        # perceptual loss
        self.loss_gen_vgg_a = self.perceptural_loss(
            x_a_recon, x_a) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.perceptural_loss(
            x_b_recon, x_b) if hyperparameters['vgg_w'] > 0 else 0

        self.loss_gen_vgg_aa = self.perceptural_loss(
            x_aa, x_a) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_bb = self.perceptural_loss(
            x_bb, x_b) if hyperparameters['vgg_w'] > 0 else 0

        # GAN loss
        x_ba_feature = self.dis_a(x_ba)
        x_ab_feature = self.dis_b(x_ab)
        self.loss_gen_adv_a = -x_ba_feature.mean()
        self.loss_gen_adv_b = -x_ab_feature.mean()

        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_aa + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_bb + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b

        self.loss_gen_total.backward()
        self.total_loss += self.loss_gen_total.item()
        self.dec_opt.step()
        self.enc_opt.step()
        self.interp_ab_opt.step()
        self.interp_ba_opt.step()

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ab, x_ba, x_aa, x_bb = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a = self.encoder(x_a[i].unsqueeze(0))
            c_b, s_b = self.encoder(x_b[i].unsqueeze(0))
            x_a_recon.append(self.decoder(c_a, s_a))
            x_b_recon.append(self.decoder(c_b, s_b))

            self.v = torch.ones(s_a.size())
            s_a_interp = self.interp_net_ba(s_b, s_a, self.v)
            s_b_interp = self.interp_net_ab(s_a, s_b, self.v)

            x_ab_i = self.decoder(c_a, s_b_interp)
            x_ba_i = self.decoder(c_b, s_a_interp)

            c_a_recon, s_b_recon = self.encoder(x_ab_i)
            c_b_recon, s_a_recon = self.encoder(x_ba_i)

            x_ab.append(self.decoder(c_a, s_b_interp.unsqueeze(0)))
            x_ba.append(self.decoder(c_b, s_a_interp.unsqueeze(0)))
            x_aa.append(self.decoder(c_a_recon, s_a.unsqueeze(0)))
            x_bb.append(self.decoder(c_b_recon, s_b.unsqueeze(0)))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ab, x_aa = torch.cat(x_ab), torch.cat(x_aa)
        x_ba, x_bb = torch.cat(x_ba), torch.cat(x_bb)

        self.train()

        return x_a, x_a_recon, x_ab, x_aa, x_b, x_b_recon, x_ba, x_bb

    def update_learning_rate(self):
        if self.dis_a_scheduler is not None:
            self.dis_a_scheduler.step()
        if self.dis_b_scheduler is not None:
            self.dis_b_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.enc_scheduler is not None:
            self.enc_scheduler.step()
        if self.dec_scheduler is not None:
            self.dec_scheduler.step()
        if self.interpo_ab_scheduler is not None:
            self.interpo_ab_scheduler.step()
        if self.interpo_ba_scheduler is not None:
            self.interpo_ba_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load encode
        model_name = get_model(checkpoint_dir, "encoder")
        state_dict = torch.load(model_name)
        self.encoder.load_state_dict(state_dict)

        # Load decode
        model_name = get_model(checkpoint_dir, "decoder")
        state_dict = torch.load(model_name)
        self.decoder.load_state_dict(state_dict)

        # Load discriminator a
        model_name = get_model(checkpoint_dir, "dis_a")
        state_dict = torch.load(model_name)
        self.dis_a.load_state_dict(state_dict)

        # Load discriminator a
        model_name = get_model(checkpoint_dir, "dis_b")
        state_dict = torch.load(model_name)
        self.dis_b.load_state_dict(state_dict)

        # Load interperlator ab
        model_name = get_model(checkpoint_dir, "interp_ab")
        state_dict = torch.load(model_name)
        self.interp_net_ab.load_state_dict(state_dict)

        # Load interperlator ba
        model_name = get_model(checkpoint_dir, "interp_ba")
        state_dict = torch.load(model_name)
        self.interp_net_ba.load_state_dict(state_dict)

        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.enc_opt.load_state_dict(state_dict['enc_opt'])
        self.dec_opt.load_state_dict(state_dict['dec_opt'])
        self.dis_a_opt.load_state_dict(state_dict['dis_a_opt'])
        self.dis_b_opt.load_state_dict(state_dict['dis_b_opt'])
        self.interp_ab_opt.load_state_dict(state_dict['interp_ab_opt'])
        self.interp_ba_opt.load_state_dict(state_dict['interp_ba_opt'])

        self.best_iter = state_dict['best_iter']
        self.total_loss = state_dict['total_loss']

        # Reinitilize schedulers
        self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters,
                                             self.best_iter)
        self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters,
                                             self.best_iter)
        self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters,
                                           self.best_iter)
        self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters,
                                           self.best_iter)
        self.interpo_ab_scheduler = get_scheduler(self.interp_ab_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        self.interpo_ba_scheduler = get_scheduler(self.interp_ba_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        print('Resume from iteration %d' % self.best_iter)
        return self.best_iter, self.total_loss

    def resume_iter(self, checkpoint_dir, surfix, hyperparameters):
        # Load encode
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'encoder' + surfix + '.pt'))
        self.encoder.load_state_dict(state_dict)

        # Load decode
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'decoder' + surfix + '.pt'))
        self.decoder.load_state_dict(state_dict)

        # Load discriminator a
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'dis_a' + surfix + '.pt'))
        self.dis_a.load_state_dict(state_dict)

        # # Load discriminator b
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'dis_b' + surfix + '.pt'))
        self.dis_b.load_state_dict(state_dict)

        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'interp' + surfix + '.pt'))
        # print(state_dict)
        self.interp_net_ab.load_state_dict(state_dict['ab'])
        self.interp_net_ba.load_state_dict(state_dict['ba'])

        # Load interperlator ab
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'interp_ab' + surfix + '.pt'))
        self.interp_net_ab.load_state_dict(state_dict)

        # # Load interperlator ba
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'interp_ba' + surfix + '.pt'))
        self.interp_net_ba.load_state_dict(state_dict)

        # Load optimizers
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'optimizer' + surfix + '.pt'))
        self.enc_opt.load_state_dict(state_dict['enc_opt'])
        self.dec_opt.load_state_dict(state_dict['dec_opt'])
        self.dis_a_opt.load_state_dict(state_dict['dis_a_opt'])
        self.dis_b_opt.load_state_dict(state_dict['dis_b_opt'])
        self.interp_ab_opt.load_state_dict(state_dict['interp_ab_opt'])
        self.interp_ba_opt.load_state_dict(state_dict['interp_ba_opt'])

        self.best_iter = state_dict['best_iter']
        self.total_loss = state_dict['total_loss']

        # Reinitilize schedulers
        self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters,
                                             self.best_iter)
        self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters,
                                             self.best_iter)
        self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters,
                                           self.best_iter)
        self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters,
                                           self.best_iter)
        self.interpo_ab_scheduler = get_scheduler(self.interp_ab_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        self.interpo_ba_scheduler = get_scheduler(self.interp_ba_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        print('Resume from iteration %d' % self.best_iter)
        return self.best_iter, self.total_loss

    def save_better_model(self, snapshot_dir):
        # remove sub_optimal models
        files = glob.glob(snapshot_dir + '/*')
        for f in files:
            os.remove(f)
        # Save encoder, decoder, interpolator, discriminators, and optimizers
        encoder_name = os.path.join(snapshot_dir,
                                    'encoder_%.4f.pt' % (self.total_loss))
        decoder_name = os.path.join(snapshot_dir,
                                    'decoder_%.4f.pt' % (self.total_loss))
        interp_ab_name = os.path.join(snapshot_dir,
                                      'interp_ab_%.4f.pt' % (self.total_loss))
        interp_ba_name = os.path.join(snapshot_dir,
                                      'interp_ba_%.4f.pt' % (self.total_loss))
        dis_a_name = os.path.join(snapshot_dir,
                                  'dis_a_%.4f.pt' % (self.total_loss))
        dis_b_name = os.path.join(snapshot_dir,
                                  'dis_b_%.4f.pt' % (self.total_loss))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')

        torch.save(self.encoder.state_dict(), encoder_name)
        torch.save(self.decoder.state_dict(), decoder_name)
        torch.save(self.interp_net_ab.state_dict(), interp_ab_name)
        torch.save(self.interp_net_ba.state_dict(), interp_ba_name)
        torch.save(self.dis_a_opt.state_dict(), dis_a_name)
        torch.save(self.dis_b_opt.state_dict(), dis_b_name)
        torch.save(
            {
                'enc_opt': self.enc_opt.state_dict(),
                'dec_opt': self.dec_opt.state_dict(),
                'dis_a_opt': self.dis_a_opt.state_dict(),
                'dis_b_opt': self.dis_b_opt.state_dict(),
                'interp_ab_opt': self.interp_ab_opt.state_dict(),
                'interp_ba_opt': self.interp_ba_opt.state_dict(),
                'best_iter': self.best_iter,
                'total_loss': self.total_loss
            }, opt_name)

    def save_at_iter(self, snapshot_dir, iterations):

        encoder_name = os.path.join(snapshot_dir,
                                    'encoder_%08d.pt' % (iterations + 1))
        decoder_name = os.path.join(snapshot_dir,
                                    'decoder_%08d.pt' % (iterations + 1))
        interp_ab_name = os.path.join(snapshot_dir,
                                      'interp_ab_%08d.pt' % (iterations + 1))
        interp_ba_name = os.path.join(snapshot_dir,
                                      'interp_ba_%08d.pt' % (iterations + 1))
        dis_a_name = os.path.join(snapshot_dir,
                                  'dis_a_%08d.pt' % (iterations + 1))
        dis_b_name = os.path.join(snapshot_dir,
                                  'dis_b_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir,
                                'optimizer_%08d.pt' % (iterations + 1))

        torch.save(self.encoder.state_dict(), encoder_name)
        torch.save(self.decoder.state_dict(), decoder_name)
        torch.save(self.interp_net_ab.state_dict(), interp_ab_name)
        torch.save(self.interp_net_ba.state_dict(), interp_ba_name)
        torch.save(self.dis_a_opt.state_dict(), dis_a_name)
        torch.save(self.dis_b_opt.state_dict(), dis_b_name)
        torch.save(
            {
                'enc_opt': self.enc_opt.state_dict(),
                'dec_opt': self.dec_opt.state_dict(),
                'dis_a_opt': self.dis_a_opt.state_dict(),
                'dis_b_opt': self.dis_b_opt.state_dict(),
                'interp_ab_opt': self.interp_ab_opt.state_dict(),
                'interp_ba_opt': self.interp_ba_opt.state_dict(),
                'best_iter': self.best_iter,
                'total_loss': self.total_loss
            }, opt_name)
Exemple #5
0
class fgan(object):
    """
    This class ensembles data generating process of Huber's contamination model and training process
    for estimating center parameter via F-GAN.

    Usage:
        >> f = fgan(p=100, eps=0.2, device=device, tol=1e-5)
        >> f.dist_init(true_type='Gaussian', cont_type='Gaussian', 
            cont_mean=5.0, cont_var=1.)
        >> f.data_init(train_size=50000, batch_size=500)
        >> f.net_init(d_hidden_units=[20], elliptical=False, activation_D1='LeakyReLU')
        >> f.optimizer_init(lr_d=0.2, lr_g=0.02, d_steps=5, g_steps=1)
        >> f.fit(floss='js', epochs=150, avg_epochs=25, verbose=50, show=True)

    Please refer to the Demo.ipynb for more examples.
    """
    def __init__(self, p, eps, device=None, tol=1e-5):
        """Set parameters for Huber's model epsilon
                X i.i.d ~ (1-eps) P(mu, Sigma) + eps Q, 
            where P is the real distribution, mu is the center parameter we want to 
            estimate, Q is the contamination distribution and eps is the contamination
            ratio.

        Args:
            p: dimension.
            eps: contamination ratio.
            tol: make sure the denominator is not zero.
            device: If no device is provided, it will automatically choose cpu or cuda.
        """

        self.p = p
        self.eps = eps
        self.tol = tol
        self.device = device if device is not None \
                      else torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    def dist_init(self,
                  true_type='Gaussian',
                  cont_type='Gaussian',
                  true_mean=0.0,
                  cont_mean=0.0,
                  cont_var=1,
                  cont_covmat=None):
        """
        Set parameters for distribution under Huber contaminaton models. We assume
        the center parameter of the true distribution mu is 0 and the covariance
        is indentity martix. 

        Args:
            true_type : Type of real distribution P. 'Gaussian', 'Cauchy'.
            cont_type : Type of contamination distribution Q, 'Gaussian', 'Cauchy'.
            cont_mean: center parameter for Q
            cont_var: If scatter (covariance) matrix of Q is diagonal, cont_var gives 
                      the diagonal element.
            cont_covmat: Other scatter matrix can be provided (as torch.tensor format).
                         If cont_covmat is not None, cont_var will be ignored. 
        """

        self.true_type = true_type
        self.cont_type = cont_type

        ## settings for true distribution sampler
        self.true_mean = torch.ones(self.p) * true_mean

        if true_type == 'Gaussian':
            self.t_d = MultivariateNormal(self.true_mean,
                                          covariance_matrix=torch.eye(self.p))
        elif true_type == 'Cauchy':
            self.t_normal_d = MultivariateNormal(torch.zeros(self.p),
                                                 covariance_matrix=torch.eye(
                                                     self.p))
            self.t_chi2_d = Chi2(df=1)
        else:
            raise NameError('True type must be Gaussian or Cauchy!')

        ## settings for contamination distribution sampler
        if cont_covmat is not None:
            self.cont_covmat = cont_covmat
        else:
            self.cont_covmat = torch.eye(self.p) * cont_var
        self.cont_mean = torch.ones(self.p) * cont_mean
        if cont_type == 'Gaussian':
            self.c_d = MultivariateNormal(torch.zeros(self.p),
                                          covariance_matrix=self.cont_covmat)
        elif cont_type == 'Cauchy':
            self.c_normal_d = MultivariateNormal(
                torch.zeros(self.p), covariance_matrix=self.cont_covmat)
            self.c_chi2_d = Chi2(df=1)
        else:
            raise NameError('Cont type must be Gaussian or Cauchy!')

    def _sampler(self, n):
        """ Sampler and it will return a [n, p] torch tensor. """

        if self.true_type == 'Gaussian':
            t_x = self.t_d.sample((n, ))
        elif self.true_type == 'Cauchy':
            t_normal_x = self.t_normal_d.sample((n, ))
            t_chi2_x = self.t_chi2_d.sample((n, ))
            t_x = t_normal_x / (torch.sqrt(t_chi2_x.view(-1, 1)) + self.tol)

        if self.cont_type == 'Gaussian':
            c_x = self.c_d.sample((n, )) + self.cont_mean.view(1, -1)
        elif self.cont_type == 'Cauchy':
            c_normal_x = self.c_normal_d.sample((n, ))
            c_chi2_x = self.c_chi2_d.sample((n, ))
            c_x = c_normal_x / (torch.sqrt(c_chi2_x.view(-1, 1)) + self.tol) +\
                  self.cont_mean.view(1, -1)

        s = (torch.rand(n) < self.eps).float()
        x = (t_x.transpose(1, 0) * (1 - s) +
             c_x.transpose(1, 0) * s).transpose(1, 0)

        return x

    def data_init(self, train_size=50000, batch_size=100):
        self.Xtr = self._sampler(train_size)
        self.batch_size = batch_size
        self.poolset = PoolSet(self.Xtr)
        self.dataloader = DataLoader(self.poolset,
                                     batch_size=self.batch_size,
                                     shuffle=True)

    def net_init(self,
                 d_hidden_units,
                 use_logistic_regression=False,
                 init_weights=None,
                 init_eta=0.0,
                 use_median_init_G=True,
                 elliptical=False,
                 g_input_dim=10,
                 g_hidden_units=[10, 10],
                 activation_D1='Sigmoid',
                 verbose=True):
        """
        Settings for Discriminator and Generator.

        Args:
            d_hidden_units: a list of hidden units for Discriminator, 
                            e.g. d_hidden_units=[10, 5], then the discrimintor has
                            structure p (input) - 10 - 5 - 1 (output).
            elliptical: Boolean. If elliptical == False, 
                            G_1(x|b) = x + b,
                        where b will be learned and x ~ Gaussian/Cauchy(0, I_p) 
                        according to the true distribution.
                        If elliptical = True,
                            G_2(t, u|b) = g_2(t)u + b,
                        where G_2(t, x|b) generates the family of elliptical 
                        distribution, t ~ Normal(0, I) and u ~ Uniform(\\|u\\|_2 = 1)
            g_input_dim: (Even) number. When elliptical == True, the dimension of input for 
                         g_2(t) need to be provided. 
            g_hidden_units: A list of hidden units for g_2(t). When elliptical == True, 
                            structure of g_2(t) need to be provided. 
                            e.g. g_hidden_units = [24, 12, 8], then g_2(t) has structure
                            g_input_dim - 24 - 12 - 8 - p.
            activation_D1: 'Sigmoid', 'ReLU' or 'LeakyReLU'. The first activation 
                            function after the input layer. Especially when 
                            true_type == 'Cauchy', Sigmoid activation is preferred.
            verbose: Boolean. If verbose == True, initial error 
                        \\|\\hat{\\mu}_0 - \\mu\\|_2
                     will be printed.
        """
        self.elliptical = elliptical
        self.g_input_dim = g_input_dim

        if self.elliptical:
            assert (g_input_dim %
                    2 == 0), 'g_input_dim should be an even number'
            self.netGXi = GeneratorXi(input_dim=g_input_dim,
                                      hidden_units=g_hidden_units).to(
                                          self.device)

        self.netG = Generator(p=self.p,
                              elliptical=self.elliptical).to(self.device)

        # Initialize center parameter with sample median.
        if use_median_init_G:
            self.netG.bias.data = torch.median(self.Xtr,
                                               dim=0)[0].to(self.device)
        else:
            self.netG.bias.data = (torch.ones(self.p) * init_eta).to(
                self.device)

        self.mean_err_init = np.linalg.norm(self.netG.bias.data.cpu().numpy() -\
                                            self.true_mean.numpy())
        if verbose:
            print('Initialize Mean Error: %.4f' % self.mean_err_init)

        ## Initialize discrminator and g_2(t) when ellpitical == True
        if use_logistic_regression:
            self.netD = LogisticRegression(p=self.p).to(self.device)

        else:
            self.netD = Discriminator(p=self.p,
                                      hidden_units=d_hidden_units,
                                      activation_1=activation_D1).to(
                                          self.device)

        weights_init_netD = partial(weights_init, value=init_weights)
        self.netD.apply(weights_init_netD)

        if (self.elliptical):
            self.netGXi.apply(weights_init_xavier)

    def optimizer_init(self, lr_d, lr_g, d_steps, g_steps, type_opt='SGD'):
        """
        Settings for optimizer.

        Args:
            lr_d: learning rate for discrimintaor.
            lr_g: learning rate for generator.
            d_steps: number of steps of discriminator per discriminator iteration.
            g_steps: number of steps of generator per generator iteration.

        """
        if type_opt == 'SGD':
            self.optG = optim.SGD(self.netG.parameters(), lr=lr_g)
            if self.elliptical:
                self.optGXi = optim.SGD(self.netGXi.parameters(), lr=lr_g)
            self.optD = optim.SGD(self.netD.parameters(), lr=lr_d)
        else:
            self.optG = optim.Adam(self.netG.parameters(), lr=lr_g)
            if self.elliptical:
                self.optGXi = optim.Adam(self.netGXi.parameters(), lr=lr_g)
            self.optD = optim.Adam(self.netD.parameters(), lr=lr_d)
        self.g_steps = g_steps
        self.d_steps = d_steps

    def fit(self,
            floss='js',
            epochs=20,
            avg_epochs=10,
            use_inverse_gaussian=True,
            verbose=25):
        """
        Training process.
        
        Args:
            floss: 'js' or 'tv'. For JS-GAN, we consider the original GAN with 
                   Jensen-Shannon divergence and for TV-GAN, total variation will be
                   used.
            epochs: Number. Number of epochs for training.
            avg_epochs: Number. An average estimation using the last certain epochs.
            use_use_inverse_gaussian: Boolean. If elliptical == True, \\xi generator,
                                  g_2(t) takes random vector t as input and outputs
                                  \\xi samples. If use_use_inverse_gaussian == True, we take
                                  t = (t1, t2), where t1 ~ Normal(0, I_(d/2)) and
                                  t2 ~ 1/Normal(0, I_(d/2)), 
                                  otherwise, t ~ Normal(0, I_d).
            verbose: Number. Print intermediate result every certain epochs.
            show: Boolean. If show == True, final result will be printed after training.
        """
        assert floss in ['js', 'tv'], 'floss must be \'js\' or \'tv\''
        if floss == 'js':
            criterion = nn.BCEWithLogitsLoss()
        self.floss = floss
        self.loss_D = []
        self.loss_G = []
        self.mean_err_record = []
        self.mean_est_record = []
        current_d_step = 1

        for ep in range(epochs):
            loss_D_ep = []
            loss_G_ep = []
            for _, data in enumerate(self.dataloader):
                ## update D
                self.netD.train()
                self.netD.zero_grad()
                ## discriminator loss
                x_real = data.to(self.device)
                feat_real, d_real_score = self.netD(x_real)
                if (floss == 'js'):
                    one_b = torch.ones_like(d_real_score).to(self.device)
                    d_real_loss = criterion(d_real_score, one_b)
                elif floss == 'tv':
                    d_real_loss = -torch.sigmoid(d_real_score).mean()
                #d_real_loss = criterion(d_real_score, one_b)
                ## generator loss
                z_b = torch.zeros(data.shape[0], self.p).to(self.device)
                if self.elliptical:
                    if use_inverse_gaussian:
                        xi_b1 = torch.zeros(data.shape[0], self.g_input_dim //
                                            2).to(self.device)
                        xi_b2 = torch.zeros(data.shape[0], self.g_input_dim //
                                            2).to(self.device)
                    else:
                        xi_b = torch.zeros(data.shape[0],
                                           self.g_input_dim).to(self.device)

                if self.elliptical:
                    z_b.normal_()
                    z_b.div_(z_b.norm(2, dim=1).view(-1, 1) + self.tol)
                    if use_inverse_gaussian:
                        xi_b1.normal_()
                        xi_b2.normal_()
                        xi_b2.data = 1 / (torch.abs(xi_b2.data) + self.tol)
                        xi = self.netGXi(torch.cat([xi_b1, xi_b2],
                                                   dim=1)).view(
                                                       self.batch_size, -1)
                    else:
                        xi_b.normal_()
                        xi = self.netGXi(xi_b).view(self.batch_size, -1)
                    x_fake = self.netG(z_b, xi).detach()
                elif (self.true_type == 'Cauchy'):
                    z_b.normal_()
                    z_b.data.div_(
                        torch.sqrt(self.t_chi2_d.sample((self.batch_size,
                                                         1))).to(self.device) +
                        self.tol)
                    x_fake = self.netG(z_b).detach()
                elif self.true_type == 'Gaussian':
                    x_fake = self.netG(z_b.normal_()).detach()
                feat_fake, d_fake_score = self.netD(x_fake)
                if floss == 'js':
                    one_b = torch.ones_like(d_fake_score).to(self.device)
                    d_fake_loss = criterion(d_fake_score, 1 - one_b)
                elif floss == 'tv':
                    d_fake_loss = torch.sigmoid(d_fake_score).mean()
                d_loss = d_real_loss + d_fake_loss
                d_loss.backward()
                loss_D_ep.append(d_loss.cpu().item())
                self.optD.step()
                if current_d_step < self.d_steps:
                    current_d_step += 1
                    continue
                else:
                    current_d_step = 1

                ## update G
                self.netD.eval()
                for _ in range(self.g_steps):
                    self.netG.zero_grad()
                    if self.elliptical:
                        self.netGXi.zero_grad()
                        z_b.normal_()
                        z_b.div_(z_b.norm(2, dim=1).view(-1, 1) + self.tol)
                        if use_inverse_gaussian:
                            xi_b1.normal_()
                            xi_b2.normal_()
                            xi_b2.data = 1 / (torch.abs(xi_b2.data) + self.tol)
                            xi = self.netGXi(torch.cat([xi_b1, xi_b2],
                                                       dim=1)).view(
                                                           self.batch_size, -1)
                        else:
                            xi_b.normal_()
                            xi = self.netGXi(xi_b).view(self.batch_size, -1)
                        x_fake = self.netG(z_b, xi)
                    elif self.true_type == 'Gaussian':
                        x_fake = self.netG(z_b.normal_())
                    elif (self.true_type == 'Cauchy'):
                        z_b.normal_()
                        z_b.data.div_(
                            torch.sqrt(
                                self.t_chi2_d.sample((self.batch_size,
                                                      1))).to(self.device) +
                            self.tol)
                        x_fake = self.netG(z_b)
                    feat_fake, g_fake_score = self.netD(x_fake)
                    if (floss == 'js'):
                        one_b = torch.ones_like(g_fake_score).to(self.device)
                        g_fake_loss = -criterion(g_fake_score, 1 - one_b)
                        g_fake_loss.backward()
                        loss_G_ep.append(-g_fake_loss.cpu().item())
                    elif floss == 'tv':
                        g_fake_loss = -torch.sigmoid(g_fake_score).mean()
                        g_fake_loss.backward()
                        loss_G_ep.append(g_fake_loss.cpu().item())
                    self.optG.step()
                    if self.elliptical:
                        self.optGXi.step()
            ## Record intermediate error during training for monitoring.
            self.mean_err_record.append(
                (self.netG.bias.data -
                 self.true_mean.to(self.device)).norm(2).item())
            ## Record intermediate estimation during training for averaging.
            if (ep >= (epochs - avg_epochs)):
                self.mean_est_record.append(self.netG.bias.data.clone().cpu())
            self.loss_D.append(np.mean(loss_D_ep))
            self.loss_G.append(np.mean(loss_G_ep))
            ## Print intermediate result every verbose epoch.
            if ((ep + 1) % verbose == 0):
                print('Epoch:%d, LossD/G:%.4f/%.4f, Error(Mean):%.4f' %
                      (ep + 1, self.loss_D[-1], self.loss_G[-1],
                       self.mean_err_record[-1]))
        ## Final results
        self.mean_avg = sum(self.mean_est_record[-avg_epochs:])/\
                            len(self.mean_est_record[-avg_epochs:])
        self.mean_err_avg = (self.mean_avg -
                             self.true_mean.cpu()).norm(2).item()
        self.mean_err_last = (self.netG.bias.data -
                              self.true_mean.to(self.device)).norm(2).item()

    def report_results(self,
                       figsize=(6, 4),
                       show_plots=True,
                       save_g_loss=None,
                       save_d_loss=None,
                       save_error=None,
                       save_distribution=None):
        ## Print the final results.
        self.netD.eval()
        ## Scores of true distribution from 10,000 samples.
        if self.true_type == 'Gaussian':
            t_x = self.t_d.sample((10000, ))
        elif self.true_type == 'Cauchy':
            t_normal_x = self.t_normal_d.sample((10000, ))
            t_chi2_x = self.t_chi2_d.sample((10000, ))
            t_x = t_normal_x / (torch.sqrt(t_chi2_x.view(-1, 1)) + self.tol)
        self.true_D = self.netD(t_x.to(self.device))[1].detach().cpu().numpy()
        ## Scores of contamination distribution from 10,000 samples.
        if self.cont_type == 'Gaussian':
            c_x = self.c_d.sample((10000, )) + self.cont_mean.view(1, -1)
        elif self.cont_type == 'Cauchy':
            c_normal_x = self.c_normal_d.sample((10000, ))
            c_chi2_x = self.c_chi2_d.sample((10000, ))
            c_x = c_normal_x / (torch.sqrt(c_chi2_x.view(-1, 1)) + self.tol) +\
                      self.cont_mean.view(1, -1)
        self.cont_D = self.netD(c_x.to(self.device))[1].detach().cpu().numpy()
        ## Scores of 10,000 generating samples.
        if self.elliptical:
            t_z = torch.randn(10000, self.p).to(self.device)
            t_z.div_(t_z.norm(2, dim=1).view(-1, 1) + self.tol)
            if use_inverse_gaussian:
                t_xi1 = torch.randn(10000,
                                    self.g_input_dim // 2).to(self.device)
                t_xi2 = torch.randn(10000,
                                    self.g_input_dim // 2).to(self.device)
                t_xi2 = 1 / (torch.abs(t_xi2.data) + self.tol)
                xi = self.netGXi(torch.cat([t_xi1, t_xi2],
                                           dim=1)).view(10000, -1)
            else:
                t_xi = torch.randn(10000, self.g_input_dim).to(self.device)
                xi = self.netGXi(t_xi).view(10000, -1)
            g_x = self.netG(t_z, xi).detach()
        elif self.true_type == 'Gaussian':
            g_x = self.netG(torch.randn(10000, self.p).to(self.device))
        elif (self.true_type == 'Cauchy'):
            g_z = torch.randn(10000, self.p).to(self.device)
            g_z.data.div_(
                torch.sqrt(self.t_chi2_d.sample((10000, 1))).to(self.device) +
                self.tol)
            g_x = self.netG(g_z)
        self.gene_D = self.netD(g_x)[1].detach().cpu().numpy()
        ## Some useful prints and plots

        print('Avg error: %.4f, Last error: %.4f' %
              (self.mean_err_avg, self.mean_err_last))
        grand_mean = (1 -
                      self.eps) * self.true_mean + self.eps * self.cont_mean
        grand_mean_err = (grand_mean.to(self.device) -
                          self.true_mean.to(self.device)).norm(2).item()
        grand_mean_err_record = [
            grand_mean_err for i in range(len(self.mean_err_record))
        ]

        if self.p == 1:
            print("True mean = %.4f" % (self.true_mean.item()))
            print("Contamination mean = %.4f" % (self.cont_mean.item()))
            print("Result mean = %.4f" % (self.netG.bias.data.item()))
            print("Grand mean = %.4f" % (grand_mean.item()))

        loss_type = 'Total Variation' if self.floss == 'tv' else 'Jensen-Shannon'

        fig, ax = plt.subplots(figsize=figsize)
        ax.plot(self.loss_D)
        ax.grid(True)
        ax.set_title(f'Discriminator loss, type = {loss_type}')
        ax.set_xlabel("epoch num")
        ax.set_ylabel("Loss")
        if save_d_loss is not None:
            plt.savefig(save_d_loss)

        if show_plots:
            plt.show()
        else:
            plt.close(fig)

        fig, ax = plt.subplots(figsize=figsize)
        ax.plot(self.loss_G)
        ax.grid(True)
        ax.set_title(f'Generator loss, type = {loss_type}')
        ax.set_xlabel("epoch num")
        ax.set_ylabel("Loss")
        if save_g_loss is not None:
            plt.savefig(save_g_loss)

        if show_plots:
            plt.show()
        else:
            plt.close(fig)

        fig, ax = plt.subplots(figsize=figsize)
        ax.plot(self.mean_err_record, label='mean error process')
        ax.plot(grand_mean_err_record, label='grand mean error')
        ax.legend()
        ax.grid(True)
        ax.set_title(
            r'$\ell_{2}$ error in prediction of mean for true distribution')
        ax.set_xlabel("epoch num")
        ax.set_ylabel(r"$\|\eta_{est} - \eta_{true}\|_{2}$")
        if save_error is not None:
            plt.savefig(save_error)

        if show_plots:
            plt.show()
        else:
            plt.close(fig)

        fig, ax = plt.subplots(figsize=figsize)
        d_distributions = {}
        d_distributions['true distribution'] = self.true_D[(self.true_D < 25) &
                                                           (self.true_D > -25)]
        d_distributions['generated distribution'] = self.gene_D[
            (self.gene_D < 25) & (self.gene_D > -25)]
        d_distributions['contamination distribution'] = self.cont_D[
            (self.cont_D < 25) & (self.cont_D > -25)]

        g = sns.kdeplot(ax=ax, data=d_distributions)
        ax.set_xlabel(r"$D(x)$")
        ax.set_ylabel("Density")
        ax.grid(True)

        ax.set_title(r'Discriminator distribution, $D(x)$')
        if save_distribution is not None:
            plt.savefig(save_distribution)

        if show_plots:
            plt.show()
        else:
            plt.close(fig)
discriminator = Discriminator().to(device)

generator = torch.nn.DataParallel(generator,
                                  list(range(torch.cuda.device_count())))
discriminator = torch.nn.DataParallel(discriminator,
                                      list(range(torch.cuda.device_count())))

if opt['load_model']:
    if os.path.isfile("saved_models/generator.pth"):
        generator.load_state_dict(torch.load("saved_models/generator.pth"))
    if os.path.isfile("saved_models/discriminator.pth"):
        discriminator.load_state_dict(
            torch.load("saved_models/discriminator.pth"))
else:
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

optimizer_G = torch.optim.Adam(generator.parameters(),
                               lr=opt["lr"],
                               betas=(opt["b1"], opt["b2"]))
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                               lr=opt["lr"],
                               betas=(opt["b1"], opt["b2"]))

for epoch in range(opt['n_epochs']):
    for i in range(25000 // opt['batch_size']):

        y, x = next(data.data_generator())

        real_A = Variable(x.type(Tensor))
        real_B = Variable(y.type(Tensor))
Exemple #7
0
training_log = {
    'time': time.time(),
    'rounds': [],
}


def weights_init(m):
    # pass
    classname = m.__class__.__name__
    if 'Linear' in classname:
        nn.init.normal_(m.weight.data, 0.048, 0.48)


generator.apply(weights_init)
discriminator.apply(weights_init)

color_file = open('data/color.txt')
color_file.seek(0, os.SEEK_END)
color_file_size = color_file.tell()


def random_color_file_seek():
    color_file.seek(random.randint(0, color_file_size))
    color_file.readline()


def get_real_color_tensor():
    line = color_file.readline()
    if line == '':
        color_file.seek(0)
Exemple #8
0
train_dataset = datasets.CIFAR10('./data/',
                                 train=True,
                                 download=True,
                                 transform=train_transforms)
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True)

# 3. Networks
G = Generator(image_channels).to(device)
D = Discriminator(image_channels).to(device)

G.apply(initialize_weights)
D.apply(initialize_weights)

loss_fn = nn.BCELoss().to(device)

# 4. Optimizers
optimizer_for_G = torch.optim.Adam(G.parameters(),
                                   lr=learning_rate,
                                   betas=(beta1, beta2))
optimizer_for_D = torch.optim.Adam(D.parameters(),
                                   lr=learning_rate,
                                   betas=(beta1, beta2))

# 5. Training
fake_gt = np.zeros((batch_size, 1, 1, 1), dtype=np.float32)
fake_gt = torch.FloatTensor(fake_gt).to(device)
fake_gt = torch.autograd.Variable(fake_gt, requires_grad=False)
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
        discriminator.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.discriminator_save)))

    """
    variable definition
    """
    real_domain_labels = 1
    fake_domain_labels = 0

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)

    domain_labels = torch.LongTensor(FLAGS.batch_size)
    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)

    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()

    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        domain_labels = domain_labels.cuda()
        style_latent_space = style_latent_space.cuda()

    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    discriminator_optimizer = optim.Adam(
        list(discriminator.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    generator_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write('Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\t')
            log.write('Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n')

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist', download=True, train=True, transform=transform_config)
    loader = cycle(DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True))

    # initialise variables
    discriminator_accuracy = 0.

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print('Epoch #' + str(epoch) + '..........................................................................')

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_1 = encoder(Variable(X_1))
            style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

            kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            _, __, class_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(style_1, class_1)
            reconstructed_X_2 = decoder(style_1, class_2)

            reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_1))
            reconstruction_error_2.backward()

            reconstruction_error = reconstruction_error_1 + reconstruction_error_2
            kl_divergence_error = kl_divergence_loss_1

            auto_encoder_optimizer.step()

            # B. run the generator
            for i in range(FLAGS.generator_times):

                generator_optimizer.zero_grad()

                image_batch_1, _, __ = next(loader)
                image_batch_3, _, __ = next(loader)

                domain_labels.fill_(real_domain_labels)
                X_1.copy_(image_batch_1)
                X_3.copy_(image_batch_3)

                style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
                style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

                kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
                kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
                kl_divergence_loss_1.backward(retain_graph=True)

                _, __, class_3 = encoder(Variable(X_3))
                reconstructed_X_1_3 = decoder(style_1, class_3)

                output_1 = discriminator(Variable(X_3), reconstructed_X_1_3)

                generator_error_1 = cross_entropy_loss(output_1, Variable(domain_labels))
                generator_error_1.backward(retain_graph=True)

                style_latent_space.normal_(0., 1.)
                reconstructed_X_latent_3 = decoder(Variable(style_latent_space), class_3)

                output_2 = discriminator(Variable(X_3), reconstructed_X_latent_3)

                generator_error_2 = cross_entropy_loss(output_2, Variable(domain_labels))
                generator_error_2.backward()

                generator_error = generator_error_1 + generator_error_2
                kl_divergence_error += kl_divergence_loss_1

                generator_optimizer.step()

            # C. run the discriminator
            for i in range(FLAGS.discriminator_times):

                discriminator_optimizer.zero_grad()

                # train discriminator on real data
                domain_labels.fill_(real_domain_labels)

                image_batch_1, _, __ = next(loader)
                image_batch_2, image_batch_3, _ = next(loader)

                X_1.copy_(image_batch_1)
                X_2.copy_(image_batch_2)
                X_3.copy_(image_batch_3)

                real_output = discriminator(Variable(X_2), Variable(X_3))

                discriminator_real_error = cross_entropy_loss(real_output, Variable(domain_labels))
                discriminator_real_error.backward()

                # train discriminator on fake data
                domain_labels.fill_(fake_domain_labels)

                style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
                style_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)

                _, __, class_3 = encoder(Variable(X_3))
                reconstructed_X_1_3 = decoder(style_1, class_3)

                fake_output = discriminator(Variable(X_3), reconstructed_X_1_3)

                discriminator_fake_error = cross_entropy_loss(fake_output, Variable(domain_labels))
                discriminator_fake_error.backward()

                # total discriminator error
                discriminator_error = discriminator_real_error + discriminator_fake_error

                # calculate discriminator accuracy for this step
                target_true_labels = torch.cat((torch.ones(FLAGS.batch_size), torch.zeros(FLAGS.batch_size)), dim=0)
                if FLAGS.cuda:
                    target_true_labels = target_true_labels.cuda()

                discriminator_predictions = torch.cat((real_output, fake_output), dim=0)
                _, discriminator_predictions = torch.max(discriminator_predictions, 1)

                discriminator_accuracy = (discriminator_predictions.data == target_true_labels.long()
                                          ).sum().item() / (FLAGS.batch_size * 2)

                if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy:
                    discriminator_optimizer.step()

            if (iteration + 1) % 50 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0]))

                print('')
                print('Generator loss: ' + str(generator_error.data.storage().tolist()[0]))
                print('Discriminator loss: ' + str(discriminator_error.data.storage().tolist()[0]))
                print('Discriminator accuracy: ' + str(discriminator_accuracy))

                print('..........')

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\n'.format(
                    epoch,
                    iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    kl_divergence_error.data.storage().tolist()[0],
                    generator_error.data.storage().tolist()[0],
                    discriminator_error.data.storage().tolist()[0],
                    discriminator_accuracy
                ))

            # write to tensorboard
            writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Generator loss', generator_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Discriminator loss', discriminator_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Discriminator accuracy', discriminator_accuracy * 100,
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))
            torch.save(discriminator.state_dict(), os.path.join('checkpoints', FLAGS.discriminator_save))
Exemple #10
0
                      normalize=True,
                      range=(-1., 1.))
    vutils.save_image(fixed_annos.float() / n_classes,
                      join(sample_path, '{:03d}_anno.jpg'.format(0)),
                      nrow=4,
                      padding=0)

    # Models
    E = Encoder().to(device)
    E.apply(init_weights)
    # summary(E, (3, 256, 256), device=device)
    G = Generator(n_classes).to(device)
    G.apply(init_weights)
    # summary(G, [(256,), (10, 256, 256)], device=device)
    D = Discriminator(n_classes).to(device)
    D.apply(init_weights)
    # summary(D, (13, 256, 256), device=device)
    vgg = VGG().to(device)

    if args.multi_gpu:
        E = nn.DataParallel(E)
        G = nn.DataParallel(G)
        # G = convert_model(G)
        D = nn.DataParallel(D)
        VGG = nn.DataParallel(VGG)

    # Optimizers
    G_opt = optim.Adam(itertools.chain(G.parameters(), E.parameters()),
                       lr=args.lr_G,
                       betas=(args.beta1, args.beta2))
    D_opt = optim.Adam(D.parameters(),
Exemple #11
0
def init_training(args):
    """Initialize the data loader, the networks, the optimizers and the loss functions."""
    datasets = Cifar10Dataset.get_datasets_from_scratch(args.data_path)
    for phase in ['train', 'test']:
        print('{} dataset len: {}'.format(phase, len(datasets[phase])))

    # define loaders
    data_loaders = {
        'train':
        DataLoader(datasets['train'],
                   batch_size=args.batch_size,
                   shuffle=True,
                   num_workers=args.num_workers),
        'test':
        DataLoader(datasets['test'],
                   batch_size=args.batch_size,
                   shuffle=False,
                   num_workers=args.num_workers)
    }

    # check CUDA availability and set device
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('Use GPU: {}'.format(str(device) != 'cpu'))

    # set up models
    generator = Generator(args.gen_norm).to(device)
    discriminator = Discriminator(args.disc_norm).to(device)

    # initialize weights
    if args.apply_weight_init:
        generator.apply(weights_init_normal)
        discriminator.apply(weights_init_normal)

    # adam optimizer with reduced momentum
    optimizers = {
        'gen':
        torch.optim.Adam(generator.parameters(),
                         lr=args.base_lr_gen,
                         betas=(0.5, 0.999)),
        'disc':
        torch.optim.Adam(discriminator.parameters(),
                         lr=args.base_lr_disc,
                         betas=(0.5, 0.999))
    }

    # losses
    losses = {
        'l1': torch.nn.L1Loss(reduction='mean'),
        'disc': torch.nn.BCELoss(reduction='mean')
    }

    # make save dir, if it does not exists
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    # load weights if the training is not starting from the beginning
    global_step = args.start_epoch * len(
        data_loaders['train']) if args.start_epoch > 0 else 0
    if args.start_epoch > 0:

        generator.load_state_dict(
            torch.load(os.path.join(
                args.save_path,
                'checkpoint_ep{}_gen.pt'.format(args.start_epoch - 1)),
                       map_location=device))
        discriminator.load_state_dict(
            torch.load(os.path.join(
                args.save_path,
                'checkpoint_ep{}_disc.pt'.format(args.start_epoch - 1)),
                       map_location=device))

    return global_step, device, data_loaders, generator, discriminator, optimizers, losses
Exemple #12
0
    else:
        USE_CUDA = False

    dataset = CustomDataset(opt)
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=opt.shuffle,
                                              num_workers=opt.n_workers)

    print(len(data_loader))

    G = Generator(opt)
    D = Discriminator(opt)

    G.apply(weight_init)
    D.apply(weight_init)

    print(G)
    print(D)

    if USE_CUDA:
        G = G.cuda()
        D = D.cuda()

    G_optim = torch.optim.Adam(G.parameters(),
                               lr=opt.lr,
                               betas=(opt.beta1, opt.beta2))
    D_optim = torch.optim.Adam(D.parameters(),
                               lr=opt.lr,
                               betas=(opt.beta1, opt.beta2))
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim)
    encoder.apply(weights_init)

    decoder = Decoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
        discriminator.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.discriminator_save)))
    """
    variable definition
    """
    real_domain_labels = 1
    fake_domain_labels = 0

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)

    domain_labels = torch.LongTensor(FLAGS.batch_size)
    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        domain_labels = domain_labels.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))

    discriminator_optimizer = optim.Adam(list(discriminator.parameters()),
                                         lr=FLAGS.initial_learning_rate,
                                         betas=(FLAGS.beta_1, FLAGS.beta_2))

    generator_optimizer = optim.Adam(list(encoder.parameters()) +
                                     list(decoder.parameters()),
                                     lr=FLAGS.initial_learning_rate,
                                     betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write('Epoch\tIteration\tReconstruction_loss\t')
            log.write(
                'Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n')

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist',
                                download=True,
                                train=True,
                                transform=transform_config)
    loader = cycle(
        DataLoader(paired_mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialise variables
    discriminator_accuracy = 0.

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, labels_batch_1 = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            nv_1, nc_1 = encoder(Variable(X_1))
            nv_2, nc_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(nv_1, nc_2)
            reconstructed_X_2 = decoder(nv_2, nc_1)

            reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = reconstruction_error_1 + reconstruction_error_2

            if FLAGS.train_auto_encoder:
                auto_encoder_optimizer.step()

            # B. run the adversarial part of the architecture

            # B. a) run the discriminator
            for i in range(FLAGS.discriminator_times):
                discriminator_optimizer.zero_grad()

                # train discriminator on real data
                domain_labels.fill_(real_domain_labels)

                image_batch_1, image_batch_2, labels_batch_1 = next(loader)

                X_1.copy_(image_batch_1)
                X_2.copy_(image_batch_2)

                real_output = discriminator(Variable(X_1), Variable(X_2))

                discriminator_real_error = FLAGS.disc_coef * cross_entropy_loss(
                    real_output, Variable(domain_labels))
                discriminator_real_error.backward()

                # train discriminator on fake data
                domain_labels.fill_(fake_domain_labels)

                image_batch_3, _, labels_batch_3 = next(loader)
                X_3.copy_(image_batch_3)

                nv_3, nc_3 = encoder(Variable(X_3))

                # reconstruction is taking common factor from X_1 and varying factor from X_3
                reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1])

                fake_output = discriminator(Variable(X_1), reconstructed_X_3_1)

                discriminator_fake_error = FLAGS.disc_coef * cross_entropy_loss(
                    fake_output, Variable(domain_labels))
                discriminator_fake_error.backward()

                # total discriminator error
                discriminator_error = discriminator_real_error + discriminator_fake_error

                # calculate discriminator accuracy for this step
                target_true_labels = torch.cat((torch.ones(
                    FLAGS.batch_size), torch.zeros(FLAGS.batch_size)),
                                               dim=0)
                if FLAGS.cuda:
                    target_true_labels = target_true_labels.cuda()

                discriminator_predictions = torch.cat(
                    (real_output, fake_output), dim=0)
                _, discriminator_predictions = torch.max(
                    discriminator_predictions, 1)

                discriminator_accuracy = (discriminator_predictions.data
                                          == target_true_labels.long()).sum(
                                          ).item() / (FLAGS.batch_size * 2)

                if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy and FLAGS.train_discriminator:
                    discriminator_optimizer.step()

            # B. b) run the generator
            for i in range(FLAGS.generator_times):

                generator_optimizer.zero_grad()

                image_batch_1, _, labels_batch_1 = next(loader)
                image_batch_3, __, labels_batch_3 = next(loader)

                domain_labels.fill_(real_domain_labels)
                X_1.copy_(image_batch_1)
                X_3.copy_(image_batch_3)

                nv_3, nc_3 = encoder(Variable(X_3))

                # reconstruction is taking common factor from X_1 and varying factor from X_3
                reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1])

                output = discriminator(Variable(X_1), reconstructed_X_3_1)

                generator_error = FLAGS.gen_coef * cross_entropy_loss(
                    output, Variable(domain_labels))
                generator_error.backward()

                if FLAGS.train_generator:
                    generator_optimizer.step()

            # print progress after 10 iterations
            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('Generator loss: ' +
                      str(generator_error.data.storage().tolist()[0]))

                print('')
                print('Discriminator loss: ' +
                      str(discriminator_error.data.storage().tolist()[0]))
                print('Discriminator accuracy: ' + str(discriminator_accuracy))

                print('..........')

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    generator_error.data.storage().tolist()[0],
                    discriminator_error.data.storage().tolist()[0],
                    discriminator_accuracy))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Generator loss',
                generator_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Discriminator loss',
                discriminator_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
            torch.save(discriminator.state_dict(),
                       os.path.join('checkpoints', FLAGS.discriminator_save))
            """
            save reconstructed images and style swapped image generations to check progress
            """
            image_batch_1, image_batch_2, labels_batch_1 = next(loader)
            image_batch_3, _, __ = next(loader)

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)
            X_3.copy_(image_batch_3)

            nv_1, nc_1 = encoder(Variable(X_1))
            nv_2, nc_2 = encoder(Variable(X_2))
            nv_3, nc_3 = encoder(Variable(X_3))

            reconstructed_X_1 = decoder(nv_1, nc_2)
            reconstructed_X_3_2 = decoder(nv_3, nc_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            image_batch = np.concatenate(
                (image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(
                reconstructed_X_1.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_x = np.concatenate(
                (reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x,
                        name=str(epoch) + '_target',
                        save=True)

            # save cross reconstructed batch
            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            style_batch = np.concatenate(
                (style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            reconstructed_style = np.transpose(
                reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_style = np.concatenate(
                (reconstructed_style, reconstructed_style,
                 reconstructed_style),
                axis=3)
            imshow_grid(reconstructed_style,
                        name=str(epoch) + '_style_target',
                        save=True)
Exemple #14
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        raise Exception('This is not implemented')
        encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))

    """
    variable definition
    """

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)

    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)

    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    adversarial_loss = nn.BCELoss()

    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()
        adversarial_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        style_latent_space = style_latent_space.cuda()

    """
    optimizer and scheduler definition
    """
    auto_encoder_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    reverse_cycle_optimizer = optim.Adam(
        list(encoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    generator_optimizer = optim.Adam(
        list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    discriminator_optimizer = optim.Adam(
        list(discriminator.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    # divide the learning rate by a factor of 10 after 80 epochs
    auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer, step_size=80, gamma=0.1)
    reverse_cycle_scheduler = optim.lr_scheduler.StepLR(reverse_cycle_optimizer, step_size=80, gamma=0.1)
    generator_scheduler = optim.lr_scheduler.StepLR(generator_optimizer, step_size=80, gamma=0.1)
    discriminator_scheduler = optim.lr_scheduler.StepLR(discriminator_optimizer, step_size=80, gamma=0.1)

    # Used later to define discriminator ground truths
    Tensor = torch.cuda.FloatTensor if FLAGS.cuda else torch.FloatTensor

    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            headers = ['Epoch', 'Iteration', 'Reconstruction_loss', 'KL_divergence_loss', 'Reverse_cycle_loss']

            if FLAGS.forward_gan:
              headers.extend(['Generator_forward_loss', 'Discriminator_forward_loss'])

            if FLAGS.reverse_gan:
              headers.extend(['Generator_reverse_loss', 'Discriminator_reverse_loss'])

            log.write('\t'.join(headers) + '\n')

    # load data set and create data loader instance
    print('Loading CIFAR paired dataset...')
    paired_cifar = CIFAR_Paired(root='cifar', download=True, train=True, transform=transform_config)
    loader = cycle(DataLoader(paired_cifar, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True))

    # Save a batch of images to use for visualization
    image_sample_1, image_sample_2, _ = next(loader)
    image_sample_3, _, _ = next(loader)

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print('Epoch #' + str(epoch) + '..........................................................................')

        # update the learning rate scheduler
        auto_encoder_scheduler.step()
        reverse_cycle_scheduler.step()
        generator_scheduler.step()
        discriminator_scheduler.step()

        for iteration in range(int(len(paired_cifar) / FLAGS.batch_size)):
            # Adversarial ground truths
            valid = Variable(Tensor(FLAGS.batch_size, 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(FLAGS.batch_size, 1).fill_(0.0), requires_grad=False)

            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(Variable(X_1))
            style_latent_space_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

            kl_divergence_loss_1 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
            )
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            style_mu_2, style_logvar_2, class_latent_space_2 = encoder(Variable(X_2))
            style_latent_space_2 = reparameterize(training=True, mu=style_mu_2, logvar=style_logvar_2)

            kl_divergence_loss_2 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) - style_logvar_2.exp())
            )
            kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_2.backward(retain_graph=True)

            reconstructed_X_1 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_2 = decoder(style_latent_space_2, class_latent_space_1)

            reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = (reconstruction_error_1 + reconstruction_error_2) / FLAGS.reconstruction_coef
            kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2) / FLAGS.kl_divergence_coef

            auto_encoder_optimizer.step()

            # A-1. Discriminator training during forward cycle
            if FLAGS.forward_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_f_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_f_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_f_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_f_loss.backward()

              discriminator_optimizer.step()

            # B. reverse cycle
            image_batch_1, _, __ = next(loader)
            image_batch_2, _, __ = next(loader)

            reverse_cycle_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_latent_space.normal_(0., 1.)

            _, __, class_latent_space_1 = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(Variable(style_latent_space), class_latent_space_1.detach())
            reconstructed_X_2 = decoder(Variable(style_latent_space), class_latent_space_2.detach())

            style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1)
            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)

            style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2)
            style_latent_space_2 = reparameterize(training=False, mu=style_mu_2, logvar=style_logvar_2)

            reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(style_latent_space_1, style_latent_space_2)
            reverse_cycle_loss.backward()
            reverse_cycle_loss /= FLAGS.reverse_cycle_coef

            reverse_cycle_optimizer.step()

            # B-1. Discriminator training during reverse cycle
            if FLAGS.reverse_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_r_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_r_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_r_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_r_loss.backward()

              discriminator_optimizer.step()

            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0]))
                print('Reverse cycle loss: ' + str(reverse_cycle_loss.data.storage().tolist()[0]))

                if FLAGS.forward_gan:
                  print('Generator F loss: ' + str(gen_f_loss.data.storage().tolist()[0]))
                  print('Discriminator F loss: ' + str(dis_f_loss.data.storage().tolist()[0]))

                if FLAGS.reverse_gan:
                  print('Generator R loss: ' + str(gen_r_loss.data.storage().tolist()[0]))
                  print('Discriminator R loss: ' + str(dis_r_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                row = []

                row.append(epoch)
                row.append(iteration)
                row.append(reconstruction_error.data.storage().tolist()[0])
                row.append(kl_divergence_error.data.storage().tolist()[0])
                row.append(reverse_cycle_loss.data.storage().tolist()[0])

                if FLAGS.forward_gan:
                  row.append(gen_f_loss.data.storage().tolist()[0])
                  row.append(dis_f_loss.data.storage().tolist()[0])

                if FLAGS.reverse_gan:
                  row.append(gen_r_loss.data.storage().tolist()[0])
                  row.append(dis_r_loss.data.storage().tolist()[0])

                row = [str(x) for x in row]
                log.write('\t'.join(row) + '\n')

            # write to tensorboard
            writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Reverse cycle loss', reverse_cycle_loss.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.forward_gan:
              writer.add_scalar('Generator F loss', gen_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator F loss', dis_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.reverse_gan:
              writer.add_scalar('Generator R loss', gen_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator R loss', dis_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))

            """
            save reconstructed images and style swapped image generations to check progress
            """

            X_1.copy_(image_sample_1)
            X_2.copy_(image_sample_2)
            X_3.copy_(image_sample_3)

            style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))
            style_mu_3, style_logvar_3, _ = encoder(Variable(X_3))

            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)
            style_latent_space_3 = reparameterize(training=False, mu=style_mu_3, logvar=style_logvar_3)

            reconstructed_X_1_2 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_3_2 = decoder(style_latent_space_3, class_latent_space_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              image_batch = np.concatenate((image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_x = np.concatenate((reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True)

            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              style_batch = np.concatenate((style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            # save style swapped reconstructed batch
            reconstructed_style = np.transpose(reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_style = np.concatenate((reconstructed_style, reconstructed_style, reconstructed_style), axis=3)
            imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
def init_training(args):
    """Initialize the data loader, the networks, the optimizers and the loss functions."""
    datasets = dict()
    datasets['train'] = customed_dataset(img_path = args.train_data_path, img_size = args.img_size, km_file_path = args.km_file_path)
    datasets['val'] = customed_dataset(img_path = args.val_data_path, img_size = args.img_size,km_file_path = args.km_file_path)
    for phase in ['train', 'val']:
        print('{} dataset len: {}'.format(phase, len(datasets[phase])))

    # define loaders
    data_loaders = {
        'train': DataLoader(datasets['train'], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers),
        'val': DataLoader(datasets['val'], batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    }

    # check CUDA availability and set device
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('Use GPU: {}'.format(str(device) != 'cpu'))

    # set up models
    if args.use_memory == True:
        mem = Memory_Network(mem_size = args.mem_size, color_feat_dim = args.color_feat_dim, spatial_feat_dim = args.spatial_feat_dim, top_k = args.top_k, alpha = args.alpha).to(device)
        feature_integrator = Feature_Integrator(3, 1, 200).to(device)
    generator = Generator(args.color_feat_dim, args.img_size, args.gen_norm).to(device)
    discriminator = Discriminator(args.color_feat_dim, args.img_size, args.dis_norm).to(device)

    # initialize weights
    if args.apply_weight_init == True:
        generator.apply(weights_init_normal)
        discriminator.apply(weights_init_normal)

    # set networks as training mode
    generator = generator.train()
    discriminator = discriminator.train()
    if args.use_memory == True:
        mem = mem.train()
        feature_integrator = feature_integrator.train()

    # adam optimizer
    if args.use_memory == True:
        optimizers = {
            'gen': torch.optim.Adam(generator.parameters(), lr=args.base_lr_gen, betas=(0.5, 0.999)),
            'disc': torch.optim.Adam(discriminator.parameters(), lr=args.base_lr_disc, betas=(0.5, 0.999)),
            'mem': torch.optim.Adam(mem.parameters(), lr = args.base_lr_mem),
            'feat': torch.optim.Adam(feature_integrator.parameters(), lr = args.base_lr_feat)
        }
    else:
        optimizers = {
            'gen': torch.optim.Adam(generator.parameters(), lr=args.base_lr_gen),
            'disc': torch.optim.Adam(discriminator.parameters(), lr=args.base_lr_disc),
        }

    # losses
    losses = {
        'l1': torch.nn.L1Loss(reduction='mean'),
        'disc': torch.nn.BCEWithLogitsLoss(reduction='mean'),
        'smoothl1': torch.nn.SmoothL1Loss(reduction='mean'),
        'KLD': torch.nn.KLDivLoss(reduction='batchmean')
    }

    # make save dir, if it does not exists
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    # load weights if the training is not starting from the beginning
    global_step = args.start_epoch * len(data_loaders['train']) if args.start_epoch > 0 else 0
    if args.start_epoch > 0:

        generator.load_state_dict(torch.load(
            os.path.join(args.save_path, 'checkpoint_ep{}_gen.pt'.format(args.start_epoch - 1)),
            map_location=device
        ))
        discriminator.load_state_dict(torch.load(
            os.path.join(args.save_path, 'checkpoint_ep{}_disc.pt'.format(args.start_epoch - 1)),
            map_location=device
        ))
        mem_checkpoint = torch.load(os.path.join(args.save_path, 'checkpoint_ep{}_mem.pt'.format(args.start_epoch - 1)), map_location=device)
        mem.load_state_dict(mem_checkpoint['mem_model'])
        mem.sptial_key = mem_checkpoint['mem_key']
        mem.color_value = mem_checkpoint['mem_value']
        mem.age = mem_checkpoint['mem_age']
        mem.img_id = mem_checkpoint['img_id']

        feature_integrator.load_state_dict(torch.load(
            os.path.join(args.save_path, 'checkpoint_ep{}_feat.pt'.format(args.start_epoch - 1)),
            map_location=device
        ))

    if args.use_memory == True:
        return global_step, device, data_loaders, mem, feature_integrator, generator, discriminator, optimizers, losses
    else:
        return global_step, device, data_loaders, generator, discriminator, optimizers, losses
Exemple #16
0
 def get_models(self):
     G = Generator().to(self.device)
     D = Discriminator().to(self.device)
     G.apply(weights_init)
     D.apply(weights_init)
     return G, D
Exemple #17
0
def main(args, dataloader):
    # define the networks
    netG = Generator(ngf=args.ngf, nz=args.nz, nc=args.nc).cuda()
    netG.apply(weight_init)
    print(netG)

    netD = Discriminator(ndf=args.ndf, nc=args.nc, nz=args.nz).cuda()
    netD.apply(weight_init)
    print(netD)

    netE = Encoder(nc=args.nc, ngf=args.ngf, nz=args.nz).cuda()
    netE.apply(weight_init)
    print(netE)

    # define the loss criterion
    criterion = nn.BCELoss()

    # define the ground truth labels.
    real_label = 1  # for the real pair
    fake_label = 0  # for the fake pair

    # define the optimizers, one for each network
    netD_optimizer = optim.Adam(netD.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
    netG_optimizer = optim.Adam([{
        'params': netG.parameters()
    }, {
        'params': netE.parameters()
    }],
                                lr=args.lr,
                                betas=(0.5, 0.999))

    # Training loop
    iters = 0

    for epoch in range(args.num_epochs):
        # iterate through the dataloader
        for i, data in enumerate(dataloader, 0):
            real_images = data[0].cuda()
            bs = real_images.shape[0]

            noise1 = torch.Tensor(real_images.size()).normal_(
                0, 0.1 * (args.num_epochs - epoch) / args.num_epochs).cuda()
            noise2 = torch.Tensor(real_images.size()).normal_(
                0, 0.1 * (args.num_epochs - epoch) / args.num_epochs).cuda()

            # get the output from the encoder
            z_real = netE(real_images).view(bs, -1)
            mu, sigma = z_real[:, :args.nz], z_real[:, args.nz:]
            log_sigma = torch.exp(sigma)
            epsilon = torch.randn(bs, args.nz).cuda()
            # reparameterization trick
            output_z = mu + epsilon * log_sigma
            output_z = output_z.view(bs, -1, 1, 1)

            # get the output from the generator
            z_fake = torch.randn(bs, args.nz, 1, 1).cuda()
            d_fake = netG(z_fake)

            # get the output from the discriminator for the real pair
            out_real_pair = netD(real_images + noise1, output_z)

            # get the output from the discriminator for the fake pair
            out_fake_pair = netD(d_fake + noise2, z_fake)

            real_labels = torch.full((bs, ), real_label).cuda()
            fake_labels = torch.full((bs, ), fake_label).cuda()

            # compute the losses
            d_loss = criterion(out_real_pair, real_labels) + criterion(
                out_fake_pair, fake_labels)
            g_loss = criterion(out_real_pair, fake_labels) + criterion(
                out_fake_pair, real_labels)

            # update weights
            if g_loss.item() < 3.5:
                netD_optimizer.zero_grad()
                d_loss.backward(retain_graph=True)
                netD_optimizer.step()

            netG_optimizer.zero_grad()
            g_loss.backward()
            netG_optimizer.step()

            # print the training losses
            if iters % 10 == 0:
                print(
                    '[%3d/%d][%3d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x, z): %.4f\tD(G(z), z): %.4f'
                    %
                    (epoch, args.num_epochs, i, len(dataloader), d_loss.item(),
                     g_loss.item(), out_real_pair.mean().item(),
                     out_fake_pair.mean().item()))

            # visualize the samples generated by the G.
            if iters % 500 == 0:
                out_dir = os.path.join(args.log_dir, args.run_name, 'out/')
                os.makedirs(out_dir, exist_ok=True)
                save_image(d_fake.cpu()[:64, ],
                           os.path.join(out_dir,
                                        str(iters).zfill(7) + '.png'),
                           nrow=8,
                           normalize=True)
                # save reconstructions
                recons_dir = os.path.join(args.log_dir, args.run_name,
                                          'recons/')
                os.makedirs(recons_dir, exist_ok=True)
                save_image(torch.cat(
                    [real_images.cpu()[:8],
                     d_fake.cpu()[:8, ]], dim=3),
                           os.path.join(recons_dir,
                                        str(iters).zfill(7) + '.png'),
                           nrow=1,
                           normalize=True)

            iters += 1

        # save weights
        save_dir = os.path.join(args.log_dir, args.run_name, 'weights')
        os.makedirs(save_dir, exist_ok=True)
        save_weights(netG, './%s/netG.pth' % (save_dir))
        save_weights(netE, './%s/netE.pth' % (save_dir))