class Trainer(nn.Module): def __init__(self, hyperparameters): super(Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks # auto-encoder for domain a self.trait_dim = hyperparameters['gen']['trait_dim'] self.gen_a = VAEGen(hyperparameters['input_dim'], hyperparameters['basis_encoder_dims'], hyperparameters['trait_encoder_dims'], hyperparameters['decoder_dims'], self.trait_dim) # auto-encoder for domain b self.gen_b = VAEGen(hyperparameters['input_dim'], hyperparameters['basis_encoder_dims'], hyperparameters['trait_encoder_dims'], hyperparameters['decoder_dims'], self.trait_dim) # discriminator for domain a self.dis_a = Discriminator(hyperparameters['input_dim'], hyperparameters['dis_dims'], 1) # discriminator for domain b self.dis_b = Discriminator(hyperparameters['input_dim'], hyperparameters['dis_dims'], 1) # fix the noise used in sampling self.trait_a = torch.randn(8, self.trait_dim, 1, 1) self.trait_b = torch.randn(8, self.trait_dim, 1, 1) # Setup the optimizers dis_params = list(self.dis_a.parameters()) + \ list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + \ list(self.gen_b.parameters()) for _p in gen_params: print(_p.data.shape) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.gen_a.apply(weights_init('gaussian')) self.gen_b.apply(weights_init('gaussian')) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() trait_a = Variable(self.trait_a) trait_b = Variable(self.trait_b) basis_a, trait_a_fake = self.gen_a.encode(x_a) basis_b, trait_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(basis_b, trait_a) x_ab = self.gen_b.decode(basis_a, trait_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() trait_a = Variable(torch.randn(x_a.size(0), self.trait_dim)) trait_b = Variable(torch.randn(x_b.size(0), self.trait_dim)) # encode basis_a, trait_a_prime = self.gen_a.encode(x_a) basis_b, trait_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(basis_a, trait_a_prime) x_b_recon = self.gen_b.decode(basis_b, trait_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(basis_b, trait_a) x_ab = self.gen_b.decode(basis_a, trait_b) # encode again basis_b_recon, trait_a_recon = self.gen_a.encode(x_ba) basis_a_recon, trait_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode( basis_a_recon, trait_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( basis_b_recon, trait_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_trait_a = self.recon_criterion( trait_a_recon, trait_a) self.loss_gen_recon_trait_b = self.recon_criterion( trait_b_recon, trait_b) self.loss_gen_recon_basis_a = self.recon_criterion( basis_a_recon, basis_a) self.loss_gen_recon_basis_b = self.recon_criterion( basis_b_recon, basis_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # total loss self.loss_gen_total = hyperparameters[ 'gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_trait_w'] * self.loss_gen_recon_trait_a + \ hyperparameters['recon_basis_w'] * self.loss_gen_recon_basis_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_trait_w'] * self.loss_gen_recon_trait_b + \ hyperparameters['recon_basis_w'] * self.loss_gen_recon_basis_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b self.loss_gen_total.backward() self.gen_opt.step() # def sample(self, x_a, x_b): # self.eval() # s_a1 = Variable(self.s_a) # s_b1 = Variable(self.s_b) # s_a2 = Variable(torch.randn(x_a.size(0), self.trait_dim, 1, 1)) # s_b2 = Variable(torch.randn(x_b.size(0), self.trait_dim, 1, 1)) # x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] # for i in range(x_a.size(0)): # c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) # c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) # x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) # x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) # x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) # x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) # x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) # x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) # x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) # x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) # x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) # self.train() # return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() trait_a = Variable(torch.randn(x_a.size(0), self.trait_dim)) trait_b = Variable(torch.randn(x_b.size(0), self.trait_dim)) # encode basis_a, _ = self.gen_a.encode(x_a) basis_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(basis_b, trait_a) x_ab = self.gen_b.decode(basis_a, trait_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba, x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab, x_b) self.loss_dis_total = hyperparameters['gan_w'] * \ self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
def main(args, dataloader): # define the G and D netG = DCGenerator(nz=args.nz, ngf=args.ngf, nc=args.nc).cuda() netG.apply(weight_init) print(netG) netD = Discriminator(nc=args.nc, ndf=args.ndf).cuda() netD.apply(weight_init) print(netD) # define the loss criterion criterion = nn.BCELoss() # sample a fixed noise vector that will be used to visualize the training # progress fixed_noise = torch.randn(64, args.nz, 1, 1).cuda() # define the ground truth labels. real_labels = 1 # for the real images fake_labels = 0 # for the fake images # define the optimizers, one for each network netD_optimizer = optim.Adam(params=netD.parameters(), lr=args.lr, betas=(0.5, 0.999)) netG_optimizer = optim.Adam(params=netG.parameters(), lr=args.lr, betas=(0.5, 0.999)) # sample two fixed noise vectors and do a linear interpolation between them # to get the intermediate noise vectors. We will generate samples for the interpolated # noise vectors to see effect of interpolation in the latent space. (See later!) z_1 = torch.randn(1, args.nz, 1, 1) z_2 = torch.randn(1, args.nz, 1, 1) fixed_interpolate = [] for i in range(64): lambda_interp = i / 63 z_interp = z_1 * (1 - lambda_interp) + lambda_interp * z_2 fixed_interpolate.append(z_interp) fixed_interpolate = torch.cat(fixed_interpolate, dim=0).cuda() # Training loop iters = 0 # for each epoch for epoch in range(args.num_epochs): # iterate through the data loader for i, data in enumerate(dataloader, 0): ## Discriminator training ## # maximize log(D(x)) + log(1 - D(G(x))) # The discriminator will be updated once with the real images # and once with the fake images. This is achieved by first computing # the gradients with the real images (the first term in the D loss function), # and then with the fake images generated by the G (second loss term). # Only after that the optimizer.step() will be done, which will update the # weights of the D. # IMPORTANT to note that when the D is updated, the G is kept frozen. # Gradients are calculated with loss.backward(). # train D with real images netD.train() netD.zero_grad() real_images = data[0].cuda() bs = real_images.shape[0] label = torch.full((bs,), real_labels).cuda() noise_1 = torch.Tensor(real_images.shape).normal_(0, 0.1 * (args.num_epochs - epoch) / args.num_epochs).cuda() output = netD(real_images + noise_1).view(-1) # calculate loss on real images. It pushes the D's output for real images # close to 1 errD_real = criterion(output, label) # calculate gradients for D errD_real.backward() # track D outputs for real images D_x = output.mean().item() # train D with fake images # sample a batch of noise vectors noise = torch.randn(bs, args.nz, 1, 1).cuda() # generate fake data fake_images = netG(noise) label.fill_(fake_labels) # run the fake images through the discriminator. # IMPORTANT to detach the fake_images because we do not need gradients # of the G activations wrt to the G weights. noise_2 = torch.Tensor(real_images.shape).normal_(0, 0.1 * (args.num_epochs - epoch) / args.num_epochs).cuda() output = netD(fake_images.detach() + noise_2).view(-1) # calculate loss on the fake images. It pushes the D's output for fake # images close to 0 errD_fake = criterion(output, label) # calculate the gradients for D errD_fake.backward() errD = (errD_real + errD_fake) # track D outputs for fake images D_G_x_1 = output.mean().item() # update the D weights with the gradients accumulated netD_optimizer.step() ## Generator training ## # minimize log(1 - D(G(x))) # But such a formulation provides no gradient during the early stages of # training and hence its is reformulated as: # maximize log(D(G(x))) # during the G training the D is kept fixed netG.train() netG.zero_grad() # real_labels because the G wants to make the fake images look as real as # possible label.fill_(real_labels) output = netD(fake_images + noise_2).view(-1) # calculate loss for G based on the fake images. It pushes the D's output # for fake images close to 1 errG = criterion(output, label) # calculate the gradients for G errG.backward() # track the outputs for fake images D_G_x_2 = output.mean().item() # update the G weights with the gradients accumulated netG_optimizer.step() # print the training losses if iters % 50 == 0: print('[%3d/%d][%3d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, args.num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_x_1, D_G_x_2)) # visualize the samples generated by the G. if (iters % 1000 == 0): out_dir = os.path.join(args.log_dir, args.run_name, 'out/') os.makedirs(out_dir, exist_ok=True) interp_dir = os.path.join(args.log_dir, args.run_name, 'interpolate/') os.makedirs(interp_dir, exist_ok=True) netG.eval() with torch.no_grad(): fake_fixed = netG(fixed_noise).cpu() save_image(fake_fixed, os.path.join(out_dir, str(iters).zfill(7) + '.png'), normalize=True) interp_fixed = netG(fixed_interpolate).cpu() save_image(interp_fixed, os.path.join(interp_dir, str(iters).zfill(7) + '.png'), normalize=True) iters += 1
class BiGAN(object): def __init__(self, args): self.z_dim = args.z_dim self.decay_rate = args.decay_rate self.learning_rate = args.learning_rate self.model_name = args.model_name self.batch_size = args.batch_size #initialize networks self.Generator = Generator(self.z_dim).cuda() self.Encoder = Encoder(self.z_dim).cuda() self.Discriminator = Discriminator().cuda() #set optimizers for all networks self.optimizer_G_E = torch.optim.Adam( list(self.Generator.parameters()) + list(self.Encoder.parameters()), lr=self.learning_rate, betas=(0.5, 0.999)) self.optimizer_D = torch.optim.Adam(self.Discriminator.parameters(), lr=self.learning_rate, betas=(0.5, 0.999)) #initialize network weights self.Generator.apply(weights_init) self.Encoder.apply(weights_init) self.Discriminator.apply(weights_init) def train(self, data): self.Generator.train() self.Encoder.train() self.Discriminator.train() self.optimizer_G_E.zero_grad() self.optimizer_D.zero_grad() #get fake z_data for generator self.z_fake = torch.randn((self.batch_size, self.z_dim)) #send fake z_data through generator to get fake x_data self.x_fake = self.Generator(self.z_fake.detach()) #send real data through encoder to get real z_data self.z_real = self.Encoder(data) #send real x and z data into discriminator self.out_real = self.Discriminator(data, z_real.detach()) #send fake x and z data into discriminator self.out_fake = self.Discriminator(x_fake.detach(), z_fake.detach()) #compute discriminator loss self.D_loss = nn.BCELoss() #compute generator/encoder loss self.G_E_loss = nn.BCELoss() #compute discriminator gradiants and backpropogate self.D_loss.backward() self.optimizer_D.step() #compute generator/encoder gradiants and backpropogate self.G_E_loss.backward() self.optimizer_G_E.step()
class LSGANs_Trainer(nn.Module): def __init__(self, hyperparameters): super(LSGANs_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.encoder = Encoder(hyperparameters['input_dim_a'], hyperparameters['gen']) self.decoder = Decoder(hyperparameters['input_dim_a'], hyperparameters['gen']) self.dis_a = Discriminator() self.dis_b = Discriminator() self.interp_net_ab = Interpolator() self.interp_net_ba = Interpolator() self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] enc_params = list(self.encoder.parameters()) dec_params = list(self.decoder.parameters()) dis_a_params = list(self.dis_a.parameters()) dis_b_params = list(self.dis_b.parameters()) interperlator_ab_params = list(self.interp_net_ab.parameters()) interperlator_ba_params = list(self.interp_net_ba.parameters()) self.enc_opt = torch.optim.Adam( [p for p in enc_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dec_opt = torch.optim.Adam( [p for p in dec_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_a_opt = torch.optim.Adam( [p for p in dis_a_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_b_opt = torch.optim.Adam( [p for p in dis_b_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.interp_ab_opt = torch.optim.Adam( [p for p in interperlator_ab_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.interp_ba_opt = torch.optim.Adam( [p for p in interperlator_ba_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters) self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters) self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters) self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters) self.interp_ab_scheduler = get_scheduler(self.interp_ab_opt, hyperparameters) self.interp_ba_scheduler = get_scheduler(self.interp_ba_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False self.total_loss = 0 self.best_iter = 0 self.perceptural_loss = Perceptural_loss() def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() c_a, s_a_fake = self.encoder(x_a) c_b, s_b_fake = self.encoder(x_b) # decode (cross domain) s_ab_interp = self.interp_net_ab(s_a_fake, s_b_fake, self.v) s_ba_interp = self.interp_net_ba(s_b_fake, s_a_fake, self.v) x_ba = self.decoder(c_b, s_a_interp) x_ab = selfdecoder(c_a, s_b_interp) self.train() return x_ab, x_ba def zero_grad(self): self.dis_a_opt.zero_grad() self.dis_b_opt.zero_grad() self.dec_opt.zero_grad() self.enc_opt.zero_grad() self.interp_ab_opt.zero_grad() self.interp_ba_opt.zero_grad() def dis_update(self, x_a, x_b, hyperparameters): self.zero_grad() # encode c_a, s_a = self.encoder(x_a) c_b, s_b = self.encoder(x_b) # decode (cross domain) self.v = torch.ones(s_a.size()) s_a_interp = self.interp_net_ba(s_b, s_a, self.v) s_b_interp = self.interp_net_ab(s_a, s_b, self.v) x_ba = self.decoder(c_b, s_a_interp) x_ab = self.decoder(c_a, s_b_interp) x_a_feature = self.dis_a(x_a) x_ba_feature = self.dis_a(x_ba) x_b_feature = self.dis_b(x_b) x_ab_feature = self.dis_b(x_ab) self.loss_dis_a = (x_ba_feature - x_a_feature).mean() self.loss_dis_b = (x_ab_feature - x_b_feature).mean() # gradient penality self.loss_dis_a_gp = self.dis_a.calculate_gradient_penalty(x_ba, x_a) self.loss_dis_b_gp = self.dis_b.calculate_gradient_penalty(x_ab, x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \ hyperparameters['gan_w'] * self.loss_dis_b + \ hyperparameters['gan_w'] * self.loss_dis_a_gp + \ hyperparameters['gan_w'] * self.loss_dis_b_gp self.loss_dis_total.backward() self.total_loss += self.loss_dis_total.item() self.dis_a_opt.step() self.dis_b_opt.step() def gen_update(self, x_a, x_b, hyperparameters): self.zero_grad() # encode c_a, s_a = self.encoder(x_a) c_b, s_b = self.encoder(x_b) # decode (within domain) x_a_recon = self.decoder(c_a, s_a) x_b_recon = self.decoder(c_b, s_b) # decode (cross domain) self.v = torch.ones(s_a.size()) s_a_interp = self.interp_net_ba(s_b, s_a, self.v) s_b_interp = self.interp_net_ab(s_a, s_b, self.v) x_ba = self.decoder(c_b, s_a_interp) x_ab = self.decoder(c_a, s_b_interp) # encode again c_b_recon, s_a_recon = self.encoder(x_ba) c_a_recon, s_b_recon = self.encoder(x_ab) # decode again x_aa = self.decoder( c_a_recon, s_a) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bb = self.decoder( c_b_recon, s_b) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aa, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bb, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # perceptual loss self.loss_gen_vgg_a = self.perceptural_loss( x_a_recon, x_a) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.perceptural_loss( x_b_recon, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_aa = self.perceptural_loss( x_aa, x_a) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_bb = self.perceptural_loss( x_bb, x_b) if hyperparameters['vgg_w'] > 0 else 0 # GAN loss x_ba_feature = self.dis_a(x_ba) x_ab_feature = self.dis_b(x_ab) self.loss_gen_adv_a = -x_ba_feature.mean() self.loss_gen_adv_b = -x_ab_feature.mean() # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_aa + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_bb + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.total_loss += self.loss_gen_total.item() self.dec_opt.step() self.enc_opt.step() self.interp_ab_opt.step() self.interp_ba_opt.step() def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ab, x_ba, x_aa, x_bb = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a = self.encoder(x_a[i].unsqueeze(0)) c_b, s_b = self.encoder(x_b[i].unsqueeze(0)) x_a_recon.append(self.decoder(c_a, s_a)) x_b_recon.append(self.decoder(c_b, s_b)) self.v = torch.ones(s_a.size()) s_a_interp = self.interp_net_ba(s_b, s_a, self.v) s_b_interp = self.interp_net_ab(s_a, s_b, self.v) x_ab_i = self.decoder(c_a, s_b_interp) x_ba_i = self.decoder(c_b, s_a_interp) c_a_recon, s_b_recon = self.encoder(x_ab_i) c_b_recon, s_a_recon = self.encoder(x_ba_i) x_ab.append(self.decoder(c_a, s_b_interp.unsqueeze(0))) x_ba.append(self.decoder(c_b, s_a_interp.unsqueeze(0))) x_aa.append(self.decoder(c_a_recon, s_a.unsqueeze(0))) x_bb.append(self.decoder(c_b_recon, s_b.unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ab, x_aa = torch.cat(x_ab), torch.cat(x_aa) x_ba, x_bb = torch.cat(x_ba), torch.cat(x_bb) self.train() return x_a, x_a_recon, x_ab, x_aa, x_b, x_b_recon, x_ba, x_bb def update_learning_rate(self): if self.dis_a_scheduler is not None: self.dis_a_scheduler.step() if self.dis_b_scheduler is not None: self.dis_b_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.enc_scheduler is not None: self.enc_scheduler.step() if self.dec_scheduler is not None: self.dec_scheduler.step() if self.interpo_ab_scheduler is not None: self.interpo_ab_scheduler.step() if self.interpo_ba_scheduler is not None: self.interpo_ba_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load encode model_name = get_model(checkpoint_dir, "encoder") state_dict = torch.load(model_name) self.encoder.load_state_dict(state_dict) # Load decode model_name = get_model(checkpoint_dir, "decoder") state_dict = torch.load(model_name) self.decoder.load_state_dict(state_dict) # Load discriminator a model_name = get_model(checkpoint_dir, "dis_a") state_dict = torch.load(model_name) self.dis_a.load_state_dict(state_dict) # Load discriminator a model_name = get_model(checkpoint_dir, "dis_b") state_dict = torch.load(model_name) self.dis_b.load_state_dict(state_dict) # Load interperlator ab model_name = get_model(checkpoint_dir, "interp_ab") state_dict = torch.load(model_name) self.interp_net_ab.load_state_dict(state_dict) # Load interperlator ba model_name = get_model(checkpoint_dir, "interp_ba") state_dict = torch.load(model_name) self.interp_net_ba.load_state_dict(state_dict) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.enc_opt.load_state_dict(state_dict['enc_opt']) self.dec_opt.load_state_dict(state_dict['dec_opt']) self.dis_a_opt.load_state_dict(state_dict['dis_a_opt']) self.dis_b_opt.load_state_dict(state_dict['dis_b_opt']) self.interp_ab_opt.load_state_dict(state_dict['interp_ab_opt']) self.interp_ba_opt.load_state_dict(state_dict['interp_ba_opt']) self.best_iter = state_dict['best_iter'] self.total_loss = state_dict['total_loss'] # Reinitilize schedulers self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters, self.best_iter) self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters, self.best_iter) self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters, self.best_iter) self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters, self.best_iter) self.interpo_ab_scheduler = get_scheduler(self.interp_ab_opt, hyperparameters, self.best_iter) self.interpo_ba_scheduler = get_scheduler(self.interp_ba_opt, hyperparameters, self.best_iter) print('Resume from iteration %d' % self.best_iter) return self.best_iter, self.total_loss def resume_iter(self, checkpoint_dir, surfix, hyperparameters): # Load encode state_dict = torch.load( os.path.join(checkpoint_dir, 'encoder' + surfix + '.pt')) self.encoder.load_state_dict(state_dict) # Load decode state_dict = torch.load( os.path.join(checkpoint_dir, 'decoder' + surfix + '.pt')) self.decoder.load_state_dict(state_dict) # Load discriminator a state_dict = torch.load( os.path.join(checkpoint_dir, 'dis_a' + surfix + '.pt')) self.dis_a.load_state_dict(state_dict) # # Load discriminator b state_dict = torch.load( os.path.join(checkpoint_dir, 'dis_b' + surfix + '.pt')) self.dis_b.load_state_dict(state_dict) state_dict = torch.load( os.path.join(checkpoint_dir, 'interp' + surfix + '.pt')) # print(state_dict) self.interp_net_ab.load_state_dict(state_dict['ab']) self.interp_net_ba.load_state_dict(state_dict['ba']) # Load interperlator ab state_dict = torch.load( os.path.join(checkpoint_dir, 'interp_ab' + surfix + '.pt')) self.interp_net_ab.load_state_dict(state_dict) # # Load interperlator ba state_dict = torch.load( os.path.join(checkpoint_dir, 'interp_ba' + surfix + '.pt')) self.interp_net_ba.load_state_dict(state_dict) # Load optimizers state_dict = torch.load( os.path.join(checkpoint_dir, 'optimizer' + surfix + '.pt')) self.enc_opt.load_state_dict(state_dict['enc_opt']) self.dec_opt.load_state_dict(state_dict['dec_opt']) self.dis_a_opt.load_state_dict(state_dict['dis_a_opt']) self.dis_b_opt.load_state_dict(state_dict['dis_b_opt']) self.interp_ab_opt.load_state_dict(state_dict['interp_ab_opt']) self.interp_ba_opt.load_state_dict(state_dict['interp_ba_opt']) self.best_iter = state_dict['best_iter'] self.total_loss = state_dict['total_loss'] # Reinitilize schedulers self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters, self.best_iter) self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters, self.best_iter) self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters, self.best_iter) self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters, self.best_iter) self.interpo_ab_scheduler = get_scheduler(self.interp_ab_opt, hyperparameters, self.best_iter) self.interpo_ba_scheduler = get_scheduler(self.interp_ba_opt, hyperparameters, self.best_iter) print('Resume from iteration %d' % self.best_iter) return self.best_iter, self.total_loss def save_better_model(self, snapshot_dir): # remove sub_optimal models files = glob.glob(snapshot_dir + '/*') for f in files: os.remove(f) # Save encoder, decoder, interpolator, discriminators, and optimizers encoder_name = os.path.join(snapshot_dir, 'encoder_%.4f.pt' % (self.total_loss)) decoder_name = os.path.join(snapshot_dir, 'decoder_%.4f.pt' % (self.total_loss)) interp_ab_name = os.path.join(snapshot_dir, 'interp_ab_%.4f.pt' % (self.total_loss)) interp_ba_name = os.path.join(snapshot_dir, 'interp_ba_%.4f.pt' % (self.total_loss)) dis_a_name = os.path.join(snapshot_dir, 'dis_a_%.4f.pt' % (self.total_loss)) dis_b_name = os.path.join(snapshot_dir, 'dis_b_%.4f.pt' % (self.total_loss)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save(self.encoder.state_dict(), encoder_name) torch.save(self.decoder.state_dict(), decoder_name) torch.save(self.interp_net_ab.state_dict(), interp_ab_name) torch.save(self.interp_net_ba.state_dict(), interp_ba_name) torch.save(self.dis_a_opt.state_dict(), dis_a_name) torch.save(self.dis_b_opt.state_dict(), dis_b_name) torch.save( { 'enc_opt': self.enc_opt.state_dict(), 'dec_opt': self.dec_opt.state_dict(), 'dis_a_opt': self.dis_a_opt.state_dict(), 'dis_b_opt': self.dis_b_opt.state_dict(), 'interp_ab_opt': self.interp_ab_opt.state_dict(), 'interp_ba_opt': self.interp_ba_opt.state_dict(), 'best_iter': self.best_iter, 'total_loss': self.total_loss }, opt_name) def save_at_iter(self, snapshot_dir, iterations): encoder_name = os.path.join(snapshot_dir, 'encoder_%08d.pt' % (iterations + 1)) decoder_name = os.path.join(snapshot_dir, 'decoder_%08d.pt' % (iterations + 1)) interp_ab_name = os.path.join(snapshot_dir, 'interp_ab_%08d.pt' % (iterations + 1)) interp_ba_name = os.path.join(snapshot_dir, 'interp_ba_%08d.pt' % (iterations + 1)) dis_a_name = os.path.join(snapshot_dir, 'dis_a_%08d.pt' % (iterations + 1)) dis_b_name = os.path.join(snapshot_dir, 'dis_b_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer_%08d.pt' % (iterations + 1)) torch.save(self.encoder.state_dict(), encoder_name) torch.save(self.decoder.state_dict(), decoder_name) torch.save(self.interp_net_ab.state_dict(), interp_ab_name) torch.save(self.interp_net_ba.state_dict(), interp_ba_name) torch.save(self.dis_a_opt.state_dict(), dis_a_name) torch.save(self.dis_b_opt.state_dict(), dis_b_name) torch.save( { 'enc_opt': self.enc_opt.state_dict(), 'dec_opt': self.dec_opt.state_dict(), 'dis_a_opt': self.dis_a_opt.state_dict(), 'dis_b_opt': self.dis_b_opt.state_dict(), 'interp_ab_opt': self.interp_ab_opt.state_dict(), 'interp_ba_opt': self.interp_ba_opt.state_dict(), 'best_iter': self.best_iter, 'total_loss': self.total_loss }, opt_name)
class fgan(object): """ This class ensembles data generating process of Huber's contamination model and training process for estimating center parameter via F-GAN. Usage: >> f = fgan(p=100, eps=0.2, device=device, tol=1e-5) >> f.dist_init(true_type='Gaussian', cont_type='Gaussian', cont_mean=5.0, cont_var=1.) >> f.data_init(train_size=50000, batch_size=500) >> f.net_init(d_hidden_units=[20], elliptical=False, activation_D1='LeakyReLU') >> f.optimizer_init(lr_d=0.2, lr_g=0.02, d_steps=5, g_steps=1) >> f.fit(floss='js', epochs=150, avg_epochs=25, verbose=50, show=True) Please refer to the Demo.ipynb for more examples. """ def __init__(self, p, eps, device=None, tol=1e-5): """Set parameters for Huber's model epsilon X i.i.d ~ (1-eps) P(mu, Sigma) + eps Q, where P is the real distribution, mu is the center parameter we want to estimate, Q is the contamination distribution and eps is the contamination ratio. Args: p: dimension. eps: contamination ratio. tol: make sure the denominator is not zero. device: If no device is provided, it will automatically choose cpu or cuda. """ self.p = p self.eps = eps self.tol = tol self.device = device if device is not None \ else torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') def dist_init(self, true_type='Gaussian', cont_type='Gaussian', true_mean=0.0, cont_mean=0.0, cont_var=1, cont_covmat=None): """ Set parameters for distribution under Huber contaminaton models. We assume the center parameter of the true distribution mu is 0 and the covariance is indentity martix. Args: true_type : Type of real distribution P. 'Gaussian', 'Cauchy'. cont_type : Type of contamination distribution Q, 'Gaussian', 'Cauchy'. cont_mean: center parameter for Q cont_var: If scatter (covariance) matrix of Q is diagonal, cont_var gives the diagonal element. cont_covmat: Other scatter matrix can be provided (as torch.tensor format). If cont_covmat is not None, cont_var will be ignored. """ self.true_type = true_type self.cont_type = cont_type ## settings for true distribution sampler self.true_mean = torch.ones(self.p) * true_mean if true_type == 'Gaussian': self.t_d = MultivariateNormal(self.true_mean, covariance_matrix=torch.eye(self.p)) elif true_type == 'Cauchy': self.t_normal_d = MultivariateNormal(torch.zeros(self.p), covariance_matrix=torch.eye( self.p)) self.t_chi2_d = Chi2(df=1) else: raise NameError('True type must be Gaussian or Cauchy!') ## settings for contamination distribution sampler if cont_covmat is not None: self.cont_covmat = cont_covmat else: self.cont_covmat = torch.eye(self.p) * cont_var self.cont_mean = torch.ones(self.p) * cont_mean if cont_type == 'Gaussian': self.c_d = MultivariateNormal(torch.zeros(self.p), covariance_matrix=self.cont_covmat) elif cont_type == 'Cauchy': self.c_normal_d = MultivariateNormal( torch.zeros(self.p), covariance_matrix=self.cont_covmat) self.c_chi2_d = Chi2(df=1) else: raise NameError('Cont type must be Gaussian or Cauchy!') def _sampler(self, n): """ Sampler and it will return a [n, p] torch tensor. """ if self.true_type == 'Gaussian': t_x = self.t_d.sample((n, )) elif self.true_type == 'Cauchy': t_normal_x = self.t_normal_d.sample((n, )) t_chi2_x = self.t_chi2_d.sample((n, )) t_x = t_normal_x / (torch.sqrt(t_chi2_x.view(-1, 1)) + self.tol) if self.cont_type == 'Gaussian': c_x = self.c_d.sample((n, )) + self.cont_mean.view(1, -1) elif self.cont_type == 'Cauchy': c_normal_x = self.c_normal_d.sample((n, )) c_chi2_x = self.c_chi2_d.sample((n, )) c_x = c_normal_x / (torch.sqrt(c_chi2_x.view(-1, 1)) + self.tol) +\ self.cont_mean.view(1, -1) s = (torch.rand(n) < self.eps).float() x = (t_x.transpose(1, 0) * (1 - s) + c_x.transpose(1, 0) * s).transpose(1, 0) return x def data_init(self, train_size=50000, batch_size=100): self.Xtr = self._sampler(train_size) self.batch_size = batch_size self.poolset = PoolSet(self.Xtr) self.dataloader = DataLoader(self.poolset, batch_size=self.batch_size, shuffle=True) def net_init(self, d_hidden_units, use_logistic_regression=False, init_weights=None, init_eta=0.0, use_median_init_G=True, elliptical=False, g_input_dim=10, g_hidden_units=[10, 10], activation_D1='Sigmoid', verbose=True): """ Settings for Discriminator and Generator. Args: d_hidden_units: a list of hidden units for Discriminator, e.g. d_hidden_units=[10, 5], then the discrimintor has structure p (input) - 10 - 5 - 1 (output). elliptical: Boolean. If elliptical == False, G_1(x|b) = x + b, where b will be learned and x ~ Gaussian/Cauchy(0, I_p) according to the true distribution. If elliptical = True, G_2(t, u|b) = g_2(t)u + b, where G_2(t, x|b) generates the family of elliptical distribution, t ~ Normal(0, I) and u ~ Uniform(\\|u\\|_2 = 1) g_input_dim: (Even) number. When elliptical == True, the dimension of input for g_2(t) need to be provided. g_hidden_units: A list of hidden units for g_2(t). When elliptical == True, structure of g_2(t) need to be provided. e.g. g_hidden_units = [24, 12, 8], then g_2(t) has structure g_input_dim - 24 - 12 - 8 - p. activation_D1: 'Sigmoid', 'ReLU' or 'LeakyReLU'. The first activation function after the input layer. Especially when true_type == 'Cauchy', Sigmoid activation is preferred. verbose: Boolean. If verbose == True, initial error \\|\\hat{\\mu}_0 - \\mu\\|_2 will be printed. """ self.elliptical = elliptical self.g_input_dim = g_input_dim if self.elliptical: assert (g_input_dim % 2 == 0), 'g_input_dim should be an even number' self.netGXi = GeneratorXi(input_dim=g_input_dim, hidden_units=g_hidden_units).to( self.device) self.netG = Generator(p=self.p, elliptical=self.elliptical).to(self.device) # Initialize center parameter with sample median. if use_median_init_G: self.netG.bias.data = torch.median(self.Xtr, dim=0)[0].to(self.device) else: self.netG.bias.data = (torch.ones(self.p) * init_eta).to( self.device) self.mean_err_init = np.linalg.norm(self.netG.bias.data.cpu().numpy() -\ self.true_mean.numpy()) if verbose: print('Initialize Mean Error: %.4f' % self.mean_err_init) ## Initialize discrminator and g_2(t) when ellpitical == True if use_logistic_regression: self.netD = LogisticRegression(p=self.p).to(self.device) else: self.netD = Discriminator(p=self.p, hidden_units=d_hidden_units, activation_1=activation_D1).to( self.device) weights_init_netD = partial(weights_init, value=init_weights) self.netD.apply(weights_init_netD) if (self.elliptical): self.netGXi.apply(weights_init_xavier) def optimizer_init(self, lr_d, lr_g, d_steps, g_steps, type_opt='SGD'): """ Settings for optimizer. Args: lr_d: learning rate for discrimintaor. lr_g: learning rate for generator. d_steps: number of steps of discriminator per discriminator iteration. g_steps: number of steps of generator per generator iteration. """ if type_opt == 'SGD': self.optG = optim.SGD(self.netG.parameters(), lr=lr_g) if self.elliptical: self.optGXi = optim.SGD(self.netGXi.parameters(), lr=lr_g) self.optD = optim.SGD(self.netD.parameters(), lr=lr_d) else: self.optG = optim.Adam(self.netG.parameters(), lr=lr_g) if self.elliptical: self.optGXi = optim.Adam(self.netGXi.parameters(), lr=lr_g) self.optD = optim.Adam(self.netD.parameters(), lr=lr_d) self.g_steps = g_steps self.d_steps = d_steps def fit(self, floss='js', epochs=20, avg_epochs=10, use_inverse_gaussian=True, verbose=25): """ Training process. Args: floss: 'js' or 'tv'. For JS-GAN, we consider the original GAN with Jensen-Shannon divergence and for TV-GAN, total variation will be used. epochs: Number. Number of epochs for training. avg_epochs: Number. An average estimation using the last certain epochs. use_use_inverse_gaussian: Boolean. If elliptical == True, \\xi generator, g_2(t) takes random vector t as input and outputs \\xi samples. If use_use_inverse_gaussian == True, we take t = (t1, t2), where t1 ~ Normal(0, I_(d/2)) and t2 ~ 1/Normal(0, I_(d/2)), otherwise, t ~ Normal(0, I_d). verbose: Number. Print intermediate result every certain epochs. show: Boolean. If show == True, final result will be printed after training. """ assert floss in ['js', 'tv'], 'floss must be \'js\' or \'tv\'' if floss == 'js': criterion = nn.BCEWithLogitsLoss() self.floss = floss self.loss_D = [] self.loss_G = [] self.mean_err_record = [] self.mean_est_record = [] current_d_step = 1 for ep in range(epochs): loss_D_ep = [] loss_G_ep = [] for _, data in enumerate(self.dataloader): ## update D self.netD.train() self.netD.zero_grad() ## discriminator loss x_real = data.to(self.device) feat_real, d_real_score = self.netD(x_real) if (floss == 'js'): one_b = torch.ones_like(d_real_score).to(self.device) d_real_loss = criterion(d_real_score, one_b) elif floss == 'tv': d_real_loss = -torch.sigmoid(d_real_score).mean() #d_real_loss = criterion(d_real_score, one_b) ## generator loss z_b = torch.zeros(data.shape[0], self.p).to(self.device) if self.elliptical: if use_inverse_gaussian: xi_b1 = torch.zeros(data.shape[0], self.g_input_dim // 2).to(self.device) xi_b2 = torch.zeros(data.shape[0], self.g_input_dim // 2).to(self.device) else: xi_b = torch.zeros(data.shape[0], self.g_input_dim).to(self.device) if self.elliptical: z_b.normal_() z_b.div_(z_b.norm(2, dim=1).view(-1, 1) + self.tol) if use_inverse_gaussian: xi_b1.normal_() xi_b2.normal_() xi_b2.data = 1 / (torch.abs(xi_b2.data) + self.tol) xi = self.netGXi(torch.cat([xi_b1, xi_b2], dim=1)).view( self.batch_size, -1) else: xi_b.normal_() xi = self.netGXi(xi_b).view(self.batch_size, -1) x_fake = self.netG(z_b, xi).detach() elif (self.true_type == 'Cauchy'): z_b.normal_() z_b.data.div_( torch.sqrt(self.t_chi2_d.sample((self.batch_size, 1))).to(self.device) + self.tol) x_fake = self.netG(z_b).detach() elif self.true_type == 'Gaussian': x_fake = self.netG(z_b.normal_()).detach() feat_fake, d_fake_score = self.netD(x_fake) if floss == 'js': one_b = torch.ones_like(d_fake_score).to(self.device) d_fake_loss = criterion(d_fake_score, 1 - one_b) elif floss == 'tv': d_fake_loss = torch.sigmoid(d_fake_score).mean() d_loss = d_real_loss + d_fake_loss d_loss.backward() loss_D_ep.append(d_loss.cpu().item()) self.optD.step() if current_d_step < self.d_steps: current_d_step += 1 continue else: current_d_step = 1 ## update G self.netD.eval() for _ in range(self.g_steps): self.netG.zero_grad() if self.elliptical: self.netGXi.zero_grad() z_b.normal_() z_b.div_(z_b.norm(2, dim=1).view(-1, 1) + self.tol) if use_inverse_gaussian: xi_b1.normal_() xi_b2.normal_() xi_b2.data = 1 / (torch.abs(xi_b2.data) + self.tol) xi = self.netGXi(torch.cat([xi_b1, xi_b2], dim=1)).view( self.batch_size, -1) else: xi_b.normal_() xi = self.netGXi(xi_b).view(self.batch_size, -1) x_fake = self.netG(z_b, xi) elif self.true_type == 'Gaussian': x_fake = self.netG(z_b.normal_()) elif (self.true_type == 'Cauchy'): z_b.normal_() z_b.data.div_( torch.sqrt( self.t_chi2_d.sample((self.batch_size, 1))).to(self.device) + self.tol) x_fake = self.netG(z_b) feat_fake, g_fake_score = self.netD(x_fake) if (floss == 'js'): one_b = torch.ones_like(g_fake_score).to(self.device) g_fake_loss = -criterion(g_fake_score, 1 - one_b) g_fake_loss.backward() loss_G_ep.append(-g_fake_loss.cpu().item()) elif floss == 'tv': g_fake_loss = -torch.sigmoid(g_fake_score).mean() g_fake_loss.backward() loss_G_ep.append(g_fake_loss.cpu().item()) self.optG.step() if self.elliptical: self.optGXi.step() ## Record intermediate error during training for monitoring. self.mean_err_record.append( (self.netG.bias.data - self.true_mean.to(self.device)).norm(2).item()) ## Record intermediate estimation during training for averaging. if (ep >= (epochs - avg_epochs)): self.mean_est_record.append(self.netG.bias.data.clone().cpu()) self.loss_D.append(np.mean(loss_D_ep)) self.loss_G.append(np.mean(loss_G_ep)) ## Print intermediate result every verbose epoch. if ((ep + 1) % verbose == 0): print('Epoch:%d, LossD/G:%.4f/%.4f, Error(Mean):%.4f' % (ep + 1, self.loss_D[-1], self.loss_G[-1], self.mean_err_record[-1])) ## Final results self.mean_avg = sum(self.mean_est_record[-avg_epochs:])/\ len(self.mean_est_record[-avg_epochs:]) self.mean_err_avg = (self.mean_avg - self.true_mean.cpu()).norm(2).item() self.mean_err_last = (self.netG.bias.data - self.true_mean.to(self.device)).norm(2).item() def report_results(self, figsize=(6, 4), show_plots=True, save_g_loss=None, save_d_loss=None, save_error=None, save_distribution=None): ## Print the final results. self.netD.eval() ## Scores of true distribution from 10,000 samples. if self.true_type == 'Gaussian': t_x = self.t_d.sample((10000, )) elif self.true_type == 'Cauchy': t_normal_x = self.t_normal_d.sample((10000, )) t_chi2_x = self.t_chi2_d.sample((10000, )) t_x = t_normal_x / (torch.sqrt(t_chi2_x.view(-1, 1)) + self.tol) self.true_D = self.netD(t_x.to(self.device))[1].detach().cpu().numpy() ## Scores of contamination distribution from 10,000 samples. if self.cont_type == 'Gaussian': c_x = self.c_d.sample((10000, )) + self.cont_mean.view(1, -1) elif self.cont_type == 'Cauchy': c_normal_x = self.c_normal_d.sample((10000, )) c_chi2_x = self.c_chi2_d.sample((10000, )) c_x = c_normal_x / (torch.sqrt(c_chi2_x.view(-1, 1)) + self.tol) +\ self.cont_mean.view(1, -1) self.cont_D = self.netD(c_x.to(self.device))[1].detach().cpu().numpy() ## Scores of 10,000 generating samples. if self.elliptical: t_z = torch.randn(10000, self.p).to(self.device) t_z.div_(t_z.norm(2, dim=1).view(-1, 1) + self.tol) if use_inverse_gaussian: t_xi1 = torch.randn(10000, self.g_input_dim // 2).to(self.device) t_xi2 = torch.randn(10000, self.g_input_dim // 2).to(self.device) t_xi2 = 1 / (torch.abs(t_xi2.data) + self.tol) xi = self.netGXi(torch.cat([t_xi1, t_xi2], dim=1)).view(10000, -1) else: t_xi = torch.randn(10000, self.g_input_dim).to(self.device) xi = self.netGXi(t_xi).view(10000, -1) g_x = self.netG(t_z, xi).detach() elif self.true_type == 'Gaussian': g_x = self.netG(torch.randn(10000, self.p).to(self.device)) elif (self.true_type == 'Cauchy'): g_z = torch.randn(10000, self.p).to(self.device) g_z.data.div_( torch.sqrt(self.t_chi2_d.sample((10000, 1))).to(self.device) + self.tol) g_x = self.netG(g_z) self.gene_D = self.netD(g_x)[1].detach().cpu().numpy() ## Some useful prints and plots print('Avg error: %.4f, Last error: %.4f' % (self.mean_err_avg, self.mean_err_last)) grand_mean = (1 - self.eps) * self.true_mean + self.eps * self.cont_mean grand_mean_err = (grand_mean.to(self.device) - self.true_mean.to(self.device)).norm(2).item() grand_mean_err_record = [ grand_mean_err for i in range(len(self.mean_err_record)) ] if self.p == 1: print("True mean = %.4f" % (self.true_mean.item())) print("Contamination mean = %.4f" % (self.cont_mean.item())) print("Result mean = %.4f" % (self.netG.bias.data.item())) print("Grand mean = %.4f" % (grand_mean.item())) loss_type = 'Total Variation' if self.floss == 'tv' else 'Jensen-Shannon' fig, ax = plt.subplots(figsize=figsize) ax.plot(self.loss_D) ax.grid(True) ax.set_title(f'Discriminator loss, type = {loss_type}') ax.set_xlabel("epoch num") ax.set_ylabel("Loss") if save_d_loss is not None: plt.savefig(save_d_loss) if show_plots: plt.show() else: plt.close(fig) fig, ax = plt.subplots(figsize=figsize) ax.plot(self.loss_G) ax.grid(True) ax.set_title(f'Generator loss, type = {loss_type}') ax.set_xlabel("epoch num") ax.set_ylabel("Loss") if save_g_loss is not None: plt.savefig(save_g_loss) if show_plots: plt.show() else: plt.close(fig) fig, ax = plt.subplots(figsize=figsize) ax.plot(self.mean_err_record, label='mean error process') ax.plot(grand_mean_err_record, label='grand mean error') ax.legend() ax.grid(True) ax.set_title( r'$\ell_{2}$ error in prediction of mean for true distribution') ax.set_xlabel("epoch num") ax.set_ylabel(r"$\|\eta_{est} - \eta_{true}\|_{2}$") if save_error is not None: plt.savefig(save_error) if show_plots: plt.show() else: plt.close(fig) fig, ax = plt.subplots(figsize=figsize) d_distributions = {} d_distributions['true distribution'] = self.true_D[(self.true_D < 25) & (self.true_D > -25)] d_distributions['generated distribution'] = self.gene_D[ (self.gene_D < 25) & (self.gene_D > -25)] d_distributions['contamination distribution'] = self.cont_D[ (self.cont_D < 25) & (self.cont_D > -25)] g = sns.kdeplot(ax=ax, data=d_distributions) ax.set_xlabel(r"$D(x)$") ax.set_ylabel("Density") ax.grid(True) ax.set_title(r'Discriminator distribution, $D(x)$') if save_distribution is not None: plt.savefig(save_distribution) if show_plots: plt.show() else: plt.close(fig)
discriminator = Discriminator().to(device) generator = torch.nn.DataParallel(generator, list(range(torch.cuda.device_count()))) discriminator = torch.nn.DataParallel(discriminator, list(range(torch.cuda.device_count()))) if opt['load_model']: if os.path.isfile("saved_models/generator.pth"): generator.load_state_dict(torch.load("saved_models/generator.pth")) if os.path.isfile("saved_models/discriminator.pth"): discriminator.load_state_dict( torch.load("saved_models/discriminator.pth")) else: generator.apply(weights_init_normal) discriminator.apply(weights_init_normal) optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt["lr"], betas=(opt["b1"], opt["b2"])) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt["lr"], betas=(opt["b1"], opt["b2"])) for epoch in range(opt['n_epochs']): for i in range(25000 // opt['batch_size']): y, x = next(data.data_generator()) real_A = Variable(x.type(Tensor)) real_B = Variable(y.type(Tensor))
training_log = { 'time': time.time(), 'rounds': [], } def weights_init(m): # pass classname = m.__class__.__name__ if 'Linear' in classname: nn.init.normal_(m.weight.data, 0.048, 0.48) generator.apply(weights_init) discriminator.apply(weights_init) color_file = open('data/color.txt') color_file.seek(0, os.SEEK_END) color_file_size = color_file.tell() def random_color_file_seek(): color_file.seek(random.randint(0, color_file_size)) color_file.readline() def get_real_color_tensor(): line = color_file.readline() if line == '': color_file.seek(0)
train_dataset = datasets.CIFAR10('./data/', train=True, download=True, transform=train_transforms) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) # 3. Networks G = Generator(image_channels).to(device) D = Discriminator(image_channels).to(device) G.apply(initialize_weights) D.apply(initialize_weights) loss_fn = nn.BCELoss().to(device) # 4. Optimizers optimizer_for_G = torch.optim.Adam(G.parameters(), lr=learning_rate, betas=(beta1, beta2)) optimizer_for_D = torch.optim.Adam(D.parameters(), lr=learning_rate, betas=(beta1, beta2)) # 5. Training fake_gt = np.zeros((batch_size, 1, 1, 1), dtype=np.float32) fake_gt = torch.FloatTensor(fake_gt).to(device) fake_gt = torch.autograd.Variable(fake_gt, requires_grad=False)
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) discriminator = Discriminator() discriminator.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) discriminator.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.discriminator_save))) """ variable definition """ real_domain_labels = 1 fake_domain_labels = 0 X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) domain_labels = torch.LongTensor(FLAGS.batch_size) style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim) """ loss definitions """ cross_entropy_loss = nn.CrossEntropyLoss() ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() discriminator.cuda() cross_entropy_loss.cuda() X_1 = X_1.cuda() X_2 = X_2.cuda() X_3 = X_3.cuda() domain_labels = domain_labels.cuda() style_latent_space = style_latent_space.cuda() """ optimizer definition """ auto_encoder_optimizer = optim.Adam( list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) discriminator_optimizer = optim.Adam( list(discriminator.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) generator_optimizer = optim.Adam( list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") if not os.path.exists('checkpoints'): os.makedirs('checkpoints') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write('Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\t') log.write('Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n') # load data set and create data loader instance print('Loading MNIST paired dataset...') paired_mnist = MNIST_Paired(root='mnist', download=True, train=True, transform=transform_config) loader = cycle(DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # initialise variables discriminator_accuracy = 0. # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print('Epoch #' + str(epoch) + '..........................................................................') for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)): # A. run the auto-encoder reconstruction image_batch_1, image_batch_2, _ = next(loader) auto_encoder_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_mu_1, style_logvar_1, class_1 = encoder(Variable(X_1)) style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1) kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp()) kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_1.backward(retain_graph=True) _, __, class_2 = encoder(Variable(X_2)) reconstructed_X_1 = decoder(style_1, class_1) reconstructed_X_2 = decoder(style_1, class_2) reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1)) reconstruction_error_1.backward(retain_graph=True) reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_1)) reconstruction_error_2.backward() reconstruction_error = reconstruction_error_1 + reconstruction_error_2 kl_divergence_error = kl_divergence_loss_1 auto_encoder_optimizer.step() # B. run the generator for i in range(FLAGS.generator_times): generator_optimizer.zero_grad() image_batch_1, _, __ = next(loader) image_batch_3, _, __ = next(loader) domain_labels.fill_(real_domain_labels) X_1.copy_(image_batch_1) X_3.copy_(image_batch_3) style_mu_1, style_logvar_1, _ = encoder(Variable(X_1)) style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1) kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp()) kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_1.backward(retain_graph=True) _, __, class_3 = encoder(Variable(X_3)) reconstructed_X_1_3 = decoder(style_1, class_3) output_1 = discriminator(Variable(X_3), reconstructed_X_1_3) generator_error_1 = cross_entropy_loss(output_1, Variable(domain_labels)) generator_error_1.backward(retain_graph=True) style_latent_space.normal_(0., 1.) reconstructed_X_latent_3 = decoder(Variable(style_latent_space), class_3) output_2 = discriminator(Variable(X_3), reconstructed_X_latent_3) generator_error_2 = cross_entropy_loss(output_2, Variable(domain_labels)) generator_error_2.backward() generator_error = generator_error_1 + generator_error_2 kl_divergence_error += kl_divergence_loss_1 generator_optimizer.step() # C. run the discriminator for i in range(FLAGS.discriminator_times): discriminator_optimizer.zero_grad() # train discriminator on real data domain_labels.fill_(real_domain_labels) image_batch_1, _, __ = next(loader) image_batch_2, image_batch_3, _ = next(loader) X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) X_3.copy_(image_batch_3) real_output = discriminator(Variable(X_2), Variable(X_3)) discriminator_real_error = cross_entropy_loss(real_output, Variable(domain_labels)) discriminator_real_error.backward() # train discriminator on fake data domain_labels.fill_(fake_domain_labels) style_mu_1, style_logvar_1, _ = encoder(Variable(X_1)) style_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) _, __, class_3 = encoder(Variable(X_3)) reconstructed_X_1_3 = decoder(style_1, class_3) fake_output = discriminator(Variable(X_3), reconstructed_X_1_3) discriminator_fake_error = cross_entropy_loss(fake_output, Variable(domain_labels)) discriminator_fake_error.backward() # total discriminator error discriminator_error = discriminator_real_error + discriminator_fake_error # calculate discriminator accuracy for this step target_true_labels = torch.cat((torch.ones(FLAGS.batch_size), torch.zeros(FLAGS.batch_size)), dim=0) if FLAGS.cuda: target_true_labels = target_true_labels.cuda() discriminator_predictions = torch.cat((real_output, fake_output), dim=0) _, discriminator_predictions = torch.max(discriminator_predictions, 1) discriminator_accuracy = (discriminator_predictions.data == target_true_labels.long() ).sum().item() / (FLAGS.batch_size * 2) if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy: discriminator_optimizer.step() if (iteration + 1) % 50 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0])) print('') print('Generator loss: ' + str(generator_error.data.storage().tolist()[0])) print('Discriminator loss: ' + str(discriminator_error.data.storage().tolist()[0])) print('Discriminator accuracy: ' + str(discriminator_accuracy)) print('..........') # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], kl_divergence_error.data.storage().tolist()[0], generator_error.data.storage().tolist()[0], discriminator_error.data.storage().tolist()[0], discriminator_accuracy )) # write to tensorboard writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Generator loss', generator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator loss', discriminator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator accuracy', discriminator_accuracy * 100, epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) # save model after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save)) torch.save(discriminator.state_dict(), os.path.join('checkpoints', FLAGS.discriminator_save))
normalize=True, range=(-1., 1.)) vutils.save_image(fixed_annos.float() / n_classes, join(sample_path, '{:03d}_anno.jpg'.format(0)), nrow=4, padding=0) # Models E = Encoder().to(device) E.apply(init_weights) # summary(E, (3, 256, 256), device=device) G = Generator(n_classes).to(device) G.apply(init_weights) # summary(G, [(256,), (10, 256, 256)], device=device) D = Discriminator(n_classes).to(device) D.apply(init_weights) # summary(D, (13, 256, 256), device=device) vgg = VGG().to(device) if args.multi_gpu: E = nn.DataParallel(E) G = nn.DataParallel(G) # G = convert_model(G) D = nn.DataParallel(D) VGG = nn.DataParallel(VGG) # Optimizers G_opt = optim.Adam(itertools.chain(G.parameters(), E.parameters()), lr=args.lr_G, betas=(args.beta1, args.beta2)) D_opt = optim.Adam(D.parameters(),
def init_training(args): """Initialize the data loader, the networks, the optimizers and the loss functions.""" datasets = Cifar10Dataset.get_datasets_from_scratch(args.data_path) for phase in ['train', 'test']: print('{} dataset len: {}'.format(phase, len(datasets[phase]))) # define loaders data_loaders = { 'train': DataLoader(datasets['train'], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers), 'test': DataLoader(datasets['test'], batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) } # check CUDA availability and set device device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print('Use GPU: {}'.format(str(device) != 'cpu')) # set up models generator = Generator(args.gen_norm).to(device) discriminator = Discriminator(args.disc_norm).to(device) # initialize weights if args.apply_weight_init: generator.apply(weights_init_normal) discriminator.apply(weights_init_normal) # adam optimizer with reduced momentum optimizers = { 'gen': torch.optim.Adam(generator.parameters(), lr=args.base_lr_gen, betas=(0.5, 0.999)), 'disc': torch.optim.Adam(discriminator.parameters(), lr=args.base_lr_disc, betas=(0.5, 0.999)) } # losses losses = { 'l1': torch.nn.L1Loss(reduction='mean'), 'disc': torch.nn.BCELoss(reduction='mean') } # make save dir, if it does not exists if not os.path.exists(args.save_path): os.makedirs(args.save_path) # load weights if the training is not starting from the beginning global_step = args.start_epoch * len( data_loaders['train']) if args.start_epoch > 0 else 0 if args.start_epoch > 0: generator.load_state_dict( torch.load(os.path.join( args.save_path, 'checkpoint_ep{}_gen.pt'.format(args.start_epoch - 1)), map_location=device)) discriminator.load_state_dict( torch.load(os.path.join( args.save_path, 'checkpoint_ep{}_disc.pt'.format(args.start_epoch - 1)), map_location=device)) return global_step, device, data_loaders, generator, discriminator, optimizers, losses
else: USE_CUDA = False dataset = CustomDataset(opt) data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=opt.batch_size, shuffle=opt.shuffle, num_workers=opt.n_workers) print(len(data_loader)) G = Generator(opt) D = Discriminator(opt) G.apply(weight_init) D.apply(weight_init) print(G) print(D) if USE_CUDA: G = G.cuda() D = D.cuda() G_optim = torch.optim.Adam(G.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) D_optim = torch.optim.Adam(D.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim) encoder.apply(weights_init) decoder = Decoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim) decoder.apply(weights_init) discriminator = Discriminator() discriminator.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) discriminator.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.discriminator_save))) """ variable definition """ real_domain_labels = 1 fake_domain_labels = 0 X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) domain_labels = torch.LongTensor(FLAGS.batch_size) """ loss definitions """ cross_entropy_loss = nn.CrossEntropyLoss() ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() discriminator.cuda() cross_entropy_loss.cuda() X_1 = X_1.cuda() X_2 = X_2.cuda() X_3 = X_3.cuda() domain_labels = domain_labels.cuda() """ optimizer definition """ auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) discriminator_optimizer = optim.Adam(list(discriminator.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) generator_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) if not os.path.exists('checkpoints'): os.makedirs('checkpoints') if not os.path.exists('reconstructed_images'): os.makedirs('reconstructed_images') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write('Epoch\tIteration\tReconstruction_loss\t') log.write( 'Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n') # load data set and create data loader instance print('Loading MNIST paired dataset...') paired_mnist = MNIST_Paired(root='mnist', download=True, train=True, transform=transform_config) loader = cycle( DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # initialise variables discriminator_accuracy = 0. # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print( 'Epoch #' + str(epoch) + '..........................................................................' ) for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)): # A. run the auto-encoder reconstruction image_batch_1, image_batch_2, labels_batch_1 = next(loader) auto_encoder_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) nv_1, nc_1 = encoder(Variable(X_1)) nv_2, nc_2 = encoder(Variable(X_2)) reconstructed_X_1 = decoder(nv_1, nc_2) reconstructed_X_2 = decoder(nv_2, nc_1) reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1)) reconstruction_error_1.backward(retain_graph=True) reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_2)) reconstruction_error_2.backward() reconstruction_error = reconstruction_error_1 + reconstruction_error_2 if FLAGS.train_auto_encoder: auto_encoder_optimizer.step() # B. run the adversarial part of the architecture # B. a) run the discriminator for i in range(FLAGS.discriminator_times): discriminator_optimizer.zero_grad() # train discriminator on real data domain_labels.fill_(real_domain_labels) image_batch_1, image_batch_2, labels_batch_1 = next(loader) X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) real_output = discriminator(Variable(X_1), Variable(X_2)) discriminator_real_error = FLAGS.disc_coef * cross_entropy_loss( real_output, Variable(domain_labels)) discriminator_real_error.backward() # train discriminator on fake data domain_labels.fill_(fake_domain_labels) image_batch_3, _, labels_batch_3 = next(loader) X_3.copy_(image_batch_3) nv_3, nc_3 = encoder(Variable(X_3)) # reconstruction is taking common factor from X_1 and varying factor from X_3 reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1]) fake_output = discriminator(Variable(X_1), reconstructed_X_3_1) discriminator_fake_error = FLAGS.disc_coef * cross_entropy_loss( fake_output, Variable(domain_labels)) discriminator_fake_error.backward() # total discriminator error discriminator_error = discriminator_real_error + discriminator_fake_error # calculate discriminator accuracy for this step target_true_labels = torch.cat((torch.ones( FLAGS.batch_size), torch.zeros(FLAGS.batch_size)), dim=0) if FLAGS.cuda: target_true_labels = target_true_labels.cuda() discriminator_predictions = torch.cat( (real_output, fake_output), dim=0) _, discriminator_predictions = torch.max( discriminator_predictions, 1) discriminator_accuracy = (discriminator_predictions.data == target_true_labels.long()).sum( ).item() / (FLAGS.batch_size * 2) if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy and FLAGS.train_discriminator: discriminator_optimizer.step() # B. b) run the generator for i in range(FLAGS.generator_times): generator_optimizer.zero_grad() image_batch_1, _, labels_batch_1 = next(loader) image_batch_3, __, labels_batch_3 = next(loader) domain_labels.fill_(real_domain_labels) X_1.copy_(image_batch_1) X_3.copy_(image_batch_3) nv_3, nc_3 = encoder(Variable(X_3)) # reconstruction is taking common factor from X_1 and varying factor from X_3 reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1]) output = discriminator(Variable(X_1), reconstructed_X_3_1) generator_error = FLAGS.gen_coef * cross_entropy_loss( output, Variable(domain_labels)) generator_error.backward() if FLAGS.train_generator: generator_optimizer.step() # print progress after 10 iterations if (iteration + 1) % 10 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('Generator loss: ' + str(generator_error.data.storage().tolist()[0])) print('') print('Discriminator loss: ' + str(discriminator_error.data.storage().tolist()[0])) print('Discriminator accuracy: ' + str(discriminator_accuracy)) print('..........') # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], generator_error.data.storage().tolist()[0], discriminator_error.data.storage().tolist()[0], discriminator_accuracy)) # write to tensorboard writer.add_scalar( 'Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Generator loss', generator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Discriminator loss', discriminator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) # save model after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save)) torch.save(discriminator.state_dict(), os.path.join('checkpoints', FLAGS.discriminator_save)) """ save reconstructed images and style swapped image generations to check progress """ image_batch_1, image_batch_2, labels_batch_1 = next(loader) image_batch_3, _, __ = next(loader) X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) X_3.copy_(image_batch_3) nv_1, nc_1 = encoder(Variable(X_1)) nv_2, nc_2 = encoder(Variable(X_2)) nv_3, nc_3 = encoder(Variable(X_3)) reconstructed_X_1 = decoder(nv_1, nc_2) reconstructed_X_3_2 = decoder(nv_3, nc_2) # save input image batch image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1)) image_batch = np.concatenate( (image_batch, image_batch, image_batch), axis=3) imshow_grid(image_batch, name=str(epoch) + '_original', save=True) # save reconstructed batch reconstructed_x = np.transpose( reconstructed_X_1.cpu().data.numpy(), (0, 2, 3, 1)) reconstructed_x = np.concatenate( (reconstructed_x, reconstructed_x, reconstructed_x), axis=3) imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True) # save cross reconstructed batch style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1)) style_batch = np.concatenate( (style_batch, style_batch, style_batch), axis=3) imshow_grid(style_batch, name=str(epoch) + '_style', save=True) reconstructed_style = np.transpose( reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1)) reconstructed_style = np.concatenate( (reconstructed_style, reconstructed_style, reconstructed_style), axis=3) imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) discriminator = Discriminator() discriminator.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: raise Exception('This is not implemented') encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) """ variable definition """ X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim) """ loss definitions """ cross_entropy_loss = nn.CrossEntropyLoss() adversarial_loss = nn.BCELoss() ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() discriminator.cuda() cross_entropy_loss.cuda() adversarial_loss.cuda() X_1 = X_1.cuda() X_2 = X_2.cuda() X_3 = X_3.cuda() style_latent_space = style_latent_space.cuda() """ optimizer and scheduler definition """ auto_encoder_optimizer = optim.Adam( list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) reverse_cycle_optimizer = optim.Adam( list(encoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) generator_optimizer = optim.Adam( list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) discriminator_optimizer = optim.Adam( list(discriminator.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) # divide the learning rate by a factor of 10 after 80 epochs auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer, step_size=80, gamma=0.1) reverse_cycle_scheduler = optim.lr_scheduler.StepLR(reverse_cycle_optimizer, step_size=80, gamma=0.1) generator_scheduler = optim.lr_scheduler.StepLR(generator_optimizer, step_size=80, gamma=0.1) discriminator_scheduler = optim.lr_scheduler.StepLR(discriminator_optimizer, step_size=80, gamma=0.1) # Used later to define discriminator ground truths Tensor = torch.cuda.FloatTensor if FLAGS.cuda else torch.FloatTensor """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") if not os.path.exists('checkpoints'): os.makedirs('checkpoints') if not os.path.exists('reconstructed_images'): os.makedirs('reconstructed_images') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: headers = ['Epoch', 'Iteration', 'Reconstruction_loss', 'KL_divergence_loss', 'Reverse_cycle_loss'] if FLAGS.forward_gan: headers.extend(['Generator_forward_loss', 'Discriminator_forward_loss']) if FLAGS.reverse_gan: headers.extend(['Generator_reverse_loss', 'Discriminator_reverse_loss']) log.write('\t'.join(headers) + '\n') # load data set and create data loader instance print('Loading CIFAR paired dataset...') paired_cifar = CIFAR_Paired(root='cifar', download=True, train=True, transform=transform_config) loader = cycle(DataLoader(paired_cifar, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # Save a batch of images to use for visualization image_sample_1, image_sample_2, _ = next(loader) image_sample_3, _, _ = next(loader) # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print('Epoch #' + str(epoch) + '..........................................................................') # update the learning rate scheduler auto_encoder_scheduler.step() reverse_cycle_scheduler.step() generator_scheduler.step() discriminator_scheduler.step() for iteration in range(int(len(paired_cifar) / FLAGS.batch_size)): # Adversarial ground truths valid = Variable(Tensor(FLAGS.batch_size, 1).fill_(1.0), requires_grad=False) fake = Variable(Tensor(FLAGS.batch_size, 1).fill_(0.0), requires_grad=False) # A. run the auto-encoder reconstruction image_batch_1, image_batch_2, _ = next(loader) auto_encoder_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_mu_1, style_logvar_1, class_latent_space_1 = encoder(Variable(X_1)) style_latent_space_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1) kl_divergence_loss_1 = FLAGS.kl_divergence_coef * ( - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp()) ) kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_1.backward(retain_graph=True) style_mu_2, style_logvar_2, class_latent_space_2 = encoder(Variable(X_2)) style_latent_space_2 = reparameterize(training=True, mu=style_mu_2, logvar=style_logvar_2) kl_divergence_loss_2 = FLAGS.kl_divergence_coef * ( - 0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) - style_logvar_2.exp()) ) kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_2.backward(retain_graph=True) reconstructed_X_1 = decoder(style_latent_space_1, class_latent_space_2) reconstructed_X_2 = decoder(style_latent_space_2, class_latent_space_1) reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_1, Variable(X_1)) reconstruction_error_1.backward(retain_graph=True) reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_2, Variable(X_2)) reconstruction_error_2.backward() reconstruction_error = (reconstruction_error_1 + reconstruction_error_2) / FLAGS.reconstruction_coef kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2) / FLAGS.kl_divergence_coef auto_encoder_optimizer.step() # A-1. Discriminator training during forward cycle if FLAGS.forward_gan: # Training generator generator_optimizer.zero_grad() g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid) g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid) gen_f_loss = (g_loss_1 + g_loss_2) / 2.0 gen_f_loss.backward() generator_optimizer.step() # Training discriminator discriminator_optimizer.zero_grad() real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid) real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid) fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake) fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake) dis_f_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0 dis_f_loss.backward() discriminator_optimizer.step() # B. reverse cycle image_batch_1, _, __ = next(loader) image_batch_2, _, __ = next(loader) reverse_cycle_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_latent_space.normal_(0., 1.) _, __, class_latent_space_1 = encoder(Variable(X_1)) _, __, class_latent_space_2 = encoder(Variable(X_2)) reconstructed_X_1 = decoder(Variable(style_latent_space), class_latent_space_1.detach()) reconstructed_X_2 = decoder(Variable(style_latent_space), class_latent_space_2.detach()) style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1) style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2) style_latent_space_2 = reparameterize(training=False, mu=style_mu_2, logvar=style_logvar_2) reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(style_latent_space_1, style_latent_space_2) reverse_cycle_loss.backward() reverse_cycle_loss /= FLAGS.reverse_cycle_coef reverse_cycle_optimizer.step() # B-1. Discriminator training during reverse cycle if FLAGS.reverse_gan: # Training generator generator_optimizer.zero_grad() g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid) g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid) gen_r_loss = (g_loss_1 + g_loss_2) / 2.0 gen_r_loss.backward() generator_optimizer.step() # Training discriminator discriminator_optimizer.zero_grad() real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid) real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid) fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake) fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake) dis_r_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0 dis_r_loss.backward() discriminator_optimizer.step() if (iteration + 1) % 10 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0])) print('Reverse cycle loss: ' + str(reverse_cycle_loss.data.storage().tolist()[0])) if FLAGS.forward_gan: print('Generator F loss: ' + str(gen_f_loss.data.storage().tolist()[0])) print('Discriminator F loss: ' + str(dis_f_loss.data.storage().tolist()[0])) if FLAGS.reverse_gan: print('Generator R loss: ' + str(gen_r_loss.data.storage().tolist()[0])) print('Discriminator R loss: ' + str(dis_r_loss.data.storage().tolist()[0])) # write to log with open(FLAGS.log_file, 'a') as log: row = [] row.append(epoch) row.append(iteration) row.append(reconstruction_error.data.storage().tolist()[0]) row.append(kl_divergence_error.data.storage().tolist()[0]) row.append(reverse_cycle_loss.data.storage().tolist()[0]) if FLAGS.forward_gan: row.append(gen_f_loss.data.storage().tolist()[0]) row.append(dis_f_loss.data.storage().tolist()[0]) if FLAGS.reverse_gan: row.append(gen_r_loss.data.storage().tolist()[0]) row.append(dis_r_loss.data.storage().tolist()[0]) row = [str(x) for x in row] log.write('\t'.join(row) + '\n') # write to tensorboard writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Reverse cycle loss', reverse_cycle_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) if FLAGS.forward_gan: writer.add_scalar('Generator F loss', gen_f_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator F loss', dis_f_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) if FLAGS.reverse_gan: writer.add_scalar('Generator R loss', gen_r_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator R loss', dis_r_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) # save model after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save)) """ save reconstructed images and style swapped image generations to check progress """ X_1.copy_(image_sample_1) X_2.copy_(image_sample_2) X_3.copy_(image_sample_3) style_mu_1, style_logvar_1, _ = encoder(Variable(X_1)) _, __, class_latent_space_2 = encoder(Variable(X_2)) style_mu_3, style_logvar_3, _ = encoder(Variable(X_3)) style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) style_latent_space_3 = reparameterize(training=False, mu=style_mu_3, logvar=style_logvar_3) reconstructed_X_1_2 = decoder(style_latent_space_1, class_latent_space_2) reconstructed_X_3_2 = decoder(style_latent_space_3, class_latent_space_2) # save input image batch image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: image_batch = np.concatenate((image_batch, image_batch, image_batch), axis=3) imshow_grid(image_batch, name=str(epoch) + '_original', save=True) # save reconstructed batch reconstructed_x = np.transpose(reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: reconstructed_x = np.concatenate((reconstructed_x, reconstructed_x, reconstructed_x), axis=3) imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True) style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: style_batch = np.concatenate((style_batch, style_batch, style_batch), axis=3) imshow_grid(style_batch, name=str(epoch) + '_style', save=True) # save style swapped reconstructed batch reconstructed_style = np.transpose(reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: reconstructed_style = np.concatenate((reconstructed_style, reconstructed_style, reconstructed_style), axis=3) imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
def init_training(args): """Initialize the data loader, the networks, the optimizers and the loss functions.""" datasets = dict() datasets['train'] = customed_dataset(img_path = args.train_data_path, img_size = args.img_size, km_file_path = args.km_file_path) datasets['val'] = customed_dataset(img_path = args.val_data_path, img_size = args.img_size,km_file_path = args.km_file_path) for phase in ['train', 'val']: print('{} dataset len: {}'.format(phase, len(datasets[phase]))) # define loaders data_loaders = { 'train': DataLoader(datasets['train'], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers), 'val': DataLoader(datasets['val'], batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) } # check CUDA availability and set device device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print('Use GPU: {}'.format(str(device) != 'cpu')) # set up models if args.use_memory == True: mem = Memory_Network(mem_size = args.mem_size, color_feat_dim = args.color_feat_dim, spatial_feat_dim = args.spatial_feat_dim, top_k = args.top_k, alpha = args.alpha).to(device) feature_integrator = Feature_Integrator(3, 1, 200).to(device) generator = Generator(args.color_feat_dim, args.img_size, args.gen_norm).to(device) discriminator = Discriminator(args.color_feat_dim, args.img_size, args.dis_norm).to(device) # initialize weights if args.apply_weight_init == True: generator.apply(weights_init_normal) discriminator.apply(weights_init_normal) # set networks as training mode generator = generator.train() discriminator = discriminator.train() if args.use_memory == True: mem = mem.train() feature_integrator = feature_integrator.train() # adam optimizer if args.use_memory == True: optimizers = { 'gen': torch.optim.Adam(generator.parameters(), lr=args.base_lr_gen, betas=(0.5, 0.999)), 'disc': torch.optim.Adam(discriminator.parameters(), lr=args.base_lr_disc, betas=(0.5, 0.999)), 'mem': torch.optim.Adam(mem.parameters(), lr = args.base_lr_mem), 'feat': torch.optim.Adam(feature_integrator.parameters(), lr = args.base_lr_feat) } else: optimizers = { 'gen': torch.optim.Adam(generator.parameters(), lr=args.base_lr_gen), 'disc': torch.optim.Adam(discriminator.parameters(), lr=args.base_lr_disc), } # losses losses = { 'l1': torch.nn.L1Loss(reduction='mean'), 'disc': torch.nn.BCEWithLogitsLoss(reduction='mean'), 'smoothl1': torch.nn.SmoothL1Loss(reduction='mean'), 'KLD': torch.nn.KLDivLoss(reduction='batchmean') } # make save dir, if it does not exists if not os.path.exists(args.save_path): os.makedirs(args.save_path) # load weights if the training is not starting from the beginning global_step = args.start_epoch * len(data_loaders['train']) if args.start_epoch > 0 else 0 if args.start_epoch > 0: generator.load_state_dict(torch.load( os.path.join(args.save_path, 'checkpoint_ep{}_gen.pt'.format(args.start_epoch - 1)), map_location=device )) discriminator.load_state_dict(torch.load( os.path.join(args.save_path, 'checkpoint_ep{}_disc.pt'.format(args.start_epoch - 1)), map_location=device )) mem_checkpoint = torch.load(os.path.join(args.save_path, 'checkpoint_ep{}_mem.pt'.format(args.start_epoch - 1)), map_location=device) mem.load_state_dict(mem_checkpoint['mem_model']) mem.sptial_key = mem_checkpoint['mem_key'] mem.color_value = mem_checkpoint['mem_value'] mem.age = mem_checkpoint['mem_age'] mem.img_id = mem_checkpoint['img_id'] feature_integrator.load_state_dict(torch.load( os.path.join(args.save_path, 'checkpoint_ep{}_feat.pt'.format(args.start_epoch - 1)), map_location=device )) if args.use_memory == True: return global_step, device, data_loaders, mem, feature_integrator, generator, discriminator, optimizers, losses else: return global_step, device, data_loaders, generator, discriminator, optimizers, losses
def get_models(self): G = Generator().to(self.device) D = Discriminator().to(self.device) G.apply(weights_init) D.apply(weights_init) return G, D
def main(args, dataloader): # define the networks netG = Generator(ngf=args.ngf, nz=args.nz, nc=args.nc).cuda() netG.apply(weight_init) print(netG) netD = Discriminator(ndf=args.ndf, nc=args.nc, nz=args.nz).cuda() netD.apply(weight_init) print(netD) netE = Encoder(nc=args.nc, ngf=args.ngf, nz=args.nz).cuda() netE.apply(weight_init) print(netE) # define the loss criterion criterion = nn.BCELoss() # define the ground truth labels. real_label = 1 # for the real pair fake_label = 0 # for the fake pair # define the optimizers, one for each network netD_optimizer = optim.Adam(netD.parameters(), lr=args.lr, betas=(0.5, 0.999)) netG_optimizer = optim.Adam([{ 'params': netG.parameters() }, { 'params': netE.parameters() }], lr=args.lr, betas=(0.5, 0.999)) # Training loop iters = 0 for epoch in range(args.num_epochs): # iterate through the dataloader for i, data in enumerate(dataloader, 0): real_images = data[0].cuda() bs = real_images.shape[0] noise1 = torch.Tensor(real_images.size()).normal_( 0, 0.1 * (args.num_epochs - epoch) / args.num_epochs).cuda() noise2 = torch.Tensor(real_images.size()).normal_( 0, 0.1 * (args.num_epochs - epoch) / args.num_epochs).cuda() # get the output from the encoder z_real = netE(real_images).view(bs, -1) mu, sigma = z_real[:, :args.nz], z_real[:, args.nz:] log_sigma = torch.exp(sigma) epsilon = torch.randn(bs, args.nz).cuda() # reparameterization trick output_z = mu + epsilon * log_sigma output_z = output_z.view(bs, -1, 1, 1) # get the output from the generator z_fake = torch.randn(bs, args.nz, 1, 1).cuda() d_fake = netG(z_fake) # get the output from the discriminator for the real pair out_real_pair = netD(real_images + noise1, output_z) # get the output from the discriminator for the fake pair out_fake_pair = netD(d_fake + noise2, z_fake) real_labels = torch.full((bs, ), real_label).cuda() fake_labels = torch.full((bs, ), fake_label).cuda() # compute the losses d_loss = criterion(out_real_pair, real_labels) + criterion( out_fake_pair, fake_labels) g_loss = criterion(out_real_pair, fake_labels) + criterion( out_fake_pair, real_labels) # update weights if g_loss.item() < 3.5: netD_optimizer.zero_grad() d_loss.backward(retain_graph=True) netD_optimizer.step() netG_optimizer.zero_grad() g_loss.backward() netG_optimizer.step() # print the training losses if iters % 10 == 0: print( '[%3d/%d][%3d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x, z): %.4f\tD(G(z), z): %.4f' % (epoch, args.num_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), out_real_pair.mean().item(), out_fake_pair.mean().item())) # visualize the samples generated by the G. if iters % 500 == 0: out_dir = os.path.join(args.log_dir, args.run_name, 'out/') os.makedirs(out_dir, exist_ok=True) save_image(d_fake.cpu()[:64, ], os.path.join(out_dir, str(iters).zfill(7) + '.png'), nrow=8, normalize=True) # save reconstructions recons_dir = os.path.join(args.log_dir, args.run_name, 'recons/') os.makedirs(recons_dir, exist_ok=True) save_image(torch.cat( [real_images.cpu()[:8], d_fake.cpu()[:8, ]], dim=3), os.path.join(recons_dir, str(iters).zfill(7) + '.png'), nrow=1, normalize=True) iters += 1 # save weights save_dir = os.path.join(args.log_dir, args.run_name, 'weights') os.makedirs(save_dir, exist_ok=True) save_weights(netG, './%s/netG.pth' % (save_dir)) save_weights(netE, './%s/netE.pth' % (save_dir))