def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) """ variable definition """ X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim) """ loss definitions """ cross_entropy_loss = nn.CrossEntropyLoss() ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() cross_entropy_loss.cuda() X_1 = X_1.cuda() X_2 = X_2.cuda() X_3 = X_3.cuda() style_latent_space = style_latent_space.cuda() """ optimizer and scheduler definition """ auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) reverse_cycle_optimizer = optim.Adam(list(encoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) # divide the learning rate by a factor of 10 after 80 epochs auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer, step_size=80, gamma=0.1) reverse_cycle_scheduler = optim.lr_scheduler.StepLR( reverse_cycle_optimizer, step_size=80, gamma=0.1) """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) if not os.path.exists('checkpoints'): os.makedirs('checkpoints') if not os.path.exists('reconstructed_images'): os.makedirs('reconstructed_images') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write( 'Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\tReverse_cycle_loss\n' ) # load data set and create data loader instance print('Loading MNIST paired dataset...') paired_mnist = MNIST_Paired(root='mnist', download=True, train=True, transform=transform_config) loader = cycle( DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print( 'Epoch #' + str(epoch) + '..........................................................................' ) # update the learning rate scheduler auto_encoder_scheduler.step() reverse_cycle_scheduler.step() for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)): # A. run the auto-encoder reconstruction image_batch_1, image_batch_2, _ = next(loader) auto_encoder_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_mu_1, style_logvar_1, class_latent_space_1 = encoder( Variable(X_1)) style_latent_space_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1) kl_divergence_loss_1 = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())) kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_1.backward(retain_graph=True) style_mu_2, style_logvar_2, class_latent_space_2 = encoder( Variable(X_2)) style_latent_space_2 = reparameterize(training=True, mu=style_mu_2, logvar=style_logvar_2) kl_divergence_loss_2 = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) - style_logvar_2.exp())) kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_2.backward(retain_graph=True) reconstructed_X_1 = decoder(style_latent_space_1, class_latent_space_2) reconstructed_X_2 = decoder(style_latent_space_2, class_latent_space_1) reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss( reconstructed_X_1, Variable(X_1)) reconstruction_error_1.backward(retain_graph=True) reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss( reconstructed_X_2, Variable(X_2)) reconstruction_error_2.backward() reconstruction_error = ( reconstruction_error_1 + reconstruction_error_2) / FLAGS.reconstruction_coef kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2 ) / FLAGS.kl_divergence_coef auto_encoder_optimizer.step() # B. reverse cycle image_batch_1, _, __ = next(loader) image_batch_2, _, __ = next(loader) reverse_cycle_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_latent_space.normal_(0., 1.) _, __, class_latent_space_1 = encoder(Variable(X_1)) _, __, class_latent_space_2 = encoder(Variable(X_2)) reconstructed_X_1 = decoder(Variable(style_latent_space), class_latent_space_1.detach()) reconstructed_X_2 = decoder(Variable(style_latent_space), class_latent_space_2.detach()) style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1) style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2) style_latent_space_2 = reparameterize(training=False, mu=style_mu_2, logvar=style_logvar_2) reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss( style_latent_space_1, style_latent_space_2) reverse_cycle_loss.backward() reverse_cycle_loss /= FLAGS.reverse_cycle_coef reverse_cycle_optimizer.step() if (iteration + 1) % 10 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0])) print('Reverse cycle loss: ' + str(reverse_cycle_loss.data.storage().tolist()[0])) # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], kl_divergence_error.data.storage().tolist()[0], reverse_cycle_loss.data.storage().tolist()[0])) # write to tensorboard writer.add_scalar( 'Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Reverse cycle loss', reverse_cycle_loss.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) # save model after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save)) """ save reconstructed images and style swapped image generations to check progress """ image_batch_1, image_batch_2, _ = next(loader) image_batch_3, _, __ = next(loader) X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) X_3.copy_(image_batch_3) style_mu_1, style_logvar_1, _ = encoder(Variable(X_1)) _, __, class_latent_space_2 = encoder(Variable(X_2)) style_mu_3, style_logvar_3, _ = encoder(Variable(X_3)) style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) style_latent_space_3 = reparameterize(training=False, mu=style_mu_3, logvar=style_logvar_3) reconstructed_X_1_2 = decoder(style_latent_space_1, class_latent_space_2) reconstructed_X_3_2 = decoder(style_latent_space_3, class_latent_space_2) # save input image batch image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1)) image_batch = np.concatenate( (image_batch, image_batch, image_batch), axis=3) imshow_grid(image_batch, name=str(epoch) + '_original', save=True) # save reconstructed batch reconstructed_x = np.transpose( reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1)) reconstructed_x = np.concatenate( (reconstructed_x, reconstructed_x, reconstructed_x), axis=3) imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True) style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1)) style_batch = np.concatenate( (style_batch, style_batch, style_batch), axis=3) imshow_grid(style_batch, name=str(epoch) + '_style', save=True) # save style swapped reconstructed batch reconstructed_style = np.transpose( reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1)) reconstructed_style = np.concatenate( (reconstructed_style, reconstructed_style, reconstructed_style), axis=3) imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
class LSGANs_Trainer(nn.Module): def __init__(self, hyperparameters): super(LSGANs_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.encoder = Encoder(hyperparameters['input_dim_a'], hyperparameters['gen']) self.decoder = Decoder(hyperparameters['input_dim_a'], hyperparameters['gen']) self.dis_a = Discriminator() self.dis_b = Discriminator() self.interp_net_ab = Interpolator() self.interp_net_ba = Interpolator() self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] enc_params = list(self.encoder.parameters()) dec_params = list(self.decoder.parameters()) dis_a_params = list(self.dis_a.parameters()) dis_b_params = list(self.dis_b.parameters()) interperlator_ab_params = list(self.interp_net_ab.parameters()) interperlator_ba_params = list(self.interp_net_ba.parameters()) self.enc_opt = torch.optim.Adam( [p for p in enc_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dec_opt = torch.optim.Adam( [p for p in dec_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_a_opt = torch.optim.Adam( [p for p in dis_a_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_b_opt = torch.optim.Adam( [p for p in dis_b_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.interp_ab_opt = torch.optim.Adam( [p for p in interperlator_ab_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.interp_ba_opt = torch.optim.Adam( [p for p in interperlator_ba_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters) self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters) self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters) self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters) self.interp_ab_scheduler = get_scheduler(self.interp_ab_opt, hyperparameters) self.interp_ba_scheduler = get_scheduler(self.interp_ba_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False self.total_loss = 0 self.best_iter = 0 self.perceptural_loss = Perceptural_loss() def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() c_a, s_a_fake = self.encoder(x_a) c_b, s_b_fake = self.encoder(x_b) # decode (cross domain) s_ab_interp = self.interp_net_ab(s_a_fake, s_b_fake, self.v) s_ba_interp = self.interp_net_ba(s_b_fake, s_a_fake, self.v) x_ba = self.decoder(c_b, s_a_interp) x_ab = selfdecoder(c_a, s_b_interp) self.train() return x_ab, x_ba def zero_grad(self): self.dis_a_opt.zero_grad() self.dis_b_opt.zero_grad() self.dec_opt.zero_grad() self.enc_opt.zero_grad() self.interp_ab_opt.zero_grad() self.interp_ba_opt.zero_grad() def dis_update(self, x_a, x_b, hyperparameters): self.zero_grad() # encode c_a, s_a = self.encoder(x_a) c_b, s_b = self.encoder(x_b) # decode (cross domain) self.v = torch.ones(s_a.size()) s_a_interp = self.interp_net_ba(s_b, s_a, self.v) s_b_interp = self.interp_net_ab(s_a, s_b, self.v) x_ba = self.decoder(c_b, s_a_interp) x_ab = self.decoder(c_a, s_b_interp) x_a_feature = self.dis_a(x_a) x_ba_feature = self.dis_a(x_ba) x_b_feature = self.dis_b(x_b) x_ab_feature = self.dis_b(x_ab) self.loss_dis_a = (x_ba_feature - x_a_feature).mean() self.loss_dis_b = (x_ab_feature - x_b_feature).mean() # gradient penality self.loss_dis_a_gp = self.dis_a.calculate_gradient_penalty(x_ba, x_a) self.loss_dis_b_gp = self.dis_b.calculate_gradient_penalty(x_ab, x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \ hyperparameters['gan_w'] * self.loss_dis_b + \ hyperparameters['gan_w'] * self.loss_dis_a_gp + \ hyperparameters['gan_w'] * self.loss_dis_b_gp self.loss_dis_total.backward() self.total_loss += self.loss_dis_total.item() self.dis_a_opt.step() self.dis_b_opt.step() def gen_update(self, x_a, x_b, hyperparameters): self.zero_grad() # encode c_a, s_a = self.encoder(x_a) c_b, s_b = self.encoder(x_b) # decode (within domain) x_a_recon = self.decoder(c_a, s_a) x_b_recon = self.decoder(c_b, s_b) # decode (cross domain) self.v = torch.ones(s_a.size()) s_a_interp = self.interp_net_ba(s_b, s_a, self.v) s_b_interp = self.interp_net_ab(s_a, s_b, self.v) x_ba = self.decoder(c_b, s_a_interp) x_ab = self.decoder(c_a, s_b_interp) # encode again c_b_recon, s_a_recon = self.encoder(x_ba) c_a_recon, s_b_recon = self.encoder(x_ab) # decode again x_aa = self.decoder( c_a_recon, s_a) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bb = self.decoder( c_b_recon, s_b) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aa, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bb, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # perceptual loss self.loss_gen_vgg_a = self.perceptural_loss( x_a_recon, x_a) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.perceptural_loss( x_b_recon, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_aa = self.perceptural_loss( x_aa, x_a) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_bb = self.perceptural_loss( x_bb, x_b) if hyperparameters['vgg_w'] > 0 else 0 # GAN loss x_ba_feature = self.dis_a(x_ba) x_ab_feature = self.dis_b(x_ab) self.loss_gen_adv_a = -x_ba_feature.mean() self.loss_gen_adv_b = -x_ab_feature.mean() # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_aa + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_bb + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.total_loss += self.loss_gen_total.item() self.dec_opt.step() self.enc_opt.step() self.interp_ab_opt.step() self.interp_ba_opt.step() def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ab, x_ba, x_aa, x_bb = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a = self.encoder(x_a[i].unsqueeze(0)) c_b, s_b = self.encoder(x_b[i].unsqueeze(0)) x_a_recon.append(self.decoder(c_a, s_a)) x_b_recon.append(self.decoder(c_b, s_b)) self.v = torch.ones(s_a.size()) s_a_interp = self.interp_net_ba(s_b, s_a, self.v) s_b_interp = self.interp_net_ab(s_a, s_b, self.v) x_ab_i = self.decoder(c_a, s_b_interp) x_ba_i = self.decoder(c_b, s_a_interp) c_a_recon, s_b_recon = self.encoder(x_ab_i) c_b_recon, s_a_recon = self.encoder(x_ba_i) x_ab.append(self.decoder(c_a, s_b_interp.unsqueeze(0))) x_ba.append(self.decoder(c_b, s_a_interp.unsqueeze(0))) x_aa.append(self.decoder(c_a_recon, s_a.unsqueeze(0))) x_bb.append(self.decoder(c_b_recon, s_b.unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ab, x_aa = torch.cat(x_ab), torch.cat(x_aa) x_ba, x_bb = torch.cat(x_ba), torch.cat(x_bb) self.train() return x_a, x_a_recon, x_ab, x_aa, x_b, x_b_recon, x_ba, x_bb def update_learning_rate(self): if self.dis_a_scheduler is not None: self.dis_a_scheduler.step() if self.dis_b_scheduler is not None: self.dis_b_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.enc_scheduler is not None: self.enc_scheduler.step() if self.dec_scheduler is not None: self.dec_scheduler.step() if self.interpo_ab_scheduler is not None: self.interpo_ab_scheduler.step() if self.interpo_ba_scheduler is not None: self.interpo_ba_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load encode model_name = get_model(checkpoint_dir, "encoder") state_dict = torch.load(model_name) self.encoder.load_state_dict(state_dict) # Load decode model_name = get_model(checkpoint_dir, "decoder") state_dict = torch.load(model_name) self.decoder.load_state_dict(state_dict) # Load discriminator a model_name = get_model(checkpoint_dir, "dis_a") state_dict = torch.load(model_name) self.dis_a.load_state_dict(state_dict) # Load discriminator a model_name = get_model(checkpoint_dir, "dis_b") state_dict = torch.load(model_name) self.dis_b.load_state_dict(state_dict) # Load interperlator ab model_name = get_model(checkpoint_dir, "interp_ab") state_dict = torch.load(model_name) self.interp_net_ab.load_state_dict(state_dict) # Load interperlator ba model_name = get_model(checkpoint_dir, "interp_ba") state_dict = torch.load(model_name) self.interp_net_ba.load_state_dict(state_dict) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.enc_opt.load_state_dict(state_dict['enc_opt']) self.dec_opt.load_state_dict(state_dict['dec_opt']) self.dis_a_opt.load_state_dict(state_dict['dis_a_opt']) self.dis_b_opt.load_state_dict(state_dict['dis_b_opt']) self.interp_ab_opt.load_state_dict(state_dict['interp_ab_opt']) self.interp_ba_opt.load_state_dict(state_dict['interp_ba_opt']) self.best_iter = state_dict['best_iter'] self.total_loss = state_dict['total_loss'] # Reinitilize schedulers self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters, self.best_iter) self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters, self.best_iter) self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters, self.best_iter) self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters, self.best_iter) self.interpo_ab_scheduler = get_scheduler(self.interp_ab_opt, hyperparameters, self.best_iter) self.interpo_ba_scheduler = get_scheduler(self.interp_ba_opt, hyperparameters, self.best_iter) print('Resume from iteration %d' % self.best_iter) return self.best_iter, self.total_loss def resume_iter(self, checkpoint_dir, surfix, hyperparameters): # Load encode state_dict = torch.load( os.path.join(checkpoint_dir, 'encoder' + surfix + '.pt')) self.encoder.load_state_dict(state_dict) # Load decode state_dict = torch.load( os.path.join(checkpoint_dir, 'decoder' + surfix + '.pt')) self.decoder.load_state_dict(state_dict) # Load discriminator a state_dict = torch.load( os.path.join(checkpoint_dir, 'dis_a' + surfix + '.pt')) self.dis_a.load_state_dict(state_dict) # # Load discriminator b state_dict = torch.load( os.path.join(checkpoint_dir, 'dis_b' + surfix + '.pt')) self.dis_b.load_state_dict(state_dict) state_dict = torch.load( os.path.join(checkpoint_dir, 'interp' + surfix + '.pt')) # print(state_dict) self.interp_net_ab.load_state_dict(state_dict['ab']) self.interp_net_ba.load_state_dict(state_dict['ba']) # Load interperlator ab state_dict = torch.load( os.path.join(checkpoint_dir, 'interp_ab' + surfix + '.pt')) self.interp_net_ab.load_state_dict(state_dict) # # Load interperlator ba state_dict = torch.load( os.path.join(checkpoint_dir, 'interp_ba' + surfix + '.pt')) self.interp_net_ba.load_state_dict(state_dict) # Load optimizers state_dict = torch.load( os.path.join(checkpoint_dir, 'optimizer' + surfix + '.pt')) self.enc_opt.load_state_dict(state_dict['enc_opt']) self.dec_opt.load_state_dict(state_dict['dec_opt']) self.dis_a_opt.load_state_dict(state_dict['dis_a_opt']) self.dis_b_opt.load_state_dict(state_dict['dis_b_opt']) self.interp_ab_opt.load_state_dict(state_dict['interp_ab_opt']) self.interp_ba_opt.load_state_dict(state_dict['interp_ba_opt']) self.best_iter = state_dict['best_iter'] self.total_loss = state_dict['total_loss'] # Reinitilize schedulers self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters, self.best_iter) self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters, self.best_iter) self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters, self.best_iter) self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters, self.best_iter) self.interpo_ab_scheduler = get_scheduler(self.interp_ab_opt, hyperparameters, self.best_iter) self.interpo_ba_scheduler = get_scheduler(self.interp_ba_opt, hyperparameters, self.best_iter) print('Resume from iteration %d' % self.best_iter) return self.best_iter, self.total_loss def save_better_model(self, snapshot_dir): # remove sub_optimal models files = glob.glob(snapshot_dir + '/*') for f in files: os.remove(f) # Save encoder, decoder, interpolator, discriminators, and optimizers encoder_name = os.path.join(snapshot_dir, 'encoder_%.4f.pt' % (self.total_loss)) decoder_name = os.path.join(snapshot_dir, 'decoder_%.4f.pt' % (self.total_loss)) interp_ab_name = os.path.join(snapshot_dir, 'interp_ab_%.4f.pt' % (self.total_loss)) interp_ba_name = os.path.join(snapshot_dir, 'interp_ba_%.4f.pt' % (self.total_loss)) dis_a_name = os.path.join(snapshot_dir, 'dis_a_%.4f.pt' % (self.total_loss)) dis_b_name = os.path.join(snapshot_dir, 'dis_b_%.4f.pt' % (self.total_loss)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save(self.encoder.state_dict(), encoder_name) torch.save(self.decoder.state_dict(), decoder_name) torch.save(self.interp_net_ab.state_dict(), interp_ab_name) torch.save(self.interp_net_ba.state_dict(), interp_ba_name) torch.save(self.dis_a_opt.state_dict(), dis_a_name) torch.save(self.dis_b_opt.state_dict(), dis_b_name) torch.save( { 'enc_opt': self.enc_opt.state_dict(), 'dec_opt': self.dec_opt.state_dict(), 'dis_a_opt': self.dis_a_opt.state_dict(), 'dis_b_opt': self.dis_b_opt.state_dict(), 'interp_ab_opt': self.interp_ab_opt.state_dict(), 'interp_ba_opt': self.interp_ba_opt.state_dict(), 'best_iter': self.best_iter, 'total_loss': self.total_loss }, opt_name) def save_at_iter(self, snapshot_dir, iterations): encoder_name = os.path.join(snapshot_dir, 'encoder_%08d.pt' % (iterations + 1)) decoder_name = os.path.join(snapshot_dir, 'decoder_%08d.pt' % (iterations + 1)) interp_ab_name = os.path.join(snapshot_dir, 'interp_ab_%08d.pt' % (iterations + 1)) interp_ba_name = os.path.join(snapshot_dir, 'interp_ba_%08d.pt' % (iterations + 1)) dis_a_name = os.path.join(snapshot_dir, 'dis_a_%08d.pt' % (iterations + 1)) dis_b_name = os.path.join(snapshot_dir, 'dis_b_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer_%08d.pt' % (iterations + 1)) torch.save(self.encoder.state_dict(), encoder_name) torch.save(self.decoder.state_dict(), decoder_name) torch.save(self.interp_net_ab.state_dict(), interp_ab_name) torch.save(self.interp_net_ba.state_dict(), interp_ba_name) torch.save(self.dis_a_opt.state_dict(), dis_a_name) torch.save(self.dis_b_opt.state_dict(), dis_b_name) torch.save( { 'enc_opt': self.enc_opt.state_dict(), 'dec_opt': self.dec_opt.state_dict(), 'dis_a_opt': self.dis_a_opt.state_dict(), 'dis_b_opt': self.dis_b_opt.state_dict(), 'interp_ab_opt': self.interp_ab_opt.state_dict(), 'interp_ba_opt': self.interp_ba_opt.state_dict(), 'best_iter': self.best_iter, 'total_loss': self.total_loss }, opt_name)
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) discriminator = Discriminator() discriminator.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: raise Exception('This is not implemented') encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) """ variable definition """ X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim) """ loss definitions """ cross_entropy_loss = nn.CrossEntropyLoss() adversarial_loss = nn.BCELoss() ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() discriminator.cuda() cross_entropy_loss.cuda() adversarial_loss.cuda() X_1 = X_1.cuda() X_2 = X_2.cuda() X_3 = X_3.cuda() style_latent_space = style_latent_space.cuda() """ optimizer and scheduler definition """ auto_encoder_optimizer = optim.Adam( list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) reverse_cycle_optimizer = optim.Adam( list(encoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) generator_optimizer = optim.Adam( list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) discriminator_optimizer = optim.Adam( list(discriminator.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) # divide the learning rate by a factor of 10 after 80 epochs auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer, step_size=80, gamma=0.1) reverse_cycle_scheduler = optim.lr_scheduler.StepLR(reverse_cycle_optimizer, step_size=80, gamma=0.1) generator_scheduler = optim.lr_scheduler.StepLR(generator_optimizer, step_size=80, gamma=0.1) discriminator_scheduler = optim.lr_scheduler.StepLR(discriminator_optimizer, step_size=80, gamma=0.1) # Used later to define discriminator ground truths Tensor = torch.cuda.FloatTensor if FLAGS.cuda else torch.FloatTensor """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") if not os.path.exists('checkpoints'): os.makedirs('checkpoints') if not os.path.exists('reconstructed_images'): os.makedirs('reconstructed_images') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: headers = ['Epoch', 'Iteration', 'Reconstruction_loss', 'KL_divergence_loss', 'Reverse_cycle_loss'] if FLAGS.forward_gan: headers.extend(['Generator_forward_loss', 'Discriminator_forward_loss']) if FLAGS.reverse_gan: headers.extend(['Generator_reverse_loss', 'Discriminator_reverse_loss']) log.write('\t'.join(headers) + '\n') # load data set and create data loader instance print('Loading CIFAR paired dataset...') paired_cifar = CIFAR_Paired(root='cifar', download=True, train=True, transform=transform_config) loader = cycle(DataLoader(paired_cifar, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # Save a batch of images to use for visualization image_sample_1, image_sample_2, _ = next(loader) image_sample_3, _, _ = next(loader) # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print('Epoch #' + str(epoch) + '..........................................................................') # update the learning rate scheduler auto_encoder_scheduler.step() reverse_cycle_scheduler.step() generator_scheduler.step() discriminator_scheduler.step() for iteration in range(int(len(paired_cifar) / FLAGS.batch_size)): # Adversarial ground truths valid = Variable(Tensor(FLAGS.batch_size, 1).fill_(1.0), requires_grad=False) fake = Variable(Tensor(FLAGS.batch_size, 1).fill_(0.0), requires_grad=False) # A. run the auto-encoder reconstruction image_batch_1, image_batch_2, _ = next(loader) auto_encoder_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_mu_1, style_logvar_1, class_latent_space_1 = encoder(Variable(X_1)) style_latent_space_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1) kl_divergence_loss_1 = FLAGS.kl_divergence_coef * ( - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp()) ) kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_1.backward(retain_graph=True) style_mu_2, style_logvar_2, class_latent_space_2 = encoder(Variable(X_2)) style_latent_space_2 = reparameterize(training=True, mu=style_mu_2, logvar=style_logvar_2) kl_divergence_loss_2 = FLAGS.kl_divergence_coef * ( - 0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) - style_logvar_2.exp()) ) kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_2.backward(retain_graph=True) reconstructed_X_1 = decoder(style_latent_space_1, class_latent_space_2) reconstructed_X_2 = decoder(style_latent_space_2, class_latent_space_1) reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_1, Variable(X_1)) reconstruction_error_1.backward(retain_graph=True) reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_2, Variable(X_2)) reconstruction_error_2.backward() reconstruction_error = (reconstruction_error_1 + reconstruction_error_2) / FLAGS.reconstruction_coef kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2) / FLAGS.kl_divergence_coef auto_encoder_optimizer.step() # A-1. Discriminator training during forward cycle if FLAGS.forward_gan: # Training generator generator_optimizer.zero_grad() g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid) g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid) gen_f_loss = (g_loss_1 + g_loss_2) / 2.0 gen_f_loss.backward() generator_optimizer.step() # Training discriminator discriminator_optimizer.zero_grad() real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid) real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid) fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake) fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake) dis_f_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0 dis_f_loss.backward() discriminator_optimizer.step() # B. reverse cycle image_batch_1, _, __ = next(loader) image_batch_2, _, __ = next(loader) reverse_cycle_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_latent_space.normal_(0., 1.) _, __, class_latent_space_1 = encoder(Variable(X_1)) _, __, class_latent_space_2 = encoder(Variable(X_2)) reconstructed_X_1 = decoder(Variable(style_latent_space), class_latent_space_1.detach()) reconstructed_X_2 = decoder(Variable(style_latent_space), class_latent_space_2.detach()) style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1) style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2) style_latent_space_2 = reparameterize(training=False, mu=style_mu_2, logvar=style_logvar_2) reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(style_latent_space_1, style_latent_space_2) reverse_cycle_loss.backward() reverse_cycle_loss /= FLAGS.reverse_cycle_coef reverse_cycle_optimizer.step() # B-1. Discriminator training during reverse cycle if FLAGS.reverse_gan: # Training generator generator_optimizer.zero_grad() g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid) g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid) gen_r_loss = (g_loss_1 + g_loss_2) / 2.0 gen_r_loss.backward() generator_optimizer.step() # Training discriminator discriminator_optimizer.zero_grad() real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid) real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid) fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake) fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake) dis_r_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0 dis_r_loss.backward() discriminator_optimizer.step() if (iteration + 1) % 10 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0])) print('Reverse cycle loss: ' + str(reverse_cycle_loss.data.storage().tolist()[0])) if FLAGS.forward_gan: print('Generator F loss: ' + str(gen_f_loss.data.storage().tolist()[0])) print('Discriminator F loss: ' + str(dis_f_loss.data.storage().tolist()[0])) if FLAGS.reverse_gan: print('Generator R loss: ' + str(gen_r_loss.data.storage().tolist()[0])) print('Discriminator R loss: ' + str(dis_r_loss.data.storage().tolist()[0])) # write to log with open(FLAGS.log_file, 'a') as log: row = [] row.append(epoch) row.append(iteration) row.append(reconstruction_error.data.storage().tolist()[0]) row.append(kl_divergence_error.data.storage().tolist()[0]) row.append(reverse_cycle_loss.data.storage().tolist()[0]) if FLAGS.forward_gan: row.append(gen_f_loss.data.storage().tolist()[0]) row.append(dis_f_loss.data.storage().tolist()[0]) if FLAGS.reverse_gan: row.append(gen_r_loss.data.storage().tolist()[0]) row.append(dis_r_loss.data.storage().tolist()[0]) row = [str(x) for x in row] log.write('\t'.join(row) + '\n') # write to tensorboard writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Reverse cycle loss', reverse_cycle_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) if FLAGS.forward_gan: writer.add_scalar('Generator F loss', gen_f_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator F loss', dis_f_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) if FLAGS.reverse_gan: writer.add_scalar('Generator R loss', gen_r_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator R loss', dis_r_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) # save model after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save)) """ save reconstructed images and style swapped image generations to check progress """ X_1.copy_(image_sample_1) X_2.copy_(image_sample_2) X_3.copy_(image_sample_3) style_mu_1, style_logvar_1, _ = encoder(Variable(X_1)) _, __, class_latent_space_2 = encoder(Variable(X_2)) style_mu_3, style_logvar_3, _ = encoder(Variable(X_3)) style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) style_latent_space_3 = reparameterize(training=False, mu=style_mu_3, logvar=style_logvar_3) reconstructed_X_1_2 = decoder(style_latent_space_1, class_latent_space_2) reconstructed_X_3_2 = decoder(style_latent_space_3, class_latent_space_2) # save input image batch image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: image_batch = np.concatenate((image_batch, image_batch, image_batch), axis=3) imshow_grid(image_batch, name=str(epoch) + '_original', save=True) # save reconstructed batch reconstructed_x = np.transpose(reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: reconstructed_x = np.concatenate((reconstructed_x, reconstructed_x, reconstructed_x), axis=3) imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True) style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: style_batch = np.concatenate((style_batch, style_batch, style_batch), axis=3) imshow_grid(style_batch, name=str(epoch) + '_style', save=True) # save style swapped reconstructed batch reconstructed_style = np.transpose(reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: reconstructed_style = np.concatenate((reconstructed_style, reconstructed_style, reconstructed_style), axis=3) imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
save_image(original_sample, OUTPUT_PATH + '/epoch={}_original.png'.format(str(epoch)), nrow=NUM_FRAMES, normalize=True) save_image(decoded_sample, OUTPUT_PATH + '/epoch={}_recon.png'.format(str(epoch)), nrow=NUM_FRAMES, normalize=True) epoch_loss /= 3 if epoch_loss < lowest_loss: lowest_loss = epoch_loss # save checkpoints torch.save(encoder.state_dict(), os.path.join('checkpoints', ENCODER_SAVE)) torch.save(decoder.state_dict(), os.path.join('checkpoints', DECODER_SAVE)) print('Model Saved! Epoch loss at {}'.format(lowest_loss)) encoder_test.load_state_dict( torch.load(os.path.join('checkpoints', ENCODER_SAVE))) decoder_test.load_state_dict( torch.load(os.path.join('checkpoints', DECODER_SAVE))) video1 = next(loader).float().cuda()[0].unsqueeze(0) video2 = next(loader).float().cuda()[0].unsqueeze(0) X1_v1, KL1_v1, muL1_v1, det_q1_v1 = encoder_test(video1, BATCH_SIZE=1)
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) """ variable definition """ X = torch.FloatTensor(FLAGS.batch_size, 1, FLAGS.image_size, FLAGS.image_size) ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() X = X.cuda() """ optimizer definition """ auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) if not os.path.exists('checkpoints'): os.makedirs('checkpoints') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write( 'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n' ) # load data set and create data loader instance print('Loading MNIST dataset...') mnist = datasets.MNIST(root='mnist', download=True, train=True, transform=transform_config) loader = cycle( DataLoader(mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print( 'Epoch #' + str(epoch) + '..........................................................................' ) for iteration in range(int(len(mnist) / FLAGS.batch_size)): # load a mini-batch image_batch, labels_batch = next(loader) # set zero_grad for the optimizer auto_encoder_optimizer.zero_grad() X.copy_(image_batch) style_mu, style_logvar, class_mu, class_logvar = encoder( Variable(X)) grouped_mu, grouped_logvar = accumulate_group_evidence( class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda) # kl-divergence error for style latent space style_kl_divergence_loss = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) - style_logvar.exp())) style_kl_divergence_loss /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) style_kl_divergence_loss.backward(retain_graph=True) # kl-divergence error for class latent space class_kl_divergence_loss = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) - grouped_logvar.exp())) class_kl_divergence_loss /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) class_kl_divergence_loss.backward(retain_graph=True) # reconstruct samples """ sampling from group mu and logvar for each image in mini-batch differently makes the decoder consider class latent embeddings as random noise and ignore them """ style_latent_embeddings = reparameterize(training=True, mu=style_mu, logvar=style_logvar) class_latent_embeddings = group_wise_reparameterize( training=True, mu=grouped_mu, logvar=grouped_logvar, labels_batch=labels_batch, cuda=FLAGS.cuda) reconstructed_images = decoder(style_latent_embeddings, class_latent_embeddings) reconstruction_error = FLAGS.reconstruction_coef * mse_loss( reconstructed_images, Variable(X)) reconstruction_error.backward() auto_encoder_optimizer.step() if (iteration + 1) % 50 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('Style KL-Divergence loss: ' + str(style_kl_divergence_loss.data.storage().tolist()[0])) print('Class KL-Divergence loss: ' + str(class_kl_divergence_loss.data.storage().tolist()[0])) # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], style_kl_divergence_loss.data.storage().tolist()[0], class_kl_divergence_loss.data.storage().tolist()[0])) # write to tensorboard writer.add_scalar( 'Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Style KL-Divergence loss', style_kl_divergence_loss.data.storage().tolist()[0], epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Class KL-Divergence loss', class_kl_divergence_loss.data.storage().tolist()[0], epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration) # save checkpoints after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))
def train(opt): #### device device = torch.device('cuda:{}'.format(opt.gpu_id) if opt.gpu_id >= 0 else torch.device('cpu')) #### dataset data_loader = UnAlignedDataLoader() data_loader.initialize(opt) data_set = data_loader.load_data() print("The number of training images = %d." % len(data_set)) #### initialize models ## declaration E_a2Zb = Encoder(input_nc=opt.input_nc, ngf=opt.ngf, norm_type=opt.norm_type, use_dropout=not opt.no_dropout, n_blocks=9) G_Zb2b = Decoder(output_nc=opt.output_nc, ngf=opt.ngf, norm_type=opt.norm_type) T_Zb2Za = LatentTranslator(n_channels=256, norm_type=opt.norm_type, use_dropout=not opt.no_dropout) D_b = Discriminator(input_nc=opt.input_nc, ndf=opt.ndf, n_layers=opt.n_layers, norm_type=opt.norm_type) E_b2Za = Encoder(input_nc=opt.input_nc, ngf=opt.ngf, norm_type=opt.norm_type, use_dropout=not opt.no_dropout, n_blocks=9) G_Za2a = Decoder(output_nc=opt.output_nc, ngf=opt.ngf, norm_type=opt.norm_type) T_Za2Zb = LatentTranslator(n_channels=256, norm_type=opt.norm_type, use_dropout=not opt.no_dropout) D_a = Discriminator(input_nc=opt.input_nc, ndf=opt.ndf, n_layers=opt.n_layers, norm_type=opt.norm_type) ## initialization E_a2Zb = init_net(E_a2Zb, init_type=opt.init_type).to(device) G_Zb2b = init_net(G_Zb2b, init_type=opt.init_type).to(device) T_Zb2Za = init_net(T_Zb2Za, init_type=opt.init_type).to(device) D_b = init_net(D_b, init_type=opt.init_type).to(device) E_b2Za = init_net(E_b2Za, init_type=opt.init_type).to(device) G_Za2a = init_net(G_Za2a, init_type=opt.init_type).to(device) T_Za2Zb = init_net(T_Za2Zb, init_type=opt.init_type).to(device) D_a = init_net(D_a, init_type=opt.init_type).to(device) print( "+------------------------------------------------------+\nFinish initializing networks." ) #### optimizer and criterion ## criterion criterionGAN = GANLoss(opt.gan_mode).to(device) criterionZId = nn.L1Loss() criterionIdt = nn.L1Loss() criterionCTC = nn.L1Loss() criterionZCyc = nn.L1Loss() ## optimizer optimizer_G = torch.optim.Adam(itertools.chain(E_a2Zb.parameters(), G_Zb2b.parameters(), T_Zb2Za.parameters(), E_b2Za.parameters(), G_Za2a.parameters(), T_Za2Zb.parameters()), lr=opt.lr, betas=(opt.beta1, opt.beta2)) optimizer_D = torch.optim.Adam(itertools.chain(D_a.parameters(), D_b.parameters()), lr=opt.lr, betas=(opt.beta1, opt.beta2)) ## scheduler scheduler = [ get_scheduler(optimizer_G, opt), get_scheduler(optimizer_D, opt) ] print( "+------------------------------------------------------+\nFinish initializing the optimizers and criterions." ) #### global variables checkpoints_pth = os.path.join(opt.checkpoints, opt.name) if os.path.exists(checkpoints_pth) is not True: os.mkdir(checkpoints_pth) os.mkdir(os.path.join(checkpoints_pth, 'images')) record_fh = open(os.path.join(checkpoints_pth, 'records.txt'), 'w', encoding='utf-8') loss_names = [ 'GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B', 'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B' ] fake_A_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images fake_B_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images print( "+------------------------------------------------------+\nFinish preparing the other works." ) print( "+------------------------------------------------------+\nNow training is beginning .." ) #### training cur_iter = 0 for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() # timer for entire epoch for i, data in enumerate(data_set): ## setup inputs real_A = data['A'].to(device) real_B = data['B'].to(device) ## forward # image cycle / GAN latent_B = E_a2Zb(real_A) #-> a -> Zb : E_a2b(a) fake_B = G_Zb2b(latent_B) #-> Zb -> b' : G_b(E_a2b(a)) latent_A = E_b2Za(real_B) #-> b -> Za : E_b2a(b) fake_A = G_Za2a(latent_A) #-> Za -> a' : G_a(E_b2a(b)) # Idt ''' rec_A = G_Za2a(E_b2Za(fake_B)) #-> b' -> Za' -> rec_a : G_a(E_b2a(fake_b)) rec_B = G_Zb2b(E_a2Zb(fake_A)) #-> a' -> Zb' -> rec_b : G_b(E_a2b(fake_a)) ''' idt_latent_A = E_b2Za(real_A) #-> a -> Za : E_b2a(a) idt_A = G_Za2a(idt_latent_A) #-> Za -> idt_a : G_a(E_b2a(a)) idt_latent_B = E_a2Zb(real_B) #-> b -> Zb : E_a2b(b) idt_B = G_Zb2b(idt_latent_B) #-> Zb -> idt_b : G_b(E_a2b(b)) # ZIdt T_latent_A = T_Zb2Za(latent_B) #-> Zb -> Za'' : T_b2a(E_a2b(a)) T_rec_A = G_Za2a( T_latent_A) #-> Za'' -> a'' : G_a(T_b2a(E_a2b(a))) T_latent_B = T_Za2Zb(latent_A) #-> Za -> Zb'' : T_a2b(E_b2a(b)) T_rec_B = G_Zb2b( T_latent_B) #-> Zb'' -> b'' : G_b(T_a2b(E_b2a(b))) # CTC T_idt_latent_B = T_Za2Zb(idt_latent_A) #-> a -> T_a2b(E_b2a(a)) T_idt_latent_A = T_Zb2Za(idt_latent_B) #-> b -> T_b2a(E_a2b(b)) # ZCyc TT_latent_B = T_Za2Zb(T_latent_A) #-> T_a2b(T_b2a(E_a2b(a))) TT_latent_A = T_Zb2Za(T_latent_B) #-> T_b2a(T_a2b(E_b2a(b))) ### optimize parameters ## Generator updating set_requires_grad( [D_b, D_a], False) #-> set Discriminator to require no gradient optimizer_G.zero_grad() # GAN loss loss_G_A = criterionGAN(D_b(fake_B), True) loss_G_B = criterionGAN(D_a(fake_A), True) loss_GAN = loss_G_A + loss_G_B # Idt loss loss_idt_A = criterionIdt(idt_A, real_A) loss_idt_B = criterionIdt(idt_B, real_B) loss_Idt = loss_idt_A + loss_idt_B # Latent cross-identity loss loss_Zid_A = criterionZId(T_rec_A, real_A) loss_Zid_B = criterionZId(T_rec_B, real_B) loss_Zid = loss_Zid_A + loss_Zid_B # Latent cross-translation consistency loss_CTC_A = criterionCTC(T_idt_latent_A, latent_A) loss_CTC_B = criterionCTC(T_idt_latent_B, latent_B) loss_CTC = loss_CTC_B + loss_CTC_A # Latent cycle consistency loss_ZCyc_A = criterionZCyc(TT_latent_A, latent_A) loss_ZCyc_B = criterionZCyc(TT_latent_B, latent_B) loss_ZCyc = loss_ZCyc_B + loss_ZCyc_A loss_G = opt.lambda_gan * loss_GAN + opt.lambda_idt * loss_Idt + opt.lambda_zid * loss_Zid + opt.lambda_ctc * loss_CTC + opt.lambda_zcyc * loss_ZCyc # backward and gradient updating loss_G.backward() optimizer_G.step() ## Discriminator updating set_requires_grad([D_b, D_a], True) # -> set Discriminator to require gradient optimizer_D.zero_grad() # backward D_b fake_B_ = fake_B_pool.query(fake_B) #-> real_B, fake_B pred_real_B = D_b(real_B) loss_D_real_B = criterionGAN(pred_real_B, True) pred_fake_B = D_b(fake_B_) loss_D_fake_B = criterionGAN(pred_fake_B, False) loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5 loss_D_B.backward() # backward D_a fake_A_ = fake_A_pool.query(fake_A) #-> real_A, fake_A pred_real_A = D_a(real_A) loss_D_real_A = criterionGAN(pred_real_A, True) pred_fake_A = D_a(fake_A_) loss_D_fake_A = criterionGAN(pred_fake_A, False) loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5 loss_D_A.backward() # update the gradients optimizer_D.step() ### validate here, both qualitively and quantitatively ## record the losses if cur_iter % opt.log_freq == 0: # loss_names = ['GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B', 'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B'] losses = [ loss_G_A.item(), loss_D_A.item(), loss_idt_A.item(), loss_CTC_A.item(), loss_Zid_A.item(), loss_ZCyc_A.item(), loss_G_B.item(), loss_D_B.item(), loss_idt_B.item(), loss_CTC_B.item(), loss_Zid_B.item(), loss_ZCyc_B.item() ] # record line = '' for loss in losses: line += '{} '.format(loss) record_fh.write(line[:-1] + '\n') # print out print('Epoch: %3d/%3dIter: %9d--------------------------+' % (epoch, opt.epoch, i)) field_names = loss_names[:len(loss_names) // 2] table = PrettyTable(field_names=field_names) for l_n in field_names: table.align[l_n] = 'm' table.add_row(losses[:len(field_names)]) print(table.get_string(reversesort=True)) field_names = loss_names[len(loss_names) // 2:] table = PrettyTable(field_names=field_names) for l_n in field_names: table.align[l_n] = 'm' table.add_row(losses[-len(field_names):]) print(table.get_string(reversesort=True)) ## visualize if cur_iter % opt.vis_freq == 0: if opt.gpu_id >= 0: real_A = real_A.cpu().data real_B = real_B.cpu().data fake_A = fake_A.cpu().data fake_B = fake_B.cpu().data idt_A = idt_A.cpu().data idt_B = idt_B.cpu().data T_rec_A = T_rec_A.cpu().data T_rec_B = T_rec_B.cpu().data plt.subplot(241), plt.title('real_A'), plt.imshow( tensor2image_RGB(real_A[0, ...])) plt.subplot(242), plt.title('fake_B'), plt.imshow( tensor2image_RGB(fake_B[0, ...])) plt.subplot(243), plt.title('idt_A'), plt.imshow( tensor2image_RGB(idt_A[0, ...])) plt.subplot(244), plt.title('L_idt_A'), plt.imshow( tensor2image_RGB(T_rec_A[0, ...])) plt.subplot(245), plt.title('real_B'), plt.imshow( tensor2image_RGB(real_B[0, ...])) plt.subplot(246), plt.title('fake_A'), plt.imshow( tensor2image_RGB(fake_A[0, ...])) plt.subplot(247), plt.title('idt_B'), plt.imshow( tensor2image_RGB(idt_B[0, ...])) plt.subplot(248), plt.title('L_idt_B'), plt.imshow( tensor2image_RGB(T_rec_B[0, ...])) plt.savefig( os.path.join(checkpoints_pth, 'images', '%03d_%09d.jpg' % (epoch, i))) cur_iter += 1 #break #-> debug ## till now, we finish one epoch, try to update the learning rate update_learning_rate(schedulers=scheduler, opt=opt, optimizer=optimizer_D) ## save the model if epoch % opt.ckp_freq == 0: #-> save models # torch.save(model.state_dict(), PATH) #-> load in models # model.load_state_dict(torch.load(PATH)) # model.eval() if opt.gpu_id >= 0: E_a2Zb = E_a2Zb.cpu() G_Zb2b = G_Zb2b.cpu() T_Zb2Za = T_Zb2Za.cpu() D_b = D_b.cpu() E_b2Za = E_b2Za.cpu() G_Za2a = G_Za2a.cpu() T_Za2Zb = T_Za2Zb.cpu() D_a = D_a.cpu() ''' torch.save( E_a2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch)) torch.save( G_Zb2b.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-G_b.pth' % epoch)) torch.save(T_Zb2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch)) torch.save( D_b.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-D_b.pth' % epoch)) torch.save( E_b2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch)) torch.save( G_Za2a.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-G_a.pth' % epoch)) torch.save(T_Za2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch)) torch.save( D_a.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-D_a.pth' % epoch)) ''' torch.save( E_a2Zb.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch)) torch.save( G_Zb2b.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-G_b.pth' % epoch)) torch.save( T_Zb2Za.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch)) torch.save( D_b.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-D_b.pth' % epoch)) torch.save( E_b2Za.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch)) torch.save( G_Za2a.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-G_a.pth' % epoch)) torch.save( T_Za2Zb.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch)) torch.save( D_a.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-D_a.pth' % epoch)) if opt.gpu_id >= 0: E_a2Zb = E_a2Zb.to(device) G_Zb2b = G_Zb2b.to(device) T_Zb2Za = T_Zb2Za.to(device) D_b = D_b.to(device) E_b2Za = E_b2Za.to(device) G_Za2a = G_Za2a.to(device) T_Za2Zb = T_Za2Zb.to(device) D_a = D_a.to(device) print("+Successfully saving models in epoch: %3d.-------------+" % epoch) #break #-> debug record_fh.close() print("≧◔◡◔≦ Congratulation! Finishing the training!")
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) """ variable definition """ X = torch.FloatTensor(FLAGS.batch_size, 784) ''' run on GPU if GPU is available ''' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') encoder.to(device=device) decoder.to(device=device) X = X.to(device=device) """ optimizer definition """ auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) """ """ if torch.cuda.is_available() and not FLAGS.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) if not os.path.exists('checkpoints'): os.makedirs('checkpoints') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write( 'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n' ) # load data set and create data loader instance dirs = os.listdir(os.path.join(os.getcwd(), 'data')) print('Loading double multivariate normal time series data...') for dsname in dirs: params = dsname.split('_') if params[2] in ('theta=-1'): print('Running dataset ', dsname) ds = DoubleMulNormal(dsname) # ds = experiment3(1000, 50, 3) loader = cycle( DataLoader(ds, batch_size=FLAGS.batch_size, shuffle=True, drop_last=True)) # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print() print( 'Epoch #' + str(epoch) + '........................................................') # the total loss at each epoch after running iterations of batches total_loss = 0 for iteration in range(int(len(ds) / FLAGS.batch_size)): # load a mini-batch image_batch, labels_batch = next(loader) # set zero_grad for the optimizer auto_encoder_optimizer.zero_grad() X.copy_(image_batch) style_mu, style_logvar, class_mu, class_logvar = encoder( Variable(X)) grouped_mu, grouped_logvar = accumulate_group_evidence( class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda) # kl-divergence error for style latent space style_kl_divergence_loss = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) - style_logvar.exp())) style_kl_divergence_loss /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) style_kl_divergence_loss.backward(retain_graph=True) # kl-divergence error for class latent space class_kl_divergence_loss = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) - grouped_logvar.exp())) class_kl_divergence_loss /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) class_kl_divergence_loss.backward(retain_graph=True) # reconstruct samples """ sampling from group mu and logvar for each image in mini-batch differently makes the decoder consider class latent embeddings as random noise and ignore them """ style_latent_embeddings = reparameterize( training=True, mu=style_mu, logvar=style_logvar) class_latent_embeddings = group_wise_reparameterize( training=True, mu=grouped_mu, logvar=grouped_logvar, labels_batch=labels_batch, cuda=FLAGS.cuda) reconstructed_images = decoder(style_latent_embeddings, class_latent_embeddings) reconstruction_error = FLAGS.reconstruction_coef * mse_loss( reconstructed_images, Variable(X)) reconstruction_error.backward() total_loss += style_kl_divergence_loss + class_kl_divergence_loss + reconstruction_error auto_encoder_optimizer.step() if (iteration + 1) % 50 == 0: print('\tIteration #' + str(iteration)) print('Reconstruction loss: ' + str( reconstruction_error.data.storage().tolist()[0])) print('Style KL loss: ' + str(style_kl_divergence_loss.data.storage(). tolist()[0])) print('Class KL loss: ' + str(class_kl_divergence_loss.data.storage(). tolist()[0])) # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], style_kl_divergence_loss.data.storage().tolist() [0], class_kl_divergence_loss.data.storage().tolist() [0])) # write to tensorboard writer.add_scalar( 'Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(ds) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Style KL-Divergence loss', style_kl_divergence_loss.data.storage().tolist()[0], epoch * (int(len(ds) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Class KL-Divergence loss', class_kl_divergence_loss.data.storage().tolist()[0], epoch * (int(len(ds) / FLAGS.batch_size) + 1) + iteration) if epoch == 0 and (iteration + 1) % 50 == 0: torch.save( encoder.state_dict(), os.path.join('checkpoints', 'encoder_' + dsname)) torch.save( decoder.state_dict(), os.path.join('checkpoints', 'decoder_' + dsname)) # save checkpoints after every 10 epochs if (epoch + 1) % 10 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save( encoder.state_dict(), os.path.join('checkpoints', 'encoder_' + dsname)) torch.save( decoder.state_dict(), os.path.join('checkpoints', 'decoder_' + dsname)) print('Total loss at current epoch: ', total_loss.item())
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict( torch.load(os.path.join(savedir, FLAGS.encoder_save))) decoder.load_state_dict( torch.load(os.path.join(savedir, FLAGS.decoder_save))) ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() """ optimizer definition """ auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) savedir = 'checkpoints_%d' % (FLAGS.batch_size) if not os.path.exists(savedir): os.makedirs(savedir) # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write( 'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n' ) # load data set and create data loader instance print('Loading MNIST dataset...') mnist = datasets.MNIST(root='mnist', download=True, train=True, transform=transform_config) # Creating data indices for training and validation splits: dataset_size = len(mnist) indices = list(range(dataset_size)) split = 10000 np.random.seed(0) np.random.shuffle(indices) train_indices, val_indices = indices[split:], indices[:split] train_mnist, val_mnist = torch.utils.data.random_split( mnist, [dataset_size - split, split]) # Creating PT data samplers and loaders: weights_train = torch.ones(len(mnist)) weights_test = torch.ones(len(mnist)) weights_train[val_mnist.indices] = 0 weights_test[train_mnist.indices] = 0 counts = torch.zeros(10) for i in range(10): idx_label = mnist.targets[train_mnist.indices].eq(i) counts[i] = idx_label.sum() max = float(counts.max()) sum_counts = float(counts.sum()) for i in range(10): idx_label = mnist.targets[train_mnist.indices].eq( i).nonzero().squeeze() weights_train[train_mnist.indices[idx_label]] = (sum_counts / counts[i]) train_sampler = SubsetRandomSampler(train_mnist.indices) valid_sampler = SubsetRandomSampler(val_mnist.indices) kwargs = {'num_workers': 1, 'pin_memory': True} if FLAGS.cuda else {} loader = DataLoader(mnist, batch_size=FLAGS.batch_size, sampler=train_sampler, **kwargs) valid_loader = DataLoader(mnist, batch_size=FLAGS.batch_size, sampler=valid_sampler, **kwargs) monitor = torch.zeros(FLAGS.end_epoch - FLAGS.start_epoch, 4) # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print( 'Epoch #' + str(epoch) + '..........................................................................' ) elbo_epoch = 0 term1_epoch = 0 term2_epoch = 0 term3_epoch = 0 for it, (image_batch, labels_batch) in enumerate(loader): # set zero_grad for the optimizer auto_encoder_optimizer.zero_grad() X = image_batch.cuda().detach().clone() elbo, reconstruction_proba, style_kl_divergence_loss, class_kl_divergence_loss = process( FLAGS, X, labels_batch, encoder, decoder) (-elbo).backward() auto_encoder_optimizer.step() elbo_epoch += elbo term1_epoch += reconstruction_proba term2_epoch += style_kl_divergence_loss term3_epoch += class_kl_divergence_loss print("Elbo epoch %.2f" % (elbo_epoch / (it + 1))) print("Rec. Proba %.2f" % (term1_epoch / (it + 1))) print("KL style %.2f" % (term2_epoch / (it + 1))) print("KL content %.2f" % (term3_epoch / (it + 1))) # save checkpoints after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: monitor[epoch, :] = eval(FLAGS, valid_loader, encoder, decoder) torch.save( encoder.state_dict(), os.path.join(savedir, FLAGS.encoder_save + '_e%d' % epoch)) torch.save( decoder.state_dict(), os.path.join(savedir, FLAGS.decoder_save + '_e%d' % epoch)) print("VAL elbo %.2f" % (monitor[epoch, 0])) print("VAL Rec. Proba %.2f" % (monitor[epoch, 1])) print("VAL KL style %.2f" % (monitor[epoch, 2])) print("VAL KL content %.2f" % (monitor[epoch, 3])) torch.save(monitor, os.path.join(savedir, 'monitor_e%d' % epoch))
range=(-1., 1.)) it += 1 # Sample images if (ep + 1) % args.sample_epochs == 0: E.eval() G.eval() with torch.no_grad(): mu, logvar = E(fixed_reals) latents = sample_latent(mu, logvar) samples = G(latents, fixed_annos_onehot) vutils.save_image(samples, join(sample_path, '{:03d}_fake.jpg'.format(ep)), nrow=4, padding=0, normalize=True, range=(-1., 1.)) # Checkpoints if (ep + 1) % args.save_epochs == 0: torch.save(E.state_dict(), join(checkpoint_path, '{:03}.E.pth'.format(ep))) torch.save(G.state_dict(), join(checkpoint_path, '{:03}.G.pth'.format(ep))) torch.save(D.state_dict(), join(checkpoint_path, '{:03}.D.pth'.format(ep))) torch.save(G_opt.state_dict(), join(checkpoint_path, '{:03}.G_opt.pth'.format(ep))) torch.save(D_opt.state_dict(), join(checkpoint_path, '{:03}.D_opt.pth'.format(ep)))
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) discriminator = Discriminator() discriminator.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) discriminator.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.discriminator_save))) """ variable definition """ real_domain_labels = 1 fake_domain_labels = 0 X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) domain_labels = torch.LongTensor(FLAGS.batch_size) style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim) """ loss definitions """ cross_entropy_loss = nn.CrossEntropyLoss() ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() discriminator.cuda() cross_entropy_loss.cuda() X_1 = X_1.cuda() X_2 = X_2.cuda() X_3 = X_3.cuda() domain_labels = domain_labels.cuda() style_latent_space = style_latent_space.cuda() """ optimizer definition """ auto_encoder_optimizer = optim.Adam( list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) discriminator_optimizer = optim.Adam( list(discriminator.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) generator_optimizer = optim.Adam( list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") if not os.path.exists('checkpoints'): os.makedirs('checkpoints') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write('Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\t') log.write('Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n') # load data set and create data loader instance print('Loading MNIST paired dataset...') paired_mnist = MNIST_Paired(root='mnist', download=True, train=True, transform=transform_config) loader = cycle(DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # initialise variables discriminator_accuracy = 0. # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print('Epoch #' + str(epoch) + '..........................................................................') for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)): # A. run the auto-encoder reconstruction image_batch_1, image_batch_2, _ = next(loader) auto_encoder_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_mu_1, style_logvar_1, class_1 = encoder(Variable(X_1)) style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1) kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp()) kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_1.backward(retain_graph=True) _, __, class_2 = encoder(Variable(X_2)) reconstructed_X_1 = decoder(style_1, class_1) reconstructed_X_2 = decoder(style_1, class_2) reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1)) reconstruction_error_1.backward(retain_graph=True) reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_1)) reconstruction_error_2.backward() reconstruction_error = reconstruction_error_1 + reconstruction_error_2 kl_divergence_error = kl_divergence_loss_1 auto_encoder_optimizer.step() # B. run the generator for i in range(FLAGS.generator_times): generator_optimizer.zero_grad() image_batch_1, _, __ = next(loader) image_batch_3, _, __ = next(loader) domain_labels.fill_(real_domain_labels) X_1.copy_(image_batch_1) X_3.copy_(image_batch_3) style_mu_1, style_logvar_1, _ = encoder(Variable(X_1)) style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1) kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp()) kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_1.backward(retain_graph=True) _, __, class_3 = encoder(Variable(X_3)) reconstructed_X_1_3 = decoder(style_1, class_3) output_1 = discriminator(Variable(X_3), reconstructed_X_1_3) generator_error_1 = cross_entropy_loss(output_1, Variable(domain_labels)) generator_error_1.backward(retain_graph=True) style_latent_space.normal_(0., 1.) reconstructed_X_latent_3 = decoder(Variable(style_latent_space), class_3) output_2 = discriminator(Variable(X_3), reconstructed_X_latent_3) generator_error_2 = cross_entropy_loss(output_2, Variable(domain_labels)) generator_error_2.backward() generator_error = generator_error_1 + generator_error_2 kl_divergence_error += kl_divergence_loss_1 generator_optimizer.step() # C. run the discriminator for i in range(FLAGS.discriminator_times): discriminator_optimizer.zero_grad() # train discriminator on real data domain_labels.fill_(real_domain_labels) image_batch_1, _, __ = next(loader) image_batch_2, image_batch_3, _ = next(loader) X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) X_3.copy_(image_batch_3) real_output = discriminator(Variable(X_2), Variable(X_3)) discriminator_real_error = cross_entropy_loss(real_output, Variable(domain_labels)) discriminator_real_error.backward() # train discriminator on fake data domain_labels.fill_(fake_domain_labels) style_mu_1, style_logvar_1, _ = encoder(Variable(X_1)) style_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) _, __, class_3 = encoder(Variable(X_3)) reconstructed_X_1_3 = decoder(style_1, class_3) fake_output = discriminator(Variable(X_3), reconstructed_X_1_3) discriminator_fake_error = cross_entropy_loss(fake_output, Variable(domain_labels)) discriminator_fake_error.backward() # total discriminator error discriminator_error = discriminator_real_error + discriminator_fake_error # calculate discriminator accuracy for this step target_true_labels = torch.cat((torch.ones(FLAGS.batch_size), torch.zeros(FLAGS.batch_size)), dim=0) if FLAGS.cuda: target_true_labels = target_true_labels.cuda() discriminator_predictions = torch.cat((real_output, fake_output), dim=0) _, discriminator_predictions = torch.max(discriminator_predictions, 1) discriminator_accuracy = (discriminator_predictions.data == target_true_labels.long() ).sum().item() / (FLAGS.batch_size * 2) if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy: discriminator_optimizer.step() if (iteration + 1) % 50 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0])) print('') print('Generator loss: ' + str(generator_error.data.storage().tolist()[0])) print('Discriminator loss: ' + str(discriminator_error.data.storage().tolist()[0])) print('Discriminator accuracy: ' + str(discriminator_accuracy)) print('..........') # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], kl_divergence_error.data.storage().tolist()[0], generator_error.data.storage().tolist()[0], discriminator_error.data.storage().tolist()[0], discriminator_accuracy )) # write to tensorboard writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Generator loss', generator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator loss', discriminator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator accuracy', discriminator_accuracy * 100, epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) # save model after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save)) torch.save(discriminator.state_dict(), os.path.join('checkpoints', FLAGS.discriminator_save))
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim) encoder.apply(weights_init) decoder = Decoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim) decoder.apply(weights_init) discriminator = Discriminator() discriminator.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) discriminator.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.discriminator_save))) """ variable definition """ real_domain_labels = 1 fake_domain_labels = 0 X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) domain_labels = torch.LongTensor(FLAGS.batch_size) """ loss definitions """ cross_entropy_loss = nn.CrossEntropyLoss() ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() discriminator.cuda() cross_entropy_loss.cuda() X_1 = X_1.cuda() X_2 = X_2.cuda() X_3 = X_3.cuda() domain_labels = domain_labels.cuda() """ optimizer definition """ auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) discriminator_optimizer = optim.Adam(list(discriminator.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) generator_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) if not os.path.exists('checkpoints'): os.makedirs('checkpoints') if not os.path.exists('reconstructed_images'): os.makedirs('reconstructed_images') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write('Epoch\tIteration\tReconstruction_loss\t') log.write( 'Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n') # load data set and create data loader instance print('Loading MNIST paired dataset...') paired_mnist = MNIST_Paired(root='mnist', download=True, train=True, transform=transform_config) loader = cycle( DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # initialise variables discriminator_accuracy = 0. # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print( 'Epoch #' + str(epoch) + '..........................................................................' ) for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)): # A. run the auto-encoder reconstruction image_batch_1, image_batch_2, labels_batch_1 = next(loader) auto_encoder_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) nv_1, nc_1 = encoder(Variable(X_1)) nv_2, nc_2 = encoder(Variable(X_2)) reconstructed_X_1 = decoder(nv_1, nc_2) reconstructed_X_2 = decoder(nv_2, nc_1) reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1)) reconstruction_error_1.backward(retain_graph=True) reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_2)) reconstruction_error_2.backward() reconstruction_error = reconstruction_error_1 + reconstruction_error_2 if FLAGS.train_auto_encoder: auto_encoder_optimizer.step() # B. run the adversarial part of the architecture # B. a) run the discriminator for i in range(FLAGS.discriminator_times): discriminator_optimizer.zero_grad() # train discriminator on real data domain_labels.fill_(real_domain_labels) image_batch_1, image_batch_2, labels_batch_1 = next(loader) X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) real_output = discriminator(Variable(X_1), Variable(X_2)) discriminator_real_error = FLAGS.disc_coef * cross_entropy_loss( real_output, Variable(domain_labels)) discriminator_real_error.backward() # train discriminator on fake data domain_labels.fill_(fake_domain_labels) image_batch_3, _, labels_batch_3 = next(loader) X_3.copy_(image_batch_3) nv_3, nc_3 = encoder(Variable(X_3)) # reconstruction is taking common factor from X_1 and varying factor from X_3 reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1]) fake_output = discriminator(Variable(X_1), reconstructed_X_3_1) discriminator_fake_error = FLAGS.disc_coef * cross_entropy_loss( fake_output, Variable(domain_labels)) discriminator_fake_error.backward() # total discriminator error discriminator_error = discriminator_real_error + discriminator_fake_error # calculate discriminator accuracy for this step target_true_labels = torch.cat((torch.ones( FLAGS.batch_size), torch.zeros(FLAGS.batch_size)), dim=0) if FLAGS.cuda: target_true_labels = target_true_labels.cuda() discriminator_predictions = torch.cat( (real_output, fake_output), dim=0) _, discriminator_predictions = torch.max( discriminator_predictions, 1) discriminator_accuracy = (discriminator_predictions.data == target_true_labels.long()).sum( ).item() / (FLAGS.batch_size * 2) if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy and FLAGS.train_discriminator: discriminator_optimizer.step() # B. b) run the generator for i in range(FLAGS.generator_times): generator_optimizer.zero_grad() image_batch_1, _, labels_batch_1 = next(loader) image_batch_3, __, labels_batch_3 = next(loader) domain_labels.fill_(real_domain_labels) X_1.copy_(image_batch_1) X_3.copy_(image_batch_3) nv_3, nc_3 = encoder(Variable(X_3)) # reconstruction is taking common factor from X_1 and varying factor from X_3 reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1]) output = discriminator(Variable(X_1), reconstructed_X_3_1) generator_error = FLAGS.gen_coef * cross_entropy_loss( output, Variable(domain_labels)) generator_error.backward() if FLAGS.train_generator: generator_optimizer.step() # print progress after 10 iterations if (iteration + 1) % 10 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('Generator loss: ' + str(generator_error.data.storage().tolist()[0])) print('') print('Discriminator loss: ' + str(discriminator_error.data.storage().tolist()[0])) print('Discriminator accuracy: ' + str(discriminator_accuracy)) print('..........') # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], generator_error.data.storage().tolist()[0], discriminator_error.data.storage().tolist()[0], discriminator_accuracy)) # write to tensorboard writer.add_scalar( 'Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Generator loss', generator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Discriminator loss', discriminator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) # save model after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save)) torch.save(discriminator.state_dict(), os.path.join('checkpoints', FLAGS.discriminator_save)) """ save reconstructed images and style swapped image generations to check progress """ image_batch_1, image_batch_2, labels_batch_1 = next(loader) image_batch_3, _, __ = next(loader) X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) X_3.copy_(image_batch_3) nv_1, nc_1 = encoder(Variable(X_1)) nv_2, nc_2 = encoder(Variable(X_2)) nv_3, nc_3 = encoder(Variable(X_3)) reconstructed_X_1 = decoder(nv_1, nc_2) reconstructed_X_3_2 = decoder(nv_3, nc_2) # save input image batch image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1)) image_batch = np.concatenate( (image_batch, image_batch, image_batch), axis=3) imshow_grid(image_batch, name=str(epoch) + '_original', save=True) # save reconstructed batch reconstructed_x = np.transpose( reconstructed_X_1.cpu().data.numpy(), (0, 2, 3, 1)) reconstructed_x = np.concatenate( (reconstructed_x, reconstructed_x, reconstructed_x), axis=3) imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True) # save cross reconstructed batch style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1)) style_batch = np.concatenate( (style_batch, style_batch, style_batch), axis=3) imshow_grid(style_batch, name=str(epoch) + '_style', save=True) reconstructed_style = np.transpose( reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1)) reconstructed_style = np.concatenate( (reconstructed_style, reconstructed_style, reconstructed_style), axis=3) imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
class DoubleDQN: def __init__(self, env, tau=0.1, gamma=0.9, epsilon=1.0): self.env = env self.tau = tau self.gamma = gamma self.embedding_size = 30 self.hidden_size = 30 self.obs_shape = self.env.get_obs().shape self.action_shape = 40 // 5 if args.encoding == "onehot": self.encoder = OneHot( args.bins, self.env.all_questions + self.env.held_out_questions, self.hidden_size).to(DEVICE) else: self.encoder = Encoder(self.embedding_size, self.hidden_size).to(DEVICE) self.model = DQN(self.obs_shape, self.action_shape, self.encoder).to(DEVICE) self.target_model = DQN(self.obs_shape, self.action_shape, self.encoder).to(DEVICE) self.optimizer = torch.optim.Adam(self.model.parameters()) self.epsilon = epsilon if os.path.exists(MODEL_FILE): checkpoint = torch.load(MODEL_FILE) self.encoder.load_state_dict(checkpoint['encoder_state_dict']) self.model.load_state_dict(checkpoint['model_state_dict']) self.target_model.load_state_dict( checkpoint['target_model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.epsilon = checkpoint['epsilon'] # hard copy model parameters to target model parameters for target_param, param in zip(self.model.parameters(), self.target_model.parameters()): target_param.data.copy_(param) def get_action(self, state, goal): assert len(state.shape ) == 2 # This function should not be called during update if (np.random.rand() > self.epsilon): q_values = self.model.forward(state, goal) idx = torch.argmax(q_values).detach() obj_selection = idx // 8 direction_selection = idx % 8 else: action = self.env.sample_random_action() obj_selection = action[0] direction_selection = action[1] return int(obj_selection), int(direction_selection) def compute_loss(self, batch): states, actions, goals, rewards, next_states, satisfied_goals, dones = batch rewards = torch.FloatTensor(rewards).to(DEVICE) dones = torch.FloatTensor(dones).to(DEVICE) curr_Q = self.model(states, goals) curr_Q_prev_actions = [ curr_Q[batch, actions[batch][0], actions[batch][1]] for batch in range(len(states)) ] # TODO: Use pytorch gather curr_Q_prev_actions = torch.stack(curr_Q_prev_actions) next_Q = self.target_model(next_states, goals) next_Q_max_actions = torch.max(next_Q, -1).values next_Q_max_actions = torch.max(next_Q_max_actions, -1).values next_Q_max_actions = rewards + ( 1 - dones) * self.gamma * next_Q_max_actions loss = F.mse_loss(curr_Q_prev_actions, next_Q_max_actions.detach()) return loss def update(self, replay_buffer, batch_size): for _ in range(UPDATE_STEPS): batch = replay_buffer.sample(batch_size) loss = self.compute_loss(batch) self.optimizer.zero_grad() loss.backward() self.optimizer.step() def update_target_net(self): # TODO: Check this function # target network update for target_param, param in zip(self.target_model.parameters(), self.model.parameters()): target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param) def save_model(self): torch.save( { 'model_state_dict': self.model.state_dict(), 'target_model_state_dict': self.target_model.state_dict(), 'encoder_state_dict': self.encoder.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'epsilon': self.epsilon }, MODEL_FILE)