def main(FLAGS): encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) device = 'cuda:0' decoder.to(device) encoder.to(device) tsne = TSNE(2) mnist = DataLoader( datasets.MNIST(root='mnist', download=True, train=False, transform=transform_config)) s_dict = {} with torch.no_grad(): for i, (image, label) in enumerate(mnist): label = int(label) print(i, label) style_mu_1, style_logvar_1, class_latent_space_1 = encoder( image.to(device)) s_dict.setdefault(label, []).append(class_latent_space_1) s_all = [] for label in range(10): s_all.extend(s_dict[label]) s_all = torch.cat(s_all) s_all = s_all.view(s_all.shape[0], -1).cpu() s_2d = tsne.fit_transform(s_all) np.savez('s_2d.npz', s_2d=s_2d)
class LSGANs_Trainer(nn.Module): def __init__(self, hyperparameters): super(LSGANs_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.encoder = Encoder(hyperparameters['input_dim_a'], hyperparameters['gen']) self.decoder = Decoder(hyperparameters['input_dim_a'], hyperparameters['gen']) self.dis_a = Discriminator() self.dis_b = Discriminator() self.interp_net_ab = Interpolator() self.interp_net_ba = Interpolator() self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] enc_params = list(self.encoder.parameters()) dec_params = list(self.decoder.parameters()) dis_a_params = list(self.dis_a.parameters()) dis_b_params = list(self.dis_b.parameters()) interperlator_ab_params = list(self.interp_net_ab.parameters()) interperlator_ba_params = list(self.interp_net_ba.parameters()) self.enc_opt = torch.optim.Adam( [p for p in enc_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dec_opt = torch.optim.Adam( [p for p in dec_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_a_opt = torch.optim.Adam( [p for p in dis_a_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_b_opt = torch.optim.Adam( [p for p in dis_b_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.interp_ab_opt = torch.optim.Adam( [p for p in interperlator_ab_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.interp_ba_opt = torch.optim.Adam( [p for p in interperlator_ba_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters) self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters) self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters) self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters) self.interp_ab_scheduler = get_scheduler(self.interp_ab_opt, hyperparameters) self.interp_ba_scheduler = get_scheduler(self.interp_ba_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False self.total_loss = 0 self.best_iter = 0 self.perceptural_loss = Perceptural_loss() def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() c_a, s_a_fake = self.encoder(x_a) c_b, s_b_fake = self.encoder(x_b) # decode (cross domain) s_ab_interp = self.interp_net_ab(s_a_fake, s_b_fake, self.v) s_ba_interp = self.interp_net_ba(s_b_fake, s_a_fake, self.v) x_ba = self.decoder(c_b, s_a_interp) x_ab = selfdecoder(c_a, s_b_interp) self.train() return x_ab, x_ba def zero_grad(self): self.dis_a_opt.zero_grad() self.dis_b_opt.zero_grad() self.dec_opt.zero_grad() self.enc_opt.zero_grad() self.interp_ab_opt.zero_grad() self.interp_ba_opt.zero_grad() def dis_update(self, x_a, x_b, hyperparameters): self.zero_grad() # encode c_a, s_a = self.encoder(x_a) c_b, s_b = self.encoder(x_b) # decode (cross domain) self.v = torch.ones(s_a.size()) s_a_interp = self.interp_net_ba(s_b, s_a, self.v) s_b_interp = self.interp_net_ab(s_a, s_b, self.v) x_ba = self.decoder(c_b, s_a_interp) x_ab = self.decoder(c_a, s_b_interp) x_a_feature = self.dis_a(x_a) x_ba_feature = self.dis_a(x_ba) x_b_feature = self.dis_b(x_b) x_ab_feature = self.dis_b(x_ab) self.loss_dis_a = (x_ba_feature - x_a_feature).mean() self.loss_dis_b = (x_ab_feature - x_b_feature).mean() # gradient penality self.loss_dis_a_gp = self.dis_a.calculate_gradient_penalty(x_ba, x_a) self.loss_dis_b_gp = self.dis_b.calculate_gradient_penalty(x_ab, x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \ hyperparameters['gan_w'] * self.loss_dis_b + \ hyperparameters['gan_w'] * self.loss_dis_a_gp + \ hyperparameters['gan_w'] * self.loss_dis_b_gp self.loss_dis_total.backward() self.total_loss += self.loss_dis_total.item() self.dis_a_opt.step() self.dis_b_opt.step() def gen_update(self, x_a, x_b, hyperparameters): self.zero_grad() # encode c_a, s_a = self.encoder(x_a) c_b, s_b = self.encoder(x_b) # decode (within domain) x_a_recon = self.decoder(c_a, s_a) x_b_recon = self.decoder(c_b, s_b) # decode (cross domain) self.v = torch.ones(s_a.size()) s_a_interp = self.interp_net_ba(s_b, s_a, self.v) s_b_interp = self.interp_net_ab(s_a, s_b, self.v) x_ba = self.decoder(c_b, s_a_interp) x_ab = self.decoder(c_a, s_b_interp) # encode again c_b_recon, s_a_recon = self.encoder(x_ba) c_a_recon, s_b_recon = self.encoder(x_ab) # decode again x_aa = self.decoder( c_a_recon, s_a) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bb = self.decoder( c_b_recon, s_b) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aa, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bb, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # perceptual loss self.loss_gen_vgg_a = self.perceptural_loss( x_a_recon, x_a) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.perceptural_loss( x_b_recon, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_aa = self.perceptural_loss( x_aa, x_a) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_bb = self.perceptural_loss( x_bb, x_b) if hyperparameters['vgg_w'] > 0 else 0 # GAN loss x_ba_feature = self.dis_a(x_ba) x_ab_feature = self.dis_b(x_ab) self.loss_gen_adv_a = -x_ba_feature.mean() self.loss_gen_adv_b = -x_ab_feature.mean() # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_aa + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_bb + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.total_loss += self.loss_gen_total.item() self.dec_opt.step() self.enc_opt.step() self.interp_ab_opt.step() self.interp_ba_opt.step() def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ab, x_ba, x_aa, x_bb = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a = self.encoder(x_a[i].unsqueeze(0)) c_b, s_b = self.encoder(x_b[i].unsqueeze(0)) x_a_recon.append(self.decoder(c_a, s_a)) x_b_recon.append(self.decoder(c_b, s_b)) self.v = torch.ones(s_a.size()) s_a_interp = self.interp_net_ba(s_b, s_a, self.v) s_b_interp = self.interp_net_ab(s_a, s_b, self.v) x_ab_i = self.decoder(c_a, s_b_interp) x_ba_i = self.decoder(c_b, s_a_interp) c_a_recon, s_b_recon = self.encoder(x_ab_i) c_b_recon, s_a_recon = self.encoder(x_ba_i) x_ab.append(self.decoder(c_a, s_b_interp.unsqueeze(0))) x_ba.append(self.decoder(c_b, s_a_interp.unsqueeze(0))) x_aa.append(self.decoder(c_a_recon, s_a.unsqueeze(0))) x_bb.append(self.decoder(c_b_recon, s_b.unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ab, x_aa = torch.cat(x_ab), torch.cat(x_aa) x_ba, x_bb = torch.cat(x_ba), torch.cat(x_bb) self.train() return x_a, x_a_recon, x_ab, x_aa, x_b, x_b_recon, x_ba, x_bb def update_learning_rate(self): if self.dis_a_scheduler is not None: self.dis_a_scheduler.step() if self.dis_b_scheduler is not None: self.dis_b_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.enc_scheduler is not None: self.enc_scheduler.step() if self.dec_scheduler is not None: self.dec_scheduler.step() if self.interpo_ab_scheduler is not None: self.interpo_ab_scheduler.step() if self.interpo_ba_scheduler is not None: self.interpo_ba_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load encode model_name = get_model(checkpoint_dir, "encoder") state_dict = torch.load(model_name) self.encoder.load_state_dict(state_dict) # Load decode model_name = get_model(checkpoint_dir, "decoder") state_dict = torch.load(model_name) self.decoder.load_state_dict(state_dict) # Load discriminator a model_name = get_model(checkpoint_dir, "dis_a") state_dict = torch.load(model_name) self.dis_a.load_state_dict(state_dict) # Load discriminator a model_name = get_model(checkpoint_dir, "dis_b") state_dict = torch.load(model_name) self.dis_b.load_state_dict(state_dict) # Load interperlator ab model_name = get_model(checkpoint_dir, "interp_ab") state_dict = torch.load(model_name) self.interp_net_ab.load_state_dict(state_dict) # Load interperlator ba model_name = get_model(checkpoint_dir, "interp_ba") state_dict = torch.load(model_name) self.interp_net_ba.load_state_dict(state_dict) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.enc_opt.load_state_dict(state_dict['enc_opt']) self.dec_opt.load_state_dict(state_dict['dec_opt']) self.dis_a_opt.load_state_dict(state_dict['dis_a_opt']) self.dis_b_opt.load_state_dict(state_dict['dis_b_opt']) self.interp_ab_opt.load_state_dict(state_dict['interp_ab_opt']) self.interp_ba_opt.load_state_dict(state_dict['interp_ba_opt']) self.best_iter = state_dict['best_iter'] self.total_loss = state_dict['total_loss'] # Reinitilize schedulers self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters, self.best_iter) self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters, self.best_iter) self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters, self.best_iter) self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters, self.best_iter) self.interpo_ab_scheduler = get_scheduler(self.interp_ab_opt, hyperparameters, self.best_iter) self.interpo_ba_scheduler = get_scheduler(self.interp_ba_opt, hyperparameters, self.best_iter) print('Resume from iteration %d' % self.best_iter) return self.best_iter, self.total_loss def resume_iter(self, checkpoint_dir, surfix, hyperparameters): # Load encode state_dict = torch.load( os.path.join(checkpoint_dir, 'encoder' + surfix + '.pt')) self.encoder.load_state_dict(state_dict) # Load decode state_dict = torch.load( os.path.join(checkpoint_dir, 'decoder' + surfix + '.pt')) self.decoder.load_state_dict(state_dict) # Load discriminator a state_dict = torch.load( os.path.join(checkpoint_dir, 'dis_a' + surfix + '.pt')) self.dis_a.load_state_dict(state_dict) # # Load discriminator b state_dict = torch.load( os.path.join(checkpoint_dir, 'dis_b' + surfix + '.pt')) self.dis_b.load_state_dict(state_dict) state_dict = torch.load( os.path.join(checkpoint_dir, 'interp' + surfix + '.pt')) # print(state_dict) self.interp_net_ab.load_state_dict(state_dict['ab']) self.interp_net_ba.load_state_dict(state_dict['ba']) # Load interperlator ab state_dict = torch.load( os.path.join(checkpoint_dir, 'interp_ab' + surfix + '.pt')) self.interp_net_ab.load_state_dict(state_dict) # # Load interperlator ba state_dict = torch.load( os.path.join(checkpoint_dir, 'interp_ba' + surfix + '.pt')) self.interp_net_ba.load_state_dict(state_dict) # Load optimizers state_dict = torch.load( os.path.join(checkpoint_dir, 'optimizer' + surfix + '.pt')) self.enc_opt.load_state_dict(state_dict['enc_opt']) self.dec_opt.load_state_dict(state_dict['dec_opt']) self.dis_a_opt.load_state_dict(state_dict['dis_a_opt']) self.dis_b_opt.load_state_dict(state_dict['dis_b_opt']) self.interp_ab_opt.load_state_dict(state_dict['interp_ab_opt']) self.interp_ba_opt.load_state_dict(state_dict['interp_ba_opt']) self.best_iter = state_dict['best_iter'] self.total_loss = state_dict['total_loss'] # Reinitilize schedulers self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters, self.best_iter) self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters, self.best_iter) self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters, self.best_iter) self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters, self.best_iter) self.interpo_ab_scheduler = get_scheduler(self.interp_ab_opt, hyperparameters, self.best_iter) self.interpo_ba_scheduler = get_scheduler(self.interp_ba_opt, hyperparameters, self.best_iter) print('Resume from iteration %d' % self.best_iter) return self.best_iter, self.total_loss def save_better_model(self, snapshot_dir): # remove sub_optimal models files = glob.glob(snapshot_dir + '/*') for f in files: os.remove(f) # Save encoder, decoder, interpolator, discriminators, and optimizers encoder_name = os.path.join(snapshot_dir, 'encoder_%.4f.pt' % (self.total_loss)) decoder_name = os.path.join(snapshot_dir, 'decoder_%.4f.pt' % (self.total_loss)) interp_ab_name = os.path.join(snapshot_dir, 'interp_ab_%.4f.pt' % (self.total_loss)) interp_ba_name = os.path.join(snapshot_dir, 'interp_ba_%.4f.pt' % (self.total_loss)) dis_a_name = os.path.join(snapshot_dir, 'dis_a_%.4f.pt' % (self.total_loss)) dis_b_name = os.path.join(snapshot_dir, 'dis_b_%.4f.pt' % (self.total_loss)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save(self.encoder.state_dict(), encoder_name) torch.save(self.decoder.state_dict(), decoder_name) torch.save(self.interp_net_ab.state_dict(), interp_ab_name) torch.save(self.interp_net_ba.state_dict(), interp_ba_name) torch.save(self.dis_a_opt.state_dict(), dis_a_name) torch.save(self.dis_b_opt.state_dict(), dis_b_name) torch.save( { 'enc_opt': self.enc_opt.state_dict(), 'dec_opt': self.dec_opt.state_dict(), 'dis_a_opt': self.dis_a_opt.state_dict(), 'dis_b_opt': self.dis_b_opt.state_dict(), 'interp_ab_opt': self.interp_ab_opt.state_dict(), 'interp_ba_opt': self.interp_ba_opt.state_dict(), 'best_iter': self.best_iter, 'total_loss': self.total_loss }, opt_name) def save_at_iter(self, snapshot_dir, iterations): encoder_name = os.path.join(snapshot_dir, 'encoder_%08d.pt' % (iterations + 1)) decoder_name = os.path.join(snapshot_dir, 'decoder_%08d.pt' % (iterations + 1)) interp_ab_name = os.path.join(snapshot_dir, 'interp_ab_%08d.pt' % (iterations + 1)) interp_ba_name = os.path.join(snapshot_dir, 'interp_ba_%08d.pt' % (iterations + 1)) dis_a_name = os.path.join(snapshot_dir, 'dis_a_%08d.pt' % (iterations + 1)) dis_b_name = os.path.join(snapshot_dir, 'dis_b_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer_%08d.pt' % (iterations + 1)) torch.save(self.encoder.state_dict(), encoder_name) torch.save(self.decoder.state_dict(), decoder_name) torch.save(self.interp_net_ab.state_dict(), interp_ab_name) torch.save(self.interp_net_ba.state_dict(), interp_ba_name) torch.save(self.dis_a_opt.state_dict(), dis_a_name) torch.save(self.dis_b_opt.state_dict(), dis_b_name) torch.save( { 'enc_opt': self.enc_opt.state_dict(), 'dec_opt': self.dec_opt.state_dict(), 'dis_a_opt': self.dis_a_opt.state_dict(), 'dis_b_opt': self.dis_b_opt.state_dict(), 'interp_ab_opt': self.interp_ab_opt.state_dict(), 'interp_ba_opt': self.interp_ba_opt.state_dict(), 'best_iter': self.best_iter, 'total_loss': self.total_loss }, opt_name)
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) """ variable definition """ X = torch.FloatTensor(FLAGS.batch_size, 1, FLAGS.image_size, FLAGS.image_size) ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() X = X.cuda() """ optimizer definition """ auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) if not os.path.exists('checkpoints'): os.makedirs('checkpoints') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write( 'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n' ) # load data set and create data loader instance print('Loading MNIST dataset...') mnist = datasets.MNIST(root='mnist', download=True, train=True, transform=transform_config) loader = cycle( DataLoader(mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print( 'Epoch #' + str(epoch) + '..........................................................................' ) for iteration in range(int(len(mnist) / FLAGS.batch_size)): # load a mini-batch image_batch, labels_batch = next(loader) # set zero_grad for the optimizer auto_encoder_optimizer.zero_grad() X.copy_(image_batch) style_mu, style_logvar, class_mu, class_logvar = encoder( Variable(X)) grouped_mu, grouped_logvar = accumulate_group_evidence( class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda) # kl-divergence error for style latent space style_kl_divergence_loss = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) - style_logvar.exp())) style_kl_divergence_loss /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) style_kl_divergence_loss.backward(retain_graph=True) # kl-divergence error for class latent space class_kl_divergence_loss = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) - grouped_logvar.exp())) class_kl_divergence_loss /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) class_kl_divergence_loss.backward(retain_graph=True) # reconstruct samples """ sampling from group mu and logvar for each image in mini-batch differently makes the decoder consider class latent embeddings as random noise and ignore them """ style_latent_embeddings = reparameterize(training=True, mu=style_mu, logvar=style_logvar) class_latent_embeddings = group_wise_reparameterize( training=True, mu=grouped_mu, logvar=grouped_logvar, labels_batch=labels_batch, cuda=FLAGS.cuda) reconstructed_images = decoder(style_latent_embeddings, class_latent_embeddings) reconstruction_error = FLAGS.reconstruction_coef * mse_loss( reconstructed_images, Variable(X)) reconstruction_error.backward() auto_encoder_optimizer.step() if (iteration + 1) % 50 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('Style KL-Divergence loss: ' + str(style_kl_divergence_loss.data.storage().tolist()[0])) print('Class KL-Divergence loss: ' + str(class_kl_divergence_loss.data.storage().tolist()[0])) # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], style_kl_divergence_loss.data.storage().tolist()[0], class_kl_divergence_loss.data.storage().tolist()[0])) # write to tensorboard writer.add_scalar( 'Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Style KL-Divergence loss', style_kl_divergence_loss.data.storage().tolist()[0], epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Class KL-Divergence loss', class_kl_divergence_loss.data.storage().tolist()[0], epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration) # save checkpoints after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))
loss = torch.sum(l1 + l2 + torch.log(det_p) - torch.log(det_q), dim=1) return loss if (__name__ == '__main__'): # model definition encoder = Encoder() encoder.apply(weights_init) decoder = Decoder() decoder.apply(weights_init) # load saved models if load_saved flag is true if LOAD_SAVED: encoder.load_state_dict( torch.load(os.path.join('checkpoints', ENCODER_SAVE))) decoder.load_state_dict( torch.load(os.path.join('checkpoints', DECODER_SAVE))) # loss definition mse_loss = nn.MSELoss() # add option to run on gpu if (CUDA): encoder.cuda() decoder.cuda() mse_loss.cuda() # optimizer optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()),
likelihood = torch.sum(summand) / summand.size(0) FLAGS = parser.parse_args() if __name__ == '__main__': """ model definitions """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.encoder_save), map_location=lambda storage, loc: storage)) decoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.decoder_save), map_location=lambda storage, loc: storage)) encoder.cuda() decoder.cuda() if not os.path.exists('reconstructed_images'): os.makedirs('reconstructed_images') # load data set and create data loader instance ''' print('Loading MNIST paired dataset...') paired_mnist = MNIST_Paired(root='mnist', download=True, train=False, transform=transform_config) loader = cycle(DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) image_array = []
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) """ variable definition """ X = torch.FloatTensor(FLAGS.batch_size, 784) ''' run on GPU if GPU is available ''' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') encoder.to(device=device) decoder.to(device=device) X = X.to(device=device) """ optimizer definition """ auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) """ """ if torch.cuda.is_available() and not FLAGS.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) if not os.path.exists('checkpoints'): os.makedirs('checkpoints') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write( 'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n' ) # load data set and create data loader instance dirs = os.listdir(os.path.join(os.getcwd(), 'data')) print('Loading double multivariate normal time series data...') for dsname in dirs: params = dsname.split('_') if params[2] in ('theta=-1'): print('Running dataset ', dsname) ds = DoubleMulNormal(dsname) # ds = experiment3(1000, 50, 3) loader = cycle( DataLoader(ds, batch_size=FLAGS.batch_size, shuffle=True, drop_last=True)) # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print() print( 'Epoch #' + str(epoch) + '........................................................') # the total loss at each epoch after running iterations of batches total_loss = 0 for iteration in range(int(len(ds) / FLAGS.batch_size)): # load a mini-batch image_batch, labels_batch = next(loader) # set zero_grad for the optimizer auto_encoder_optimizer.zero_grad() X.copy_(image_batch) style_mu, style_logvar, class_mu, class_logvar = encoder( Variable(X)) grouped_mu, grouped_logvar = accumulate_group_evidence( class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda) # kl-divergence error for style latent space style_kl_divergence_loss = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) - style_logvar.exp())) style_kl_divergence_loss /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) style_kl_divergence_loss.backward(retain_graph=True) # kl-divergence error for class latent space class_kl_divergence_loss = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) - grouped_logvar.exp())) class_kl_divergence_loss /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) class_kl_divergence_loss.backward(retain_graph=True) # reconstruct samples """ sampling from group mu and logvar for each image in mini-batch differently makes the decoder consider class latent embeddings as random noise and ignore them """ style_latent_embeddings = reparameterize( training=True, mu=style_mu, logvar=style_logvar) class_latent_embeddings = group_wise_reparameterize( training=True, mu=grouped_mu, logvar=grouped_logvar, labels_batch=labels_batch, cuda=FLAGS.cuda) reconstructed_images = decoder(style_latent_embeddings, class_latent_embeddings) reconstruction_error = FLAGS.reconstruction_coef * mse_loss( reconstructed_images, Variable(X)) reconstruction_error.backward() total_loss += style_kl_divergence_loss + class_kl_divergence_loss + reconstruction_error auto_encoder_optimizer.step() if (iteration + 1) % 50 == 0: print('\tIteration #' + str(iteration)) print('Reconstruction loss: ' + str( reconstruction_error.data.storage().tolist()[0])) print('Style KL loss: ' + str(style_kl_divergence_loss.data.storage(). tolist()[0])) print('Class KL loss: ' + str(class_kl_divergence_loss.data.storage(). tolist()[0])) # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], style_kl_divergence_loss.data.storage().tolist() [0], class_kl_divergence_loss.data.storage().tolist() [0])) # write to tensorboard writer.add_scalar( 'Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(ds) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Style KL-Divergence loss', style_kl_divergence_loss.data.storage().tolist()[0], epoch * (int(len(ds) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Class KL-Divergence loss', class_kl_divergence_loss.data.storage().tolist()[0], epoch * (int(len(ds) / FLAGS.batch_size) + 1) + iteration) if epoch == 0 and (iteration + 1) % 50 == 0: torch.save( encoder.state_dict(), os.path.join('checkpoints', 'encoder_' + dsname)) torch.save( decoder.state_dict(), os.path.join('checkpoints', 'decoder_' + dsname)) # save checkpoints after every 10 epochs if (epoch + 1) % 10 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save( encoder.state_dict(), os.path.join('checkpoints', 'encoder_' + dsname)) torch.save( decoder.state_dict(), os.path.join('checkpoints', 'decoder_' + dsname)) print('Total loss at current epoch: ', total_loss.item())
def test(opt): #### mkdir des_pth = os.path.join('results', opt.name) if os.path.exists(os.path.join(des_pth)) is not True: os.mkdir(des_pth) src_pth = os.path.join(opt.checkpoints, opt.name) models_name = os.listdir(src_pth) models_name.remove('images') models_name.remove('records.txt') models_name.sort(key=lambda x: int(x[6:9])) target = int(models_name[-1][6:9]) #### device device = torch.device('cuda:{}'.format(opt.gpu_id) if opt.gpu_id >= 0 else torch.device('cpu')) #### data data_loader = UnAlignedDataLoader() data_loader.initialize(opt) data_set = data_loader.load_data() #### networks ## initialize E_a2b = Encoder(input_nc=opt.input_nc, ngf=opt.ngf, norm_type=opt.norm_type, use_dropout=not opt.no_dropout, n_blocks=9) G_b = Decoder(output_nc=opt.output_nc, ngf=opt.ngf, norm_type=opt.norm_type) E_b2a = Encoder(input_nc=opt.input_nc, ngf=opt.ngf, norm_type=opt.norm_type, use_dropout=not opt.no_dropout, n_blocks=9) G_a = Decoder(output_nc=opt.output_nc, ngf=opt.ngf, norm_type=opt.norm_type) ## load in models E_a2b.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-E_a2b.pth'%target))) G_b.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-G_b.pth'%target))) E_b2a.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-E_b2a.pth' % target))) G_a.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-G_a.pth' % target))) E_a2b = E_a2b.to(device) G_b = G_b.to(device) E_b2a = E_b2a.to(device) G_a = G_a.to(device) for i, data in enumerate(data_set): real_A = data['A'].to(device) real_B = data['B'].to(device) fake_B = G_b(E_a2b(real_A)) fake_A = G_a(E_b2a(real_B)) ## visualize if opt.gpu_id >= 0: fake_B = fake_B.cpu().data fake_A = fake_A.cpu().data real_A = real_A.cpu() real_B = real_B.cpu() for j in range(opt.batch_size): fake_b = tensor2image_RGB(fake_B[j, ...]) fake_a = tensor2image_RGB(fake_A[j, ...]) real_a = tensor2image_RGB(real_A[j, ...]) real_b = tensor2image_RGB(real_B[j, ...]) plt.subplot(221), plt.title("real_A"), plt.imshow(real_a) plt.subplot(222), plt.title("fake_B"), plt.imshow(fake_b) plt.subplot(223), plt.title("real_B"), plt.imshow(real_b) plt.subplot(224), plt.title("fake_A"), plt.imshow(fake_a) plt.savefig(os.path.join(des_pth, '%06d-%02d.jpg'%(i, j))) #break #-> debug print("≧◔◡◔≦ Congratulation! Successfully finishing the testing!")
class DoubleDQN: def __init__(self, env, tau=0.1, gamma=0.9, epsilon=1.0): self.env = env self.tau = tau self.gamma = gamma self.embedding_size = 30 self.hidden_size = 30 self.obs_shape = self.env.get_obs().shape self.action_shape = 40 // 5 if args.encoding == "onehot": self.encoder = OneHot( args.bins, self.env.all_questions + self.env.held_out_questions, self.hidden_size).to(DEVICE) else: self.encoder = Encoder(self.embedding_size, self.hidden_size).to(DEVICE) self.model = DQN(self.obs_shape, self.action_shape, self.encoder).to(DEVICE) self.target_model = DQN(self.obs_shape, self.action_shape, self.encoder).to(DEVICE) self.optimizer = torch.optim.Adam(self.model.parameters()) self.epsilon = epsilon if os.path.exists(MODEL_FILE): checkpoint = torch.load(MODEL_FILE) self.encoder.load_state_dict(checkpoint['encoder_state_dict']) self.model.load_state_dict(checkpoint['model_state_dict']) self.target_model.load_state_dict( checkpoint['target_model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.epsilon = checkpoint['epsilon'] # hard copy model parameters to target model parameters for target_param, param in zip(self.model.parameters(), self.target_model.parameters()): target_param.data.copy_(param) def get_action(self, state, goal): assert len(state.shape ) == 2 # This function should not be called during update if (np.random.rand() > self.epsilon): q_values = self.model.forward(state, goal) idx = torch.argmax(q_values).detach() obj_selection = idx // 8 direction_selection = idx % 8 else: action = self.env.sample_random_action() obj_selection = action[0] direction_selection = action[1] return int(obj_selection), int(direction_selection) def compute_loss(self, batch): states, actions, goals, rewards, next_states, satisfied_goals, dones = batch rewards = torch.FloatTensor(rewards).to(DEVICE) dones = torch.FloatTensor(dones).to(DEVICE) curr_Q = self.model(states, goals) curr_Q_prev_actions = [ curr_Q[batch, actions[batch][0], actions[batch][1]] for batch in range(len(states)) ] # TODO: Use pytorch gather curr_Q_prev_actions = torch.stack(curr_Q_prev_actions) next_Q = self.target_model(next_states, goals) next_Q_max_actions = torch.max(next_Q, -1).values next_Q_max_actions = torch.max(next_Q_max_actions, -1).values next_Q_max_actions = rewards + ( 1 - dones) * self.gamma * next_Q_max_actions loss = F.mse_loss(curr_Q_prev_actions, next_Q_max_actions.detach()) return loss def update(self, replay_buffer, batch_size): for _ in range(UPDATE_STEPS): batch = replay_buffer.sample(batch_size) loss = self.compute_loss(batch) self.optimizer.zero_grad() loss.backward() self.optimizer.step() def update_target_net(self): # TODO: Check this function # target network update for target_param, param in zip(self.target_model.parameters(), self.model.parameters()): target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param) def save_model(self): torch.save( { 'model_state_dict': self.model.state_dict(), 'target_model_state_dict': self.target_model.state_dict(), 'encoder_state_dict': self.encoder.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'epsilon': self.epsilon }, MODEL_FILE)
if not os.path.exists('reconstructed_images'): os.makedirs('reconstructed_images') if not os.path.exists('sqerrors'): os.makedirs('sqerrors') cwd = os.getcwd() dirs = os.listdir(os.path.join(cwd, 'data')) print('Loading double univariate normal time series test data...') for dsname in dirs: params = dsname.split('_') if params[2] in ('theta=-1'): # load saved parameters of encoder and decoder encoder.load_state_dict( torch.load(os.path.join(cwd, 'checkpoints', 'encoder_' + dsname), map_location=lambda storage, loc: storage)) decoder.load_state_dict( torch.load(os.path.join(cwd, 'checkpoints', 'decoder_' + dsname), map_location=lambda storage, loc: storage)) encoder = encoder.to(device=device) decoder = decoder.to(device=device) paired_mnist = DoubleMulNormal(dsname) loader = cycle( DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True))
print(folder) checks = np.arange(0, FLAGS.end_epoch + 5, 5) checks[1:] -= 1 monitor[0, 0] = -np.inf best_elbo = monitor[checks, 0].argmax() #DEBUG best_elbo = -1 print(checks[best_elbo], monitor[checks[best_elbo], 0], monitor[checks, 0].max(), best_elbo) FLAGS.encoder_save = folder + '/encoder_e%d' % checks[best_elbo] FLAGS.decoder_save = folder + '/decoder_e%d' % checks[best_elbo] FLAGS.batch_size = 256 encoder.load_state_dict( torch.load(FLAGS.encoder_save, map_location=lambda storage, loc: storage)) decoder.load_state_dict( torch.load(FLAGS.decoder_save, map_location=lambda storage, loc: storage)) if FLAGS.cuda: encoder.cuda() decoder.cuda() if not os.path.exists('reconstructed_images'): os.makedirs('reconstructed_images') # load data set and create data loader instance print('Loading MNIST paired dataset...') paired_mnist = MNIST_Paired(root='mnist',
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict( torch.load(os.path.join(savedir, FLAGS.encoder_save))) decoder.load_state_dict( torch.load(os.path.join(savedir, FLAGS.decoder_save))) ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() """ optimizer definition """ auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) savedir = 'checkpoints_%d' % (FLAGS.batch_size) if not os.path.exists(savedir): os.makedirs(savedir) # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write( 'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n' ) # load data set and create data loader instance print('Loading MNIST dataset...') mnist = datasets.MNIST(root='mnist', download=True, train=True, transform=transform_config) # Creating data indices for training and validation splits: dataset_size = len(mnist) indices = list(range(dataset_size)) split = 10000 np.random.seed(0) np.random.shuffle(indices) train_indices, val_indices = indices[split:], indices[:split] train_mnist, val_mnist = torch.utils.data.random_split( mnist, [dataset_size - split, split]) # Creating PT data samplers and loaders: weights_train = torch.ones(len(mnist)) weights_test = torch.ones(len(mnist)) weights_train[val_mnist.indices] = 0 weights_test[train_mnist.indices] = 0 counts = torch.zeros(10) for i in range(10): idx_label = mnist.targets[train_mnist.indices].eq(i) counts[i] = idx_label.sum() max = float(counts.max()) sum_counts = float(counts.sum()) for i in range(10): idx_label = mnist.targets[train_mnist.indices].eq( i).nonzero().squeeze() weights_train[train_mnist.indices[idx_label]] = (sum_counts / counts[i]) train_sampler = SubsetRandomSampler(train_mnist.indices) valid_sampler = SubsetRandomSampler(val_mnist.indices) kwargs = {'num_workers': 1, 'pin_memory': True} if FLAGS.cuda else {} loader = DataLoader(mnist, batch_size=FLAGS.batch_size, sampler=train_sampler, **kwargs) valid_loader = DataLoader(mnist, batch_size=FLAGS.batch_size, sampler=valid_sampler, **kwargs) monitor = torch.zeros(FLAGS.end_epoch - FLAGS.start_epoch, 4) # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print( 'Epoch #' + str(epoch) + '..........................................................................' ) elbo_epoch = 0 term1_epoch = 0 term2_epoch = 0 term3_epoch = 0 for it, (image_batch, labels_batch) in enumerate(loader): # set zero_grad for the optimizer auto_encoder_optimizer.zero_grad() X = image_batch.cuda().detach().clone() elbo, reconstruction_proba, style_kl_divergence_loss, class_kl_divergence_loss = process( FLAGS, X, labels_batch, encoder, decoder) (-elbo).backward() auto_encoder_optimizer.step() elbo_epoch += elbo term1_epoch += reconstruction_proba term2_epoch += style_kl_divergence_loss term3_epoch += class_kl_divergence_loss print("Elbo epoch %.2f" % (elbo_epoch / (it + 1))) print("Rec. Proba %.2f" % (term1_epoch / (it + 1))) print("KL style %.2f" % (term2_epoch / (it + 1))) print("KL content %.2f" % (term3_epoch / (it + 1))) # save checkpoints after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: monitor[epoch, :] = eval(FLAGS, valid_loader, encoder, decoder) torch.save( encoder.state_dict(), os.path.join(savedir, FLAGS.encoder_save + '_e%d' % epoch)) torch.save( decoder.state_dict(), os.path.join(savedir, FLAGS.decoder_save + '_e%d' % epoch)) print("VAL elbo %.2f" % (monitor[epoch, 0])) print("VAL Rec. Proba %.2f" % (monitor[epoch, 1])) print("VAL KL style %.2f" % (monitor[epoch, 2])) print("VAL KL content %.2f" % (monitor[epoch, 3])) torch.save(monitor, os.path.join(savedir, 'monitor_e%d' % epoch))
D_opt = optim.Adam(D.parameters(), lr=args.lr_D, betas=(args.beta1, args.beta2)) # Load weights from a specific epoch start_ep = 0 if args.load_epoch is not None: if args.load_from_experiment is None: load_checkpoint_path = checkpoint_path else: load_checkpoint_path = join('results', args.load_from_experiment, 'checkpoint') load_ep = args.load_epoch start_ep = load_ep + 1 E.load_state_dict( torch.load( join(load_checkpoint_path, '{:03}.E.pth'.format(load_ep)))) G.load_state_dict( torch.load( join(load_checkpoint_path, '{:03}.G.pth'.format(load_ep)))) D.load_state_dict( torch.load( join(load_checkpoint_path, '{:03}.D.pth'.format(load_ep)))) G_opt.load_state_dict( torch.load( join(load_checkpoint_path, '{:03}.G_opt.pth'.format(load_ep)))) D_opt.load_state_dict( torch.load( join(load_checkpoint_path, '{:03}.D_opt.pth'.format(load_ep)))) # Criterion
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) discriminator = Discriminator() discriminator.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) discriminator.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.discriminator_save))) """ variable definition """ real_domain_labels = 1 fake_domain_labels = 0 X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) domain_labels = torch.LongTensor(FLAGS.batch_size) style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim) """ loss definitions """ cross_entropy_loss = nn.CrossEntropyLoss() ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() discriminator.cuda() cross_entropy_loss.cuda() X_1 = X_1.cuda() X_2 = X_2.cuda() X_3 = X_3.cuda() domain_labels = domain_labels.cuda() style_latent_space = style_latent_space.cuda() """ optimizer definition """ auto_encoder_optimizer = optim.Adam( list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) discriminator_optimizer = optim.Adam( list(discriminator.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) generator_optimizer = optim.Adam( list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") if not os.path.exists('checkpoints'): os.makedirs('checkpoints') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write('Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\t') log.write('Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n') # load data set and create data loader instance print('Loading MNIST paired dataset...') paired_mnist = MNIST_Paired(root='mnist', download=True, train=True, transform=transform_config) loader = cycle(DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # initialise variables discriminator_accuracy = 0. # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print('Epoch #' + str(epoch) + '..........................................................................') for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)): # A. run the auto-encoder reconstruction image_batch_1, image_batch_2, _ = next(loader) auto_encoder_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_mu_1, style_logvar_1, class_1 = encoder(Variable(X_1)) style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1) kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp()) kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_1.backward(retain_graph=True) _, __, class_2 = encoder(Variable(X_2)) reconstructed_X_1 = decoder(style_1, class_1) reconstructed_X_2 = decoder(style_1, class_2) reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1)) reconstruction_error_1.backward(retain_graph=True) reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_1)) reconstruction_error_2.backward() reconstruction_error = reconstruction_error_1 + reconstruction_error_2 kl_divergence_error = kl_divergence_loss_1 auto_encoder_optimizer.step() # B. run the generator for i in range(FLAGS.generator_times): generator_optimizer.zero_grad() image_batch_1, _, __ = next(loader) image_batch_3, _, __ = next(loader) domain_labels.fill_(real_domain_labels) X_1.copy_(image_batch_1) X_3.copy_(image_batch_3) style_mu_1, style_logvar_1, _ = encoder(Variable(X_1)) style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1) kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp()) kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_1.backward(retain_graph=True) _, __, class_3 = encoder(Variable(X_3)) reconstructed_X_1_3 = decoder(style_1, class_3) output_1 = discriminator(Variable(X_3), reconstructed_X_1_3) generator_error_1 = cross_entropy_loss(output_1, Variable(domain_labels)) generator_error_1.backward(retain_graph=True) style_latent_space.normal_(0., 1.) reconstructed_X_latent_3 = decoder(Variable(style_latent_space), class_3) output_2 = discriminator(Variable(X_3), reconstructed_X_latent_3) generator_error_2 = cross_entropy_loss(output_2, Variable(domain_labels)) generator_error_2.backward() generator_error = generator_error_1 + generator_error_2 kl_divergence_error += kl_divergence_loss_1 generator_optimizer.step() # C. run the discriminator for i in range(FLAGS.discriminator_times): discriminator_optimizer.zero_grad() # train discriminator on real data domain_labels.fill_(real_domain_labels) image_batch_1, _, __ = next(loader) image_batch_2, image_batch_3, _ = next(loader) X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) X_3.copy_(image_batch_3) real_output = discriminator(Variable(X_2), Variable(X_3)) discriminator_real_error = cross_entropy_loss(real_output, Variable(domain_labels)) discriminator_real_error.backward() # train discriminator on fake data domain_labels.fill_(fake_domain_labels) style_mu_1, style_logvar_1, _ = encoder(Variable(X_1)) style_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) _, __, class_3 = encoder(Variable(X_3)) reconstructed_X_1_3 = decoder(style_1, class_3) fake_output = discriminator(Variable(X_3), reconstructed_X_1_3) discriminator_fake_error = cross_entropy_loss(fake_output, Variable(domain_labels)) discriminator_fake_error.backward() # total discriminator error discriminator_error = discriminator_real_error + discriminator_fake_error # calculate discriminator accuracy for this step target_true_labels = torch.cat((torch.ones(FLAGS.batch_size), torch.zeros(FLAGS.batch_size)), dim=0) if FLAGS.cuda: target_true_labels = target_true_labels.cuda() discriminator_predictions = torch.cat((real_output, fake_output), dim=0) _, discriminator_predictions = torch.max(discriminator_predictions, 1) discriminator_accuracy = (discriminator_predictions.data == target_true_labels.long() ).sum().item() / (FLAGS.batch_size * 2) if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy: discriminator_optimizer.step() if (iteration + 1) % 50 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0])) print('') print('Generator loss: ' + str(generator_error.data.storage().tolist()[0])) print('Discriminator loss: ' + str(discriminator_error.data.storage().tolist()[0])) print('Discriminator accuracy: ' + str(discriminator_accuracy)) print('..........') # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], kl_divergence_error.data.storage().tolist()[0], generator_error.data.storage().tolist()[0], discriminator_error.data.storage().tolist()[0], discriminator_accuracy )) # write to tensorboard writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Generator loss', generator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator loss', discriminator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator accuracy', discriminator_accuracy * 100, epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) # save model after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save)) torch.save(discriminator.state_dict(), os.path.join('checkpoints', FLAGS.discriminator_save))
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim) encoder.apply(weights_init) decoder = Decoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim) decoder.apply(weights_init) discriminator = Discriminator() discriminator.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) discriminator.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.discriminator_save))) """ variable definition """ real_domain_labels = 1 fake_domain_labels = 0 X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) domain_labels = torch.LongTensor(FLAGS.batch_size) """ loss definitions """ cross_entropy_loss = nn.CrossEntropyLoss() ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() discriminator.cuda() cross_entropy_loss.cuda() X_1 = X_1.cuda() X_2 = X_2.cuda() X_3 = X_3.cuda() domain_labels = domain_labels.cuda() """ optimizer definition """ auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) discriminator_optimizer = optim.Adam(list(discriminator.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) generator_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) if not os.path.exists('checkpoints'): os.makedirs('checkpoints') if not os.path.exists('reconstructed_images'): os.makedirs('reconstructed_images') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write('Epoch\tIteration\tReconstruction_loss\t') log.write( 'Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n') # load data set and create data loader instance print('Loading MNIST paired dataset...') paired_mnist = MNIST_Paired(root='mnist', download=True, train=True, transform=transform_config) loader = cycle( DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # initialise variables discriminator_accuracy = 0. # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print( 'Epoch #' + str(epoch) + '..........................................................................' ) for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)): # A. run the auto-encoder reconstruction image_batch_1, image_batch_2, labels_batch_1 = next(loader) auto_encoder_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) nv_1, nc_1 = encoder(Variable(X_1)) nv_2, nc_2 = encoder(Variable(X_2)) reconstructed_X_1 = decoder(nv_1, nc_2) reconstructed_X_2 = decoder(nv_2, nc_1) reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1)) reconstruction_error_1.backward(retain_graph=True) reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_2)) reconstruction_error_2.backward() reconstruction_error = reconstruction_error_1 + reconstruction_error_2 if FLAGS.train_auto_encoder: auto_encoder_optimizer.step() # B. run the adversarial part of the architecture # B. a) run the discriminator for i in range(FLAGS.discriminator_times): discriminator_optimizer.zero_grad() # train discriminator on real data domain_labels.fill_(real_domain_labels) image_batch_1, image_batch_2, labels_batch_1 = next(loader) X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) real_output = discriminator(Variable(X_1), Variable(X_2)) discriminator_real_error = FLAGS.disc_coef * cross_entropy_loss( real_output, Variable(domain_labels)) discriminator_real_error.backward() # train discriminator on fake data domain_labels.fill_(fake_domain_labels) image_batch_3, _, labels_batch_3 = next(loader) X_3.copy_(image_batch_3) nv_3, nc_3 = encoder(Variable(X_3)) # reconstruction is taking common factor from X_1 and varying factor from X_3 reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1]) fake_output = discriminator(Variable(X_1), reconstructed_X_3_1) discriminator_fake_error = FLAGS.disc_coef * cross_entropy_loss( fake_output, Variable(domain_labels)) discriminator_fake_error.backward() # total discriminator error discriminator_error = discriminator_real_error + discriminator_fake_error # calculate discriminator accuracy for this step target_true_labels = torch.cat((torch.ones( FLAGS.batch_size), torch.zeros(FLAGS.batch_size)), dim=0) if FLAGS.cuda: target_true_labels = target_true_labels.cuda() discriminator_predictions = torch.cat( (real_output, fake_output), dim=0) _, discriminator_predictions = torch.max( discriminator_predictions, 1) discriminator_accuracy = (discriminator_predictions.data == target_true_labels.long()).sum( ).item() / (FLAGS.batch_size * 2) if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy and FLAGS.train_discriminator: discriminator_optimizer.step() # B. b) run the generator for i in range(FLAGS.generator_times): generator_optimizer.zero_grad() image_batch_1, _, labels_batch_1 = next(loader) image_batch_3, __, labels_batch_3 = next(loader) domain_labels.fill_(real_domain_labels) X_1.copy_(image_batch_1) X_3.copy_(image_batch_3) nv_3, nc_3 = encoder(Variable(X_3)) # reconstruction is taking common factor from X_1 and varying factor from X_3 reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1]) output = discriminator(Variable(X_1), reconstructed_X_3_1) generator_error = FLAGS.gen_coef * cross_entropy_loss( output, Variable(domain_labels)) generator_error.backward() if FLAGS.train_generator: generator_optimizer.step() # print progress after 10 iterations if (iteration + 1) % 10 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('Generator loss: ' + str(generator_error.data.storage().tolist()[0])) print('') print('Discriminator loss: ' + str(discriminator_error.data.storage().tolist()[0])) print('Discriminator accuracy: ' + str(discriminator_accuracy)) print('..........') # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], generator_error.data.storage().tolist()[0], discriminator_error.data.storage().tolist()[0], discriminator_accuracy)) # write to tensorboard writer.add_scalar( 'Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Generator loss', generator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Discriminator loss', discriminator_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) # save model after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save)) torch.save(discriminator.state_dict(), os.path.join('checkpoints', FLAGS.discriminator_save)) """ save reconstructed images and style swapped image generations to check progress """ image_batch_1, image_batch_2, labels_batch_1 = next(loader) image_batch_3, _, __ = next(loader) X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) X_3.copy_(image_batch_3) nv_1, nc_1 = encoder(Variable(X_1)) nv_2, nc_2 = encoder(Variable(X_2)) nv_3, nc_3 = encoder(Variable(X_3)) reconstructed_X_1 = decoder(nv_1, nc_2) reconstructed_X_3_2 = decoder(nv_3, nc_2) # save input image batch image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1)) image_batch = np.concatenate( (image_batch, image_batch, image_batch), axis=3) imshow_grid(image_batch, name=str(epoch) + '_original', save=True) # save reconstructed batch reconstructed_x = np.transpose( reconstructed_X_1.cpu().data.numpy(), (0, 2, 3, 1)) reconstructed_x = np.concatenate( (reconstructed_x, reconstructed_x, reconstructed_x), axis=3) imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True) # save cross reconstructed batch style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1)) style_batch = np.concatenate( (style_batch, style_batch, style_batch), axis=3) imshow_grid(style_batch, name=str(epoch) + '_style', save=True) reconstructed_style = np.transpose( reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1)) reconstructed_style = np.concatenate( (reconstructed_style, reconstructed_style, reconstructed_style), axis=3) imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
E.to(device) G = Generator(n_classes) G.to(device) if args.multi_gpu: # If trained with multi-GPU, the model needs to be loaded with multi-GPU, too. E = nn.DataParallel(E) G = nn.DataParallel(G) # G = convert_model(G) # Load from checkpoints load_epoch = args.test_epoch if load_epoch is None: # Use the lastest model load_epoch = max(int(path.split('.')[0]) for path in listdir(checkpoint_path) if path.split('.')[0].isdigit()) print('Loading generator from epoch {:03d}'.format(load_epoch)) E.load_state_dict(torch.load( join(checkpoint_path, '{:03d}.E.pth'.format(load_epoch)), map_location=lambda storage, loc: storage )) G.load_state_dict(torch.load( join(checkpoint_path, '{:03d}.G.pth'.format(load_epoch)), map_location=lambda storage, loc: storage )) E.eval() G.eval() with torch.no_grad(): for batch_idx, (reals, annos) in enumerate(tqdm(val_data)): reals, annos = reals.to(device), annos.to(device) annos_onehot = onehot2d(annos, n_classes).type_as(reals) # Encode images and sample latents mu, logvar = E(reals)
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) discriminator = Discriminator() discriminator.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: raise Exception('This is not implemented') encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) """ variable definition """ X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim) """ loss definitions """ cross_entropy_loss = nn.CrossEntropyLoss() adversarial_loss = nn.BCELoss() ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() discriminator.cuda() cross_entropy_loss.cuda() adversarial_loss.cuda() X_1 = X_1.cuda() X_2 = X_2.cuda() X_3 = X_3.cuda() style_latent_space = style_latent_space.cuda() """ optimizer and scheduler definition """ auto_encoder_optimizer = optim.Adam( list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) reverse_cycle_optimizer = optim.Adam( list(encoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) generator_optimizer = optim.Adam( list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) discriminator_optimizer = optim.Adam( list(discriminator.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2) ) # divide the learning rate by a factor of 10 after 80 epochs auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer, step_size=80, gamma=0.1) reverse_cycle_scheduler = optim.lr_scheduler.StepLR(reverse_cycle_optimizer, step_size=80, gamma=0.1) generator_scheduler = optim.lr_scheduler.StepLR(generator_optimizer, step_size=80, gamma=0.1) discriminator_scheduler = optim.lr_scheduler.StepLR(discriminator_optimizer, step_size=80, gamma=0.1) # Used later to define discriminator ground truths Tensor = torch.cuda.FloatTensor if FLAGS.cuda else torch.FloatTensor """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") if not os.path.exists('checkpoints'): os.makedirs('checkpoints') if not os.path.exists('reconstructed_images'): os.makedirs('reconstructed_images') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: headers = ['Epoch', 'Iteration', 'Reconstruction_loss', 'KL_divergence_loss', 'Reverse_cycle_loss'] if FLAGS.forward_gan: headers.extend(['Generator_forward_loss', 'Discriminator_forward_loss']) if FLAGS.reverse_gan: headers.extend(['Generator_reverse_loss', 'Discriminator_reverse_loss']) log.write('\t'.join(headers) + '\n') # load data set and create data loader instance print('Loading CIFAR paired dataset...') paired_cifar = CIFAR_Paired(root='cifar', download=True, train=True, transform=transform_config) loader = cycle(DataLoader(paired_cifar, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # Save a batch of images to use for visualization image_sample_1, image_sample_2, _ = next(loader) image_sample_3, _, _ = next(loader) # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print('Epoch #' + str(epoch) + '..........................................................................') # update the learning rate scheduler auto_encoder_scheduler.step() reverse_cycle_scheduler.step() generator_scheduler.step() discriminator_scheduler.step() for iteration in range(int(len(paired_cifar) / FLAGS.batch_size)): # Adversarial ground truths valid = Variable(Tensor(FLAGS.batch_size, 1).fill_(1.0), requires_grad=False) fake = Variable(Tensor(FLAGS.batch_size, 1).fill_(0.0), requires_grad=False) # A. run the auto-encoder reconstruction image_batch_1, image_batch_2, _ = next(loader) auto_encoder_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_mu_1, style_logvar_1, class_latent_space_1 = encoder(Variable(X_1)) style_latent_space_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1) kl_divergence_loss_1 = FLAGS.kl_divergence_coef * ( - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp()) ) kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_1.backward(retain_graph=True) style_mu_2, style_logvar_2, class_latent_space_2 = encoder(Variable(X_2)) style_latent_space_2 = reparameterize(training=True, mu=style_mu_2, logvar=style_logvar_2) kl_divergence_loss_2 = FLAGS.kl_divergence_coef * ( - 0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) - style_logvar_2.exp()) ) kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_2.backward(retain_graph=True) reconstructed_X_1 = decoder(style_latent_space_1, class_latent_space_2) reconstructed_X_2 = decoder(style_latent_space_2, class_latent_space_1) reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_1, Variable(X_1)) reconstruction_error_1.backward(retain_graph=True) reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_2, Variable(X_2)) reconstruction_error_2.backward() reconstruction_error = (reconstruction_error_1 + reconstruction_error_2) / FLAGS.reconstruction_coef kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2) / FLAGS.kl_divergence_coef auto_encoder_optimizer.step() # A-1. Discriminator training during forward cycle if FLAGS.forward_gan: # Training generator generator_optimizer.zero_grad() g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid) g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid) gen_f_loss = (g_loss_1 + g_loss_2) / 2.0 gen_f_loss.backward() generator_optimizer.step() # Training discriminator discriminator_optimizer.zero_grad() real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid) real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid) fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake) fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake) dis_f_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0 dis_f_loss.backward() discriminator_optimizer.step() # B. reverse cycle image_batch_1, _, __ = next(loader) image_batch_2, _, __ = next(loader) reverse_cycle_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_latent_space.normal_(0., 1.) _, __, class_latent_space_1 = encoder(Variable(X_1)) _, __, class_latent_space_2 = encoder(Variable(X_2)) reconstructed_X_1 = decoder(Variable(style_latent_space), class_latent_space_1.detach()) reconstructed_X_2 = decoder(Variable(style_latent_space), class_latent_space_2.detach()) style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1) style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2) style_latent_space_2 = reparameterize(training=False, mu=style_mu_2, logvar=style_logvar_2) reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(style_latent_space_1, style_latent_space_2) reverse_cycle_loss.backward() reverse_cycle_loss /= FLAGS.reverse_cycle_coef reverse_cycle_optimizer.step() # B-1. Discriminator training during reverse cycle if FLAGS.reverse_gan: # Training generator generator_optimizer.zero_grad() g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid) g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid) gen_r_loss = (g_loss_1 + g_loss_2) / 2.0 gen_r_loss.backward() generator_optimizer.step() # Training discriminator discriminator_optimizer.zero_grad() real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid) real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid) fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake) fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake) dis_r_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0 dis_r_loss.backward() discriminator_optimizer.step() if (iteration + 1) % 10 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0])) print('Reverse cycle loss: ' + str(reverse_cycle_loss.data.storage().tolist()[0])) if FLAGS.forward_gan: print('Generator F loss: ' + str(gen_f_loss.data.storage().tolist()[0])) print('Discriminator F loss: ' + str(dis_f_loss.data.storage().tolist()[0])) if FLAGS.reverse_gan: print('Generator R loss: ' + str(gen_r_loss.data.storage().tolist()[0])) print('Discriminator R loss: ' + str(dis_r_loss.data.storage().tolist()[0])) # write to log with open(FLAGS.log_file, 'a') as log: row = [] row.append(epoch) row.append(iteration) row.append(reconstruction_error.data.storage().tolist()[0]) row.append(kl_divergence_error.data.storage().tolist()[0]) row.append(reverse_cycle_loss.data.storage().tolist()[0]) if FLAGS.forward_gan: row.append(gen_f_loss.data.storage().tolist()[0]) row.append(dis_f_loss.data.storage().tolist()[0]) if FLAGS.reverse_gan: row.append(gen_r_loss.data.storage().tolist()[0]) row.append(dis_r_loss.data.storage().tolist()[0]) row = [str(x) for x in row] log.write('\t'.join(row) + '\n') # write to tensorboard writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Reverse cycle loss', reverse_cycle_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) if FLAGS.forward_gan: writer.add_scalar('Generator F loss', gen_f_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator F loss', dis_f_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) if FLAGS.reverse_gan: writer.add_scalar('Generator R loss', gen_r_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar('Discriminator R loss', dis_r_loss.data.storage().tolist()[0], epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration) # save model after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save)) """ save reconstructed images and style swapped image generations to check progress """ X_1.copy_(image_sample_1) X_2.copy_(image_sample_2) X_3.copy_(image_sample_3) style_mu_1, style_logvar_1, _ = encoder(Variable(X_1)) _, __, class_latent_space_2 = encoder(Variable(X_2)) style_mu_3, style_logvar_3, _ = encoder(Variable(X_3)) style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) style_latent_space_3 = reparameterize(training=False, mu=style_mu_3, logvar=style_logvar_3) reconstructed_X_1_2 = decoder(style_latent_space_1, class_latent_space_2) reconstructed_X_3_2 = decoder(style_latent_space_3, class_latent_space_2) # save input image batch image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: image_batch = np.concatenate((image_batch, image_batch, image_batch), axis=3) imshow_grid(image_batch, name=str(epoch) + '_original', save=True) # save reconstructed batch reconstructed_x = np.transpose(reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: reconstructed_x = np.concatenate((reconstructed_x, reconstructed_x, reconstructed_x), axis=3) imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True) style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: style_batch = np.concatenate((style_batch, style_batch, style_batch), axis=3) imshow_grid(style_batch, name=str(epoch) + '_style', save=True) # save style swapped reconstructed batch reconstructed_style = np.transpose(reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1)) if FLAGS.num_channels == 1: reconstructed_style = np.concatenate((reconstructed_style, reconstructed_style, reconstructed_style), axis=3) imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) """ variable definition """ X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size) style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim) """ loss definitions """ cross_entropy_loss = nn.CrossEntropyLoss() ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() cross_entropy_loss.cuda() X_1 = X_1.cuda() X_2 = X_2.cuda() X_3 = X_3.cuda() style_latent_space = style_latent_space.cuda() """ optimizer and scheduler definition """ auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) reverse_cycle_optimizer = optim.Adam(list(encoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) # divide the learning rate by a factor of 10 after 80 epochs auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer, step_size=80, gamma=0.1) reverse_cycle_scheduler = optim.lr_scheduler.StepLR( reverse_cycle_optimizer, step_size=80, gamma=0.1) """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) if not os.path.exists('checkpoints'): os.makedirs('checkpoints') if not os.path.exists('reconstructed_images'): os.makedirs('reconstructed_images') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write( 'Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\tReverse_cycle_loss\n' ) # load data set and create data loader instance print('Loading MNIST paired dataset...') paired_mnist = MNIST_Paired(root='mnist', download=True, train=True, transform=transform_config) loader = cycle( DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print( 'Epoch #' + str(epoch) + '..........................................................................' ) # update the learning rate scheduler auto_encoder_scheduler.step() reverse_cycle_scheduler.step() for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)): # A. run the auto-encoder reconstruction image_batch_1, image_batch_2, _ = next(loader) auto_encoder_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_mu_1, style_logvar_1, class_latent_space_1 = encoder( Variable(X_1)) style_latent_space_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1) kl_divergence_loss_1 = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())) kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_1.backward(retain_graph=True) style_mu_2, style_logvar_2, class_latent_space_2 = encoder( Variable(X_2)) style_latent_space_2 = reparameterize(training=True, mu=style_mu_2, logvar=style_logvar_2) kl_divergence_loss_2 = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) - style_logvar_2.exp())) kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) kl_divergence_loss_2.backward(retain_graph=True) reconstructed_X_1 = decoder(style_latent_space_1, class_latent_space_2) reconstructed_X_2 = decoder(style_latent_space_2, class_latent_space_1) reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss( reconstructed_X_1, Variable(X_1)) reconstruction_error_1.backward(retain_graph=True) reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss( reconstructed_X_2, Variable(X_2)) reconstruction_error_2.backward() reconstruction_error = ( reconstruction_error_1 + reconstruction_error_2) / FLAGS.reconstruction_coef kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2 ) / FLAGS.kl_divergence_coef auto_encoder_optimizer.step() # B. reverse cycle image_batch_1, _, __ = next(loader) image_batch_2, _, __ = next(loader) reverse_cycle_optimizer.zero_grad() X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) style_latent_space.normal_(0., 1.) _, __, class_latent_space_1 = encoder(Variable(X_1)) _, __, class_latent_space_2 = encoder(Variable(X_2)) reconstructed_X_1 = decoder(Variable(style_latent_space), class_latent_space_1.detach()) reconstructed_X_2 = decoder(Variable(style_latent_space), class_latent_space_2.detach()) style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1) style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2) style_latent_space_2 = reparameterize(training=False, mu=style_mu_2, logvar=style_logvar_2) reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss( style_latent_space_1, style_latent_space_2) reverse_cycle_loss.backward() reverse_cycle_loss /= FLAGS.reverse_cycle_coef reverse_cycle_optimizer.step() if (iteration + 1) % 10 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0])) print('Reverse cycle loss: ' + str(reverse_cycle_loss.data.storage().tolist()[0])) # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], kl_divergence_error.data.storage().tolist()[0], reverse_cycle_loss.data.storage().tolist()[0])) # write to tensorboard writer.add_scalar( 'Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Reverse cycle loss', reverse_cycle_loss.data.storage().tolist()[0], epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration) # save model after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save)) """ save reconstructed images and style swapped image generations to check progress """ image_batch_1, image_batch_2, _ = next(loader) image_batch_3, _, __ = next(loader) X_1.copy_(image_batch_1) X_2.copy_(image_batch_2) X_3.copy_(image_batch_3) style_mu_1, style_logvar_1, _ = encoder(Variable(X_1)) _, __, class_latent_space_2 = encoder(Variable(X_2)) style_mu_3, style_logvar_3, _ = encoder(Variable(X_3)) style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1) style_latent_space_3 = reparameterize(training=False, mu=style_mu_3, logvar=style_logvar_3) reconstructed_X_1_2 = decoder(style_latent_space_1, class_latent_space_2) reconstructed_X_3_2 = decoder(style_latent_space_3, class_latent_space_2) # save input image batch image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1)) image_batch = np.concatenate( (image_batch, image_batch, image_batch), axis=3) imshow_grid(image_batch, name=str(epoch) + '_original', save=True) # save reconstructed batch reconstructed_x = np.transpose( reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1)) reconstructed_x = np.concatenate( (reconstructed_x, reconstructed_x, reconstructed_x), axis=3) imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True) style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1)) style_batch = np.concatenate( (style_batch, style_batch, style_batch), axis=3) imshow_grid(style_batch, name=str(epoch) + '_style', save=True) # save style swapped reconstructed batch reconstructed_style = np.transpose( reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1)) reconstructed_style = np.concatenate( (reconstructed_style, reconstructed_style, reconstructed_style), axis=3) imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
print('Pre-Trained GloVe Model') else: print('Pre-Trained Baseline Model') else: encoder_checkpoint = torch.load(f'./checkpoints/encoder_{model_tag}', map_location='cpu') decoder_checkpoint = torch.load(f'./checkpoints/decoder_{model_tag}', map_location='cpu') if bert_model: print('Pre-Trained BERT Model') elif glove_model: print('Pre-Trained GloVe Model') else: print('Pre-Trained Baseline Model') encoder.load_state_dict(encoder_checkpoint['model_state_dict']) decoder_optimizer = torch.optim.Adam(params=decoder.parameters(), lr=decoder_lr) decoder.load_state_dict(decoder_checkpoint['model_state_dict']) decoder_optimizer.load_state_dict( decoder_checkpoint['optimizer_state_dict']) else: encoder = Encoder().to(device) decoder = Decoder(vocab_size=len(vocab), use_glove=glove_model, use_bert=bert_model, device=device, tokenizer=tokenizer, vocab=vocab, bert_model=BertModel, glove_vectors=glove_vectors).to(device)