class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True s_a = Variable(self.s_a, volatile=True) s_b = Variable(self.s_b, volatile=True) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def sample(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True s_a1 = Variable(self.s_a, volatile=True) s_b1 = Variable(self.s_b, volatile=True) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(), volatile=True) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(), volatile=True) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class UnsupIntrinsicTrainer(nn.Module): def __init__(self, param): super(UnsupIntrinsicTrainer, self).__init__() lr = param['lr'] # Initiate the networks self.gen_i = AdaINGen(param['input_dim_a'], param['input_dim_a'], param['gen']) # auto-encoder for domain I self.gen_r = AdaINGen(param['input_dim_b'], param['input_dim_b'], param['gen']) # auto-encoder for domain R self.gen_s = AdaINGen(param['input_dim_c'], param['input_dim_c'], param['gen']) # auto-encoder for domain S self.dis_r = MsImageDis(param['input_dim_b'], param['dis']) # discriminator for domain R self.dis_s = MsImageDis(param['input_dim_c'], param['dis']) # discriminator for domain S gp = param['gen'] self.with_mapping = True self.use_phy_loss = True self.use_content_loss = True if 'ablation_study' in param: if 'with_mapping' in param['ablation_study']: wm = param['ablation_study']['with_mapping'] self.with_mapping = True if wm != 0 else False if 'wo_phy_loss' in param['ablation_study']: wpl = param['ablation_study']['wo_phy_loss'] self.use_phy_loss = True if wpl == 0 else False if 'wo_content_loss' in param['ablation_study']: wcl = param['ablation_study']['wo_content_loss'] self.use_content_loss = True if wcl == 0 else False if self.with_mapping: self.fea_s = IntrinsicSplitor(gp['style_dim'], gp['mlp_dim'], gp['n_layer'], gp['activ']) # split style for I self.fea_m = IntrinsicMerger(gp['style_dim'], gp['mlp_dim'], gp['n_layer'], gp['activ']) # merge style for R, S self.bias_shift = param['bias_shift'] self.instance_norm = nn.InstanceNorm2d(512, affine=False) self.style_dim = param['gen']['style_dim'] # fix the noise used in sampling display_size = int(param['display_size']) self.s_r = torch.randn(display_size, self.style_dim, 1, 1).cuda() + 1. self.s_s = torch.randn(display_size, self.style_dim, 1, 1).cuda() - 1. # Setup the optimizers beta1 = param['beta1'] beta2 = param['beta2'] dis_params = list(self.dis_r.parameters()) + list( self.dis_s.parameters()) if self.with_mapping: gen_params = list(self.gen_i.parameters()) + list(self.gen_r.parameters()) + \ list(self.gen_s.parameters()) + \ list(self.fea_s.parameters()) + list(self.fea_m.parameters()) else: gen_params = list(self.gen_i.parameters()) + list( self.gen_r.parameters()) + list(self.gen_s.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=param['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=param['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, param) self.gen_scheduler = get_scheduler(self.gen_opt, param) # Network weight initialization self.apply(weights_init(param['init'])) self.dis_r.apply(weights_init('gaussian')) self.dis_s.apply(weights_init('gaussian')) self.best_result = float('inf') self.reflectance_loss = LocalAlbedoSmoothnessLoss(param) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def physical_criterion(self, x_i, x_r, x_s): return torch.mean(torch.abs(x_i - x_r * x_s)) def forward(self, x_i): c_i, s_i_fake = self.gen_i.encode(x_i) if self.with_mapping: s_r, s_s = self.fea_s(s_i_fake) else: s_r = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_s = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift x_ri = self.gen_r.decode(c_i, s_r) x_si = self.gen_s.decode(c_i, s_s) return x_ri, x_si def inference(self, x_i, use_rand_fea=False): with torch.no_grad(): c_i, s_i_fake = self.gen_i.encode(x_i) if self.with_mapping: s_r, s_s = self.fea_s(s_i_fake) else: s_r = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_s = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift if use_rand_fea: s_r = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_s = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift x_ri = self.gen_r.decode(c_i, s_r) x_si = self.gen_s.decode(c_i, s_s) return x_ri, x_si # noinspection PyAttributeOutsideInit def gen_update(self, x_i, x_r, x_s, targets=None, param=None): self.gen_opt.zero_grad() # ============= Domain Translations ============= # encode c_i, s_i_prime = self.gen_i.encode(x_i) c_r, s_r_prime = self.gen_r.encode(x_r) c_s, s_s_prime = self.gen_s.encode(x_s) if self.with_mapping: s_ri, s_si = self.fea_s(s_i_prime) else: s_ri = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_si = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift s_r_rand = Variable( torch.randn(x_r.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_s_rand = Variable( torch.randn(x_s.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift if self.with_mapping: s_i_recon = self.fea_m(s_ri, s_si) else: s_i_recon = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) # decode (within domain) x_i_recon = self.gen_i.decode(c_i, s_i_prime) x_r_recon = self.gen_s.decode(c_r, s_r_prime) x_s_recon = self.gen_r.decode(c_s, s_s_prime) # decode (cross domain) x_rs = self.gen_r.decode(c_s, s_r_rand) x_ri = self.gen_r.decode(c_i, s_ri) x_ri_rand = self.gen_r.decode(c_i, s_r_rand) x_sr = self.gen_s.decode(c_r, s_s_rand) x_si = self.gen_s.decode(c_i, s_si) x_si_rand = self.gen_s.decode(c_i, s_r_rand) # encode again, for feature domain consistency constraints c_rs_recon, s_rs_recon = self.gen_r.encode(x_rs) c_ri_recon, s_ri_recon = self.gen_r.encode(x_ri) c_ri_rand_recon, s_ri_rand_recon = self.gen_r.encode(x_ri_rand) c_sr_recon, s_sr_recon = self.gen_s.encode(x_sr) c_si_recon, s_si_recon = self.gen_s.encode(x_si) c_si_rand_recon, s_si_rand_recon = self.gen_s.encode(x_si_rand) # decode again, for image domain cycle consistency x_rsr = self.gen_r.decode(c_sr_recon, s_r_prime) x_iri = self.gen_i.decode(c_ri_recon, s_i_prime) x_iri_rand = self.gen_i.decode(c_ri_rand_recon, s_i_prime) x_srs = self.gen_s.decode(c_rs_recon, s_s_prime) x_isi = self.gen_i.decode(c_si_recon, s_i_prime) x_isi_rand = self.gen_i.decode(c_si_rand_recon, s_i_prime) # ============= Loss Functions ============= # Encoder decoder reconstruction loss for three domain self.loss_gen_recon_x_i = self.recon_criterion(x_i_recon, x_i) self.loss_gen_recon_x_r = self.recon_criterion(x_r_recon, x_r) self.loss_gen_recon_x_s = self.recon_criterion(x_s_recon, x_s) # Style-level reconstruction loss for cross domain if self.with_mapping: self.loss_gen_recon_s_ii = self.recon_criterion( s_i_recon, s_i_prime) else: self.loss_gen_recon_s_ii = 0 self.loss_gen_recon_s_ri = self.recon_criterion(s_ri_recon, s_ri) self.loss_gen_recon_s_ri_rand = self.recon_criterion( s_ri_rand_recon, s_ri) self.loss_gen_recon_s_rs = self.recon_criterion(s_rs_recon, s_r_rand) self.loss_gen_recon_s_sr = self.recon_criterion(s_sr_recon, s_s_rand) self.loss_gen_recon_s_si = self.recon_criterion(s_si_recon, s_si) self.loss_gen_recon_s_si_rand = self.recon_criterion( s_si_rand_recon, s_si) # Content-level reconstruction loss for cross domain self.loss_gen_recon_c_rs = self.recon_criterion(c_rs_recon, c_s) self.loss_gen_recon_c_ri = self.recon_criterion( c_ri_recon, c_i) if self.use_content_loss is True else 0 self.loss_gen_recon_c_ri_rand = self.recon_criterion( c_ri_rand_recon, c_i) if self.use_content_loss is True else 0 self.loss_gen_recon_c_sr = self.recon_criterion(c_sr_recon, c_r) self.loss_gen_recon_c_si = self.recon_criterion( c_si_recon, c_i) if self.use_content_loss is True else 0 self.loss_gen_recon_c_si_rand = self.recon_criterion( c_si_rand_recon, c_i) if self.use_content_loss is True else 0 # Cycle consistency loss for three image domain self.loss_gen_cyc_recon_x_rs = self.recon_criterion(x_rsr, x_r) self.loss_gen_cyc_recon_x_ir = self.recon_criterion(x_iri, x_i) self.loss_gen_cyc_recon_x_ir_rand = self.recon_criterion( x_iri_rand, x_i) self.loss_gen_cyc_recon_x_sr = self.recon_criterion(x_srs, x_s) self.loss_gen_cyc_recon_x_is = self.recon_criterion(x_isi, x_i) self.loss_gen_cyc_recon_x_is_rand = self.recon_criterion( x_isi_rand, x_i) # GAN loss self.loss_gen_adv_rs = self.dis_r.calc_gen_loss(x_rs) self.loss_gen_adv_ri = self.dis_r.calc_gen_loss(x_ri) self.loss_gen_adv_ri_rand = self.dis_r.calc_gen_loss(x_ri_rand) self.loss_gen_adv_sr = self.dis_s.calc_gen_loss(x_sr) self.loss_gen_adv_si = self.dis_s.calc_gen_loss(x_si) self.loss_gen_adv_si_rand = self.dis_s.calc_gen_loss(x_si_rand) # Physical loss self.loss_gen_phy_i = self.physical_criterion( x_i, x_ri, x_si) if self.use_phy_loss is True else 0 self.loss_gen_phy_i_rand = self.physical_criterion( x_i, x_ri_rand, x_si_rand) if self.use_phy_loss is True else 0 # Reflectance smoothness loss self.loss_refl_ri = self.reflectance_loss( x_ri, targets) if targets is not None else 0 self.loss_refl_ri_rand = self.reflectance_loss( x_ri_rand, targets) if targets is not None else 0 # total loss self.loss_gen_total = param['gan_w'] * self.loss_gen_adv_rs + \ param['gan_w'] * self.loss_gen_adv_ri + \ param['gan_w'] * self.loss_gen_adv_ri_rand + \ param['gan_w'] * self.loss_gen_adv_sr + \ param['gan_w'] * self.loss_gen_adv_si + \ param['gan_w'] * self.loss_gen_adv_si_rand + \ param['recon_x_w'] * self.loss_gen_recon_x_i + \ param['recon_x_w'] * self.loss_gen_recon_x_r + \ param['recon_x_w'] * self.loss_gen_recon_x_s + \ param['recon_s_w'] * self.loss_gen_recon_s_ii + \ param['recon_s_w'] * self.loss_gen_recon_s_ri + \ param['recon_s_w'] * self.loss_gen_recon_s_ri_rand + \ param['recon_s_w'] * self.loss_gen_recon_s_rs + \ param['recon_s_w'] * self.loss_gen_recon_s_si + \ param['recon_s_w'] * self.loss_gen_recon_s_si_rand + \ param['recon_s_w'] * self.loss_gen_recon_s_sr + \ param['recon_c_w'] * self.loss_gen_recon_c_ri + \ param['recon_c_w'] * self.loss_gen_recon_c_rs + \ param['recon_c_w'] * self.loss_gen_recon_c_ri_rand + \ param['recon_c_w'] * self.loss_gen_recon_c_si + \ param['recon_c_w'] * self.loss_gen_recon_c_sr + \ param['recon_c_w'] * self.loss_gen_recon_c_si_rand + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_ir + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_ir_rand + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_is + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_is_rand + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_rs + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_sr + \ param['phy_x_w'] * self.loss_gen_phy_i + \ param['phy_x_w'] * self.loss_gen_phy_i_rand + \ param['refl_smooth_w'] * self.loss_refl_ri + \ param['refl_smooth_w'] * self.loss_refl_ri_rand self.loss_gen_total.backward() self.gen_opt.step() def sample(self, x_i, x_r, x_s): self.eval() s_r = Variable(self.s_r) s_s = Variable(self.s_s) x_i_recon, x_r_recon, x_s_recon, x_rs, x_ri, x_sr, x_si = [], [], [], [], [], [], [] for i in range(x_i.size(0)): c_i, s_i_fake = self.gen_i.encode(x_i[i].unsqueeze(0)) c_r, s_r_fake = self.gen_r.encode(x_r[i].unsqueeze(0)) c_s, s_s_fake = self.gen_s.encode(x_s[i].unsqueeze(0)) if self.with_mapping: s_ri, s_si = self.fea_s(s_i_fake) else: s_ri = Variable(torch.randn(1, self.style_dim, 1, 1).cuda()) + self.bias_shift s_si = Variable(torch.randn(1, self.style_dim, 1, 1).cuda()) - self.bias_shift x_i_recon.append(self.gen_i.decode(c_i, s_i_fake)) x_r_recon.append(self.gen_r.decode(c_r, s_r_fake)) x_s_recon.append(self.gen_s.decode(c_s, s_s_fake)) x_rs.append(self.gen_r.decode(c_s, s_r[i].unsqueeze(0))) x_ri.append(self.gen_r.decode(c_i, s_ri.unsqueeze(0))) x_sr.append(self.gen_s.decode(c_s, s_s[i].unsqueeze(0))) x_si.append(self.gen_s.decode(c_i, s_si.unsqueeze(0))) x_i_recon, x_r_recon, x_s_recon = torch.cat(x_i_recon), torch.cat( x_r_recon), torch.cat(x_s_recon) x_rs, x_ri = torch.cat(x_rs), torch.cat(x_ri) x_sr, x_si = torch.cat(x_sr), torch.cat(x_si) self.train() return x_i, x_i_recon, x_r, x_r_recon, x_rs, x_ri, x_s, x_s_recon, x_sr, x_si # noinspection PyAttributeOutsideInit def dis_update(self, x_i, x_r, x_s, params): self.dis_opt.zero_grad() s_r = Variable(torch.randn(x_r.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift s_s = Variable(torch.randn(x_s.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift # encode c_r, _ = self.gen_r.encode(x_r) c_s, _ = self.gen_s.encode(x_s) c_i, s_i = self.gen_i.encode(x_i) if self.with_mapping: s_ri, s_si = self.fea_s(s_i) else: s_ri = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_si = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift # decode (cross domain) x_rs = self.gen_r.decode(c_s, s_r) x_ri = self.gen_r.decode(c_i, s_ri) x_sr = self.gen_s.decode(c_r, s_s) x_si = self.gen_s.decode(c_i, s_si) # D loss self.loss_dis_rs = self.dis_r.calc_dis_loss(x_rs.detach(), x_r) self.loss_dis_ri = self.dis_r.calc_dis_loss(x_ri.detach(), x_r) self.loss_dis_sr = self.dis_s.calc_dis_loss(x_sr.detach(), x_s) self.loss_dis_si = self.dis_s.calc_dis_loss(x_si.detach(), x_s) self.loss_dis_total = params['gan_w'] * self.loss_dis_rs +\ params['gan_w'] * self.loss_dis_ri +\ params['gan_w'] * self.loss_dis_sr +\ params['gan_w'] * self.loss_dis_si self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, param): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_i.load_state_dict(state_dict['i']) self.gen_r.load_state_dict(state_dict['r']) self.gen_s.load_state_dict(state_dict['s']) if self.with_mapping: self.fea_m.load_state_dict(state_dict['fm']) self.fea_s.load_state_dict(state_dict['fs']) self.best_result = state_dict['best_result'] iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_r.load_state_dict(state_dict['r']) self.dis_s.load_state_dict(state_dict['s']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, param, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, param, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') if self.with_mapping: torch.save( { 'i': self.gen_i.state_dict(), 'r': self.gen_r.state_dict(), 's': self.gen_s.state_dict(), 'fs': self.fea_s.state_dict(), 'fm': self.fea_m.state_dict(), 'best_result': self.best_result }, gen_name) else: torch.save( { 'i': self.gen_i.state_dict(), 'r': self.gen_r.state_dict(), 's': self.gen_s.state_dict(), 'best_result': self.best_result }, gen_name) torch.save({ 'r': self.dis_r.state_dict(), 's': self.dis_s.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class DGNet_Trainer(nn.Module): def __init__(self, hyperparameters, gpu_ids=[0]): super(DGNet_Trainer, self).__init__() lr_g = hyperparameters['lr_g'] lr_d = hyperparameters['lr_d'] ID_class = hyperparameters['ID_class'] if not 'apex' in hyperparameters.keys(): hyperparameters['apex'] = False self.fp16 = hyperparameters['apex'] # Initiate the networks # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False. self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16=False) # auto-encoder for domain a self.gen_b = self.gen_a # auto-encoder for domain b if not 'ID_stride' in hyperparameters.keys(): hyperparameters['ID_stride'] = 2 if hyperparameters['ID_style'] == 'PCB': self.id_a = PCB(ID_class) elif hyperparameters['ID_style'] == 'AB': self.id_a = ft_netAB(ID_class, stride=hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) else: self.id_a = ft_net(ID_class, norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) # return 2048 now self.id_b = self.id_a self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16=False) # discriminator for domain a self.dis_b = self.dis_a # discriminator for domain b # load teachers if hyperparameters['teacher'] != "": teacher_name = hyperparameters['teacher'] print(teacher_name) teacher_names = teacher_name.split(',') teacher_model = nn.ModuleList() teacher_count = 0 for teacher_name in teacher_names: config_tmp = load_config(teacher_name) if 'stride' in config_tmp: stride = config_tmp['stride'] else: stride = 2 model_tmp = ft_net(ID_class, stride=stride) teacher_model_tmp = load_network(model_tmp, teacher_name) teacher_model_tmp.model.fc = nn.Sequential( ) # remove the original fc layer in ImageNet teacher_model_tmp = teacher_model_tmp.cuda() if self.fp16: teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1") teacher_model.append(teacher_model_tmp.cuda().eval()) teacher_count += 1 self.teacher_model = teacher_model if hyperparameters['train_bn']: self.teacher_model = self.teacher_model.apply(train_bn) self.instancenorm = nn.InstanceNorm2d(512, affine=False) # RGB to one channel if hyperparameters['single'] == 'edge': self.single = to_edge else: self.single = to_gray(False) # Random Erasing when training if not 'erasing_p' in hyperparameters.keys(): hyperparameters['erasing_p'] = 0 self.single_re = RandomErasing( probability=hyperparameters['erasing_p'], mean=[0.0, 0.0, 0.0]) if not 'T_w' in hyperparameters.keys(): hyperparameters['T_w'] = 1 # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list( self.dis_a.parameters()) #+ list(self.dis_b.parameters()) gen_params = list( self.gen_a.parameters()) #+ list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) # id params if hyperparameters['ID_style'] == 'PCB': ignored_params = ( list(map(id, self.id_a.classifier0.parameters())) + list(map(id, self.id_a.classifier1.parameters())) + list(map(id, self.id_a.classifier2.parameters())) + list(map(id, self.id_a.classifier3.parameters()))) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier0.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier1.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier2.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier3.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) elif hyperparameters['ID_style'] == 'AB': ignored_params = ( list(map(id, self.id_a.classifier1.parameters())) + list(map(id, self.id_a.classifier2.parameters()))) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier1.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier2.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) else: ignored_params = list(map(id, self.id_a.classifier.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) self.id_scheduler = get_scheduler(self.id_opt, hyperparameters) self.id_scheduler.gamma = hyperparameters['gamma2'] #ID Loss self.id_criterion = nn.CrossEntropyLoss() self.criterion_teacher = nn.KLDivLoss(size_average=False) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False # save memory if self.fp16: # Name the FP16_Optimizer instance to replace the existing optimizer assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." self.gen_a = self.gen_a.cuda() self.dis_a = self.dis_a.cuda() self.id_a = self.id_a.cuda() self.gen_b = self.gen_a self.dis_b = self.dis_a self.id_b = self.id_a self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1") self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1") self.id_a, self.id_opt = amp.initialize(self.id_a, self.id_opt, opt_level="O1") def to_re(self, x): out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3)) out = out.cuda() for i in range(x.size(0)): out[i, :, :, :] = self.single_re(x[i, :, :, :]) return out def recon_criterion(self, input, target): diff = input - target.detach() return torch.mean(torch.abs(diff[:])) def recon_criterion_sqrt(self, input, target): diff = input - target return torch.mean(torch.sqrt(torch.abs(diff[:]) + 1e-8)) def recon_criterion2(self, input, target): diff = input - target return torch.mean(diff[:]**2) def recon_cos(self, input, target): cos = torch.nn.CosineSimilarity() cos_dis = 1 - cos(input, target) return torch.mean(cos_dis[:]) def forward(self, x_a, x_b): self.eval() s_a = self.gen_a.encode(self.single(x_a)) s_b = self.gen_b.encode(self.single(x_b)) f_a, _ = self.id_a(scale2(x_a)) f_b, _ = self.id_b(scale2(x_b)) x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) self.train() return x_ab, x_ba def gen_update(self, x_a, l_a, xp_a, x_b, l_b, xp_b, hyperparameters, iteration): # ppa, ppb is the same person self.gen_opt.zero_grad() self.id_opt.zero_grad() # encode s_a = self.gen_a.encode(self.single(x_a)) s_b = self.gen_b.encode(self.single(x_b)) f_a, p_a = self.id_a(scale2(x_a)) f_b, p_b = self.id_b(scale2(x_b)) # autodecode x_a_recon = self.gen_a.decode(s_a, f_a) x_b_recon = self.gen_b.decode(s_b, f_b) # encode the same ID different photo fp_a, pp_a = self.id_a(scale2(xp_a)) fp_b, pp_b = self.id_b(scale2(xp_b)) # decode the same person x_a_recon_p = self.gen_a.decode(s_a, fp_a) x_b_recon_p = self.gen_b.decode(s_b, fp_b) # has gradient x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) # no gradient x_ba_copy = Variable(x_ba.data, requires_grad=False) x_ab_copy = Variable(x_ab.data, requires_grad=False) rand_num = random.uniform(0, 1) ################################# # encode structure if hyperparameters['use_encoder_again'] >= rand_num: # encode again (encoder is tuned, input is fixed) s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy)) s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy)) else: # copy the encoder self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content) self.enc_content_copy = self.enc_content_copy.eval() # encode again (encoder is fixed, input is tuned) s_a_recon = self.enc_content_copy(self.single(x_ab)) s_b_recon = self.enc_content_copy(self.single(x_ba)) ################################# # encode appearance self.id_a_copy = copy.deepcopy(self.id_a) self.id_a_copy = self.id_a_copy.eval() if hyperparameters['train_bn']: self.id_a_copy = self.id_a_copy.apply(train_bn) self.id_b_copy = self.id_a_copy # encode again (encoder is fixed, input is tuned) f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba)) f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab)) # teacher Loss # Tune the ID model log_sm = nn.LogSoftmax(dim=1) if hyperparameters['teacher_w'] > 0 and hyperparameters[ 'teacher'] != "": if hyperparameters['ID_style'] == 'normal': _, p_a_student = self.id_a(scale2(x_ba_copy)) p_a_student = log_sm(p_a_student) p_a_teacher = predict_label( self.teacher_model, scale2(x_ba_copy), num_class=hyperparameters['ID_class'], alabel=l_a, slabel=l_b, teacher_style=hyperparameters['teacher_style']) self.loss_teacher = self.criterion_teacher( p_a_student, p_a_teacher) / p_a_student.size(0) _, p_b_student = self.id_b(scale2(x_ab_copy)) p_b_student = log_sm(p_b_student) p_b_teacher = predict_label( self.teacher_model, scale2(x_ab_copy), num_class=hyperparameters['ID_class'], alabel=l_b, slabel=l_a, teacher_style=hyperparameters['teacher_style']) self.loss_teacher += self.criterion_teacher( p_b_student, p_b_teacher) / p_b_student.size(0) elif hyperparameters['ID_style'] == 'AB': # normal teacher-student loss # BA -> LabelA(smooth) + LabelB(batchB) _, p_ba_student = self.id_a(scale2(x_ba_copy)) # f_a, s_b p_a_student = log_sm(p_ba_student[0]) with torch.no_grad(): p_a_teacher = predict_label( self.teacher_model, scale2(x_ba_copy), num_class=hyperparameters['ID_class'], alabel=l_a, slabel=l_b, teacher_style=hyperparameters['teacher_style']) self.loss_teacher = self.criterion_teacher( p_a_student, p_a_teacher) / p_a_student.size(0) _, p_ab_student = self.id_b(scale2(x_ab_copy)) # f_b, s_a p_b_student = log_sm(p_ab_student[0]) with torch.no_grad(): p_b_teacher = predict_label( self.teacher_model, scale2(x_ab_copy), num_class=hyperparameters['ID_class'], alabel=l_b, slabel=l_a, teacher_style=hyperparameters['teacher_style']) self.loss_teacher += self.criterion_teacher( p_b_student, p_b_teacher) / p_b_student.size(0) # branch b loss # here we give different label loss_B = self.id_criterion(p_ba_student[1], l_b) + self.id_criterion( p_ab_student[1], l_a) self.loss_teacher = hyperparameters[ 'T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B else: self.loss_teacher = 0.0 # decode again (if needed) if hyperparameters['use_decoder_again']: x_aba = self.gen_a.decode( s_a_recon, f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( s_b_recon, f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None else: self.mlp_w_copy = copy.deepcopy(self.gen_a.mlp_w) self.mlp_b_copy = copy.deepcopy(self.gen_a.mlp_b) self.dec_copy = copy.deepcopy(self.gen_a.dec) # Error ID = f_a_recon ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1) adain_params_w = self.mlp_w_copy(ID_Style) adain_params_b = self.mlp_b_copy(ID_Style) self.gen_a.assign_adain_params(adain_params_w, adain_params_b, self.dec_copy) x_aba = self.dec_copy( s_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None ID = f_b_recon ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1) adain_params_w = self.mlp_w_copy(ID_Style) adain_params_b = self.mlp_b_copy(ID_Style) self.gen_a.assign_adain_params(adain_params_w, adain_params_b, self.dec_copy) x_bab = self.dec_copy( s_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # auto-encoder image reconstruction self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a) self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b) # feature reconstruction self.loss_gen_recon_s_a = self.recon_criterion( s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0 self.loss_gen_recon_s_b = self.recon_criterion( s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0 self.loss_gen_recon_f_a = self.recon_criterion( f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0 self.loss_gen_recon_f_b = self.recon_criterion( f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0 # Random Erasing only effect the ID and PID loss. if hyperparameters['erasing_p'] > 0: x_a_re = self.to_re(scale2(x_a.clone())) x_b_re = self.to_re(scale2(x_b.clone())) xp_a_re = self.to_re(scale2(xp_a.clone())) xp_b_re = self.to_re(scale2(xp_b.clone())) _, p_a = self.id_a(x_a_re) _, p_b = self.id_b(x_b_re) # encode the same ID different photo _, pp_a = self.id_a(xp_a_re) _, pp_b = self.id_b(xp_b_re) # ID loss AND Tune the Generated image if hyperparameters['ID_style'] == 'PCB': self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b) self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b) self.loss_gen_recon_id = self.PCB_loss( p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b) elif hyperparameters['ID_style'] == 'AB': weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w'] self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \ + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) ) self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion( pp_b[0], l_b ) #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) ) self.loss_gen_recon_id = self.id_criterion( p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b) else: self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion( p_b, l_b) self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion( pp_b, l_b) self.loss_gen_recon_id = self.id_criterion( p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b) #print(f_a_recon, f_a) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 if iteration > hyperparameters['warm_iter']: hyperparameters['recon_f_w'] += hyperparameters['warm_scale'] hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w']) hyperparameters['recon_s_w'] += hyperparameters['warm_scale'] hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w']) hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale'] hyperparameters['recon_x_cyc_w'] = min( hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w']) if iteration > hyperparameters['warm_teacher_iter']: hyperparameters['teacher_w'] += hyperparameters['warm_scale'] hyperparameters['teacher_w'] = min( hyperparameters['teacher_w'], hyperparameters['max_teacher_w']) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['id_w'] * self.loss_id + \ hyperparameters['pid_w'] * self.loss_pid + \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ hyperparameters['teacher_w'] * self.loss_teacher if self.fp16: with amp.scale_loss(self.loss_gen_total, [self.gen_opt, self.id_opt]) as scaled_loss: scaled_loss.backward() self.gen_opt.step() self.id_opt.step() else: self.loss_gen_total.backward() self.gen_opt.step() self.id_opt.step() print("L_total: %.4f, L_gan: %.4f, Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \ hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \ hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \ hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \ hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \ hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \ hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \ hyperparameters['id_w'] * self.loss_id,\ hyperparameters['pid_w'] * self.loss_pid,\ hyperparameters['teacher_w'] * self.loss_teacher ) ) def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def PCB_loss(self, inputs, labels): loss = 0.0 for part in inputs: loss += self.id_criterion(part, labels) return loss / len(inputs) def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], [] for i in range(x_a.size(0)): s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0))) s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0))) f_a, _ = self.id_a(scale2(x_a[i].unsqueeze(0))) f_b, _ = self.id_b(scale2(x_b[i].unsqueeze(0))) x_a_recon.append(self.gen_a.decode(s_a, f_a)) x_b_recon.append(self.gen_b.decode(s_b, f_b)) x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) x_ba1.append(x_ba) x_ab1.append(x_ab) #cycle s_b_recon = self.gen_a.enc_content(self.single(x_ba)) s_a_recon = self.gen_b.enc_content(self.single(x_ab)) f_a_recon, _ = self.id_a(scale2(x_ba)) f_b_recon, _ = self.id_b(scale2(x_ab)) x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon)) x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1) self.train() return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() # encode s_a = self.gen_a.encode(self.single(x_a)) s_b = self.gen_b.encode(self.single(x_b)) f_a, _ = self.id_a(scale2(x_a)) f_b, _ = self.id_b(scale2(x_b)) # decode (cross domain) x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) # D loss self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b print("DLoss: %.4f" % self.loss_dis_total, "Reg: %.4f" % (reg_a + reg_b)) if self.fp16: with amp.scale_loss(self.loss_dis_total, self.dis_opt) as scaled_loss: scaled_loss.backward() else: self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.id_scheduler is not None: self.id_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b = self.gen_a iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b = self.dis_a # Load ID dis last_model_name = get_model_list(checkpoint_dir, "id") state_dict = torch.load(last_model_name) self.id_a.load_state_dict(state_dict['a']) self.id_b = self.id_a # Load optimizers try: state_dict = torch.load( os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) self.id_opt.load_state_dict(state_dict['id']) except: pass # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict()}, dis_name) torch.save({'a': self.id_a.state_dict()}, id_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'id': self.id_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode( c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def snap_clean(self, snap_dir, iterations, save_last=10000, period=20000): # Cleaning snapshot directory from old files if not os.path.exists(snap_dir): return None gen_models = [ os.path.join(snap_dir, f) for f in os.listdir(snap_dir) if "gen" in f and ".pt" in f ] dis_models = [ os.path.join(snap_dir, f) for f in os.listdir(snap_dir) if "dis" in f and ".pt" in f ] gen_models.sort() dis_models.sort() marked_clean = [] for i, model in enumerate(gen_models): m_iter = int(model[-11:-3]) if i == 0: m_prev = 0 continue if m_iter > iterations - save_last: break if m_iter - m_prev < period: marked_clean.append(model) while m_iter - m_prev >= period: m_prev += period for i, model in enumerate(dis_models): m_iter = int(model[-11:-3]) if i == 0: m_prev = 0 continue if m_iter > iterations - save_last: break if m_iter - m_prev < period: marked_clean.append(model) while m_iter - m_prev >= period: m_prev += period print(f'Cleaning snapshots: {marked_clean}') for f in marked_clean: os.remove(f) def save(self, snapshot_dir, iterations, smart_override): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name) if smart_override: self.snap_clean(snapshot_dir, iterations + 1)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode( c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b self.loss_gen_total.backward() self.gen_opt.step() def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, m_A, x_b, m_B, hyperparameters): self.gen_opt.zero_grad() im_A = 1 - m_A im_B = 1 - m_B # encode c_a, s_bA = self.gen_a.encode(x_a, im_A) c_b, s_fB = self.gen_b.encode(x_b, m_B) _, s_fA = self.gen_a.encode(x_a, m_A) _, s_bB = self.gen_b.encode(x_b, im_B) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_fA, m_B, s_bB) x_ab = self.gen_b.decode(c_a, s_fB, m_A, s_bA) # decode (within domain) x_aa = self.gen_a.decode(c_a, s_fA, m_A, s_bA) x_bb = self.gen_b.decode(c_b, s_fB, m_B, s_bB) # encode again c_ba, s_fBA = self.gen_a.encode(x_ba, m_B) c_ab, s_fAB = self.gen_a.encode(x_ab, m_A) _, s_bBA = self.gen_a.encode(x_ba, im_B) _, s_bAB = self.gen_a.encode(x_ab, im_A) # decode again (if needed) x_aba = self.gen_a.decode( c_ab, s_fBA, m_A, s_bAB) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_ba, s_fAB, m_B, s_bBA) if hyperparameters['recon_x_cyc_w'] > 0 else None self.loss_gen_recon_c_a = self.recon_criterion(c_ab, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_ba, c_b) self.loss_gen_recon_s_a = self.recon_criterion(s_bAB, s_bA) self.loss_gen_recon_s_b = self.recon_criterion(s_bBA, s_bB) self.loss_gen_recon_s_af = self.recon_criterion(s_fAB, s_fB) self.loss_gen_recon_s_bf = self.recon_criterion(s_fBA, s_fA) self.loss_gen_recon_x_a = self.recon_criterion(x_aa, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_bb, x_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( im_A * x_aba, im_A * x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( m_B * x_bab, m_B * x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_af + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_bf + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, loader_a, loader_b, size): self.eval() im_a = torch.stack([loader_a.dataset[i][0] for i in range(size)]).cuda() seg_a = torch.stack([loader_a.dataset[i][1] for i in range(size)]).cuda() im_b = torch.stack([loader_b.dataset[i][0] for i in range(size)]).cuda() seg_b = torch.stack([loader_b.dataset[i][1] for i in range(size)]).cuda() x_a_recon, x_b_recon, x_ba1, x_bm, x_ab1, x_am = [], [], [], [], [], [] for i in range(im_a.size(0)): mask_a = seg_a[i].unsqueeze(0) mask_b = seg_b[i].unsqueeze(0) x_a = im_a[i].unsqueeze(0) x_b = im_b[i].unsqueeze(0) masked_a = mask_a * x_a masked_b = mask_b * x_b c_a, s_bA = self.gen_a.encode(x_a, 1 - mask_a) c_b, s_fB = self.gen_b.encode(x_b, mask_b) c_a, s_fA = self.gen_a.encode(x_a, mask_a) c_b, s_bB = self.gen_b.encode(x_b, 1 - mask_b) # decode (cross domain) x_BA = self.gen_a.decode(c_b, s_fA, mask_b, s_bB) x_AB = self.gen_b.decode(c_a, s_fB, mask_a, s_bA) if 0 == i % 2: x_AB = (1 * (1 - mask_a) * x_a + (0 * (1 - mask_a) * x_AB)) + mask_a * x_AB x_BA = (1 * (1 - mask_b) * x_b + (0 * (1 - mask_b) * x_BA)) + mask_b * x_BA x_ba1.append(x_BA) x_ab1.append(x_AB) x_am.append(masked_a) x_bm.append(masked_b) # decode (within domain) x_A_recon = self.gen_a.decode(c_a, s_fA, mask_a, s_bA) x_B_recon = self.gen_b.decode(c_b, s_fB, mask_b, s_bB) x_a_recon.append(x_A_recon) x_b_recon.append(x_B_recon) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1 = torch.cat(x_ba1) x_ab1 = torch.cat(x_ab1) x_bm = torch.cat(x_bm) x_am = torch.cat(x_am) self.train() return im_a, x_a_recon, x_ab1, x_am, im_b, x_b_recon, x_ba1, x_bm def dis_update(self, x_a, m_a, x_b, m_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode up_im_A = 1 - m_a #F.interpolate(1-m_a, None,1, 'bilinear', align_corners=False) up_m_B = m_b #F.interpolate(m_b, None, 1, 'bilinear', align_corners=False) up_m_A = m_a #F.interpolate(m_a, None, 1, 'bilinear', align_corners=False) up_im_B = 1 - m_b #.interpolate(1-m_b, None, 1, 'bilinear', align_corners=False) c_a, s_bA = self.gen_a.encode(x_a, up_im_A) c_b, s_fB = self.gen_b.encode(x_b, up_m_B) _, s_fA = self.gen_a.encode(x_a, up_m_A) _, s_bB = self.gen_b.encode(x_b, up_im_B) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_fA, m_b, s_bB) x_ab = self.gen_b.decode(c_a, s_fB, m_a, s_bA) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler, self.lr_policy = get_scheduler( self.dis_opt, hyperparameters) self.gen_scheduler, self.lr_policy = get_scheduler( self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) self.metric = 0 # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba # 进来两张图 a b def gen_update(self, x_a, x_b, hyperparameters, mask): self.guided = 0 print(type(mask)) self.gen_opt.zero_grad() # 随机出style a b s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode # 通过gen 得到content style' c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) 把encoder decoder应该要能recon x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) 如果结合content 和style 得到的应该是translation结束的结果 if self.guided == 0: x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) elif self.guided == 1: x_ba = self.gen_a.decode(c_b, s_a_prime) x_ab = self.gen_b.decode(c_a, s_b_prime) # encode again 再区分conten style c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode( c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None print(x_a_recon.size(), x_b_recon.size()) # mask loss mask = torch.cat([mask, mask, mask], 1) self.loss_attentive = self.recon_criterion(x_a[mask == 1], x_a_recon[mask == 1]) # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ hyperparameters['att_w'] * self.loss_attentive self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 # 训练discriminator 输入两张图片,各自转成不同的domain。 # 使用各自的content code,但是使用随机的style code,去encode出一张图片 # 希望能骗过discriminator def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: if self.lr_policy == 'plateau': self.dis_scheduler.step(self.metric) else: self.dis_scheduler.step() if self.gen_scheduler is not None: if self.lr_policy == 'plateau': self.dis_scheduler.step(self.metric) else: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler, self.lr_policy = get_scheduler( self.dis_opt, hyperparameters, iterations) self.gen_scheduler, self.lr_policy = get_scheduler( self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class DSMAP_Trainer(nn.Module): def __init__(self, hyperparameters): super(DSMAP_Trainer, self).__init__() # Initiate the networks mid_downsample = hyperparameters['gen'].get('mid_downsample', 1) self.content_enc = ContentEncoder_share( hyperparameters['gen']['n_downsample'], mid_downsample, hyperparameters['gen']['n_res'], hyperparameters['input_dim_a'], hyperparameters['gen']['dim'], 'in', hyperparameters['gen']['activ'], pad_type=hyperparameters['gen']['pad_type']) self.style_dim = hyperparameters['gen']['style_dim'] self.gen_a = AdaINGen( hyperparameters['input_dim_a'], self.content_enc, 'a', hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], self.content_enc, 'b', hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], self.content_enc.output_dim, hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], self.content_enc.output_dim, hyperparameters['dis']) # discriminator for domain b def build_optimizer(self, hyperparameters): # Setup the optimizers lr = hyperparameters['lr'] beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: import torchvision.models as models self.vgg = models.vgg16(pretrained=True) # If you cannot download pretrained model automatically, you can download it from # https://download.pytorch.org/models/vgg16-397923af.pth and load it manually # state_dict = torch.load('vgg16-397923af.pth') # self.vgg.load_state_dict(state_dict) self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def __compute_kl(self, mu, logvar): encoding_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return encoding_loss def gen_update(self, x_a, x_b, hyperparameters, iterations): self.gen_opt.zero_grad() self.gen_backward_cc(x_a, x_b, hyperparameters) #self.gen_opt.step() #self.gen_opt.zero_grad() self.gen_backward_latent(x_a, x_b, hyperparameters) self.gen_opt.step() def gen_backward_latent(self, x_a, x_b, hyperparameters): # random sample style vector and multimodal training s_a_random = Variable( torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b_random = Variable( torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # decode x_ba_random = self.gen_a.decode(self.c_b, self.da_b, s_a_random) x_ab_random = self.gen_b.decode(self.c_a, self.db_a, s_b_random) c_b_random_recon, _, _, s_a_random_recon, _, _ = self.gen_a.encode( x_ba_random) c_a_random_recon, _, _, s_b_random_recon, _, _ = self.gen_b.encode( x_ab_random) # style reconstruction loss self.loss_gen_recon_s_a = self.recon_criterion(s_a_random, s_a_random_recon) self.loss_gen_recon_s_b = self.recon_criterion(s_b_random, s_b_random_recon) loss_gen_recon_c_a = self.recon_criterion(self.c_a, c_a_random_recon) loss_gen_recon_c_b = self.recon_criterion(self.c_b, c_b_random_recon) loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba_random, x_a) loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab_random, x_b) loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba_random, x_a) if hyperparameters['vgg_w'] > 0 else 0 loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab_random, x_b) if hyperparameters['vgg_w'] > 0 else 0 loss_gen_total = hyperparameters['gan_w'] * loss_gen_adv_a + \ hyperparameters['gan_w'] * loss_gen_adv_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * loss_gen_recon_c_a + \ hyperparameters['recon_c_w'] * loss_gen_recon_c_b + \ hyperparameters['vgg_w'] * loss_gen_vgg_a + \ hyperparameters['vgg_w'] * loss_gen_vgg_b self.loss_gen_total += loss_gen_total self.loss_gen_total.backward() self.loss_gen_total += loss_gen_total self.loss_gen_adv_a += loss_gen_adv_a self.loss_gen_adv_b += loss_gen_adv_b self.loss_gen_recon_c_a += loss_gen_recon_c_a self.loss_gen_recon_c_b += loss_gen_recon_c_b self.loss_gen_vgg_a += loss_gen_vgg_a self.loss_gen_vgg_b += loss_gen_vgg_b def gen_backward_cc(self, x_a, x_b, hyperparameters): pre_c_a, self.c_a, c_domain_a, self.db_a, self.s_a_prime, mu_a, logvar_a = self.gen_a.encode( x_a, training=True, flag=True) pre_c_b, self.c_b, c_domain_b, self.da_b, self.s_b_prime, mu_b, logvar_b = self.gen_b.encode( x_b, training=True, flag=True) self.da_a = self.gen_b.domain_mapping(self.c_a, pre_c_a) self.db_b = self.gen_a.domain_mapping(self.c_b, pre_c_b) # decode (within domain) x_a_recon = self.gen_a.decode(self.c_a, self.da_a, self.s_a_prime) x_b_recon = self.gen_b.decode(self.c_b, self.db_b, self.s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(self.c_b, self.da_b, self.s_a_prime) x_ab = self.gen_b.decode(self.c_a, self.db_a, self.s_b_prime) c_b_recon, _, self.db_b_recon, s_a_recon, _, _ = self.gen_a.encode( x_ba, training=True) c_a_recon, _, self.da_a_recon, s_b_recon, _, _ = self.gen_b.encode( x_ab, training=True) # decode again (cycle consistance loss) x_aba = self.gen_a.decode(c_a_recon, self.da_a_recon, s_a_recon) x_bab = self.gen_b.decode(c_b_recon, self.db_b_recon, s_b_recon) # domain-specific content reconstruction loss self.loss_gen_recon_d_a = self.recon_criterion( c_domain_a, self.da_a) if hyperparameters['recon_d_w'] > 0 else 0 self.loss_gen_recon_d_b = self.recon_criterion( c_domain_b, self.db_b) if hyperparameters['recon_d_w'] > 0 else 0 # image reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) # domain-invariant content reconstruction loss self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, self.c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, self.c_b) # cyc loss self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) # kl loss (if needed) self.loss_gen_recon_kl_a = self.__compute_kl( mu_a, logvar_a) if hyperparameters['recon_kl_w'] > 0 else 0 self.loss_gen_recon_kl_b = self.__compute_kl( mu_b, logvar_b) if hyperparameters['recon_kl_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba, x_a) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab, x_b) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_a) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_b) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_d_w'] * self.loss_gen_recon_d_a + \ hyperparameters['recon_d_w'] * self.loss_gen_recon_d_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b # self.loss_gen_total.backward(retain_graph=True) def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg.features(img_vgg) target_fea = vgg.features(target_vgg) return contextual_loss(img_fea, target_fea) def dis_update(self, x_a, x_b, hyperparameters, iterations): self.dis_opt.zero_grad() s_a_random = Variable( torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b_random = Variable( torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode pre_c_a, c_a, c_domain_a, db_a, s_a, _, _ = self.gen_a.encode( x_a, training=True, flag=True) pre_c_b, c_b, c_domain_b, da_b, s_b, _, _ = self.gen_b.encode( x_b, training=True, flag=True) da_a = self.gen_b.domain_mapping(c_a, pre_c_a) db_b = self.gen_a.domain_mapping(c_b, pre_c_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, da_b, s_a) x_ab = self.gen_b.decode(c_a, db_a, s_b) # decode (cross domain) x_ba_random = self.gen_a.decode(c_b, da_b, s_a_random) x_ab_random = self.gen_b.decode(c_a, db_a, s_b_random) c_b_recon, _, db_b_recon, s_a_recon, _, _ = self.gen_a.encode(x_ba) c_a_recon, _, da_a_recon, s_b_recon, _, _ = self.gen_b.encode(x_ab) _, _, db_b_random_recon, _, _, _ = self.gen_a.encode(x_ba_random) _, _, da_a_random_recon, _, _, _ = self.gen_b.encode(x_ab_random) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss( x_ba.detach(), x_a) + self.dis_a.calc_dis_loss( x_ba_random.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss( x_ab.detach(), x_b) + self.dis_b.calc_dis_loss( x_ab_random.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \ hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ab, x_ba = [], [], [], [] for i in range(x_a.size(0)): pre_c_a, c_a, _, db_a, s_a_fake, _, _ = self.gen_a.encode( x_a[i].unsqueeze(0), flag=True) pre_c_b, c_b, _, da_b, s_b_fake, _, _ = self.gen_b.encode( x_b[i].unsqueeze(0), flag=True) da_a = self.gen_b.domain_mapping(c_a, pre_c_a) db_b = self.gen_a.domain_mapping(c_b, pre_c_b) x_a_recon.append(self.gen_a.decode(c_a, da_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, db_b, s_b_fake)) x_ba.append(self.gen_a.decode(c_b, da_b, s_a_fake)) x_ab.append(self.gen_b.decode(c_a, db_a, s_b_fake)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ab, x_ba = torch.cat(x_ab), torch.cat(x_ba) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)