def gradient_penalty(fake_x, origin_x, D, d_optimizer): alpha = torch.rand(origin_x.size(0), 1, 1, 1).cuda().expand_as(origin_x) interpolated = Variable(alpha * origin_x.data + (1 - alpha) * fake_x.data, requires_grad=True) out = D(interpolated) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones( out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() d_optimizer.step() return d_loss_gp
def fitness_score(self, eval_fake_imgs, eval_real_imgs): self.set_requires_grad(self.D, True) eval_fake = self.D(eval_fake_imgs) eval_real = self.D(eval_real_imgs) fake_loss = torch.mean(eval_fake) real_loss = -torch.mean(eval_real) D_loss_score = fake_loss + real_loss # quality fitness score Fq = nn.functional.sigmoid(eval_fake).data.mean().cpu().numpy() # Diversity fitness score gradients = torch.autograd.grad( outputs=D_loss_score, inputs=self.D.parameters(), grad_outputs=torch.ones(D_loss_score.size()).to(self.device), create_graph=True, retain_graph=True, only_inputs=True) with torch.no_grad(): for i, grad in enumerate(gradients): grad = grad.view(-1) allgrad = grad if i == 0 else torch.cat([allgrad, grad]) Fd = -torch.log(torch.norm(allgrad)).data.cpu().numpy() return Fq, Fd
def make_step(grad, attack, step_size): if attack == 'l2': grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, 1, 1, 1) scaled_grad = grad / (grad_norm + 1e-10) step = step_size * scaled_grad elif attack == 'inf': step = step_size * torch.sign(grad) else: step = step_size * grad return step
def compute_grad_penalty(net_D, true_data, fake_data): batch_size = true_data.shape[0] epsilon = true_data.new(batch_size, 1, 1, 1) epsilon = epsilon.uniform_() line_data = true_data * (1 - epsilon) + fake_data * (1 - epsilon) line_data = Parameter(line_data) line_pred = net_D(line_data).sum() grad, = torch.autograd.grad(line_pred, line_data, create_graph=True) grad = grad.view(batch_size, -1) grad_norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) return ((grad_norm - 1) ** 2).mean()
def Fvp(v): kl = self.get_kl(states) kl = kl.mean() grads = torch.autograd.grad(kl, self.actor.parameters(), create_graph=True) flat_grad_kl = torch.cat([grad.view(-1) for grad in grads]) kl_v = (flat_grad_kl * Variable(v)).sum() grads = torch.autograd.grad(kl_v, self.actor.parameters()) flat_grad_grad_kl = torch.cat( [grad.contiguous().view(-1) for grad in grads]).data return flat_grad_grad_kl + v * self.damping
def update_actor(self, states, actions, advantages): action_means, action_log_stds, action_stds = self.actor( Variable(states)) fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone() loss = self.get_loss(states, actions, advantages, fixed_log_prob) grads = torch.autograd.grad(loss, self.actor.parameters()) loss_grad = torch.cat([grad.view(-1) for grad in grads]).data def Fvp(v): kl = self.get_kl(states) kl = kl.mean() grads = torch.autograd.grad(kl, self.actor.parameters(), create_graph=True) flat_grad_kl = torch.cat([grad.view(-1) for grad in grads]) kl_v = (flat_grad_kl * Variable(v)).sum() grads = torch.autograd.grad(kl_v, self.actor.parameters()) flat_grad_grad_kl = torch.cat( [grad.contiguous().view(-1) for grad in grads]).data return flat_grad_grad_kl + v * self.damping step_dir = conjugate_gradients(Fvp, -loss_grad, self.nsteps) shs = 0.5 * (step_dir * Fvp(step_dir)).sum(0, keepdim=True) lm = torch.sqrt(shs / self.max_kl) fullstep = step_dir / lm[0] neggdotstepdir = (-loss_grad * step_dir).sum(0, keepdim=True) print("lagrange multiplier:", lm[0], "grad_norm:", loss_grad.norm()) success, new_params = self.linesearch(states, actions, advantages, fixed_log_prob, fullstep, neggdotstepdir / lm[0]) set_flat_params_to(self.actor, new_params)
def train(self): """Train attribute-guided face image synthesis model""" self.data_loader = self.face_data_loader # The number of iterations for each epoch iters_per_epoch = len(self.data_loader) sample_x = [] sample_l = [] real_y = [] for i, (images, landmark) in enumerate(self.data_loader): labels = images[1] sample_x.append(images[0]) sample_l.append(landmark[0]) real_y.append(labels) if i == 2: break # Sample inputs and desired domain labels for testing sample_x = torch.cat(sample_x, dim=0) sample_x = self.to_var(sample_x, volatile=True) sample_l = torch.cat(sample_l, dim=0) sample_l = self.to_var(sample_l, volatile=True) real_y = torch.cat(real_y, dim=0) sample_y_list = [] for i in range(self.y_dim): sample_y = self.one_hot( torch.ones(sample_x.size(0)) * i, self.y_dim) sample_y_list.append(self.to_var(sample_y, volatile=True)) # Learning rate for decaying d_lr = self.d_lr enc_lr = self.enc_lr dec_lr = self.dec_lr # Start with trained model if self.trained_model: start = int(self.trained_model.split('_')[0]) else: start = 0 # Start training start_time = time.time() for e in range(start, self.num_epochs): for i, (real_image, real_landmark) in enumerate(self.data_loader): #real_x: real image and real_l: conditional side image (landmark heatmap) real_x = real_image[0] real_label = real_image[1] real_l = real_landmark[0] # Sample fake labels randomly rand_idx = torch.randperm(real_label.size(0)) fake_label = real_label[rand_idx] real_y = self.one_hot(real_label, self.y_dim) fake_y = self.one_hot(fake_label, self.y_dim) # Convert tensor to variable real_x = self.to_var(real_x) real_l = self.to_var(real_l) real_y = self.to_var(real_y) fake_y = self.to_var(fake_y) real_label = self.to_var(real_label) fake_label = self.to_var(fake_label) #================== Train Discriminator ================== # # Input images (original image+side images) are concatenated src_output, cls_output = self.D(torch.cat([real_x, real_l], 1)) d_loss_real = -torch.mean(src_output) d_loss_cls = F.cross_entropy(cls_output, real_label) # Compute expression recognition accuracy on synthetic images if (i + 1) % self.log_step == 0: accuracies = self.calculate_accuracy( cls_output, real_label) log = [ "{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy() ] print('Recognition Acc: ') print(log) # Generate outputs and compute loss with fake generated images enc_feat = self.Enc(torch.cat([real_x, real_l], 1)) fake_x, fake_l = self.Dec(enc_feat, fake_y) fake_x = Variable(fake_x.data) fake_l = Variable(fake_l.data) src_output, cls_output = self.D(torch.cat([fake_x, fake_l], 1)) d_loss_fake = torch.mean(src_output) # Discriminator losses d_loss = self.lambda_cls * d_loss_cls + d_loss_real + d_loss_fake self.reset() d_loss.backward() self.d_optimizer.step() # Compute gradient penalty loss real = torch.cat([real_x, real_l], 1) fake = torch.cat([fake_x, fake_l], 1) alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real) interpolated = Variable(alpha * real.data + (1 - alpha) * fake.data, requires_grad=True) output, cls_output = self.D(interpolated) grad = torch.autograd.grad(outputs=output, inputs=interpolated, grad_outputs=torch.ones( output.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Gradient penalty loss d_loss = self.lambda_gp * d_loss_gp self.reset() d_loss.backward() self.d_optimizer.step() # Logging loss = {} loss['D/loss_real'] = d_loss_real.data[0] loss['D/loss_fake'] = d_loss_fake.data[0] loss['D/loss_cls'] = d_loss_cls.data[0] loss['D/loss_gp'] = d_loss_gp.data[0] # ================== Train Encoder-Decoder networks ================== # if (i + 1) % self.d_train_repeat == 0: # Original-to-target and target-to-original domain enc_feat = self.Enc(torch.cat([real_x, real_l], 1)) fake_x, fake_l = self.Dec(enc_feat, fake_y) src_output, cls_output = self.D( torch.cat([fake_x, fake_l], 1)) g_loss_fake = -torch.mean(src_output) #rec_feat = self.Enc(fake_x) rec_feat = self.Enc(torch.cat([fake_x, fake_l], 1)) rec_x, rec_l = self.Dec(rec_feat, real_y) # bidirectional loss of the images g_loss_rec_x = torch.mean(torch.abs(real_x - rec_x)) g_loss_rec_l = torch.mean(torch.abs(real_l - rec_l)) #bidirectional loss of the latent feature g_loss_feature = torch.mean(torch.abs(enc_feat - rec_feat)) #identity loss of the images g_loss_identity_x = torch.mean(torch.abs(real_x - fake_x)) g_loss_identity_l = torch.mean(torch.abs(real_l - fake_l)) # attribute classification loss for the fake generated images g_loss_cls = F.cross_entropy(cls_output, fake_label) # Backward + Optimize (generator (encoder-decoder) losses), we update decoder two times for each encoder update g_loss = g_loss_fake + self.lambda_bi * g_loss_rec_x + self.lambda_bi * g_loss_rec_l + self.lambda_bi * g_loss_feature + self.lambda_id * g_loss_identity_x + self.lambda_id * g_loss_identity_l + self.lambda_cls * g_loss_cls self.reset() g_loss.backward() self.enc_optimizer.step() self.dec_optimizer.step() self.dec_optimizer.step() # Logging Generator losses loss['G/loss_feature'] = g_loss_feature.data[0] loss['G/loss_identity_x'] = g_loss_identity_x.data[0] loss['G/loss_identity_l'] = g_loss_identity_l.data[0] loss['G/loss_rec_x'] = g_loss_rec_x.data[0] loss['G/loss_rec_l'] = g_loss_rec_l.data[0] loss['G/loss_fake'] = g_loss_fake.data[0] loss['G/loss_cls'] = g_loss_cls.data[0] # Print out log if (i + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( elapsed, e + 1, self.num_epochs, i + 1, iters_per_epoch) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) for tag, value in loss.items(): self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1) # Synthesize images if (i + 1) % self.sample_step == 0: fake_image_list = [sample_x] for sample_y in sample_y_list: enc_feat = self.Enc(torch.cat([sample_x, sample_l], 1)) sample_result, sample_landmark = self.Dec( enc_feat, sample_y) fake_image_list.append(sample_result) fake_images = torch.cat(fake_image_list, dim=3) save_image(self.denorm(fake_images.data), os.path.join( self.sample_path, '{}_{}_fake.png'.format(e + 1, i + 1)), nrow=1, padding=0) print('Generated images and saved into {}..!'.format( self.sample_path)) # Save checkpoints if (i + 1) % self.model_save_step == 0: torch.save( self.Enc.state_dict(), os.path.join(self.model_path, '{}_{}_Enc.pth'.format(e + 1, i + 1))) torch.save( self.Dec.state_dict(), os.path.join(self.model_path, '{}_{}_Dec.pth'.format(e + 1, i + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_path, '{}_{}_D.pth'.format(e + 1, i + 1))) # Decay learning rate if (e + 1) > (self.num_epochs - self.num_epochs_decay): d_lr -= (self.d_lr / float(self.num_epochs_decay)) enc_lr -= (self.enc_lr / float(self.num_epochs_decay)) dec_lr -= (self.dec_lr / float(self.num_epochs_decay)) self.update_lr(enc_lr, dec_lr, d_lr) print('Decay learning rate to enc_lr: {}, d_lr: {}.'.format( enc_lr, d_lr))
def train(self): """Train anomaly detection model""" self.data_loader = self.img_data_loader # The number of iterations per epoch iters_per_epoch = len(self.data_loader.train) fixed_x = [] for i, (images, labels) in enumerate(self.data_loader.train): fixed_x.append(images) if i == 0: break # Fixed inputs and target domain labels for debugging fixed_x = torch.cat(fixed_x, dim=0) fixed_x = self.to_var(fixed_x, volatile=True) # Learning rate for decaying d_lr = self.d_lr g_lr = self.g_lr # Start with trained model if self.trained_model: start = int(self.trained_model.split('_')[0]) else: start = 0 # Start training start_time = time.time() for e in range(start, self.num_epochs): for i, (real_x, real_label) in enumerate(self.data_loader.train): rand_idx = torch.randperm(real_label.size(0)) # Convert tensor to variable real_x = self.to_var(real_x) #================== Train Discriminator ================== # # Compute loss with real images out_src = self.D(real_x) d_loss_real = -torch.mean(out_src) fake_x, _, _ = self.G(real_x) fake_x = Variable(fake_x.data) out_src = self.D(fake_x) d_loss_fake = torch.mean(out_src) # Discriminator losses d_loss = d_loss_real + d_loss_fake self.reset() d_loss.backward() self.d_optimizer.step() # Compute gradient penalty alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) out = self.D(interpolated) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones( out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Gradient penalty loss d_loss = self.lambda_gp * d_loss_gp self.reset() d_loss.backward() self.d_optimizer.step() # Logging loss = {} loss['D/loss_real'] = d_loss_real.item() loss['D/loss_fake'] = d_loss_fake.item() loss['D/loss_gp'] = d_loss_gp.item() # ================== Train Encoder-Decoder networks ================== # if (i + 1) % self.d_train_repeat == 0: fake_x, enc_feat, rec_feat = self.G(real_x) out_src = self.D(fake_x) g_loss_fake = -torch.mean(out_src) g_loss_rec_x = torch.mean(torch.abs(real_x - fake_x)) g_loss_ssim = (0.5 * (1 - self.ssim_loss(real_x, fake_x))).clamp( 0, 1) g_loss_feature = torch.mean( torch.pow((enc_feat - rec_feat), 2)) g_loss = g_loss_fake + self.lambda_f * g_loss_feature + +self.lambda_bi * g_loss_rec_x + self.lambda_ssim * g_loss_ssim self.reset() g_loss.backward() self.g_optimizer.step() # Logging Generator losses loss['G/loss_feature'] = g_loss_feature.item() loss['G/loss_image'] = g_loss_rec_x.item() loss['G/loss_ssim'] = g_loss_ssim.item() loss['G/loss_fake'] = g_loss_fake.item() # Print out log if (i + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( elapsed, e + 1, self.num_epochs, i + 1, iters_per_epoch) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) for tag, value in loss.items(): self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1) # Reconstructed images if (i + 1) % self.sample_step == 0: fake_image_list = [fixed_x] #for fixed_c in fixed_c_list: sample_result, _, _ = self.G(fixed_x) fake_image_list.append(sample_result) fake_images = torch.cat(fake_image_list, dim=3) save_image(self.denorm(fake_images.data), os.path.join( self.sample_path, '{}_{}_fake.png'.format(e + 1, i + 1)), nrow=1, padding=0) print('Generated images and saved into {}..!'.format( self.sample_path)) # Save model checkpoints if (i + 1) % self.model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e + 1, i + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e + 1, i + 1))) # Decay learning rate if (e + 1) > (self.num_epochs - self.num_epochs_decay): d_lr -= (self.d_lr / float(self.num_epochs_decay)) g_lr -= (self.g_lr / float(self.num_epochs_decay)) self.update_lr(g_lr, d_lr) print('Decay learning rate to g_lr: {}, d_lr: {}.'.format( g_lr, d_lr))
def train_multi(self): """Train StarGAN with multiple datasets. In the code below, 1 is related to CelebA and 2 is releated to RaFD. """ # Fixed imagse and labels for debugging fixed_x = [] real_c = [] for i, (images, labels) in enumerate(self.celebA_loader): fixed_x.append(images) real_c.append(labels) if i == 2: break fixed_x = torch.cat(fixed_x, dim=0) fixed_x = self.to_var(fixed_x, volatile=True) real_c = torch.cat(real_c, dim=0) fixed_c1_list = self.make_celeb_labels(real_c) fixed_c2_list = [] for i in range(self.c2_dim): fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c2_dim) fixed_c2_list.append(self.to_var(fixed_c, volatile=True)) fixed_zero1 = self.to_var(torch.zeros(fixed_x.size(0), self.c2_dim)) # zero vector when training with CelebA fixed_mask1 = self.to_var(self.one_hot(torch.zeros(fixed_x.size(0)), 2)) # mask vector: [1, 0] fixed_zero2 = self.to_var(torch.zeros(fixed_x.size(0), self.c_dim)) # zero vector when training with RaFD fixed_mask2 = self.to_var(self.one_hot(torch.ones(fixed_x.size(0)), 2)) # mask vector: [0, 1] # lr cache for decaying g_lr = self.g_lr d_lr = self.d_lr # data iterator data_iter1 = iter(self.celebA_loader) data_iter2 = iter(self.rafd_loader) # Start with trained model if self.pretrained_model: start = int(self.pretrained_model) + 1 else: start = 0 # # Start training start_time = time.time() for i in range(start, self.num_iters): # Fetch mini-batch images and labels try: real_x1, real_label1 = next(data_iter1) except: data_iter1 = iter(self.celebA_loader) real_x1, real_label1 = next(data_iter1) try: real_x2, real_label2 = next(data_iter2) except: data_iter2 = iter(self.rafd_loader) real_x2, real_label2 = next(data_iter2) # Generate fake labels randomly (target domain labels) rand_idx = torch.randperm(real_label1.size(0)) fake_label1 = real_label1[rand_idx] rand_idx = torch.randperm(real_label2.size(0)) fake_label2 = real_label2[rand_idx] real_c1 = real_label1.clone() fake_c1 = fake_label1.clone() zero1 = torch.zeros(real_x1.size(0), self.c2_dim) mask1 = self.one_hot(torch.zeros(real_x1.size(0)), 2) real_c2 = self.one_hot(real_label2, self.c2_dim) fake_c2 = self.one_hot(fake_label2, self.c2_dim) zero2 = torch.zeros(real_x2.size(0), self.c_dim) mask2 = self.one_hot(torch.ones(real_x2.size(0)), 2) # Convert tensor to variable real_x1 = self.to_var(real_x1) real_c1 = self.to_var(real_c1) fake_c1 = self.to_var(fake_c1) mask1 = self.to_var(mask1) zero1 = self.to_var(zero1) real_x2 = self.to_var(real_x2) real_c2 = self.to_var(real_c2) fake_c2 = self.to_var(fake_c2) mask2 = self.to_var(mask2) zero2 = self.to_var(zero2) real_label1 = self.to_var(real_label1) fake_label1 = self.to_var(fake_label1) real_label2 = self.to_var(real_label2) fake_label2 = self.to_var(fake_label2) # ================== Train D ================== # # Real images (CelebA) out_real, out_cls = self.D(real_x1) out_cls1 = out_cls[:, :self.c_dim] # celebA part d_loss_real = - torch.mean(out_real) d_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, real_label1, size_average=False) / real_x1.size(0) # Real images (RaFD) out_real, out_cls = self.D(real_x2) out_cls2 = out_cls[:, self.c_dim:] # rafd part d_loss_real += - torch.mean(out_real) d_loss_cls += F.cross_entropy(out_cls2, real_label2) # Compute classification accuracy of the discriminator if (i+1) % self.log_step == 0: accuracies = self.compute_accuracy(out_cls1, real_label1, 'CelebA') log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', '') print(log) accuracies = self.compute_accuracy(out_cls2, real_label2, 'RaFD') log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] print('Classification Acc (8 emotional expressions): ', '') print(log) # Fake images (CelebA) fake_c = torch.cat([fake_c1, zero1, mask1], dim=1) fake_x1 = self.G(real_x1, fake_c) fake_x1 = Variable(fake_x1.data) out_fake, _ = self.D(fake_x1) d_loss_fake = torch.mean(out_fake) # Fake images (RaFD) fake_c = torch.cat([zero2, fake_c2, mask2], dim=1) fake_x2 = self.G(real_x2, fake_c) out_fake, _ = self.D(fake_x2) d_loss_fake += torch.mean(out_fake) # Backward + Optimize d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls self.reset_grad() d_loss.backward() self.d_optimizer.step() # Compute gradient penalty if (i+1) % 2 == 0: real_x = real_x1 fake_x = fake_x1 else: real_x = real_x2 fake_x = fake_x2 alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) out, out_cls = self.D(interpolated) if (i+1) % 2 == 0: out_cls = out_cls[:, :self.c_dim] # CelebA else: out_cls = out_cls[:, self.c_dim:] # RaFD grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones(out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging loss = {} loss['D/loss_real'] = d_loss_real.data[0] loss['D/loss_fake'] = d_loss_fake.data[0] loss['D/loss_cls'] = d_loss_cls.data[0] loss['D/loss_gp'] = d_loss_gp.data[0] # ================== Train G ================== # if (i+1) % self.d_train_repeat == 0: # Original-to-target and target-to-original domain (CelebA) fake_c = torch.cat([fake_c1, zero1, mask1], dim=1) real_c = torch.cat([real_c1, zero1, mask1], dim=1) fake_x1 = self.G(real_x1, fake_c) rec_x1 = self.G(fake_x1, real_c) # Compute losses out, out_cls = self.D(fake_x1) out_cls1 = out_cls[:, :self.c_dim] g_loss_fake = - torch.mean(out) g_loss_rec = torch.mean(torch.abs(real_x1 - rec_x1)) g_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, fake_label1, size_average=False) / fake_x1.size(0) # Original-to-target and target-to-original domain (RaFD) fake_c = torch.cat([zero2, fake_c2, mask2], dim=1) real_c = torch.cat([zero2, real_c2, mask2], dim=1) fake_x2 = self.G(real_x2, fake_c) rec_x2 = self.G(fake_x2, real_c) # Compute losses out, out_cls = self.D(fake_x2) out_cls2 = out_cls[:, self.c_dim:] g_loss_fake += - torch.mean(out) g_loss_rec += torch.mean(torch.abs(real_x2 - rec_x2)) g_loss_cls += F.cross_entropy(out_cls2, fake_label2) # Backward + Optimize g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + self.lambda_rec * g_loss_rec self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging loss['G/loss_fake'] = g_loss_fake.data[0] loss['G/loss_cls'] = g_loss_cls.data[0] loss['G/loss_rec'] = g_loss_rec.data[0] # Print out log info if (i+1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed [{}], Iter [{}/{}]".format( elapsed, i+1, self.num_iters) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, i+1) # Translate the images (debugging) if (i+1) % self.sample_step == 0: fake_image_list = [fixed_x] # Changing hair color, gender, and age for j in range(self.c_dim): fake_c = torch.cat([fixed_c1_list[j], fixed_zero1, fixed_mask1], dim=1) fake_image_list.append(self.G(fixed_x, fake_c)) # Changing emotional expressions for j in range(self.c2_dim): fake_c = torch.cat([fixed_zero2, fixed_c2_list[j], fixed_mask2], dim=1) fake_image_list.append(self.G(fixed_x, fake_c)) fake = torch.cat(fake_image_list, dim=3) # Save the translated images save_image(self.denorm(fake.data.cpu()), os.path.join(self.sample_path, '{}_fake.png'.format(i+1)), nrow=1, padding=0) # Save model checkpoints if (i+1) % self.model_save_step == 0: torch.save(self.G.state_dict(), os.path.join(self.model_save_path, '{}_G.pth'.format(i+1))) torch.save(self.D.state_dict(), os.path.join(self.model_save_path, '{}_D.pth'.format(i+1))) # Decay learning rate decay_step = 1000 if (i+1) > (self.num_iters - self.num_iters_decay) and (i+1) % decay_step==0: g_lr -= (self.g_lr / float(self.num_iters_decay) * decay_step) d_lr -= (self.d_lr / float(self.num_iters_decay) * decay_step) self.update_lr(g_lr, d_lr) print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
def train(self): """Train StarGAN within a single dataset.""" # The number of iterations per epoch iters_per_epoch = len(self.data_loader) fixed_x = [] real_c = [] for i, (images, labels) in enumerate(self.data_loader): fixed_x.append(images) real_c.append(labels) if i == 3: break # Fixed inputs and target domain labels for debugging fixed_x = torch.cat(fixed_x, dim=0) fixed_x = self.to_var(fixed_x, volatile=True) real_c = torch.cat(real_c, dim=0) fixed_c_list = self.make_data_labels(real_c) # lr cache for decaying g_lr = self.g_lr d_lr = self.d_lr # Start with trained model if exists if self.pretrained_model: start = int(self.pretrained_model.split('_')[0]) else: start = 0 # Start training start_time = time.time() for e in range(start, self.num_epochs): for i, (real_x, real_label) in enumerate(self.data_loader): # Generat fake labels randomly (target domain labels) rand_idx = torch.randperm(real_label.size(0)) fake_label = real_label[rand_idx] real_c = real_label.clone() fake_c = fake_label.clone() # Convert tensor to variable real_x = self.to_var(real_x) real_c = self.to_var(real_c) # input for the generator fake_c = self.to_var(fake_c) real_label = self.to_var( real_label ) # this is same as real_c if dataset == 'CelebA' fake_label = self.to_var(fake_label) # ================== Train D ================== # # Compute loss with real images out_src, out_cls = self.D(real_x) d_loss_real = -torch.mean(out_src) d_loss_cls = F.binary_cross_entropy_with_logits( out_cls, real_label, size_average=False) / real_x.size(0) # Compute classification accuracy of the discriminator if (i + 1) % self.log_step == 0: accuracies = self.compute_accuracy(out_cls, real_label) log = [ "{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy() ] print('Classification Acc: ') print(log) # Compute loss with fake images fake_x = self.G(real_x, fake_c) fake_x = Variable(fake_x.data) out_src, out_cls = self.D(fake_x) d_loss_fake = torch.mean(out_src) # Backward + Optimize d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls self.reset_grad() d_loss.backward() self.d_optimizer.step() # Compute gradient penalty alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) out, out_cls = self.D(interpolated) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones( out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging loss = {} loss['D/loss_real'] = d_loss_real.data[0] loss['D/loss_fake'] = d_loss_fake.data[0] loss['D/loss_cls'] = d_loss_cls.data[0] loss['D/loss_gp'] = d_loss_gp.data[0] # ================== Train G ================== # if (i + 1) % self.d_train_repeat == 0: # Original-to-target and target-to-original domain fake_x = self.G(real_x, fake_c) rec_x = self.G(fake_x, real_c) # Compute losses out_src, out_cls = self.D(fake_x) g_loss_fake = -torch.mean(out_src) g_loss_rec = torch.mean(torch.abs(real_x - rec_x)) g_loss_cls = F.binary_cross_entropy_with_logits( out_cls, fake_label, size_average=False) / fake_x.size(0) # Backward + Optimize g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging loss['G/loss_fake'] = g_loss_fake.data[0] loss['G/loss_rec'] = g_loss_rec.data[0] loss['G/loss_cls'] = g_loss_cls.data[0] # Print out log info if (i + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( elapsed, e + 1, self.num_epochs, i + 1, iters_per_epoch) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary( tag, value, e * iters_per_epoch + i + 1) # Translate fixed images for debugging if (i + 1) % self.sample_step == 0: fake_image_list = [fixed_x] for fixed_c in fixed_c_list: fake_image_list.append(self.G(fixed_x, fixed_c)) fake_images = torch.cat(fake_image_list, dim=3) save_image(self.denorm(fake_images.data.cpu()), os.path.join( self.sample_path, '{}_{}_fake.png'.format(e + 1, i + 1)), nrow=1, padding=0) print('Translated images and saved into {}..!'.format( self.sample_path)) # Save model checkpoints if (i + 1) % self.model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e + 1, i + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e + 1, i + 1))) # Decay learning rate if (e + 1) > (self.num_epochs - self.num_epochs_decay): g_lr -= (self.g_lr / float(self.num_epochs_decay)) d_lr -= (self.d_lr / float(self.num_epochs_decay)) self.update_lr(g_lr, d_lr) print('Decay learning rate to g_lr: {}, d_lr: {}.'.format( g_lr, d_lr))
def train(self): """Train StarGAN within a single dataset.""" # Set dataloader self.data_loader = self.dataset1_loader # The number of iterations per epoch iters_per_epoch = len(self.data_loader) fixed_x = [] real_c = [] for i, (images, labels) in enumerate(self.data_loader): fixed_x.append(images) real_c.append(labels) if i == 0: break # Fixed inputs and target domain labels for debugging fixed_x = torch.cat(fixed_x, dim=0) fixed_x = self.to_var(fixed_x, requires_grad=False) real_c = torch.cat(real_c, dim=0) fixed_c_list = [] for i in range(self.c_dim): fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c_dim) fixed_c_list.append(self.to_var(fixed_c, requires_grad=False)) # lr cache for decaying g_lr = self.g_lr d_lr = self.d_lr # Start with trained model if exists if self.resume: start = int(self.resume.split('_')[0]) else: start = 0 # Start training start_time = time.time() for e in range(start, self.num_epochs): for i, (real_x, real_label) in enumerate(self.data_loader): # Generat fake labels randomly (target domain labels) rand_idx = torch.randperm(real_label.size(0)) fake_label = real_label[rand_idx] real_c = self.one_hot(real_label, self.c_dim) fake_c = self.one_hot(fake_label, self.c_dim) # Convert tensor to variable real_x = self.to_var(real_x) real_c = self.to_var(real_c) # input for the generator fake_c = self.to_var(fake_c) real_label = self.to_var(real_label, requires_grad=False) # this is same as real_c if dataset == 'CelebA' fake_label = self.to_var(fake_label, requires_grad=False) # ================== Train D ================== # # Compute loss with real images out_real, out_cls, out_reg = self.D(real_x, real_label) # Compute loss with fake images fake_x = self.G(real_x, fake_c).detach() out_fake, _, _ = self.D(fake_x, fake_label) # d_loss_adv = loss_hinge_dis(out_fake, out_real) d_loss_adv = -torch.mean(out_real) + torch.mean(out_fake) d_loss_cls = F.cross_entropy(out_cls, real_label) # todo:regression d_loss_reg = loss_hard_reg(out_reg, real_label, self.c_dim) # Compute classification accuracy of the discriminator if (i+1) % self.log_step == 0: classification_accuracies = self.compute_accuracy(out_cls, real_label, n_classes=self.c_dim) regression_accuracies = self.compute_accuracy(out_reg.squeeze(), real_label, n_classes=self.c_dim) log = "{:.2f}/{:.2f}".format(classification_accuracies.data.cpu().numpy(), regression_accuracies.data.cpu().numpy()) print('Classification/regression Acc: ', end='') print(log) # Backward + Optimize d_loss = d_loss_adv + self.lambda_cls * d_loss_cls + self.lambda_reg * d_loss_reg self.reset_grad() d_loss.backward() self.d_optimizer.step() # Compute gradient penalty alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) out, _, _ = self.D(interpolated, real_label) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones(out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1) ** 2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging loss = collections.OrderedDict() loss['D/loss_adv'] = d_loss_adv.data loss['D/loss_reg'] = d_loss_reg.data loss['D/loss_gp'] = d_loss_gp.data loss['D/loss_cls'] = d_loss_cls.data # todo if self.dataset2_loader is not None: pExample1_img, pExample2_img, pExample1_lbl, pExample2_lbl, nExample1_img, nExample2_img, nExample1_lbl, nExample2_lbl = iter( self.dataset2_loader).next() # Generat fake labels randomly (target domain labels) pExample1_c = self.one_hot(pExample1_lbl, self.c_dim) pExample2_c = self.one_hot(pExample2_lbl, self.c_dim) nExample1_c = self.one_hot(nExample1_lbl, self.c_dim) nExample2_c = self.one_hot(nExample2_lbl, self.c_dim) # Convert tensor to variable pExample1_img = self.to_var(pExample1_img) pExample2_img = self.to_var(pExample2_img) nExample1_img = self.to_var(nExample1_img) nExample2_img = self.to_var(nExample2_img) pExample1_c = self.to_var(pExample1_c, self.c_dim) pExample2_c = self.to_var(pExample2_c, self.c_dim) nExample1_c = self.to_var(nExample1_c, self.c_dim) nExample2_c = self.to_var(nExample2_c, self.c_dim) pExample1_lbl = self.to_var(pExample1_lbl, requires_grad=False) pExample2_lbl = self.to_var(pExample2_lbl, requires_grad=False) nExample1_lbl = self.to_var(nExample1_lbl, requires_grad=False) nExample2_lbl = self.to_var(nExample2_lbl, requires_grad=False) # ================== Train D2 ================== # # Compute loss with real example out_real, mu1_real, logvar1_real, mu2_real, logvar2_real = self.D2(pExample1_img, pExample2_img, pExample1_lbl, pExample2_lbl) # Compute loss with negtive example out_neg, mu1_neg, logvar1_neg, mu2_neg, logvar2_neg = self.D2(nExample1_img, nExample2_img, nExample1_lbl, nExample2_lbl) # no projection # out_neg, mu1_neg, logvar1_neg, mu2_neg, logvar2_neg = self.D2(nExample1_img, nExample2_img) # # Compute loss with real example # out_real = self.D2(pExample1_img, pExample2_img, pExample1_lbl, pExample2_lbl) # # Compute loss with negtive example # out_neg = self.D2(nExample1_img, nExample2_img, nExample1_lbl, nExample2_lbl) # Compute loss with fake example # fExample2_img = self.G(pExample1_img, pExample2_c).detach() # out_fake, _, _, mu2_fake, logvar2_fake = self.D2(pExample1_img, fExample2_img, pExample1_lbl, pExample2_lbl) fExample2_img = self.G(pExample1_img, pExample2_c).detach() out_fake, _, _, mu2_fake, logvar2_fake = self.D2(pExample1_img, fExample2_img, pExample1_lbl, pExample2_lbl) # unilateral projection # out_fake, _, _, mu2_fake, logvar2_fake = self.D2(pExample1_img, fExample2_img, None, # pExample2_lbl) kl_real = loss_kl(mu1_real, logvar1_real) + loss_kl(mu2_real, logvar2_real) kl_neg = loss_kl(mu1_neg, logvar1_neg) + loss_kl(mu2_neg, logvar2_neg) kl_fake = loss_kl(mu2_fake, logvar2_fake) kl = (kl_fake + kl_neg + kl_real) * 0.2 # d_loss_adv = loss_hinge_dis(out_fake, out_real) d2_loss_adv = -torch.mean(out_real) + (torch.mean(out_fake) + torch.mean(out_neg))*0.5 d2_loss = d2_loss_adv + kl*self.lambda_kl # Backward + Optimize self.reset_grad() d2_loss.backward() self.d2_optimizer.step() # Compute gradient penalty alpha = torch.rand(nExample2_img.size(0), 1, 1, 1).cuda().expand_as(nExample2_img) interpolated = Variable(alpha * nExample2_img.data + (1 - alpha) * fExample2_img.data, requires_grad=True) out, _, _, _, _ = self.D2(pExample1_img, interpolated, pExample1_lbl, pExample2_lbl) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones(out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) d2_loss_gp = torch.mean((grad_l2norm - 1) ** 2) # Backward + Optimize d2_loss = self.lambda_gp * d2_loss_gp self.reset_grad() d2_loss.backward() self.d2_optimizer.step() # Logging loss['D2/loss_adv'] = d2_loss_adv.data loss['D2/d2_loss_gp'] = d2_loss_gp.data loss['D2/loss_kl'] = kl.data # ================== Train G ================== # if (i+1) % self.d_train_repeat == 0: # Original-to-target and target-to-original domain fake_x = self.G(real_x, fake_c) # todo rec_x = self.G(real_x, real_c) # Compute losses out_fake, out_cls, out_reg = self.D(fake_x, fake_label) g_loss_adv = -torch.mean(out_fake) g_loss_rec = torch.mean(torch.abs(real_x - rec_x)) g_loss_cls = F.cross_entropy(out_cls, fake_label) g_loss_reg = loss_hard_reg(out_reg, fake_label, self.c_dim) # todo if self.dataset2_loader is not None: # fExample2_img = self.G(pExample1_img, pExample2_c) # out_fake, _, _, _, _ = self.D2(pExample1_img, fExample2_img, # pExample1_lbl, # pExample2_lbl) fExample2_img = self.G(pExample1_img, nExample2_c) out_fake, _, _, _, _ = self.D2(pExample1_img, fExample2_img, pExample1_lbl, nExample2_lbl) # unilateral projection # out_fake, _, _, _, _ = self.D2(pExample1_img, fExample2_img, # None, # nExample2_lbl) g_loss_adv2 = -torch.mean(out_fake) else: g_loss_adv2 = 0 # Backward + Optimize g_loss = g_loss_adv + self.lambda_cls * g_loss_cls + self.lambda_rec * g_loss_rec + self.lambda_reg * g_loss_reg + g_loss_adv2 self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging loss['G/loss_adv'] = g_loss_adv.data loss['G/loss_rec'] = g_loss_rec.data loss['G/loss_reg'] = g_loss_reg.data loss['G/loss_cls'] = g_loss_cls.data loss['G/loss_adv2'] = g_loss_adv2 # Print out log info if (i+1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( elapsed, e+1, self.num_epochs, i+1, iters_per_epoch) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1) # Translate fixed images for debugging if (i+1) % self.sample_step == 0: fake_image_list = [fixed_x] for fixed_c in fixed_c_list: fake_image_list.append(self.G(fixed_x, fixed_c).detach()) fake_images = torch.cat(fake_image_list, dim=3) save_image(self.denorm(fake_images.data), os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0) if self.dataset2_loader is not None: pair_images = torch.cat([pExample1_img, pExample2_img, fExample2_img], dim=3).detach() if self.debug: pair_images = torch.cat([pair_images, nExample1_img, nExample2_img], dim=3).detach() save_image(self.denorm(pair_images.data), os.path.join(self.sample_path, '{}_{}_pair_images.png'.format(e + 1, i + 1)), nrow=1, padding=0) print('Translated images and saved into {}..!'.format(self.sample_path)) # Save model checkpoints if (e+1) % self.model_save_step == 0 and (e+1) > self.model_save_star: print('Save model checkpoints') torch.save(self.G.state_dict(), os.path.join(self.model_save_path, '{}_G.pth'.format(e+1))) torch.save(self.D.state_dict(), os.path.join(self.model_save_path, '{}_D.pth'.format(e+1))) if self.dataset2_loader is not None: torch.save(self.D2.state_dict(), os.path.join(self.model_save_path, '{}_D2.pth'.format(e + 1))) intra_fid = calculate_intra_fid(self.eval_path, self.eval_batchsize, True, self.dims, self.eval_model, self.G, self.eval_loader) log = 'TEST Epoch [{}/{}]'.format(e+1, self.num_epochs) for tag, value in intra_fid.items(): log += ", {}: {:.4f}".format(tag, value) test_log_path = os.path.join(self.log_path, 'test.log') with open(test_log_path, 'a') as f: f.write(log) f.write('\n') print(log) # Decay learning rate if (e+1) > (self.num_epochs - self.num_epochs_decay): g_lr -= (self.g_lr / float(self.num_epochs_decay)) d_lr -= (self.d_lr / float(self.num_epochs_decay)) self.update_lr(g_lr, d_lr) print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
def train(self): """Train StarGAN within a single dataset.""" # Set dataloader self.data_loader = self.celebA_loader # The number of iterations per epoch iters_per_epoch = len(self.data_loader) # Fixed latent vector and label for output samples fixed_size = 20 fixed_z = torch.randn(fixed_size, self.z_dim) fixed_z = self.to_var(fixed_z, volatile=True) fixed_c_list = self.make_celeb_labels_test() fixed_z_repeat = fixed_z.repeat(len(fixed_c_list), 1) fixed_c_repeat_list = [] for fixed_c in fixed_c_list: fixed_c_repeat_list.append( fixed_c.expand(fixed_size, fixed_c.size(1))) fixed_c_list = [] fixed_c_repeat = torch.cat(fixed_c_repeat_list, dim=0) fixed_c_repeat_list = [] # lr cache for decaying g_lr = self.g_lr d_lr = self.d_lr # Start with trained model if exists if self.pretrained_model: start = int(self.pretrained_model.split('_')[0]) - 1 else: start = 0 # Start training start_time = time.time() for e in range(start, self.num_epochs): epoch_iter = 0 for i, (real_x, real_label) in enumerate(self.data_loader): epoch_iter = epoch_iter + 1 if self.dataset == 'Fashion': real_c_i = real_label_i.clone() real_c = real_label.clone() # rand_idx = torch.randperm(real_c.size(0)) # fake_c = real_c[rand_idx] z = torch.randn(real_x.size(0), self.z_dim) z = self.to_var(z) # Convert tensor to variable real_x = self.to_var(real_x) real_c = self.to_var(real_c) # input for the generator if self.dataset == 'Fashion': real_c_i = self.to_var(real_c_i) # fake_c = self.to_var(fake_c, volatile=True) # ================== Train D ================== # # Compute loss with real images out_src, out_cls = self.D(real_x) d_loss_real = -torch.mean(out_src) # print(real_x.size()) # print(out_src.size()) # print(out_cls.size()) # print(real_c.size()) if self.dataset == 'CelebA': d_loss_cls = F.binary_cross_entropy_with_logits( out_cls, real_c, size_average=False) / real_x.size(0) elif self.dataset == 'Fashion': d_loss_cls = F.cross_entropy(out_cls, real_c_i) # # Compute classification accuracy of the discriminator # if (i+1) % self.log_step == 0: # accuracies = self.compute_accuracy(out_cls, real_c, self.dataset) # log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] # if self.dataset == 'CelebA': # print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='') # else: # print('Classification Acc (8 emotional expressions): ', end='') # print(log) # Compute loss with fake images fake_x = self.G(z, real_c) fake_x = Variable(fake_x.data) out_src, out_cls = self.D(fake_x) d_loss_fake = torch.mean(out_src) # Backward + Optimize d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls self.reset_grad() d_loss.backward() self.d_optimizer.step() # Compute gradient penalty alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) out, out_cls = self.D(interpolated) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones( out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging loss = {} loss['D/loss_real'] = d_loss_real.data[0] loss['D/loss_fake'] = d_loss_fake.data[0] loss['D/loss_cls'] = d_loss_cls.data[0] loss['D/loss_gp'] = d_loss_gp.data[0] # ================== Train G ================== # if (i + 1) % self.d_train_repeat == 0: # Original-to-target and target-to-original domain fake_x = self.G(z, real_c) # fake_x2 = self.G(z, fake_c) # Compute losses out_src, out_cls = self.D(fake_x) g_loss_fake = -torch.mean(out_src) if self.dataset == 'CelebA': g_loss_cls = F.binary_cross_entropy_with_logits( out_cls, real_c, size_average=False) / fake_x.size(0) elif self.dataset == 'Fashion': g_loss_cls = F.cross_entropy(out_cls, real_c_i) # Backward + Optimize g_loss = g_loss_fake + self.lambda_cls * g_loss_cls self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging loss['G/loss_fake'] = g_loss_fake.data[0] loss['G/loss_cls'] = g_loss_cls.data[0] if (i + 1) % self.visual_step == 0: # save visuals self.real_x = real_x self.fake_x = fake_x # self.fake_x2 = fake_x2 # save losses self.d_real = -d_loss_real self.d_fake = d_loss_fake self.d_loss = d_loss self.g_loss = g_loss self.g_loss_fake = g_loss_fake self.g_loss_cls = self.lambda_cls * g_loss_cls self.d_loss_cls = self.lambda_cls * d_loss_cls errors_D = self.get_current_errors('D') errors_G = self.get_current_errors('G') self.visualizer.display_current_results( self.get_current_visuals(), e) self.visualizer.plot_current_errors_D( e, float(epoch_iter) / float(iters_per_epoch), errors_D) self.visualizer.plot_current_errors_G( e, float(epoch_iter) / float(iters_per_epoch), errors_G) # Print out log info if (i + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( elapsed, e + 1, self.num_epochs, i + 1, iters_per_epoch) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary( tag, value, e * iters_per_epoch + i + 1) # Translate fixed images for debugging # if (i+1) % self.sample_step == 0: # # fake_image_list = [] # # for fixed_c in fixed_c_list: # # fixed_c = fixed_c.expand(fixed_z.size(0), fixed_c.size(1)) # # fake_image_list.append(self.G(fixed_z, fixed_c)) # # fake_images = torch.cat(fake_image_list, dim=3) # # save_image(self.denorm(fake_images.data), # # os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0) # # print('Translated images and saved into {}..!'.format(self.sample_path)) # fake_images_repeat = self.G(fixed_z_repeat, fixed_c_repeat) # fake_image_list = [] # for idx in range(12): # fake_image_list.append(fake_images_repeat[fixed_size*(idx):fixed_size*(idx+1)]) # fake_images = torch.cat(fake_image_list, dim=3) # save_image(self.denorm(fake_images.data), # os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0) # print('Translated images and saved into {}..!'.format(self.sample_path)) # Save model checkpoints if (i + 1) % self.model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e + 1, i + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e + 1, i + 1))) # Decay learning rate if (e + 1) > (self.num_epochs - self.num_epochs_decay): g_lr -= (self.g_lr / float(self.num_epochs_decay)) d_lr -= (self.d_lr / float(self.num_epochs_decay)) self.update_lr(g_lr, d_lr) print('Decay learning rate to g_lr: {}, d_lr: {}.'.format( g_lr, d_lr))
def train(self): """Train StarGAN within a single dataset.""" self.criterionL1 = torch.nn.L1Loss() # self.criterionL2 = torch.nn.MSELoss() self.criterionTV = TVLoss() self.data_loader = self.Msceleb_loader # The number of iterations per epoch iters_per_epoch = len(self.data_loader) fixed_x = [] real_c = [] for i, (aug_images, aug_labels, _, _) in enumerate(self.data_loader): fixed_x.append(aug_images) real_c.append(aug_labels) if i == 3: break # Fixed inputs and target domain labels for debugging fixed_x = torch.cat(fixed_x, dim=0) fixed_x = self.to_var(fixed_x, volatile=True) # lr cache for decaying g_lr = self.g_lr d_lr = self.d_lr # Start with trained model if exists if self.pretrained_model: start = int(self.pretrained_model.split('_')[0]) else: start = 0 # Start training start_time = time.time() for e in range(start, self.num_epochs): for i, (aug_x, aug_label, origin_x, origin_label) in enumerate(self.data_loader): # Generat fake labels randomly (target domain labels) # aug_c = self.one_hot(aug_label, self.c_dim) # origin_c = self.one_hot(origin_label, self.c_dim) aug_c_V = self.to_var(aug_label) origin_c_V = self.to_var(origin_label) aug_x = self.to_var(aug_x) origin_x = self.to_var(origin_x) # # ================== Train D ================== # # Compute loss with real images out_src = self.D(origin_x) out_cls = self.C(origin_x) d_loss_real = - torch.mean(out_src) c_loss_cls = F.cross_entropy(out_cls, origin_c_V) # Compute classification accuracy of the discriminator if (i+1) % self.log_step == 0: accuracies = self.compute_accuracy(out_cls, origin_c_V) log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] print('Classification Acc (75268 ids): ') print(log) # Compute loss with fake images fake_x = self.G(aug_x) fake_x = Variable(fake_x.data) out_src = self.D(fake_x) d_loss_fake = torch.mean(out_src) # Backward + Optimize d_loss = d_loss_real + d_loss_fake c_loss = self.lambda_cls * c_loss_cls self.reset_grad() d_loss.backward() c_loss.backward() self.d_optimizer.step() self.c_optimizer.step() # Compute gradient penalty alpha = torch.rand(origin_x.size(0), 1, 1, 1).cuda().expand_as(origin_x) interpolated = Variable(alpha * origin_x.data + (1 - alpha) * fake_x.data, requires_grad=True) out = self.D(interpolated) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones(out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging loss = {} loss['D/loss_real'] = d_loss_real.data[0] loss['D/loss_fake'] = d_loss_fake.data[0] loss['D/loss_gp'] = d_loss_gp.data[0] loss['C/loss_cls'] = c_loss_cls.data[0] # ================== Train G ================== # if (i+1) % self.d_train_repeat == 0: # Original-to-target and target-to-original domain fake_x = self.G(aug_x) # Compute losses out_src = self.D(fake_x) out_cls = self.C(fake_x) g_loss_fake = - torch.mean(out_src) g_loss_cls = F.cross_entropy(out_cls, aug_c_V) # Backward + Optimize recon_loss = self.criterionL1(fake_x, aug_x) TV_loss = self.criterionTV(fake_x) * 0.001 g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + 5* recon_loss + TV_loss # if self.lambda_face > 0.0: # self.criterionFace = nn.L1Loss() # # real_input_x = (torch.sum(real_x, 1, keepdim=True) / 3.0 + 1) / 2.0 # fake_input_x = (torch.sum(fake_x, 1, keepdim=True) / 3.0 + 1) / 2.0 # rec_input_x = (torch.sum(rec_x, 1, keepdim=True) / 3.0 + 1) / 2.0 # # _, real_x_feature_fc, real_x_feature_conv = self.Face_recognition_network.forward( # real_input_x) # _, fake_x_feature_fc, fake_x_feature_conv = self.Face_recognition_network.forward( # fake_input_x) # _, rec_x1_feature_fc, rec_x1_feature_conv = self.Face_recognition_network.forward(rec_input_x) # # x1_loss = (self.criterionFace(fake_x1_feature_fc, Variable(real_x1_feature_fc.data,requires_grad=False)) + # # self.criterionFace(fake_x1_feature_conv,Variable(real_x1_feature_conv.data,requires_grad=False)))\ # # * self.lambda_face # x_loss = (self.criterionFace(fake_x_feature_fc,Variable(real_x_feature_fc.data, requires_grad=False))) \ # * self.lambda_face # # rec_x_loss = (self.criterionFace(rec_x1_feature_fc, Variable(real_x_feature_fc.data, requires_grad=False))) # # self.id_loss = x_loss + rec_x_loss # loss['G/id_loss'] = self.id_loss.data[0] # g_loss += self.id_loss self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging loss['G/loss_fake'] = g_loss_fake.data[0] loss['G/loss_cls'] = g_loss_cls.data[0] # Print out log info if (i+1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( elapsed, e+1, self.num_epochs, i+1, iters_per_epoch) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) # Translate fixed images for debugging if (i+1) % self.sample_step == 0: fake_image_list = [fixed_x] fake_image_list.append(self.G(fixed_x)) fake_images = torch.cat(fake_image_list, dim=3) save_image(self.denorm(fake_images.data), os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0) print('Translated images and saved into {}..!'.format(self.sample_path)) # Save model checkpoints if (i+1) % self.model_save_step == 0: torch.save(self.G.state_dict(), os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e+1, i+1))) torch.save(self.D.state_dict(), os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e+1, i+1))) torch.save(self.C.state_dict(), os.path.join(self.model_save_path, '{}_{}_C.pth'.format(e+1, i+1))) # Decay learning rate if (e+1) > (self.num_epochs - self.num_epochs_decay): g_lr -= (self.g_lr / float(self.num_epochs_decay)) d_lr -= (self.d_lr / float(self.num_epochs_decay)) self.update_lr(g_lr, d_lr) print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) torch.save(self.G.state_dict(), os.path.join(self.model_save_path, '{}_final_G.pth'.format(e + 1))) torch.save(self.D.state_dict(), os.path.join(self.model_save_path, '{}_final_D.pth'.format(e + 1))) torch.save(self.C.state_dict(), os.path.join(self.model_save_path, '{}_final_C.pth'.format(e + 1)))
def train(self): # if self.config.visualize: visualizer = Visualizer() """Train StarGAN within a single dataset.""" # Set dataloader self.data_loader = self.train_loader # The number of iterations per epoch iters_per_epoch = len(self.data_loader) fixed_x = [] real_c = [] for i, (imgs, labels, _) in enumerate(self.data_loader): fixed_x.append(imgs[0]) real_c.append(labels) if i == 0: break # Fixed inputs and target domain labels for debugging fixed_x = torch.cat(fixed_x, dim=0) fixed_x = self.to_var(fixed_x, volatile=True) real_c = torch.cat(real_c, dim=0) fixed_c_list = self.make_celeb_labels(self.config.batch_size) # lr cache for decaying g_lr = self.config.g_lr d_lr = self.config.d_lr # Start with trained model if exists if self.config.pretrained_model: start = int(self.config.pretrained_model.split('_')[0]) - 1 else: start = 0 # Start training self.loss = {} start_time = time.time() for e in range(start, self.config.num_epochs): self.test(e) for i, (images, real_label, identity) in enumerate(self.data_loader): real_x = images[0] if self.config.use_si: real_ox = self.to_var(images[1]) real_oo = self.to_var(images[2]) if self.config.id_cls_loss == 'cross': identity = identity.squeeze() # Generate fake labels randomly (target domain labels) rand_idx = torch.randperm(real_label.size(0)) fake_label = real_label[rand_idx] real_c = real_label.clone() fake_c = fake_label.clone() # Convert tensor to variable real_x = self.to_var(real_x) real_c = self.to_var(real_c) # input for the generator fake_c = self.to_var(fake_c) real_label = self.to_var( real_label ) # this is same as real_c if dataset == 'CelebA' fake_label = self.to_var(fake_label) identity = self.to_var(identity) # ================== Train D ================== # # Compute loss with real images if self.config.loss_id_cls: out_src, out_cls, out_id_real = self.D(real_x) else: out_src, out_cls = self.D(real_x) d_loss_real = -torch.mean(out_src) d_loss_cls = F.binary_cross_entropy_with_logits( out_cls, real_label, size_average=False) / real_x.size(0) if self.config.loss_id_cls: d_loss_id_cls = self.id_cls_criterion( out_id_real, identity) self.loss[ 'D/loss_id_cls'] = self.config.lambda_id_cls * d_loss_id_cls.data[ 0] else: d_loss_id_cls = 0.0 # Compute classification accuracy of the discriminator if (i + 1) % self.config.log_step == 0: accuracies = self.compute_accuracy(out_cls.detach(), real_label, self.config.dataset) log = [ "{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy() ] print('Classification Acc (20 classes): ') print(log) print('\n') # Compute loss with fake images if self.config.use_gpb: fake_x, _ = self.G(real_x, fake_c) else: fake_x = self.G(real_x, fake_c) fake_x = Variable(fake_x.data) if self.config.loss_id_cls: out_src, out_cls, _ = self.D(fake_x.detach()) else: out_src, out_cls = self.D(fake_x.detach()) d_loss_fake = torch.mean(out_src) # Backward + Optimize d_loss = d_loss_real + d_loss_fake + self.config.lambda_cls * d_loss_cls + d_loss_id_cls * self.config.lambda_id_cls self.reset_grad() d_loss.backward() self.d_optimizer.step() # Compute gradient penalty alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) if self.config.loss_id_cls: out, out_cls, _ = self.D(interpolated) else: out, out_cls = self.D(interpolated) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones( out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.config.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging self.loss['D/loss_real'] = d_loss_real.data[0] self.loss['D/loss_fake'] = d_loss_fake.data[0] self.loss[ 'D/loss_cls'] = self.config.lambda_cls * d_loss_cls.data[0] self.loss[ 'D/loss_gp'] = self.config.lambda_gp * d_loss_gp.data[0] # ================== Train G ================== # if (i + 1) % self.config.d_train_repeat == 0: self.img = {} # Original-to-target and target-to-original domain if self.config.use_gpb: fake_x, id_vector_real_in_x = self.G(real_x, fake_c) rec_x, id_vector_fake_in_x = self.G( fake_x.detach(), real_c) else: fake_x = self.G(real_x, fake_c) rec_x = self.G(fake_x.detach(), real_c) # Compute losses if self.config.loss_id_cls: out_src, out_cls, out_id_fake = self.D(fake_x) else: out_src, out_cls = self.D(fake_x) g_loss_fake = -torch.mean(out_src) g_loss_rec = torch.mean(torch.abs(real_x - rec_x)) ### siamese loss if self.config.use_si: if self.config.use_gpb: # feedforward fake_ox, id_vector_ox = self.G(real_ox, fake_c) fake_oo, id_vector_oo = self.G(real_oo, fake_c) id_vector_ox = id_vector_ox.detach() id_vector_oo = id_vector_oo.detach() mdist = 1.0 - torch.mean( torch.abs(id_vector_real_in_x - id_vector_oo)) mdist = torch.clamp(mdist, min=0.0) g_loss_si = 0.5 * (torch.pow( torch.mean( torch.abs(id_vector_real_in_x - id_vector_ox)), 2) + torch.pow(mdist, 2)) # backward _, id_vector_ox = self.G(fake_ox.detach(), real_c) _, id_vector_oo = self.G(fake_oo.detach(), real_c) id_vector_ox = id_vector_ox.detach() id_vector_oo = id_vector_oo.detach() mdist = 1.0 - torch.mean( torch.abs(id_vector_fake_in_x - id_vector_oo)) mdist = torch.clamp(mdist, min=0.0) g_loss_si += 0.5 * (torch.pow( torch.mean( torch.abs(id_vector_fake_in_x - id_vector_ox)), 2) + torch.pow(mdist, 2)) self.loss['G/g_loss_si'] = g_loss_si.data[0] else: fake_ox = self.G(real_ox, fake_c).detach() fake_ooc = fake_c.data.cpu().numpy().copy() fake_ooc = np.roll(fake_ooc, np.random.randint( self.config.c_dim), axis=1) fake_ooc = self.to_var(torch.FloatTensor(fake_ooc)) fake_oo = self.G(real_oo, fake_ooc).detach() mdist = 1.0 - torch.mean( torch.abs(fake_x - fake_oo)) mdist = torch.clamp(mdist, min=0.0) g_loss_si = 0.5 * (torch.pow( torch.mean(torch.abs(fake_x - fake_ox)), 2) + torch.pow(mdist, 2)) self.loss['G/g_loss_si'] = g_loss_si.data[0] else: g_loss_si = 0.0 ### id cls loss if self.config.loss_id_cls: g_loss_id_cls = self.id_cls_criterion( out_id_fake, identity) self.loss[ 'G/g_loss_id_cls'] = self.config.lambda_id_cls * g_loss_id_cls.data[ 0] else: g_loss_id_cls = 0.0 ### sym loss if self.config.loss_symmetry: g_loss_sym_fake = self.find_sym_img_and_cal_loss( fake_x, fake_c, True) # cal. over samples w/ specific labels g_loss_sym_rec = self.find_sym_img_and_cal_loss( rec_x, real_c, True) lap_fake_x = self.take_laplacian(fake_x) lap_rec_x = self.take_laplacian(rec_x) g_loss_sym_lap_fake = self.find_sym_img_and_cal_loss( lap_fake_x, None, False) # cal. over all samples g_loss_sym_lap_rec = self.find_sym_img_and_cal_loss( lap_rec_x, None, False) sym_loss = (g_loss_sym_fake + g_loss_sym_rec + g_loss_sym_lap_fake + g_loss_sym_lap_rec) self.loss[ 'G/g_loss_sym'] = self.config.lambda_symmetry * sym_loss.data[ 0] else: sym_loss = 0 ###id loss if self.config.loss_id: if self.config.use_gpb: idx, _ = self.G(real_x, real_c) else: idx = self.G(real_x, real_c) self.img['idx'] = idx g_loss_id = torch.mean(torch.abs(real_x - idx)) self.loss[ 'G/g_loss_id'] = self.config.lambda_idx * g_loss_id.data[ 0] else: g_loss_id = 0 ###identity loss if self.config.loss_identity: real_x_f, real_x_p = self.get_feature(real_x) fake_x_f, fake_x_p = self.get_feature(fake_x) g_loss_identity = torch.mean( torch.abs(real_x_f - fake_x_f)) g_loss_identity += torch.mean( torch.abs(real_x_p - fake_x_p)) self.loss[ 'G/g_loss_identity'] = self.config.lambda_identity * g_loss_identity.data[ 0] else: g_loss_identity = 0 ###total var loss if self.config.loss_tv: g_tv_loss = (self.total_variation_loss(fake_x) + self.total_variation_loss(rec_x)) / 2 self.loss[ 'G/tv_loss'] = self.config.lambda_tv * g_tv_loss.data[ 0] else: g_tv_loss = 0 ### D's cls loss g_loss_cls = F.binary_cross_entropy_with_logits( out_cls, fake_label, size_average=False) / fake_x.size(0) # Backward + Optimize g_loss = g_loss_fake +\ self.config.lambda_rec * g_loss_rec +\ self.config.lambda_cls * g_loss_cls+\ self.config.lambda_idx * g_loss_id+\ self.config.lambda_identity*g_loss_identity+\ self.config.lambda_tv*g_tv_loss+\ self.config.lambda_symmetry*sym_loss+\ self.config.lambda_id_cls * g_loss_id_cls+\ self.config.lambda_si * g_loss_si self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging self.img['real_x'] = real_x self.img['fake_x'] = fake_x self.img['rec_x'] = rec_x self.loss['G/loss_fake'] = g_loss_fake.data[0] self.loss[ 'G/loss_rec'] = self.config.lambda_rec * g_loss_rec.data[ 0] self.loss[ 'G/loss_cls'] = self.config.lambda_cls * g_loss_cls.data[ 0] # # Print out log info if (i + 1) % self.config.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( elapsed, e + 1, self.config.num_epochs, i + 1, iters_per_epoch) for tag, value in self.loss.items(): log += ", {}: {}".format(tag, value) print(log) if self.config.use_tensorboard: for tag, value in self.loss.items(): self.logger.scalar_summary( tag, value, e * iters_per_epoch + i + 1) # Translate fixed images for debugging if (i) % self.config.sample_step == 0: fake_image_list = [fixed_x] for fixed_c in fixed_c_list: if self.config.use_gpb: fake_image_list.append(self.G(fixed_x, fixed_c)[0]) else: fake_image_list.append(self.G(fixed_x, fixed_c)) fake_images = torch.cat(fake_image_list, dim=3) if not self.config.log_space: save_image(self.denorm(fake_images.data), os.path.join( self.config.sample_path, '{}_{}_fake.png'.format(e + 1, i + 1)), nrow=1, padding=0) else: fake_images = self.denorm(fake_images.data) * 255.0 fake_images = torch.pow( 2.71828182846, fake_images / 255.0 * np.log(256.0)) - 1.0 fake_images = fake_images / 255.0 fake_images = fake_images.clamp(0.0, 1.0) save_image(fake_images, os.path.join( self.config.sample_path, '{}_{}_fake.png'.format(e + 1, i + 1)), nrow=1, padding=0) print('Translated images and saved into {}..!'.format( self.config.sample_path)) # Save model checkpoints if (i + 1) % self.config.model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.config.model_save_path, '{}_{}_G.pth'.format(e + 1, i + 1))) torch.save( self.D.state_dict(), os.path.join(self.config.model_save_path, '{}_{}_D.pth'.format(e + 1, i + 1))) if self.config.visualize and (i + 1) % self.config.display_f == 0: visualizer.display_current_results(self.img) visualizer.plot_current_errors( e, float(i + 1) / iters_per_epoch, self.loss) # Decay learning rate if (e + 1) > (self.config.num_epochs - self.config.num_epochs_decay): g_lr -= (self.config.g_lr / float(self.config.num_epochs_decay)) d_lr -= (self.config.d_lr / float(self.config.num_epochs_decay)) self.update_lr(g_lr, d_lr) print('Decay learning rate to g_lr: {}, d_lr: {}.'.format( g_lr, d_lr))
def train(self): """Train StarGAN within a single dataset.""" # Set dataloader if self.dataset == 'CelebA': self.data_loader = self.celebA_loader else: self.data_loader = self.rafd_loader # The number of iterations per epoch iters_per_epoch = len(self.data_loader) fixed_x = [] real_c = [] fixed_s = [] for i, (images, seg_i, seg, labels) in enumerate(self.data_loader): fixed_x.append(images) fixed_s.append(seg) real_c.append(labels) if i == 3: break # Fixed inputs and target domain labels for debugging fixed_x = torch.cat(fixed_x, dim=0) fixed_x = self.to_var(fixed_x, volatile=True) real_c = torch.cat(real_c, dim=0) fixed_s = torch.cat(fixed_s, dim=0) fixed_s_list = [] fixed_s_list.append(self.to_var(fixed_s, volatile=True)) rand_idx = torch.randperm(fixed_s.size(0)) fixed_s_num = 5 fixed_s_vec = fixed_s[rand_idx][:fixed_s_num] for i in range(fixed_s_num): fixed_s_temp = fixed_s_vec[i].unsqueeze(0).repeat(fixed_s.size(0),1,1,1) fixed_s_temp = self.to_var(fixed_s_temp) fixed_s_list.append(fixed_s_temp) # for i in range(4): # rand_idx = torch.randperm(fixed_s.size(0)) # fixed_s_temp = self.to_var(fixed_s[rand_idx], volatile=True) # fixed_s_list.append(fixed_s_temp) if self.dataset == 'CelebA': fixed_c_list = self.make_celeb_labels(real_c) elif self.dataset == 'RaFD': fixed_c_list = [] for i in range(self.c_dim): fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c_dim) fixed_c_list.append(self.to_var(fixed_c, volatile=True)) # lr cache for decaying g_lr = self.g_lr d_lr = self.d_lr # Start with trained model if exists if self.pretrained_model: start = int(self.pretrained_model.split('_')[0])-1 else: start = 0 # Start training start_time = time.time() for e in range(start, self.num_epochs): epoch_iter = 0 for i, (real_x, real_s_i, real_s, real_label) in enumerate(self.data_loader): epoch_iter = epoch_iter + 1 # Generat fake labels randomly (target domain labels) rand_idx = torch.randperm(real_label.size(0)) fake_label = real_label[rand_idx] rand_idx = torch.randperm(real_label.size(0)) fake_s = real_s[rand_idx] fake_s_i = real_s_i[rand_idx] if self.dataset == 'CelebA': real_c = real_label.clone() fake_c = fake_label.clone() else: real_c = self.one_hot(real_label, self.c_dim) fake_c = self.one_hot(fake_label, self.c_dim) # Convert tensor to variable real_x = self.to_var(real_x) real_s = self.to_var(real_s) real_s_i = self.to_var(real_s_i) fake_s = self.to_var(fake_s) fake_s_i = self.to_var(fake_s_i) real_c = self.to_var(real_c) # input for the generator fake_c = self.to_var(fake_c) real_label = self.to_var(real_label) # this is same as real_c if dataset == 'CelebA' fake_label = self.to_var(fake_label) # ================== Train D ================== # # Compute loss with real images out_src, out_cls = self.D(real_x) d_loss_real = - torch.mean(out_src) if self.dataset == 'CelebA': d_loss_cls = F.binary_cross_entropy_with_logits( out_cls, real_label, size_average=False) / real_x.size(0) else: d_loss_cls = F.cross_entropy(out_cls, real_label) # Compute classification accuracy of the discriminator if (i+1) % self.log_step == 0: accuracies = self.compute_accuracy(out_cls, real_label, self.dataset) log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] if self.dataset == 'CelebA': print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='') else: print('Classification Acc (8 emotional expressions): ', end='') print(log) # Compute loss with fake images fake_x = self.G(real_x, fake_c, fake_s) fake_x = Variable(fake_x.data) out_src, out_cls = self.D(fake_x) d_loss_fake = torch.mean(out_src) # Backward + Optimize d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls self.reset_grad() d_loss.backward() self.d_optimizer.step() # Compute gradient penalty alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) out, out_cls = self.D(interpolated) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones(out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # ================== Train A ================== # self.a_optimizer.zero_grad() out_real_s = self.A(real_x) # a_loss = self.criterion_s(out_real_s, real_s_i.type(torch.cuda.LongTensor)) * self.lambda_s a_loss = self.criterion_s(out_real_s, real_s_i) * self.lambda_s # a_loss = torch.mean(torch.abs(real_s - out_real_s)) a_loss.backward() self.a_optimizer.step() # Logging loss = {} loss['D/loss_real'] = d_loss_real.data[0] loss['D/loss_fake'] = d_loss_fake.data[0] loss['D/loss_cls'] = d_loss_cls.data[0] loss['D/loss_gp'] = d_loss_gp.data[0] # ================== Train G ================== # if (i+1) % self.d_train_repeat == 0: # Original-to-target and target-to-original domain fake_x = self.G(real_x, fake_c, fake_s) rec_x = self.G(fake_x, real_c, real_s) # Compute losses out_src, out_cls = self.D(fake_x) g_loss_fake = - torch.mean(out_src) g_loss_rec = self.lambda_rec * torch.mean(torch.abs(real_x - rec_x)) if self.dataset == 'CelebA': g_loss_cls = F.binary_cross_entropy_with_logits( out_cls, fake_label, size_average=False) / fake_x.size(0) else: g_loss_cls = F.cross_entropy(out_cls, fake_label) # segmentation loss out_fake_s = self.A(fake_x) g_loss_s = self.lambda_s * self.criterion_s(out_fake_s, fake_s_i) # Backward + Optimize g_loss = g_loss_fake + g_loss_rec + g_loss_s + self.lambda_cls * g_loss_cls # g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging loss['G/loss_fake'] = g_loss_fake.data[0] loss['G/loss_rec'] = g_loss_rec.data[0] loss['G/loss_cls'] = g_loss_cls.data[0] if (i+1) % self.visual_step == 0: # save visuals self.real_x = real_x self.fake_x = fake_x self.rec_x = rec_x self.real_s = real_s self.fake_s = fake_s self.out_real_s = out_real_s self.out_fake_s = out_fake_s self.a_loss = a_loss # save losses self.d_real = - d_loss_real self.d_fake = d_loss_fake self.d_loss = d_loss self.g_loss = g_loss self.g_loss_fake = g_loss_fake self.g_loss_rec = g_loss_rec self.g_loss_s = g_loss_s errors_D = self.get_current_errors('D') errors_G = self.get_current_errors('G') self.visualizer.display_current_results(self.get_current_visuals(), e) self.visualizer.plot_current_errors_D(e, float(epoch_iter)/float(iters_per_epoch), errors_D) self.visualizer.plot_current_errors_G(e, float(epoch_iter)/float(iters_per_epoch), errors_G) # Print out log info if (i+1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( elapsed, e+1, self.num_epochs, i+1, iters_per_epoch) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1) # Translate fixed images for debugging if (i+1) % self.sample_step == 0: fake_image_list = [fixed_x] fixed_c = fixed_c_list[0] real_seg_list = [] for fixed_c in fixed_c_list: for fixed_s in fixed_s_list: fake_image_list.append(self.G(fixed_x, fixed_c, fixed_s)) real_seg_list.append(fixed_s) fake_images = torch.cat(fake_image_list, dim=3) real_seg_images = torch.cat(real_seg_list, dim=3) save_image(self.denorm(fake_images.data), os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0) save_image(self.cat2class_tensor(real_seg_images.data), os.path.join(self.sample_path, '{}_{}_seg.png'.format(e+1, i+1)),nrow=1, padding=0) print('Translated images and saved into {}..!'.format(self.sample_path)) # Save model checkpoints if (i+1) % self.model_save_step == 0: torch.save(self.G.state_dict(), os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e+1, i+1))) torch.save(self.D.state_dict(), os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e+1, i+1))) torch.save(self.A.state_dict(), os.path.join(self.model_save_path, '{}_{}_A.pth'.format(e+1, i+1))) # Decay learning rate if (e+1) > (self.num_epochs - self.num_epochs_decay): g_lr -= (self.g_lr / float(self.num_epochs_decay)) d_lr -= (self.d_lr / float(self.num_epochs_decay)) self.update_lr(g_lr, d_lr) print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
def train(self): """Train StarGAN within a single dataset.""" # Set dataloader if self.dataset == 'CelebA': self.data_loader = self.celebA_loader else: self.data_loader = self.rafd_loader # The number of iterations per epoch iters_per_epoch = len(self.data_loader) fixed_x = [] real_c = [] for i, (images, labels) in enumerate(self.data_loader): fixed_x.append(images) real_c.append(labels) if i == 3: break # Fixed inputs and target domain labels for debugging fixed_x = torch.cat(fixed_x, dim=0) fixed_x = self.to_var(fixed_x, volatile=True) real_c = torch.cat(real_c, dim=0) if self.dataset == 'CelebA': fixed_c_list = self.make_celeb_labels(real_c) elif self.dataset == 'RaFD': fixed_c_list = [] for i in range(self.c_dim): fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c_dim) fixed_c_list.append(self.to_var(fixed_c, volatile=True)) # lr cache for decaying g_lr = self.g_lr d_lr = self.d_lr # Start with trained model if exists if self.pretrained_model: start = int(self.pretrained_model.split('_')[0]) else: start = 0 # Start training start_time = time.time() for e in range(start, self.num_epochs): for i, (real_x, real_label) in enumerate(self.data_loader): # Generat fake labels randomly (target domain labels) rand_idx = torch.randperm(real_label.size(0)) fake_label = real_label[rand_idx] if self.dataset == 'CelebA': real_c = real_label.clone() fake_c = fake_label.clone() else: real_c = self.one_hot(real_label, self.c_dim) fake_c = self.one_hot(fake_label, self.c_dim) # Convert tensor to variable real_x = self.to_var(real_x) real_c = self.to_var(real_c) # input for the generator fake_c = self.to_var(fake_c) real_label = self.to_var(real_label) # this is same as real_c if dataset == 'CelebA' fake_label = self.to_var(fake_label) # ================== Train D ================== # # Compute loss with real images out_src, out_cls = self.D(real_x) d_loss_real = - torch.mean(out_src) if self.dataset == 'CelebA': d_loss_cls = F.binary_cross_entropy_with_logits( out_cls, real_label, size_average=False) / real_x.size(0) else: d_loss_cls = F.cross_entropy(out_cls, real_label) # Compute classification accuracy of the discriminator if (i+1) % self.log_step == 0: accuracies = self.compute_accuracy(out_cls, real_label, self.dataset) log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] if self.dataset == 'CelebA': print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='') else: print('Classification Acc (8 emotional expressions): ', end='') print(log) # Compute loss with fake images fake_x = self.G(real_x, fake_c) fake_x = Variable(fake_x.data) out_src, out_cls = self.D(fake_x) d_loss_fake = torch.mean(out_src) # Backward + Optimize d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls self.reset_grad() d_loss.backward() self.d_optimizer.step() # Compute gradient penalty alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) out, out_cls = self.D(interpolated) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones(out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging loss = {} loss['D/loss_real'] = d_loss_real.data[0] loss['D/loss_fake'] = d_loss_fake.data[0] loss['D/loss_cls'] = d_loss_cls.data[0] loss['D/loss_gp'] = d_loss_gp.data[0] # ================== Train G ================== # if (i+1) % self.d_train_repeat == 0: # Original-to-target and target-to-original domain fake_x = self.G(real_x, fake_c) rec_x = self.G(fake_x, real_c) # Compute losses out_src, out_cls = self.D(fake_x) g_loss_fake = - torch.mean(out_src) g_loss_rec = torch.mean(torch.abs(real_x - rec_x)) if self.dataset == 'CelebA': g_loss_cls = F.binary_cross_entropy_with_logits( out_cls, fake_label, size_average=False) / fake_x.size(0) else: g_loss_cls = F.cross_entropy(out_cls, fake_label) # Backward + Optimize g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging loss['G/loss_fake'] = g_loss_fake.data[0] loss['G/loss_rec'] = g_loss_rec.data[0] loss['G/loss_cls'] = g_loss_cls.data[0] # Print out log info if (i+1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( elapsed, e+1, self.num_epochs, i+1, iters_per_epoch) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1) # Translate fixed images for debugging if (i+1) % self.sample_step == 0: fake_image_list = [fixed_x] for fixed_c in fixed_c_list: fake_image_list.append(self.G(fixed_x, fixed_c)) fake_images = torch.cat(fake_image_list, dim=3) save_image(self.denorm(fake_images.data), os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0) print('Translated images and saved into {}..!'.format(self.sample_path)) # Save model checkpoints if (i+1) % self.model_save_step == 0: torch.save(self.G.state_dict(), os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e+1, i+1))) torch.save(self.D.state_dict(), os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e+1, i+1))) # Decay learning rate if (e+1) > (self.num_epochs - self.num_epochs_decay): g_lr -= (self.g_lr / float(self.num_epochs_decay)) d_lr -= (self.d_lr / float(self.num_epochs_decay)) self.update_lr(g_lr, d_lr) print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
def train_multi(self): """Train StarGAN with multiple datasets. In the code below, 1 is related to CelebA and 2 is releated to RaFD. """ # Fixed imagse and labels for debugging fixed_x = [] real_c = [] for i, (images, labels) in enumerate(self.celebA_loader): fixed_x.append(images) real_c.append(labels) if i == 2: break fixed_x = torch.cat(fixed_x, dim=0) fixed_x = self.to_var(fixed_x, volatile=True) real_c = torch.cat(real_c, dim=0) fixed_c1_list = self.make_celeb_labels(real_c) fixed_c2_list = [] for i in range(self.c2_dim): fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c2_dim) fixed_c2_list.append(self.to_var(fixed_c, volatile=True)) fixed_zero1 = self.to_var(torch.zeros(fixed_x.size(0), self.c2_dim)) # zero vector when training with CelebA fixed_mask1 = self.to_var(self.one_hot(torch.zeros(fixed_x.size(0)), 2)) # mask vector: [1, 0] fixed_zero2 = self.to_var(torch.zeros(fixed_x.size(0), self.c_dim)) # zero vector when training with RaFD fixed_mask2 = self.to_var(self.one_hot(torch.ones(fixed_x.size(0)), 2)) # mask vector: [0, 1] # lr cache for decaying g_lr = self.g_lr d_lr = self.d_lr # data iterator data_iter1 = iter(self.celebA_loader) data_iter2 = iter(self.rafd_loader) # Start with trained model if self.pretrained_model: start = int(self.pretrained_model) + 1 else: start = 0 # # Start training start_time = time.time() for i in range(start, self.num_iters): # Fetch mini-batch images and labels try: real_x1, real_label1 = next(data_iter1) except: data_iter1 = iter(self.celebA_loader) real_x1, real_label1 = next(data_iter1) try: real_x2, real_label2 = next(data_iter2) except: data_iter2 = iter(self.rafd_loader) real_x2, real_label2 = next(data_iter2) # Generate fake labels randomly (target domain labels) rand_idx = torch.randperm(real_label1.size(0)) fake_label1 = real_label1[rand_idx] rand_idx = torch.randperm(real_label2.size(0)) fake_label2 = real_label2[rand_idx] real_c1 = real_label1.clone() fake_c1 = fake_label1.clone() zero1 = torch.zeros(real_x1.size(0), self.c2_dim) mask1 = self.one_hot(torch.zeros(real_x1.size(0)), 2) real_c2 = self.one_hot(real_label2, self.c2_dim) fake_c2 = self.one_hot(fake_label2, self.c2_dim) zero2 = torch.zeros(real_x2.size(0), self.c_dim) mask2 = self.one_hot(torch.ones(real_x2.size(0)), 2) # Convert tensor to variable real_x1 = self.to_var(real_x1) real_c1 = self.to_var(real_c1) fake_c1 = self.to_var(fake_c1) mask1 = self.to_var(mask1) zero1 = self.to_var(zero1) real_x2 = self.to_var(real_x2) real_c2 = self.to_var(real_c2) fake_c2 = self.to_var(fake_c2) mask2 = self.to_var(mask2) zero2 = self.to_var(zero2) real_label1 = self.to_var(real_label1) fake_label1 = self.to_var(fake_label1) real_label2 = self.to_var(real_label2) fake_label2 = self.to_var(fake_label2) # ================== Train D ================== # # Real images (CelebA) out_real, out_cls = self.D(real_x1) out_cls1 = out_cls[:, :self.c_dim] # celebA part d_loss_real = - torch.mean(out_real) d_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, real_label1, size_average=False) / real_x1.size(0) # Real images (RaFD) out_real, out_cls = self.D(real_x2) out_cls2 = out_cls[:, self.c_dim:] # rafd part d_loss_real += - torch.mean(out_real) d_loss_cls += F.cross_entropy(out_cls2, real_label2) # Compute classification accuracy of the discriminator if (i+1) % self.log_step == 0: accuracies = self.compute_accuracy(out_cls1, real_label1, 'CelebA') log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='') print(log) accuracies = self.compute_accuracy(out_cls2, real_label2, 'RaFD') log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] print('Classification Acc (8 emotional expressions): ', end='') print(log) # Fake images (CelebA) fake_c = torch.cat([fake_c1, zero1, mask1], dim=1) fake_x1 = self.G(real_x1, fake_c) fake_x1 = Variable(fake_x1.data) out_fake, _ = self.D(fake_x1) d_loss_fake = torch.mean(out_fake) # Fake images (RaFD) fake_c = torch.cat([zero2, fake_c2, mask2], dim=1) fake_x2 = self.G(real_x2, fake_c) out_fake, _ = self.D(fake_x2) d_loss_fake += torch.mean(out_fake) # Backward + Optimize d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls self.reset_grad() d_loss.backward() self.d_optimizer.step() # Compute gradient penalty if (i+1) % 2 == 0: real_x = real_x1 fake_x = fake_x1 else: real_x = real_x2 fake_x = fake_x2 alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) out, out_cls = self.D(interpolated) if (i+1) % 2 == 0: out_cls = out_cls[:, :self.c_dim] # CelebA else: out_cls = out_cls[:, self.c_dim:] # RaFD grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones(out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging loss = {} loss['D/loss_real'] = d_loss_real.data[0] loss['D/loss_fake'] = d_loss_fake.data[0] loss['D/loss_cls'] = d_loss_cls.data[0] loss['D/loss_gp'] = d_loss_gp.data[0] # ================== Train G ================== # if (i+1) % self.d_train_repeat == 0: # Original-to-target and target-to-original domain (CelebA) fake_c = torch.cat([fake_c1, zero1, mask1], dim=1) real_c = torch.cat([real_c1, zero1, mask1], dim=1) fake_x1 = self.G(real_x1, fake_c) rec_x1 = self.G(fake_x1, real_c) # Compute losses out, out_cls = self.D(fake_x1) out_cls1 = out_cls[:, :self.c_dim] g_loss_fake = - torch.mean(out) g_loss_rec = torch.mean(torch.abs(real_x1 - rec_x1)) g_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, fake_label1, size_average=False) / fake_x1.size(0) # Original-to-target and target-to-original domain (RaFD) fake_c = torch.cat([zero2, fake_c2, mask2], dim=1) real_c = torch.cat([zero2, real_c2, mask2], dim=1) fake_x2 = self.G(real_x2, fake_c) rec_x2 = self.G(fake_x2, real_c) # Compute losses out, out_cls = self.D(fake_x2) out_cls2 = out_cls[:, self.c_dim:] g_loss_fake += - torch.mean(out) g_loss_rec += torch.mean(torch.abs(real_x2 - rec_x2)) g_loss_cls += F.cross_entropy(out_cls2, fake_label2) # Backward + Optimize g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + self.lambda_rec * g_loss_rec self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging loss['G/loss_fake'] = g_loss_fake.data[0] loss['G/loss_cls'] = g_loss_cls.data[0] loss['G/loss_rec'] = g_loss_rec.data[0] # Print out log info if (i+1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed [{}], Iter [{}/{}]".format( elapsed, i+1, self.num_iters) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, i+1) # Translate the images (debugging) if (i+1) % self.sample_step == 0: fake_image_list = [fixed_x] # Changing hair color, gender, and age for j in range(self.c_dim): fake_c = torch.cat([fixed_c1_list[j], fixed_zero1, fixed_mask1], dim=1) fake_image_list.append(self.G(fixed_x, fake_c)) # Changing emotional expressions for j in range(self.c2_dim): fake_c = torch.cat([fixed_zero2, fixed_c2_list[j], fixed_mask2], dim=1) fake_image_list.append(self.G(fixed_x, fake_c)) fake = torch.cat(fake_image_list, dim=3) # Save the translated images save_image(self.denorm(fake.data), os.path.join(self.sample_path, '{}_fake.png'.format(i+1)), nrow=1, padding=0) # Save model checkpoints if (i+1) % self.model_save_step == 0: torch.save(self.G.state_dict(), os.path.join(self.model_save_path, '{}_G.pth'.format(i+1))) torch.save(self.D.state_dict(), os.path.join(self.model_save_path, '{}_D.pth'.format(i+1))) # Decay learning rate decay_step = 1000 if (i+1) > (self.num_iters - self.num_iters_decay) and (i+1) % decay_step==0: g_lr -= (self.g_lr / float(self.num_iters_decay) * decay_step) d_lr -= (self.d_lr / float(self.num_iters_decay) * decay_step) self.update_lr(g_lr, d_lr) print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
def train(self): print(len(self.data_loader)) for e in range(self.epoch): for i, batch_images in enumerate(self.data_loader): batch_size = batch_images.size(0) label = torch.FloatTensor(batch_size) real_x = self.to_var(batch_images) noise_x = self.to_var( torch.FloatTensor(noise_vector(batch_size, self.noise_n))) # train D fake_x = self.G(noise_x) real_out = self.D(real_x) fake_out = self.D(fake_x.detach()) D_real = -torch.mean(real_out) D_fake = torch.mean(fake_out) D_loss = D_real + D_fake self.reset_grad() D_loss.backward() self.D_optimizer.step() # Log loss = {} loss['D/loss_real'] = D_real.data[0] loss['D/loss_fake'] = D_fake.data[0] loss['D/loss'] = D_loss.data[0] # choose one in below two # Clip weights of D # for p in self.D.parameters(): # p.data.clamp_(-self.clip_value, clip_value) # Gradients penalty, WGAP-GP alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) # print(alpha.shape, real_x.shape, fake_x.shape) interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) gp_out = self.D(interpolated) grad = torch.autograd.grad(outputs=gp_out, inputs=interpolated, grad_outputs=torch.ones( gp_out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.D_optimizer.step() # Train G if (i + 1) % self.D_train_step == 0: fake_out = self.D(self.G(noise_x)) G_loss = -torch.mean(fake_out) self.reset_grad() G_loss.backward() self.G_optimizer.step() loss['G/loss'] = G_loss.data[0] # Print log if (i + 1) % self.log_step == 0: log = "Epoch: {}/{}, Iter: {}/{}".format( e + 1, self.epoch, i + 1, len(self.data_loader)) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary( tag, value, e * len(self.data_loader) + i + 1) # Save images if (e + 1) % self.save_image_step == 0: noise_x = self.to_var( torch.FloatTensor(noise_vector(16, self.noise_n))) fake_image = self.G(noise_x) save_image( self.denorm(fake_image.data), os.path.join(self.image_save_path, "{}_fake.png".format(e + 1))) if (e + 1) % self.model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, "{}_G.pth".format(e + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, "{}_D.pth".format(e + 1)))
def train(self): """Train StarGAN within a single dataset.""" # Set dataloader if self.dataset == 'CelebA': self.data_loader = self.celebA_loader elif self.dataset == 'RaFD': self.data_loader = self.rafd_loader elif self.dataset == 'fer2013': self.data_loader = self.fer2013_loader elif self.dataset == 'ferg_db': self.data_loader = self.ferg_db_loader # The number of iterations per epoch iters_per_epoch = len(self.data_loader) fixed_x = [] real_c = [] for i, (images, labels) in enumerate(self.data_loader): fixed_x.append(images) real_c.append(labels) if i == 3: break # Fixed inputs and target domain labels for debugging fixed_x = torch.cat(fixed_x, dim=0) fixed_x = self.to_var(fixed_x, volatile=True) real_c = torch.cat(real_c, dim=0) if self.dataset in ['CelebA']: fixed_c_list = self.make_celeb_labels(real_c) elif self.dataset in ['RaFD', 'fer2013', 'ferg_db']: fixed_c_list = [] for i in range(self.c_dim): fixed_c = self.one_hot( torch.ones(fixed_x.size(0)) * i, self.c_dim) fixed_c_list.append(self.to_var(fixed_c, volatile=True)) # lr cache for decaying g_lr = self.g_lr d_lr = self.d_lr # Start with trained model if exists if self.pretrained_model: start = int(self.pretrained_model.split('_')[0]) else: start = 0 # Start training start_time = time.time() for e in range(start, self.num_epochs): for i, (real_x, real_label) in enumerate(self.data_loader): # Generat fake labels randomly (target domain labels) rand_idx = torch.randperm(real_label.size(0)) fake_label = real_label[rand_idx] if self.dataset == 'CelebA': real_c = real_label.clone() fake_c = fake_label.clone() else: real_c = self.one_hot(real_label, self.c_dim) fake_c = self.one_hot(fake_label, self.c_dim) # Convert tensor to variable real_x = self.to_var(real_x) real_c = self.to_var(real_c) # input for the generator fake_c = self.to_var(fake_c) real_label = self.to_var( real_label ) # this is same as real_c if dataset == 'CelebA' fake_label = self.to_var(fake_label) # ================== Train D ================== # # Compute loss with real images out_src, out_cls, out_feats_real = self.D(real_x) d_loss_real = -torch.mean(out_src) if self.dataset == 'CelebA': d_loss_cls = F.binary_cross_entropy_with_logits( out_cls, real_label, size_average=False) / real_x.size(0) else: d_loss_cls = F.cross_entropy(out_cls, real_label) # Compute classification accuracy of the discriminator if (i + 1) % self.log_step == 0: accuracies = self.compute_accuracy(out_cls, real_label, self.dataset) log = [ "{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy() ] if self.dataset == 'CelebA': print( 'Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='') elif self.dataset in ['fer2013', 'ferg_db']: print('Classification Acc (7 emotional expressions): ', end='') else: print('Classification Acc (8 emotional expressions): ', end='') print(log) # Compute loss with fake images fake_x = self.G(real_x, fake_c) fake_x = Variable(fake_x.data) out_src, out_cls, out_feats_fake = self.D(fake_x) d_loss_fake = torch.mean(out_src) # Backward + Optimize d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls self.reset_grad() d_loss.backward(retain_graph=True) self.d_optimizer.step() # Compute gradient penalty alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) out, out_cls, out_feats = self.D(interpolated) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=torch.ones( out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward(retain_graph=True) self.d_optimizer.step() # Logging loss = {} loss['D/loss_real'] = d_loss_real.data[0] loss['D/loss_fake'] = d_loss_fake.data[0] loss['D/loss_cls'] = d_loss_cls.data[0] loss['D/loss_gp'] = d_loss_gp.data[0] # ================== Train G ================== # if (i + 1) % self.d_train_repeat == 0: # Original-to-target and target-to-original domain fake_x = self.G(real_x, fake_c) rec_x = self.G(fake_x, real_c) # Compute losses out_src, out_cls, out_feats_fake = self.D(fake_x) g_loss_fake = -torch.mean(out_src) if self.dataset == 'CelebA': g_loss_cls = F.binary_cross_entropy_with_logits( out_cls, fake_label, size_average=False) / fake_x.size(0) else: g_loss_cls = F.cross_entropy(out_cls, fake_label) ### Discriminate for rec_x out_src, out_cls, out_feats_rec = self.D(rec_x) g_loss_rec = torch.mean(torch.abs(real_x - rec_x)) ''' ### Replace pixel-wise reconstruction error between real_x / rec_x ### with feature-wise reconstruction error (multi layers) between real_feat / rec_feat (L1 norm) g_loss_feat_rec = 0 for real_feat, rec_feat in zip(out_feats_real, out_feats_rec): g_loss_feat_rec += torch.mean(torch.abs(real_feat - rec_feat)) ''' ''' ### Feature matching (distribution) loss (multi layers) between real_feat / rec_feat (L2 norm, from DiscoGAN) feat_criterion = nn.HingeEmbeddingLoss() g_loss_feat_match = 0 for real_feat, rec_feat in zip(out_feats_real, out_feats_rec): l2 = (torch.mean(real_feat, 0) - torch.mean(rec_feat, 0)) ** 2 g_loss_feat_match += feat_criterion( l2, Variable( torch.ones( l2.size() ) ).cuda() ) ''' # Backward + Optimize g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + self.lambda_rec * g_loss_rec ''' if e < self.num_epochs // 5: # early phase g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + self.lambda_rec * g_loss_rec else: g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + self.lambda_feat_rec * g_loss_feat_rec ''' self.reset_grad() g_loss.backward(retain_graph=True) self.g_optimizer.step() # Logging loss['G/loss_fake'] = g_loss_fake.data[0] loss['G/loss_rec'] = g_loss_rec.data[0] loss['G/loss_cls'] = g_loss_cls.data[0] ###loss['G/loss_feat_rec'] = g_loss_feat_rec.data[0] # Print out log info if (i + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( elapsed, e + 1, self.num_epochs, i + 1, iters_per_epoch) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary( tag, value, e * iters_per_epoch + i + 1) # Translate fixed images for debugging if (i + 1) % self.sample_step == 0: fake_image_list = [fixed_x] for fixed_c in fixed_c_list: fake_image_list.append(self.G(fixed_x, fixed_c)) fake_images = torch.cat(fake_image_list, dim=3) save_image(self.denorm(fake_images.data), os.path.join( self.sample_path, '{}_{}_fake.png'.format(e + 1, i + 1)), nrow=1, padding=0) print('Translated images and saved into {}..!'.format( self.sample_path)) # Save model checkpoints if (i + 1) % self.model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e + 1, i + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e + 1, i + 1))) # Decay learning rate if (e + 1) > (self.num_epochs - self.num_epochs_decay): g_lr -= (self.g_lr / float(self.num_epochs_decay)) d_lr -= (self.d_lr / float(self.num_epochs_decay)) self.update_lr(g_lr, d_lr) print('Decay learning rate to g_lr: {}, d_lr: {}.'.format( g_lr, d_lr))