class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters, opts): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] self.loss = {} # fix the noise used in sampling self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(opts.output_base + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def print_network(self, model, name): num_params = 0 for p in model.parameters(): num_params += p.numel() logger.info( '{} - {} - Number of parameters: {}'.format(name, model, num_params)) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss['G/rec_x_A'] = self.recon_criterion(x_a_recon, x_a) self.loss['G/rec_x_B'] = self.recon_criterion(x_b_recon, x_b) self.loss['G/rec_s_A'] = self.recon_criterion(s_a_recon, s_a) self.loss['G/rec_s_B'] = self.recon_criterion(s_b_recon, s_b) self.loss['G/rec_c_A'] = self.recon_criterion(c_a_recon, c_a) self.loss['G/rec_c_B'] = self.recon_criterion(c_b_recon, c_b) self.loss['G/cycrec_x_A'] = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss['G/cycrec_x_B'] = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss['G/adv_A'] = self.dis_a.calc_gen_loss(x_ba) self.loss['G/adv_B'] = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss['G/vgg_A'] = self.compute_vgg_loss(self.vgg, x_ba.cuda(), x_b.cuda()) if hyperparameters['vgg_w'] > 0 else 0 self.loss['G/vgg_B'] = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss['G/total'] = hyperparameters['gan_w'] * self.loss['G/adv_A'] + \ hyperparameters['gan_w'] * self.loss['G/adv_B'] + \ hyperparameters['recon_x_w'] * self.loss['G/rec_x_A'] + \ hyperparameters['recon_s_w'] * self.loss['G/rec_s_A'] + \ hyperparameters['recon_c_w'] * self.loss['G/rec_c_A'] + \ hyperparameters['recon_x_w'] * self.loss['G/rec_x_B'] + \ hyperparameters['recon_s_w'] * self.loss['G/rec_s_B'] + \ hyperparameters['recon_c_w'] * self.loss['G/rec_c_B'] + \ hyperparameters['recon_x_cyc_w'] * self.loss['G/cycrec_x_A'] + \ hyperparameters['recon_x_cyc_w'] * self.loss['G/cycrec_x_B'] + \ hyperparameters['vgg_w'] * self.loss['G/vgg_A'] + \ hyperparameters['vgg_w'] * self.loss['G/vgg_B'] self.loss['G/total'].backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a_fake.unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b_fake.unsqueeze(0))) outputs = {} outputs['A/real'] = x_a outputs['B/real'] = x_b outputs['A/rec'] = torch.cat(x_a_recon) outputs['B/rec'] = torch.cat(x_b_recon) outputs['A/B_random_style'] = torch.cat(x_ab1) outputs['A/B'] = torch.cat(x_ab2) outputs['B/A_random_style'] = torch.cat(x_ba1) outputs['B/A'] = torch.cat(x_ba2) self.train() return outputs def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss['D/A'] = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss['D/B'] = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss['D/total'] = hyperparameters['gan_w'] * self.loss['D/A'] + hyperparameters['gan_w'] * self.loss['D/B'] self.loss['D/total'].backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: old_lr = self.dis_opt.param_groups[0]['lr'] self.dis_scheduler.step() new_lr = self.dis_opt.param_groups[0]['lr'] if old_lr != new_lr: logger.info('Updated D learning rate: {}'.format(new_lr)) if self.gen_scheduler is not None: old_lr = self.gen_opt.param_groups[0]['lr'] self.gen_scheduler.step() new_lr = self.gen_opt.param_groups[0]['lr'] if old_lr != new_lr: logger.info('Updated G learning rate: {}'.format(new_lr)) def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) logger.info('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name) logger.info('Saving snapshots to: {}'.format(snapshot_dir))
class aclgan_Trainer(nn.Module): def __init__(self, hyperparameters): super(aclmaskpermgidtno_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_AB = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain A self.gen_BA = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain B self.dis_A = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain A self.dis_B = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain B self.dis_2 = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator 2 # self.dis_2B = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator 2 for domain B self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.z_1 = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.z_2 = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.z_3 = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_A.parameters()) + list( self.dis_B.parameters()) + list(self.dis_2.parameters()) gen_params = list(self.gen_AB.parameters()) + list( self.gen_BA.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) self.alpha = hyperparameters['alpha'] self.focus_lam = hyperparameters['focus_loss'] # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_A.apply(weights_init('gaussian')) self.dis_B.apply(weights_init('gaussian')) self.dis_2.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): z_1 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) z_2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) z_3 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # encode c_1, _ = self.gen_AB.encode(x_a) c_2, s_2 = self.gen_BA.encode(x_a) c_4, s_4 = self.gen_AB.encode(x_b) # decode self.x_B_fake = self.gen_AB.decode(c_1, z_1) self.x_A_fake = self.gen_BA.decode(c_2, z_2) # recon self.x_A_recon = self.gen_BA.decode(c_2, s_2) self.x_B_recon = self.gen_AB.decode(c_4, s_4) #encode 2 c_3, _ = self.gen_BA.encode(self.x_B_fake) self.x_A2_fake = self.gen_BA.decode(c_3, z_3) self.X_A_A1_pair = torch.cat((x_a, self.x_A_fake), -3) self.X_A_A2_pair = torch.cat((x_a, self.x_A2_fake), -3) def focus_translation(self, x_fg, x_bg, x_focus): x_map = (x_focus + 1) / 2 x_map = x_map.repeat(1, 3, 1, 1) return torch.mul(x_fg, x_map) + torch.mul(x_bg, 1 - x_map) def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() focus_delta = hyperparameters['focus_delta'] focus_lambda = hyperparameters['focus_loss'] focus_lower = hyperparameters['focus_lower'] focus_upper = hyperparameters['focus_upper'] focus_epsilon = hyperparameters['focus_epsilon'] #forward z_1 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) z_2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) z_3 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # encode c_1, _ = self.gen_AB.encode(x_a) c_2, s_2 = self.gen_BA.encode(x_a) c_4, s_4 = self.gen_AB.encode(x_b) # decode if focus_lambda > 0: x_B_fake, x_B_focus = self.gen_AB.decode(c_1, z_1).split(3, 1) x_A_fake, x_A_focus = self.gen_BA.decode(c_2, self.alpha * z_2).split( 3, 1) x_B_fake = self.focus_translation(x_B_fake, x_a, x_B_focus) x_A_fake = self.focus_translation(x_A_fake, x_a, x_A_focus) # recon x_A_recon, x_A_recon_focus = self.gen_BA.decode(c_2, s_2).split(3, 1) x_B_recon, x_B_recon_focus = self.gen_AB.decode(c_4, s_4).split(3, 1) # x_A_recon = self.focus_translation(x_A_recon, x_a, x_A_recon_focus) # x_B_recon = self.focus_translation(x_B_recon, x_b, x_B_recon_focus) else: x_B_fake = self.gen_AB.decode(c_1, z_1) x_A_fake = self.gen_BA.decode(c_2, self.alpha * z_2) # recon x_A_recon = self.gen_BA.decode(c_2, s_2) x_B_recon = self.gen_AB.decode(c_4, s_4) #encode 2 c_3, _ = self.gen_BA.encode(x_B_fake) if focus_lambda > 0: x_A2_fake, x_A2_focus = self.gen_BA.decode(c_3, z_3).split(3, 1) x_A2_fake = self.focus_translation(x_A2_fake, x_B_fake, x_A2_focus) else: x_A2_fake = self.gen_BA.decode(c_3, z_3) x_A_A1_pair = torch.cat((x_a, x_A_fake), -3) x_A_A2_pair = torch.cat((x_a, x_A2_fake), -3) # GAN loss self.loss_gen_adv_A = (self.dis_A.calc_gen_loss(x_A_fake) + \ self.dis_A.calc_gen_loss(x_A2_fake)) * 0.5 self.loss_gen_adv_B = self.dis_B.calc_gen_loss(x_B_fake) self.loss_gen_adv_2 = self.dis_2.calc_gen_d2_loss( x_A_A1_pair, x_A_A2_pair) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_A + \ hyperparameters['gan_w'] * self.loss_gen_adv_B + \ hyperparameters['gan_cw'] * self.loss_gen_adv_2 if focus_lambda > 0: x_B_focus = (x_B_focus + 1) / 2 x_A_focus = (x_A_focus + 1) / 2 x_A2_focus = (x_A2_focus + 1) / 2 self.loss_gen_focus_B_size = (F.relu(torch.sum(x_B_focus - focus_upper), inplace=True) ** 2) * focus_delta + \ (F.relu(torch.sum(focus_lower - x_B_focus), inplace=True) ** 2) * focus_delta self.loss_gen_focus_B_digit = torch.sum( 1 / (torch.abs(x_B_focus - 0.5) + focus_epsilon)) self.loss_gen_focus_A_size = (F.relu(torch.sum(x_A_focus - focus_upper), inplace=True) ** 2) * focus_delta + \ (F.relu(torch.sum(focus_lower - x_A_focus), inplace=True) ** 2) * focus_delta self.loss_gen_focus_A_digit = torch.sum( 1 / (torch.abs(x_A_focus - 0.5) + focus_epsilon)) # self.loss_gen_focus_A = torch.sum(1 / (torch.abs(x_A_focus - 0.5) + focus_epsilon)) self.loss_gen_focus_A2_size = (F.relu(torch.sum(x_A2_focus - focus_upper), inplace=True) ** 2) * focus_delta + \ (F.relu(torch.sum(focus_lower - x_A2_focus), inplace=True) ** 2) * focus_delta self.loss_gen_focus_A2_digit = torch.sum( 1 / (torch.abs(x_A2_focus - 0.5) + focus_epsilon)) self.loss_gen_total += focus_lambda * (self.loss_gen_focus_B_size + self.loss_gen_focus_B_digit + \ self.loss_gen_focus_A_size + self.loss_gen_focus_A_digit +\ self.loss_gen_focus_A2_size + self.loss_gen_focus_A2_digit)/ x_a.size(2) / x_a.size(3) / x_a.size(0) / 3 self.loss_idt_A = self.recon_criterion(x_A_recon, x_a) self.loss_idt_B = self.recon_criterion(x_B_recon, x_b) self.loss_gen_total += hyperparameters['recon_x_w'] * self.loss_idt_A + \ hyperparameters['recon_x_w'] * self.loss_idt_B # print(self.loss_gen_focus_B, self.loss_gen_total) # print(self.loss_idt_A) self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() z_1 = Variable(self.z_1) z_2 = Variable(self.z_2) z_3 = Variable(self.z_3) x_A, x_B, x_A_fake, x_B_fake, x_A2_fake = [], [], [], [], [] if self.focus_lam > 0: mask_A, mask_B, mask_A2, mask_recon = [], [], [], [] x_A_recon = [] else: x_A_recon, x_B_recon = [], [] for i in range(x_a.size(0)): x_A.append(x_a[i].unsqueeze(0)) x_B.append(x_b[i].unsqueeze(0)) if self.focus_lam > 0: c_1, s_1 = self.gen_BA.encode(x_a[i].unsqueeze(0)) img, mask = self.gen_BA.decode(c_1, z_1[i].unsqueeze(0)).split( 3, 1) x_A_fake.append( self.focus_translation(img, x_a[i].unsqueeze(0), mask)) mask_A.append(mask) img, mask = self.gen_BA.decode(c_1, s_1).split(3, 1) # x_A_recon.append(self.focus_translation(img, x_a[i].unsqueeze(0), mask)) x_A_recon.append(img) mask_recon.append(mask) c_2, _ = self.gen_AB.encode(x_a[i].unsqueeze(0)) x_b_img, mask = self.gen_AB.decode(c_2, z_2[i].unsqueeze(0)).split( 3, 1) x_b_img = self.focus_translation(x_b_img, x_a[i].unsqueeze(0), mask) x_B_fake.append(x_b_img) mask_B.append(mask) c_3, _ = self.gen_BA.encode(x_b_img) img, mask = self.gen_BA.decode(c_3, z_3[i].unsqueeze(0)).split( 3, 1) x_A2_fake.append(self.focus_translation(img, x_b_img, mask)) mask_A2.append(mask) else: c_1, s_1 = self.gen_BA.encode(x_a[i].unsqueeze(0)) x_A_fake.append(self.gen_BA.decode(c_1, z_1[i].unsqueeze(0))) x_A_recon.append(self.gen_BA.decode(c_1, s_1)) c_2, _ = self.gen_AB.encode(x_a[i].unsqueeze(0)) x_B1 = self.gen_AB.decode(c_2, z_2[i].unsqueeze(0)) x_B_fake.append(x_B1) c_3, _ = self.gen_BA.encode(x_B1) x_A2_fake.append(self.gen_BA.decode(c_3, z_3[i].unsqueeze(0))) c_4, s_4 = self.gen_AB.encode(x_b) x_B_recon.append(self.gen_AB.decode(c_4, s_4)) if self.focus_lam > 0: x_A, x_B = torch.cat(x_A), torch.cat(x_B) x_A_fake, x_B_fake = torch.cat(x_A_fake), torch.cat(x_B_fake) mask_A, x_A2_fake = torch.cat(mask_A), torch.cat(x_A2_fake) mask_B, mask_recon = torch.cat(mask_B), torch.cat(mask_recon) mask_A2, x_A_recon = torch.cat(mask_A2), torch.cat(x_A_recon) self.train() return x_A, x_A_fake, mask_A, x_B_fake, mask_B, x_A2_fake, mask_A2, x_A_recon, mask_recon else: x_A, x_B = torch.cat(x_A), torch.cat(x_B) x_A_fake, x_B_fake = torch.cat(x_A_fake), torch.cat(x_B_fake) x_A_recon, x_A2_fake = torch.cat(x_A_recon), torch.cat(x_A2_fake) x_B_recon = torch.cat(x_B_recon) self.train() return x_A, x_A_fake, x_B_fake, x_A2_fake, x_A_recon, x_B, x_B_recon def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() focus_delta = hyperparameters['focus_delta'] focus_lambda = hyperparameters['focus_loss'] focus_epsilon = hyperparameters['focus_epsilon'] #forward z_1 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) z_2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) z_3 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # encode c_1, _ = self.gen_AB.encode(x_a) c_2, s_2 = self.gen_BA.encode(x_a) c_4, s_4 = self.gen_AB.encode(x_b) # decode if focus_lambda > 0: x_B_fake, x_B_focus = self.gen_AB.decode(c_1, z_1).split(3, 1) x_A_fake, x_A_focus = self.gen_BA.decode(c_2, self.alpha * z_2).split( 3, 1) x_B_fake = self.focus_translation(x_B_fake, x_a, x_B_focus) x_A_fake = self.focus_translation(x_A_fake, x_a, x_A_focus) else: x_B_fake = self.gen_AB.decode(c_1, z_1) x_A_fake = self.gen_BA.decode(c_2, self.alpha * z_2) #encode 2 c_3, _ = self.gen_BA.encode(x_B_fake) if focus_lambda > 0: x_A2_fake, x_A2_focus = self.gen_BA.decode(c_3, z_3).split(3, 1) x_A2_fake = self.focus_translation(x_A2_fake, x_B_fake, x_A2_focus) else: x_A2_fake = self.gen_BA.decode(c_3, z_3) x_A_A1_pair = torch.cat((x_a, x_A_fake), -3) x_A_A2_pair = torch.cat((x_a, x_A2_fake), -3) # D loss self.loss_dis_A = (self.dis_A.calc_dis_loss(x_A_fake, x_a) + \ self.dis_A.calc_dis_loss(x_A2_fake, x_a)) * 0.5 self.loss_dis_B = self.dis_B.calc_dis_loss(x_B_fake, x_b) self.loss_dis_2 = self.dis_2.calc_dis_loss(x_A_A1_pair, x_A_A2_pair) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_A + \ hyperparameters['gan_w'] * self.loss_dis_B + \ hyperparameters['gan_cw'] * self.loss_dis_2 self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_AB.load_state_dict(state_dict['AB']) self.gen_BA.load_state_dict(state_dict['BA']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_A.load_state_dict(state_dict['A']) self.dis_B.load_state_dict(state_dict['B']) self.dis_2.load_state_dict(state_dict['2']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save( { 'AB': self.gen_AB.state_dict(), 'BA': self.gen_BA.state_dict() }, gen_name) torch.save( { 'A': self.dis_A.state_dict(), 'B': self.dis_B.state_dict(), '2': self.dis_2.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() # super() 函数是用于调用父类(超类)的一个方法。 lr = hyperparameters['lr'] # Initiate the networks, 需要好好看看生成器和鉴别器到底是如何构造的 self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b # https://blog.csdn.net/liuxiao214/article/details/81037416 self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) # s_a , s_b 表示的是两个不同的style self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() # 16*8*1*1 self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] # 两个鉴别器 dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) # 两个生成器 gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) # 优化器 self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) # 优化策略 self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization # 解释 apply apply(lambda x,y : x+y, (1),{'y' : 2}) https://zhuanlan.zhihu.com/p/42756654 self.apply(weights_init(hyperparameters['init'])) # 初始化当前类 self.dis_a.apply(weights_init('gaussian')) # 初始化dis_a,是一个类对象 self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) # here the self.s_a is random style s_b = Variable(self.s_b) # here the self.s_b is random style # 两个auto-encoder c_a, s_a_fake = self.gen_a.encode(x_a) # c_a, s_a_fake is the content and style of input x_a c_b, s_b_fake = self.gen_b.encode(x_b) # x_ba 表示的是 imgb->imga x_ba = self.gen_a.decode(c_b, s_a) # combine(c_b, s_a) to generate the x_ba x_ab = self.gen_b.decode(c_a, s_b) # combine(c_a, s_b) to generate the x_ab # 训练模式 self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): # 这里只负责参数的更新操作,真正的训练操作实际上来源于AdainGan 以及 MsImgDis # self.gen_opt 表示的是生成器的优化器 self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # print('content shape:', c_b.shape) # print('style shape:', s_b_prime.shape) # decode (within domain), 这个是必要的部分,因为需要保证解码器的效果 x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # 本质上这里传过来的还是c_a, 我们希望c_a_recon与c_a越相似越好 # decode again (if needed) x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss 为什么要用vgg呢?也就是内容不变的约束 # 对feature map 的 L2 loss 好好理解一下这里的域不变loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def sample(self, x_a, x_b): # 模型调整为eval模式 self.eval() # # random style s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] # 实际上x_a.size(0) = display_size 16 for i in range(x_a.size(0)): # https://blog.csdn.net/xiexu911/article/details/80820028 # unsequeeze 在指定的维度上对数据的维度进行扩展 c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) # 这里用到了cat,那么x_a_recon是4个channel,这样最终输出的就是32个channel,然后我们将这32个channel 分别打印show出来。 x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, hyperparameters): # self.dis_opt 表示的是鉴别器的优化器 self.dis_opt.zero_grad() # s_a : a 图片的风格 # s_b : b 图片的风格 s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode # content encode ,encode 只是将不同图片的内容和风格进行编码,不做迁移处理, 生成器有一个编码器和一个解码器 c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) # 进行交叉域的内容和风格decode,这里需要好好理解一下decode到底在做什么,为什么可以进行内容和风格的混合 x_ba = self.gen_a.decode(c_b, s_a) # xba 是由随机风格和原始图像内容组合的结果 x_ab = self.gen_b.decode(c_a, s_b) # D loss 鉴别器loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) # 这里要把x_ba, x_a 独立出来来看,因为鉴别器的目的是 # 为了判断真伪, self.loss_dis_a 是分别计算x_ba, x_a的loss的和, 他们两个对象之间不涉及对比。 # 这里挺有意思的,鉴别器的loss是两个独立输入图片的二分类loss的和, 注意的一点就是,我们实际上是希望送入鉴别器的这两张图拒用相同的风格。 # 这里就要区分和重建loss的区别。 self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) # 为什么是这样的组合呢? # hyperparameters['gan_w'] weight of adversarial loss self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b # https://blog.csdn.net/jacke121/article/details/82995740 深入理解backward(), step() self.loss_dis_total.backward() # calculate grad self.dis_opt.step() # update grad def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() # update the learning rate if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators # 每次选择的都是最新的模型 last_model_name = get_model_list(checkpoint_dir, "gen") print('resume model: ', last_model_name) state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations, gpuids): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') if len(gpuids) > 1: torch.save({'a': self.gen_a.module.state_dict(), 'b': self.gen_b.module.state_dict()}, gen_name) torch.save({'a': self.dis_a.module.state_dict(), 'b': self.dis_b.module.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.module.state_dict(), 'dis': self.dis_opt.module.state_dict()}, opt_name) else: torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, m_A, x_b, m_B, hyperparameters): self.gen_opt.zero_grad() im_A = 1 - m_A im_B = 1 - m_B # encode c_a, s_bA = self.gen_a.encode(x_a, im_A) c_b, s_fB = self.gen_b.encode(x_b, m_B) _, s_fA = self.gen_a.encode(x_a, m_A) _, s_bB = self.gen_b.encode(x_b, im_B) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_fA, m_B, s_bB) x_ab = self.gen_b.decode(c_a, s_fB, m_A, s_bA) # decode (within domain) x_aa = self.gen_a.decode(c_a, s_fA, m_A, s_bA) x_bb = self.gen_b.decode(c_b, s_fB, m_B, s_bB) # encode again c_ba, s_fBA = self.gen_a.encode(x_ba, m_B) c_ab, s_fAB = self.gen_a.encode(x_ab, m_A) _, s_bBA = self.gen_a.encode(x_ba, im_B) _, s_bAB = self.gen_a.encode(x_ab, im_A) # decode again (if needed) x_aba = self.gen_a.decode( c_ab, s_fBA, m_A, s_bAB) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_ba, s_fAB, m_B, s_bBA) if hyperparameters['recon_x_cyc_w'] > 0 else None self.loss_gen_recon_c_a = self.recon_criterion(c_ab, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_ba, c_b) self.loss_gen_recon_s_a = self.recon_criterion(s_bAB, s_bA) self.loss_gen_recon_s_b = self.recon_criterion(s_bBA, s_bB) self.loss_gen_recon_s_af = self.recon_criterion(s_fAB, s_fB) self.loss_gen_recon_s_bf = self.recon_criterion(s_fBA, s_fA) self.loss_gen_recon_x_a = self.recon_criterion(x_aa, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_bb, x_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( im_A * x_aba, im_A * x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( m_B * x_bab, m_B * x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_af + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_bf + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, loader_a, loader_b, size): self.eval() im_a = torch.stack([loader_a.dataset[i][0] for i in range(size)]).cuda() seg_a = torch.stack([loader_a.dataset[i][1] for i in range(size)]).cuda() im_b = torch.stack([loader_b.dataset[i][0] for i in range(size)]).cuda() seg_b = torch.stack([loader_b.dataset[i][1] for i in range(size)]).cuda() x_a_recon, x_b_recon, x_ba1, x_bm, x_ab1, x_am = [], [], [], [], [], [] for i in range(im_a.size(0)): mask_a = seg_a[i].unsqueeze(0) mask_b = seg_b[i].unsqueeze(0) x_a = im_a[i].unsqueeze(0) x_b = im_b[i].unsqueeze(0) masked_a = mask_a * x_a masked_b = mask_b * x_b c_a, s_bA = self.gen_a.encode(x_a, 1 - mask_a) c_b, s_fB = self.gen_b.encode(x_b, mask_b) c_a, s_fA = self.gen_a.encode(x_a, mask_a) c_b, s_bB = self.gen_b.encode(x_b, 1 - mask_b) # decode (cross domain) x_BA = self.gen_a.decode(c_b, s_fA, mask_b, s_bB) x_AB = self.gen_b.decode(c_a, s_fB, mask_a, s_bA) if 0 == i % 2: x_AB = (1 * (1 - mask_a) * x_a + (0 * (1 - mask_a) * x_AB)) + mask_a * x_AB x_BA = (1 * (1 - mask_b) * x_b + (0 * (1 - mask_b) * x_BA)) + mask_b * x_BA x_ba1.append(x_BA) x_ab1.append(x_AB) x_am.append(masked_a) x_bm.append(masked_b) # decode (within domain) x_A_recon = self.gen_a.decode(c_a, s_fA, mask_a, s_bA) x_B_recon = self.gen_b.decode(c_b, s_fB, mask_b, s_bB) x_a_recon.append(x_A_recon) x_b_recon.append(x_B_recon) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1 = torch.cat(x_ba1) x_ab1 = torch.cat(x_ab1) x_bm = torch.cat(x_bm) x_am = torch.cat(x_am) self.train() return im_a, x_a_recon, x_ab1, x_am, im_b, x_b_recon, x_ba1, x_bm def dis_update(self, x_a, m_a, x_b, m_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode up_im_A = 1 - m_a #F.interpolate(1-m_a, None,1, 'bilinear', align_corners=False) up_m_B = m_b #F.interpolate(m_b, None, 1, 'bilinear', align_corners=False) up_m_A = m_a #F.interpolate(m_a, None, 1, 'bilinear', align_corners=False) up_im_B = 1 - m_b #.interpolate(1-m_b, None, 1, 'bilinear', align_corners=False) c_a, s_bA = self.gen_a.encode(x_a, up_im_A) c_b, s_fB = self.gen_b.encode(x_b, up_m_B) _, s_fA = self.gen_a.encode(x_a, up_m_A) _, s_bB = self.gen_b.encode(x_b, up_im_B) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_fA, m_b, s_bB) x_ab = self.gen_b.decode(c_a, s_fB, m_a, s_bA) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] ''' input_dim_a和input_dim_b是输入图像的维度,RGB图就是3 gen和dis是在yaml中定义的与架构相关的配置 ''' # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() ''' 为每幅显示的图像(总共16幅)配置随机的风格(维度为8) ''' # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) ''' 这种简洁的写法值得学习:先将parameter()的list并起来,然后[p for p in params if p.requires_grad] 这里分别为判别器参数、生成器参数各自建立一个优化器 优化器采用Adam,算法参数为0.5和0.999 优化器中可同时配置权重衰减,这里是1e-4 学习率调节器默认配置为每100000步减小为0.5 ''' # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) ''' 注:这个apply函数递归地对每个子模块应用某种函数 ''' # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False '''默认配置中,没有使用这个vgg网络''' def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) # 注,只有在forward内部,是evaluation模式,具体这个方法在哪里用到了,我还不太清楚。 def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # 两个随机的风格码 s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) # prime表示是由真图解码来的风格码 c_b, s_b_prime = self.gen_b.encode( x_b) # c码为(1,256,64,64);s码为(1,8,1,1) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) # (a)用内容码和风格码还原原图 x_b_recon = self.gen_b.decode(c_b, s_b_prime) # (b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode( c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) # (a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) # (b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) # (c) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) # (d) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) # (e) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) # (f) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # (g) self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # (h) # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) # (i) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # (j) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 # (k) self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # (l) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) # 送进去两张图片(batch),交换他们的风格 def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) # 这个是固定的某种风格 s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # 这是即时生成的随机风格 s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode( x_a[i].unsqueeze(0)) # 这是把送入的图片进行编码 c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) # 这是对送入的图像进行重建 x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode( c_b, s_a1[i].unsqueeze(0))) # 这是把固定的风格施加在b的内容上,产生a风格的图片 x_ba2.append(self.gen_a.decode( c_b, s_a2[i].unsqueeze(0))) # 这是把随机风格施加在b的内容上,产生a风格的图片 x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # 生成图片的随机风格码 (1,8,1,1) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) # 将图像利用编码器 c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters, device): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] self.device = device # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).to(self.device) self.s_b = torch.randn(display_size, self.style_dim, 1, 1).to(self.device) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, self.s_a) x_ab = self.gen_b.decode(c_a, self.s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() s_a = torch.randn(x_a.size(0), self.style_dim, 1, 1).to(self.device) s_b = torch.randn(x_b.size(0), self.style_dim, 1, 1).to(self.device) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode( c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() s_a2 = torch.randn(x_a.size(0), self.style_dim, 1, 1).to(self.device) s_b2 = torch.randn(x_b.size(0), self.style_dim, 1, 1).to(self.device) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, self.s_a[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, self.s_b[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = torch.randn(x_a.size(0), self.style_dim, 1, 1).to(self.device) s_b = torch.randn(x_b.size(0), self.style_dim, 1, 1).to(self.device) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class AGUIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(AGUIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.noise_dim = hyperparameters['gen']['noise_dim'] self.attr_dim = len(hyperparameters['gen']['selected_attrs']) self.gen = AdaINGen(hyperparameters['input_dim'], hyperparameters['gen']) self.dis = MsImageDis(hyperparameters['input_dim'], self.attr_dim, hyperparameters['dis']) self.dis_content = ContentDis( hyperparameters['gen']['dim'] * (2**hyperparameters['gen']['n_downsample']), self.attr_dim) # fix the noise used in sampling display_size = int(hyperparameters['display_size']) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis.parameters()) + list( self.dis_content.parameters()) gen_params = list(self.gen.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis.apply(weights_init('gaussian')) self.dis_content.apply(weights_init('gaussian')) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def gen_update(self, x_l, x_u, l, hyperparameters): self.gen_opt.zero_grad() # l_s_rand = torch.randn_like(l_s) # l_s = torch.where(l_s == 0, l_s_rand, l_s) s_r = torch.cat([torch.randn(x_u.size(0), self.noise_dim).cuda(), l], 1) # encode c_l, s_l = self.gen.encode(x_l) c_u, s_u = self.gen.encode(x_u) # decode (within domain) x_u_recon = self.gen.decode(c_u, s_u) # decode (cross domain) x_ur = self.gen.decode(c_u, s_r) # encode again c_u_recon, s_r_recon = self.gen.encode(x_ur) x_u_cycle = self.gen.decode(c_u_recon, s_u) # additional KL-loss (optional) s_mean = s_l[:, 0:self.noise_dim].mean() s_std = s_l[:, 0:self.noise_dim].std() self.loss_gen_kld = (s_mean**2 + s_std.pow(2) - s_std.pow(2).log() - 1).mean() / 2 self.loss_gen_adv_content = self.dis_content.calc_gen_loss(c_l, c_u, l) # reconstruction loss self.loss_gen_rec = self.recon_criterion(x_u_recon, x_u) self.loss_gen_rec_s = self.recon_criterion(s_r_recon, s_r) self.loss_gen_rec_c = self.recon_criterion(c_u_recon, c_u) self.loss_gen_cyc = self.recon_criterion(x_u_cycle, x_u) # GAN loss self.loss_gen_adv = self.dis.calc_gen_loss(x_ur, l) # label part loss self.loss_gen_cla = ( s_l[:, self.noise_dim:self.noise_dim + self.attr_dim] - l).pow(2).mean() self.loss_gen_total = hyperparameters['adv_w'] * self.loss_gen_adv + \ hyperparameters['adv_c_w'] * self.loss_gen_adv_content + \ hyperparameters['rec_w'] * self.loss_gen_rec + \ hyperparameters['rec_s_w'] * self.loss_gen_rec_s + \ hyperparameters['rec_c_w'] * self.loss_gen_rec_c + \ hyperparameters['cla_w'] * self.loss_gen_cla + \ hyperparameters['kld_w'] * self.loss_gen_kld + \ hyperparameters['cyc_w'] * self.loss_gen_cyc self.loss_gen_total.backward() self.gen_opt.step() return self.loss_gen_total.detach() def sample(self, x_l, l): c_l, s_l = self.gen.encode(x_l) # decode (within domain) x_l_recon = self.gen.decode(c_l, s_l) out = [x_l, x_l_recon] for i in range(self.attr_dim): s_changed = s_l.clone() s_changed[:, self.noise_dim + i] = -l[:, i] out += [self.gen.decode(c_l, s_changed)] return out def dis_update(self, x_l, x_u, l, hyperparameters): self.dis_opt.zero_grad() s_r = torch.cat([torch.randn(x_u.size(0), self.noise_dim).cuda(), l], 1) # encode c_l, s_l = self.gen.encode(x_l) c_u, s_u = self.gen.encode(x_u) # decode (cross domain) x_ur = self.gen.decode(c_u, s_r) # D loss self.loss_dis_adv = self.dis.calc_dis_loss(x_ur.detach(), x_l, x_u, l) self.loss_dis_adv_content = self.dis_content.calc_dis_loss(c_l, c_u, l) self.loss_dis_total = hyperparameters['adv_w'] * self.loss_dis_adv + \ hyperparameters['adv_c_w'] * self.loss_dis_adv_content self.loss_dis_total.backward() self.dis_opt.step() return self.loss_dis_total.detach() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen.load_state_dict(state_dict['gen']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis.load_state_dict(state_dict['dis']) self.dis_content.load_state_dict(state_dict['dis_content']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'gen': self.gen.state_dict()}, gen_name) torch.save( { 'dis': self.dis_a.state_dict(), 'dis_content': self.dis_content.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class UnsupIntrinsicTrainer(nn.Module): def __init__(self, param): super(UnsupIntrinsicTrainer, self).__init__() lr = param['lr'] # Initiate the networks self.gen_i = AdaINGen(param['input_dim_a'], param['input_dim_a'], param['gen']) # auto-encoder for domain I self.gen_r = AdaINGen(param['input_dim_b'], param['input_dim_b'], param['gen']) # auto-encoder for domain R self.gen_s = AdaINGen(param['input_dim_c'], param['input_dim_c'], param['gen']) # auto-encoder for domain S self.dis_r = MsImageDis(param['input_dim_b'], param['dis']) # discriminator for domain R self.dis_s = MsImageDis(param['input_dim_c'], param['dis']) # discriminator for domain S gp = param['gen'] self.with_mapping = True self.use_phy_loss = True self.use_content_loss = True if 'ablation_study' in param: if 'with_mapping' in param['ablation_study']: wm = param['ablation_study']['with_mapping'] self.with_mapping = True if wm != 0 else False if 'wo_phy_loss' in param['ablation_study']: wpl = param['ablation_study']['wo_phy_loss'] self.use_phy_loss = True if wpl == 0 else False if 'wo_content_loss' in param['ablation_study']: wcl = param['ablation_study']['wo_content_loss'] self.use_content_loss = True if wcl == 0 else False if self.with_mapping: self.fea_s = IntrinsicSplitor(gp['style_dim'], gp['mlp_dim'], gp['n_layer'], gp['activ']) # split style for I self.fea_m = IntrinsicMerger(gp['style_dim'], gp['mlp_dim'], gp['n_layer'], gp['activ']) # merge style for R, S self.bias_shift = param['bias_shift'] self.instance_norm = nn.InstanceNorm2d(512, affine=False) self.style_dim = param['gen']['style_dim'] # fix the noise used in sampling display_size = int(param['display_size']) self.s_r = torch.randn(display_size, self.style_dim, 1, 1).cuda() + 1. self.s_s = torch.randn(display_size, self.style_dim, 1, 1).cuda() - 1. # Setup the optimizers beta1 = param['beta1'] beta2 = param['beta2'] dis_params = list(self.dis_r.parameters()) + list( self.dis_s.parameters()) if self.with_mapping: gen_params = list(self.gen_i.parameters()) + list(self.gen_r.parameters()) + \ list(self.gen_s.parameters()) + \ list(self.fea_s.parameters()) + list(self.fea_m.parameters()) else: gen_params = list(self.gen_i.parameters()) + list( self.gen_r.parameters()) + list(self.gen_s.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=param['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=param['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, param) self.gen_scheduler = get_scheduler(self.gen_opt, param) # Network weight initialization self.apply(weights_init(param['init'])) self.dis_r.apply(weights_init('gaussian')) self.dis_s.apply(weights_init('gaussian')) self.best_result = float('inf') self.reflectance_loss = LocalAlbedoSmoothnessLoss(param) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def physical_criterion(self, x_i, x_r, x_s): return torch.mean(torch.abs(x_i - x_r * x_s)) def forward(self, x_i): c_i, s_i_fake = self.gen_i.encode(x_i) if self.with_mapping: s_r, s_s = self.fea_s(s_i_fake) else: s_r = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_s = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift x_ri = self.gen_r.decode(c_i, s_r) x_si = self.gen_s.decode(c_i, s_s) return x_ri, x_si def inference(self, x_i, use_rand_fea=False): with torch.no_grad(): c_i, s_i_fake = self.gen_i.encode(x_i) if self.with_mapping: s_r, s_s = self.fea_s(s_i_fake) else: s_r = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_s = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift if use_rand_fea: s_r = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_s = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift x_ri = self.gen_r.decode(c_i, s_r) x_si = self.gen_s.decode(c_i, s_s) return x_ri, x_si # noinspection PyAttributeOutsideInit def gen_update(self, x_i, x_r, x_s, targets=None, param=None): self.gen_opt.zero_grad() # ============= Domain Translations ============= # encode c_i, s_i_prime = self.gen_i.encode(x_i) c_r, s_r_prime = self.gen_r.encode(x_r) c_s, s_s_prime = self.gen_s.encode(x_s) if self.with_mapping: s_ri, s_si = self.fea_s(s_i_prime) else: s_ri = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_si = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift s_r_rand = Variable( torch.randn(x_r.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_s_rand = Variable( torch.randn(x_s.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift if self.with_mapping: s_i_recon = self.fea_m(s_ri, s_si) else: s_i_recon = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) # decode (within domain) x_i_recon = self.gen_i.decode(c_i, s_i_prime) x_r_recon = self.gen_s.decode(c_r, s_r_prime) x_s_recon = self.gen_r.decode(c_s, s_s_prime) # decode (cross domain) x_rs = self.gen_r.decode(c_s, s_r_rand) x_ri = self.gen_r.decode(c_i, s_ri) x_ri_rand = self.gen_r.decode(c_i, s_r_rand) x_sr = self.gen_s.decode(c_r, s_s_rand) x_si = self.gen_s.decode(c_i, s_si) x_si_rand = self.gen_s.decode(c_i, s_r_rand) # encode again, for feature domain consistency constraints c_rs_recon, s_rs_recon = self.gen_r.encode(x_rs) c_ri_recon, s_ri_recon = self.gen_r.encode(x_ri) c_ri_rand_recon, s_ri_rand_recon = self.gen_r.encode(x_ri_rand) c_sr_recon, s_sr_recon = self.gen_s.encode(x_sr) c_si_recon, s_si_recon = self.gen_s.encode(x_si) c_si_rand_recon, s_si_rand_recon = self.gen_s.encode(x_si_rand) # decode again, for image domain cycle consistency x_rsr = self.gen_r.decode(c_sr_recon, s_r_prime) x_iri = self.gen_i.decode(c_ri_recon, s_i_prime) x_iri_rand = self.gen_i.decode(c_ri_rand_recon, s_i_prime) x_srs = self.gen_s.decode(c_rs_recon, s_s_prime) x_isi = self.gen_i.decode(c_si_recon, s_i_prime) x_isi_rand = self.gen_i.decode(c_si_rand_recon, s_i_prime) # ============= Loss Functions ============= # Encoder decoder reconstruction loss for three domain self.loss_gen_recon_x_i = self.recon_criterion(x_i_recon, x_i) self.loss_gen_recon_x_r = self.recon_criterion(x_r_recon, x_r) self.loss_gen_recon_x_s = self.recon_criterion(x_s_recon, x_s) # Style-level reconstruction loss for cross domain if self.with_mapping: self.loss_gen_recon_s_ii = self.recon_criterion( s_i_recon, s_i_prime) else: self.loss_gen_recon_s_ii = 0 self.loss_gen_recon_s_ri = self.recon_criterion(s_ri_recon, s_ri) self.loss_gen_recon_s_ri_rand = self.recon_criterion( s_ri_rand_recon, s_ri) self.loss_gen_recon_s_rs = self.recon_criterion(s_rs_recon, s_r_rand) self.loss_gen_recon_s_sr = self.recon_criterion(s_sr_recon, s_s_rand) self.loss_gen_recon_s_si = self.recon_criterion(s_si_recon, s_si) self.loss_gen_recon_s_si_rand = self.recon_criterion( s_si_rand_recon, s_si) # Content-level reconstruction loss for cross domain self.loss_gen_recon_c_rs = self.recon_criterion(c_rs_recon, c_s) self.loss_gen_recon_c_ri = self.recon_criterion( c_ri_recon, c_i) if self.use_content_loss is True else 0 self.loss_gen_recon_c_ri_rand = self.recon_criterion( c_ri_rand_recon, c_i) if self.use_content_loss is True else 0 self.loss_gen_recon_c_sr = self.recon_criterion(c_sr_recon, c_r) self.loss_gen_recon_c_si = self.recon_criterion( c_si_recon, c_i) if self.use_content_loss is True else 0 self.loss_gen_recon_c_si_rand = self.recon_criterion( c_si_rand_recon, c_i) if self.use_content_loss is True else 0 # Cycle consistency loss for three image domain self.loss_gen_cyc_recon_x_rs = self.recon_criterion(x_rsr, x_r) self.loss_gen_cyc_recon_x_ir = self.recon_criterion(x_iri, x_i) self.loss_gen_cyc_recon_x_ir_rand = self.recon_criterion( x_iri_rand, x_i) self.loss_gen_cyc_recon_x_sr = self.recon_criterion(x_srs, x_s) self.loss_gen_cyc_recon_x_is = self.recon_criterion(x_isi, x_i) self.loss_gen_cyc_recon_x_is_rand = self.recon_criterion( x_isi_rand, x_i) # GAN loss self.loss_gen_adv_rs = self.dis_r.calc_gen_loss(x_rs) self.loss_gen_adv_ri = self.dis_r.calc_gen_loss(x_ri) self.loss_gen_adv_ri_rand = self.dis_r.calc_gen_loss(x_ri_rand) self.loss_gen_adv_sr = self.dis_s.calc_gen_loss(x_sr) self.loss_gen_adv_si = self.dis_s.calc_gen_loss(x_si) self.loss_gen_adv_si_rand = self.dis_s.calc_gen_loss(x_si_rand) # Physical loss self.loss_gen_phy_i = self.physical_criterion( x_i, x_ri, x_si) if self.use_phy_loss is True else 0 self.loss_gen_phy_i_rand = self.physical_criterion( x_i, x_ri_rand, x_si_rand) if self.use_phy_loss is True else 0 # Reflectance smoothness loss self.loss_refl_ri = self.reflectance_loss( x_ri, targets) if targets is not None else 0 self.loss_refl_ri_rand = self.reflectance_loss( x_ri_rand, targets) if targets is not None else 0 # total loss self.loss_gen_total = param['gan_w'] * self.loss_gen_adv_rs + \ param['gan_w'] * self.loss_gen_adv_ri + \ param['gan_w'] * self.loss_gen_adv_ri_rand + \ param['gan_w'] * self.loss_gen_adv_sr + \ param['gan_w'] * self.loss_gen_adv_si + \ param['gan_w'] * self.loss_gen_adv_si_rand + \ param['recon_x_w'] * self.loss_gen_recon_x_i + \ param['recon_x_w'] * self.loss_gen_recon_x_r + \ param['recon_x_w'] * self.loss_gen_recon_x_s + \ param['recon_s_w'] * self.loss_gen_recon_s_ii + \ param['recon_s_w'] * self.loss_gen_recon_s_ri + \ param['recon_s_w'] * self.loss_gen_recon_s_ri_rand + \ param['recon_s_w'] * self.loss_gen_recon_s_rs + \ param['recon_s_w'] * self.loss_gen_recon_s_si + \ param['recon_s_w'] * self.loss_gen_recon_s_si_rand + \ param['recon_s_w'] * self.loss_gen_recon_s_sr + \ param['recon_c_w'] * self.loss_gen_recon_c_ri + \ param['recon_c_w'] * self.loss_gen_recon_c_rs + \ param['recon_c_w'] * self.loss_gen_recon_c_ri_rand + \ param['recon_c_w'] * self.loss_gen_recon_c_si + \ param['recon_c_w'] * self.loss_gen_recon_c_sr + \ param['recon_c_w'] * self.loss_gen_recon_c_si_rand + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_ir + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_ir_rand + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_is + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_is_rand + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_rs + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_sr + \ param['phy_x_w'] * self.loss_gen_phy_i + \ param['phy_x_w'] * self.loss_gen_phy_i_rand + \ param['refl_smooth_w'] * self.loss_refl_ri + \ param['refl_smooth_w'] * self.loss_refl_ri_rand self.loss_gen_total.backward() self.gen_opt.step() def sample(self, x_i, x_r, x_s): self.eval() s_r = Variable(self.s_r) s_s = Variable(self.s_s) x_i_recon, x_r_recon, x_s_recon, x_rs, x_ri, x_sr, x_si = [], [], [], [], [], [], [] for i in range(x_i.size(0)): c_i, s_i_fake = self.gen_i.encode(x_i[i].unsqueeze(0)) c_r, s_r_fake = self.gen_r.encode(x_r[i].unsqueeze(0)) c_s, s_s_fake = self.gen_s.encode(x_s[i].unsqueeze(0)) if self.with_mapping: s_ri, s_si = self.fea_s(s_i_fake) else: s_ri = Variable(torch.randn(1, self.style_dim, 1, 1).cuda()) + self.bias_shift s_si = Variable(torch.randn(1, self.style_dim, 1, 1).cuda()) - self.bias_shift x_i_recon.append(self.gen_i.decode(c_i, s_i_fake)) x_r_recon.append(self.gen_r.decode(c_r, s_r_fake)) x_s_recon.append(self.gen_s.decode(c_s, s_s_fake)) x_rs.append(self.gen_r.decode(c_s, s_r[i].unsqueeze(0))) x_ri.append(self.gen_r.decode(c_i, s_ri.unsqueeze(0))) x_sr.append(self.gen_s.decode(c_s, s_s[i].unsqueeze(0))) x_si.append(self.gen_s.decode(c_i, s_si.unsqueeze(0))) x_i_recon, x_r_recon, x_s_recon = torch.cat(x_i_recon), torch.cat( x_r_recon), torch.cat(x_s_recon) x_rs, x_ri = torch.cat(x_rs), torch.cat(x_ri) x_sr, x_si = torch.cat(x_sr), torch.cat(x_si) self.train() return x_i, x_i_recon, x_r, x_r_recon, x_rs, x_ri, x_s, x_s_recon, x_sr, x_si # noinspection PyAttributeOutsideInit def dis_update(self, x_i, x_r, x_s, params): self.dis_opt.zero_grad() s_r = Variable(torch.randn(x_r.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift s_s = Variable(torch.randn(x_s.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift # encode c_r, _ = self.gen_r.encode(x_r) c_s, _ = self.gen_s.encode(x_s) c_i, s_i = self.gen_i.encode(x_i) if self.with_mapping: s_ri, s_si = self.fea_s(s_i) else: s_ri = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_si = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift # decode (cross domain) x_rs = self.gen_r.decode(c_s, s_r) x_ri = self.gen_r.decode(c_i, s_ri) x_sr = self.gen_s.decode(c_r, s_s) x_si = self.gen_s.decode(c_i, s_si) # D loss self.loss_dis_rs = self.dis_r.calc_dis_loss(x_rs.detach(), x_r) self.loss_dis_ri = self.dis_r.calc_dis_loss(x_ri.detach(), x_r) self.loss_dis_sr = self.dis_s.calc_dis_loss(x_sr.detach(), x_s) self.loss_dis_si = self.dis_s.calc_dis_loss(x_si.detach(), x_s) self.loss_dis_total = params['gan_w'] * self.loss_dis_rs +\ params['gan_w'] * self.loss_dis_ri +\ params['gan_w'] * self.loss_dis_sr +\ params['gan_w'] * self.loss_dis_si self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, param): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_i.load_state_dict(state_dict['i']) self.gen_r.load_state_dict(state_dict['r']) self.gen_s.load_state_dict(state_dict['s']) if self.with_mapping: self.fea_m.load_state_dict(state_dict['fm']) self.fea_s.load_state_dict(state_dict['fs']) self.best_result = state_dict['best_result'] iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_r.load_state_dict(state_dict['r']) self.dis_s.load_state_dict(state_dict['s']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, param, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, param, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') if self.with_mapping: torch.save( { 'i': self.gen_i.state_dict(), 'r': self.gen_r.state_dict(), 's': self.gen_s.state_dict(), 'fs': self.fea_s.state_dict(), 'fm': self.fea_m.state_dict(), 'best_result': self.best_result }, gen_name) else: torch.save( { 'i': self.gen_i.state_dict(), 'r': self.gen_r.state_dict(), 's': self.gen_s.state_dict(), 'best_result': self.best_result }, gen_name) torch.save({ 'r': self.dis_r.state_dict(), 's': self.dis_s.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters, resume_epoch=-1, snapshot_dir=None): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks. self.gen = AdaINGen( hyperparameters['input_dim'] + hyperparameters['n_datasets'], hyperparameters['gen'], hyperparameters['n_datasets']) # Auto-encoder for domain a. self.dis = MsImageDis( hyperparameters['input_dim'] + hyperparameters['n_datasets'], hyperparameters['dis']) # Discriminator for domain a. self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] self.beta1 = hyperparameters['beta1'] self.beta2 = hyperparameters['beta2'] self.weight_decay = hyperparameters['weight_decay'] # Initiating and loader pretrained UNet. self.sup = UNet(input_channels=hyperparameters['input_dim'], num_classes=2).cuda() # Fix the noise used in sampling. self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda() # Setup the optimizers. beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis.parameters()) gen_params = list(self.gen.parameters()) + list(self.sup.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(self.beta1, self.beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(self.beta1, self.beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization. self.apply(weights_init(hyperparameters['init'])) self.dis.apply(weights_init('gaussian')) # Presetting one hot encoding vectors. self.one_hot_img = torch.zeros(hyperparameters['n_datasets'], hyperparameters['batch_size'], hyperparameters['n_datasets'], 256, 256).cuda() self.one_hot_c = torch.zeros(hyperparameters['n_datasets'], hyperparameters['batch_size'], hyperparameters['n_datasets'], 64, 64).cuda() for i in range(hyperparameters['n_datasets']): self.one_hot_img[i, :, i, :, :].fill_(1) self.one_hot_c[i, :, i, :, :].fill_(1) if resume_epoch != -1: self.resume(snapshot_dir, hyperparameters) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def semi_criterion(self, input, target): loss = CrossEntropyLoss2d(size_average=False).cuda() return loss(input, target) def forward(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True s_a = Variable(self.s_a, volatile=True) s_b = Variable(self.s_b, volatile=True) one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) c_a, s_a_fake = self.gen.encode(one_hot_x_a) c_b, s_b_fake = self.gen.encode(one_hot_x_b) one_hot_c_b = torch.cat([c_b, self.one_hot_c[d_index_a]], 1) one_hot_c_a = torch.cat([c_a, self.one_hot_c[d_index_b]], 1) x_ba = self.gen.decode(one_hot_c_b, s_a) x_ab = self.gen.decode(one_hot_c_a, s_b) self.train() return x_ab, x_ba def set_gen_trainable(self, train_bool): if train_bool: self.gen.train() for param in self.gen.parameters(): param.requires_grad = True else: self.gen.eval() for param in self.gen.parameters(): param.requires_grad = True def set_sup_trainable(self, train_bool): if train_bool: self.sup.train() for param in self.sup.parameters(): param.requires_grad = True else: self.sup.eval() for param in self.sup.parameters(): param.requires_grad = True def sup_update(self, x_a, x_b, y_a, y_b, d_index_a, d_index_b, use_a, use_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) # Encode. c_a, s_a_prime = self.gen.encode(one_hot_x_a) c_b, s_b_prime = self.gen.encode(one_hot_x_b) # Decode (within domain). one_hot_c_a = torch.cat([c_a, self.one_hot_c[d_index_a]], 1) one_hot_c_b = torch.cat([c_b, self.one_hot_c[d_index_b]], 1) x_a_recon = self.gen.decode(one_hot_c_a, s_a_prime) x_b_recon = self.gen.decode(one_hot_c_b, s_b_prime) # Decode (cross domain). one_hot_c_ab = torch.cat([c_a, self.one_hot_c[d_index_b]], 1) one_hot_c_ba = torch.cat([c_b, self.one_hot_c[d_index_a]], 1) x_ba = self.gen.decode(one_hot_c_ba, s_a) x_ab = self.gen.decode(one_hot_c_ab, s_b) # Encode again. one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1) one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1) c_b_recon, s_a_recon = self.gen.encode(one_hot_x_ba) c_a_recon, s_b_recon = self.gen.encode(one_hot_x_ab) # Forwarding through supervised model. p_a = None p_b = None loss_semi_a = None loss_semi_b = None has_a_label = (c_a[use_a, :, :, :].size(0) != 0) if has_a_label: p_a = self.sup(c_a, use_a, True) p_a_recon = self.sup(c_a_recon, use_a, True) loss_semi_a = self.semi_criterion(p_a, y_a[use_a, :, :]) + \ self.semi_criterion(p_a_recon, y_a[use_a, :, :]) has_b_label = (c_b[use_b, :, :, :].size(0) != 0) if has_b_label: p_b = self.sup(c_b, use_b, True) p_b_recon = self.sup(c_b, use_b, True) loss_semi_b = self.semi_criterion(p_b, y_b[use_b, :, :]) + \ self.semi_criterion(p_b_recon, y_b[use_b, :, :]) self.loss_gen_total = None if loss_semi_a is not None and loss_semi_b is not None: self.loss_gen_total = loss_semi_a + loss_semi_b elif loss_semi_a is not None: self.loss_gen_total = loss_semi_a elif loss_semi_b is not None: self.loss_gen_total = loss_semi_b if self.loss_gen_total is not None: self.loss_gen_total.backward() self.gen_opt.step() def sup_forward(self, x, y, d_index, hyperparameters): self.sup.eval() # Encoding content image. one_hot_x = torch.cat([x, self.one_hot_img[d_index, 0].unsqueeze(0)], 1) content, _ = self.gen.encode(one_hot_x) # Forwarding on supervised model. y_pred = self.sup(content, only_prediction=True) # Computing metrics. pred = y_pred.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy() jacc = jaccard(pred, y.cpu().squeeze(0).numpy()) return jacc, pred, content def gen_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # Encode. one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) c_a, s_a_prime = self.gen.encode(one_hot_x_a) c_b, s_b_prime = self.gen.encode(one_hot_x_b) # Decode (within domain). one_hot_c_a = torch.cat([c_a, self.one_hot_c[d_index_a]], 1) one_hot_c_b = torch.cat([c_b, self.one_hot_c[d_index_b]], 1) x_a_recon = self.gen.decode(one_hot_c_a, s_a_prime) x_b_recon = self.gen.decode(one_hot_c_b, s_b_prime) # Decode (cross domain). one_hot_c_ab = torch.cat([c_a, self.one_hot_c[d_index_b]], 1) one_hot_c_ba = torch.cat([c_b, self.one_hot_c[d_index_a]], 1) x_ba = self.gen.decode(one_hot_c_ba, s_a) x_ab = self.gen.decode(one_hot_c_ab, s_b) # Encode again. one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1) one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1) c_b_recon, s_a_recon = self.gen.encode(one_hot_x_ba) c_a_recon, s_b_recon = self.gen.encode(one_hot_x_ab) # Decode again (if needed). one_hot_c_aba_recon = torch.cat([c_a_recon, self.one_hot_c[d_index_a]], 1) one_hot_c_bab_recon = torch.cat([c_b_recon, self.one_hot_c[d_index_b]], 1) x_aba = self.gen.decode(one_hot_c_aba_recon, s_a_prime) x_bab = self.gen.decode(one_hot_c_bab_recon, s_b_prime) # Reconstruction loss. self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) # GAN loss. self.loss_gen_adv_a = self.dis.calc_gen_loss(one_hot_x_ba) self.loss_gen_adv_b = self.dis.calc_gen_loss(one_hot_x_ab) # Total loss. self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True s_a1 = Variable(self.s_a, volatile=True) s_b1 = Variable(self.s_b, volatile=True) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(), volatile=True) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(), volatile=True) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): one_hot_x_a = torch.cat( [x_a[i].unsqueeze(0), self.one_hot_img_a[i].unsqueeze(0)], 1) one_hot_x_b = torch.cat( [x_b[i].unsqueeze(0), self.one_hot_img_b[i].unsqueeze(0)], 1) c_a, s_a_fake = self.gen.encode(one_hot_x_a) c_b, s_b_fake = self.gen.encode(one_hot_x_b) x_a_recon.append(self.gen.decode(c_a, s_a_fake)) x_b_recon.append(self.gen.decode(c_b, s_b_fake)) x_ba1.append(self.gen.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # Encode. one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) c_a, _ = self.gen.encode(one_hot_x_a) c_b, _ = self.gen.encode(one_hot_x_b) one_hot_c_ba = torch.cat([c_b, self.one_hot_c[d_index_a]], 1) one_hot_c_ab = torch.cat([c_a, self.one_hot_c[d_index_b]], 1) # Decode (cross domain). x_ba = self.gen.decode(one_hot_c_ba, s_a) x_ab = self.gen.decode(one_hot_c_ab, s_b) # D loss. one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1) one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1) self.loss_dis_a = self.dis.calc_dis_loss(one_hot_x_ba, one_hot_x_a) self.loss_dis_b = self.dis.calc_dis_loss(one_hot_x_ab, one_hot_x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \ hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): print("--> " + checkpoint_dir) # Load generator. last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen.load_state_dict(state_dict) epochs = int(last_model_name[-11:-3]) # Load supervised model. last_model_name = get_model_list(checkpoint_dir, "sup") state_dict = torch.load(last_model_name) self.sup.load_state_dict(state_dict) # Load discriminator. last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis.load_state_dict(state_dict) # Load optimizers. last_model_name = get_model_list(checkpoint_dir, "opt") state_dict = torch.load(last_model_name) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) for state in self.dis_opt.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() for state in self.gen_opt.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() # Reinitilize schedulers. self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, epochs) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, epochs) print('Resume from epoch %d' % epochs) return epochs def save(self, snapshot_dir, epoch): # Save generators, discriminators, and optimizers. gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % epoch) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % epoch) sup_name = os.path.join(snapshot_dir, 'sup_%08d.pt' % epoch) opt_name = os.path.join(snapshot_dir, 'opt_%08d.pt' % epoch) torch.save(self.gen.state_dict(), gen_name) torch.save(self.dis.state_dict(), dis_name) torch.save(self.sup.state_dict(), sup_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class DGNet_Trainer(nn.Module): #初始化函数 def __init__(self, hyperparameters, gpu_ids=[0]): super(DGNet_Trainer, self).__init__() # 从配置文件获取生成模型的和鉴别模型的学习率 lr_g = hyperparameters['lr_g'] lr_d = hyperparameters['lr_d'] # # ID的类别,这里要注意,不同的数据集都是不一样的,应该是训练数据集的ID数目,非测试集 ID_class = hyperparameters['ID_class'] # 看是否设置使用float16,估计float16可以增加精确度 if not 'apex' in hyperparameters.keys(): hyperparameters['apex'] = False self.fp16 = hyperparameters['apex'] # Initiate the networks # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False. ################################################################################################################ ##这里是定义Es和G # 注意这里包含了两个步骤,Es编码+解码过程,既然解码(论文Figure 2的黄色梯形G)包含到这里了,下面Ea应该不会包含解码过程了 # 因为这里是一个类,如后续gen_a.encode()可以进行编码,gen_b.encode()可以进行解码 self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16=False) # auto-encoder for domain a self.gen_b = self.gen_a # auto-encoder for domain b ############################################################################################################################################ ############################################################################################################################################ ##这里是定义Ea # ID_stride,外观编码器池化层的stride if not 'ID_stride' in hyperparameters.keys(): hyperparameters['ID_stride'] = 2 # hyperparameters['ID_style']默认为'AB',论文中的Ea编码器 #这里是设置Ea,有三种模型可以选择 #PCB模型,ft_netAB为改造后的resnet50,ft_net为resnet50 if hyperparameters['ID_style'] == 'PCB': self.id_a = PCB(ID_class) elif hyperparameters['ID_style'] == 'AB': # 这是我们执行的模型,注意的是,id_a返回两个x(表示身份),获得f,具体介绍看函数内部 # 我们使用的是ft_netAB,是代码中Ea编码的过程,也就得到 ap code的过程,除了ap code还会得到两个分类结果 # 现在怀疑,该分类结果,可能就是行人重识别的结果 #ID_class表示有ID_class个不同ID的行人 self.id_a = ft_netAB(ID_class, stride=hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) else: self.id_a = ft_net(ID_class, norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) # return 2048 now # 这里进行的是浅拷贝,所以我认为他们的权重是一起的,可以理解为一个 self.id_b = self.id_a ############################################################################################################################################################ ############################################################################################################################################################ ##这里是定义D # 鉴别器,行人重识别,这里使用的是一个多尺寸的鉴别器,大概就是说,对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失 # 经过网络3个元素,分别大小为[batch_size,1,64,32], [batch_size,1,32,16], [batch_size,1,16,8] self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16=False) # discriminator for domain a self.dis_b = self.dis_a # discriminator for domain b ############################################################################################################################################################ ############################################################################################################################################################ # load teachers # 加载老师模型 # teacher:老师模型名称。对于DukeMTMC,您可以设置“best - duke” if hyperparameters['teacher'] != "": #teacher_name=best teacher_name = hyperparameters['teacher'] print(teacher_name) #有这个操作,我怀疑是可以加载多个教师模型 teacher_names = teacher_name.split(',') #构建老师模型 teacher_model = nn.ModuleList() teacher_count = 0 # 默认只有一个teacher_name='teacher_name',所以其加载的模型配置文件为项目根目录models/best/opts.yaml模型 for teacher_name in teacher_names: # 加载配置文件models/best/opts.yaml config_tmp = load_config(teacher_name) if 'stride' in config_tmp: #stride=1 stride = config_tmp['stride'] else: stride = 2 # 老师模型加载,老师模型为ft_net为resnet50 model_tmp = ft_net(ID_class, stride=stride) teacher_model_tmp = load_network(model_tmp, teacher_name) # 移除原本的全连接层 teacher_model_tmp.model.fc = nn.Sequential( ) # remove the original fc layer in ImageNet teacher_model_tmp = teacher_model_tmp.cuda() # summary(teacher_model_tmp, (3, 224, 224)) #使用浮点型 if self.fp16: teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1") teacher_model.append(teacher_model_tmp.cuda().eval()) teacher_count += 1 self.teacher_model = teacher_model # 选择是否使用bn if hyperparameters['train_bn']: self.teacher_model = self.teacher_model.apply(train_bn) ############################################################################################################################################################ # 实例正则化 self.instancenorm = nn.InstanceNorm2d(512, affine=False) # RGB to one channel # 默认设置signal=gray,Es的输入为灰度图 if hyperparameters['single'] == 'edge': self.single = to_edge else: self.single = to_gray(False) # Random Erasing when training #earsing_p表示随机擦除的概率 if not 'erasing_p' in hyperparameters.keys(): self.erasing_p = 0 else: self.erasing_p = hyperparameters['erasing_p'] #随机擦除矩形区域的一些像素,应该类似于数据增强 self.single_re = RandomErasing(probability=self.erasing_p, mean=[0.0, 0.0, 0.0]) # 设置T_w为1,T_w为primary feature learning loss的权重系数 if not 'T_w' in hyperparameters.keys(): hyperparameters['T_w'] = 1 ################################################################################################ # Setup the optimizers # 设置优化器参数 beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list( self.dis_a.parameters()) #+ list(self.dis_b.parameters()) gen_params = list( self.gen_a.parameters()) #+ list(self.gen_b.parameters()) #使用Adams优化器,用Adams训练Es,G,D self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) # id params # 因为ID_style默认为AB,所以这里不执行 if hyperparameters['ID_style'] == 'PCB': ignored_params = ( list(map(id, self.id_a.classifier0.parameters())) + list(map(id, self.id_a.classifier1.parameters())) + list(map(id, self.id_a.classifier2.parameters())) + list(map(id, self.id_a.classifier3.parameters()))) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] #Ea 的优化器 self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier0.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier1.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier2.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier3.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) # 这里是我们执行的代码 elif hyperparameters['ID_style'] == 'AB': # 忽略的参数,应该是适用于'PCB'或者其他的,但是不适用于'AB'的 ignored_params = ( list(map(id, self.id_a.classifier1.parameters())) + list(map(id, self.id_a.classifier2.parameters()))) # 获得基本的配置参数,如学习率 base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] #对Ea使用SGD self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier1.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier2.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) else: ignored_params = list(map(id, self.id_a.classifier.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) # 选择各个网络的优化 self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) self.id_scheduler = get_scheduler(self.id_opt, hyperparameters) self.id_scheduler.gamma = hyperparameters['gamma2'] #ID Loss #交叉熵损失函数 self.id_criterion = nn.CrossEntropyLoss() # KL散度 self.criterion_teacher = nn.KLDivLoss(size_average=False) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False # save memory if self.fp16: # Name the FP16_Optimizer instance to replace the existing optimizer assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." self.gen_a = self.gen_a.cuda() self.dis_a = self.dis_a.cuda() self.id_a = self.id_a.cuda() self.gen_b = self.gen_a self.dis_b = self.dis_a self.id_b = self.id_a self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1") self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1") self.id_a, self.id_opt = amp.initialize(self.id_a, self.id_opt, opt_level="O1") def to_re(self, x): out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3)) out = out.cuda() for i in range(x.size(0)): out[i, :, :, :] = self.single_re(x[i, :, :, :]) return out # L1 loss,(差的绝对值) def recon_criterion(self, input, target): diff = input - target.detach() return torch.mean(torch.abs(diff[:])) #L1 loss 开根号((差的绝对值后开根号)) def recon_criterion_sqrt(self, input, target): diff = input - target return torch.mean(torch.sqrt(torch.abs(diff[:]) + 1e-8)) # L2 loss def recon_criterion2(self, input, target): diff = input - target return torch.mean(diff[:]**2) # cos loss def recon_cos(self, input, target): cos = torch.nn.CosineSimilarity() cos_dis = 1 - cos(input, target) return torch.mean(cos_dis[:]) # x_a,x_b, xp_a, xp_b[4, 3, 256, 128], # 第一个参数表示bitch size,第二个参数表示输入通道数,第三个参数表示输入图片的高度,第四个参数表示输入图片的宽度 def forward(self, x_a, x_b, xp_a, xp_b): #送入x_a,x_b两张图片(来自训练集不同ID) #通过st编码器,编码成两个stcode,structure code # s_a[batch,128,64,32] # s_b[batch,128,64,32] # single会根据参数设定判断是否转化为灰度图 s_a = self.gen_a.encode(self.single(x_a)) s_b = self.gen_b.encode(self.single(x_b)) # 先把图片进行下采样,图示我们可以看到ap code的体积比st code是要小的,这样会出现一个情况,那么他们是没有办法直接融合的,所以后面有个全链接成把他们统一 # f_a[batch_size,2024*4=8192], p_a[0]=[batch_size, class_num=751], p_a[1]=[batch_size, class_num=751] # f_b[batch_size,2024*4=8192], p_b[0]=[batch_size, class_num=751], p_b[1]=[batch_size, class_num=751] # f代表的是经过ap编码器得到的ap code, # p表示对身份的预测(有两个身份预测,也就是p_a了两个元素,这里不好解释), # 前面提到过,ap编码器,不仅负责编码,还要负责身份的预测(行人重识别),也是我们落实项目的关键所在 # 这里是第一个重难点,在论文的翻译中提到过,后续详细讲解 f_a, p_a = self.id_a(scale2(x_a)) f_b, p_b = self.id_b(scale2(x_b)) # 进行解码操作,就是Figure 2中的黄色梯形G操作,这里的x_a,与x_b进行衣服互换,不同ID # s_b[batch,128,64,32] f_a[batch_size,2028,4,1] --> x_ba[batch_size,3,256,128] x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) #同一张图片进行重构,相当于autoencoder x_a_recon = self.gen_a.decode(s_a, f_a) x_b_recon = self.gen_b.decode(s_b, f_b) fp_a, pp_a = self.id_a(scale2(xp_a)) fp_b, pp_b = self.id_b(scale2(xp_b)) # decode the same person #x_a,xp_a表示同ID的不同图片,以下即表示同ID不同图片的重构 x_a_recon_p = self.gen_a.decode(s_a, fp_a) x_b_recon_p = self.gen_b.decode(s_b, fp_b) # Random Erasing only effect the ID and PID loss. #把图片擦除一些像素,然后进行ap code编码 if self.erasing_p > 0: #先把每一张图片都擦除一些像素 x_a_re = self.to_re(scale2(x_a.clone())) x_b_re = self.to_re(scale2(x_b.clone())) xp_a_re = self.to_re(scale2(xp_a.clone())) xp_b_re = self.to_re(scale2(xp_b.clone())) # 然后经过编码成ap code,暂时不知道作用,感觉应该是数据增强 # 类似于,擦除了图片的一些像素,但是已经能够识别出来这些图片是谁 _, p_a = self.id_a(x_a_re) _, p_b = self.id_b(x_b_re) # encode the same ID different photo _, pp_a = self.id_a(xp_a_re) _, pp_b = self.id_b(xp_b_re) # 混合合成图片:x_ab[images_a的st,images_b的ap] 混合合成图片x_ba[images_b的st,images_a的ap] # s_a[输入图片images_a经过Es编码得到的 st code] s_b[输入图片images_b经过Es编码得到的 st code] # f_a[输入图片images_a经过Ea编码得到的 ap code] f_b[输入图片images_b经过Ea编码得到的 ap code] # p_a[输入图片images_a经过Ea编码进行身份ID的预测] p_b[输入图片images_b经过Ea编码进行身份ID的预测] # pp_a[输入图片pos_a经过Ea编码进行身份ID的预测] pp_b[输入图片pos_b经过Ea编码进行身份ID的预测] # x_a_recon[输入图片images_a(s_a)与自己(f_a)合成的图片,当然和images_a长得一样] # x_b_recon[输入图片images_b(s_b)与自己(f_b)合成的图片,当然和images_b长得一样] # x_a_recon_p[输入图片images_a(s_a)与图片pos_a(fp_a)合成的图片,当然和images_a长得一样] # x_b_recon_p[输入图片images_a(s_a)与图片pos_b(fp_b)合成的图片,当然和images_b长得一样] return x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p def gen_update(self, x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, x_a, x_b, xp_a, xp_b, l_a, l_b, hyperparameters, iteration, num_gpu): """ :param x_ab:[images_a的st,images_b的ap] :param x_ba:[images_b的st,images_a的ap] :param s_a:[输入图片images_a经过Es编码得到的 st code] :param s_b:[输入图片images_b经过Es编码得到的 st code] :param f_a:[输入图片images_a经过Ea编码得到的 ap code] :param f_b:[输入图片images_b经过Ea编码得到的 ap code] :param p_a:[输入图片images_a经过Ea编码进行身份ID的预测] :param p_b:[输入图片images_b经过Ea编码进行身份ID的预测] :param pp_a:[输入图片pos_a经过Ea编码进行身份ID的预测] :param pp_b:[输入图片pos_b经过Ea编码进行身份ID的预测] :param x_a_recon:[输入图片images_a(s_a)与自己(f_a)合成的图片,当然和images_a长得一样] :param x_b_recon:[输入图片images_b(s_b)与自己(f_b)合成的图片,当然和images_b长得一样] :param x_a_recon_p:[输入图片images_a(s_a)与图片pos_a(fp_a)合成的图片,当然和images_a长得一样] :param x_b_recon_p:[输入图片images_b(s_b)与图片pos_b(fp_b)合成的图片,当然和images_b长得一样] :param x_a:images_a :param x_b:images_b :param xp_a:pos_a :param xp_b:pos_b :param l_a:labels_a :param l_b:labels_b :param hyperparameters: :param iteration: :param num_gpu: :return: """ # ppa, ppb is the same person? self.gen_opt.zero_grad() #梯度清零 self.id_opt.zero_grad() # no gradient # 对合成x_ba和x_ab分别进行一份拷贝 x_ba_copy = Variable(x_ba.data, requires_grad=False) x_ab_copy = Variable(x_ab.data, requires_grad=False) rand_num = random.uniform(0, 1) ################################# # encode structure # enc_content是类ContentEncoder对象 if hyperparameters['use_encoder_again'] >= rand_num: # encode again (encoder is tuned, input is fixed) # Es编码得到s_a_recon与s_b_recon即st code # 如果是理想模型,s_a_recon=s_a, s_b_recon=s_b s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy)) s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy)) else: # copy the encoder # 这里的是深拷贝 #enc_content_copy=gen_a.enc_content self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content) self.enc_content_copy = self.enc_content_copy.eval() # encode again (encoder is fixed, input is tuned) s_a_recon = self.enc_content_copy(self.single(x_ab)) s_b_recon = self.enc_content_copy(self.single(x_ba)) ################################# # encode appearance #id_a_copy=id_a=Ea self.id_a_copy = copy.deepcopy(self.id_a) self.id_a_copy = self.id_a_copy.eval() if hyperparameters['train_bn']: self.id_a_copy = self.id_a_copy.apply(train_bn) self.id_b_copy = self.id_a_copy # encode again (encoder is fixed, input is tuned) # 对混合生成的图片x_ba,x_ab进行Es编码操作,同时对身份进行鉴别# # f_a_recon,f_b_recon表示的ap code,p_a_recon,p_b_recon表示对身份的鉴别 f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba)) f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab)) # teacher Loss # Tune the ID model log_sm = nn.LogSoftmax(dim=1) #如果使用了教师网络 #默认ID_style为AB if hyperparameters['teacher_w'] > 0 and hyperparameters[ 'teacher'] != "": if hyperparameters['ID_style'] == 'normal': #p_a_student表示x_ba_copy的身份编码,使用的是Ea进行身份编码,也就是使用学生模型进行身份编码 _, p_a_student = self.id_a(scale2(x_ba_copy)) #对p_a_student使用logsoftmax,输出结果为x_ba_copy像某张图片的概率(就是一个分布) p_a_student = log_sm(p_a_student) #使用教师模型对生成图像x_ba_copy进行分类,输出结果为x_ba_copy像某张图片的概率(就是一个分布) p_a_teacher = predict_label( self.teacher_model, scale2(x_ba_copy), num_class=hyperparameters['ID_class'], alabel=l_a, slabel=l_b, teacher_style=hyperparameters['teacher_style']) #通过最小化KL散度损失函数,目的是让分布p_a_student与p_a_teacher尽可能的一致 self.loss_teacher = self.criterion_teacher( p_a_student, p_a_teacher) / p_a_student.size(0) #对x_ab_copy进行同样的操作 _, p_b_student = self.id_b(scale2(x_ab_copy)) p_b_student = log_sm(p_b_student) p_b_teacher = predict_label( self.teacher_model, scale2(x_ab_copy), num_class=hyperparameters['ID_class'], alabel=l_b, slabel=l_a, teacher_style=hyperparameters['teacher_style']) self.loss_teacher += self.criterion_teacher( p_b_student, p_b_teacher) / p_b_student.size(0) ####################################################################################################################################################################################################### # primary feature learning loss ####################################################################################################################################################################################################### # ID_style为AB elif hyperparameters['ID_style'] == 'AB': # normal teacher-student loss # BA -> LabelA(smooth) + LabelB(batchB) # 合成的图片经过身份鉴别器,得到每个ID可能性的概率,注意这里去的是p_ba_student[0],我们知有两个身份预测结果,这里只取了一个 # 并且赋值给了p_a_student,用于和教师模型结合的,共同计算损失 #p_a_student分为两个部分,p_a_student[0]表示L_prim,p_a_student[1]表示L_fine。 _, p_ba_student = self.id_a(scale2(x_ba_copy)) # f_a, s_b p_a_student = log_sm(p_ba_student[0]) with torch.no_grad(): ##使用教师模型对生成图像x_ba_copy进行分类,输出结果为x_ba_copy像某张图片(x_a/x_b)的概率(就是一个分布) p_a_teacher = predict_label( self.teacher_model, scale2(x_ba_copy), num_class=hyperparameters['ID_class'], alabel=l_a, slabel=l_b, teacher_style=hyperparameters['teacher_style']) # criterion_teacher = nn.KLDivLoss(size_average=False) # 计算离散距离,可以理解为p_a_student与p_a_teacher每个元素的距离之和,然后除以p_a_student.size(0)取平均值 # 就是说学生网络(Ea)的预测越与教师网络结果相同,则是最好的 self.loss_teacher = self.criterion_teacher( p_a_student, p_a_teacher) / p_a_student.size(0) # 对另一张合成图片进行同样的操作 _, p_ab_student = self.id_b(scale2(x_ab_copy)) # f_b, s_a p_b_student = log_sm(p_ab_student[0]) with torch.no_grad(): p_b_teacher = predict_label( self.teacher_model, scale2(x_ab_copy), num_class=hyperparameters['ID_class'], alabel=l_b, slabel=l_a, teacher_style=hyperparameters['teacher_style']) self.loss_teacher += self.criterion_teacher( p_b_student, p_b_teacher) / p_b_student.size(0) ######################################################################################################################################################################################################## ######################################################################################################################################################################################################## #fine—grained feature mining loss ######################################################################################################################################################################################################## # branch b loss # here we give different label # p_ba_student[1]表示的是f_fine特征,l_b表示的是images_b,即为生成图像提供st code 的图片 loss_B = self.id_criterion(p_ba_student[1], l_b) + self.id_criterion( p_ab_student[1], l_a) ####################################################################################################################################################################################################### # 对两部分损失进行权重调整 self.loss_teacher = hyperparameters[ 'T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B else: self.loss_teacher = 0.0 ## 剩下的就是重构图像之间的损失了 # 前面提到,重构和合成是不一样的,重构是构建出来和原来图片一样的图片 # 所以也就是可以把重构的图片和原来的图像直接计算像素直接的插值 # 但是合成的图片是没有办法的,因为训练数据集是没有合成图片的,所以,没有办法计算像素之间的损失 # ####################################################################################################################################################################################################### # auto-encoder image reconstruction # 同ID图像进行重构时的损失函数 self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a) self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b) # ####################################################################################################################################################################################################### ####################################################################################################################################################################################################### # feature reconstruction # 不同ID图像进行图像合成时,为了保证合成图像的st code和ap code与为合成图像提供st code 和 ap code保持一致所使用的损失函数 self.loss_gen_recon_s_a = self.recon_criterion( s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0 self.loss_gen_recon_s_b = self.recon_criterion( s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0 self.loss_gen_recon_f_a = self.recon_criterion( f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0 self.loss_gen_recon_f_b = self.recon_criterion( f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0 # ####################################################################################################################################################################################################### # 又一次进行图像合成 x_aba = self.gen_a.decode( s_a_recon, f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( s_b_recon, f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # ID loss AND Tune the Generated image if hyperparameters['ID_style'] == 'PCB': self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b) self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b) self.loss_gen_recon_id = self.PCB_loss( p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b) ######################################################################################################################################################################################################## # 使用的是 ['ID_style']=='AB' elif hyperparameters['ID_style'] == 'AB': weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w'] #计算的是L^s_id self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \ + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) ) #对同ID不同图片计算L^s_id self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion( pp_b[0], l_b ) #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) ) # 对生成图像计算L^C_id self.loss_gen_recon_id = self.id_criterion( p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b) ######################################################################################################################################################################################################## else: self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion( p_b, l_b) self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion( pp_b, l_b) self.loss_gen_recon_id = self.id_criterion( p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b) #print(f_a_recon, f_a) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 ######################################################################################################################################################################################################## # GAN loss #计算生成器G的对抗损失函数 ######################################################################################################################################################################################################## if num_gpu > 1: self.loss_gen_adv_a = self.dis_a.module.calc_gen_loss( self.dis_a, x_ba) self.loss_gen_adv_b = self.dis_b.module.calc_gen_loss( self.dis_b, x_ab) else: self.loss_gen_adv_a = self.dis_a.calc_gen_loss(self.dis_a, x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(self.dis_b, x_ab) ######################################################################################################################################################################################################## # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 if iteration > hyperparameters['warm_iter']: hyperparameters['recon_f_w'] += hyperparameters['warm_scale'] hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w']) hyperparameters['recon_s_w'] += hyperparameters['warm_scale'] hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w']) hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale'] hyperparameters['recon_x_cyc_w'] = min( hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w']) if iteration > hyperparameters['warm_teacher_iter']: hyperparameters['teacher_w'] += hyperparameters['warm_scale'] hyperparameters['teacher_w'] = min( hyperparameters['teacher_w'], hyperparameters['max_teacher_w']) # total loss,计算总的loss #1个teacher loss+4个同ID图片重构loss+4个不同ID图片合成loss++3个ID loss+2个生成器loss、 #teacher loss包括了primary feature learning loss和fine_grain mining loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['id_w'] * self.loss_id + \ hyperparameters['pid_w'] * self.loss_pid + \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ hyperparameters['teacher_w'] * self.loss_teacher if self.fp16: with amp.scale_loss(self.loss_gen_total, [self.gen_opt, self.id_opt]) as scaled_loss: scaled_loss.backward() self.gen_opt.step() self.id_opt.step() else: self.loss_gen_total.backward() #计算梯度 self.gen_opt.step() #梯度更新 self.id_opt.step() #梯度更新 print("L_total: %.4f, L_gan: %.4f, Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \ hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \ hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \ hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \ hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \ hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \ hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \ hyperparameters['id_w'] * self.loss_id,\ hyperparameters['pid_w'] * self.loss_pid,\ hyperparameters['teacher_w'] * self.loss_teacher ) ) def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def PCB_loss(self, inputs, labels): loss = 0.0 for part in inputs: loss += self.id_criterion(part, labels) return loss / len(inputs) def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], [] for i in range(x_a.size(0)): s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0))) s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0))) f_a, _ = self.id_a(scale2(x_a[i].unsqueeze(0))) f_b, _ = self.id_b(scale2(x_b[i].unsqueeze(0))) x_a_recon.append(self.gen_a.decode(s_a, f_a)) x_b_recon.append(self.gen_b.decode(s_b, f_b)) x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) x_ba1.append(x_ba) x_ab1.append(x_ab) #cycle s_b_recon = self.gen_a.enc_content(self.single(x_ba)) s_a_recon = self.gen_b.enc_content(self.single(x_ab)) f_a_recon, _ = self.id_a(scale2(x_ba)) f_b_recon, _ = self.id_b(scale2(x_ab)) x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon)) x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1) self.train() return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1 def dis_update(self, x_ab, x_ba, x_a, x_b, hyperparameters, num_gpu): self.dis_opt.zero_grad() #梯度清零 # D loss #计算判别器的损失函数,然后计算梯度,进行梯度更新 #输入为(x_ba,x_a),(x_ab,x_b)两对图片,损失为两对图片的总和 if num_gpu > 1: self.loss_dis_a, reg_a = self.dis_a.module.calc_dis_loss( self.dis_a, x_ba.detach(), x_a) self.loss_dis_b, reg_b = self.dis_b.module.calc_dis_loss( self.dis_b, x_ab.detach(), x_b) else: # 计算判别器的损失函数 self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss( self.dis_a, x_ba.detach(), x_a) self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss( self.dis_b, x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b print("DLoss: %.4f" % self.loss_dis_total, "Reg: %.4f" % (reg_a + reg_b)) if self.fp16: with amp.scale_loss(self.loss_dis_total, self.dis_opt) as scaled_loss: scaled_loss.backward() else: self.loss_dis_total.backward() #计算梯度 self.dis_opt.step() #梯度更新 def update_learning_rate(self): #对学习率的更新 if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.id_scheduler is not None: self.id_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b = self.gen_a iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b = self.dis_a # Load ID dis last_model_name = get_model_list(checkpoint_dir, "id") state_dict = torch.load(last_model_name) self.id_a.load_state_dict(state_dict['a']) self.id_b = self.id_a # Load optimizers try: state_dict = torch.load( os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) self.id_opt.load_state_dict(state_dict['id']) except: pass # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations, num_gpu=1): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict()}, gen_name) if num_gpu > 1: torch.save({'a': self.dis_a.module.state_dict()}, dis_name) else: torch.save({'a': self.dis_a.state_dict()}, dis_name) torch.save({'a': self.id_a.state_dict()}, id_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'id': self.id_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters["lr"] self.gen_state = hyperparameters["gen_state"] self.guided = hyperparameters["guided"] self.newsize = hyperparameters["crop_image_height"] self.semantic_w = hyperparameters["semantic_w"] > 0 self.recon_mask = hyperparameters["recon_mask"] == 1 self.check_alignment = hyperparameters["check_alignment"] == 1 self.full_adaptation = hyperparameters["full_adaptation"] == 1 self.dann_scheduler = None self.full_adaptation = hyperparameters["full_adaptation"] == 1 if "domain_adv_w" in hyperparameters.keys(): self.domain_classif = hyperparameters["domain_adv_w"] > 0 else: self.domain_classif = False if self.gen_state == 0: # Initiate the networks self.gen_a = AdaINGen( hyperparameters["input_dim_a"], hyperparameters["gen"]) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters["input_dim_b"], hyperparameters["gen"]) # auto-encoder for domain b elif self.gen_state == 1: self.gen = AdaINGen_double(hyperparameters["input_dim_a"], hyperparameters["gen"]) else: print("self.gen_state unknown value:", self.gen_state) self.dis_a = MsImageDis( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters["input_dim_b"], hyperparameters["dis"]) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters["gen"]["style_dim"] # fix the noise used in sampling display_size = int(hyperparameters["display_size"]) print(self.style_dim) print(display_size) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters["beta1"] beta2 = hyperparameters["beta2"] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) if self.gen_state == 0: gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) elif self.gen_state == 1: gen_params = list(self.gen.parameters()) else: print("self.gen_state unknown value:", self.gen_state) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters["init"])) self.dis_a.apply(weights_init("gaussian")) self.dis_b.apply(weights_init("gaussian")) # Load VGG model if needed if "vgg_w" in hyperparameters.keys() and hyperparameters["vgg_w"] > 0: self.vgg = load_vgg16(hyperparameters["vgg_model_path"] + "/models") self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False # Load semantic segmentation model if needed if "semantic_w" in hyperparameters.keys( ) and hyperparameters["semantic_w"] > 0: self.segmentation_model = load_segmentation_model( hyperparameters["semantic_ckpt_path"]) self.segmentation_model.eval() for param in self.segmentation_model.parameters(): param.requires_grad = False # Load domain classifier if needed if ("domain_adv_w" in hyperparameters.keys() and hyperparameters["domain_adv_w"] > 0): self.domain_classifier = domainClassifier(256) dann_params = list(self.domain_classifier.parameters()) self.dann_opt = torch.optim.Adam( [p for p in dann_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.domain_classifier.apply(weights_init("gaussian")) self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters) def recon_criterion(self, input, target): """ Compute pixelwise L1 loss between two images input and target Arguments: input {torch.Tensor} -- Image tensor (original image such as x_a) target {torch.Tensor} -- Image tensor (after cycle-translation image x_aba) Returns: torch.Float -- pixelwise L1 loss """ return torch.mean(torch.abs(input - target)) def recon_criterion_mask(self, input, target, mask): """ Compute a weaker version of the recon_criterion between two images input and target where the L1 is only computed on the unmasked region Arguments: input {torch.Tensor} -- Image (original image such as x_a) target {torch.Tensor} -- Image (after cycle-translation image x_aba) mask {} -- binary Mask of size HxW (input.shape ~ CxHxW) Returns: torch.Float -- L1 loss over input.(1-mask) and target.(1-mask) """ return torch.mean(torch.abs(torch.mul((input - target), 1 - mask))) def forward(self, x_a, x_b): """ Perform the translation from domain A (resp B) to domain B (resp A): x_a to x_ab (resp: x_b to x_ba). Arguments: x_a {torch.Tensor} -- Image from domain A after transform in tensor format x_b {torch.Tensor} -- Image from domain B after transform in tensor format Returns: torch.Tensor, torch.Tensor -- Translated version of x_a in domain B, Translated version of x_b in domain A """ self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) if self.gen_state == 0: c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) elif self.gen_state == 1: c_a, s_a_fake = self.gen.encode(x_a, 1) c_b, s_b_fake = self.gen.encode(x_b, 2) x_ba = self.gen.decode(c_b, s_a, 1) x_ab = self.gen.decode(c_a, s_b, 2) else: print("self.gen_state unknown value:", self.gen_state) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters, mask_a=None, mask_b=None, comet_exp=None, synth=0): """ Update the generator parameters Arguments: x_a {torch.Tensor} -- Image from domain A after transform in tensor format x_b {torch.Tensor} -- Image from domain B after transform in tensor format hyperparameters {dictionnary} -- dictionnary with all hyperparameters Keyword Arguments: mask_a {torch.Tensor} -- binary mask (0,1) corresponding to the ground in x_a (default: {None}) mask_b {torch.Tensor} -- binary mask (0,1) corresponding to the water in x_b (default: {None}) comet_exp {cometExperience} -- CometML object use to log all the loss and images (default: {None}) synth {boolean} -- binary True or False stating if we have a synthetic pair or not Returns: [type] -- [description] """ self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) if self.gen_state == 0: # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) if self.guided == 0: x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) elif self.guided == 1: x_ba = self.gen_a.decode(c_b, s_a_prime) x_ab = self.gen_b.decode(c_a, s_b_prime) else: print("self.guided unknown value:", self.guided) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = (self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters["recon_x_cyc_w"] > 0 else None) x_bab = (self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters["recon_x_cyc_w"] > 0 else None) elif self.gen_state == 1: # encode c_a, s_a_prime = self.gen.encode(x_a, 1) print(c_a.shape) c_b, s_b_prime = self.gen.encode(x_b, 2) # decode (within domain) x_a_recon = self.gen.decode(c_a, s_a_prime, 1) x_b_recon = self.gen.decode(c_b, s_b_prime, 2) # decode (cross domain) if self.guided == 0: x_ba = self.gen.decode(c_b, s_a, 1) x_ab = self.gen.decode(c_a, s_b, 2) elif self.guided == 1: x_ba = self.gen.decode(c_b, s_a_prime, 1) x_ab = self.gen.decode(c_a, s_b_prime, 2) else: print("self.guided unknown value:", self.guided) # encode again c_b_recon, s_a_recon = self.gen.encode(x_ba, 1) c_a_recon, s_b_recon = self.gen.encode(x_ab, 2) # decode again (if needed) x_aba = (self.gen.decode(c_a_recon, s_a_prime, 1) if hyperparameters["recon_x_cyc_w"] > 0 else None) x_bab = (self.gen.decode(c_b_recon, s_b_prime, 2) if hyperparameters["recon_x_cyc_w"] > 0 else None) else: print("self.gen_state unknown value:", self.gen_state) # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) if self.guided == 0: self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) elif self.guided == 1: self.loss_gen_recon_s_a = self.recon_criterion( s_a_recon, s_a_prime) self.loss_gen_recon_s_b = self.recon_criterion( s_b_recon, s_b_prime) else: print("self.guided unknown value:", self.guided) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) # Synthetic reconstruction loss if self.check_alignment: print('mask_b.shape', mask_b.shape) # Define the mask of exact same pixel among a pair mask_alignment = (torch.sum(torch.abs(x_a - x_b), 1) == 0).unsqueeze(1) mask_alignment = mask_alignment.type(torch.cuda.FloatTensor) #print('mask_alignment.shape', mask_alignment.shape) self.loss_gen_recon_synth = self.recon_criterion_mask(x_ab, x_b, 1-mask_alignment) + \ self.recon_criterion_mask(x_ba, x_a, 1-mask_alignment) if self.check_alignment else 0 if self.recon_mask: self.loss_gen_cycrecon_x_a = (self.recon_criterion_mask( x_aba, x_a, mask_a) if hyperparameters["recon_x_cyc_w"] > 0 else 0) self.loss_gen_cycrecon_x_b = (self.recon_criterion_mask( x_bab, x_b, mask_b) if hyperparameters["recon_x_cyc_w"] > 0 else 0) else: self.loss_gen_cycrecon_x_a = (self.recon_criterion( x_aba, x_a) if hyperparameters["recon_x_cyc_w"] > 0 else 0) self.loss_gen_cycrecon_x_b = (self.recon_criterion( x_bab, x_b) if hyperparameters["recon_x_cyc_w"] > 0 else 0) # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = (self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters["vgg_w"] > 0 else 0) self.loss_gen_vgg_b = (self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters["vgg_w"] > 0 else 0) # semantic-segmentation loss self.loss_sem_seg = (self.compute_semantic_seg_loss( x_a.squeeze(), x_ab.squeeze(), mask_a) + self.compute_semantic_seg_loss( x_b.squeeze(), x_ba.squeeze(), mask_b) if hyperparameters["semantic_w"] > 0 else 0) # Domain adversarial loss (c_a and c_b are swapped because we want the feature to be less informative # minmax (accuracy but max min loss) self.domain_adv_loss = (self.compute_domain_adv_loss( c_a, c_b, compute_accuracy=False, minimize=False) if hyperparameters["domain_adv_w"] > 0 else 0) # total loss self.loss_gen_total = ( hyperparameters["gan_w"] * self.loss_gen_adv_a + hyperparameters["gan_w"] * self.loss_gen_adv_b + hyperparameters["recon_x_w"] * self.loss_gen_recon_x_a + hyperparameters["recon_s_w"] * self.loss_gen_recon_s_a + hyperparameters["recon_c_w"] * self.loss_gen_recon_c_a + hyperparameters["recon_x_w"] * self.loss_gen_recon_x_b + hyperparameters["recon_s_w"] * self.loss_gen_recon_s_b + hyperparameters["recon_c_w"] * self.loss_gen_recon_c_b + hyperparameters["recon_x_cyc_w"] * self.loss_gen_cycrecon_x_a + hyperparameters["recon_x_cyc_w"] * self.loss_gen_cycrecon_x_b + hyperparameters["vgg_w"] * self.loss_gen_vgg_a + hyperparameters["vgg_w"] * self.loss_gen_vgg_b + hyperparameters["semantic_w"] * self.loss_sem_seg + hyperparameters["domain_adv_w"] * self.domain_adv_loss + hyperparameters["recon_synth_w"] * self.loss_gen_recon_synth) self.loss_gen_total.backward() self.gen_opt.step() if comet_exp is not None: comet_exp.log_metric("loss_gen_adv_a", self.loss_gen_adv_a.cpu().detach()) comet_exp.log_metric("loss_gen_adv_b", self.loss_gen_adv_b.cpu().detach()) comet_exp.log_metric("loss_gen_recon_x_a", self.loss_gen_recon_x_a.cpu().detach()) comet_exp.log_metric("loss_gen_recon_s_a", self.loss_gen_recon_s_a.cpu().detach()) comet_exp.log_metric("loss_gen_recon_c_a", self.loss_gen_recon_c_a.cpu().detach()) comet_exp.log_metric("loss_gen_recon_x_b", self.loss_gen_recon_x_b.cpu().detach()) comet_exp.log_metric("loss_gen_recon_s_b", self.loss_gen_recon_s_b.cpu().detach()) comet_exp.log_metric("loss_gen_recon_c_b", self.loss_gen_recon_c_b.cpu().detach()) comet_exp.log_metric("loss_gen_cycrecon_x_a", self.loss_gen_cycrecon_x_a.cpu().detach()) comet_exp.log_metric("loss_gen_cycrecon_x_b", self.loss_gen_cycrecon_x_b.cpu().detach()) comet_exp.log_metric("loss_gen_total", self.loss_gen_total.cpu().detach()) if hyperparameters["vgg_w"] > 0: comet_exp.log_metric("loss_gen_vgg_a", self.loss_gen_vgg_a.cpu().detach()) comet_exp.log_metric("loss_gen_vgg_b", self.loss_gen_vgg_b.cpu().detach()) if hyperparameters["semantic_w"] > 0: comet_exp.log_metric("loss_sem_seg", self.loss_sem_seg.cpu().detach()) if hyperparameters["domain_adv_w"] > 0: comet_exp.log_metric("domain_adv_loss_gen", self.domain_adv_loss.cpu().detach()) if synth == 0: comet_exp.log_metric("loss_gen_recon_synth", self.loss_gen_recon_synth.cpu().detach()) def compute_vgg_loss(self, vgg, img, target): """ Compute the domain-invariant perceptual loss Arguments: vgg {model} -- popular Convolutional Network for Classification and Detection img {torch.Tensor} -- image before translation target {torch.Tensor} -- image after translation Returns: torch.Float -- domain invariant perceptual loss """ img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def compute_domain_adv_loss(self, c_a, c_b, compute_accuracy=False, minimize=True): """ Compute a domain adversarial loss on the embedding of the classifier: we are trying to learn an anonymized representation of the content. Arguments: c_a {torch.tensor} -- content extracted from an image of domain A with encoder A c_b {torch.tensor} -- content extracted from an image of domain B with encoder B Keyword Arguments: compute_accuracy {bool} -- either return only the loss or loss and softmax probs (default: {False}) minimize {bool} -- optimize classification accuracy(True) or anonymized the representation(False) Returns: torch.Float -- loss (optionnal softmax P(classifier(c_a)=a) and P(classifier(c_b)=b)) """ # Infer domain classifier on content extracted from an image of domainA output_a = self.domain_classifier(c_a) # Infer domain classifier on content extracted from an image of domainB output_b = self.domain_classifier(c_b) # Concatenate the output in a single vector output = torch.cat((output_a, output_b)) if minimize: target = torch.tensor([1., 0., 0., 1.], device='cuda') else: target = torch.tensor([0.5, 0.5, 0.5, 0.5], device='cuda') # mean square error loss loss = torch.nn.MSELoss()(output, target) if compute_accuracy: return loss, output_a[0], output_b[1] else: return loss def compute_semantic_seg_loss(self, img1, img2, mask=None): """ Compute semantic segmentation loss between two images on the unmasked region or in the entire image Arguments: img1 {torch.Tensor} -- Image from domain A after transform in tensor format img2 {torch.Tensor} -- Image transformed mask {torch.Tensor} -- Binary mask where we force the loss to be zero Returns: torch.float -- Cross entropy loss on the unmasked region """ # denorm img1_denorm = (img1 + 1) / 2.0 img2_denorm = (img2 + 1) / 2.0 # norm for semantic seg network input_transformed1 = seg_batch_transform(img1_denorm) input_transformed2 = seg_batch_transform(img2_denorm) # compute labels from original image and logits from translated version target = (self.segmentation_model(input_transformed1).max(1)[1]) output = self.segmentation_model(input_transformed2) if not self.full_adaptation and mask is not None: # Resize mask to the size of the image mask1 = torch.nn.functional.interpolate(mask, size=(self.newsize, self.newsize)) mask1_tensor = torch.tensor(mask1, dtype=torch.long).cuda() mask1_tensor = mask1_tensor.squeeze(1) # we want the masked region to be labeled as unknown (19 is not an existing label) target_with_mask = torch.mul(1 - mask1_tensor, target) + mask1_tensor * 19 mask2 = torch.nn.functional.interpolate(mask, size=(self.newsize, self.newsize)) mask_tensor = torch.tensor(mask2, dtype=torch.float).cuda() output_with_mask = torch.mul(1 - mask_tensor, output) # cat the mask as to the logits (loss=0 over the masked region) output_with_mask_cat = torch.cat((output_with_mask, mask_tensor), dim=1) loss = nn.CrossEntropyLoss()(output_with_mask_cat, target_with_mask) else: loss = nn.CrossEntropyLoss()(output, target) return loss def sample(self, x_a, x_b): """ Infer the model on a batch of image Arguments: x_a {torch.Tensor} -- batch of image from domain A x_b {[type]} -- batch of image from domain B Returns: A list of torch images -- columnwise :x_a, autoencode(x_a), x_ab_1, x_ab_2 Or if self.semantic_w is true: x_a, autoencode(x_a), Semantic segmentation x_a, x_ab_1,semantic segmentation x_ab_1, x_ab_2 """ self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] if self.gen_state == 0: for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) if self.guided == 0: x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) elif self.guided == 1: x_ba1.append(self.gen_a.decode( c_b, s_a_fake)) # s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode( c_b, s_a_fake)) # s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode( c_a, s_b_fake)) # s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode( c_a, s_b_fake)) # s_b2[i].unsqueeze(0))) else: print("self.guided unknown value:", self.guided) elif self.gen_state == 1: for i in range(x_a.size(0)): c_a, s_a_fake = self.gen.encode(x_a[i].unsqueeze(0), 1) c_b, s_b_fake = self.gen.encode(x_b[i].unsqueeze(0), 2) x_a_recon.append(self.gen.decode(c_a, s_a_fake, 1)) x_b_recon.append(self.gen.decode(c_b, s_b_fake, 2)) if self.guided == 0: x_ba1.append(self.gen.decode(c_b, s_a1[i].unsqueeze(0), 1)) x_ba2.append(self.gen.decode(c_b, s_a2[i].unsqueeze(0), 1)) x_ab1.append(self.gen.decode(c_a, s_b1[i].unsqueeze(0), 2)) x_ab2.append(self.gen.decode(c_a, s_b2[i].unsqueeze(0), 2)) elif self.guided == 1: x_ba1.append(self.gen.decode(c_b, s_a_fake, 1)) # s_a1[i].unsqueeze(0))) x_ba2.append(self.gen.decode(c_b, s_a_fake, 1)) # s_a2[i].unsqueeze(0))) x_ab1.append(self.gen.decode(c_a, s_b_fake, 2)) # s_b1[i].unsqueeze(0))) x_ab2.append(self.gen.decode(c_a, s_b_fake, 2)) # s_b2[i].unsqueeze(0))) else: print("self.guided unknown value:", self.guided) else: print("self.gen_state unknown value:", self.gen_state) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) if self.semantic_w: rgb_a_list, rgb_b_list, rgb_ab_list, rgb_ba_list = [], [], [], [] for i in range(x_a.size(0)): # Inference semantic segmentation on original images im_a = (x_a[i].squeeze() + 1) / 2.0 im_b = (x_b[i].squeeze() + 1) / 2.0 input_transformed_a = seg_transform()(im_a).unsqueeze(0) input_transformed_b = seg_transform()(im_b).unsqueeze(0) output_a = (self.segmentation_model( input_transformed_a).squeeze().max(0)[1]) output_b = (self.segmentation_model( input_transformed_b).squeeze().max(0)[1]) rgb_a = decode_segmap(output_a.cpu().numpy()) rgb_b = decode_segmap(output_b.cpu().numpy()) rgb_a = Image.fromarray(rgb_a).resize( (x_a.size(3), x_a.size(3))) rgb_b = Image.fromarray(rgb_b).resize( (x_a.size(3), x_a.size(3))) rgb_a_list.append(transforms.ToTensor()(rgb_a).unsqueeze(0)) rgb_b_list.append(transforms.ToTensor()(rgb_b).unsqueeze(0)) # Inference semantic segmentation on fake images image_ab = (x_ab1[i].squeeze() + 1) / 2.0 image_ba = (x_ba1[i].squeeze() + 1) / 2.0 input_transformed_ab = seg_transform()(image_ab).unsqueeze( 0).to("cuda") input_transformed_ba = seg_transform()(image_ba).unsqueeze( 0).to("cuda") output_ab = (self.segmentation_model( input_transformed_ab).squeeze().max(0)[1]) output_ba = (self.segmentation_model( input_transformed_ba).squeeze().max(0)[1]) rgb_ab = decode_segmap(output_ab.cpu().numpy()) rgb_ba = decode_segmap(output_ba.cpu().numpy()) rgb_ab = Image.fromarray(rgb_ab).resize( (x_a.size(3), x_a.size(3))) rgb_ba = Image.fromarray(rgb_ba).resize( (x_a.size(3), x_a.size(3))) rgb_ab_list.append(transforms.ToTensor()(rgb_ab).unsqueeze(0)) rgb_ba_list.append(transforms.ToTensor()(rgb_ba).unsqueeze(0)) rgb1_a, rgb1_b, rgb1_ab, rgb1_ba = ( torch.cat(rgb_a_list).cuda(), torch.cat(rgb_b_list).cuda(), torch.cat(rgb_ab_list).cuda(), torch.cat(rgb_ba_list).cuda(), ) self.train() if self.semantic_w: self.segmentation_model.eval() return ( x_a, x_a_recon, rgb1_a, x_ab1, rgb1_ab, x_ab2, x_b, x_b_recon, rgb1_b, x_ba1, rgb1_ba, x_ba2, ) else: return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def sample_fid(self, x_a, x_b): """ Infer the model on a batch of image Arguments: x_a {torch.Tensor} -- batch of image from domain A x_b {[type]} -- batch of image from domain B Returns: A list of torch images -- columnwise :x_a, autoencode(x_a), x_ab_1, x_ab_2 Or if self.semantic_w is true: x_a, autoencode(x_a), Semantic segmentation x_a, x_ab_1,semantic segmentation x_ab_1, x_ab_2 """ self.eval() x_ab1 = [] if self.gen_state == 0: for i in range(x_a.size(0)): c_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) _, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) if self.guided == 1: x_ab1.append(self.gen_b.decode(c_a, s_b_fake)) else: print("self.guided unknown value:", self.guided) elif self.gen_state == 1: for i in range(x_a.size(0)): c_a, _ = self.gen.encode(x_a[i].unsqueeze(0), 1) _, s_b_fake = self.gen.encode(x_b[i].unsqueeze(0), 2) if self.guided == 1: x_ab1.append(self.gen.decode(c_a, s_b_fake, 2)) else: print("self.guided unknown value:", self.guided) else: print("self.gen_state unknown value:", self.gen_state) x_ab1 = torch.cat(x_ab1) self.train() if self.semantic_w: self.segmentation_model.eval() return x_ab1 def dis_update(self, x_a, x_b, hyperparameters, comet_exp=None): """ Update the weights of the discriminator Arguments: x_a {torch.Tensor} -- Image from domain A after transform in tensor format x_b {torch.Tensor} -- Image from domain B after transform in tensor format hyperparameters {dictionnary} -- dictionnary with all hyperparameters Keyword Arguments: comet_exp {cometExperience} -- CometML object use to log all the loss and images (default: {None}) """ self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) if self.gen_state == 0: # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (cross domain) if self.guided == 0: x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) elif self.guided == 1: x_ba = self.gen_a.decode(c_b, s_a_prime) x_ab = self.gen_b.decode(c_a, s_b_prime) else: print("self.guided unknown value:", self.guided) elif self.gen_state == 1: # encode c_a, s_a_prime = self.gen.encode(x_a, 1) c_b, s_b_prime = self.gen.encode(x_b, 2) # decode (cross domain) if self.guided == 0: x_ba = self.gen.decode(c_b, s_a, 1) x_ab = self.gen.decode(c_a, s_b, 2) elif self.guided == 1: x_ba = self.gen.decode(c_b, s_a_prime, 1) x_ab = self.gen.decode(c_a, s_b_prime, 2) else: print("self.guided unknown value:", self.guided) else: print("self.gen_state unknown value:", self.gen_state) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = (hyperparameters["gan_w"] * self.loss_dis_a + hyperparameters["gan_w"] * self.loss_dis_b) self.loss_dis_total.backward() self.dis_opt.step() if comet_exp is not None: comet_exp.log_metric("loss_dis_b", self.loss_dis_b.cpu().detach()) comet_exp.log_metric("loss_dis_a", self.loss_dis_a.cpu().detach()) def domain_classifier_update(self, x_a, x_b, hyperparameters, comet_exp=None): """ Update the weights of the domain classifier Arguments: x_a {torch.Tensor} -- Image from domain A after transform in tensor format x_b {torch.Tensor} -- Image from domain B after transform in tensor format hyperparameters {dictionnary} -- dictionnary with all hyperparameters Keyword Arguments: comet_exp {cometExperience} -- CometML object use to log all the loss and images (default: {None}) """ self.dann_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) if self.gen_state == 0: # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) elif self.gen_state == 1: # encode c_a, _ = self.gen.encode(x_a, 1) c_b, _ = self.gen.encode(x_b, 2) else: print("self.gen_state unknown value:", self.gen_state) # domain classifier loss self.domain_class_loss, out_a, out_b = self.compute_domain_adv_loss( c_a, c_b, compute_accuracy=True, minimize=True) self.domain_class_loss.backward() self.dann_opt.step() if comet_exp is not None: comet_exp.log_metric("domain_class_loss", self.domain_class_loss.cpu().detach()) comet_exp.log_metric("probability A being identified as A", out_a.cpu().detach()) comet_exp.log_metric("probability B being identified as B", out_b.cpu().detach()) def update_learning_rate(self): """ Update the learning rate """ if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.dann_scheduler is not None: self.dann_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): """ Resume the training loading the network parameters Arguments: checkpoint_dir {string} -- path to the directory where the checkpoints are saved hyperparameters {dictionnary} -- dictionnary with all hyperparameters Returns: int -- number of iterations (used by the optimizer) """ # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) if self.gen_state == 0: self.gen_a.load_state_dict(state_dict["a"]) self.gen_b.load_state_dict(state_dict["b"]) elif self.gen_state == 1: self.gen.load_state_dict(state_dict["2"]) else: print("self.gen_state unknown value:", self.gen_state) # Load domain classifier if self.domain_classif == 1: last_model_name = get_model_list(checkpoint_dir, "domain_classif") state_dict = torch.load(last_model_name) self.domain_classifier.load_state_dict(state_dict["d"]) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict["a"]) self.dis_b.load_state_dict(state_dict["b"]) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, "optimizer.pt")) self.dis_opt.load_state_dict(state_dict["dis"]) self.gen_opt.load_state_dict(state_dict["gen"]) if self.domain_classif == 1: self.dann_opt.load_state_dict(state_dict["dann"]) self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters, iterations) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print("Resume from iteration %d" % iterations) return iterations def save(self, snapshot_dir, iterations): """ Save generators, discriminators, and optimizers Arguments: snapshot_dir {string} -- directory path where to save the networks weights iterations {int} -- number of training iterations """ # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, "gen_%08d.pt" % (iterations + 1)) dis_name = os.path.join(snapshot_dir, "dis_%08d.pt" % (iterations + 1)) domain_classifier_name = os.path.join( snapshot_dir, "domain_classifier_%08d.pt" % (iterations + 1)) opt_name = os.path.join(snapshot_dir, "optimizer.pt") if self.gen_state == 0: torch.save( { "a": self.gen_a.state_dict(), "b": self.gen_b.state_dict() }, gen_name) elif self.gen_state == 1: torch.save({"2": self.gen.state_dict()}, gen_name) else: print("self.gen_state unknown value:", self.gen_state) torch.save({ "a": self.dis_a.state_dict(), "b": self.dis_b.state_dict() }, dis_name) if self.domain_classif: torch.save({"d": self.domain_classifier.state_dict()}, domain_classifier_name) torch.save( { "gen": self.gen_opt.state_dict(), "dis": self.dis_opt.state_dict(), "dann": self.dann_opt.state_dict(), }, opt_name, ) else: torch.save( { "gen": self.gen_opt.state_dict(), "dis": self.dis_opt.state_dict() }, opt_name, )
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), eps=1e-8, weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), eps=1e-8, weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) self.iter = 0 def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def intrinsic_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def volumeloss_criterion(self, input, target): idx_select = torch.tensor([0]).cuda() input, target = input.index_select(1, idx_select), target.index_select( 1, idx_select) input, target = torch.mean(input, 3), torch.mean(target, 3) input, target = torch.mean(input, 2), torch.mean(target, 2) return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): with torch.no_grad(): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def update_iter(self): self.iter += 1 def gen_update(self, x_a, x_b, hyperparameters, x_a_rand=None, x_b_rand=None): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode( c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba, x_a_rand) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab, x_b_rand) # ceps loss self.loss_gen_ceps_a = self.calc_cepstrum_loss( x_ba) if hyperparameters['ceps_w'] > 0 else 0 self.loss_gen_ceps_b = self.calc_cepstrum_loss( x_ab) if hyperparameters['ceps_w'] > 0 else 0 # flux loss self.loss_gen_flux_a2b = self.calc_spectral_flux_loss( x_ab) if hyperparameters['flux_w'] > 0 else 0 self.loss_gen_flux_b2a = self.calc_spectral_flux_loss( x_ba) if hyperparameters['flux_w'] > 0 else 0 # enve loss self.loss_gen_enve_a2b = self.calc_spectral_enve15_loss( x_ab) if hyperparameters['enve_w'] > 0 else 0 self.loss_gen_enve_b2a = self.calc_spectral_enve15_loss( x_ba) if hyperparameters['enve_w'] > 0 else 0 # volume loss self.loss_gen_vol_a = self.volumeloss_criterion( x_a, x_ab) if hyperparameters['vol_w'] > 0 else 0 self.loss_gen_vol_b = self.volumeloss_criterion( x_b, x_ba) if hyperparameters['vol_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['ceps_w'] * self.loss_gen_ceps_a + \ hyperparameters['ceps_w'] * self.loss_gen_ceps_b + \ hyperparameters['flux_w'] * self.loss_gen_flux_a2b + \ hyperparameters['flux_w'] * self.loss_gen_flux_b2a + \ hyperparameters['enve_w'] * self.loss_gen_enve_a2b + \ hyperparameters['enve_w'] * self.loss_gen_enve_b2a + \ hyperparameters['vol_w'] * self.loss_gen_vol_a + \ hyperparameters['vol_w'] * self.loss_gen_vol_b self.loss_gen_total.backward() if hyperparameters['clip_grad'] == 'value': torch.nn.utils.clip_grad_value_( list(self.gen_a.parameters()) + list(self.gen_b.parameters()), 1) elif hyperparameters['clip_grad'] == 'norm': torch.nn.utils.clip_grad_norm_( list(self.gen_a.parameters()) + list(self.gen_b.parameters()), 0.5) self.gen_opt.step() def calc_cepstrum_loss(self, x_fake): idx_select_spec = torch.tensor([0]).cuda() idx_select_ceps = torch.tensor([1]).cuda() fake_spec = x_fake.index_select( 1, idx_select_spec).detach().cpu().numpy() ceps = scipy.fftpack.dct(fake_spec, axis=2, type=2, norm='ortho') ceps = np.maximum(ceps, 0) return self.intrinsic_criterion( x_fake.index_select(1, idx_select_ceps), torch.from_numpy(ceps).cuda()) def calc_spectral_flux_loss(self, x_fake): idx_select_spec = torch.tensor([0]).cuda() idx_select_flux = torch.tensor([2]).cuda() fake_spec = x_fake.index_select( 1, idx_select_spec).detach().cpu().numpy() spec_flux = np.zeros_like(fake_spec) hei, wid = 256, 256 for i in range(1, wid - 1): spec_flux[:, :, :, i] = np.maximum( fake_spec[:, :, :, i + 1] - fake_spec[:, :, :, i - 1], 0.0) spec_flux[:, :, :, 0] = spec_flux[:, :, :, 1] spec_flux[:, :, :, -1] = spec_flux[:, :, :, -2] return self.intrinsic_criterion( x_fake.index_select(1, idx_select_flux), torch.from_numpy(spec_flux).cuda()) def calc_spectral_enve15_loss(self, x_fake): idx_select_spec = torch.tensor([0]).cuda() idx_select_enve = torch.tensor([3]).cuda() fake_spec = x_fake.index_select( 1, idx_select_spec).detach().cpu().numpy() MFCC = scipy.fftpack.dct(fake_spec, axis=2, type=2, norm='ortho') MFCC[:, :, 15:, :] = 0.0 spec_enve = scipy.fftpack.idct(MFCC, axis=2, type=2, norm='ortho') spec_enve = np.maximum(spec_enve, 0.0) return self.intrinsic_criterion( x_fake.index_select(1, idx_select_enve), torch.from_numpy(spec_enve).cuda()) def sample(self, x_a, x_b): with torch.no_grad(): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable( torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable( torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() #self.save_grad(list(self.dis_a.named_parameters()) + list(self.dis_b.named_parameters())) #torch.nn.utils.clip_grad_norm_(list(self.dis_a.parameters()) + list(self.dis_b.parameters()), 0.5) self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.dis_sa = MsImageDis( hyperparameters['input_dim_a'] * 2, hyperparameters['dis']) # discriminator for domain a self.dis_sb = MsImageDis( hyperparameters['input_dim_b'] * 2, hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) dis_style_params = list(self.dis_sa.parameters()) + list( self.dis_sb.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_style_opt = torch.optim.Adam( [p for p in dis_style_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.dis_style_scheduler = get_scheduler(self.dis_style_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_sa.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) self.dis_sb.apply(weights_init('gaussian')) if hyperparameters['gen']['CE_method'] == 'vgg': self.gen_a.content_init() self.gen_b.content_init() self.criterion = nn.L1Loss().cuda() self.triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2).cuda() self.kld = nn.KLDivLoss() self.contextual_loss = ContextualLoss() # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def kl_loss(self, input, target): #return torch.mean(torch.abs(self.kld(input, target))) return torch.mean(self.kld(input, target)) def normalize_feat(self, feat): bs, c, H, W = feat.shape feat = feat.view(bs, c, -1) feat_norm = torch.norm(feat, 2, 1, keepdim=True) + sys.float_info.epsilon feat = torch.div(feat, feat_norm) #print(max(feat)) return feat def norm_two_domain(self, feat_c, feat_s): feat = torch.cat((feat_c, feat_s), 1) bs, c, H, W = feat.shape feat_norm = torch.norm(feat, 2, 1, keepdim=True) feat = torch.div(feat, feat_norm) feat_c = feat[:, 0:256, :, :].view(bs, 256, -1) feat_s = feat[:, 256:512, :, :].view(bs, 256, -1) return feat_c, feat_s def generate_map(self, corr_index, h, w): coor = [] corr_map = [] for i in range(len(corr_index)): x = corr_index[i] // h y = corr_index[i] % w coor.append(x) coor.append(y) corr_map.append(list(np.asarray(coor))) coor.clear() corr_map_final = np.reshape(np.asarray(corr_map), (h, w, 2)) return corr_map_final def warp_img(self, corr_map, ref_img): bs, c, h_img, w_img = ref_img.shape h, w, _ = corr_map.shape scale = h_img // h warped_img = torch.zeros(ref_img.shape) for i in range(h): for j in range(w): nnx = corr_map[i][j][0] nny = corr_map[i][j][1] warped_img[:, :, i * scale:(i + 1) * scale, j * scale:(j + 1) * scale] = ref_img[:, :, nnx * scale:(nnx + 1) * scale, nny * scale:(nny + 1) * scale] return warped_img.cuda() def warp_style(self, cur_content, ref_content, ref_style): # normalize feature cur_content = self.normalize_feat(cur_content) ref_content = self.normalize_feat(ref_content) #cur_content, ref_content = self.norm_two_domain(cur_content, ref_content) cur_content = cur_content.permute(0, 2, 1) # calculate similarity f = torch.matmul(cur_content, ref_content) # 1 x (H x W) x (H x W) f_corr = F.softmax(f / 0.005, dim=-1) # 1 x (H x W) x (H x W) #f_corr = F.softmax(f, dim=-1) # 1 x (H x W) x (H x W) # get corr index replace softmax bs, HW, WH = f_corr.shape corr_index = torch.argmax(f_corr, dim=-1).squeeze(0) # collect ref style bs, c, H, W = ref_style.shape ref_style = ref_style.view(bs, c, -1) ref_style = ref_style.permute(0, 2, 1) # 1 x (H x W) x c # warp ref style warped_style = torch.matmul(f_corr, ref_style) # 1 x (H x W) x c warped_style = warped_style.permute(0, 2, 1).contiguous() warped_style = warped_style.view(bs, c, H, W) return corr_index, warped_style def forward(self, x_a, x_b): self.eval() c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) # warp the ref_style to the content_style _, s_ab_warp = warp_style(c_a, c_b, s_b_fake) _, s_ba_warp = warp_style(c_b, c_a, s_a_fake) x_ba = self.gen_a.decode(s_ba_warp, c_b) x_ab = self.gen_b.decode(s_ab_warp, c_a) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, x_adf, x_bdf, hyperparameters): self.gen_opt.zero_grad() # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # add style warp here _, s_ab = self.warp_style(c_a, c_b, s_b_prime) _, s_ba = self.warp_style(c_b, c_a, s_a_prime) # decode (within domain) x_a_recon = self.gen_a.decode(s_a_prime, c_a) x_b_recon = self.gen_b.decode(s_b_prime, c_b) # decode (cross domain) x_ba = self.gen_a.decode(s_ba, c_b) x_ab = self.gen_b.decode(s_ab, c_a) # encode again c_b_recon, s_ba_recon = self.gen_a.encode( x_ba) # now the s_a_recon matches the structure of B c_a_recon, s_ab_recon = self.gen_b.encode(x_ab) # decode again (if needed) # to warp style first _, s_aba = self.warp_style(c_a, c_b, s_ba_recon) _, s_bab = self.warp_style(c_b, c_a, s_ab_recon) # to reconstruct then x_aba = self.gen_a.decode( s_a_prime, c_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( s_b_prime, c_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # prepare paired data for adv generator #pair_a_ffake = torch.cat((x_ba, x_a), 1) #pair_b_ffake = torch.cat((x_ab, x_b), 1) # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_aba, s_a_prime) self.loss_gen_recon_s_b = self.recon_criterion( s_bab, s_b_prime) # default is s_bab, need to test s_b_recon #self.loss_gen_recon_s_a += self.triplet_loss(s_a_prime, s_aba, s_b_prime) #self.loss_gen_recon_s_b += self.triplet_loss(s_b_prime, s_bab, s_a_prime) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) #self.loss_gen_kl_ab = self.kl_loss(x_ab, x_b) #self.loss_gen_kl_ba = self.kl_loss(x_ba, x_a) self.loss_gen_cx_a = self.contextual_loss(s_ba, s_a_prime) self.loss_gen_cx_b = self.contextual_loss(s_ab, s_b_prime) self.loss_gen_cycrecon_x_a = self.criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_xa = self.gen_a.calc_gen_loss( self.dis_a.forward(x_ba)) self.loss_gen_adv_xb = self.gen_b.calc_gen_loss( self.dis_b.forward(x_ab)) #self.loss_gen_adv_sxa = self.gen_a.calc_gen_loss(self.dis_sa.forward(pair_a_ffake)) #self.loss_gen_adv_sxb = self.gen_b.calc_gen_loss(self.dis_sb.forward(pair_b_ffake)) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss_new( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss_new( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_xa + \ hyperparameters['gan_w'] * self.loss_gen_adv_xb + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b #hyperparameters['recon_kl_w'] * self.loss_gen_kl_ab + \ #hyperparameters['recon_kl_w'] * self.loss_gen_kl_ba + \ #hyperparameters['recon_cx_w'] * self.loss_gen_cx_a + \ #hyperparameters['recon_cx_w'] * self.loss_gen_cx_b + \ #hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_s_a + \ #hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_s_b + \ #hyperparameters['gan_wp'] * self.loss_gen_adv_sxa + \ #hyperparameters['gan_wp'] * self.loss_gen_adv_sxb + \ self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def compute_vgg_loss_new(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_feat = vgg(img_vgg) target_feat = vgg(target_vgg) return self.recon_criterion(img_feat, target_feat) def sample(self, x_a, x_b, x_adf, x_bdf): self.eval() #s_a1 = Variable(self.s_a) #s_b1 = Variable(self.s_b) #s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) #s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] h = w = 64 for i in range(x_a.size(0)): img_a = x_a[i].unsqueeze(0) img_b = x_b[i].unsqueeze(0) c_a, s_a_fake = self.gen_a.encode(img_a) c_b, s_b_fake = self.gen_b.encode(img_b) # reconstruction x_a_recon.append(self.gen_a.decode(s_a_fake, c_a)) x_b_recon.append(self.gen_b.decode(s_b_fake, c_b)) print(x_a_recon[0].shape) # warp style corr_index_ab, s_ab = self.warp_style(c_a, c_b, s_b_fake) corr_index_ba, s_ba = self.warp_style(c_b, c_a, s_a_fake) # cross domain construction x_ba1.append(self.gen_a.decode(s_ba, c_b)) ## output warped results x_ba2 corr_map_ba = self.generate_map(corr_index_ba, h, w) x_ba2.append(self.warp_img(corr_map_ba, img_a)) x_ab1.append(self.gen_b.decode(s_ab, c_a)) ## output warped results x_ab2 corr_map_ab = self.generate_map(corr_index_ab, h, w) x_ab2.append(self.warp_img(corr_map_ab, img_b)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_adf, x_a_recon, x_ab1, x_ab2, x_b, x_bdf, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, x_adf, x_bdf, hyperparameters): self.dis_opt.zero_grad() self.dis_style_opt.zero_grad() #s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) #s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a = self.gen_a.encode(x_a) c_b, s_b = self.gen_b.encode(x_b) # warp the style here _, s_ab_warp = self.warp_style(c_a, c_b, s_b) _, s_ba_warp = self.warp_style(c_b, c_a, s_a) # decode (cross domain) x_ba = self.gen_a.decode(s_ba_warp, c_b) x_ab = self.gen_b.decode(s_ab_warp, c_a) # prepare data for the paired discriminator # real fake data -> 0 if (len(self.dis_sa.pool_) == 0): print(len(self.dis_sa.pool_)) pair_a_rfake = torch.cat((x_b, x_a), 1) else: pair_a_rfake = torch.cat((self.dis_sa.pool('fetch'), x_a), 1) self.dis_sa.pool('push', x_a) if (len(self.dis_sb.pool_) == 0): print(len(self.dis_sb.pool_)) pair_b_rfake = torch.cat((x_a, x_b), 1) else: pair_b_rfake = torch.cat((self.dis_sb.pool('fetch'), x_b), 1) self.dis_sb.pool('push', x_b) # real real data -> 1 pair_a_rreal = torch.cat((x_a, x_adf), 1) pair_b_rreal = torch.cat((x_b, x_bdf), 1) # fake fake data -> 0 #pair_a_ffake = torch.cat((x_ba.detach(), x_a), 1) #pair_b_ffake = torch.cat((x_ab.detach(), x_b), 1) # D loss self.loss_dis_xa = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_xb = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) #self.loss_dis_xa = self.dis_a.calc_dis_loss(x_ba.detach(), self.dis_sa.pool('fetch')) #self.loss_dis_xb = self.dis_b.calc_dis_loss(x_ab.detach(), self.dis_sb.pool('fetch')) #self.loss_dis_sxa = (self.dis_sa.calc_dis_loss(pair_a_rfake, pair_a_rreal) + self.dis_sa.calc_dis_loss(pair_a_ffake, pair_a_rreal)) / 2 #self.loss_dis_sxb = (self.dis_sb.calc_dis_loss(pair_b_rfake, pair_b_rreal) + self.dis_sb.calc_dis_loss(pair_b_ffake, pair_b_rreal)) / 2 #self.loss_dis_sxa = self.dis_sa.calc_dis_loss(pair_a_ffake, pair_a_rreal) #self.loss_dis_sxb = self.dis_sb.calc_dis_loss(pair_b_ffake, pair_b_rreal) #self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_xa + hyperparameters['gan_w'] * self.loss_dis_xb + hyperparameters['gan_wp'] * self.loss_dis_sxa + hyperparameters['gan_wp'] * self.loss_dis_sxb self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_xa + hyperparameters[ 'gan_w'] * self.loss_dis_xb self.loss_dis_total.backward() self.dis_opt.step() self.dis_style_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.dis_style_scheduler is not None: self.dis_style_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.dis_style_scheduler = get_scheduler(self.dis_style_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True s_a = Variable(self.s_a, volatile=True) s_b = Variable(self.s_b, volatile=True) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def sample(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True s_a1 = Variable(self.s_a, volatile=True) s_b1 = Variable(self.s_b, volatile=True) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(), volatile=True) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(), volatile=True) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class DGNet_Trainer(nn.Module): def __init__(self, hyperparameters, gpu_ids=[0]): super(DGNet_Trainer, self).__init__() # 从配置文件获取生成模型和鉴别模型的学习率 lr_g = hyperparameters['lr_g'] lr_d = hyperparameters['lr_d'] # ID 类别 ID_class = hyperparameters['ID_class'] # 是否设置使用fp16, if not 'apex' in hyperparameters.keys(): hyperparameters['apex'] = False self.fp16 = hyperparameters['apex'] # Initiate the networks # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False. self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16=False) # auto-encoder for domain a self.gen_b = self.gen_a # auto-encoder for domain b ''' ft_netAB : Ea ''' # ID_stride: 外观编码器池化层的stride if not 'ID_stride' in hyperparameters.keys(): hyperparameters['ID_stride'] = 2 # id_a : 外观编码器 -> Ea if hyperparameters['ID_style'] == 'PCB': self.id_a = PCB(ID_class) elif hyperparameters['ID_style'] == 'AB': self.id_a = ft_netAB(ID_class, stride=hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) else: self.id_a = ft_net(ID_class, norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) # return 2048 now self.id_b = self.id_a # 对图片b的操作与图片a的操作一致 # 判别器,使用的是一个多尺寸的判别器,就是对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失 # 经过网络3个缩放,,分别为:[batch_size, 1, 64, 32],[batch_size, 1, 32, 16],[batch_size, 1, 16, 8] self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16=False) # discriminator for domain a self.dis_b = self.dis_a # discriminator for domain b # load teachers if hyperparameters['teacher'] != "": teacher_name = hyperparameters['teacher'] print(teacher_name) # 加载多个老师模型 teacher_names = teacher_name.split(',') # 构建老师模型 teacher_model = nn.ModuleList() # 初始化为空,接下来开始填充 teacher_count = 0 for teacher_name in teacher_names: config_tmp = load_config(teacher_name) # 池化层的stride if 'stride' in config_tmp: stride = config_tmp['stride'] else: stride = 2 # 开始搭建网络 model_tmp = ft_net(ID_class, stride=stride) teacher_model_tmp = load_network(model_tmp, teacher_name) teacher_model_tmp.model.fc = nn.Sequential( ) # remove the original fc layer in ImageNet teacher_model_tmp = teacher_model_tmp.cuda() # teacher_model_tmp,[3, 224, 224] # 使用fp16 if self.fp16: teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1") teacher_model.append(teacher_model_tmp.cuda().eval( )) # 第一个填充为 teacher_model_tmp.cuda().eval() teacher_count += 1 self.teacher_model = teacher_model # 是否使用batchnorm if hyperparameters['train_bn']: self.teacher_model = self.teacher_model.apply(train_bn) # 实例正则化 self.instancenorm = nn.InstanceNorm2d(512, affine=False) # RGB to one channel # 因为Es 需要使用灰度图, 所以single 用来将图片转化为灰度图 if hyperparameters['single'] == 'edge': self.single = to_edge else: self.single = to_gray(False) # Random Erasing when training # arasing_p 随机擦除的概率 if not 'erasing_p' in hyperparameters.keys(): self.erasing_p = 0 else: self.erasing_p = hyperparameters['erasing_p'] # 对图片中的某一随机区域进行擦除,具体:将该区域的像素值设置为均值 self.single_re = RandomErasing(probability=self.erasing_p, mean=[0.0, 0.0, 0.0]) if not 'T_w' in hyperparameters.keys(): hyperparameters['T_w'] = 1 # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list( self.dis_a.parameters()) #+ list(self.dis_b.parameters()) gen_params = list( self.gen_a.parameters()) #+ list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) # id params # 修改 id_a模型中分类器的学习率 if hyperparameters['ID_style'] == 'PCB': ignored_params = ( list(map(id, self.id_a.classifier0.parameters())) + list(map(id, self.id_a.classifier1.parameters())) + list(map(id, self.id_a.classifier2.parameters())) + list(map(id, self.id_a.classifier3.parameters()))) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier0.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier1.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier2.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier3.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) elif hyperparameters['ID_style'] == 'AB': ignored_params = ( list(map(id, self.id_a.classifier1.parameters())) + list(map(id, self.id_a.classifier2.parameters()))) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier1.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier2.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) else: ignored_params = list(map(id, self.id_a.classifier.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) # 生成器和判别器中的优化策略(学习率的更新策略) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) self.id_scheduler = get_scheduler(self.id_opt, hyperparameters) self.id_scheduler.gamma = hyperparameters['gamma2'] #ID Loss self.id_criterion = nn.CrossEntropyLoss() self.criterion_teacher = nn.KLDivLoss( size_average=False) # 生成主要特征: Lprim # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False # save memory # 保存当前的模型,是为了提高计算效率 if self.fp16: # Name the FP16_Optimizer instance to replace the existing optimizer assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." self.gen_a = self.gen_a.cuda() self.dis_a = self.dis_a.cuda() self.id_a = self.id_a.cuda() self.gen_b = self.gen_a self.dis_b = self.dis_a self.id_b = self.id_a self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1") self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1") self.id_a, self.id_opt = amp.initialize(self.id_a, self.id_opt, opt_level="O1") def to_re(self, x): out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3)) out = out.cuda() for i in range(x.size(0)): out[i, :, :, :] = self.single_re(x[i, :, :, :]) # 修改对应像素值 return out def recon_criterion(self, input, target): # 重构损失函数 diff = input - target.detach() # 对应像素之间相减 return torch.mean(torch.abs(diff[:])) def recon_criterion_sqrt(self, input, target): # 重构损失平方函数 diff = input - target return torch.mean(torch.sqrt(torch.abs(diff[:]) + 1e-8)) def recon_criterion2(self, input, target): # 重构损失平方求均值 diff = input - target return torch.mean(diff[:]**2) def recon_cos(self, input, target): # 重构均值余弦相似度损失 cos = torch.nn.CosineSimilarity() cos_dis = 1 - cos(input, target) return torch.mean(cos_dis[:]) def forward(self, x_a, x_b, xp_a, xp_b): ''' 一共输入4张图片 :param x_a: :param xp_a: id 相同 :param x_b: :param xp_b: id 相同 为什么要输入四张图片: 因为一个完整的DG_Net输入需要三张图片:id1, id2, id1正例 如果一次输入3张图片,那么训练两组数据就需要6张图片 而如果一次输入四张图片如:id1,id1正例, id2,id2正例 那么就可以组成两组数据:id1,id2,id1正例 和 id2,id1,d2正例 这样就节省了两张图片。 ''' # self.gen_a.encode :-> Es # single : 转化为灰度图 s_a = self.gen_a.encode(self.single( x_a)) # shape: [batch_size, 128, 64, 32] -> a st code s_b = self.gen_b.encode(self.single( x_b)) # shape: [batch_size, 128, 64, 32] -> b st code # self.id_a : -> Ea f_a, p_a = self.id_a( scale2(x_a)) # -> a ap code f_b, p_b = self.id_b(scale2(x_b)) # f shape:[batch_size, 2024*4=8192] # -> b ap code # p[0] shape:[batch_size, class_num=751], p[1] shape:[batch_size, class_num=751] -> probability distribution # self.gen_a.decode -> D x_ba = self.gen_a.decode( s_b, f_a) # shape: [batch_size, 3, 256, 128] -> a-ap + b-st x_ab = self.gen_b.decode( s_a, f_b) # shape: [batch_size, 3, 256, 128] -> a-st + b-ap x_a_recon = self.gen_a.decode( s_a, f_a) # shape: [batch_size, 3, 256, 128] -> a-ap + a-st x_b_recon = self.gen_b.decode( s_b, f_b) # shape: [batch_size, 3, 256, 128] -> b-ap + b-st fp_a, pp_a = self.id_a( scale2(xp_a) ) # -> x_a ap code, pro-dis fp_b, pp_b = self.id_b( scale2(xp_b) ) # -> x_b ap code, pro-dis # decode the same person x_a_recon_p = self.gen_a.decode( s_a, fp_a) # shape: [batch_size, 3, 256, 128] -> a-st + x_a-ap x_b_recon_p = self.gen_b.decode( s_b, fp_b) # shape: [batch_size, 3, 256, 128] -> b-st + x_b-ap # Random Erasing only effect the ID and PID loss. if self.erasing_p > 0: x_a_re = self.to_re(scale2(x_a.clone())) x_b_re = self.to_re(scale2(x_b.clone())) xp_a_re = self.to_re(scale2(xp_a.clone())) xp_b_re = self.to_re(scale2(xp_b.clone())) _, p_a = self.id_a(x_a_re) # 经过随机擦除之后再预测概率分布 _, p_b = self.id_b(x_b_re) # encode the same ID different photo _, pp_a = self.id_a(xp_a_re) _, pp_b = self.id_b(xp_b_re) return x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p ''' 输入3张图片训练一次 s_a = self.gen_a.encode(self.single(x_a)) f_a, p_a = self.id_a(scale2(x_a)) f_b, p_b = self.id_b(scale2(x_b)) fp_a, pp_a = self.id_a(scale2(xp_a)) x_a_recon = self.gen_a.decode(s_a, f_a) x_ab = self.gen_b.decode(s_a, f_b) x_a_recon_p = self.gen_a.decode(s_a, fp_a) 输入3张图片训练一次 s_b = self.gen_b.encode(self.single(x_b)) f_a, p_a = self.id_a(scale2(x_a)) f_b, p_b = self.id_b(scale2(x_b)) fp_b, pp_b = self.id_b(scale2(xp_b)) x_ba = self.gen_a.decode(s_b, f_a) x_b_recon_p = self.gen_b.decode(s_b, fp_b) x_b_recon_p = self.gen_b.decode(s_b, fp_b) ''' def gen_update(self, x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, x_a, x_b, xp_a, xp_b, l_a, l_b, hyperparameters, iteration, num_gpu): # ppa, ppb is the same person # pp_a: 输入图片a经过Ea编码进行身份预测 pp_b:输入图片b经过Ea编码进行身份预测 self.gen_opt.zero_grad() self.id_opt.zero_grad() # no gradient x_ba_copy = Variable(x_ba.data, requires_grad=False) x_ab_copy = Variable(x_ab.data, requires_grad=False) rand_num = random.uniform(0, 1) ################################# # encode structure if hyperparameters['use_encoder_again'] >= rand_num: # encode again (encoder is tuned, input is fixed) s_a_recon = self.gen_b.enc_content( self.single(x_ab_copy)) # 对x_ab经过Es进行编码 得到st code s_b_recon = self.gen_a.enc_content( self.single(x_ba_copy)) # 对x_ba经过Es进行编码 得到st code else: # copy the encoder # 这里是shencopy self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content) self.enc_content_copy = self.enc_content_copy.eval() # encode again (encoder is fixed, input is tuned) s_a_recon = self.enc_content_copy(self.single(x_ab)) s_b_recon = self.enc_content_copy(self.single(x_ba)) ################################# # encode appearance self.id_a_copy = copy.deepcopy(self.id_a) self.id_a_copy = self.id_a_copy.eval() if hyperparameters['train_bn']: self.id_a_copy = self.id_a_copy.apply(train_bn) self.id_b_copy = self.id_a_copy # encode again (encoder is fixed, input is tuned) f_a_recon, p_a_recon = self.id_a_copy( scale2(x_ba)) # 对合成的图片 x_ba进行Ea编码和身份预测 f_b_recon, p_b_recon = self.id_b_copy( scale2(x_ab)) # 对合成的图片 x_ab进行Ea编码和身份预测 # teacher Loss # Tune the ID model log_sm = nn.LogSoftmax(dim=1) if hyperparameters['teacher_w'] > 0 and hyperparameters[ 'teacher'] != "": if hyperparameters['ID_style'] == 'normal': _, p_a_student = self.id_a(scale2(x_ba_copy)) p_a_student = log_sm(p_a_student) p_a_teacher = predict_label( self.teacher_model, scale2(x_ba_copy), num_class=hyperparameters['ID_class'], alabel=l_a, slabel=l_b, teacher_style=hyperparameters['teacher_style']) self.loss_teacher = self.criterion_teacher( p_a_student, p_a_teacher) / p_a_student.size(0) _, p_b_student = self.id_b(scale2(x_ab_copy)) p_b_student = log_sm(p_b_student) p_b_teacher = predict_label( self.teacher_model, scale2(x_ab_copy), num_class=hyperparameters['ID_class'], alabel=l_b, slabel=l_a, teacher_style=hyperparameters['teacher_style']) self.loss_teacher += self.criterion_teacher( p_b_student, p_b_teacher) / p_b_student.size(0) elif hyperparameters['ID_style'] == 'AB': # normal teacher-student loss # BA -> LabelA(smooth) + LabelB(batchB) # 合成的图片经过身份鉴别器,得到每个ID可能的概率 _, p_ba_student = self.id_a(scale2(x_ba_copy)) # f_a, s_b p_a_student = log_sm(p_ba_student[0]) # 两个身份预测的第一个预测值 with torch.no_grad(): p_a_teacher = predict_label( self.teacher_model, scale2(x_ba_copy), num_class=hyperparameters['ID_class'], alabel=l_a, slabel=l_b, teacher_style=hyperparameters['teacher_style']) self.loss_teacher = self.criterion_teacher( p_a_student, p_a_teacher) / p_a_student.size( 0) # 在老师模型监督下,x_ba身份预测损失 # 公式(8) _, p_ab_student = self.id_b(scale2(x_ab_copy)) # f_b, s_a p_b_student = log_sm(p_ab_student[0]) with torch.no_grad(): p_b_teacher = predict_label( self.teacher_model, scale2(x_ab_copy), num_class=hyperparameters['ID_class'], alabel=l_b, slabel=l_a, teacher_style=hyperparameters['teacher_style']) self.loss_teacher += self.criterion_teacher( p_b_student, p_b_teacher) / p_b_student.size( 0) # 在老师模型监督下,x_ab身份预测损失 # 公式 (8) # branch b loss # here we give different label # 用Ea的第二个身份预测值计算身份预测损失, # 这就相当于是Ea输出两个向量,一个用来计算与老师模型的身份预测损失,另一个用来计算自身身份预测损失 loss_B = self.id_criterion( p_ba_student[1], l_b) + self.id_criterion( p_ab_student[1], l_a) # l_b 是b的label # 公式(9) self.loss_teacher = hyperparameters[ 'T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B else: self.loss_teacher = 0.0 # auto-encoder image reconstruction self.loss_gen_recon_x_a = self.recon_criterion( x_a_recon, x_a) # x_a_recon, a 的 ap 和 a 的 st # 公式 (1) self.loss_gen_recon_x_b = self.recon_criterion( x_b_recon, x_b) # x_b_recon, b 的 ap 和 b 的 st # 公式 (1) self.loss_gen_recon_xp_a = self.recon_criterion( x_a_recon_p, x_a) # x_a_recon_p, a 的 st 和 pos_a 的 ap # 公式 (2) self.loss_gen_recon_xp_b = self.recon_criterion( x_b_recon_p, x_b) # x_b_recon_p, b 的 st 和 pos_b 的 ap # 公式 (2) # feature reconstruction self.loss_gen_recon_s_a = self.recon_criterion( s_a_recon, s_a) if hyperparameters[ 'recon_s_w'] > 0 else 0 # s_a_recon, 合成图片x_ab 的st # 公式 (5) self.loss_gen_recon_s_b = self.recon_criterion( s_b_recon, s_b) if hyperparameters[ 'recon_s_w'] > 0 else 0 # s_b_recon, 合成图片x_ba 的st # 公式 (5) self.loss_gen_recon_f_a = self.recon_criterion( f_a_recon, f_a) if hyperparameters[ 'recon_f_w'] > 0 else 0 # f_a_recon, 合成图片x_ba 的ap # 公式 (4) self.loss_gen_recon_f_b = self.recon_criterion( f_b_recon, f_b) if hyperparameters[ 'recon_f_w'] > 0 else 0 # f_b_recon, 合成图片x_ab 的ap # 公式 (4) x_aba = self.gen_a.decode(s_a_recon, f_a_recon) if hyperparameters[ 'recon_x_cyc_w'] > 0 else None # x_aba,ab 的 st 与 ba 的 ap x_bab = self.gen_b.decode(s_b_recon, f_b_recon) if hyperparameters[ 'recon_x_cyc_w'] > 0 else None # x_bab,ba 的 st 与 ab 的 ap # ID loss AND Tune the Generated image if hyperparameters['ID_style'] == 'PCB': self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b) self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b) self.loss_gen_recon_id = self.PCB_loss( p_a_recon, l_a) + self.PCB_loss( p_b_recon, l_b) # x_ba 与l_a, x_ab 与l_b 的身份预测损失 elif hyperparameters['ID_style'] == 'AB': weight_B = hyperparameters['teacher_w'] * hyperparameters[ 'B_w'] # teather_w = 1.0, B_w = 0.2 self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \ + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) ) # a和b的身份预测损失 # 公式(3) self.loss_pid = self.id_criterion( pp_a[0], l_a) + self.id_criterion( pp_b[0], l_b) # pos_a 和 pos_b 的身份预测损失 # 公式(3) self.loss_gen_recon_id = self.id_criterion( p_a_recon[0], l_a) + self.id_criterion( p_b_recon[0], l_b) # 不太懂为什么用了b的st 却要判定为a的label 公式(7) else: self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion( p_b, l_b) self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion( pp_b, l_b) self.loss_gen_recon_id = self.id_criterion( p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b) #print(f_a_recon, f_a) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters[ 'recon_x_cyc_w'] > 0 else 0 # x_aba,ab 的 st 与 ba 的 ap self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters[ 'recon_x_cyc_w'] > 0 else 0 # x_bab,ba 的 st 与 ab 的 ap # GAN loss if num_gpu > 1: self.loss_gen_adv_a = self.dis_a.module.calc_gen_loss( self.dis_a, x_ba) # 公式(6) self.loss_gen_adv_b = self.dis_b.module.calc_gen_loss( self.dis_b, x_ab) # 公式(6) else: self.loss_gen_adv_a = self.dis_a.calc_gen_loss(self.dis_a, x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(self.dis_b, x_ab) # domain-invariant perceptual loss # 使用vgg,对合成图片和真实图片进行特征提取,然后计算两个特征loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # 每个loss所占的权重 if iteration > hyperparameters['warm_iter']: hyperparameters['recon_f_w'] += hyperparameters['warm_scale'] hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w']) hyperparameters['recon_s_w'] += hyperparameters['warm_scale'] hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w']) hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale'] hyperparameters['recon_x_cyc_w'] = min( hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w']) if iteration > hyperparameters['warm_teacher_iter']: hyperparameters['teacher_w'] += hyperparameters['warm_scale'] hyperparameters['teacher_w'] = min( hyperparameters['teacher_w'], hyperparameters['max_teacher_w']) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['id_w'] * self.loss_id + \ hyperparameters['pid_w'] * self.loss_pid + \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ hyperparameters['teacher_w'] * self.loss_teacher # 增大计算效率 if self.fp16: with amp.scale_loss(self.loss_gen_total, [self.gen_opt, self.id_opt]) as scaled_loss: scaled_loss.backward() self.gen_opt.step() self.id_opt.step() else: self.loss_gen_total.backward() # 后向传播 self.gen_opt.step() self.id_opt.step() print("L_total: %.4f, L_gan: %.4f, Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \ hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \ hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \ hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \ hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \ hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \ hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \ hyperparameters['id_w'] * self.loss_id,\ hyperparameters['pid_w'] * self.loss_pid,\ hyperparameters['teacher_w'] * self.loss_teacher ) ) def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def PCB_loss(self, inputs, labels): loss = 0.0 for part in inputs: loss += self.id_criterion(part, labels) return loss / len(inputs) def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], [] for i in range(x_a.size(0)): s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0))) s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0))) f_a, _ = self.id_a(scale2(x_a[i].unsqueeze(0))) f_b, _ = self.id_b(scale2(x_b[i].unsqueeze(0))) x_a_recon.append(self.gen_a.decode(s_a, f_a)) x_b_recon.append(self.gen_b.decode(s_b, f_b)) x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) x_ba1.append(x_ba) x_ab1.append(x_ab) #cycle s_b_recon = self.gen_a.enc_content(self.single(x_ba)) s_a_recon = self.gen_b.enc_content(self.single(x_ab)) f_a_recon, _ = self.id_a(scale2(x_ba)) f_b_recon, _ = self.id_b(scale2(x_ab)) x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon)) x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1) self.train() return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1 def dis_update(self, x_ab, x_ba, x_a, x_b, hyperparameters, num_gpu): # 对判别器进行更新 self.dis_opt.zero_grad() # D loss if num_gpu > 1: self.loss_dis_a, reg_a = self.dis_a.module.calc_dis_loss( self.dis_a, x_ba.detach(), x_a) # lsgan 损失 self.loss_dis_b, reg_b = self.dis_b.module.calc_dis_loss( self.dis_b, x_ab.detach(), x_b) else: self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss( self.dis_a, x_ba.detach(), x_a) self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss( self.dis_b, x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b print("DLoss: %.4f" % self.loss_dis_total, "Reg: %.4f" % (reg_a + reg_b)) if self.fp16: with amp.scale_loss(self.loss_dis_total, self.dis_opt) as scaled_loss: scaled_loss.backward() else: self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): # 调整学习率 if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.id_scheduler is not None: self.id_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # load 网络 # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b = self.gen_a iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b = self.dis_a # Load ID dis last_model_name = get_model_list(checkpoint_dir, "id") state_dict = torch.load(last_model_name) self.id_a.load_state_dict(state_dict['a']) self.id_b = self.id_a # Load optimizers try: state_dict = torch.load( os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) self.id_opt.load_state_dict(state_dict['id']) except: pass # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations, num_gpu=1): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict()}, gen_name) if num_gpu > 1: torch.save({'a': self.dis_a.module.state_dict()}, dis_name) else: torch.save({'a': self.dis_a.state_dict()}, dis_name) torch.save({'a': self.id_a.state_dict()}, id_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'id': self.id_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class SupIntrinsicTrainer(nn.Module): def __init__(self, param): super(SupIntrinsicTrainer, self).__init__() lr = param['lr'] # Initiate the networks self.model = AdaINGen(param['input_dim_a'], param['input_dim_b'] + param['input_dim_b'], param['gen']) # auto-encoder # Setup the optimizers beta1 = param['beta1'] beta2 = param['beta2'] gen_params = list(self.model.parameters()) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=param['weight_decay']) self.gen_scheduler = get_scheduler(self.gen_opt, param) # Network weight initialization self.apply(weights_init(param['init'])) self.best_result = float('inf') def recon_criterion(self, input, target, mask=None): if mask is not None: return torch.mean(torch.abs(input[mask] - target[mask])) else: return torch.mean(torch.abs(input - target)) def forward(self, x): self.eval() out = self.model(x) x_r, x_s = out[:, :3, :, :], out[:, :, 3:, :] return x_r, x_s def gen_update(self, x_i, x_r, x_s, x_m, param): self.gen_opt.zero_grad() out = self.model(x_i) pred_r, pred_s = out[:, :3, :, :], out[:, 3:, :, :] # reconstruction loss self.loss_r = self.recon_criterion(pred_r, x_r, x_m) self.loss_s = self.recon_criterion(pred_s, x_s, x_m) # total loss self.loss_gen_total = self.loss_r + self.loss_s self.loss_gen_total.backward() self.gen_opt.step() def sample(self, x_i, x_r, x_s): self.eval() x_ri, x_si = [], [] for i in range(x_i.size(0)): out = self.model(x_i[i].unsqueeze(0)) x_r, x_s = out[:, :3, :, :], out[:, 3:, :, :] x_ri.append(x_r) x_si.append(x_s) x_ri = torch.cat(x_ri) x_si = torch.cat(x_si) self.train() return x_i, x_r, x_ri, x_s, x_si # noinspection PyAttributeOutsideInit def dis_update(self, x_i, x_r, x_s, param=None): pass def update_learning_rate(self): if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, param): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.model.load_state_dict(state_dict['i']) iterations = int(last_model_name[-11:-3]) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.gen_scheduler = get_scheduler(self.gen_opt, param, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save( { 'model': self.model.state_dict(), 'best_result': self.best_result }, gen_name) torch.save({'gen': self.gen_opt.state_dict()}, opt_name)
class Solver(object): """Solver for training and testing StarGAN.""" def __init__(self, face_loader, hyperparameters, opts): self.model_name = os.path.splitext(os.path.basename(opts.config))[0] self.face_loader = face_loader self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.style_dim = hyperparameters['gen']['style_dim'] self.lr = hyperparameters['lr'] self.label_dim = hyperparameters['label_dim'] self.gen_var = hyperparameters['gen'] self.dis_var = hyperparameters['dis'] self.max_iter = hyperparameters['max_iter'] self.lambda_cls = hyperparameters['lambda_cls'] self.lambda_gp = hyperparameters['lambda_gp'] # fix the noise used in sampling batch_size = int(hyperparameters['batch_size']) # randn()返回一个张量,包含了从标准正态分布中抽取一组随机数,形状由可变参数sizes定义 # Setup the optimizers self.beta1 = hyperparameters['beta1'] self.beta2 = hyperparameters['beta2'] self.n_critic = hyperparameters['n_critic'] self.lambda_rec = hyperparameters['lambda_rec'] self.recon_s_w = hyperparameters['recon_s_w'] self.recon_c_w = hyperparameters['recon_c_w'] self.log_iter = hyperparameters['log_iter'] self.image_save_iter = hyperparameters['image_save_iter'] self.snapshot_save_iter = hyperparameters['snapshot_save_iter'] self.output_path = opts.output_path self.model_save_dir = opts.model_save_dir self.step_size = hyperparameters['step_size'] self.lr_update_step = hyperparameters['lr_update_step'] self.result_dir = opts.result_dir self.num_style = opts.num_style self.log_path = opts.log_path self.build_model() def build_model(self): self.G = AdaINGen(self.label_dim, self.gen_var) self.D = Discriminator(self.label_dim, self.dis_var) self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.lr, [self.beta1, self.beta2]) self.G.to(self.device) self.D.to(self.device) def update_lr(self, g_lr, d_lr): """Decay learning rates of the generator and discriminator.""" for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr def reset_grad(self): """Reset the gradient buffers.""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() def denorm(self, x): """Convert the range from [-1, 1] to [0, 1].""" out = (x + 1) / 2 return out.clamp_(0, 1) def gradient_penalty(self, y, x): """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" weight = torch.ones(y.size()).to(self.device) dydx = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=weight, retain_graph=True, # 每次 backward() 时,默认会把整个计算图free掉。 create_graph=True, only_inputs=True)[0] dydx = dydx.view(dydx.size(0), -1) dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) return torch.mean((dydx_l2norm - 1)**2) def label2onehot(self, labels, dim): """Convert label indices to one-hot vectors.""" batch_size = labels.size(0) out = torch.zeros(batch_size, dim) out[np.arange(batch_size), labels.long()] = 1 return out def create_labels(self, c_org, c_dim=8): """Generate target domain labels for debugging and testing.""" c_trg_list = [] for i in range(c_dim): c_trg = self.label2onehot(torch.ones(c_org.size(0)) * i, c_dim) c_trg_list.append(c_trg.to(self.device)) return c_trg_list def classification_loss(self, logit, target): return F.cross_entropy(logit, target) def recon_criterion(self, input, target): # torch.mean()返回输入张量所有元素的均值;abs()输出张量的每个元素的绝对值 return torch.mean(torch.abs(input - target)) def train(self): writer = SummaryWriter(os.path.join(self.log_path, self.model_name)) data_loader = self.face_loader data_iter = iter(data_loader) # x_fixed表示图像像素值 c_org表示真实标签值 x_fixed, c_org = next(data_iter) # 得到一个batch的图片和标签 x_fixed = x_fixed.to(self.device) c_fixed_list = self.create_labels(c_org, self.label_dim) g_lr = self.lr d_lr = self.lr start_iters = 0 print('Start training...') start_time = time.time() for i in range(start_iters, self.max_iter): try: x_real, label_org = next(data_iter) except: data_iter = iter(data_loader) x_real, label_org = next(data_iter) # Generate target domain labels randomly. # 给定参数n,返回一个从0 到n -1 的随机整数排列。 rand_idx = torch.randperm(label_org.size(0)) # 随机生成目标标签label_trg label_trg = label_org[rand_idx] c_org = self.label2onehot(label_org, self.label_dim) c_trg = self.label2onehot(label_trg, self.label_dim) x_real = x_real.to(self.device) # Input images. c_org = c_org.to(self.device) # Original domain labels. c_trg = c_trg.to(self.device) # Target domain labels. label_org = label_org.to( self.device) # Labels for computing classification loss. label_trg = label_trg.to( self.device) # Labels for computing classification loss. # dis_update out_src, out_cls = self.D(x_real) d_loss_real = -torch.mean( out_src) # d_loss_real最小,那么out_src最大==1(针对图像) d_loss_cls = self.classification_loss(out_cls, label_org) # 针对标签 style = Variable( torch.randn(x_real.size(0), self.style_dim, 1, 1).to(self.device)) # encode content_fake, _ = self.G.encode(x_real, c_trg) # decode x_fake = self.G.decode(content_fake, style, c_trg) out_src, out_cls = self.D(x_fake.detach()) d_loss_fake = torch.mean(out_src) # 假图像为0 alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) out_src, _ = self.D(x_hat) d_loss_gp = self.gradient_penalty(out_src, x_hat) # 最终d_loss_gp在0.9954~0.9956波动 # Backward and optimize. d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging. loss = {} loss['D/loss_real'] = d_loss_real.item() loss['D/loss_fake'] = d_loss_fake.item() loss['D/loss_cls'] = d_loss_cls.item() loss['D/loss_gp'] = d_loss_gp.item() # gen_update if (i + 1) % self.n_critic == 0: # 每更新5次判别器再更新一次生成器 # encode content_real, style_real = self.G.encode(x_real, c_org) content_fake, style_fake = self.G.encode(x_real, c_trg) x_fake = self.G.decode(content_fake, style, c_trg) out_src, out_cls = self.D(x_fake) g_loss_fake = -torch.mean(out_src) g_loss_cls = self.classification_loss( out_cls, label_trg) # 估计标签越接近目标标签损失越小 x_recon = self.G.decode(content_real, style_real, c_org) g_loss_rec = torch.mean(torch.abs(x_real - x_recon)) # encode again content_recon, style_recon = self.G.encode(x_fake, c_trg) # reconstruction loss loss_gen_recon_style = self.recon_criterion(style_recon, style) loss_gen_recon_content = self.recon_criterion( content_recon, content_fake) # Backward and optimize.生成网络参数更新 g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + \ self.lambda_cls * g_loss_cls + \ self.recon_s_w * loss_gen_recon_style + \ self.recon_c_w * loss_gen_recon_content self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging. loss['G/loss_fake'] = g_loss_fake.item() loss['G/loss_rec'] = g_loss_rec.item() loss['G/loss_cls'] = g_loss_cls.item() loss['G/loss_style'] = loss_gen_recon_style.item() loss['G/loss_content'] = loss_gen_recon_content.item() # Miscellaneous if (i + 1) % self.log_iter == 0: et = time.time() - start_time et = str(datetime.timedelta(seconds=et))[:-7] log = "Elapsed [{}], Iteration [{}/{}]".format( et, i + 1, self.max_iter) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if (i + 1) % self.image_save_iter == 0: with torch.no_grad(): style1 = Variable( torch.randn(x_fixed.size(0), self.style_dim, 1, 1).to(self.device)) style2 = Variable( torch.randn(x_fixed.size(0), self.style_dim, 1, 1).to(self.device)) x_fake_list = [x_fixed] for c_fixed in c_fixed_list: content_fake, style_fake = self.G.encode( x_fixed, c_fixed) x_fake_list.append( self.G.decode(content_fake, style1, c_fixed)) x_fake_list.append( self.G.decode(content_fake, style2, c_fixed)) x_concat = torch.cat(x_fake_list, dim=3) sample_path = os.path.join(self.output_path, '{}-images.jpg'.format(i + 1)) save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0) print('Saved real and fake images into {}...'.format( sample_path)) # Save network weights if (i + 1) % self.snapshot_save_iter == 0: G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i + 1)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i + 1)) torch.save(self.G.state_dict(), G_path) torch.save(self.D.state_dict(), D_path) print('Saved model checkpoints into {}...'.format( self.model_save_dir)) # Decay learning rates. if (i + 1) % self.lr_update_step == 0 and (i + 1) > ( self.max_iter - self.step_size): g_lr -= (self.lr / float(self.step_size)) d_lr -= (self.lr / float(self.step_size)) self.update_lr(g_lr, d_lr) print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format( g_lr, d_lr)) if (i + 1) % self.log_iter == 0: writer.add_scalar('D/loss_real', d_loss_real, i) writer.add_scalar('D/loss_fake', d_loss_fake, i) writer.add_scalar('D/loss_cls', d_loss_cls, i) writer.add_scalar('G/loss_fake', g_loss_fake, i) writer.add_scalar('G/loss_rec', g_loss_rec, i) writer.add_scalar('G/loss_cls', g_loss_cls, i) writer.add_scalar('G/loss_style', loss_gen_recon_style, i) writer.add_scalar('G/loss_content', loss_gen_recon_content, i) writer.add_scalars('data/scalar_group', { 'D/loss': d_loss, 'G/loss': g_loss }, i) writer.close() def test(self): """Translate images using StarGAN trained on a single dataset.""" # Load the trained generator. self.restore_model(self.max_iter) # Set data loader. data_loader = self.face_loader with torch.no_grad(): for i, (x_real, c_org) in enumerate(data_loader): # Prepare input images and target domain labels. x_real = x_real.to(self.device) c_trg_list = self.create_labels(c_org, 8) # Translate images. x_fake_list = [x_real] for c_trg in c_trg_list: content_fake, style_fake = self.G.encode(x_real, c_trg) style_rand = Variable( torch.randn(self.num_style, self.style_dim, 1, 1).cuda()) for j in range(self.num_style): s = style_rand[j].unsqueeze(0) x_fake_list.append( self.G.decode(content_fake, s, c_trg)) # Save the translated images. x_concat = torch.cat(x_fake_list, dim=3) result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i + 1)) save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) print('Saved real and fake images into {}...'.format( result_path)) def restore_model(self, resume_iters): """Restore the trained generator and discriminator.""" print( 'Loading the trained models from step {}...'.format(resume_iters)) G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) self.G.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) self.D.load_state_dict( torch.load(D_path, map_location=lambda storage, loc: storage))
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters, opts): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] self.opts = opts # Initiate the networks self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.seg = segmentor(num_classes=2, channels=hyperparameters['input_dim_b'], hyperpars=hyperparameters['seg']) self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.seg_opt = torch.optim.SGD(self.seg.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters['lr_policy'], hyperparameters=hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters['lr_policy'], hyperparameters=hyperparameters) self.seg_scheduler = get_scheduler(self.seg_opt, 'constant', hyperparameters=None) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) self.criterion_seg = DiceLoss(ignore_index=hyperparameters['seg']['ignore_index']) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters, target_a, iters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None if iters >= hyperparameters['guide_gen_iters']: config.task = 0 self.seg.eval() self.pred_x_ab = self.seg(x_ab) self.seg.train() # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # semantic loss ab if iters >= hyperparameters['guide_gen_iters']: self.loss_sem_ab, _ = self.criterion_seg(self.pred_x_ab, target_a) else: self.loss_sem_ab = 0 # only use semantic loss when segmentor has reasonably low loss if not hasattr(self, 'loss_seg_ab') or self.loss_seg_ab.detach().item() > -0.3: self.loss_sem_ab = 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['sem_w'] * self.loss_sem_ab self.loss_gen_total.backward() self.gen_opt.step() def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def seg_update(self, x_a, x_b, target_a, target_b): self.seg.train() self.seg_opt.zero_grad() s_b = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) with torch.no_grad(): # encode c_a, _ = self.gen_a.encode(x_a) # decode (cross domain) x_ab = self.gen_b.decode(c_a, s_b) config.task = 0 self.pred_x_ab = self.seg(x_ab.detach()) config.task = 1 self.pred_x_b = self.seg(x_b) self.loss_seg_ab, _ = self.criterion_seg(self.pred_x_ab, target_a) self.loss_seg_b, _ = self.criterion_seg(self.pred_x_b, target_b) self.loss_seg_total = self.loss_seg_ab + self.loss_seg_b self.loss_seg_total.backward() self.seg_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.seg_scheduler is not None: self.seg_scheduler.step() def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_aba, x_bab, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [], [], [] for i in range(x_b.size(0)): # encode c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) # decode (within domain) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) # decode (cross domain) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba1[-1]) c_a_recon, s_b_recon = self.gen_b.encode(x_ab1[-1]) x_aba.append(self.gen_a.decode(c_a_recon, s_a_fake)) x_bab.append(self.gen_b.decode(c_b_recon, s_b_fake)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) self.train() return x_a, x_a_recon, x_aba, x_ab1, x_ab2, x_b, x_b_recon, x_bab, x_ba1, x_ba2 def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, 'gen') state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, 'dis') state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load segmentor last_model_name = get_model_list(checkpoint_dir, 'seg') state_dict = torch.load(last_model_name) self.seg.load_state_dict(state_dict) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'opt.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) state_dict = torch.load(os.path.join(checkpoint_dir, 'opt_seg.pt')) self.seg_opt.load_state_dict(state_dict) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters['lr_policy'], hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters['lr_policy'], hyperparameters, iterations) self.seg_scheduler = get_scheduler(self.seg_opt, 'constant', None, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) seg_name = os.path.join(snapshot_dir, 'seg_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'opt.pt') opt_seg_name = os.path.join(snapshot_dir, 'opt_seg.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save(self.seg.state_dict(), seg_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name) torch.save(self.seg_opt.state_dict(), opt_seg_name)
class DGNet_Trainer(nn.Module): def __init__(self, hyperparameters): super(DGNet_Trainer, self).__init__() lr_g = hyperparameters['lr_g'] lr_d = hyperparameters['lr_d'] ID_class = hyperparameters['ID_class'] if not 'apex' in hyperparameters.keys(): hyperparameters['apex'] = False self.fp16 = hyperparameters['apex'] # Initiate the networks # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False. self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16=False) # auto-encoder for domain a self.gen_b = self.gen_a # auto-encoder for domain b if not 'ID_stride' in hyperparameters.keys(): hyperparameters['ID_stride'] = 2 if hyperparameters['ID_style'] == 'PCB': self.id_a = PCB(ID_class) elif hyperparameters['ID_style'] == 'AB': self.id_a = ft_netAB(ID_class, stride=hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) else: self.id_a = ft_net(ID_class, norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) # return 2048 now self.id_b = self.id_a self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16=False) # discriminator for domain a self.dis_b = self.dis_a # discriminator for domain b # load teachers if hyperparameters['teacher'] != "": teacher_name = hyperparameters['teacher'] print(teacher_name) teacher_names = teacher_name.split(',') teacher_model = nn.ModuleList() teacher_count = 0 for teacher_name in teacher_names: config_tmp = load_config(teacher_name) if 'stride' in config_tmp: stride = config_tmp['stride'] else: stride = 2 model_tmp = ft_net(ID_class, stride=stride) teacher_model_tmp = load_network(model_tmp, teacher_name) teacher_model_tmp.model.fc = nn.Sequential( ) # remove the original fc layer in ImageNet teacher_model_tmp = teacher_model_tmp.cuda() if self.fp16: teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1") teacher_model.append(teacher_model_tmp.cuda().eval()) teacher_count += 1 self.teacher_model = teacher_model if hyperparameters['train_bn']: self.teacher_model = self.teacher_model.apply(train_bn) self.instancenorm = nn.InstanceNorm2d(512, affine=False) display_size = int(hyperparameters['display_size']) # RGB to one channel if hyperparameters['single'] == 'edge': self.single = to_edge else: self.single = to_gray(False) # Random Erasing when training if not 'erasing_p' in hyperparameters.keys(): hyperparameters['erasing_p'] = 0 self.single_re = RandomErasing( probability=hyperparameters['erasing_p'], mean=[0.0, 0.0, 0.0]) if not 'T_w' in hyperparameters.keys(): hyperparameters['T_w'] = 1 # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list( self.dis_a.parameters()) #+ list(self.dis_b.parameters()) gen_params = list( self.gen_a.parameters()) #+ list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) # id params if hyperparameters['ID_style'] == 'PCB': ignored_params = ( list(map(id, self.id_a.classifier0.parameters())) + list(map(id, self.id_a.classifier1.parameters())) + list(map(id, self.id_a.classifier2.parameters())) + list(map(id, self.id_a.classifier3.parameters()))) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier0.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier1.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier2.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier3.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) elif hyperparameters['ID_style'] == 'AB': ignored_params = ( list(map(id, self.id_a.classifier1.parameters())) + list(map(id, self.id_a.classifier2.parameters()))) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier1.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier2.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) else: ignored_params = list(map(id, self.id_a.classifier.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) self.id_scheduler = get_scheduler(self.id_opt, hyperparameters) self.id_scheduler.gamma = hyperparameters['gamma2'] #ID Loss self.id_criterion = nn.CrossEntropyLoss() self.criterion_teacher = nn.KLDivLoss(size_average=False) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False # save memory if self.fp16: # Name the FP16_Optimizer instance to replace the existing optimizer assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." self.gen_a = self.gen_a.cuda() self.dis_a = self.dis_a.cuda() self.id_a = self.id_a.cuda() self.gen_b = self.gen_a self.dis_b = self.dis_a self.id_b = self.id_a self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1") self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1") self.id_a, self.id_opt = amp.initialize(self.id_a, self.id_opt, opt_level="O1") def to_re(self, x): out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3)) out = out.cuda() for i in range(x.size(0)): out[i, :, :, :] = self.single_re(x[i, :, :, :]) return out def recon_criterion(self, input, target): diff = input - target.detach() return torch.mean(torch.abs(diff[:])) def recon_criterion_sqrt(self, input, target): diff = input - target return torch.mean(torch.sqrt(torch.abs(diff[:]) + 1e-8)) def recon_criterion2(self, input, target): diff = input - target return torch.mean(diff[:]**2) def recon_cos(self, input, target): cos = torch.nn.CosineSimilarity() cos_dis = 1 - cos(input, target) return torch.mean(cos_dis[:]) def forward(self, x_a, x_b): self.eval() s_a = self.gen_a.encode(self.single(x_a)) s_b = self.gen_b.encode(self.single(x_b)) f_a, _ = self.id_a(scale2(x_a)) f_b, _ = self.id_b(scale2(x_b)) x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) self.train() return x_ab, x_ba def gen_update(self, x_a, l_a, xp_a, x_b, l_b, xp_b, hyperparameters, iteration): # ppa, ppb is the same person self.gen_opt.zero_grad() self.id_opt.zero_grad() # encode s_a = self.gen_a.encode(self.single(x_a)) s_b = self.gen_b.encode(self.single(x_b)) f_a, p_a = self.id_a(scale2(x_a)) f_b, p_b = self.id_b(scale2(x_b)) # autodecode x_a_recon = self.gen_a.decode(s_a, f_a) x_b_recon = self.gen_b.decode(s_b, f_b) # encode the same ID different photo fp_a, pp_a = self.id_a(scale2(xp_a)) fp_b, pp_b = self.id_b(scale2(xp_b)) # decode the same person x_a_recon_p = self.gen_a.decode(s_a, fp_a) x_b_recon_p = self.gen_b.decode(s_b, fp_b) # has gradient x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) # no gradient x_ba_copy = Variable(x_ba.data, requires_grad=False) x_ab_copy = Variable(x_ab.data, requires_grad=False) rand_num = random.uniform(0, 1) ################################# # encode structure if hyperparameters['use_encoder_again'] >= rand_num: # encode again (encoder is tuned, input is fixed) s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy)) s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy)) else: # copy the encoder self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content) self.enc_content_copy = self.enc_content_copy.eval() # encode again (encoder is fixed, input is tuned) s_a_recon = self.enc_content_copy(self.single(x_ab)) s_b_recon = self.enc_content_copy(self.single(x_ba)) ################################# # encode appearance self.id_a_copy = copy.deepcopy(self.id_a) self.id_a_copy = self.id_a_copy.eval() if hyperparameters['train_bn']: self.id_a_copy = self.id_a_copy.apply(train_bn) self.id_b_copy = self.id_a_copy # encode again (encoder is fixed, input is tuned) f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba)) f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab)) # teacher Loss # Tune the ID model log_sm = nn.LogSoftmax(dim=1) if hyperparameters['teacher_w'] > 0 and hyperparameters[ 'teacher'] != "": if hyperparameters['ID_style'] == 'normal': _, p_a_student = self.id_a(scale2(x_ba_copy)) p_a_student = log_sm(p_a_student) p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy)) self.loss_teacher = self.criterion_teacher( p_a_student, p_a_teacher) / p_a_student.size(0) _, p_b_student = self.id_b(scale2(x_ab_copy)) p_b_student = log_sm(p_b_student) p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy)) self.loss_teacher += self.criterion_teacher( p_b_student, p_b_teacher) / p_b_student.size(0) elif hyperparameters['ID_style'] == 'AB': # normal teacher-student loss # BA -> LabelA(smooth) + LabelB(batchB) _, p_ba_student = self.id_a(scale2(x_ba_copy)) # f_a, s_b p_a_student = log_sm(p_ba_student[0]) with torch.no_grad(): p_a_teacher = predict_label( self.teacher_model, scale2(x_ba_copy), num_class=hyperparameters['ID_class'], alabel=l_a, slabel=l_b, teacher_style=hyperparameters['teacher_style']) self.loss_teacher = self.criterion_teacher( p_a_student, p_a_teacher) / p_a_student.size(0) _, p_ab_student = self.id_b(scale2(x_ab_copy)) # f_b, s_a p_b_student = log_sm(p_ab_student[0]) with torch.no_grad(): p_b_teacher = predict_label( self.teacher_model, scale2(x_ab_copy), num_class=hyperparameters['ID_class'], alabel=l_b, slabel=l_a, teacher_style=hyperparameters['teacher_style']) self.loss_teacher += self.criterion_teacher( p_b_student, p_b_teacher) / p_b_student.size(0) # branch b loss # here we give different label loss_B = self.id_criterion(p_ba_student[1], l_b) + self.id_criterion( p_ab_student[1], l_a) self.loss_teacher = hyperparameters[ 'T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B else: self.loss_teacher = 0.0 # decode again (if needed) if hyperparameters['use_decoder_again']: x_aba = self.gen_a.decode( s_a_recon, f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( s_b_recon, f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None else: self.mlp_w_copy = copy.deepcopy(self.gen_a.mlp_w) self.mlp_b_copy = copy.deepcopy(self.gen_a.mlp_b) self.dec_copy = copy.deepcopy(self.gen_a.dec) # Error ID = f_a_recon ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1) adain_params_w = self.mlp_w_copy(ID_Style) adain_params_b = self.mlp_b_copy(ID_Style) self.gen_a.assign_adain_params(adain_params_w, adain_params_b, self.dec_copy) x_aba = self.dec_copy( s_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None ID = f_b_recon ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1) adain_params_w = self.mlp_w_copy(ID_Style) adain_params_b = self.mlp_b_copy(ID_Style) self.gen_a.assign_adain_params(adain_params_w, adain_params_b, self.dec_copy) x_bab = self.dec_copy( s_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # auto-encoder image reconstruction self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a) self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b) # feature reconstruction self.loss_gen_recon_s_a = self.recon_criterion( s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0 self.loss_gen_recon_s_b = self.recon_criterion( s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0 self.loss_gen_recon_f_a = self.recon_criterion( f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0 self.loss_gen_recon_f_b = self.recon_criterion( f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0 # Random Erasing only effect the ID and PID loss. if hyperparameters['erasing_p'] > 0: x_a_re = self.to_re(scale2(x_a.clone())) x_b_re = self.to_re(scale2(x_b.clone())) xp_a_re = self.to_re(scale2(xp_a.clone())) xp_b_re = self.to_re(scale2(xp_b.clone())) _, p_a = self.id_a(x_a_re) _, p_b = self.id_b(x_b_re) # encode the same ID different photo _, pp_a = self.id_a(xp_a_re) _, pp_b = self.id_b(xp_b_re) # ID loss AND Tune the Generated image if hyperparameters['ID_style'] == 'PCB': self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b) self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b) self.loss_gen_recon_id = self.PCB_loss( p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b) elif hyperparameters['ID_style'] == 'AB': weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w'] self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \ + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) ) self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion( pp_b[0], l_b ) #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) ) self.loss_gen_recon_id = self.id_criterion( p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b) else: self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion( p_b, l_b) self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion( pp_b, l_b) self.loss_gen_recon_id = self.id_criterion( p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b) #print(f_a_recon, f_a) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 if iteration > hyperparameters['warm_iter']: hyperparameters['recon_f_w'] += hyperparameters['warm_scale'] hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w']) hyperparameters['recon_s_w'] += hyperparameters['warm_scale'] hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w']) hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale'] hyperparameters['recon_x_cyc_w'] = min( hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w']) if iteration > hyperparameters['warm_teacher_iter']: hyperparameters['teacher_w'] += hyperparameters['warm_scale'] hyperparameters['teacher_w'] = min( hyperparameters['teacher_w'], hyperparameters['max_teacher_w']) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['id_w'] * self.loss_id + \ hyperparameters['pid_w'] * self.loss_pid + \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ hyperparameters['teacher_w'] * self.loss_teacher if self.fp16: with amp.scale_loss(self.loss_gen_total, [self.gen_opt, self.id_opt]) as scaled_loss: scaled_loss.backward() self.gen_opt.step() self.id_opt.step() else: self.loss_gen_total.backward() self.gen_opt.step() self.id_opt.step() print("L_total: %.4f, L_gan: %.4f, Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \ hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \ hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \ hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \ hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \ hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \ hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \ hyperparameters['id_w'] * self.loss_id,\ hyperparameters['pid_w'] * self.loss_pid,\ hyperparameters['teacher_w'] * self.loss_teacher ) ) def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def PCB_loss(self, inputs, labels): loss = 0.0 for part in inputs: loss += self.id_criterion(part, labels) return loss / len(inputs) def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], [] for i in range(x_a.size(0)): s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0))) s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0))) f_a, _ = self.id_a(scale2(x_a[i].unsqueeze(0))) f_b, _ = self.id_b(scale2(x_b[i].unsqueeze(0))) x_a_recon.append(self.gen_a.decode(s_a, f_a)) x_b_recon.append(self.gen_b.decode(s_b, f_b)) x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) x_ba1.append(x_ba) x_ab1.append(x_ab) #cycle s_b_recon = self.gen_a.enc_content(self.single(x_ba)) s_a_recon = self.gen_b.enc_content(self.single(x_ab)) f_a_recon, _ = self.id_a(scale2(x_ba)) f_b_recon, _ = self.id_b(scale2(x_ab)) x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon)) x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1) self.train() return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() # encode s_a = self.gen_a.encode(self.single(x_a)) s_b = self.gen_b.encode(self.single(x_b)) f_a, _ = self.id_a(scale2(x_a)) f_b, _ = self.id_b(scale2(x_b)) # decode (cross domain) x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) # D loss self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b print("DLoss: %.4f" % self.loss_dis_total, "Reg: %.4f" % (reg_a + reg_b)) if self.fp16: with amp.scale_loss(self.loss_dis_total, self.dis_opt) as scaled_loss: scaled_loss.backward() else: self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.id_scheduler is not None: self.id_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b = self.gen_a iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b = self.dis_a # Load ID dis last_model_name = get_model_list(checkpoint_dir, "id") state_dict = torch.load(last_model_name) self.id_a.load_state_dict(state_dict['a']) self.id_b = self.id_a # Load optimizers try: state_dict = torch.load( os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) self.id_opt.load_state_dict(state_dict['id']) except: pass # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict()}, dis_name) torch.save({'a': self.id_a.state_dict()}, id_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'id': self.id_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling self.display_size = int(hyperparameters['display_size']) self.s_a = self.random_style() self.s_b = self.random_style() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): """ Args: x_a: Image domain A x_b: Image domain B hyperparameters: Returns: """ self.gen_opt.zero_grad() s_a = self.random_style(x_a) s_b = self.random_style(x_b) # encode c_a, s_a_prime = self.gen_a.encode(x_a) # c_a - content encoding, s_a_prime - style encoding c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) # x_a_recon - reconstruction from content and style vectors x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) # content b, style a x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) # encode to get content_b and style_a from cross domain image c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def random_style(self, x=None, factor=1): dim = self.display_size if x is None else x.size(0) return Variable(torch.randn(dim, self.style_dim, 1, 1).cuda()) * factor def sample(self, x_a, x_b): """ Args: x_a: x_b: Returns: (tuple): domainA: original reconstruction A to B - fixed sample noise A to B - random noise """ self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = self.random_style(x_a, factor=5) s_b2 = self.random_style(x_b, factor=5) #print(s_a1, s_a2) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def sampleB_toA(self, x_b): """ Args: x_b: Returns: (tuple, length=batch): INPUT IMAGES to domain A """ self.eval() s_a2 = self.random_style(x_b) x_ba1 = [] for i in range(x_b.size(0)): # loop through batches c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_ba1.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) return x_ba1 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = self.random_style(x_a) s_b = self.random_style(x_b) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters, gen_model=None, dis_model=None): # Load generators if gen_model is None: gen_model = get_model_list(checkpoint_dir, "gen") # last gen model gen_state_dict = torch.load(gen_model) self.gen_a.load_state_dict(gen_state_dict['a']) self.gen_b.load_state_dict(gen_state_dict['b']) iterations = int(gen_model[-11:-3]) # Load discriminators if dis_model is None: dis_model = get_model_list(checkpoint_dir, "dis") dis_state_dict = torch.load(dis_model) self.dis_a.load_state_dict(dis_state_dict['a']) self.dis_b.load_state_dict(dis_state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def load_model(self, checkpoint_dir, hyperparameters, iteration): from pathlib import Path gen_model = Path(checkpoint_dir) / f"gen_{iteration:08d}.pt" dis_model = Path(checkpoint_dir) / f"dis_{iteration:08d}.pt" return self.resume(checkpoint_dir, hyperparameters, gen_model.as_posix(), dis_model.as_posix()) def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class ERGAN_Trainer(nn.Module): def __init__(self, hyperparameters): super(ERGAN_Trainer, self).__init__() lr_G = hyperparameters['lr_G'] lr_D = hyperparameters['lr_D'] print(lr_D, lr_G) self.fp16 = hyperparameters['fp16'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.gen_b.enc_content = self.gen_a.enc_content # content share weight #self.gen_b.enc_style = self.gen_a.enc_style self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] self.a = hyperparameters['gen']['new_size'] / 224 # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr_D, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr_G, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) if self.fp16: self.gen_a = self.gen_a.cuda() self.dis_a = self.dis_a.cuda() self.gen_b = self.gen_b.cuda() self.dis_b = self.dis_b.cuda() self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1") self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1") # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): input = input.type_as(target) return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a, x_b) x_ab = self.gen_b.decode(c_a, s_b, x_a) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() #mask = torch.ones(x_a.shape).cuda() block_a = x_a.clone() block_b = x_b.clone() block_a[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)] = 0 block_b[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)] = 0 # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime, x_a) x_b_recon = self.gen_b.decode(c_b, s_b_prime, x_b) # decode (cross domain) # decode random # x_ba_randn = self.gen_a.decode(c_b, s_a, x_b) # x_ab_randn = self.gen_b.decode(c_a, s_b, x_a) # decode real x_ba_real = self.gen_a.decode(c_b, s_a_prime, x_b) x_ab_real = self.gen_b.decode(c_a, s_b_prime, x_a) block_ba_real = x_ba_real.clone() block_ab_real = x_ab_real.clone() block_ba_real[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)] = 0 block_ab_real[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)] = 0 # encode again # c_b_recon, s_a_recon = self.gen_a.encode(x_ba_randn) # c_a_recon, s_b_recon = self.gen_b.encode(x_ab_randn) c_b_real_recon, s_a_prime_recon = self.gen_a.encode(x_ba_real) c_a_real_recon, s_b_prime_recon = self.gen_b.encode(x_ab_real) # decode again (if needed) x_aba = self.gen_a.decode( c_a_real_recon, s_a_prime, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_b_real_recon, s_b_prime, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_res_a = self.recon_criterion( block_ab_real, block_a) self.loss_gen_recon_res_b = self.recon_criterion( block_ba_real, block_b) self.loss_gen_recon_x_a_re = self.recon_criterion( x_a_recon[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)], x_a[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)]) self.loss_gen_recon_x_b_re = self.recon_criterion( x_b_recon[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)], x_b[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)] ) # both celebA and MeGlass: [92:144, 48:172] self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a_prime = self.recon_criterion( s_a_prime_recon, s_a_prime) self.loss_gen_recon_s_b_prime = self.recon_criterion( s_b_prime_recon, s_b_prime) self.loss_gen_recon_c_a_real = self.recon_criterion( c_a_real_recon, c_a) self.loss_gen_recon_c_b_real = self.recon_criterion( c_b_real_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a_real = self.dis_a.calc_gen_loss(x_ba_real) self.loss_gen_adv_b_real = self.dis_b.calc_gen_loss(x_ab_real) # domain-invariant perceptual loss # self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba_randn, x_b) if hyperparameters['vgg_w'] > 0 else 0 # self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab_randn, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a_real + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_x_w_re'] * self.loss_gen_recon_x_b_re + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b_prime + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b_real + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b+\ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_x_w_re'] * self.loss_gen_recon_x_a_re + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a_prime + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a_real + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b_real +\ hyperparameters['recon_x_w_res'] * self.loss_gen_recon_res_b + \ hyperparameters['recon_x_w_res'] * self.loss_gen_recon_res_b if self.fp16: with amp.scale_loss(self.loss_gen_total, self.gen_opt) as scaled_loss: scaled_loss.backward() self.gen_opt.step() else: self.loss_gen_total.backward() self.gen_opt.step() #loss_gan = hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b) loss_gan_real = hyperparameters['gan_w'] * (self.loss_gen_adv_a_real + self.loss_gen_adv_b_real) loss_x = hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b) loss_x_re = hyperparameters['recon_x_w_re'] * ( self.loss_gen_recon_x_a_re + self.loss_gen_recon_x_b_re) #loss_x_res = hyperparameters['recon_x_w_res'] * (self.loss_gen_recon_res_a + self.loss_gen_recon_res_b) #loss_s = hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b) #loss_c = hyperparameters['recon_c_w'] * (self.loss_gen_recon_c_a + self.loss_gen_recon_c_b) loss_s_prime = hyperparameters['recon_s_w'] * ( self.loss_gen_recon_s_a_prime + self.loss_gen_recon_s_b_prime) loss_c_real = hyperparameters['recon_c_w'] * ( self.loss_gen_recon_c_a_real + self.loss_gen_recon_c_b_real) loss_x_cyc = hyperparameters['recon_x_cyc_w'] * ( self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b) #loss_vgg = hyperparameters['vgg_w'] * (self.loss_gen_vgg_a + self.loss_gen_vgg_b) print( '||total:%.2f||gan_real:%.2f||x:%.2f||x_re:%.2f||s_prime:%.4f||c_real:%.2f||x_cyc:%.4f||' % (self.loss_gen_total, loss_gan_real, loss_x, loss_x_re, loss_s_prime, loss_c_real, loss_x_cyc)) def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() # s_a1 = Variable(self.s_a) # s_b1 = Variable(self.s_b) # s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) #x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] x_a_recon, x_b_recon, x_bab, x_ab, x_ba, x_aba = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append( self.gen_a.decode(c_a, s_a_fake, x_a[i].unsqueeze(0))) x_b_recon.append( self.gen_b.decode(c_b, s_b_fake, x_b[i].unsqueeze(0))) x_ba_tmp = self.gen_a.decode(c_b, s_a_fake, x_b[i].unsqueeze(0)) x_ab_tmp = self.gen_b.decode(c_a, s_b_fake, x_a[i].unsqueeze(0)) x_ba.append(x_ba_tmp) x_ab.append(x_ab_tmp) c_b_recon, _ = self.gen_a.encode(x_ba_tmp) c_a_recon, _ = self.gen_b.encode(x_ab_tmp) x_aba.append( self.gen_a.decode(c_a_recon, s_a_fake, x_a[i].unsqueeze(0))) x_bab.append( self.gen_b.decode(c_b_recon, s_b_fake, x_b[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) x_ab, x_ba = torch.cat(x_ab), torch.cat(x_ba) self.train() return x_a, x_a_recon, x_ab, x_ba, x_b, x_aba, x_b, x_b_recon, x_ba, x_ab, x_a, x_bab def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() # s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (cross domain) x_ba_real = self.gen_a.decode(c_b, s_a_prime, x_b) x_ab_real = self.gen_b.decode(c_a, s_b_prime, x_a) # D loss # self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba_randn.detach(), x_a) # self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab_randn.detach(), x_b) self.loss_dis_a_real = self.dis_a.calc_dis_loss( x_ba_real.detach(), x_a) self.loss_dis_b_real = self.dis_b.calc_dis_loss( x_ab_real.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * ( self.loss_dis_a_real + self.loss_dis_b_real) if self.fp16: with amp.scale_loss(self.loss_dis_total, self.dis_opt) as scaled_loss: scaled_loss.backward() self.dis_opt.step() else: self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations - 375000) #fine_tune -370000 self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations - 375000) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.content_classifier = ContentClassifier( hyperparameters['gen']['dim'], hyperparameters) self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] self.num_con_c = hyperparameters['dis']['num_con_c'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] # dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) dis_named_params = list(self.dis_a.named_parameters()) + list( self.dis_b.named_parameters()) # gen_named_params = list(self.gen_a.named_parameters()) + list(self.gen_b.named_parameters()) ### modifying list params dis_params = list() # gen_params = list() for name, param in dis_named_params: if "_Q" in name: # print('%s --> gen_params' % name) gen_params.append(param) else: dis_params.append(param) content_classifier_params = list(self.content_classifier.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.cla_opt = torch.optim.Adam( [p for p in content_classifier_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) self.cla_scheduler = get_scheduler(self.cla_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) self.content_classifier.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False self.gan_type = hyperparameters['dis']['gan_type'] self.criterionQ_con = NormalNLLLoss() self.criterion_content_classifier = nn.CrossEntropyLoss() # self.batch_size = hyperparameters['batch_size'] self.batch_size_val = hyperparameters['batch_size_val'] # self.accu_content_classifier_c_a = 0 # self.accu_content_classifier_c_a_recon = 0 # self.accu_content_classifier_c_b = 0 # self.accu_content_classifier_c_b_recon = 0 # self.accu_CC_all = 0 def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, sample_b, hyperparameters, sample_a_limited): x_b, label_b = sample_b x_a_limited, label_a_limited = sample_a_limited self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) # x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) # GAN loss # self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) x_ba_dis_out = self.dis_a(x_ba) self.loss_gen_adv_a = self.compute_gen_adv_loss(x_ba_dis_out) # self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) x_ab_dis_out = self.dis_b(x_ab) self.loss_gen_adv_b = self.compute_gen_adv_loss(x_ab_dis_out) # loss info continuous self.info_cont_loss_a = self.compute_info_cont_loss(s_a, x_ba_dis_out) self.info_cont_loss_b = self.compute_info_cont_loss(s_b, x_ab_dis_out) # label_predict_c_a = self.content_classifier(c_a) # label_predict_c_a_recon = self.content_classifier(c_a_recon) label_predict_c_b = self.content_classifier(c_b) label_predict_c_b_recon = self.content_classifier(c_b_recon) # loss_content_classifier_c_a = self.compute_content_classifier_loss(label_predict_c_a, label_a) # loss_content_classifier_c_a_recon = self.compute_content_classifier_loss(label_predict_c_a_recon, label_a) ### compute loss of classifier c_a based on limited samples c_a_limited, _ = self.gen_a.encode(x_a_limited) label_predict_c_a_limited = self.content_classifier(c_a_limited) x_ab_limited = self.gen_b.decode(c_a_limited, s_b) c_a_recon_limited, _ = self.gen_b.encode(x_ab_limited) label_predict_c_a_recon_limited = self.content_classifier( c_a_recon_limited) loss_content_classifier_c_a = self.compute_content_classifier_loss( label_predict_c_a_limited, label_a_limited) loss_content_classifier_c_a_recon = self.compute_content_classifier_loss( label_predict_c_a_recon_limited, label_a_limited) loss_content_classifier_b = self.compute_content_classifier_loss( label_predict_c_b, label_b) loss_content_classifier_c_b_recon = self.compute_content_classifier_loss( label_predict_c_b_recon, label_b) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ self.info_cont_loss_a + \ self.info_cont_loss_b +\ loss_content_classifier_c_a + \ loss_content_classifier_c_a_recon + \ loss_content_classifier_b + \ loss_content_classifier_c_b_recon self.loss_gen_total.backward() self.gen_opt.step() def compute_info_cont_loss(self, style_code, outs_fake): loss = 0 num_cont_code = self.num_con_c for it, (out_fake) in enumerate(outs_fake): q_mu = out_fake['mu'] q_var = out_fake['var'] info_noise = style_code[:, -num_cont_code:].view( -1, num_cont_code).squeeze().squeeze() # print(q_mu.size()) # print(q_var.size()) # print(info_noise.size()) # print(num_cont_code) # exit() loss += self.criterionQ_con(info_noise, q_mu, q_var) * 0.1 return loss def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def cla_update(self, sample_a, sample_b): x_a, label_a = sample_a x_b, label_b = sample_b # print('cla_update') # print(x_a.device()) # exit() self.cla_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # print("c_a") # print(c_a.size()) # exit() # decode (within domain) # x_a_recon = self.gen_a.decode(c_a, s_a_prime) # x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) label_predict_c_a = self.content_classifier(c_a) label_predict_c_a_recon = self.content_classifier(c_a_recon) label_predict_c_b = self.content_classifier(c_b) label_predict_c_b_recon = self.content_classifier(c_b_recon) self.loss_content_classifier_c_a = self.compute_content_classifier_loss( label_predict_c_a, label_a) self.loss_content_classifier_c_a_recon = self.compute_content_classifier_loss( label_predict_c_a_recon, label_a) # self.loss_content_classifier_c_a_and_c_a_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_a_recon, label_predict_c_a) self.loss_content_classifier_b = self.compute_content_classifier_loss( label_predict_c_b, label_b) self.loss_content_classifier_c_b_recon = self.compute_content_classifier_loss( label_predict_c_b_recon, label_b) # self.loss_content_classifier_c_b_and_c_b_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_b_recon, label_predict_c_b) # self.accu_content_classifier_c_a = self.compute_content_classifier_accuracy(label_predict_c_a, label_a) # self.accu_content_classifier_c_a_recon = self.compute_content_classifier_accuracy(label_predict_c_a_recon, # label_a) # self.accu_content_classifier_c_b = self.compute_content_classifier_accuracy(label_predict_c_b, label_b) # self.accu_content_classifier_c_b_recon = self.compute_content_classifier_accuracy(label_predict_c_b_recon, # label_b) # self.accu_CC_all = self.mean_list([ # self.accu_content_classifier_c_a, # self.accu_content_classifier_c_a_recon, # self.accu_content_classifier_c_b, # self.accu_content_classifier_c_b_recon # ]) self.loss_cla_total = self.loss_content_classifier_c_a + self.loss_content_classifier_c_a_recon + \ self.loss_content_classifier_b + self.loss_content_classifier_c_b_recon # self.loss_content_classifier_c_a_and_c_a_recon + \ # self.loss_content_classifier_c_b_and_c_b_recon self.loss_cla_total.backward() self.cla_opt.step() def cla_inference(self, test_loader_a, test_loader_b): accu_content_classifier_c_a = [] accu_content_classifier_c_a_recon = [] accu_content_classifier_c_b = [] accu_content_classifier_c_b_recon = [] for it_inf, (samples_a_test, samples_b_test) in enumerate( zip(test_loader_a, test_loader_b)): x_a, label_a = samples_a_test[0].cuda().detach( ), samples_a_test[1].cuda().detach() x_b, label_b = samples_b_test[0].cuda().detach( ), samples_b_test[1].cuda().detach() s_a = Variable( torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable( torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # print("c_a") # print(c_a.size()) # exit() # decode (within domain) # x_a_recon = self.gen_a.decode(c_a, s_a_prime) # x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) label_predict_c_a = self.content_classifier(c_a) label_predict_c_a_recon = self.content_classifier(c_a_recon) label_predict_c_b = self.content_classifier(c_b) label_predict_c_b_recon = self.content_classifier(c_b_recon) # self.loss_content_classifier_c_a = self.compute_content_classifier_loss(label_predict_c_a, label_a) # self.loss_content_classifier_c_a_recon = self.compute_content_classifier_loss(label_predict_c_a_recon, label_a) # self.loss_content_classifier_c_a_and_c_a_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_a_recon, label_predict_c_a) # # self.loss_content_classifier_b = self.compute_content_classifier_loss(label_predict_c_b, label_b) # self.loss_content_classifier_c_b_recon = self.compute_content_classifier_loss(label_predict_c_b_recon, label_b) # self.loss_content_classifier_c_b_and_c_b_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_b_recon, # label_predict_c_b) accu_content_classifier_c_a.append( self.compute_content_classifier_accuracy( label_predict_c_a, label_a)) accu_content_classifier_c_a_recon.append( self.compute_content_classifier_accuracy( label_predict_c_a_recon, label_a)) accu_content_classifier_c_b.append( self.compute_content_classifier_accuracy( label_predict_c_b, label_b)) accu_content_classifier_c_b_recon.append( self.compute_content_classifier_accuracy( label_predict_c_b_recon, label_b)) self.accu_content_classifier_c_a = self.mean_list( accu_content_classifier_c_a) self.accu_content_classifier_c_a_recon = self.mean_list( accu_content_classifier_c_a_recon) self.accu_content_classifier_c_b = self.mean_list( accu_content_classifier_c_b) self.accu_content_classifier_c_b_recon = self.mean_list( accu_content_classifier_c_b_recon) self.accu_CC_all = self.mean_list([ self.accu_content_classifier_c_a, self.accu_content_classifier_c_a_recon, self.accu_content_classifier_c_b, self.accu_content_classifier_c_b_recon ]) # self.loss_cla_total = self.loss_content_classifier_c_a + self.loss_content_classifier_c_a_recon + \ # self.loss_content_classifier_b + self.loss_content_classifier_c_b_recon + \ # self.loss_content_classifier_c_a_and_c_a_recon + \ # self.loss_content_classifier_c_b_and_c_b_recon # self.loss_cla_total.backward() # self.cla_opt.step() @staticmethod def mean_list(lst): return sum(lst) / len(lst) def dis_update(self, x_a, x_b, hyperparameters): # print('dis_update') # print(x_a.is_cuda()) # exit() self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss # self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) # print(x_ba.detach().size()) # print(x_a.size()) # exit() x_ba_dis_out = self.dis_a(x_ba.detach()) x_a_dis_out = self.dis_a(x_a) self.loss_dis_a = self.compute_dis_loss(x_ba_dis_out, x_a_dis_out) # self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) x_ab_dis_out = self.dis_b(x_ab.detach()) x_b_dis_out = self.dis_b(x_b) self.loss_dis_b = self.compute_dis_loss(x_ab_dis_out, x_b_dis_out) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def compute_content_classifier_loss(self, label_predict, label_true): loss = self.criterion_content_classifier(label_predict, label_true) return loss def compute_content_classifier_two_predictions_loss( self, label_predict_1, label_predict_2): # loss = self.criterion_content_classifier(label_predict, label_true) # print(label_predict_1.size()) # print(label_predict_2.size()) loss = torch.mean(torch.abs(label_predict_1 - label_predict_2)) # print(loss.size()) # exit() return loss def compute_content_classifier_accuracy(self, label_predict, label_true): # print("label_true") # print(label_true) # # print("label_predict") # print(label_predict[0]) # print("max") values, indices = label_predict.max(1) # print(indices) results = (label_true == indices) # print(results) total_correct = results.sum().cpu().numpy() # print("total_correct") # print(total_correct) # total_samples = results.size() # print("total_samples") # print(total_samples) accuracy = float(total_correct) / float(self.batch_size_val) # print("accuracy") # print(accuracy) # # exit() return accuracy def compute_dis_loss(self, outs_fake, outs_real): # calculate the loss to train D # outs0 = self.forward(input_fake) # outs1 = self.forward(input_real) loss = 0 for it, (out_fake, out_real) in enumerate(zip(outs_fake, outs_real)): out_fake = out_fake['output_d'] out_real = out_real['output_d'] if self.gan_type == 'lsgan': loss += torch.mean((out_fake - 0)**2) + torch.mean( (out_real - 1)**2) elif self.gan_type == 'nsgan': all0 = Variable(torch.zeros_like(out_fake.data).cuda(), requires_grad=False) all1 = Variable(torch.ones_like(out_real.data).cuda(), requires_grad=False) loss += torch.mean( F.binary_cross_entropy(F.sigmoid(out_fake), all0) + F.binary_cross_entropy(F.sigmoid(out_real), all1)) else: assert 0, "Unsupported GAN type: {}".format(self.gan_type) return loss def compute_gen_adv_loss(self, outs_fake): # calculate the loss to train G # out_fake = self.forward(input_fake) loss = 0 for it, (out_fake) in enumerate(outs_fake): out_fake = out_fake['output_d'] if self.gan_type == 'lsgan': loss += torch.mean((out_fake - 1)**2) # LSGAN elif self.gan_type == 'nsgan': all1 = Variable(torch.ones_like(out_fake.data).cuda(), requires_grad=False) loss += torch.mean( F.binary_cross_entropy(F.sigmoid(out_fake), all1)) else: assert 0, "Unsupported GAN type: {}".format(self.gan_type) return loss def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load content classifier last_model_name = get_model_list(checkpoint_dir, "con_cla") state_dict = torch.load(last_model_name) self.content_classifier.load_state_dict(state_dict['con_cla']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) self.cla_opt.load_state_dict(state_dict['con_cla']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) self.cla_scheduler = get_scheduler(self.cla_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) con_cla_name = os.path.join(snapshot_dir, 'con_cla_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save({'con_cla': self.content_classifier.state_dict()}, con_cla_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict(), 'con_cla': self.cla_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['new_size'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] self.reg_param = hyperparameters['reg_param'] self.beta_step = hyperparameters['beta_step'] self.target_kl = hyperparameters['target_kl'] self.gan_type = hyperparameters['gan_type'] # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_b.parameters()) gen_params = list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.gen_b.apply(weights_init(hyperparameters['init'])) self.dis_b.apply(weights_init('gaussian')) # SSIM Loss self.ssim_loss = pytorch_ssim.SSIM() def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def recon_criterion_l1(self, input, target, mask): return torch.sum(torch.abs(input - target)) / torch.sum(mask) def forward(self, x_a, x_b): self.eval() s_b = self.gen_b.enc_style(x_b) c_a = self.gen_b.enc_content(x_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab def gen_update(self, x_a, x_b, hyperparameters): toogle_grad(self.dis_b, False) toogle_grad(self.gen_b, True) self.dis_b.train() self.gen_b.train() self.gen_opt.zero_grad() s_b = torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda() # encode c_a = self.gen_b.enc_content(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode x_ab = self.gen_b.decode(c_a, s_b) # encode again c_a_recon, s_b_recon = self.gen_b.encode(x_ab) x_ab.requires_grad_() # reconstruction loss self.loss_gen_recon_x_ab_ssim = -self.ssim_loss.forward(x_a, x_ab) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) # GAN loss _, _, d_fake = self.dis_b(x_ab) # d_fake = d_fake['out'] self.loss_gen_adv_b = self.compute_loss(d_fake, 1) # total loss self.loss_gen_total = self.loss_gen_adv_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_x_ab'] * self.loss_gen_recon_x_ab_ssim self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() x_ab = [] s_b = self.gen_b.enc_style(x_b) for i in range(x_a.size(0)): c_a = self.gen_b.enc_content(x_a[i].unsqueeze(0)) x_ab.append(self.gen_b.decode(c_a, s_b)) x_ab = torch.cat(x_ab) self.train() return x_a, x_ab def dis_update(self, x_a, x_b, hyperparameters): toogle_grad(self.gen_b, False) toogle_grad(self.dis_b, True) self.gen_b.train() self.dis_b.train() self.dis_opt.zero_grad() # On real data x_b.requires_grad_() d_real_dict = self.dis_b(x_b) d_real = d_real_dict[2] dloss_real = self.compute_loss(d_real, 1) reg = 0. # Both grad penal and vgan! dloss_real.backward(retain_graph=True) # hard coded 10 weight for grad penal. reg += 10. * compute_grad2(d_real, x_b).mean() mu = d_real_dict[0] logstd = d_real_dict[1] kl_real = kl_loss(mu, logstd).mean() # On fake data with torch.no_grad(): s_b = torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda() c_a = self.gen_b.enc_content(x_a) x_ab = self.gen_b.decode(c_a, s_b) x_ab.requires_grad_() d_fake_dict = self.dis_b(x_ab) d_fake = d_fake_dict[2] dloss_fake = self.compute_loss(d_fake, 0) dloss_fake.backward(retain_graph=True) mu_fake = d_fake_dict[0] logstd_fake = d_fake_dict[1] kl_fake = kl_loss(mu_fake, logstd_fake).mean() avg_kl = 0.5 * (kl_real + kl_fake) reg += self.reg_param * avg_kl reg.backward() self.update_beta(avg_kl) self.dis_opt.step() self.loss_dis_total = (dloss_real + dloss_fake) return self.loss_dis_total.item() def compute_loss(self, d_out, target): targets = d_out.new_full(size=d_out.size(), fill_value=target) if self.gan_type == 'standard': loss = F.binary_cross_entropy_with_logits(d_out, targets) elif self.gan_type == 'wgan': loss = (2 * target - 1) * d_out.mean() else: raise NotImplementedError return loss def update_beta(self, avg_kl): with torch.no_grad(): new_beta = self.reg_param - self.beta_step * \ (self.target_kl - avg_kl) # self.target_kl is constrain I_c, new_beta = max(new_beta, 0) # print('setting beta from %.2f to %.2f' % (self.reg_param, new_beta)) self.reg_param = new_beta def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % iterations) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % iterations) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'b': self.gen_b.state_dict()}, gen_name) torch.save({'b': self.dis_b.state_dict()}, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a # self.gen_aT = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda(cun) self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda(cun) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_aT, s_a_fake_T = self.gen_a.encodeT(x_a) # self.gen_a.ptl() c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a_fake)#change + _fake x_ab = self.gen_b.decode(c_a, s_b_fake) x_aT = self.gen_a.decodeT(c_a,s_a_fake_T) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() # s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(cun))#torch.randn(*sizes) # s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(cun)) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) c_aT,s_aT_prime = self.gen_a.encodeT(x_a) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_aT_recon = self.gen_a.decodeT(c_a,s_aT_prime) # print("style code size:",s_a_prime.size()) # print("recon img size:",x_a_recon.size()) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a_prime) x_ab = self.gen_b.decode(c_a, s_b_prime) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # print("content code size:",c_a_recon.size()) # decode again (if needed) x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_styleT = self.recon_criterion(x_a,x_aT_recon) self.loss_gen_content = self.recon_criterion(c_a,c_aT) self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a_prime) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b_prime) # self.loss_gen_geo = self.recon_criterion(s_a_prime,s_b_prime) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total += self.recon_criterion(c_a,c_aT) * 5 self.loss_gen_total += self.loss_gen_styleT * 5 self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def sample(self, x_a, x_b): self.eval() # s_a1 = Variable(self.s_a) # s_b1 = Variable(self.s_b) # s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(cun)) # s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(cun)) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2, x_ba, x_ab = [], [], [], [], [], [], [], [] # for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba.append(self.gen_a.decode(c_b, s_a_fake)) # x_ba2.append(self.gen_a.decode(c_b, s_a_fake)) x_ab.append(self.gen_b.decode(c_a, s_b_fake)) # x_ab2.append(self.gen_b.decode(c_a, s_b_fake)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba = torch.cat(x_ba) x_ab = torch.cat(x_ab) # x_ab1, x_ab2 = torch.cat(x_ab1), torch # .cat(x_ab2) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(cun)) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(cun)) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, _) x_ab = self.gen_b.decode(c_a, _) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name) # class UNIT_Trainer(nn.Module): # def __init__(self, hyperparameters): # super(UNIT_Trainer, self).__init__() # lr = hyperparameters['lr'] # # Initiate the networks # self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a # self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b # self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a # self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b # self.instancenorm = nn.InstanceNorm2d(512, affine=False) # # Setup the optimizers # beta1 = hyperparameters['beta1'] # beta2 = hyperparameters['beta2'] # dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) # gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) # self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], # lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) # self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], # lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) # self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) # self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # # Network weight initialization # self.apply(weights_init(hyperparameters['init'])) # self.dis_a.apply(weights_init('gaussian')) # self.dis_b.apply(weights_init('gaussian')) # # Load VGG model if needed # if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: # self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') # self.vgg.eval() # for param in self.vgg.parameters(): # param.requires_grad = False # def recon_criterion(self, input, target): # return torch.mean(torch.abs(input - target)) # def forward(self, x_a, x_b): # self.eval() # h_a, _ = self.gen_a.encode(x_a) # h_b, _ = self.gen_b.encode(x_b) # x_ba = self.gen_a.decode(h_b) # x_ab = self.gen_b.decode(h_a) # self.train() # return x_ab, x_ba # def __compute_kl(self, mu): # # def _compute_kl(self, mu, sd): # # mu_2 = torch.pow(mu, 2) # # sd_2 = torch.pow(sd, 2) # # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0) # # return encoding_loss # mu_2 = torch.pow(mu, 2) # encoding_loss = torch.mean(mu_2) # return encoding_loss # def gen_update(self, x_a, x_b, hyperparameters): # self.gen_opt.zero_grad() # # encode # h_a, n_a = self.gen_a.encode(x_a) # h_b, n_b = self.gen_b.encode(x_b) # # decode (within domain) # x_a_recon = self.gen_a.decode(h_a + n_a) # x_b_recon = self.gen_b.decode(h_b + n_b) # # decode (cross domain) # x_ba = self.gen_a.decode(h_b + n_b) # x_ab = self.gen_b.decode(h_a + n_a) # # encode again # h_b_recon, n_b_recon = self.gen_a.encode(x_ba) # h_a_recon, n_a_recon = self.gen_b.encode(x_ab) # # decode again (if needed) # x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # # reconstruction loss # self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) # self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) # self.loss_gen_recon_kl_a = self.__compute_kl(h_a) # self.loss_gen_recon_kl_b = self.__compute_kl(h_b) # self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) # self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) # self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) # self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) # # GAN loss # self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) # self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # # domain-invariant perceptual loss # self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 # self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # # total loss # # self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ # # hyperparameters['gan_w'] * self.loss_gen_adv_b + \ # # hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ # # hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ # # hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ # # hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ # # hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ # # hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ # # hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ # # hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ # # hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ # # hyperparameters['vgg_w'] * self.loss_gen_vgg_b # # self.loss_gen_total.backward() # # self.gen_opt.step() # self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ # hyperparameters['gan_w'] * self.loss_gen_adv_b + \ # hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ # hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ # hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ # hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ # hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ # hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ # hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ # hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ # hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ # hyperparameters['vgg_w'] * self.loss_gen_vgg_b # self.loss_gen_total.backward() # self.gen_opt.step() # def compute_vgg_loss(self, vgg, img, target): # img_vgg = vgg_preprocess(img) # target_vgg = vgg_preprocess(target) # img_fea = vgg(img_vgg) # target_fea = vgg(target_vgg) # return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) # def sample(self, x_a, x_b): # self.eval() # x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] # for i in range(x_a.size(0)): # h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) # h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0)) # x_a_recon.append(self.gen_a.decode(h_a)) # x_b_recon.append(self.gen_b.decode(h_b)) # x_ba.append(self.gen_a.decode(h_b)) # x_ab.append(self.gen_b.decode(h_a)) # x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) # x_ba = torch.cat(x_ba) # x_ab = torch.cat(x_ab) # self.train() # return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba # def dis_update(self, x_a, x_b, hyperparameters): # self.dis_opt.zero_grad() # # encode # h_a, n_a = self.gen_a.encode(x_a) # h_b, n_b = self.gen_b.encode(x_b) # # decode (cross domain) # x_ba = self.gen_a.decode(h_b + n_b) # x_ab = self.gen_b.decode(h_a + n_a) # # D loss # self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) # self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) # self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b # self.loss_dis_total.backward() # self.dis_opt.step() # def update_learning_rate(self): # if self.dis_scheduler is not None: # self.dis_scheduler.step() # if self.gen_scheduler is not None: # self.gen_scheduler.step() # def resume(self, checkpoint_dir, hyperparameters): # # Load generators # last_model_name = get_model_list(checkpoint_dir, "gen") # state_dict = torch.load(last_model_name) # self.gen_a.load_state_dict(state_dict['a']) # self.gen_b.load_state_dict(state_dict['b']) # iterations = int(last_model_name[-11:-3]) # # Load discriminators # last_model_name = get_model_list(checkpoint_dir, "dis") # state_dict = torch.load(last_model_name) # self.dis_a.load_state_dict(state_dict['a']) # self.dis_b.load_state_dict(state_dict['b']) # # Load optimizers # state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) # self.dis_opt.load_state_dict(state_dict['dis']) # self.gen_opt.load_state_dict(state_dict['gen']) # # Reinitilize schedulers # self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) # self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) # print('Resume from iteration %d' % iterations) # return iterations # def save(self, snapshot_dir, iterations): # # Save generators, discriminators, and optimizers # gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) # dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % ( # terations + 1)) # opt_name = os.path.join(snapshot_dir, 'optimizer.pt') # torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class DSMAP_Trainer(nn.Module): def __init__(self, hyperparameters): super(DSMAP_Trainer, self).__init__() # Initiate the networks mid_downsample = hyperparameters['gen'].get('mid_downsample', 1) self.content_enc = ContentEncoder_share( hyperparameters['gen']['n_downsample'], mid_downsample, hyperparameters['gen']['n_res'], hyperparameters['input_dim_a'], hyperparameters['gen']['dim'], 'in', hyperparameters['gen']['activ'], pad_type=hyperparameters['gen']['pad_type']) self.style_dim = hyperparameters['gen']['style_dim'] self.gen_a = AdaINGen( hyperparameters['input_dim_a'], self.content_enc, 'a', hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], self.content_enc, 'b', hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], self.content_enc.output_dim, hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], self.content_enc.output_dim, hyperparameters['dis']) # discriminator for domain b def build_optimizer(self, hyperparameters): # Setup the optimizers lr = hyperparameters['lr'] beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: import torchvision.models as models self.vgg = models.vgg16(pretrained=True) # If you cannot download pretrained model automatically, you can download it from # https://download.pytorch.org/models/vgg16-397923af.pth and load it manually # state_dict = torch.load('vgg16-397923af.pth') # self.vgg.load_state_dict(state_dict) self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def __compute_kl(self, mu, logvar): encoding_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return encoding_loss def gen_update(self, x_a, x_b, hyperparameters, iterations): self.gen_opt.zero_grad() self.gen_backward_cc(x_a, x_b, hyperparameters) #self.gen_opt.step() #self.gen_opt.zero_grad() self.gen_backward_latent(x_a, x_b, hyperparameters) self.gen_opt.step() def gen_backward_latent(self, x_a, x_b, hyperparameters): # random sample style vector and multimodal training s_a_random = Variable( torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b_random = Variable( torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # decode x_ba_random = self.gen_a.decode(self.c_b, self.da_b, s_a_random) x_ab_random = self.gen_b.decode(self.c_a, self.db_a, s_b_random) c_b_random_recon, _, _, s_a_random_recon, _, _ = self.gen_a.encode( x_ba_random) c_a_random_recon, _, _, s_b_random_recon, _, _ = self.gen_b.encode( x_ab_random) # style reconstruction loss self.loss_gen_recon_s_a = self.recon_criterion(s_a_random, s_a_random_recon) self.loss_gen_recon_s_b = self.recon_criterion(s_b_random, s_b_random_recon) loss_gen_recon_c_a = self.recon_criterion(self.c_a, c_a_random_recon) loss_gen_recon_c_b = self.recon_criterion(self.c_b, c_b_random_recon) loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba_random, x_a) loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab_random, x_b) loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba_random, x_a) if hyperparameters['vgg_w'] > 0 else 0 loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab_random, x_b) if hyperparameters['vgg_w'] > 0 else 0 loss_gen_total = hyperparameters['gan_w'] * loss_gen_adv_a + \ hyperparameters['gan_w'] * loss_gen_adv_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * loss_gen_recon_c_a + \ hyperparameters['recon_c_w'] * loss_gen_recon_c_b + \ hyperparameters['vgg_w'] * loss_gen_vgg_a + \ hyperparameters['vgg_w'] * loss_gen_vgg_b self.loss_gen_total += loss_gen_total self.loss_gen_total.backward() self.loss_gen_total += loss_gen_total self.loss_gen_adv_a += loss_gen_adv_a self.loss_gen_adv_b += loss_gen_adv_b self.loss_gen_recon_c_a += loss_gen_recon_c_a self.loss_gen_recon_c_b += loss_gen_recon_c_b self.loss_gen_vgg_a += loss_gen_vgg_a self.loss_gen_vgg_b += loss_gen_vgg_b def gen_backward_cc(self, x_a, x_b, hyperparameters): pre_c_a, self.c_a, c_domain_a, self.db_a, self.s_a_prime, mu_a, logvar_a = self.gen_a.encode( x_a, training=True, flag=True) pre_c_b, self.c_b, c_domain_b, self.da_b, self.s_b_prime, mu_b, logvar_b = self.gen_b.encode( x_b, training=True, flag=True) self.da_a = self.gen_b.domain_mapping(self.c_a, pre_c_a) self.db_b = self.gen_a.domain_mapping(self.c_b, pre_c_b) # decode (within domain) x_a_recon = self.gen_a.decode(self.c_a, self.da_a, self.s_a_prime) x_b_recon = self.gen_b.decode(self.c_b, self.db_b, self.s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(self.c_b, self.da_b, self.s_a_prime) x_ab = self.gen_b.decode(self.c_a, self.db_a, self.s_b_prime) c_b_recon, _, self.db_b_recon, s_a_recon, _, _ = self.gen_a.encode( x_ba, training=True) c_a_recon, _, self.da_a_recon, s_b_recon, _, _ = self.gen_b.encode( x_ab, training=True) # decode again (cycle consistance loss) x_aba = self.gen_a.decode(c_a_recon, self.da_a_recon, s_a_recon) x_bab = self.gen_b.decode(c_b_recon, self.db_b_recon, s_b_recon) # domain-specific content reconstruction loss self.loss_gen_recon_d_a = self.recon_criterion( c_domain_a, self.da_a) if hyperparameters['recon_d_w'] > 0 else 0 self.loss_gen_recon_d_b = self.recon_criterion( c_domain_b, self.db_b) if hyperparameters['recon_d_w'] > 0 else 0 # image reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) # domain-invariant content reconstruction loss self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, self.c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, self.c_b) # cyc loss self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) # kl loss (if needed) self.loss_gen_recon_kl_a = self.__compute_kl( mu_a, logvar_a) if hyperparameters['recon_kl_w'] > 0 else 0 self.loss_gen_recon_kl_b = self.__compute_kl( mu_b, logvar_b) if hyperparameters['recon_kl_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba, x_a) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab, x_b) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_a) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_b) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_d_w'] * self.loss_gen_recon_d_a + \ hyperparameters['recon_d_w'] * self.loss_gen_recon_d_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b # self.loss_gen_total.backward(retain_graph=True) def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg.features(img_vgg) target_fea = vgg.features(target_vgg) return contextual_loss(img_fea, target_fea) def dis_update(self, x_a, x_b, hyperparameters, iterations): self.dis_opt.zero_grad() s_a_random = Variable( torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b_random = Variable( torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode pre_c_a, c_a, c_domain_a, db_a, s_a, _, _ = self.gen_a.encode( x_a, training=True, flag=True) pre_c_b, c_b, c_domain_b, da_b, s_b, _, _ = self.gen_b.encode( x_b, training=True, flag=True) da_a = self.gen_b.domain_mapping(c_a, pre_c_a) db_b = self.gen_a.domain_mapping(c_b, pre_c_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, da_b, s_a) x_ab = self.gen_b.decode(c_a, db_a, s_b) # decode (cross domain) x_ba_random = self.gen_a.decode(c_b, da_b, s_a_random) x_ab_random = self.gen_b.decode(c_a, db_a, s_b_random) c_b_recon, _, db_b_recon, s_a_recon, _, _ = self.gen_a.encode(x_ba) c_a_recon, _, da_a_recon, s_b_recon, _, _ = self.gen_b.encode(x_ab) _, _, db_b_random_recon, _, _, _ = self.gen_a.encode(x_ba_random) _, _, da_a_random_recon, _, _, _ = self.gen_b.encode(x_ab_random) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss( x_ba.detach(), x_a) + self.dis_a.calc_dis_loss( x_ba_random.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss( x_ab.detach(), x_b) + self.dis_b.calc_dis_loss( x_ab_random.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \ hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ab, x_ba = [], [], [], [] for i in range(x_a.size(0)): pre_c_a, c_a, _, db_a, s_a_fake, _, _ = self.gen_a.encode( x_a[i].unsqueeze(0), flag=True) pre_c_b, c_b, _, da_b, s_b_fake, _, _ = self.gen_b.encode( x_b[i].unsqueeze(0), flag=True) da_a = self.gen_b.domain_mapping(c_a, pre_c_a) db_b = self.gen_a.domain_mapping(c_b, pre_c_b) x_a_recon.append(self.gen_a.decode(c_a, da_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, db_b, s_b_fake)) x_ba.append(self.gen_a.decode(c_b, da_b, s_a_fake)) x_ab.append(self.gen_b.decode(c_a, db_a, s_b_fake)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ab, x_ba = torch.cat(x_ab), torch.cat(x_ba) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)