class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() # super() 函数是用于调用父类(超类)的一个方法。 lr = hyperparameters['lr'] # Initiate the networks, 需要好好看看生成器和鉴别器到底是如何构造的 self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b # https://blog.csdn.net/liuxiao214/article/details/81037416 self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) # s_a , s_b 表示的是两个不同的style self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() # 16*8*1*1 self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] # 两个鉴别器 dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) # 两个生成器 gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) # 优化器 self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) # 优化策略 self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization # 解释 apply apply(lambda x,y : x+y, (1),{'y' : 2}) https://zhuanlan.zhihu.com/p/42756654 self.apply(weights_init(hyperparameters['init'])) # 初始化当前类 self.dis_a.apply(weights_init('gaussian')) # 初始化dis_a,是一个类对象 self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) # here the self.s_a is random style s_b = Variable(self.s_b) # here the self.s_b is random style # 两个auto-encoder c_a, s_a_fake = self.gen_a.encode(x_a) # c_a, s_a_fake is the content and style of input x_a c_b, s_b_fake = self.gen_b.encode(x_b) # x_ba 表示的是 imgb->imga x_ba = self.gen_a.decode(c_b, s_a) # combine(c_b, s_a) to generate the x_ba x_ab = self.gen_b.decode(c_a, s_b) # combine(c_a, s_b) to generate the x_ab # 训练模式 self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): # 这里只负责参数的更新操作,真正的训练操作实际上来源于AdainGan 以及 MsImgDis # self.gen_opt 表示的是生成器的优化器 self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # print('content shape:', c_b.shape) # print('style shape:', s_b_prime.shape) # decode (within domain), 这个是必要的部分,因为需要保证解码器的效果 x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # 本质上这里传过来的还是c_a, 我们希望c_a_recon与c_a越相似越好 # decode again (if needed) x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss 为什么要用vgg呢?也就是内容不变的约束 # 对feature map 的 L2 loss 好好理解一下这里的域不变loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def sample(self, x_a, x_b): # 模型调整为eval模式 self.eval() # # random style s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] # 实际上x_a.size(0) = display_size 16 for i in range(x_a.size(0)): # https://blog.csdn.net/xiexu911/article/details/80820028 # unsequeeze 在指定的维度上对数据的维度进行扩展 c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) # 这里用到了cat,那么x_a_recon是4个channel,这样最终输出的就是32个channel,然后我们将这32个channel 分别打印show出来。 x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, hyperparameters): # self.dis_opt 表示的是鉴别器的优化器 self.dis_opt.zero_grad() # s_a : a 图片的风格 # s_b : b 图片的风格 s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode # content encode ,encode 只是将不同图片的内容和风格进行编码,不做迁移处理, 生成器有一个编码器和一个解码器 c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) # 进行交叉域的内容和风格decode,这里需要好好理解一下decode到底在做什么,为什么可以进行内容和风格的混合 x_ba = self.gen_a.decode(c_b, s_a) # xba 是由随机风格和原始图像内容组合的结果 x_ab = self.gen_b.decode(c_a, s_b) # D loss 鉴别器loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) # 这里要把x_ba, x_a 独立出来来看,因为鉴别器的目的是 # 为了判断真伪, self.loss_dis_a 是分别计算x_ba, x_a的loss的和, 他们两个对象之间不涉及对比。 # 这里挺有意思的,鉴别器的loss是两个独立输入图片的二分类loss的和, 注意的一点就是,我们实际上是希望送入鉴别器的这两张图拒用相同的风格。 # 这里就要区分和重建loss的区别。 self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) # 为什么是这样的组合呢? # hyperparameters['gan_w'] weight of adversarial loss self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b # https://blog.csdn.net/jacke121/article/details/82995740 深入理解backward(), step() self.loss_dis_total.backward() # calculate grad self.dis_opt.step() # update grad def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() # update the learning rate if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators # 每次选择的都是最新的模型 last_model_name = get_model_list(checkpoint_dir, "gen") print('resume model: ', last_model_name) state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations, gpuids): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') if len(gpuids) > 1: torch.save({'a': self.gen_a.module.state_dict(), 'b': self.gen_b.module.state_dict()}, gen_name) torch.save({'a': self.dis_a.module.state_dict(), 'b': self.dis_b.module.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.module.state_dict(), 'dis': self.dis_opt.module.state_dict()}, opt_name) else: torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = 0.5 beta2 = 0.999 dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = lr_scheduler.StepLR(self.dis_opt, step_size=hyperparameters['step_size'], gamma=hyperparameters['gamma'], last_epoch=-1) self.gen_scheduler = lr_scheduler.StepLR(self.gen_opt, step_size=hyperparameters['step_size'], gamma=hyperparameters['gamma'], last_epoch=-1) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) s_a_fake = self.gen_a.encode(x_a) s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(x_b, s_a) x_ab = self.gen_b.decode(x_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode s_a_prime = self.gen_a.encode(x_a) s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(x_a, s_a_prime) x_b_recon = self.gen_b.decode(x_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(x_b, s_a) x_ab = self.gen_b.decode(x_a, s_b) # encode again s_a_recon = self.gen_a.encode(x_ba) s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode(x_ba, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(x_ab, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b self.loss_gen_total.backward() self.gen_opt.step() def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(x_a[i].unsqueeze(0), s_a_fake)) x_b_recon.append(self.gen_b.decode(x_b[i].unsqueeze(0), s_b_fake)) x_ba1.append(self.gen_a.decode(x_b[i].unsqueeze(0), s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(x_b[i].unsqueeze(0), s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(x_a[i].unsqueeze(0), s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(x_a[i].unsqueeze(0), s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # decode (cross domain) x_ba = self.gen_a.decode(x_b, s_a) x_ab = self.gen_b.decode(x_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() # def resume(self, checkpoint_dir, hyperparameters): # # Load generators # last_model_name = get_model_list(checkpoint_dir, "gen") # state_dict = torch.load(last_model_name) # self.gen_a.load_state_dict(state_dict['a']) # self.gen_b.load_state_dict(state_dict['b']) # iterations = int(last_model_name[-11:-3]) # # Load discriminators # last_model_name = get_model_list(checkpoint_dir, "dis") # state_dict = torch.load(last_model_name) # self.dis_a.load_state_dict(state_dict['a']) # self.dis_b.load_state_dict(state_dict['b']) # # Load optimizers # state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) # self.dis_opt.load_state_dict(state_dict['dis']) # self.gen_opt.load_state_dict(state_dict['gen']) # # Reinitilize schedulers # self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) # self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) # print('Resume from iteration %d' % iterations) # return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling self.display_size = int(hyperparameters['display_size']) self.s_a = self.random_style() self.s_b = self.random_style() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): """ Args: x_a: Image domain A x_b: Image domain B hyperparameters: Returns: """ self.gen_opt.zero_grad() s_a = self.random_style(x_a) s_b = self.random_style(x_b) # encode c_a, s_a_prime = self.gen_a.encode(x_a) # c_a - content encoding, s_a_prime - style encoding c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) # x_a_recon - reconstruction from content and style vectors x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) # content b, style a x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) # encode to get content_b and style_a from cross domain image c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def random_style(self, x=None, factor=1): dim = self.display_size if x is None else x.size(0) return Variable(torch.randn(dim, self.style_dim, 1, 1).cuda()) * factor def sample(self, x_a, x_b): """ Args: x_a: x_b: Returns: (tuple): domainA: original reconstruction A to B - fixed sample noise A to B - random noise """ self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = self.random_style(x_a, factor=5) s_b2 = self.random_style(x_b, factor=5) #print(s_a1, s_a2) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def sampleB_toA(self, x_b): """ Args: x_b: Returns: (tuple, length=batch): INPUT IMAGES to domain A """ self.eval() s_a2 = self.random_style(x_b) x_ba1 = [] for i in range(x_b.size(0)): # loop through batches c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_ba1.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) return x_ba1 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = self.random_style(x_a) s_b = self.random_style(x_b) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters, gen_model=None, dis_model=None): # Load generators if gen_model is None: gen_model = get_model_list(checkpoint_dir, "gen") # last gen model gen_state_dict = torch.load(gen_model) self.gen_a.load_state_dict(gen_state_dict['a']) self.gen_b.load_state_dict(gen_state_dict['b']) iterations = int(gen_model[-11:-3]) # Load discriminators if dis_model is None: dis_model = get_model_list(checkpoint_dir, "dis") dis_state_dict = torch.load(dis_model) self.dis_a.load_state_dict(dis_state_dict['a']) self.dis_b.load_state_dict(dis_state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def load_model(self, checkpoint_dir, hyperparameters, iteration): from pathlib import Path gen_model = Path(checkpoint_dir) / f"gen_{iteration:08d}.pt" dis_model = Path(checkpoint_dir) / f"dis_{iteration:08d}.pt" return self.resume(checkpoint_dir, hyperparameters, gen_model.as_posix(), dis_model.as_posix()) def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.is_ganilla_gen = hyperparameters['gen']['ganilla_gen'] if self.is_ganilla_gen == False: self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b else: self.gen_a = AdaINGanilla(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a with ganilla architecture self.gen_b = AdaINGanilla(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b with ganilla architecture print(self.gen_a) if hyperparameters['dis']['dis_type'] == 'patch': if hyperparameters['dis']['use_patch_gan']: self.dis_a = PatchDis(hyperparameters['input_dim_a'], hyperparameters['dis']) self.dis_b = PatchDis(hyperparameters['input_dim_b'], hyperparameters['dis']) else: self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b print(self.dis_a) else: self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() self.VggExtract = VggExtract(self.vgg) for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) if self.is_ganilla_gen: c_a = c_a[-1] c_b = c_b[-1] x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_updateN(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode _, s_a_prime = self.gen_a.encode(x_a) c, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c, s_a_prime) x_b_recon = self.gen_b.decode(c, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c, s_a) x_ab = self.gen_b.decode(c, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None if self.is_ganilla_gen: c = c[-1] c_b_recon = c_b_recon[-1] c_a_recon = c_a_recon[-1] # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_c_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_c_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_s_a = self.compute_vgg_loss(self.vgg, x_ba, x_a, all=1) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_s_b = self.compute_vgg_loss(self.vgg, x_ab, x_b, all=1) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_c_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_c_b # hyperparameters['vgg_w'] * self.loss_gen_vgg_s_a + \ # hyperparameters['vgg_w'] * self.loss_gen_vgg_s_b self.loss_gen_total.backward() self.gen_opt.step() def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None if self.is_ganilla_gen: c_a = c_a[-1] c_b = c_b[-1] c_b_recon = c_b_recon[-1] c_a_recon = c_a_recon[-1] # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_c_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_c_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_s_a = self.compute_vgg_loss(self.vgg, x_ba, x_a, all=1) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_s_b = self.compute_vgg_loss(self.vgg, x_ab, x_b, all=1) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_c_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_c_b # hyperparameters['vgg_w'] * self.loss_gen_vgg_s_a + \ # hyperparameters['vgg_w'] * self.loss_gen_vgg_s_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target, all=0): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) # img_fea = vgg(img_vgg) # target_fea = vgg(target_vgg) img_fea_dict = self.VggExtract(img_vgg) target_fea_dict = self.VggExtract(target_vgg) loss=0 if all: # for feature in img_fea_dict: # loss+= torch.mean((img_fea_dict[feature] - (target_fea_dict[feature])) ** 2) loss += torch.mean((img_fea_dict['relu4_3'] - (target_fea_dict['relu4_3'])) ** 2) else: loss += torch.mean( (self.instancenorm(img_fea_dict['relu4_3']) - self.instancenorm(target_fea_dict['relu4_3'])) ** 2) return loss def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_updateN(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode # c, _ = self.gen_a.encode(x_a) c, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c, s_a) x_ab = self.gen_b.decode(c, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class UNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = VAEGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = VAEGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.dis_content = Dis_content() self.gpuid = hyperparameters['gpuID'] # @ add backgound discriminator for each domain self.instancenorm = nn.InstanceNorm2d(512, affine=False) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.content_opt = torch.optim.Adam( self.dis_content.parameters(), lr=lr / 2., betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) self.content_scheduler = get_scheduler(self.content_opt, hyperparameters) # Network weight initialization self.gen_a.apply(weights_init(hyperparameters['init'])) self.gen_b.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) self.dis_content.apply(weights_init('gaussian')) # initialize the blur network self.BGBlur_kernel = [5, 9, 15] self.BlurNet = [ GaussionSmoothLayer(3, k_size, 25).cuda(self.gpuid) for k_size in self.BGBlur_kernel ] self.BlurWeight = [0.25, 0.5, 1.] self.Gradient = GradientLoss(3, 3) # # Load VGG model if needed for test if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg19() if torch.cuda.is_available(): self.vgg.cuda(self.gpuid) self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() h_a = self.gen_a.encode_cont(x_a) # h_a_sty = self.gen_a.encode_sty(x_a) # h_b = self.gen_b.encode_cont(x_b) x_ab = self.gen_b.decode_cont(h_a) # h_c = torch.cat((h_b, h_a_sty), 1) # x_ba = self.gen_a.decode_recs(h_c) # self.train() return x_ab #, x_ba def __compute_kl(self, mu): # def _compute_kl(self, mu, sd): mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def content_update(self, x_a, x_b, hyperparameters): # # encode self.content_opt.zero_grad() enc_a = self.gen_a.encode_cont(x_a) enc_b = self.gen_b.encode_cont(x_b) pred_fake = self.dis_content.forward(enc_a) pred_real = self.dis_content.forward(enc_b) loss_D = 0 if hyperparameters['gan_type'] == 'lsgan': loss_D += torch.mean((pred_fake - 0)**2) + torch.mean( (pred_real - 1)**2) elif hyperparameters['gan_type'] == 'nsgan': all0 = Variable(torch.zeros_like(pred_fake.data).cuda(self.gpuid), requires_grad=False) all1 = Variable(torch.ones_like(pred_real.data).cuda(self.gpuid), requires_grad=False) loss_D += torch.mean( F.binary_cross_entropy(F.sigmoid(pred_fake), all0) + F.binary_cross_entropy(F.sigmoid(pred_real), all1)) else: assert 0, "Unsupported GAN type: {}".format( hyperparameters['gan_type']) loss_D.backward() nn.utils.clip_grad_norm_(self.dis_content.parameters(), 5) self.content_opt.step() def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() self.content_opt.zero_grad() # encode h_a = self.gen_a.encode_cont(x_a) h_b = self.gen_b.encode_cont(x_b) h_a_sty = self.gen_a.encode_sty(x_a) # add domain adverisal loss for generator out_a = self.dis_content(h_a) out_b = self.dis_content(h_b) self.loss_ContentD = 0 if hyperparameters['gan_type'] == 'lsgan': self.loss_ContentD += torch.mean((out_a - 0.5)**2) + torch.mean( (out_b - 0.5)**2) elif hyperparameters['gan_type'] == 'nsgan': all1 = Variable(0.5 * torch.ones_like(out_b.data).cuda(self.gpuid), requires_grad=False) self.loss_ContentD += torch.mean( F.binary_cross_entropy(F.sigmoid(out_a), all1) + F.binary_cross_entropy(F.sigmoid(out_b), all1)) else: assert 0, "Unsupported GAN type: {}".format( hyperparameters['gan_type']) # decode (within domain) h_a_cont = torch.cat((h_a, h_a_sty), 1) noise_a = torch.randn(h_a_cont.size()).cuda(h_a_cont.data.get_device()) x_a_recon = self.gen_a.decode_recs(h_a_cont + noise_a) noise_b = torch.randn(h_b.size()).cuda(h_b.data.get_device()) x_b_recon = self.gen_b.decode_cont(h_b + noise_b) # decode (cross domain) h_ba_cont = torch.cat((h_b, h_a_sty), 1) x_ba = self.gen_a.decode_recs(h_ba_cont + noise_a) x_ab = self.gen_b.decode_cont(h_a + noise_b) # encode again h_b_recon = self.gen_a.encode_cont(x_ba) h_b_sty_recon = self.gen_a.encode_sty(x_ba) h_a_recon = self.gen_b.encode_cont(x_ab) # decode again (if needed) h_a_cat_recs = torch.cat((h_a_recon, h_b_sty_recon), 1) x_aba = self.gen_a.decode_recs( h_a_cat_recs + noise_a) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode_cont( h_b_recon + noise_b) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_kl_a = self.__compute_kl(h_a) self.loss_gen_recon_kl_b = self.__compute_kl(h_b) self.loss_gen_recon_kl_sty = self.__compute_kl(h_a_sty) self.loss_gen_cyc_x_a = self.recon_criterion( x_aba, x_a) if x_aba is not None else 0 self.loss_gen_cyc_x_b = self.recon_criterion( x_bab, x_b) if x_aba is not None else 0 self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) self.loss_gen_recon_kl_cyc_sty = self.__compute_kl(h_b_sty_recon) # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # add background guide loss self.loss_bgm = 0 if hyperparameters['BGM'] != 0: for index, weight in enumerate(self.BlurWeight): out_b = self.BlurNet[index](x_ba) out_real_b = self.BlurNet[index](x_b) out_a = self.BlurNet[index](x_ab) out_real_a = self.BlurNet[index](x_a) grad_loss_b = self.recon_criterion(out_b, out_real_b) grad_loss_a = self.recon_criterion(out_a, out_real_a) self.loss_bgm += weight * (grad_loss_a + grad_loss_b) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_sty + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_sty + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ hyperparameters['BGM'] * self.loss_bgm + \ hyperparameters['gan_w'] * self.loss_ContentD self.loss_gen_total.backward() self.gen_opt.step() self.content_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): if x_a is None or x_b is None: return None self.eval() x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] for i in range(x_a.size(0)): h_a = self.gen_a.encode_cont(x_a[i].unsqueeze(0)) h_a_sty = self.gen_a.encode_sty(x_a[i].unsqueeze(0)) h_b = self.gen_b.encode_cont(x_b[i].unsqueeze(0)) h_ba_cont = torch.cat((h_b, h_a_sty), 1) h_aa_cont = torch.cat((h_a, h_a_sty), 1) x_a_recon.append(self.gen_a.decode_recs(h_aa_cont)) x_b_recon.append(self.gen_b.decode_cont(h_b)) x_ba.append(self.gen_a.decode_recs(h_ba_cont)) x_ab.append(self.gen_b.decode_cont(h_a)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba = torch.cat(x_ba) x_ab = torch.cat(x_ab) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() self.content_opt.zero_grad() # encode h_a = self.gen_a.encode_cont(x_a) h_a_sty = self.gen_a.encode_sty(x_a) h_b = self.gen_b.encode_cont(x_b) # # @ add content adversial out_a = self.dis_content(h_a) out_b = self.dis_content(h_b) self.loss_ContentD = 0 if hyperparameters['gan_type'] == 'lsgan': self.loss_ContentD += torch.mean((out_a - 0)**2) + torch.mean( (out_b - 1)**2) elif hyperparameters['gan_type'] == 'nsgan': all0 = Variable(torch.zeros_like(out_a.data).cuda(self.gpuid), requires_grad=False) all1 = Variable(torch.ones_like(out_b.data).cuda(self.gpuid), requires_grad=False) self.loss_ContentD += torch.mean( F.binary_cross_entropy(F.sigmoid(out_a), all0) + F.binary_cross_entropy(F.sigmoid(out_b), all1)) else: assert 0, "Unsupported GAN type: {}".format( hyperparameters['gan_type']) # decode (cross domain) h_cat = torch.cat((h_b, h_a_sty), 1) noise_b = torch.randn(h_cat.size()).cuda(h_cat.data.get_device()) x_ba = self.gen_a.decode_recs(h_cat + noise_b) noise_a = torch.randn(h_a.size()).cuda(h_a.data.get_device()) x_ab = self.gen_b.decode_cont(h_a + noise_a) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * ( self.loss_dis_a + self.loss_dis_b + self.loss_ContentD) self.loss_dis_total.backward() nn.utils.clip_grad_norm_(self.dis_content.parameters(), 5) # dis_content update self.dis_opt.step() self.content_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.content_scheduler is not None: self.content_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis_00188000") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # load discontent discriminator last_model_name = get_model_list(checkpoint_dir, "dis_Content") state_dict = torch.load(last_model_name) self.dis_content.load_state_dict(state_dict['dis_c']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) self.content_opt.load_state_dict(state_dict['dis_content']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) self.content_scheduler = get_scheduler(self.content_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) dis_Con_name = os.path.join(snapshot_dir, 'dis_Content_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save({'dis_c': self.dis_content.state_dict()}, dis_Con_name) # opt state torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict(), \ 'dis_content':self.content_opt.state_dict()}, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters, opts): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] self.opts = opts # Initiate the networks self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.seg = segmentor(num_classes=2, channels=hyperparameters['input_dim_b'], hyperpars=hyperparameters['seg']) self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.seg_opt = torch.optim.SGD(self.seg.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters['lr_policy'], hyperparameters=hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters['lr_policy'], hyperparameters=hyperparameters) self.seg_scheduler = get_scheduler(self.seg_opt, 'constant', hyperparameters=None) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) self.criterion_seg = DiceLoss(ignore_index=hyperparameters['seg']['ignore_index']) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters, target_a, iters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None if iters >= hyperparameters['guide_gen_iters']: config.task = 0 self.seg.eval() self.pred_x_ab = self.seg(x_ab) self.seg.train() # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # semantic loss ab if iters >= hyperparameters['guide_gen_iters']: self.loss_sem_ab, _ = self.criterion_seg(self.pred_x_ab, target_a) else: self.loss_sem_ab = 0 # only use semantic loss when segmentor has reasonably low loss if not hasattr(self, 'loss_seg_ab') or self.loss_seg_ab.detach().item() > -0.3: self.loss_sem_ab = 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['sem_w'] * self.loss_sem_ab self.loss_gen_total.backward() self.gen_opt.step() def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def seg_update(self, x_a, x_b, target_a, target_b): self.seg.train() self.seg_opt.zero_grad() s_b = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) with torch.no_grad(): # encode c_a, _ = self.gen_a.encode(x_a) # decode (cross domain) x_ab = self.gen_b.decode(c_a, s_b) config.task = 0 self.pred_x_ab = self.seg(x_ab.detach()) config.task = 1 self.pred_x_b = self.seg(x_b) self.loss_seg_ab, _ = self.criterion_seg(self.pred_x_ab, target_a) self.loss_seg_b, _ = self.criterion_seg(self.pred_x_b, target_b) self.loss_seg_total = self.loss_seg_ab + self.loss_seg_b self.loss_seg_total.backward() self.seg_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.seg_scheduler is not None: self.seg_scheduler.step() def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_aba, x_bab, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [], [], [] for i in range(x_b.size(0)): # encode c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) # decode (within domain) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) # decode (cross domain) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba1[-1]) c_a_recon, s_b_recon = self.gen_b.encode(x_ab1[-1]) x_aba.append(self.gen_a.decode(c_a_recon, s_a_fake)) x_bab.append(self.gen_b.decode(c_b_recon, s_b_fake)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) self.train() return x_a, x_a_recon, x_aba, x_ab1, x_ab2, x_b, x_b_recon, x_bab, x_ba1, x_ba2 def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, 'gen') state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, 'dis') state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load segmentor last_model_name = get_model_list(checkpoint_dir, 'seg') state_dict = torch.load(last_model_name) self.seg.load_state_dict(state_dict) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'opt.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) state_dict = torch.load(os.path.join(checkpoint_dir, 'opt_seg.pt')) self.seg_opt.load_state_dict(state_dict) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters['lr_policy'], hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters['lr_policy'], hyperparameters, iterations) self.seg_scheduler = get_scheduler(self.seg_opt, 'constant', None, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) seg_name = os.path.join(snapshot_dir, 'seg_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'opt.pt') opt_seg_name = os.path.join(snapshot_dir, 'opt_seg.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save(self.seg.state_dict(), seg_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name) torch.save(self.seg_opt.state_dict(), opt_seg_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters, resume_epoch=-1, snapshot_dir=None): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks. self.gen = AdaINGen( hyperparameters['input_dim'] + hyperparameters['n_datasets'], hyperparameters['gen'], hyperparameters['n_datasets']) # Auto-encoder for domain a. self.dis = MsImageDis( hyperparameters['input_dim'] + hyperparameters['n_datasets'], hyperparameters['dis']) # Discriminator for domain a. self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] self.beta1 = hyperparameters['beta1'] self.beta2 = hyperparameters['beta2'] self.weight_decay = hyperparameters['weight_decay'] # Initiating and loader pretrained UNet. self.sup = UNet(input_channels=hyperparameters['input_dim'], num_classes=2).cuda() # Fix the noise used in sampling. self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda() # Setup the optimizers. beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis.parameters()) gen_params = list(self.gen.parameters()) + list(self.sup.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(self.beta1, self.beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(self.beta1, self.beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization. self.apply(weights_init(hyperparameters['init'])) self.dis.apply(weights_init('gaussian')) # Presetting one hot encoding vectors. self.one_hot_img = torch.zeros(hyperparameters['n_datasets'], hyperparameters['batch_size'], hyperparameters['n_datasets'], 256, 256).cuda() self.one_hot_c = torch.zeros(hyperparameters['n_datasets'], hyperparameters['batch_size'], hyperparameters['n_datasets'], 64, 64).cuda() for i in range(hyperparameters['n_datasets']): self.one_hot_img[i, :, i, :, :].fill_(1) self.one_hot_c[i, :, i, :, :].fill_(1) if resume_epoch != -1: self.resume(snapshot_dir, hyperparameters) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def semi_criterion(self, input, target): loss = CrossEntropyLoss2d(size_average=False).cuda() return loss(input, target) def forward(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True s_a = Variable(self.s_a, volatile=True) s_b = Variable(self.s_b, volatile=True) one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) c_a, s_a_fake = self.gen.encode(one_hot_x_a) c_b, s_b_fake = self.gen.encode(one_hot_x_b) one_hot_c_b = torch.cat([c_b, self.one_hot_c[d_index_a]], 1) one_hot_c_a = torch.cat([c_a, self.one_hot_c[d_index_b]], 1) x_ba = self.gen.decode(one_hot_c_b, s_a) x_ab = self.gen.decode(one_hot_c_a, s_b) self.train() return x_ab, x_ba def set_gen_trainable(self, train_bool): if train_bool: self.gen.train() for param in self.gen.parameters(): param.requires_grad = True else: self.gen.eval() for param in self.gen.parameters(): param.requires_grad = True def set_sup_trainable(self, train_bool): if train_bool: self.sup.train() for param in self.sup.parameters(): param.requires_grad = True else: self.sup.eval() for param in self.sup.parameters(): param.requires_grad = True def sup_update(self, x_a, x_b, y_a, y_b, d_index_a, d_index_b, use_a, use_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) # Encode. c_a, s_a_prime = self.gen.encode(one_hot_x_a) c_b, s_b_prime = self.gen.encode(one_hot_x_b) # Decode (within domain). one_hot_c_a = torch.cat([c_a, self.one_hot_c[d_index_a]], 1) one_hot_c_b = torch.cat([c_b, self.one_hot_c[d_index_b]], 1) x_a_recon = self.gen.decode(one_hot_c_a, s_a_prime) x_b_recon = self.gen.decode(one_hot_c_b, s_b_prime) # Decode (cross domain). one_hot_c_ab = torch.cat([c_a, self.one_hot_c[d_index_b]], 1) one_hot_c_ba = torch.cat([c_b, self.one_hot_c[d_index_a]], 1) x_ba = self.gen.decode(one_hot_c_ba, s_a) x_ab = self.gen.decode(one_hot_c_ab, s_b) # Encode again. one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1) one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1) c_b_recon, s_a_recon = self.gen.encode(one_hot_x_ba) c_a_recon, s_b_recon = self.gen.encode(one_hot_x_ab) # Forwarding through supervised model. p_a = None p_b = None loss_semi_a = None loss_semi_b = None has_a_label = (c_a[use_a, :, :, :].size(0) != 0) if has_a_label: p_a = self.sup(c_a, use_a, True) p_a_recon = self.sup(c_a_recon, use_a, True) loss_semi_a = self.semi_criterion(p_a, y_a[use_a, :, :]) + \ self.semi_criterion(p_a_recon, y_a[use_a, :, :]) has_b_label = (c_b[use_b, :, :, :].size(0) != 0) if has_b_label: p_b = self.sup(c_b, use_b, True) p_b_recon = self.sup(c_b, use_b, True) loss_semi_b = self.semi_criterion(p_b, y_b[use_b, :, :]) + \ self.semi_criterion(p_b_recon, y_b[use_b, :, :]) self.loss_gen_total = None if loss_semi_a is not None and loss_semi_b is not None: self.loss_gen_total = loss_semi_a + loss_semi_b elif loss_semi_a is not None: self.loss_gen_total = loss_semi_a elif loss_semi_b is not None: self.loss_gen_total = loss_semi_b if self.loss_gen_total is not None: self.loss_gen_total.backward() self.gen_opt.step() def sup_forward(self, x, y, d_index, hyperparameters): self.sup.eval() # Encoding content image. one_hot_x = torch.cat([x, self.one_hot_img[d_index, 0].unsqueeze(0)], 1) content, _ = self.gen.encode(one_hot_x) # Forwarding on supervised model. y_pred = self.sup(content, only_prediction=True) # Computing metrics. pred = y_pred.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy() jacc = jaccard(pred, y.cpu().squeeze(0).numpy()) return jacc, pred, content def gen_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # Encode. one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) c_a, s_a_prime = self.gen.encode(one_hot_x_a) c_b, s_b_prime = self.gen.encode(one_hot_x_b) # Decode (within domain). one_hot_c_a = torch.cat([c_a, self.one_hot_c[d_index_a]], 1) one_hot_c_b = torch.cat([c_b, self.one_hot_c[d_index_b]], 1) x_a_recon = self.gen.decode(one_hot_c_a, s_a_prime) x_b_recon = self.gen.decode(one_hot_c_b, s_b_prime) # Decode (cross domain). one_hot_c_ab = torch.cat([c_a, self.one_hot_c[d_index_b]], 1) one_hot_c_ba = torch.cat([c_b, self.one_hot_c[d_index_a]], 1) x_ba = self.gen.decode(one_hot_c_ba, s_a) x_ab = self.gen.decode(one_hot_c_ab, s_b) # Encode again. one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1) one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1) c_b_recon, s_a_recon = self.gen.encode(one_hot_x_ba) c_a_recon, s_b_recon = self.gen.encode(one_hot_x_ab) # Decode again (if needed). one_hot_c_aba_recon = torch.cat([c_a_recon, self.one_hot_c[d_index_a]], 1) one_hot_c_bab_recon = torch.cat([c_b_recon, self.one_hot_c[d_index_b]], 1) x_aba = self.gen.decode(one_hot_c_aba_recon, s_a_prime) x_bab = self.gen.decode(one_hot_c_bab_recon, s_b_prime) # Reconstruction loss. self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) # GAN loss. self.loss_gen_adv_a = self.dis.calc_gen_loss(one_hot_x_ba) self.loss_gen_adv_b = self.dis.calc_gen_loss(one_hot_x_ab) # Total loss. self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True s_a1 = Variable(self.s_a, volatile=True) s_b1 = Variable(self.s_b, volatile=True) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(), volatile=True) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(), volatile=True) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): one_hot_x_a = torch.cat( [x_a[i].unsqueeze(0), self.one_hot_img_a[i].unsqueeze(0)], 1) one_hot_x_b = torch.cat( [x_b[i].unsqueeze(0), self.one_hot_img_b[i].unsqueeze(0)], 1) c_a, s_a_fake = self.gen.encode(one_hot_x_a) c_b, s_b_fake = self.gen.encode(one_hot_x_b) x_a_recon.append(self.gen.decode(c_a, s_a_fake)) x_b_recon.append(self.gen.decode(c_b, s_b_fake)) x_ba1.append(self.gen.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # Encode. one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) c_a, _ = self.gen.encode(one_hot_x_a) c_b, _ = self.gen.encode(one_hot_x_b) one_hot_c_ba = torch.cat([c_b, self.one_hot_c[d_index_a]], 1) one_hot_c_ab = torch.cat([c_a, self.one_hot_c[d_index_b]], 1) # Decode (cross domain). x_ba = self.gen.decode(one_hot_c_ba, s_a) x_ab = self.gen.decode(one_hot_c_ab, s_b) # D loss. one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1) one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1) self.loss_dis_a = self.dis.calc_dis_loss(one_hot_x_ba, one_hot_x_a) self.loss_dis_b = self.dis.calc_dis_loss(one_hot_x_ab, one_hot_x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \ hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): print("--> " + checkpoint_dir) # Load generator. last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen.load_state_dict(state_dict) epochs = int(last_model_name[-11:-3]) # Load supervised model. last_model_name = get_model_list(checkpoint_dir, "sup") state_dict = torch.load(last_model_name) self.sup.load_state_dict(state_dict) # Load discriminator. last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis.load_state_dict(state_dict) # Load optimizers. last_model_name = get_model_list(checkpoint_dir, "opt") state_dict = torch.load(last_model_name) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) for state in self.dis_opt.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() for state in self.gen_opt.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() # Reinitilize schedulers. self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, epochs) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, epochs) print('Resume from epoch %d' % epochs) return epochs def save(self, snapshot_dir, epoch): # Save generators, discriminators, and optimizers. gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % epoch) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % epoch) sup_name = os.path.join(snapshot_dir, 'sup_%08d.pt' % epoch) opt_name = os.path.join(snapshot_dir, 'opt_%08d.pt' % epoch) torch.save(self.gen.state_dict(), gen_name) torch.save(self.dis.state_dict(), dis_name) torch.save(self.sup.state_dict(), sup_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class UNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b if not hyperparameters['origin']: self.dis_a = MultiscaleDiscriminator(hyperparameters['input_dim_a'], # discriminator for a ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d, use_sigmoid=False, num_D=2, getIntermFeat=True ) self.dis_b = MultiscaleDiscriminator(hyperparameters['input_dim_b'], # discriminator for b ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d, use_sigmoid=False, num_D=2, getIntermFeat=True ) self.criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.FloatTensor) else: self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) self.instancenorm = nn.InstanceNorm2d(512, affine=False) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def compute_digits_differnce(self, digits1, digits2, weight=1.0): feat_diff = 0 feat_weights = 4.0 / (3 + 1) # 3 layers's discrminator D_weights = 1.0 / 2.0 # number of discrminator for i in range(2): for j in range(len(digits2[i])-1): feat_diff += D_weights * feat_weights * \ F.l1_loss(digits2[i][j], digits1[i][j].detach()) * weight return feat_diff def compute_gan_loss(self, real_digits, fake_digits, gan_cri, loss_at='None'): errD = None errG = None errG_feat = None if gan_cri is not None: if loss_at == 'D': errD = (gan_cri(real_digits, True) \ + gan_cri(fake_digits, False)) * 0.5 elif loss_at == 'G': errG = gan_cri(fake_digits, True) errG_feat = self.compute_digits_differnce(real_digits, fake_digits, weight=10.0) return errD, errG, errG_feat def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() h_a, _ = self.gen_a.encode(x_a) h_b, _ = self.gen_b.encode(x_b) x_ba, _ = self.gen_a.decode(h_b) x_ab, _ = self.gen_b.decode(h_a) self.train() return x_ab, x_ba def __compute_kl(self, mu): # def _compute_kl(self, mu, sd): # mu_2 = torch.pow(mu, 2) # sd_2 = torch.pow(sd, 2) # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0) # return encoding_loss mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() ############ Encode #########################################$ h_a, n_a = self.gen_a.encode(x_a) h_b, n_b = self.gen_b.encode(x_b) # decode (within domain) if hyperparameters['zero_z']: pre_Z = Variable(torch.zeros(hyperparameters['batch_size'], hyperparameters['gen']['z_num']).cuda()) else: pre_Z = None ########### Reconstruction ################################### x_a_recon, _ = self.gen_a.decode(h_a + n_a, z_var=pre_Z) x_b_recon, _ = self.gen_b.decode(h_b + n_b, z_var=pre_Z) ############################################################## ########### Decode (Cross Domain) ############################ ############################################################## ########## with random vector ################ x_ba, z_var_ba_1 = self.gen_a.decode(h_b + n_b) x_ab, z_var_ab_1 = self.gen_b.decode(h_a + n_a) ########## with zero latent vector ########### x_ba_zero, _ = self.gen_a.decode(h_b + n_b, z_var=pre_Z) x_ab_zero, _ = self.gen_b.decode(h_a + n_a, z_var=pre_Z) ######## decode (cross domain the second time) ################ if hyperparameters['loss_eg_weight'] != 0: x_ba_eg, z_var_ba_2 = self.gen_a.decode(h_b + n_b) x_ab_eg, z_var_ab_2 = self.gen_b.decode(h_a + n_a) x_ba_eg = x_ba_eg.detach() x_ab_eg = x_ab_eg.detach() if not hyperparameters['origin']: x_ba_eg_digits = self.dis_a(x_ba_eg) x_ab_eg_digits = self.dis_b(x_ab_eg) # encode again h_b_recon, n_b_recon = self.gen_a.encode(x_ba_zero) h_a_recon, n_a_recon = self.gen_b.encode(x_ab_zero) # decode again (if needed) x_aba, _ = self.gen_a.decode(h_a_recon + n_a_recon, z_var=pre_Z ) if hyperparameters['recon_x_cyc_w'] > 0 else (None, 0) x_bab, _ = self.gen_b.decode(h_b_recon + n_b_recon, z_var=pre_Z ) if hyperparameters['recon_x_cyc_w'] > 0 else (None, 0) # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_kl_a = self.__compute_kl(h_a) self.loss_gen_recon_kl_b = self.__compute_kl(h_b) self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) if x_aba is not None else 0 self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) if x_bab is not None else 0 self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) if hyperparameters['loss_eg_weight'] == 0: self.loss_gen_adv_a, self.loss_gen_adv_b, self.loss_gan_feat_a, \ self.loss_gan_feat_b = 0, 0, 0, 0 elif not hyperparameters['origin']: x_ba_digits = self.dis_a(x_ba) x_a_digits = self.dis_a(x_a) _, self.loss_gen_adv_a, self.loss_gan_feat_a = \ self.compute_gan_loss(x_a_digits, x_ba_digits, self.criterionGAN, loss_at='G') x_ab_digits = self.dis_a(x_ab) x_b_digits = self.dis_a(x_b) _, self.loss_gen_adv_b, self.loss_gan_feat_b = \ self.compute_gan_loss(x_b_digits, x_ab_digits, self.criterionGAN, loss_at='G') else: self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) self.loss_gan_feat_a, self.loss_gan_feat_b = 0, 0 # GAN loss # self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) # self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 if hyperparameters['loss_eg_weight'] == 0: self.loss_eg = 0.0 elif not hyperparameters['origin']: self.loss_eg = compute_eg_loss(x_ba_digits, x_ba_eg_digits, x_ab_digits, x_ab_eg_digits, z_var_ba_1, z_var_ba_2, z_var_ab_1, z_var_ab_2, hyperparameters) else: self.loss_eg = compute_eg_loss(x_ba, x_ba_eg, x_ab, x_ab_eg, z_var_ba_1, z_var_ba_2, z_var_ab_1, z_var_ab_2, hyperparameters) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ self.loss_gan_feat_b + self.loss_gan_feat_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ hyperparameters['loss_eg_weight'] * self.loss_eg self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] for i in range(x_a.size(0)): h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(h_a)[0]) x_b_recon.append(self.gen_b.decode(h_b)[0]) x_ba.append(self.gen_a.decode(h_b)[0]) x_ab.append(self.gen_b.decode(h_a)[0]) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba = torch.cat(x_ba) x_ab = torch.cat(x_ab) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() # encode h_a, n_a = self.gen_a.encode(x_a) h_b, n_b = self.gen_b.encode(x_b) # decode (cross domain) x_ba, _ = self.gen_a.decode(h_b + n_b) x_ab, _ = self.gen_b.decode(h_a + n_a) # D loss if not hyperparameters['origin']: real_digits_a = self.dis_a(x_a) fake_digits_a = self.dis_a(x_ba.detach()) real_digits_b = self.dis_b(x_b) fake_digits_b = self.dis_b(x_ab.detach()) self.loss_dis_a, _, _ = self.compute_gan_loss(real_digits_a, fake_digits_a, self.criterionGAN, loss_at='D') self.loss_dis_b, _, _ = self.compute_gan_loss(real_digits_b, fake_digits_b, self.criterionGAN, loss_at='D') else: self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), eps=1e-8, weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), eps=1e-8, weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) self.iter = 0 def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def intrinsic_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def volumeloss_criterion(self, input, target): idx_select = torch.tensor([0]).cuda() input, target = input.index_select(1, idx_select), target.index_select( 1, idx_select) input, target = torch.mean(input, 3), torch.mean(target, 3) input, target = torch.mean(input, 2), torch.mean(target, 2) return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): with torch.no_grad(): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def update_iter(self): self.iter += 1 def gen_update(self, x_a, x_b, hyperparameters, x_a_rand=None, x_b_rand=None): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode( c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba, x_a_rand) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab, x_b_rand) # ceps loss self.loss_gen_ceps_a = self.calc_cepstrum_loss( x_ba) if hyperparameters['ceps_w'] > 0 else 0 self.loss_gen_ceps_b = self.calc_cepstrum_loss( x_ab) if hyperparameters['ceps_w'] > 0 else 0 # flux loss self.loss_gen_flux_a2b = self.calc_spectral_flux_loss( x_ab) if hyperparameters['flux_w'] > 0 else 0 self.loss_gen_flux_b2a = self.calc_spectral_flux_loss( x_ba) if hyperparameters['flux_w'] > 0 else 0 # enve loss self.loss_gen_enve_a2b = self.calc_spectral_enve15_loss( x_ab) if hyperparameters['enve_w'] > 0 else 0 self.loss_gen_enve_b2a = self.calc_spectral_enve15_loss( x_ba) if hyperparameters['enve_w'] > 0 else 0 # volume loss self.loss_gen_vol_a = self.volumeloss_criterion( x_a, x_ab) if hyperparameters['vol_w'] > 0 else 0 self.loss_gen_vol_b = self.volumeloss_criterion( x_b, x_ba) if hyperparameters['vol_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['ceps_w'] * self.loss_gen_ceps_a + \ hyperparameters['ceps_w'] * self.loss_gen_ceps_b + \ hyperparameters['flux_w'] * self.loss_gen_flux_a2b + \ hyperparameters['flux_w'] * self.loss_gen_flux_b2a + \ hyperparameters['enve_w'] * self.loss_gen_enve_a2b + \ hyperparameters['enve_w'] * self.loss_gen_enve_b2a + \ hyperparameters['vol_w'] * self.loss_gen_vol_a + \ hyperparameters['vol_w'] * self.loss_gen_vol_b self.loss_gen_total.backward() if hyperparameters['clip_grad'] == 'value': torch.nn.utils.clip_grad_value_( list(self.gen_a.parameters()) + list(self.gen_b.parameters()), 1) elif hyperparameters['clip_grad'] == 'norm': torch.nn.utils.clip_grad_norm_( list(self.gen_a.parameters()) + list(self.gen_b.parameters()), 0.5) self.gen_opt.step() def calc_cepstrum_loss(self, x_fake): idx_select_spec = torch.tensor([0]).cuda() idx_select_ceps = torch.tensor([1]).cuda() fake_spec = x_fake.index_select( 1, idx_select_spec).detach().cpu().numpy() ceps = scipy.fftpack.dct(fake_spec, axis=2, type=2, norm='ortho') ceps = np.maximum(ceps, 0) return self.intrinsic_criterion( x_fake.index_select(1, idx_select_ceps), torch.from_numpy(ceps).cuda()) def calc_spectral_flux_loss(self, x_fake): idx_select_spec = torch.tensor([0]).cuda() idx_select_flux = torch.tensor([2]).cuda() fake_spec = x_fake.index_select( 1, idx_select_spec).detach().cpu().numpy() spec_flux = np.zeros_like(fake_spec) hei, wid = 256, 256 for i in range(1, wid - 1): spec_flux[:, :, :, i] = np.maximum( fake_spec[:, :, :, i + 1] - fake_spec[:, :, :, i - 1], 0.0) spec_flux[:, :, :, 0] = spec_flux[:, :, :, 1] spec_flux[:, :, :, -1] = spec_flux[:, :, :, -2] return self.intrinsic_criterion( x_fake.index_select(1, idx_select_flux), torch.from_numpy(spec_flux).cuda()) def calc_spectral_enve15_loss(self, x_fake): idx_select_spec = torch.tensor([0]).cuda() idx_select_enve = torch.tensor([3]).cuda() fake_spec = x_fake.index_select( 1, idx_select_spec).detach().cpu().numpy() MFCC = scipy.fftpack.dct(fake_spec, axis=2, type=2, norm='ortho') MFCC[:, :, 15:, :] = 0.0 spec_enve = scipy.fftpack.idct(MFCC, axis=2, type=2, norm='ortho') spec_enve = np.maximum(spec_enve, 0.0) return self.intrinsic_criterion( x_fake.index_select(1, idx_select_enve), torch.from_numpy(spec_enve).cuda()) def sample(self, x_a, x_b): with torch.no_grad(): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable( torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable( torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() #self.save_grad(list(self.dis_a.named_parameters()) + list(self.dis_b.named_parameters())) #torch.nn.utils.clip_grad_norm_(list(self.dis_a.parameters()) + list(self.dis_b.parameters()), 0.5) self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters["lr"] self.gen_state = hyperparameters["gen_state"] self.guided = hyperparameters["guided"] self.newsize = hyperparameters["crop_image_height"] self.semantic_w = hyperparameters["semantic_w"] > 0 self.recon_mask = hyperparameters["recon_mask"] == 1 self.check_alignment = hyperparameters["check_alignment"] == 1 self.full_adaptation = hyperparameters["full_adaptation"] == 1 self.dann_scheduler = None self.full_adaptation = hyperparameters["full_adaptation"] == 1 if "domain_adv_w" in hyperparameters.keys(): self.domain_classif = hyperparameters["domain_adv_w"] > 0 else: self.domain_classif = False if self.gen_state == 0: # Initiate the networks self.gen_a = AdaINGen( hyperparameters["input_dim_a"], hyperparameters["gen"]) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters["input_dim_b"], hyperparameters["gen"]) # auto-encoder for domain b elif self.gen_state == 1: self.gen = AdaINGen_double(hyperparameters["input_dim_a"], hyperparameters["gen"]) else: print("self.gen_state unknown value:", self.gen_state) self.dis_a = MsImageDis( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters["input_dim_b"], hyperparameters["dis"]) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters["gen"]["style_dim"] # fix the noise used in sampling display_size = int(hyperparameters["display_size"]) print(self.style_dim) print(display_size) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters["beta1"] beta2 = hyperparameters["beta2"] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) if self.gen_state == 0: gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) elif self.gen_state == 1: gen_params = list(self.gen.parameters()) else: print("self.gen_state unknown value:", self.gen_state) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters["init"])) self.dis_a.apply(weights_init("gaussian")) self.dis_b.apply(weights_init("gaussian")) # Load VGG model if needed if "vgg_w" in hyperparameters.keys() and hyperparameters["vgg_w"] > 0: self.vgg = load_vgg16(hyperparameters["vgg_model_path"] + "/models") self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False # Load semantic segmentation model if needed if "semantic_w" in hyperparameters.keys( ) and hyperparameters["semantic_w"] > 0: self.segmentation_model = load_segmentation_model( hyperparameters["semantic_ckpt_path"]) self.segmentation_model.eval() for param in self.segmentation_model.parameters(): param.requires_grad = False # Load domain classifier if needed if ("domain_adv_w" in hyperparameters.keys() and hyperparameters["domain_adv_w"] > 0): self.domain_classifier = domainClassifier(256) dann_params = list(self.domain_classifier.parameters()) self.dann_opt = torch.optim.Adam( [p for p in dann_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.domain_classifier.apply(weights_init("gaussian")) self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters) def recon_criterion(self, input, target): """ Compute pixelwise L1 loss between two images input and target Arguments: input {torch.Tensor} -- Image tensor (original image such as x_a) target {torch.Tensor} -- Image tensor (after cycle-translation image x_aba) Returns: torch.Float -- pixelwise L1 loss """ return torch.mean(torch.abs(input - target)) def recon_criterion_mask(self, input, target, mask): """ Compute a weaker version of the recon_criterion between two images input and target where the L1 is only computed on the unmasked region Arguments: input {torch.Tensor} -- Image (original image such as x_a) target {torch.Tensor} -- Image (after cycle-translation image x_aba) mask {} -- binary Mask of size HxW (input.shape ~ CxHxW) Returns: torch.Float -- L1 loss over input.(1-mask) and target.(1-mask) """ return torch.mean(torch.abs(torch.mul((input - target), 1 - mask))) def forward(self, x_a, x_b): """ Perform the translation from domain A (resp B) to domain B (resp A): x_a to x_ab (resp: x_b to x_ba). Arguments: x_a {torch.Tensor} -- Image from domain A after transform in tensor format x_b {torch.Tensor} -- Image from domain B after transform in tensor format Returns: torch.Tensor, torch.Tensor -- Translated version of x_a in domain B, Translated version of x_b in domain A """ self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) if self.gen_state == 0: c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) elif self.gen_state == 1: c_a, s_a_fake = self.gen.encode(x_a, 1) c_b, s_b_fake = self.gen.encode(x_b, 2) x_ba = self.gen.decode(c_b, s_a, 1) x_ab = self.gen.decode(c_a, s_b, 2) else: print("self.gen_state unknown value:", self.gen_state) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters, mask_a=None, mask_b=None, comet_exp=None, synth=0): """ Update the generator parameters Arguments: x_a {torch.Tensor} -- Image from domain A after transform in tensor format x_b {torch.Tensor} -- Image from domain B after transform in tensor format hyperparameters {dictionnary} -- dictionnary with all hyperparameters Keyword Arguments: mask_a {torch.Tensor} -- binary mask (0,1) corresponding to the ground in x_a (default: {None}) mask_b {torch.Tensor} -- binary mask (0,1) corresponding to the water in x_b (default: {None}) comet_exp {cometExperience} -- CometML object use to log all the loss and images (default: {None}) synth {boolean} -- binary True or False stating if we have a synthetic pair or not Returns: [type] -- [description] """ self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) if self.gen_state == 0: # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) if self.guided == 0: x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) elif self.guided == 1: x_ba = self.gen_a.decode(c_b, s_a_prime) x_ab = self.gen_b.decode(c_a, s_b_prime) else: print("self.guided unknown value:", self.guided) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = (self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters["recon_x_cyc_w"] > 0 else None) x_bab = (self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters["recon_x_cyc_w"] > 0 else None) elif self.gen_state == 1: # encode c_a, s_a_prime = self.gen.encode(x_a, 1) print(c_a.shape) c_b, s_b_prime = self.gen.encode(x_b, 2) # decode (within domain) x_a_recon = self.gen.decode(c_a, s_a_prime, 1) x_b_recon = self.gen.decode(c_b, s_b_prime, 2) # decode (cross domain) if self.guided == 0: x_ba = self.gen.decode(c_b, s_a, 1) x_ab = self.gen.decode(c_a, s_b, 2) elif self.guided == 1: x_ba = self.gen.decode(c_b, s_a_prime, 1) x_ab = self.gen.decode(c_a, s_b_prime, 2) else: print("self.guided unknown value:", self.guided) # encode again c_b_recon, s_a_recon = self.gen.encode(x_ba, 1) c_a_recon, s_b_recon = self.gen.encode(x_ab, 2) # decode again (if needed) x_aba = (self.gen.decode(c_a_recon, s_a_prime, 1) if hyperparameters["recon_x_cyc_w"] > 0 else None) x_bab = (self.gen.decode(c_b_recon, s_b_prime, 2) if hyperparameters["recon_x_cyc_w"] > 0 else None) else: print("self.gen_state unknown value:", self.gen_state) # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) if self.guided == 0: self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) elif self.guided == 1: self.loss_gen_recon_s_a = self.recon_criterion( s_a_recon, s_a_prime) self.loss_gen_recon_s_b = self.recon_criterion( s_b_recon, s_b_prime) else: print("self.guided unknown value:", self.guided) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) # Synthetic reconstruction loss if self.check_alignment: print('mask_b.shape', mask_b.shape) # Define the mask of exact same pixel among a pair mask_alignment = (torch.sum(torch.abs(x_a - x_b), 1) == 0).unsqueeze(1) mask_alignment = mask_alignment.type(torch.cuda.FloatTensor) #print('mask_alignment.shape', mask_alignment.shape) self.loss_gen_recon_synth = self.recon_criterion_mask(x_ab, x_b, 1-mask_alignment) + \ self.recon_criterion_mask(x_ba, x_a, 1-mask_alignment) if self.check_alignment else 0 if self.recon_mask: self.loss_gen_cycrecon_x_a = (self.recon_criterion_mask( x_aba, x_a, mask_a) if hyperparameters["recon_x_cyc_w"] > 0 else 0) self.loss_gen_cycrecon_x_b = (self.recon_criterion_mask( x_bab, x_b, mask_b) if hyperparameters["recon_x_cyc_w"] > 0 else 0) else: self.loss_gen_cycrecon_x_a = (self.recon_criterion( x_aba, x_a) if hyperparameters["recon_x_cyc_w"] > 0 else 0) self.loss_gen_cycrecon_x_b = (self.recon_criterion( x_bab, x_b) if hyperparameters["recon_x_cyc_w"] > 0 else 0) # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = (self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters["vgg_w"] > 0 else 0) self.loss_gen_vgg_b = (self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters["vgg_w"] > 0 else 0) # semantic-segmentation loss self.loss_sem_seg = (self.compute_semantic_seg_loss( x_a.squeeze(), x_ab.squeeze(), mask_a) + self.compute_semantic_seg_loss( x_b.squeeze(), x_ba.squeeze(), mask_b) if hyperparameters["semantic_w"] > 0 else 0) # Domain adversarial loss (c_a and c_b are swapped because we want the feature to be less informative # minmax (accuracy but max min loss) self.domain_adv_loss = (self.compute_domain_adv_loss( c_a, c_b, compute_accuracy=False, minimize=False) if hyperparameters["domain_adv_w"] > 0 else 0) # total loss self.loss_gen_total = ( hyperparameters["gan_w"] * self.loss_gen_adv_a + hyperparameters["gan_w"] * self.loss_gen_adv_b + hyperparameters["recon_x_w"] * self.loss_gen_recon_x_a + hyperparameters["recon_s_w"] * self.loss_gen_recon_s_a + hyperparameters["recon_c_w"] * self.loss_gen_recon_c_a + hyperparameters["recon_x_w"] * self.loss_gen_recon_x_b + hyperparameters["recon_s_w"] * self.loss_gen_recon_s_b + hyperparameters["recon_c_w"] * self.loss_gen_recon_c_b + hyperparameters["recon_x_cyc_w"] * self.loss_gen_cycrecon_x_a + hyperparameters["recon_x_cyc_w"] * self.loss_gen_cycrecon_x_b + hyperparameters["vgg_w"] * self.loss_gen_vgg_a + hyperparameters["vgg_w"] * self.loss_gen_vgg_b + hyperparameters["semantic_w"] * self.loss_sem_seg + hyperparameters["domain_adv_w"] * self.domain_adv_loss + hyperparameters["recon_synth_w"] * self.loss_gen_recon_synth) self.loss_gen_total.backward() self.gen_opt.step() if comet_exp is not None: comet_exp.log_metric("loss_gen_adv_a", self.loss_gen_adv_a.cpu().detach()) comet_exp.log_metric("loss_gen_adv_b", self.loss_gen_adv_b.cpu().detach()) comet_exp.log_metric("loss_gen_recon_x_a", self.loss_gen_recon_x_a.cpu().detach()) comet_exp.log_metric("loss_gen_recon_s_a", self.loss_gen_recon_s_a.cpu().detach()) comet_exp.log_metric("loss_gen_recon_c_a", self.loss_gen_recon_c_a.cpu().detach()) comet_exp.log_metric("loss_gen_recon_x_b", self.loss_gen_recon_x_b.cpu().detach()) comet_exp.log_metric("loss_gen_recon_s_b", self.loss_gen_recon_s_b.cpu().detach()) comet_exp.log_metric("loss_gen_recon_c_b", self.loss_gen_recon_c_b.cpu().detach()) comet_exp.log_metric("loss_gen_cycrecon_x_a", self.loss_gen_cycrecon_x_a.cpu().detach()) comet_exp.log_metric("loss_gen_cycrecon_x_b", self.loss_gen_cycrecon_x_b.cpu().detach()) comet_exp.log_metric("loss_gen_total", self.loss_gen_total.cpu().detach()) if hyperparameters["vgg_w"] > 0: comet_exp.log_metric("loss_gen_vgg_a", self.loss_gen_vgg_a.cpu().detach()) comet_exp.log_metric("loss_gen_vgg_b", self.loss_gen_vgg_b.cpu().detach()) if hyperparameters["semantic_w"] > 0: comet_exp.log_metric("loss_sem_seg", self.loss_sem_seg.cpu().detach()) if hyperparameters["domain_adv_w"] > 0: comet_exp.log_metric("domain_adv_loss_gen", self.domain_adv_loss.cpu().detach()) if synth == 0: comet_exp.log_metric("loss_gen_recon_synth", self.loss_gen_recon_synth.cpu().detach()) def compute_vgg_loss(self, vgg, img, target): """ Compute the domain-invariant perceptual loss Arguments: vgg {model} -- popular Convolutional Network for Classification and Detection img {torch.Tensor} -- image before translation target {torch.Tensor} -- image after translation Returns: torch.Float -- domain invariant perceptual loss """ img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def compute_domain_adv_loss(self, c_a, c_b, compute_accuracy=False, minimize=True): """ Compute a domain adversarial loss on the embedding of the classifier: we are trying to learn an anonymized representation of the content. Arguments: c_a {torch.tensor} -- content extracted from an image of domain A with encoder A c_b {torch.tensor} -- content extracted from an image of domain B with encoder B Keyword Arguments: compute_accuracy {bool} -- either return only the loss or loss and softmax probs (default: {False}) minimize {bool} -- optimize classification accuracy(True) or anonymized the representation(False) Returns: torch.Float -- loss (optionnal softmax P(classifier(c_a)=a) and P(classifier(c_b)=b)) """ # Infer domain classifier on content extracted from an image of domainA output_a = self.domain_classifier(c_a) # Infer domain classifier on content extracted from an image of domainB output_b = self.domain_classifier(c_b) # Concatenate the output in a single vector output = torch.cat((output_a, output_b)) if minimize: target = torch.tensor([1., 0., 0., 1.], device='cuda') else: target = torch.tensor([0.5, 0.5, 0.5, 0.5], device='cuda') # mean square error loss loss = torch.nn.MSELoss()(output, target) if compute_accuracy: return loss, output_a[0], output_b[1] else: return loss def compute_semantic_seg_loss(self, img1, img2, mask=None): """ Compute semantic segmentation loss between two images on the unmasked region or in the entire image Arguments: img1 {torch.Tensor} -- Image from domain A after transform in tensor format img2 {torch.Tensor} -- Image transformed mask {torch.Tensor} -- Binary mask where we force the loss to be zero Returns: torch.float -- Cross entropy loss on the unmasked region """ # denorm img1_denorm = (img1 + 1) / 2.0 img2_denorm = (img2 + 1) / 2.0 # norm for semantic seg network input_transformed1 = seg_batch_transform(img1_denorm) input_transformed2 = seg_batch_transform(img2_denorm) # compute labels from original image and logits from translated version target = (self.segmentation_model(input_transformed1).max(1)[1]) output = self.segmentation_model(input_transformed2) if not self.full_adaptation and mask is not None: # Resize mask to the size of the image mask1 = torch.nn.functional.interpolate(mask, size=(self.newsize, self.newsize)) mask1_tensor = torch.tensor(mask1, dtype=torch.long).cuda() mask1_tensor = mask1_tensor.squeeze(1) # we want the masked region to be labeled as unknown (19 is not an existing label) target_with_mask = torch.mul(1 - mask1_tensor, target) + mask1_tensor * 19 mask2 = torch.nn.functional.interpolate(mask, size=(self.newsize, self.newsize)) mask_tensor = torch.tensor(mask2, dtype=torch.float).cuda() output_with_mask = torch.mul(1 - mask_tensor, output) # cat the mask as to the logits (loss=0 over the masked region) output_with_mask_cat = torch.cat((output_with_mask, mask_tensor), dim=1) loss = nn.CrossEntropyLoss()(output_with_mask_cat, target_with_mask) else: loss = nn.CrossEntropyLoss()(output, target) return loss def sample(self, x_a, x_b): """ Infer the model on a batch of image Arguments: x_a {torch.Tensor} -- batch of image from domain A x_b {[type]} -- batch of image from domain B Returns: A list of torch images -- columnwise :x_a, autoencode(x_a), x_ab_1, x_ab_2 Or if self.semantic_w is true: x_a, autoencode(x_a), Semantic segmentation x_a, x_ab_1,semantic segmentation x_ab_1, x_ab_2 """ self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] if self.gen_state == 0: for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) if self.guided == 0: x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) elif self.guided == 1: x_ba1.append(self.gen_a.decode( c_b, s_a_fake)) # s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode( c_b, s_a_fake)) # s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode( c_a, s_b_fake)) # s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode( c_a, s_b_fake)) # s_b2[i].unsqueeze(0))) else: print("self.guided unknown value:", self.guided) elif self.gen_state == 1: for i in range(x_a.size(0)): c_a, s_a_fake = self.gen.encode(x_a[i].unsqueeze(0), 1) c_b, s_b_fake = self.gen.encode(x_b[i].unsqueeze(0), 2) x_a_recon.append(self.gen.decode(c_a, s_a_fake, 1)) x_b_recon.append(self.gen.decode(c_b, s_b_fake, 2)) if self.guided == 0: x_ba1.append(self.gen.decode(c_b, s_a1[i].unsqueeze(0), 1)) x_ba2.append(self.gen.decode(c_b, s_a2[i].unsqueeze(0), 1)) x_ab1.append(self.gen.decode(c_a, s_b1[i].unsqueeze(0), 2)) x_ab2.append(self.gen.decode(c_a, s_b2[i].unsqueeze(0), 2)) elif self.guided == 1: x_ba1.append(self.gen.decode(c_b, s_a_fake, 1)) # s_a1[i].unsqueeze(0))) x_ba2.append(self.gen.decode(c_b, s_a_fake, 1)) # s_a2[i].unsqueeze(0))) x_ab1.append(self.gen.decode(c_a, s_b_fake, 2)) # s_b1[i].unsqueeze(0))) x_ab2.append(self.gen.decode(c_a, s_b_fake, 2)) # s_b2[i].unsqueeze(0))) else: print("self.guided unknown value:", self.guided) else: print("self.gen_state unknown value:", self.gen_state) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) if self.semantic_w: rgb_a_list, rgb_b_list, rgb_ab_list, rgb_ba_list = [], [], [], [] for i in range(x_a.size(0)): # Inference semantic segmentation on original images im_a = (x_a[i].squeeze() + 1) / 2.0 im_b = (x_b[i].squeeze() + 1) / 2.0 input_transformed_a = seg_transform()(im_a).unsqueeze(0) input_transformed_b = seg_transform()(im_b).unsqueeze(0) output_a = (self.segmentation_model( input_transformed_a).squeeze().max(0)[1]) output_b = (self.segmentation_model( input_transformed_b).squeeze().max(0)[1]) rgb_a = decode_segmap(output_a.cpu().numpy()) rgb_b = decode_segmap(output_b.cpu().numpy()) rgb_a = Image.fromarray(rgb_a).resize( (x_a.size(3), x_a.size(3))) rgb_b = Image.fromarray(rgb_b).resize( (x_a.size(3), x_a.size(3))) rgb_a_list.append(transforms.ToTensor()(rgb_a).unsqueeze(0)) rgb_b_list.append(transforms.ToTensor()(rgb_b).unsqueeze(0)) # Inference semantic segmentation on fake images image_ab = (x_ab1[i].squeeze() + 1) / 2.0 image_ba = (x_ba1[i].squeeze() + 1) / 2.0 input_transformed_ab = seg_transform()(image_ab).unsqueeze( 0).to("cuda") input_transformed_ba = seg_transform()(image_ba).unsqueeze( 0).to("cuda") output_ab = (self.segmentation_model( input_transformed_ab).squeeze().max(0)[1]) output_ba = (self.segmentation_model( input_transformed_ba).squeeze().max(0)[1]) rgb_ab = decode_segmap(output_ab.cpu().numpy()) rgb_ba = decode_segmap(output_ba.cpu().numpy()) rgb_ab = Image.fromarray(rgb_ab).resize( (x_a.size(3), x_a.size(3))) rgb_ba = Image.fromarray(rgb_ba).resize( (x_a.size(3), x_a.size(3))) rgb_ab_list.append(transforms.ToTensor()(rgb_ab).unsqueeze(0)) rgb_ba_list.append(transforms.ToTensor()(rgb_ba).unsqueeze(0)) rgb1_a, rgb1_b, rgb1_ab, rgb1_ba = ( torch.cat(rgb_a_list).cuda(), torch.cat(rgb_b_list).cuda(), torch.cat(rgb_ab_list).cuda(), torch.cat(rgb_ba_list).cuda(), ) self.train() if self.semantic_w: self.segmentation_model.eval() return ( x_a, x_a_recon, rgb1_a, x_ab1, rgb1_ab, x_ab2, x_b, x_b_recon, rgb1_b, x_ba1, rgb1_ba, x_ba2, ) else: return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def sample_fid(self, x_a, x_b): """ Infer the model on a batch of image Arguments: x_a {torch.Tensor} -- batch of image from domain A x_b {[type]} -- batch of image from domain B Returns: A list of torch images -- columnwise :x_a, autoencode(x_a), x_ab_1, x_ab_2 Or if self.semantic_w is true: x_a, autoencode(x_a), Semantic segmentation x_a, x_ab_1,semantic segmentation x_ab_1, x_ab_2 """ self.eval() x_ab1 = [] if self.gen_state == 0: for i in range(x_a.size(0)): c_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) _, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) if self.guided == 1: x_ab1.append(self.gen_b.decode(c_a, s_b_fake)) else: print("self.guided unknown value:", self.guided) elif self.gen_state == 1: for i in range(x_a.size(0)): c_a, _ = self.gen.encode(x_a[i].unsqueeze(0), 1) _, s_b_fake = self.gen.encode(x_b[i].unsqueeze(0), 2) if self.guided == 1: x_ab1.append(self.gen.decode(c_a, s_b_fake, 2)) else: print("self.guided unknown value:", self.guided) else: print("self.gen_state unknown value:", self.gen_state) x_ab1 = torch.cat(x_ab1) self.train() if self.semantic_w: self.segmentation_model.eval() return x_ab1 def dis_update(self, x_a, x_b, hyperparameters, comet_exp=None): """ Update the weights of the discriminator Arguments: x_a {torch.Tensor} -- Image from domain A after transform in tensor format x_b {torch.Tensor} -- Image from domain B after transform in tensor format hyperparameters {dictionnary} -- dictionnary with all hyperparameters Keyword Arguments: comet_exp {cometExperience} -- CometML object use to log all the loss and images (default: {None}) """ self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) if self.gen_state == 0: # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (cross domain) if self.guided == 0: x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) elif self.guided == 1: x_ba = self.gen_a.decode(c_b, s_a_prime) x_ab = self.gen_b.decode(c_a, s_b_prime) else: print("self.guided unknown value:", self.guided) elif self.gen_state == 1: # encode c_a, s_a_prime = self.gen.encode(x_a, 1) c_b, s_b_prime = self.gen.encode(x_b, 2) # decode (cross domain) if self.guided == 0: x_ba = self.gen.decode(c_b, s_a, 1) x_ab = self.gen.decode(c_a, s_b, 2) elif self.guided == 1: x_ba = self.gen.decode(c_b, s_a_prime, 1) x_ab = self.gen.decode(c_a, s_b_prime, 2) else: print("self.guided unknown value:", self.guided) else: print("self.gen_state unknown value:", self.gen_state) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = (hyperparameters["gan_w"] * self.loss_dis_a + hyperparameters["gan_w"] * self.loss_dis_b) self.loss_dis_total.backward() self.dis_opt.step() if comet_exp is not None: comet_exp.log_metric("loss_dis_b", self.loss_dis_b.cpu().detach()) comet_exp.log_metric("loss_dis_a", self.loss_dis_a.cpu().detach()) def domain_classifier_update(self, x_a, x_b, hyperparameters, comet_exp=None): """ Update the weights of the domain classifier Arguments: x_a {torch.Tensor} -- Image from domain A after transform in tensor format x_b {torch.Tensor} -- Image from domain B after transform in tensor format hyperparameters {dictionnary} -- dictionnary with all hyperparameters Keyword Arguments: comet_exp {cometExperience} -- CometML object use to log all the loss and images (default: {None}) """ self.dann_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) if self.gen_state == 0: # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) elif self.gen_state == 1: # encode c_a, _ = self.gen.encode(x_a, 1) c_b, _ = self.gen.encode(x_b, 2) else: print("self.gen_state unknown value:", self.gen_state) # domain classifier loss self.domain_class_loss, out_a, out_b = self.compute_domain_adv_loss( c_a, c_b, compute_accuracy=True, minimize=True) self.domain_class_loss.backward() self.dann_opt.step() if comet_exp is not None: comet_exp.log_metric("domain_class_loss", self.domain_class_loss.cpu().detach()) comet_exp.log_metric("probability A being identified as A", out_a.cpu().detach()) comet_exp.log_metric("probability B being identified as B", out_b.cpu().detach()) def update_learning_rate(self): """ Update the learning rate """ if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.dann_scheduler is not None: self.dann_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): """ Resume the training loading the network parameters Arguments: checkpoint_dir {string} -- path to the directory where the checkpoints are saved hyperparameters {dictionnary} -- dictionnary with all hyperparameters Returns: int -- number of iterations (used by the optimizer) """ # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) if self.gen_state == 0: self.gen_a.load_state_dict(state_dict["a"]) self.gen_b.load_state_dict(state_dict["b"]) elif self.gen_state == 1: self.gen.load_state_dict(state_dict["2"]) else: print("self.gen_state unknown value:", self.gen_state) # Load domain classifier if self.domain_classif == 1: last_model_name = get_model_list(checkpoint_dir, "domain_classif") state_dict = torch.load(last_model_name) self.domain_classifier.load_state_dict(state_dict["d"]) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict["a"]) self.dis_b.load_state_dict(state_dict["b"]) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, "optimizer.pt")) self.dis_opt.load_state_dict(state_dict["dis"]) self.gen_opt.load_state_dict(state_dict["gen"]) if self.domain_classif == 1: self.dann_opt.load_state_dict(state_dict["dann"]) self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters, iterations) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print("Resume from iteration %d" % iterations) return iterations def save(self, snapshot_dir, iterations): """ Save generators, discriminators, and optimizers Arguments: snapshot_dir {string} -- directory path where to save the networks weights iterations {int} -- number of training iterations """ # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, "gen_%08d.pt" % (iterations + 1)) dis_name = os.path.join(snapshot_dir, "dis_%08d.pt" % (iterations + 1)) domain_classifier_name = os.path.join( snapshot_dir, "domain_classifier_%08d.pt" % (iterations + 1)) opt_name = os.path.join(snapshot_dir, "optimizer.pt") if self.gen_state == 0: torch.save( { "a": self.gen_a.state_dict(), "b": self.gen_b.state_dict() }, gen_name) elif self.gen_state == 1: torch.save({"2": self.gen.state_dict()}, gen_name) else: print("self.gen_state unknown value:", self.gen_state) torch.save({ "a": self.dis_a.state_dict(), "b": self.dis_b.state_dict() }, dis_name) if self.domain_classif: torch.save({"d": self.domain_classifier.state_dict()}, domain_classifier_name) torch.save( { "gen": self.gen_opt.state_dict(), "dis": self.dis_opt.state_dict(), "dann": self.dann_opt.state_dict(), }, opt_name, ) else: torch.save( { "gen": self.gen_opt.state_dict(), "dis": self.dis_opt.state_dict() }, opt_name, )
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.dis_sa = MsImageDis( hyperparameters['input_dim_a'] * 2, hyperparameters['dis']) # discriminator for domain a self.dis_sb = MsImageDis( hyperparameters['input_dim_b'] * 2, hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) dis_style_params = list(self.dis_sa.parameters()) + list( self.dis_sb.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_style_opt = torch.optim.Adam( [p for p in dis_style_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.dis_style_scheduler = get_scheduler(self.dis_style_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_sa.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) self.dis_sb.apply(weights_init('gaussian')) if hyperparameters['gen']['CE_method'] == 'vgg': self.gen_a.content_init() self.gen_b.content_init() self.criterion = nn.L1Loss().cuda() self.triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2).cuda() self.kld = nn.KLDivLoss() self.contextual_loss = ContextualLoss() # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def kl_loss(self, input, target): #return torch.mean(torch.abs(self.kld(input, target))) return torch.mean(self.kld(input, target)) def normalize_feat(self, feat): bs, c, H, W = feat.shape feat = feat.view(bs, c, -1) feat_norm = torch.norm(feat, 2, 1, keepdim=True) + sys.float_info.epsilon feat = torch.div(feat, feat_norm) #print(max(feat)) return feat def norm_two_domain(self, feat_c, feat_s): feat = torch.cat((feat_c, feat_s), 1) bs, c, H, W = feat.shape feat_norm = torch.norm(feat, 2, 1, keepdim=True) feat = torch.div(feat, feat_norm) feat_c = feat[:, 0:256, :, :].view(bs, 256, -1) feat_s = feat[:, 256:512, :, :].view(bs, 256, -1) return feat_c, feat_s def generate_map(self, corr_index, h, w): coor = [] corr_map = [] for i in range(len(corr_index)): x = corr_index[i] // h y = corr_index[i] % w coor.append(x) coor.append(y) corr_map.append(list(np.asarray(coor))) coor.clear() corr_map_final = np.reshape(np.asarray(corr_map), (h, w, 2)) return corr_map_final def warp_img(self, corr_map, ref_img): bs, c, h_img, w_img = ref_img.shape h, w, _ = corr_map.shape scale = h_img // h warped_img = torch.zeros(ref_img.shape) for i in range(h): for j in range(w): nnx = corr_map[i][j][0] nny = corr_map[i][j][1] warped_img[:, :, i * scale:(i + 1) * scale, j * scale:(j + 1) * scale] = ref_img[:, :, nnx * scale:(nnx + 1) * scale, nny * scale:(nny + 1) * scale] return warped_img.cuda() def warp_style(self, cur_content, ref_content, ref_style): # normalize feature cur_content = self.normalize_feat(cur_content) ref_content = self.normalize_feat(ref_content) #cur_content, ref_content = self.norm_two_domain(cur_content, ref_content) cur_content = cur_content.permute(0, 2, 1) # calculate similarity f = torch.matmul(cur_content, ref_content) # 1 x (H x W) x (H x W) f_corr = F.softmax(f / 0.005, dim=-1) # 1 x (H x W) x (H x W) #f_corr = F.softmax(f, dim=-1) # 1 x (H x W) x (H x W) # get corr index replace softmax bs, HW, WH = f_corr.shape corr_index = torch.argmax(f_corr, dim=-1).squeeze(0) # collect ref style bs, c, H, W = ref_style.shape ref_style = ref_style.view(bs, c, -1) ref_style = ref_style.permute(0, 2, 1) # 1 x (H x W) x c # warp ref style warped_style = torch.matmul(f_corr, ref_style) # 1 x (H x W) x c warped_style = warped_style.permute(0, 2, 1).contiguous() warped_style = warped_style.view(bs, c, H, W) return corr_index, warped_style def forward(self, x_a, x_b): self.eval() c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) # warp the ref_style to the content_style _, s_ab_warp = warp_style(c_a, c_b, s_b_fake) _, s_ba_warp = warp_style(c_b, c_a, s_a_fake) x_ba = self.gen_a.decode(s_ba_warp, c_b) x_ab = self.gen_b.decode(s_ab_warp, c_a) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, x_adf, x_bdf, hyperparameters): self.gen_opt.zero_grad() # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # add style warp here _, s_ab = self.warp_style(c_a, c_b, s_b_prime) _, s_ba = self.warp_style(c_b, c_a, s_a_prime) # decode (within domain) x_a_recon = self.gen_a.decode(s_a_prime, c_a) x_b_recon = self.gen_b.decode(s_b_prime, c_b) # decode (cross domain) x_ba = self.gen_a.decode(s_ba, c_b) x_ab = self.gen_b.decode(s_ab, c_a) # encode again c_b_recon, s_ba_recon = self.gen_a.encode( x_ba) # now the s_a_recon matches the structure of B c_a_recon, s_ab_recon = self.gen_b.encode(x_ab) # decode again (if needed) # to warp style first _, s_aba = self.warp_style(c_a, c_b, s_ba_recon) _, s_bab = self.warp_style(c_b, c_a, s_ab_recon) # to reconstruct then x_aba = self.gen_a.decode( s_a_prime, c_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( s_b_prime, c_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # prepare paired data for adv generator #pair_a_ffake = torch.cat((x_ba, x_a), 1) #pair_b_ffake = torch.cat((x_ab, x_b), 1) # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_aba, s_a_prime) self.loss_gen_recon_s_b = self.recon_criterion( s_bab, s_b_prime) # default is s_bab, need to test s_b_recon #self.loss_gen_recon_s_a += self.triplet_loss(s_a_prime, s_aba, s_b_prime) #self.loss_gen_recon_s_b += self.triplet_loss(s_b_prime, s_bab, s_a_prime) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) #self.loss_gen_kl_ab = self.kl_loss(x_ab, x_b) #self.loss_gen_kl_ba = self.kl_loss(x_ba, x_a) self.loss_gen_cx_a = self.contextual_loss(s_ba, s_a_prime) self.loss_gen_cx_b = self.contextual_loss(s_ab, s_b_prime) self.loss_gen_cycrecon_x_a = self.criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_xa = self.gen_a.calc_gen_loss( self.dis_a.forward(x_ba)) self.loss_gen_adv_xb = self.gen_b.calc_gen_loss( self.dis_b.forward(x_ab)) #self.loss_gen_adv_sxa = self.gen_a.calc_gen_loss(self.dis_sa.forward(pair_a_ffake)) #self.loss_gen_adv_sxb = self.gen_b.calc_gen_loss(self.dis_sb.forward(pair_b_ffake)) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss_new( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss_new( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_xa + \ hyperparameters['gan_w'] * self.loss_gen_adv_xb + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b #hyperparameters['recon_kl_w'] * self.loss_gen_kl_ab + \ #hyperparameters['recon_kl_w'] * self.loss_gen_kl_ba + \ #hyperparameters['recon_cx_w'] * self.loss_gen_cx_a + \ #hyperparameters['recon_cx_w'] * self.loss_gen_cx_b + \ #hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_s_a + \ #hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_s_b + \ #hyperparameters['gan_wp'] * self.loss_gen_adv_sxa + \ #hyperparameters['gan_wp'] * self.loss_gen_adv_sxb + \ self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def compute_vgg_loss_new(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_feat = vgg(img_vgg) target_feat = vgg(target_vgg) return self.recon_criterion(img_feat, target_feat) def sample(self, x_a, x_b, x_adf, x_bdf): self.eval() #s_a1 = Variable(self.s_a) #s_b1 = Variable(self.s_b) #s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) #s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] h = w = 64 for i in range(x_a.size(0)): img_a = x_a[i].unsqueeze(0) img_b = x_b[i].unsqueeze(0) c_a, s_a_fake = self.gen_a.encode(img_a) c_b, s_b_fake = self.gen_b.encode(img_b) # reconstruction x_a_recon.append(self.gen_a.decode(s_a_fake, c_a)) x_b_recon.append(self.gen_b.decode(s_b_fake, c_b)) print(x_a_recon[0].shape) # warp style corr_index_ab, s_ab = self.warp_style(c_a, c_b, s_b_fake) corr_index_ba, s_ba = self.warp_style(c_b, c_a, s_a_fake) # cross domain construction x_ba1.append(self.gen_a.decode(s_ba, c_b)) ## output warped results x_ba2 corr_map_ba = self.generate_map(corr_index_ba, h, w) x_ba2.append(self.warp_img(corr_map_ba, img_a)) x_ab1.append(self.gen_b.decode(s_ab, c_a)) ## output warped results x_ab2 corr_map_ab = self.generate_map(corr_index_ab, h, w) x_ab2.append(self.warp_img(corr_map_ab, img_b)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_adf, x_a_recon, x_ab1, x_ab2, x_b, x_bdf, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, x_adf, x_bdf, hyperparameters): self.dis_opt.zero_grad() self.dis_style_opt.zero_grad() #s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) #s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a = self.gen_a.encode(x_a) c_b, s_b = self.gen_b.encode(x_b) # warp the style here _, s_ab_warp = self.warp_style(c_a, c_b, s_b) _, s_ba_warp = self.warp_style(c_b, c_a, s_a) # decode (cross domain) x_ba = self.gen_a.decode(s_ba_warp, c_b) x_ab = self.gen_b.decode(s_ab_warp, c_a) # prepare data for the paired discriminator # real fake data -> 0 if (len(self.dis_sa.pool_) == 0): print(len(self.dis_sa.pool_)) pair_a_rfake = torch.cat((x_b, x_a), 1) else: pair_a_rfake = torch.cat((self.dis_sa.pool('fetch'), x_a), 1) self.dis_sa.pool('push', x_a) if (len(self.dis_sb.pool_) == 0): print(len(self.dis_sb.pool_)) pair_b_rfake = torch.cat((x_a, x_b), 1) else: pair_b_rfake = torch.cat((self.dis_sb.pool('fetch'), x_b), 1) self.dis_sb.pool('push', x_b) # real real data -> 1 pair_a_rreal = torch.cat((x_a, x_adf), 1) pair_b_rreal = torch.cat((x_b, x_bdf), 1) # fake fake data -> 0 #pair_a_ffake = torch.cat((x_ba.detach(), x_a), 1) #pair_b_ffake = torch.cat((x_ab.detach(), x_b), 1) # D loss self.loss_dis_xa = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_xb = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) #self.loss_dis_xa = self.dis_a.calc_dis_loss(x_ba.detach(), self.dis_sa.pool('fetch')) #self.loss_dis_xb = self.dis_b.calc_dis_loss(x_ab.detach(), self.dis_sb.pool('fetch')) #self.loss_dis_sxa = (self.dis_sa.calc_dis_loss(pair_a_rfake, pair_a_rreal) + self.dis_sa.calc_dis_loss(pair_a_ffake, pair_a_rreal)) / 2 #self.loss_dis_sxb = (self.dis_sb.calc_dis_loss(pair_b_rfake, pair_b_rreal) + self.dis_sb.calc_dis_loss(pair_b_ffake, pair_b_rreal)) / 2 #self.loss_dis_sxa = self.dis_sa.calc_dis_loss(pair_a_ffake, pair_a_rreal) #self.loss_dis_sxb = self.dis_sb.calc_dis_loss(pair_b_ffake, pair_b_rreal) #self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_xa + hyperparameters['gan_w'] * self.loss_dis_xb + hyperparameters['gan_wp'] * self.loss_dis_sxa + hyperparameters['gan_wp'] * self.loss_dis_sxb self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_xa + hyperparameters[ 'gan_w'] * self.loss_dis_xb self.loss_dis_total.backward() self.dis_opt.step() self.dis_style_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.dis_style_scheduler is not None: self.dis_style_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.dis_style_scheduler = get_scheduler(self.dis_style_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode( c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def snap_clean(self, snap_dir, iterations, save_last=10000, period=20000): # Cleaning snapshot directory from old files if not os.path.exists(snap_dir): return None gen_models = [ os.path.join(snap_dir, f) for f in os.listdir(snap_dir) if "gen" in f and ".pt" in f ] dis_models = [ os.path.join(snap_dir, f) for f in os.listdir(snap_dir) if "dis" in f and ".pt" in f ] gen_models.sort() dis_models.sort() marked_clean = [] for i, model in enumerate(gen_models): m_iter = int(model[-11:-3]) if i == 0: m_prev = 0 continue if m_iter > iterations - save_last: break if m_iter - m_prev < period: marked_clean.append(model) while m_iter - m_prev >= period: m_prev += period for i, model in enumerate(dis_models): m_iter = int(model[-11:-3]) if i == 0: m_prev = 0 continue if m_iter > iterations - save_last: break if m_iter - m_prev < period: marked_clean.append(model) while m_iter - m_prev >= period: m_prev += period print(f'Cleaning snapshots: {marked_clean}') for f in marked_clean: os.remove(f) def save(self, snapshot_dir, iterations, smart_override): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name) if smart_override: self.snap_clean(snapshot_dir, iterations + 1)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, m_A, x_b, m_B, hyperparameters): self.gen_opt.zero_grad() im_A = 1 - m_A im_B = 1 - m_B # encode c_a, s_bA = self.gen_a.encode(x_a, im_A) c_b, s_fB = self.gen_b.encode(x_b, m_B) _, s_fA = self.gen_a.encode(x_a, m_A) _, s_bB = self.gen_b.encode(x_b, im_B) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_fA, m_B, s_bB) x_ab = self.gen_b.decode(c_a, s_fB, m_A, s_bA) # decode (within domain) x_aa = self.gen_a.decode(c_a, s_fA, m_A, s_bA) x_bb = self.gen_b.decode(c_b, s_fB, m_B, s_bB) # encode again c_ba, s_fBA = self.gen_a.encode(x_ba, m_B) c_ab, s_fAB = self.gen_a.encode(x_ab, m_A) _, s_bBA = self.gen_a.encode(x_ba, im_B) _, s_bAB = self.gen_a.encode(x_ab, im_A) # decode again (if needed) x_aba = self.gen_a.decode( c_ab, s_fBA, m_A, s_bAB) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_ba, s_fAB, m_B, s_bBA) if hyperparameters['recon_x_cyc_w'] > 0 else None self.loss_gen_recon_c_a = self.recon_criterion(c_ab, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_ba, c_b) self.loss_gen_recon_s_a = self.recon_criterion(s_bAB, s_bA) self.loss_gen_recon_s_b = self.recon_criterion(s_bBA, s_bB) self.loss_gen_recon_s_af = self.recon_criterion(s_fAB, s_fB) self.loss_gen_recon_s_bf = self.recon_criterion(s_fBA, s_fA) self.loss_gen_recon_x_a = self.recon_criterion(x_aa, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_bb, x_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( im_A * x_aba, im_A * x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( m_B * x_bab, m_B * x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_af + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_bf + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, loader_a, loader_b, size): self.eval() im_a = torch.stack([loader_a.dataset[i][0] for i in range(size)]).cuda() seg_a = torch.stack([loader_a.dataset[i][1] for i in range(size)]).cuda() im_b = torch.stack([loader_b.dataset[i][0] for i in range(size)]).cuda() seg_b = torch.stack([loader_b.dataset[i][1] for i in range(size)]).cuda() x_a_recon, x_b_recon, x_ba1, x_bm, x_ab1, x_am = [], [], [], [], [], [] for i in range(im_a.size(0)): mask_a = seg_a[i].unsqueeze(0) mask_b = seg_b[i].unsqueeze(0) x_a = im_a[i].unsqueeze(0) x_b = im_b[i].unsqueeze(0) masked_a = mask_a * x_a masked_b = mask_b * x_b c_a, s_bA = self.gen_a.encode(x_a, 1 - mask_a) c_b, s_fB = self.gen_b.encode(x_b, mask_b) c_a, s_fA = self.gen_a.encode(x_a, mask_a) c_b, s_bB = self.gen_b.encode(x_b, 1 - mask_b) # decode (cross domain) x_BA = self.gen_a.decode(c_b, s_fA, mask_b, s_bB) x_AB = self.gen_b.decode(c_a, s_fB, mask_a, s_bA) if 0 == i % 2: x_AB = (1 * (1 - mask_a) * x_a + (0 * (1 - mask_a) * x_AB)) + mask_a * x_AB x_BA = (1 * (1 - mask_b) * x_b + (0 * (1 - mask_b) * x_BA)) + mask_b * x_BA x_ba1.append(x_BA) x_ab1.append(x_AB) x_am.append(masked_a) x_bm.append(masked_b) # decode (within domain) x_A_recon = self.gen_a.decode(c_a, s_fA, mask_a, s_bA) x_B_recon = self.gen_b.decode(c_b, s_fB, mask_b, s_bB) x_a_recon.append(x_A_recon) x_b_recon.append(x_B_recon) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1 = torch.cat(x_ba1) x_ab1 = torch.cat(x_ab1) x_bm = torch.cat(x_bm) x_am = torch.cat(x_am) self.train() return im_a, x_a_recon, x_ab1, x_am, im_b, x_b_recon, x_ba1, x_bm def dis_update(self, x_a, m_a, x_b, m_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode up_im_A = 1 - m_a #F.interpolate(1-m_a, None,1, 'bilinear', align_corners=False) up_m_B = m_b #F.interpolate(m_b, None, 1, 'bilinear', align_corners=False) up_m_A = m_a #F.interpolate(m_a, None, 1, 'bilinear', align_corners=False) up_im_B = 1 - m_b #.interpolate(1-m_b, None, 1, 'bilinear', align_corners=False) c_a, s_bA = self.gen_a.encode(x_a, up_im_A) c_b, s_fB = self.gen_b.encode(x_b, up_m_B) _, s_fA = self.gen_a.encode(x_a, up_m_A) _, s_bB = self.gen_b.encode(x_b, up_im_B) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_fA, m_b, s_bB) x_ab = self.gen_b.decode(c_a, s_fB, m_a, s_bA) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler, self.lr_policy = get_scheduler( self.dis_opt, hyperparameters) self.gen_scheduler, self.lr_policy = get_scheduler( self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) self.metric = 0 # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba # 进来两张图 a b def gen_update(self, x_a, x_b, hyperparameters, mask): self.guided = 0 print(type(mask)) self.gen_opt.zero_grad() # 随机出style a b s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode # 通过gen 得到content style' c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) 把encoder decoder应该要能recon x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) 如果结合content 和style 得到的应该是translation结束的结果 if self.guided == 0: x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) elif self.guided == 1: x_ba = self.gen_a.decode(c_b, s_a_prime) x_ab = self.gen_b.decode(c_a, s_b_prime) # encode again 再区分conten style c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode( c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None print(x_a_recon.size(), x_b_recon.size()) # mask loss mask = torch.cat([mask, mask, mask], 1) self.loss_attentive = self.recon_criterion(x_a[mask == 1], x_a_recon[mask == 1]) # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ hyperparameters['att_w'] * self.loss_attentive self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 # 训练discriminator 输入两张图片,各自转成不同的domain。 # 使用各自的content code,但是使用随机的style code,去encode出一张图片 # 希望能骗过discriminator def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: if self.lr_policy == 'plateau': self.dis_scheduler.step(self.metric) else: self.dis_scheduler.step() if self.gen_scheduler is not None: if self.lr_policy == 'plateau': self.dis_scheduler.step(self.metric) else: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler, self.lr_policy = get_scheduler( self.dis_opt, hyperparameters, iterations) self.gen_scheduler, self.lr_policy = get_scheduler( self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class UNIT_Trainer(nn.Module): def __init__(self, hyperparameters, resume_epoch=-1, snapshot_dir=None): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks. self.gen = VAEGen( hyperparameters['input_dim'] + hyperparameters['n_datasets'], hyperparameters['gen'], hyperparameters['n_datasets']) # Auto-encoder for domain a. self.dis = MsImageDis( hyperparameters['input_dim'] + hyperparameters['n_datasets'], hyperparameters['dis']) # Discriminator for domain a. self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.sup = UNet(input_channels=hyperparameters['input_dim'], num_classes=2).cuda() # Setup the optimizers. beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis.parameters()) gen_params = list(self.gen.parameters()) + list(self.sup.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization. self.apply(weights_init(hyperparameters['init'])) self.dis.apply(weights_init('gaussian')) # Presetting one hot encoding vectors. self.one_hot_img = torch.zeros(hyperparameters['n_datasets'], hyperparameters['batch_size'], hyperparameters['n_datasets'], 256, 256).cuda() self.one_hot_h = torch.zeros(hyperparameters['n_datasets'], hyperparameters['batch_size'], hyperparameters['n_datasets'], 64, 64).cuda() for i in range(hyperparameters['n_datasets']): self.one_hot_img[i, :, i, :, :].fill_(1) self.one_hot_h[i, :, i, :, :].fill_(1) if resume_epoch != -1: self.resume(snapshot_dir, hyperparameters) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def semi_criterion(self, input, target): loss = CrossEntropyLoss2d(size_average=False).cuda() return loss(input, target) def forward(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True h_a, _ = self.gen_a.encode(x_a) h_b, _ = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(h_b) x_ab = self.gen_b.decode(h_a) self.train() return x_ab, x_ba def __compute_kl(self, mu): # def _compute_kl(self, mu, sd): # mu_2 = torch.pow(mu, 2) # sd_2 = torch.pow(sd, 2) # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0) # return encoding_loss mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def set_gen_trainable(self, train_bool): if train_bool: self.gen.train() for param in self.gen.parameters(): param.requires_grad = True else: self.gen.eval() for param in self.gen.parameters(): param.requires_grad = True def set_sup_trainable(self, train_bool): if train_bool: self.sup.train() for param in self.sup.parameters(): param.requires_grad = True else: self.sup.eval() for param in self.sup.parameters(): param.requires_grad = True def sup_update(self, x_a, x_b, y_a, y_b, d_index_a, d_index_b, use_a, use_b, hyperparameters): self.gen_opt.zero_grad() # Encode. one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) h_a, n_a = self.gen.encode(one_hot_x_a) h_b, n_b = self.gen.encode(one_hot_x_b) # Decode (within domain). one_hot_h_a = torch.cat([h_a + n_a, self.one_hot_h[d_index_a]], 1) one_hot_h_b = torch.cat([h_b + n_b, self.one_hot_h[d_index_b]], 1) x_a_recon = self.gen.decode(one_hot_h_a) x_b_recon = self.gen.decode(one_hot_h_b) # Decode (cross domain). one_hot_h_ab = torch.cat([h_a + n_a, self.one_hot_h[d_index_b]], 1) one_hot_h_ba = torch.cat([h_b + n_b, self.one_hot_h[d_index_a]], 1) x_ba = self.gen.decode(one_hot_h_ba) x_ab = self.gen.decode(one_hot_h_ab) # Encode again. one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1) one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1) h_b_recon, n_b_recon = self.gen.encode(one_hot_x_ba) h_a_recon, n_a_recon = self.gen.encode(one_hot_x_ab) # Decode again (if needed). one_hot_h_a_recon = torch.cat( [h_a_recon + n_a_recon, self.one_hot_h[d_index_a]], 1) one_hot_h_b_recon = torch.cat( [h_b_recon + n_b_recon, self.one_hot_h[d_index_b]], 1) x_aba = self.gen.decode( one_hot_h_a_recon ) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen.decode( one_hot_h_b_recon ) if hyperparameters['recon_x_cyc_w'] > 0 else None # Forwarding through supervised model. p_a = None p_b = None loss_semi_a = None loss_semi_b = None has_a_label = (h_a[use_a, :, :, :].size(0) != 0) if has_a_label: p_a = self.sup(h_a, use_a, True) p_a_recon = self.sup(h_a_recon, use_a, True) loss_semi_a = self.semi_criterion(p_a, y_a[use_a, :, :]) + \ self.semi_criterion(p_a_recon, y_a[use_a, :, :]) has_b_label = (h_b[use_b, :, :, :].size(0) != 0) if has_b_label: p_b = self.sup(h_b, use_b, True) p_b_recon = self.sup(h_b, use_b, True) loss_semi_b = self.semi_criterion(p_b, y_b[use_b, :, :]) + \ self.semi_criterion(p_b_recon, y_b[use_b, :, :]) self.loss_gen_total = None if loss_semi_a is not None and loss_semi_b is not None: self.loss_gen_total = loss_semi_a + loss_semi_b elif loss_semi_a is not None: self.loss_gen_total = loss_semi_a elif loss_semi_b is not None: self.loss_gen_total = loss_semi_b if self.loss_gen_total is not None: self.loss_gen_total.backward() self.gen_opt.step() def sup_forward(self, x, y, d_index, hyperparameters): self.sup.eval() # Encoding content image. one_hot_x = torch.cat([x, self.one_hot_img[d_index, 0].unsqueeze(0)], 1) hidden, _ = self.gen.encode(one_hot_x) # Forwarding on supervised model. y_pred = self.sup(hidden, only_prediction=True) # Computing metrics. pred = y_pred.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy() jacc = jaccard(pred, y.cpu().squeeze(0).numpy()) return jacc, pred, hidden def gen_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters): self.gen_opt.zero_grad() # Encode. one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) h_a, n_a = self.gen.encode(one_hot_x_a) h_b, n_b = self.gen.encode(one_hot_x_b) # Decode (within domain). one_hot_h_a = torch.cat([h_a + n_a, self.one_hot_h[d_index_a]], 1) one_hot_h_b = torch.cat([h_b + n_b, self.one_hot_h[d_index_b]], 1) x_a_recon = self.gen.decode(one_hot_h_a) x_b_recon = self.gen.decode(one_hot_h_b) # Decode (cross domain). one_hot_h_ab = torch.cat([h_a + n_a, self.one_hot_h[d_index_b]], 1) one_hot_h_ba = torch.cat([h_b + n_b, self.one_hot_h[d_index_a]], 1) x_ba = self.gen.decode(one_hot_h_ba) x_ab = self.gen.decode(one_hot_h_ab) # Encode again. one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1) one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1) h_b_recon, n_b_recon = self.gen.encode(one_hot_x_ba) h_a_recon, n_a_recon = self.gen.encode(one_hot_x_ab) # Decode again (if needed). one_hot_h_a_recon = torch.cat( [h_a_recon + n_a_recon, self.one_hot_h[d_index_a]], 1) one_hot_h_b_recon = torch.cat( [h_b_recon + n_b_recon, self.one_hot_h[d_index_b]], 1) x_aba = self.gen.decode( one_hot_h_a_recon ) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen.decode( one_hot_h_b_recon ) if hyperparameters['recon_x_cyc_w'] > 0 else None # Reconstruction loss. self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_kl_a = self.__compute_kl(h_a) self.loss_gen_recon_kl_b = self.__compute_kl(h_b) self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) # GAN loss. self.loss_gen_adv_a = self.dis.calc_gen_loss(one_hot_x_ba) self.loss_gen_adv_b = self.dis.calc_gen_loss(one_hot_x_ab) # Total loss. self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab self.loss_gen_total.backward() self.gen_opt.step() def sample(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] for i in range(x_a.size(0)): h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(h_a)) x_b_recon.append(self.gen_b.decode(h_b)) x_ba.append(self.gen_a.decode(h_b)) x_ab.append(self.gen_b.decode(h_a)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba = torch.cat(x_ba) x_ab = torch.cat(x_ab) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def dis_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters): self.dis_opt.zero_grad() # Encode. one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) h_a, n_a = self.gen.encode(one_hot_x_a) h_b, n_b = self.gen.encode(one_hot_x_b) # Decode (cross domain). one_hot_h_ab = torch.cat([h_a + n_a, self.one_hot_h[d_index_b]], 1) one_hot_h_ba = torch.cat([h_b + n_b, self.one_hot_h[d_index_a]], 1) x_ba = self.gen.decode(one_hot_h_ba) x_ab = self.gen.decode(one_hot_h_ab) # D loss. one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1) one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1) self.loss_dis_a = self.dis.calc_dis_loss(one_hot_x_ba.detach(), one_hot_x_a) self.loss_dis_b = self.dis.calc_dis_loss(one_hot_x_ab.detach(), one_hot_x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \ hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators. last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen.load_state_dict(state_dict) epochs = int(last_model_name[-11:-3]) # Load discriminators. last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis.load_state_dict(state_dict) # Load supervised model. last_model_name = get_model_list(checkpoint_dir, "sup") state_dict = torch.load(last_model_name) self.sup.load_state_dict(state_dict) # Load optimizers. last_model_name = get_model_list(checkpoint_dir, "opt") state_dict = torch.load(last_model_name) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) for state in self.dis_opt.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() for state in self.gen_opt.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() # Reinitilize schedulers. self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, epochs) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, epochs) print('Resume from iteration %d' % epochs) return epochs def save(self, snapshot_dir, epoch): # Save generators, discriminators, and optimizers. gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % epoch) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % epoch) sup_name = os.path.join(snapshot_dir, 'sup_%08d.pt' % epoch) opt_name = os.path.join(snapshot_dir, 'opt_%08d.pt' % epoch) torch.save(self.gen.state_dict(), gen_name) torch.save(self.dis.state_dict(), dis_name) torch.save(self.sup.state_dict(), sup_name) torch.save( { 'dis': self.dis_opt.state_dict(), 'gen': self.gen_opt.state_dict() }, opt_name)
class MUSIC_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUSIC_Trainer, self).__init__() lr = hyperparameters['lr'] old_flag = hyperparameters['old_flag'] # Initiate the networks if old_flag == 1: self.gen = MaskGenOld(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.style_dim = hyperparameters['gen']['style_dim'] self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda() else: self.gen = MaskGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) try: enhance = hyperparameters['enhance'] except KeyError: enhance = None self.enhance = enhance # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def masking(self, mask, img): if self.enhance: mask = mask ** self.enhance img = img * 0.5 + 0.5 masked_image = mask * img masked_image = (masked_image - 0.5) * 2 return masked_image def scaled_sum(self, input_1, input_2): input_1 = input_1 * 0.5 + 0.5 input_2 = input_2 * 0.5 + 0.5 sum_output = input_1 + input_2 sum_output = torch.clamp(sum_output, 0, 1) # added at 3.yaml sum_output = (sum_output - 0.5) * 2 return sum_output def scaled_sub(self, input_1, input_2): input_1 = input_1 * 0.5 + 0.5 input_2 = input_2 * 0.5 + 0.5 sub_output = input_1 - input_2 sub_output = torch.clamp(sub_output, 0, 1) # added at 3.yaml sub_output = (sub_output - 0.5) * 2 return sub_output def forward(self, x_b): self.eval() x_ba_mask = self.gen.decode(self.gen.encode(x_b)) x_ba = self.masking(x_ba_mask, x_b) self.train() return x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() # encode-decode x_ba_mask = self.gen.decode(self.gen.encode(x_b)) x_aa_mask = self.gen.decode(self.gen.encode(x_a)) x_ba = self.masking(x_ba_mask, x_b) x_aa = self.masking(x_aa_mask, x_a) # encode again x_t = self.scaled_sub(x_b, x_ba) x_b_new = self.scaled_sum(x_t, x_a) x_b_new_mask = self.gen.decode(self.gen.encode(x_b_new)) x_ba_new = self.masking(x_b_new_mask, x_b_new) x_t_new = self.scaled_sub(x_b_new, x_ba_new) # decode twice x_baa_mask = self.gen.decode(self.gen.encode(x_ba)) x_baa = self.masking(x_baa_mask, x_ba) # reconstruction loss self.loss_gen_recon_x_aa = self.recon_criterion(x_aa, x_a) self.loss_gen_recon_x_t = self.recon_criterion(x_t_new, x_t) self.loss_gen_recon_x_ba_new = self.recon_criterion(x_ba_new, x_a) self.loss_gen_recon_x_baa = self.recon_criterion(x_baa, x_ba) # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_b_new) # total loss self.loss_gen_total = hyperparameters['gan_w_a'] * self.loss_gen_adv_a + \ hyperparameters['gan_w_b'] * self.loss_gen_adv_b + \ hyperparameters['a2a_w'] * self.loss_gen_recon_x_aa + \ hyperparameters['x_t_w'] * self.loss_gen_recon_x_t + \ hyperparameters['recon_w'] * self.loss_gen_recon_x_ba_new + \ hyperparameters['DTN_w'] * self.loss_gen_recon_x_baa self.loss_gen_total.backward() self.gen_opt.step() def sample(self, x_a, x_b): self.eval() x_ba, x_aa, x_ba_new = [], [], [] x_t, x_t_new = [], [] x_baa, x_b_new = [], [] for i in range(x_a.size(0)): x_ba.append(self.masking(self.gen.decode(self.gen.encode(x_b[i].unsqueeze(0))), x_b[i].unsqueeze(0))) x_aa.append(self.masking(self.gen.decode(self.gen.encode(x_a[i].unsqueeze(0))), x_a[i].unsqueeze(0))) x_t.append(self.scaled_sub(x_b[i], x_ba[i])) x_b_new.append(self.scaled_sum(x_t[i], x_a[i].unsqueeze(0))) x_ba_new.append(self.masking(self.gen.decode(self.gen.encode(x_b_new[i])), x_b_new[i])) x_t_new.append(self.scaled_sub(x_b_new[i], x_ba_new[i])) x_baa.append(self.masking(self.gen.decode(self.gen.encode(x_ba[i])), x_ba[i])) x_ba, x_aa, x_ba_new = torch.cat(x_ba), torch.cat(x_aa), torch.cat(x_ba_new) x_t, x_t_new = torch.cat(x_t), torch.cat(x_t_new) x_baa, x_b_new = torch.cat(x_baa), torch.cat(x_b_new) self.train() return x_b, x_ba, x_baa, x_a, x_ba_new, x_aa, x_t, x_t_new, x_b_new def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() x_ba = self.masking(self.gen.decode(self.gen.encode(x_b)), x_b) x_t = self.scaled_sub(x_b, x_ba) x_b_new = self.scaled_sum(x_t, x_a) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_b_new.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w_a'] * self.loss_dis_a + hyperparameters['gan_w_b'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen.load_state_dict(state_dict['gen']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'gen': self.gen.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class AGUIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(AGUIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.noise_dim = hyperparameters['gen']['noise_dim'] self.attr_dim = len(hyperparameters['gen']['selected_attrs']) self.gen = AdaINGen(hyperparameters['input_dim'], hyperparameters['gen']) self.dis = MsImageDis(hyperparameters['input_dim'], self.attr_dim, hyperparameters['dis']) self.dis_content = ContentDis( hyperparameters['gen']['dim'] * (2**hyperparameters['gen']['n_downsample']), self.attr_dim) # fix the noise used in sampling display_size = int(hyperparameters['display_size']) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis.parameters()) + list( self.dis_content.parameters()) gen_params = list(self.gen.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis.apply(weights_init('gaussian')) self.dis_content.apply(weights_init('gaussian')) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def gen_update(self, x_l, x_u, l, hyperparameters): self.gen_opt.zero_grad() # l_s_rand = torch.randn_like(l_s) # l_s = torch.where(l_s == 0, l_s_rand, l_s) s_r = torch.cat([torch.randn(x_u.size(0), self.noise_dim).cuda(), l], 1) # encode c_l, s_l = self.gen.encode(x_l) c_u, s_u = self.gen.encode(x_u) # decode (within domain) x_u_recon = self.gen.decode(c_u, s_u) # decode (cross domain) x_ur = self.gen.decode(c_u, s_r) # encode again c_u_recon, s_r_recon = self.gen.encode(x_ur) x_u_cycle = self.gen.decode(c_u_recon, s_u) # additional KL-loss (optional) s_mean = s_l[:, 0:self.noise_dim].mean() s_std = s_l[:, 0:self.noise_dim].std() self.loss_gen_kld = (s_mean**2 + s_std.pow(2) - s_std.pow(2).log() - 1).mean() / 2 self.loss_gen_adv_content = self.dis_content.calc_gen_loss(c_l, c_u, l) # reconstruction loss self.loss_gen_rec = self.recon_criterion(x_u_recon, x_u) self.loss_gen_rec_s = self.recon_criterion(s_r_recon, s_r) self.loss_gen_rec_c = self.recon_criterion(c_u_recon, c_u) self.loss_gen_cyc = self.recon_criterion(x_u_cycle, x_u) # GAN loss self.loss_gen_adv = self.dis.calc_gen_loss(x_ur, l) # label part loss self.loss_gen_cla = ( s_l[:, self.noise_dim:self.noise_dim + self.attr_dim] - l).pow(2).mean() self.loss_gen_total = hyperparameters['adv_w'] * self.loss_gen_adv + \ hyperparameters['adv_c_w'] * self.loss_gen_adv_content + \ hyperparameters['rec_w'] * self.loss_gen_rec + \ hyperparameters['rec_s_w'] * self.loss_gen_rec_s + \ hyperparameters['rec_c_w'] * self.loss_gen_rec_c + \ hyperparameters['cla_w'] * self.loss_gen_cla + \ hyperparameters['kld_w'] * self.loss_gen_kld + \ hyperparameters['cyc_w'] * self.loss_gen_cyc self.loss_gen_total.backward() self.gen_opt.step() return self.loss_gen_total.detach() def sample(self, x_l, l): c_l, s_l = self.gen.encode(x_l) # decode (within domain) x_l_recon = self.gen.decode(c_l, s_l) out = [x_l, x_l_recon] for i in range(self.attr_dim): s_changed = s_l.clone() s_changed[:, self.noise_dim + i] = -l[:, i] out += [self.gen.decode(c_l, s_changed)] return out def dis_update(self, x_l, x_u, l, hyperparameters): self.dis_opt.zero_grad() s_r = torch.cat([torch.randn(x_u.size(0), self.noise_dim).cuda(), l], 1) # encode c_l, s_l = self.gen.encode(x_l) c_u, s_u = self.gen.encode(x_u) # decode (cross domain) x_ur = self.gen.decode(c_u, s_r) # D loss self.loss_dis_adv = self.dis.calc_dis_loss(x_ur.detach(), x_l, x_u, l) self.loss_dis_adv_content = self.dis_content.calc_dis_loss(c_l, c_u, l) self.loss_dis_total = hyperparameters['adv_w'] * self.loss_dis_adv + \ hyperparameters['adv_c_w'] * self.loss_dis_adv_content self.loss_dis_total.backward() self.dis_opt.step() return self.loss_dis_total.detach() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen.load_state_dict(state_dict['gen']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis.load_state_dict(state_dict['dis']) self.dis_content.load_state_dict(state_dict['dis_content']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'gen': self.gen.state_dict()}, gen_name) torch.save( { 'dis': self.dis_a.state_dict(), 'dis_content': self.dis_content.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class UnsupIntrinsicTrainer(nn.Module): def __init__(self, param): super(UnsupIntrinsicTrainer, self).__init__() lr = param['lr'] # Initiate the networks self.gen_i = AdaINGen(param['input_dim_a'], param['input_dim_a'], param['gen']) # auto-encoder for domain I self.gen_r = AdaINGen(param['input_dim_b'], param['input_dim_b'], param['gen']) # auto-encoder for domain R self.gen_s = AdaINGen(param['input_dim_c'], param['input_dim_c'], param['gen']) # auto-encoder for domain S self.dis_r = MsImageDis(param['input_dim_b'], param['dis']) # discriminator for domain R self.dis_s = MsImageDis(param['input_dim_c'], param['dis']) # discriminator for domain S gp = param['gen'] self.with_mapping = True self.use_phy_loss = True self.use_content_loss = True if 'ablation_study' in param: if 'with_mapping' in param['ablation_study']: wm = param['ablation_study']['with_mapping'] self.with_mapping = True if wm != 0 else False if 'wo_phy_loss' in param['ablation_study']: wpl = param['ablation_study']['wo_phy_loss'] self.use_phy_loss = True if wpl == 0 else False if 'wo_content_loss' in param['ablation_study']: wcl = param['ablation_study']['wo_content_loss'] self.use_content_loss = True if wcl == 0 else False if self.with_mapping: self.fea_s = IntrinsicSplitor(gp['style_dim'], gp['mlp_dim'], gp['n_layer'], gp['activ']) # split style for I self.fea_m = IntrinsicMerger(gp['style_dim'], gp['mlp_dim'], gp['n_layer'], gp['activ']) # merge style for R, S self.bias_shift = param['bias_shift'] self.instance_norm = nn.InstanceNorm2d(512, affine=False) self.style_dim = param['gen']['style_dim'] # fix the noise used in sampling display_size = int(param['display_size']) self.s_r = torch.randn(display_size, self.style_dim, 1, 1).cuda() + 1. self.s_s = torch.randn(display_size, self.style_dim, 1, 1).cuda() - 1. # Setup the optimizers beta1 = param['beta1'] beta2 = param['beta2'] dis_params = list(self.dis_r.parameters()) + list( self.dis_s.parameters()) if self.with_mapping: gen_params = list(self.gen_i.parameters()) + list(self.gen_r.parameters()) + \ list(self.gen_s.parameters()) + \ list(self.fea_s.parameters()) + list(self.fea_m.parameters()) else: gen_params = list(self.gen_i.parameters()) + list( self.gen_r.parameters()) + list(self.gen_s.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=param['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=param['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, param) self.gen_scheduler = get_scheduler(self.gen_opt, param) # Network weight initialization self.apply(weights_init(param['init'])) self.dis_r.apply(weights_init('gaussian')) self.dis_s.apply(weights_init('gaussian')) self.best_result = float('inf') self.reflectance_loss = LocalAlbedoSmoothnessLoss(param) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def physical_criterion(self, x_i, x_r, x_s): return torch.mean(torch.abs(x_i - x_r * x_s)) def forward(self, x_i): c_i, s_i_fake = self.gen_i.encode(x_i) if self.with_mapping: s_r, s_s = self.fea_s(s_i_fake) else: s_r = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_s = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift x_ri = self.gen_r.decode(c_i, s_r) x_si = self.gen_s.decode(c_i, s_s) return x_ri, x_si def inference(self, x_i, use_rand_fea=False): with torch.no_grad(): c_i, s_i_fake = self.gen_i.encode(x_i) if self.with_mapping: s_r, s_s = self.fea_s(s_i_fake) else: s_r = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_s = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift if use_rand_fea: s_r = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_s = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift x_ri = self.gen_r.decode(c_i, s_r) x_si = self.gen_s.decode(c_i, s_s) return x_ri, x_si # noinspection PyAttributeOutsideInit def gen_update(self, x_i, x_r, x_s, targets=None, param=None): self.gen_opt.zero_grad() # ============= Domain Translations ============= # encode c_i, s_i_prime = self.gen_i.encode(x_i) c_r, s_r_prime = self.gen_r.encode(x_r) c_s, s_s_prime = self.gen_s.encode(x_s) if self.with_mapping: s_ri, s_si = self.fea_s(s_i_prime) else: s_ri = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_si = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift s_r_rand = Variable( torch.randn(x_r.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_s_rand = Variable( torch.randn(x_s.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift if self.with_mapping: s_i_recon = self.fea_m(s_ri, s_si) else: s_i_recon = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) # decode (within domain) x_i_recon = self.gen_i.decode(c_i, s_i_prime) x_r_recon = self.gen_s.decode(c_r, s_r_prime) x_s_recon = self.gen_r.decode(c_s, s_s_prime) # decode (cross domain) x_rs = self.gen_r.decode(c_s, s_r_rand) x_ri = self.gen_r.decode(c_i, s_ri) x_ri_rand = self.gen_r.decode(c_i, s_r_rand) x_sr = self.gen_s.decode(c_r, s_s_rand) x_si = self.gen_s.decode(c_i, s_si) x_si_rand = self.gen_s.decode(c_i, s_r_rand) # encode again, for feature domain consistency constraints c_rs_recon, s_rs_recon = self.gen_r.encode(x_rs) c_ri_recon, s_ri_recon = self.gen_r.encode(x_ri) c_ri_rand_recon, s_ri_rand_recon = self.gen_r.encode(x_ri_rand) c_sr_recon, s_sr_recon = self.gen_s.encode(x_sr) c_si_recon, s_si_recon = self.gen_s.encode(x_si) c_si_rand_recon, s_si_rand_recon = self.gen_s.encode(x_si_rand) # decode again, for image domain cycle consistency x_rsr = self.gen_r.decode(c_sr_recon, s_r_prime) x_iri = self.gen_i.decode(c_ri_recon, s_i_prime) x_iri_rand = self.gen_i.decode(c_ri_rand_recon, s_i_prime) x_srs = self.gen_s.decode(c_rs_recon, s_s_prime) x_isi = self.gen_i.decode(c_si_recon, s_i_prime) x_isi_rand = self.gen_i.decode(c_si_rand_recon, s_i_prime) # ============= Loss Functions ============= # Encoder decoder reconstruction loss for three domain self.loss_gen_recon_x_i = self.recon_criterion(x_i_recon, x_i) self.loss_gen_recon_x_r = self.recon_criterion(x_r_recon, x_r) self.loss_gen_recon_x_s = self.recon_criterion(x_s_recon, x_s) # Style-level reconstruction loss for cross domain if self.with_mapping: self.loss_gen_recon_s_ii = self.recon_criterion( s_i_recon, s_i_prime) else: self.loss_gen_recon_s_ii = 0 self.loss_gen_recon_s_ri = self.recon_criterion(s_ri_recon, s_ri) self.loss_gen_recon_s_ri_rand = self.recon_criterion( s_ri_rand_recon, s_ri) self.loss_gen_recon_s_rs = self.recon_criterion(s_rs_recon, s_r_rand) self.loss_gen_recon_s_sr = self.recon_criterion(s_sr_recon, s_s_rand) self.loss_gen_recon_s_si = self.recon_criterion(s_si_recon, s_si) self.loss_gen_recon_s_si_rand = self.recon_criterion( s_si_rand_recon, s_si) # Content-level reconstruction loss for cross domain self.loss_gen_recon_c_rs = self.recon_criterion(c_rs_recon, c_s) self.loss_gen_recon_c_ri = self.recon_criterion( c_ri_recon, c_i) if self.use_content_loss is True else 0 self.loss_gen_recon_c_ri_rand = self.recon_criterion( c_ri_rand_recon, c_i) if self.use_content_loss is True else 0 self.loss_gen_recon_c_sr = self.recon_criterion(c_sr_recon, c_r) self.loss_gen_recon_c_si = self.recon_criterion( c_si_recon, c_i) if self.use_content_loss is True else 0 self.loss_gen_recon_c_si_rand = self.recon_criterion( c_si_rand_recon, c_i) if self.use_content_loss is True else 0 # Cycle consistency loss for three image domain self.loss_gen_cyc_recon_x_rs = self.recon_criterion(x_rsr, x_r) self.loss_gen_cyc_recon_x_ir = self.recon_criterion(x_iri, x_i) self.loss_gen_cyc_recon_x_ir_rand = self.recon_criterion( x_iri_rand, x_i) self.loss_gen_cyc_recon_x_sr = self.recon_criterion(x_srs, x_s) self.loss_gen_cyc_recon_x_is = self.recon_criterion(x_isi, x_i) self.loss_gen_cyc_recon_x_is_rand = self.recon_criterion( x_isi_rand, x_i) # GAN loss self.loss_gen_adv_rs = self.dis_r.calc_gen_loss(x_rs) self.loss_gen_adv_ri = self.dis_r.calc_gen_loss(x_ri) self.loss_gen_adv_ri_rand = self.dis_r.calc_gen_loss(x_ri_rand) self.loss_gen_adv_sr = self.dis_s.calc_gen_loss(x_sr) self.loss_gen_adv_si = self.dis_s.calc_gen_loss(x_si) self.loss_gen_adv_si_rand = self.dis_s.calc_gen_loss(x_si_rand) # Physical loss self.loss_gen_phy_i = self.physical_criterion( x_i, x_ri, x_si) if self.use_phy_loss is True else 0 self.loss_gen_phy_i_rand = self.physical_criterion( x_i, x_ri_rand, x_si_rand) if self.use_phy_loss is True else 0 # Reflectance smoothness loss self.loss_refl_ri = self.reflectance_loss( x_ri, targets) if targets is not None else 0 self.loss_refl_ri_rand = self.reflectance_loss( x_ri_rand, targets) if targets is not None else 0 # total loss self.loss_gen_total = param['gan_w'] * self.loss_gen_adv_rs + \ param['gan_w'] * self.loss_gen_adv_ri + \ param['gan_w'] * self.loss_gen_adv_ri_rand + \ param['gan_w'] * self.loss_gen_adv_sr + \ param['gan_w'] * self.loss_gen_adv_si + \ param['gan_w'] * self.loss_gen_adv_si_rand + \ param['recon_x_w'] * self.loss_gen_recon_x_i + \ param['recon_x_w'] * self.loss_gen_recon_x_r + \ param['recon_x_w'] * self.loss_gen_recon_x_s + \ param['recon_s_w'] * self.loss_gen_recon_s_ii + \ param['recon_s_w'] * self.loss_gen_recon_s_ri + \ param['recon_s_w'] * self.loss_gen_recon_s_ri_rand + \ param['recon_s_w'] * self.loss_gen_recon_s_rs + \ param['recon_s_w'] * self.loss_gen_recon_s_si + \ param['recon_s_w'] * self.loss_gen_recon_s_si_rand + \ param['recon_s_w'] * self.loss_gen_recon_s_sr + \ param['recon_c_w'] * self.loss_gen_recon_c_ri + \ param['recon_c_w'] * self.loss_gen_recon_c_rs + \ param['recon_c_w'] * self.loss_gen_recon_c_ri_rand + \ param['recon_c_w'] * self.loss_gen_recon_c_si + \ param['recon_c_w'] * self.loss_gen_recon_c_sr + \ param['recon_c_w'] * self.loss_gen_recon_c_si_rand + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_ir + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_ir_rand + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_is + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_is_rand + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_rs + \ param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_sr + \ param['phy_x_w'] * self.loss_gen_phy_i + \ param['phy_x_w'] * self.loss_gen_phy_i_rand + \ param['refl_smooth_w'] * self.loss_refl_ri + \ param['refl_smooth_w'] * self.loss_refl_ri_rand self.loss_gen_total.backward() self.gen_opt.step() def sample(self, x_i, x_r, x_s): self.eval() s_r = Variable(self.s_r) s_s = Variable(self.s_s) x_i_recon, x_r_recon, x_s_recon, x_rs, x_ri, x_sr, x_si = [], [], [], [], [], [], [] for i in range(x_i.size(0)): c_i, s_i_fake = self.gen_i.encode(x_i[i].unsqueeze(0)) c_r, s_r_fake = self.gen_r.encode(x_r[i].unsqueeze(0)) c_s, s_s_fake = self.gen_s.encode(x_s[i].unsqueeze(0)) if self.with_mapping: s_ri, s_si = self.fea_s(s_i_fake) else: s_ri = Variable(torch.randn(1, self.style_dim, 1, 1).cuda()) + self.bias_shift s_si = Variable(torch.randn(1, self.style_dim, 1, 1).cuda()) - self.bias_shift x_i_recon.append(self.gen_i.decode(c_i, s_i_fake)) x_r_recon.append(self.gen_r.decode(c_r, s_r_fake)) x_s_recon.append(self.gen_s.decode(c_s, s_s_fake)) x_rs.append(self.gen_r.decode(c_s, s_r[i].unsqueeze(0))) x_ri.append(self.gen_r.decode(c_i, s_ri.unsqueeze(0))) x_sr.append(self.gen_s.decode(c_s, s_s[i].unsqueeze(0))) x_si.append(self.gen_s.decode(c_i, s_si.unsqueeze(0))) x_i_recon, x_r_recon, x_s_recon = torch.cat(x_i_recon), torch.cat( x_r_recon), torch.cat(x_s_recon) x_rs, x_ri = torch.cat(x_rs), torch.cat(x_ri) x_sr, x_si = torch.cat(x_sr), torch.cat(x_si) self.train() return x_i, x_i_recon, x_r, x_r_recon, x_rs, x_ri, x_s, x_s_recon, x_sr, x_si # noinspection PyAttributeOutsideInit def dis_update(self, x_i, x_r, x_s, params): self.dis_opt.zero_grad() s_r = Variable(torch.randn(x_r.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift s_s = Variable(torch.randn(x_s.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift # encode c_r, _ = self.gen_r.encode(x_r) c_s, _ = self.gen_s.encode(x_s) c_i, s_i = self.gen_i.encode(x_i) if self.with_mapping: s_ri, s_si = self.fea_s(s_i) else: s_ri = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) + self.bias_shift s_si = Variable( torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda()) - self.bias_shift # decode (cross domain) x_rs = self.gen_r.decode(c_s, s_r) x_ri = self.gen_r.decode(c_i, s_ri) x_sr = self.gen_s.decode(c_r, s_s) x_si = self.gen_s.decode(c_i, s_si) # D loss self.loss_dis_rs = self.dis_r.calc_dis_loss(x_rs.detach(), x_r) self.loss_dis_ri = self.dis_r.calc_dis_loss(x_ri.detach(), x_r) self.loss_dis_sr = self.dis_s.calc_dis_loss(x_sr.detach(), x_s) self.loss_dis_si = self.dis_s.calc_dis_loss(x_si.detach(), x_s) self.loss_dis_total = params['gan_w'] * self.loss_dis_rs +\ params['gan_w'] * self.loss_dis_ri +\ params['gan_w'] * self.loss_dis_sr +\ params['gan_w'] * self.loss_dis_si self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, param): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_i.load_state_dict(state_dict['i']) self.gen_r.load_state_dict(state_dict['r']) self.gen_s.load_state_dict(state_dict['s']) if self.with_mapping: self.fea_m.load_state_dict(state_dict['fm']) self.fea_s.load_state_dict(state_dict['fs']) self.best_result = state_dict['best_result'] iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_r.load_state_dict(state_dict['r']) self.dis_s.load_state_dict(state_dict['s']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, param, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, param, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') if self.with_mapping: torch.save( { 'i': self.gen_i.state_dict(), 'r': self.gen_r.state_dict(), 's': self.gen_s.state_dict(), 'fs': self.fea_s.state_dict(), 'fm': self.fea_m.state_dict(), 'best_result': self.best_result }, gen_name) else: torch.save( { 'i': self.gen_i.state_dict(), 'r': self.gen_r.state_dict(), 's': self.gen_s.state_dict(), 'best_result': self.best_result }, gen_name) torch.save({ 'r': self.dis_r.state_dict(), 's': self.dis_s.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters["lr"] self.newsize = hyperparameters["crop_image_height"] self.semantic_w = hyperparameters["semantic_w"] > 0 self.recon_mask = hyperparameters["recon_mask"] == 1 self.dann_scheduler = None self.full_adaptation = hyperparameters["adaptation"][ "full_adaptation"] == 1 dim = hyperparameters["gen"]["dim"] n_downsample = hyperparameters["gen"]["n_downsample"] latent_dim = dim * (2**n_downsample) if "domain_adv_w" in hyperparameters.keys(): self.domain_classif_ab = hyperparameters["domain_adv_w"] > 0 else: self.domain_classif_ab = False if hyperparameters["adaptation"]["dfeat_lambda"] > 0: self.use_classifier_sr = True else: self.use_classifier_sr = False if hyperparameters["adaptation"]["sem_seg_lambda"] > 0: self.train_seg = True else: self.train_seg = False if hyperparameters["adaptation"]["output_classifier_lambda"] > 0: self.use_output_classifier_sr = True else: self.use_output_classifier_sr = False self.gen = SpadeGen(hyperparameters["input_dim_a"], hyperparameters["gen"]) # Note: the "+1" is for the masks if hyperparameters["dis"]["type"] == "patchgan": print("Using patchgan discrminator...") self.dis_a = MultiscaleDiscriminator( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain a self.dis_b = MultiscaleDiscriminator( hyperparameters["input_dim_b"], hyperparameters["dis"]) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.dis_a_masked = MultiscaleDiscriminator( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain a self.dis_b_masked = MultiscaleDiscriminator( hyperparameters["input_dim_b"], hyperparameters["dis"]) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) else: self.dis_a = MsImageDis( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters["input_dim_b"], hyperparameters["dis"]) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.dis_a_masked = MsImageDis( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain a self.dis_b_masked = MsImageDis( hyperparameters["input_dim_b"], hyperparameters["dis"]) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) # fix the noise usd in sampling display_size = int(hyperparameters["display_size"]) # Setup the optimizers beta1 = hyperparameters["beta1"] beta2 = hyperparameters["beta2"] dis_params = (list(self.dis_a.parameters()) + list(self.dis_b.parameters()) + list(self.dis_a_masked.parameters()) + list(self.dis_b_masked.parameters())) gen_params = list(self.gen.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters["init"])) self.dis_a.apply(weights_init("gaussian")) self.dis_b.apply(weights_init("gaussian")) self.dis_a_masked.apply(weights_init("gaussian")) self.dis_b_masked.apply(weights_init("gaussian")) # Load VGG model if needed if hyperparameters["vgg_w"] > 0: self.criterionVGG = VGGLoss() # Load semantic segmentation model if needed if "semantic_w" in hyperparameters.keys( ) and hyperparameters["semantic_w"] > 0: self.segmentation_model = load_segmentation_model( hyperparameters["semantic_ckpt_path"], 19) self.segmentation_model.eval() for param in self.segmentation_model.parameters(): param.requires_grad = False # Load domain classifier if needed if "domain_adv_w" in hyperparameters.keys( ) and hyperparameters["domain_adv_w"] > 0: self.domain_classifier_ab = domainClassifier(input_dim=latent_dim, dim=256) dann_params = list(self.domain_classifier_ab.parameters()) self.dann_opt = torch.optim.Adam( [p for p in dann_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.domain_classifier_ab.apply(weights_init("gaussian")) self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters) # Load classifier on features for syn, real adaptation if self.use_classifier_sr: #! Hardcoded self.domain_classifier_sr_b = domainClassifier( input_dim=latent_dim, dim=256) self.domain_classifier_sr_a = domainClassifier( input_dim=latent_dim, dim=256) dann_params = list( self.domain_classifier_sr_a.parameters()) + list( self.domain_classifier_sr_b.parameters()) self.classif_opt_sr = torch.optim.Adam( [p for p in dann_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.domain_classifier_sr_a.apply(weights_init("gaussian")) self.domain_classifier_sr_b.apply(weights_init("gaussian")) self.classif_sr_scheduler = get_scheduler(self.classif_opt_sr, hyperparameters) if self.use_output_classifier_sr: if self.hyperparameters["dis"]["type"] == "patchgan": self.output_classifier_sr_a = MultiscaleDiscriminator( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain a,sr self.output_classifier_sr_b = MultiscaleDiscriminator( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain b,sr else: self.output_classifier_sr_a = MsImageDis( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain a,sr self.output_classifier_sr_b = MsImageDis( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain b,sr dann_params = list( self.output_classifier_sr_a.parameters()) + list( self.output_classifier_sr_b.parameters()) self.output_classif_opt_sr = torch.optim.Adam( [p for p in dann_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.output_classifier_sr_b.apply(weights_init("gaussian")) self.output_classifier_sr_a.apply(weights_init("gaussian")) self.output_scheduler_sr = get_scheduler( self.output_classif_opt_sr, hyperparameters) if self.train_seg: pretrained = load_segmentation_model( hyperparameters["semantic_ckpt_path"], 19) last_layer = nn.Conv2d(512, 10, kernel_size=1) model = torch.nn.Sequential( *list(pretrained.resnet34_8s.children())[7:-1], last_layer.cuda()) self.segmentation_head = model for param in self.segmentation_head.parameters(): param.requires_grad = True dann_params = list(self.segmentation_head.parameters()) self.segmentation_opt = torch.optim.Adam( [p for p in dann_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.scheduler_seg = get_scheduler(self.segmentation_opt, hyperparameters) def recon_criterion(self, input, target): """ Compute pixelwise L1 loss between two images input and target Arguments: input {torch.Tensor} -- Image tensor target {torch.Tensor} -- Image tensor Returns: torch.Float -- pixelwise L1 loss """ return torch.mean(torch.abs(input - target)) def recon_criterion_mask(self, input, target, mask): """ Compute a weaker version of the recon_criterion between two images input and target where the L1 is only computed on the unmasked region Arguments: input {torch.Tensor} -- Image (original image such as x_a) target {torch.Tensor} -- Image (after cycle-translation image x_aba) mask {} -- binary Mask of size HxW (input.shape ~ CxHxW) Returns: torch.Float -- L1 loss over input.(1-mask) and target.(1-mask) """ return torch.mean(torch.abs(torch.mul((input - target), 1 - mask))) def forward(self, x_a, x_b, m_a, m_b): """ Perform the translation from domain A (resp B) to domain B (resp A): x_a to x_ab (resp: x_b to x_ba). Arguments: x_a {torch.Tensor} -- Image from domain A after transform in tensor format x_b {torch.Tensor} -- Image from domain B after transform in tensor format Returns: torch.Tensor, torch.Tensor -- Translated version of x_a in domain B, Translated version of x_b in domain A """ self.eval() x_a_augment = torch.cat([x_a, m_a], dim=1) x_b_augment = torch.cat([x_b, m_b], dim=1) c_a = self.gen.encode(x_a, 1) c_b = self.gen.encode(x_b, 2) x_ba = self.gen.decode(c_b, 1) x_ab = self.gen.decode(c_a, 2) self.train() return x_ab, x_ba def gen_update( self, x_a, x_b, hyperparameters, mask_a, mask_b, comet_exp=None, synth=False, semantic_gt_a=None, semantic_gt_b=None, ): """ Update the generator parameters Arguments: x_a {torch.Tensor} -- Image from domain A after transform in tensor format x_b {torch.Tensor} -- Image from domain B after transform in tensor format hyperparameters {dictionnary} -- dictionnary with all hyperparameters Keyword Arguments: mask_a {torch.Tensor} -- binary mask (0,1) corresponding to the ground in x_a (default: {None}) mask_b {torch.Tensor} -- binary mask (0,1) corresponding to the water in x_b (default: {None}) comet_exp {cometExperience} -- CometML object use to log all the loss and images (default: {None}) synth {boolean} -- binary True or False stating if we have a synthetic pair or not Returns: [type] -- [description] """ self.gen_opt.zero_grad() # encode x_a_augment = torch.cat([x_a, mask_a], dim=1) x_b_augment = torch.cat([x_b, mask_b], dim=1) c_a = self.gen.encode(x_a, 1) c_b = self.gen.encode(x_b, 2) # decode (within domain) x_a_recon = self.gen.decode(c_a, mask_a, 1) x_b_recon = self.gen.decode(c_b, mask_b, 2) x_ba = self.gen.decode(c_b, mask_b, 1) x_ab = self.gen.decode(c_a, mask_a, 2) x_ba_augment = torch.cat([x_ba, mask_b], dim=1) x_ab_augment = torch.cat([x_ab, mask_a], dim=1) # encode again c_b_recon = self.gen.encode(x_ba, 1) c_a_recon = self.gen.encode(x_ab, 2) # decode again (if needed) x_aba = (self.gen.decode(c_a_recon, mask_a, 1) if hyperparameters["recon_x_cyc_w"] > 0 else None) x_bab = (self.gen.decode(c_b_recon, mask_b, 2) if hyperparameters["recon_x_cyc_w"] > 0 else None) # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) # Contex preserving loss self.context_loss = self.recon_criterion_mask( x_ab, x_a, mask_a) + self.recon_criterion_mask(x_ba, x_b, mask_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) # Synthetic reconstruction loss if synth: # print('mask_b.shape', mask_b.shape) # Define the mask of exact same pixel among a pair mask_alignment = (torch.sum(torch.abs(x_a - x_b), 1) == 0).unsqueeze(1) mask_alignment = mask_alignment.type(torch.cuda.FloatTensor) # print('mask_alignment.shape', mask_alignment.shape) self.loss_gen_recon_synth = ( self.recon_criterion_mask(x_ab, x_b, 1 - mask_alignment) + self.recon_criterion_mask(x_ba, x_a, 1 - mask_alignment) if synth else 0) if self.recon_mask: self.loss_gen_cycrecon_x_a = (self.recon_criterion_mask( x_aba, x_a, mask_a) if hyperparameters["recon_x_cyc_w"] > 0 else 0) self.loss_gen_cycrecon_x_b = (self.recon_criterion_mask( x_bab, x_b, mask_b) if hyperparameters["recon_x_cyc_w"] > 0 else 0) else: self.loss_gen_cycrecon_x_a = (self.recon_criterion( x_aba, x_a) if hyperparameters["recon_x_cyc_w"] > 0 else 0) self.loss_gen_cycrecon_x_b = (self.recon_criterion( x_bab, x_b) if hyperparameters["recon_x_cyc_w"] > 0 else 0) # GAN loss # Concat masks before feeding to loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba, x_a, comet_exp, mode="a") self.loss_gen_adv_a += self.dis_a_masked.calc_gen_loss(x_ba * mask_b, x_a * mask_a, comet_exp, mode="a") self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab, x_b, comet_exp, mode="b") self.loss_gen_adv_b += self.dis_b_masked.calc_gen_loss(x_ab * mask_a, x_b * mask_b, comet_exp, mode="b") # domain-invariant perceptual loss self.loss_gen_vgg_a = (self.compute_vgg_loss(x_ba, x_b, mask_b) if hyperparameters["vgg_w"] > 0 else 0) self.loss_gen_vgg_b = (self.compute_vgg_loss(x_ab, x_a, mask_a) if hyperparameters["vgg_w"] > 0 else 0) # semantic-segmentation loss self.loss_sem_seg = ( self.compute_semantic_seg_loss(x_a, x_ab, mask_a, semantic_gt_a) + self.compute_semantic_seg_loss(x_b, x_ba, mask_b, semantic_gt_b) if hyperparameters["semantic_w"] > 0 else 0) # Domain adversarial loss (c_a and c_b are swapped because we want the feature to be less informative # minmax (accuracy but max min loss) self.domain_adv_loss = (self.compute_domain_adv_loss( c_a, c_b, compute_accuracy=False, minimize=False) if hyperparameters["domain_adv_w"] > 0 else 0) self.loss_classifier_sr = (self.compute_classifier_sr_loss( c_a, c_b, domain_synth=synth, fool=True) if hyperparameters["adaptation"]["adv_lambda"] > 0 else 0) if hyperparameters["adaptation"]["output_adv_lambda"] > 0: self.loss_output_classifier_sr = self.output_classifier_sr_a.calc_gen_loss_sr( x_ba) + self.output_classifier_sr_b.calc_gen_loss_sr(x_ab) else: self.loss_output_classifier_sr = 0 # total loss self.loss_gen_total = ( hyperparameters["gan_w"] * self.loss_gen_adv_a + hyperparameters["gan_w"] * self.loss_gen_adv_b + hyperparameters["recon_x_w"] * self.loss_gen_recon_x_a + hyperparameters["recon_c_w"] * self.loss_gen_recon_c_a + hyperparameters["recon_x_w"] * self.loss_gen_recon_x_b + hyperparameters["recon_c_w"] * self.loss_gen_recon_c_b + hyperparameters["recon_x_cyc_w"] * self.loss_gen_cycrecon_x_a + hyperparameters["recon_x_cyc_w"] * self.loss_gen_cycrecon_x_b + hyperparameters["vgg_w"] * self.loss_gen_vgg_a + hyperparameters["vgg_w"] * self.loss_gen_vgg_b + hyperparameters["context_w"] * self.context_loss + hyperparameters["semantic_w"] * self.loss_sem_seg + hyperparameters["domain_adv_w"] * self.domain_adv_loss + hyperparameters["recon_synth_w"] * self.loss_gen_recon_synth + hyperparameters["adaptation"]["adv_lambda"] * self.loss_classifier_sr + hyperparameters["adaptation"]["output_adv_lambda"] * self.loss_output_classifier_sr) self.loss_gen_total.backward() self.gen_opt.step() if comet_exp is not None: comet_exp.log_metric("loss_gen_adv_a", self.loss_gen_adv_a.cpu().detach()) comet_exp.log_metric("loss_gen_adv_b", self.loss_gen_adv_b.cpu().detach()) comet_exp.log_metric("loss_gen_recon_x_a", self.loss_gen_recon_x_a.cpu().detach()) comet_exp.log_metric("loss_gen_recon_x_b", self.loss_gen_recon_x_b.cpu().detach()) if hyperparameters["recon_c_w"] > 0: comet_exp.log_metric("loss_gen_recon_c_a", self.loss_gen_recon_c_a.cpu().detach()) comet_exp.log_metric("loss_gen_recon_c_b", self.loss_gen_recon_c_b.cpu().detach()) if hyperparameters["recon_x_cyc_w"] > 0: comet_exp.log_metric("loss_gen_cycrecon_x_a", self.loss_gen_cycrecon_x_a.cpu().detach()) comet_exp.log_metric("loss_gen_cycrecon_x_b", self.loss_gen_cycrecon_x_b.cpu().detach()) comet_exp.log_metric("loss_gen_total", self.loss_gen_total.cpu().detach()) if hyperparameters["vgg_w"] > 0: comet_exp.log_metric("loss_gen_vgg_a", self.loss_gen_vgg_a.cpu().detach()) comet_exp.log_metric("loss_gen_vgg_b", self.loss_gen_vgg_b.cpu().detach()) if hyperparameters["semantic_w"] > 0: comet_exp.log_metric("loss_sem_seg", self.loss_sem_seg.cpu().detach()) if hyperparameters["context_w"] > 0: comet_exp.log_metric("context_preserve_loss", self.context_loss.cpu().detach()) if hyperparameters["domain_adv_w"] > 0: comet_exp.log_metric("domain_adv_loss_gen", self.domain_adv_loss.cpu().detach()) if synth: comet_exp.log_metric("loss_gen_recon_synth", self.loss_gen_recon_synth.cpu().detach()) if self.use_classifier_sr: comet_exp.log_metric("loss_classifier_adv_sr", self.loss_classifier_sr.cpu().detach()) if self.use_output_classifier_sr: comet_exp.log_metric( "loss_output_classifier_adv_sr", self.loss_output_classifier_sr.cpu().detach()) def compute_vgg_loss(self, img, target, mask): """ Compute the domain-invariant perceptual loss Arguments: vgg {model} -- popular Convolutional Network for Classification and Detection img {torch.Tensor} -- image before translation target {torch.Tensor} -- image after translation Returns: torch.Float -- domain invariant perceptual loss """ img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) # Mask input to VGG: img_vgg = img_vgg * (1.0 - mask) target_vgg = target_vgg * (1.0 - mask) loss_G_VGG = self.criterionVGG(img_vgg, target_vgg) return loss_G_VGG def compute_classifier_sr_loss(self, c_a, c_b, domain_synth=False, fool=False): """ Compute classifier loss for the adaptation s/r Arguments: c_a {torch.Tensor} -- content of x_a c_b {torch.Tensor} -- content of x_b domain_synth {Boolean} -- Whether if the content is from s or r fool {Boolean} -- Wheter we want to fool the classifier or not Returns: torch.Float -- domain invariant perceptual loss """ # Infer domain classifier on content extracted from an image of domainA output_a = self.domain_classifier_sr_a(c_a) # Infer domain classifier on content extracted from an image of domainB output_b = self.domain_classifier_sr_b(c_b) if fool: loss = torch.mean((output_a - 0.5)**2) + torch.mean( (output_b - 0.5)**2) else: if domain_synth: loss = torch.mean((output_a - 0)**2) + torch.mean( (output_b - 0)**2) else: loss = torch.mean((output_a - 1)**2) + torch.mean( (output_b - 1)**2) return loss def compute_domain_adv_loss(self, c_a, c_b, compute_accuracy=False, minimize=True): """ Compute a domain adversarial loss on the embedding of the classifier: we are trying to learn an anonymized representation of the content. Arguments: c_a {torch.tensor} -- content extracted from an image of domain A with encoder A c_b {torch.tensor} -- content extracted from an image of domain B with encoder B Keyword Arguments: compute_accuracy {bool} -- either return only the loss or loss and softmax probs (default: {False}) minimize {bool} -- optimize classification accuracy(True) or anonymized the representation(False) Returns: torch.Float -- loss (optionnal softmax P(classifier(c_a)=a) and P(classifier(c_b)=b)) """ # Infer domain classifier on content extracted from an image of domainA output_a = self.domain_classifier_ab(c_a) # Infer domain classifier on content extracted from an image of domainB output_b = self.domain_classifier_ab(c_b) # Concatenate the output in a single vector output = torch.cat((output_a, output_b)) if minimize: target = torch.tensor([1.0, 0.0, 0.0, 1.0], device="cuda") else: target = torch.tensor([0.5, 0.5, 0.5, 0.5], device="cuda") # mean square error loss loss = torch.nn.MSELoss()(output, target) if compute_accuracy: return loss, output_a[0], output_b[1] else: return loss def compute_semantic_seg_loss(self, img1, img2, mask=None, ground_truth=None): """ Compute semantic segmentation loss between two images on the unmasked region or in the entire image Arguments: img1 {torch.Tensor} -- Image from domain A after transform in tensor format img2 {torch.Tensor} -- Image transformed mask {torch.Tensor} -- Binary mask where we force the loss to be zero ground_truth {torch.Tensor} -- If available palletized image of size (batch, h, w) Returns: torch.float -- Cross entropy loss on the unmasked region """ new_class = 19 # denorm img1_denorm = (img1 + 1) / 2.0 img2_denorm = (img2 + 1) / 2.0 # norm for semantic seg network input_transformed1 = seg_batch_transform(img1_denorm) input_transformed2 = seg_batch_transform(img2_denorm) # compute labels from original image and logits from translated version # target = ( # self.segmentation_model(input_transformed1).max(1)[1] # ) # Infer x_ab or x_ba output = self.segmentation_model(input_transformed2) # If we have a ground truth (simulated data), merge classes to fit the ground truth of our simulated world (19 to 10) if ground_truth is not None: target = ground_truth.type(torch.long).cuda() target = target.squeeze(1) output = merge_classes(output).cuda() new_class = 10 else: # Else use the pretrained model target = self.segmentation_model(input_transformed1).max(1)[1] # If we don't want to compute the loss on the masked region if not self.full_adaptation and mask is not None: # Resize mask to the size of the image # ADRIEN DANGEROUS TO CHAAAANGE mask1 = torch.nn.functional.interpolate(mask, size=(self.newsize, self.newsize)) mask1_tensor = torch.tensor(mask1, dtype=torch.long).cuda() mask1_tensor = mask1_tensor.squeeze(1) # we want the masked region to be labeled as unknown (19 is not an existing label) target_with_mask = (torch.mul(1 - mask1_tensor, target) + mask1_tensor * new_class ) # CATEGORICAL TENSOR (B 20 H W) (TARGET) mask2 = torch.nn.functional.interpolate(mask, size=(self.newsize, self.newsize)) mask_tensor = torch.tensor(mask2, dtype=torch.float).cuda() output_with_mask = torch.mul(1 - mask_tensor, output) # # cat the mask as to the logits (loss=0 over the masked region) output_with_mask_cat = torch.cat((output_with_mask, mask_tensor), dim=1) loss = nn.CrossEntropyLoss()(output_with_mask_cat, target_with_mask) else: loss = nn.CrossEntropyLoss()(output, target) return loss def sample(self, x_a, x_b, m_a, m_b): """ Infer the model on a batch of image Arguments: x_a {torch.Tensor} -- batch of image from domain A x_b {[type]} -- batch of image from domain B Returns: A list of torch images -- columnwise :x_a, autoencode(x_a), x_ab_1, x_ab_2 Or if self.semantic_w is true: x_a, autoencode(x_a), Semantic segmentation x_a, x_ab_1,semantic segmentation x_ab_1, x_ab_2 """ self.eval() x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] x_a_augment = torch.cat([x_a, m_a], dim=1) x_b_augment = torch.cat([x_b, m_b], dim=1) for i in range(x_a.size(0)): c_a = self.gen.encode(x_a[i].unsqueeze(0), 1) c_b = self.gen.encode(x_b[i].unsqueeze(0), 2) x_a_recon.append(self.gen.decode(c_a, m_a[i].unsqueeze(0), 1)) x_b_recon.append(self.gen.decode(c_b, m_b[i].unsqueeze(0), 2)) x_ba1.append(self.gen.decode(c_b, m_b[i].unsqueeze(0), 1)) # s_a1[i].unsqueeze(0))) x_ba2.append(self.gen.decode(c_b, m_b[i].unsqueeze(0), 1)) # s_a2[i].unsqueeze(0))) x_ab1.append(self.gen.decode(c_a, m_a[i].unsqueeze(0), 2)) # s_b1[i].unsqueeze(0))) x_ab2.append(self.gen.decode(c_a, m_a[i].unsqueeze(0), 2)) # s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) if self.semantic_w: rgb_a_list, rgb_b_list, rgb_ab_list, rgb_ba_list = [], [], [], [] for i in range(x_a.size(0)): # Inference semantic segmentation on original images im_a = (x_a[i].squeeze() + 1) / 2.0 im_b = (x_b[i].squeeze() + 1) / 2.0 input_transformed_a = seg_transform()(im_a).unsqueeze(0) input_transformed_b = seg_transform()(im_b).unsqueeze(0) output_a = self.segmentation_model( input_transformed_a).squeeze().max(0)[1] output_b = self.segmentation_model( input_transformed_b).squeeze().max(0)[1] rgb_a = decode_segmap(output_a.cpu().numpy()) rgb_b = decode_segmap(output_b.cpu().numpy()) rgb_a = Image.fromarray(rgb_a).resize( (x_a.size(3), x_a.size(3))) rgb_b = Image.fromarray(rgb_b).resize( (x_a.size(3), x_a.size(3))) rgb_a_list.append(transforms.ToTensor()(rgb_a).unsqueeze(0)) rgb_b_list.append(transforms.ToTensor()(rgb_b).unsqueeze(0)) # Inference semantic segmentation on fake images image_ab = (x_ab1[i].squeeze() + 1) / 2.0 image_ba = (x_ba1[i].squeeze() + 1) / 2.0 input_transformed_ab = seg_transform()(image_ab).unsqueeze( 0).to("cuda") input_transformed_ba = seg_transform()(image_ba).unsqueeze( 0).to("cuda") output_ab = self.segmentation_model( input_transformed_ab).squeeze().max(0)[1] output_ba = self.segmentation_model( input_transformed_ba).squeeze().max(0)[1] rgb_ab = decode_segmap(output_ab.cpu().numpy()) rgb_ba = decode_segmap(output_ba.cpu().numpy()) rgb_ab = Image.fromarray(rgb_ab).resize( (x_a.size(3), x_a.size(3))) rgb_ba = Image.fromarray(rgb_ba).resize( (x_a.size(3), x_a.size(3))) rgb_ab_list.append(transforms.ToTensor()(rgb_ab).unsqueeze(0)) rgb_ba_list.append(transforms.ToTensor()(rgb_ba).unsqueeze(0)) rgb1_a, rgb1_b, rgb1_ab, rgb1_ba = ( torch.cat(rgb_a_list).cuda(), torch.cat(rgb_b_list).cuda(), torch.cat(rgb_ab_list).cuda(), torch.cat(rgb_ba_list).cuda(), ) self.train() # Overlay mask onto image: save_m_a = x_a - (x_a * m_a.repeat(1, 3, 1, 1)) + m_a.repeat( 1, 3, 1, 1) save_m_b = x_b - (x_b * m_b.repeat(1, 3, 1, 1)) + m_b.repeat( 1, 3, 1, 1) if self.semantic_w: self.segmentation_model.eval() return ( x_a, x_a_recon, rgb1_a, x_ab1, rgb1_ab, x_ab1 * m_a, save_m_a, x_b, x_b_recon, rgb1_b, x_ba1, rgb1_ba, x_ba2 * m_b, save_m_b, ) else: return ( x_a, x_a_recon, x_ab1, x_ab1 * m_a, save_m_a, x_b, x_b_recon, x_ba1, x_ba2 * m_b, save_m_b, ) def sample_syn(self, x_a, x_b, m_a, m_b): """ Infer the model on a batch of image Arguments: x_a {torch.Tensor} -- batch of image from domain A x_b {[type]} -- batch of image from domain B Returns: A list of torch images -- columnwise :x_a, autoencode(x_a), x_ab_1, x_ab_2 Or if self.semantic_w is true: x_a, autoencode(x_a), Semantic segmentation x_a, x_ab_1,semantic segmentation x_ab_1, x_ab_2 """ self.eval() x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] x_a_augment = torch.cat([x_a, m_a], dim=1) x_b_augment = torch.cat([x_b, m_b], dim=1) for i in range(x_a.size(0)): c_a = self.gen.encode(x_a[i].unsqueeze(0), 1) c_b = self.gen.encode(x_b[i].unsqueeze(0), 2) x_a_recon.append(self.gen.decode(c_a, m_a[i].unsqueeze(0), 1)) x_b_recon.append(self.gen.decode(c_b, m_b[i].unsqueeze(0), 2)) x_ba1.append(self.gen.decode(c_b, m_b[i].unsqueeze(0), 1)) x_ba2.append(self.gen.decode(c_b, m_b[i].unsqueeze(0), 1)) x_ab1.append(self.gen.decode(c_a, m_a[i].unsqueeze(0), 2)) x_ab2.append(self.gen.decode(c_a, m_a[i].unsqueeze(0), 2)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) if self.semantic_w: rgb_a_list, rgb_b_list, rgb_ab_list, rgb_ba_list = [], [], [], [] for i in range(x_a.size(0)): # Inference semantic segmentation on original images im_a = (x_a[i].squeeze() + 1) / 2.0 im_b = (x_b[i].squeeze() + 1) / 2.0 input_transformed_a = seg_transform()(im_a).unsqueeze(0) input_transformed_b = seg_transform()(im_b).unsqueeze(0) output_a = self.segmentation_model( input_transformed_a).squeeze().max(0)[1] output_b = self.segmentation_model( input_transformed_b).squeeze().max(0)[1] rgb_a = decode_segmap(output_a.cpu().numpy()) rgb_b = decode_segmap(output_b.cpu().numpy()) rgb_a = Image.fromarray(rgb_a).resize( (x_a.size(3), x_a.size(3))) rgb_b = Image.fromarray(rgb_b).resize( (x_a.size(3), x_a.size(3))) rgb_a_list.append(transforms.ToTensor()(rgb_a).unsqueeze(0)) rgb_b_list.append(transforms.ToTensor()(rgb_b).unsqueeze(0)) # Inference semantic segmentation on fake images image_ab = (x_ab1[i].squeeze() + 1) / 2.0 image_ba = (x_ba1[i].squeeze() + 1) / 2.0 input_transformed_ab = seg_transform()(image_ab).unsqueeze( 0).to("cuda") input_transformed_ba = seg_transform()(image_ba).unsqueeze( 0).to("cuda") output_ab = self.segmentation_model( input_transformed_ab).squeeze().max(0)[1] output_ba = self.segmentation_model( input_transformed_ba).squeeze().max(0)[1] rgb_ab = decode_segmap(output_ab.cpu().numpy()) rgb_ba = decode_segmap(output_ba.cpu().numpy()) rgb_ab = Image.fromarray(rgb_ab).resize( (x_a.size(3), x_a.size(3))) rgb_ba = Image.fromarray(rgb_ba).resize( (x_a.size(3), x_a.size(3))) rgb_ab_list.append(transforms.ToTensor()(rgb_ab).unsqueeze(0)) rgb_ba_list.append(transforms.ToTensor()(rgb_ba).unsqueeze(0)) rgb1_a, rgb1_b, rgb1_ab, rgb1_ba = ( torch.cat(rgb_a_list).cuda(), torch.cat(rgb_b_list).cuda(), torch.cat(rgb_ab_list).cuda(), torch.cat(rgb_ba_list).cuda(), ) self.train() if self.semantic_w: self.segmentation_model.eval() return ( x_a, x_a_recon, rgb1_a, x_ab1, rgb1_ab, x_ab2, x_b, x_b_recon, rgb1_b, x_ba1, rgb1_ba, x_ba2, ) else: return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, m_a, m_b, hyperparameters, comet_exp=None): """ Update the weights of the discriminator Arguments: x_a {torch.Tensor} -- Image from domain A after transform in tensor format x_b {torch.Tensor} -- Image from domain B after transform in tensor format hyperparameters {dictionnary} -- dictionnary with all hyperparameters Keyword Arguments: comet_exp {cometExperience} -- CometML object use to log all the loss and images (default: {None}) """ self.dis_opt.zero_grad() x_a_augment = torch.cat([x_a, m_a], dim=1) x_b_augment = torch.cat([x_b, m_b], dim=1) # encode c_a = self.gen.encode(x_a, 1) c_b = self.gen.encode(x_b, 2) # decode (cross domain) x_ba = self.gen.decode(c_b, m_b, 1) x_ab = self.gen.decode(c_a, m_a, 2) x_ba_augment = torch.cat([x_ba, m_b], dim=1) x_ab_augment = torch.cat([x_ab, m_a], dim=1) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a, comet_exp, mode="a") self.loss_dis_a += self.dis_a_masked.calc_dis_loss(x_ba * m_b.detach(), x_a * m_a, comet_exp, mode="a") self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b, comet_exp, mode="b") self.loss_dis_b += self.dis_b_masked.calc_dis_loss(x_ab * m_a.detach(), x_b * m_b, comet_exp, mode="b") self.loss_dis_total = (hyperparameters["gan_w"] * self.loss_dis_a + hyperparameters["gan_w"] * self.loss_dis_b) self.loss_dis_total.backward() self.dis_opt.step() if comet_exp is not None: comet_exp.log_metric("loss_dis_b", self.loss_dis_b.cpu().detach()) comet_exp.log_metric("loss_dis_a", self.loss_dis_a.cpu().detach()) def domain_classifier_update(self, x_a, x_b, hyperparameters, comet_exp=None): """ Update the weights of the domain classifier Arguments: x_a {torch.Tensor} -- Image from domain A after transform in tensor format x_b {torch.Tensor} -- Image from domain B after transform in tensor format hyperparameters {dictionnary} -- dictionnary with all hyperparameters Keyword Arguments: comet_exp {cometExperience} -- CometML object use to log all the loss and images (default: {None}) """ self.dann_opt.zero_grad() # encode c_a = self.gen.encode(x_a, 1) c_b = self.gen.encode(x_b, 2) # domain classifier loss self.domain_class_loss, out_a, out_b = self.compute_domain_adv_loss( c_a, c_b, compute_accuracy=True, minimize=True) self.domain_class_loss.backward() self.dann_opt.step() if comet_exp is not None: comet_exp.log_metric("domain_class_loss", self.domain_class_loss.cpu().detach()) comet_exp.log_metric("probability A being identified as A", out_a.cpu().detach()) comet_exp.log_metric("probability B being identified as B", out_b.cpu().detach()) def domain_classifier_sr_update(self, x_a, x_b, m_a, m_b, domain_synth, lambda_classifier, step, comet_exp=None): self.classif_opt_sr.zero_grad() # encode c_a = self.gen.encode(x_a, 1) c_b = self.gen.encode(x_b, 2) # noise = c_a.data.new(c_a.size()).normal_(0, 1) loss = self.compute_classifier_sr_loss(c_a.detach(), c_b.detach(), domain_synth, fool=False) loss = lambda_classifier * loss loss.backward() self.classif_opt_sr.step() if comet_exp is not None: comet_exp.log_metric("loss_classifier_sr", loss.cpu().detach(), step=step) def output_domain_classifier_sr_update(self, x_ar, x_as, x_br, x_bs, hyperparameters, step, comet_exp=None): self.output_classif_opt_sr.zero_grad() loss = self.output_classifier_sr_b.calc_dis_loss_sr( x_bs, x_br) + self.output_classifier_sr_a.calc_dis_loss_sr( x_as, x_ar) loss = hyperparameters["adaptation"]["output_classifier_lambda"] * loss loss.backward() self.output_classif_opt_sr.step() if comet_exp is not None: comet_exp.log_metric("loss_output_classifier_sr", loss.cpu().detach(), step=step) def segmentation_head_update(self, x_a, x_b, target_a, target_b, lamb, comet_exp=None): self.segmentation_opt.zero_grad() # encode c_a = self.gen.encode(x_a, 1) c_b = self.gen.encode(x_b, 2) output_a = self.segmentation_head(c_a) output_b = self.segmentation_head(c_b) output_a = nn.functional.interpolate(input=output_a, size=(self.newsize, self.newsize), mode="bilinear") output_b = nn.functional.interpolate(input=output_b, size=(self.newsize, self.newsize), mode="bilinear") loss1 = nn.CrossEntropyLoss()(output_a, target_a.type( torch.long).squeeze(1).cuda()) loss2 = nn.CrossEntropyLoss()(output_b, target_b.type( torch.long).squeeze(1).cuda()) loss = (loss1 + loss2) * lamb loss.backward() self.segmentation_opt.step() if comet_exp is not None: comet_exp.log_metric("loss_semantic_head", loss.cpu().detach()) def update_learning_rate(self): """ Update the learning rate """ if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.dann_scheduler is not None: self.dann_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): """ Resume the training loading the network parameters Arguments: checkpoint_dir {string} -- path to the directory where the checkpoints are saved hyperparameters {dictionnary} -- dictionnary with all hyperparameters Returns: int -- number of iterations (used by the optimizer) """ # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen.load_state_dict(state_dict["2"]) # Load domain classifier if self.domain_classif_ab == 1: last_model_name = get_model_list(checkpoint_dir, "domain_classif") state_dict = torch.load(last_model_name) self.domain_classifier.load_state_dict(state_dict["d"]) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict["a"]) self.dis_b.load_state_dict(state_dict["b"]) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, "optimizer.pt")) self.dis_opt.load_state_dict(state_dict["dis"]) self.gen_opt.load_state_dict(state_dict["gen"]) if self.domain_classif_ab == 1: self.dann_opt.load_state_dict(state_dict["dann"]) self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters, iterations) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print("Resume from iteration %d" % iterations) return iterations def save(self, snapshot_dir, iterations): """ Save generators, discriminators, and optimizers Arguments: snapshot_dir {string} -- directory path where to save the networks weights iterations {int} -- number of training iterations """ # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, "gen_%08d.pt" % (iterations + 1)) dis_name = os.path.join(snapshot_dir, "dis_%08d.pt" % (iterations + 1)) domain_classifier_name = os.path.join( snapshot_dir, "domain_classifier_%08d.pt" % (iterations + 1)) opt_name = os.path.join(snapshot_dir, "optimizer.pt") torch.save({"2": self.gen.state_dict()}, gen_name) torch.save({ "a": self.dis_a.state_dict(), "b": self.dis_b.state_dict() }, dis_name) if self.domain_classif_ab: torch.save({"d": self.domain_classifier.state_dict()}, domain_classifier_name) torch.save( { "gen": self.gen_opt.state_dict(), "dis": self.dis_opt.state_dict(), "dann": self.dann_opt.state_dict(), }, opt_name, ) else: torch.save( { "gen": self.gen_opt.state_dict(), "dis": self.dis_opt.state_dict() }, opt_name, )
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] self.a_attibute = hyperparameters['label_a'] self.b_attibute = hyperparameters['label_b'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) if self.a_attibute == 0: self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() else: self.s_a = torch.randn(display_size, self.style_dim - self.a_attibute, 1, 1).cuda() s_attribute = [i % self.a_attibute for i in range(display_size)] s_attribute = torch.tensor(s_attribute, dtype=torch.long).reshape( (display_size, 1)) label_a = torch.zeros(display_size, self.a_attibute, dtype=torch.float32).scatter_( 1, s_attribute, 1) label_a = label_a.reshape(display_size, self.a_attibute, 1, 1).cuda() self.s_a = torch.cat([self.s_a, label_a], 1) if self.b_attibute == 0: self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() else: self.s_b = torch.randn(display_size, self.style_dim - self.b_attibute, 1, 1).cuda() s_attribute = [i % self.b_attibute for i in range(display_size)] s_attribute = torch.tensor(s_attribute, dtype=torch.long).reshape( (display_size, 1)) label_b = torch.zeros(display_size, self.b_attibute, dtype=torch.float32).scatter_( 1, s_attribute, 1) label_b = label_b.reshape(display_size, self.b_attibute, 1, 1).cuda() self.s_b = torch.cat([self.s_b, label_b], 1) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters, label_a=None, label_b=None): self.gen_opt.zero_grad() if label_a is None: s_a = Variable( torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) else: style_num = label_a.size(1) s_a = Variable( torch.randn(x_a.size(0), self.style_dim - style_num, 1, 1).cuda()) label_a = label_a.repeat(x_a.size(0), 1) label_a = label_a.reshape(x_a.size(0), style_num, 1, 1) s_a = torch.cat([s_a, label_a], 1) if label_b is None: s_b = Variable( torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) else: style_num = label_b.size(1) s_b = Variable( torch.randn(x_b.size(0), self.style_dim - style_num, 1, 1).cuda()) label_b = label_b.repeat(x_b.size(0), 1) label_b = label_b.reshape(x_b.size(0), style_num, 1, 1) s_b = torch.cat([s_b, label_b], 1) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode( c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a, self.loss_gen_class_a = self.dis_a.calc_gen_loss( x_ba, label_a) self.loss_gen_adv_b, self.loss_gen_class_b = self.dis_b.calc_gen_loss( x_ab, label_b) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['gan_w'] * self.loss_gen_class_a + \ hyperparameters['gan_w'] * self.loss_gen_class_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update( self, x_a, x_b, hyperparameters, label_a=None, label_b=None, ): self.dis_opt.zero_grad() if label_a is None: s_a = Variable( torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) else: # utilize label in the style code style_num = label_a.size(1) s_a = Variable( torch.randn(x_a.size(0), self.style_dim - style_num, 1, 1).cuda()) label_a = label_a.repeat(x_a.size(0), 1) label_a = label_a.reshape(x_a.size(0), style_num, 1, 1) s_a = torch.cat([s_a, label_a], 1) if label_b is None: s_b = Variable( torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) else: # utilize label in the style code style_num = label_b.size(1) s_b = Variable( torch.randn(x_b.size(0), self.style_dim - style_num, 1, 1).cuda()) label_b = label_b.repeat(x_b.size(0), 1) label_b = label_b.reshape(x_b.size(0), style_num, 1, 1) s_b = torch.cat([s_b, label_b], 1) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ab = self.gen_b.decode(c_a, s_b) x_ba = self.gen_a.decode(c_b, s_a) # D loss self.loss_dis_a, self.loss_class_a = self.dis_a.calc_dis_loss( x_ba.detach(), x_a, label_a) self.loss_dis_b, self.loss_class_b = self.dis_b.calc_dis_loss( x_ab.detach(), x_b, label_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b + \ hyperparameters['gan_w'] * self.loss_class_a + hyperparameters['gan_w'] * self.loss_class_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class ERGAN_Trainer(nn.Module): def __init__(self, hyperparameters): super(ERGAN_Trainer, self).__init__() lr_G = hyperparameters['lr_G'] lr_D = hyperparameters['lr_D'] print(lr_D, lr_G) self.fp16 = hyperparameters['fp16'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.gen_b.enc_content = self.gen_a.enc_content # content share weight #self.gen_b.enc_style = self.gen_a.enc_style self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] self.a = hyperparameters['gen']['new_size'] / 224 # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr_D, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr_G, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) if self.fp16: self.gen_a = self.gen_a.cuda() self.dis_a = self.dis_a.cuda() self.gen_b = self.gen_b.cuda() self.dis_b = self.dis_b.cuda() self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1") self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1") # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): input = input.type_as(target) return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a, x_b) x_ab = self.gen_b.decode(c_a, s_b, x_a) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() #mask = torch.ones(x_a.shape).cuda() block_a = x_a.clone() block_b = x_b.clone() block_a[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)] = 0 block_b[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)] = 0 # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime, x_a) x_b_recon = self.gen_b.decode(c_b, s_b_prime, x_b) # decode (cross domain) # decode random # x_ba_randn = self.gen_a.decode(c_b, s_a, x_b) # x_ab_randn = self.gen_b.decode(c_a, s_b, x_a) # decode real x_ba_real = self.gen_a.decode(c_b, s_a_prime, x_b) x_ab_real = self.gen_b.decode(c_a, s_b_prime, x_a) block_ba_real = x_ba_real.clone() block_ab_real = x_ab_real.clone() block_ba_real[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)] = 0 block_ab_real[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)] = 0 # encode again # c_b_recon, s_a_recon = self.gen_a.encode(x_ba_randn) # c_a_recon, s_b_recon = self.gen_b.encode(x_ab_randn) c_b_real_recon, s_a_prime_recon = self.gen_a.encode(x_ba_real) c_a_real_recon, s_b_prime_recon = self.gen_b.encode(x_ab_real) # decode again (if needed) x_aba = self.gen_a.decode( c_a_real_recon, s_a_prime, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_b_real_recon, s_b_prime, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_res_a = self.recon_criterion( block_ab_real, block_a) self.loss_gen_recon_res_b = self.recon_criterion( block_ba_real, block_b) self.loss_gen_recon_x_a_re = self.recon_criterion( x_a_recon[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)], x_a[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)]) self.loss_gen_recon_x_b_re = self.recon_criterion( x_b_recon[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)], x_b[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)] ) # both celebA and MeGlass: [92:144, 48:172] self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a_prime = self.recon_criterion( s_a_prime_recon, s_a_prime) self.loss_gen_recon_s_b_prime = self.recon_criterion( s_b_prime_recon, s_b_prime) self.loss_gen_recon_c_a_real = self.recon_criterion( c_a_real_recon, c_a) self.loss_gen_recon_c_b_real = self.recon_criterion( c_b_real_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a_real = self.dis_a.calc_gen_loss(x_ba_real) self.loss_gen_adv_b_real = self.dis_b.calc_gen_loss(x_ab_real) # domain-invariant perceptual loss # self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba_randn, x_b) if hyperparameters['vgg_w'] > 0 else 0 # self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab_randn, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a_real + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_x_w_re'] * self.loss_gen_recon_x_b_re + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b_prime + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b_real + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b+\ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_x_w_re'] * self.loss_gen_recon_x_a_re + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a_prime + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a_real + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b_real +\ hyperparameters['recon_x_w_res'] * self.loss_gen_recon_res_b + \ hyperparameters['recon_x_w_res'] * self.loss_gen_recon_res_b if self.fp16: with amp.scale_loss(self.loss_gen_total, self.gen_opt) as scaled_loss: scaled_loss.backward() self.gen_opt.step() else: self.loss_gen_total.backward() self.gen_opt.step() #loss_gan = hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b) loss_gan_real = hyperparameters['gan_w'] * (self.loss_gen_adv_a_real + self.loss_gen_adv_b_real) loss_x = hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b) loss_x_re = hyperparameters['recon_x_w_re'] * ( self.loss_gen_recon_x_a_re + self.loss_gen_recon_x_b_re) #loss_x_res = hyperparameters['recon_x_w_res'] * (self.loss_gen_recon_res_a + self.loss_gen_recon_res_b) #loss_s = hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b) #loss_c = hyperparameters['recon_c_w'] * (self.loss_gen_recon_c_a + self.loss_gen_recon_c_b) loss_s_prime = hyperparameters['recon_s_w'] * ( self.loss_gen_recon_s_a_prime + self.loss_gen_recon_s_b_prime) loss_c_real = hyperparameters['recon_c_w'] * ( self.loss_gen_recon_c_a_real + self.loss_gen_recon_c_b_real) loss_x_cyc = hyperparameters['recon_x_cyc_w'] * ( self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b) #loss_vgg = hyperparameters['vgg_w'] * (self.loss_gen_vgg_a + self.loss_gen_vgg_b) print( '||total:%.2f||gan_real:%.2f||x:%.2f||x_re:%.2f||s_prime:%.4f||c_real:%.2f||x_cyc:%.4f||' % (self.loss_gen_total, loss_gan_real, loss_x, loss_x_re, loss_s_prime, loss_c_real, loss_x_cyc)) def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() # s_a1 = Variable(self.s_a) # s_b1 = Variable(self.s_b) # s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) #x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] x_a_recon, x_b_recon, x_bab, x_ab, x_ba, x_aba = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append( self.gen_a.decode(c_a, s_a_fake, x_a[i].unsqueeze(0))) x_b_recon.append( self.gen_b.decode(c_b, s_b_fake, x_b[i].unsqueeze(0))) x_ba_tmp = self.gen_a.decode(c_b, s_a_fake, x_b[i].unsqueeze(0)) x_ab_tmp = self.gen_b.decode(c_a, s_b_fake, x_a[i].unsqueeze(0)) x_ba.append(x_ba_tmp) x_ab.append(x_ab_tmp) c_b_recon, _ = self.gen_a.encode(x_ba_tmp) c_a_recon, _ = self.gen_b.encode(x_ab_tmp) x_aba.append( self.gen_a.decode(c_a_recon, s_a_fake, x_a[i].unsqueeze(0))) x_bab.append( self.gen_b.decode(c_b_recon, s_b_fake, x_b[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) x_ab, x_ba = torch.cat(x_ab), torch.cat(x_ba) self.train() return x_a, x_a_recon, x_ab, x_ba, x_b, x_aba, x_b, x_b_recon, x_ba, x_ab, x_a, x_bab def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() # s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (cross domain) x_ba_real = self.gen_a.decode(c_b, s_a_prime, x_b) x_ab_real = self.gen_b.decode(c_a, s_b_prime, x_a) # D loss # self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba_randn.detach(), x_a) # self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab_randn.detach(), x_b) self.loss_dis_a_real = self.dis_a.calc_dis_loss( x_ba_real.detach(), x_a) self.loss_dis_b_real = self.dis_b.calc_dis_loss( x_ab_real.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * ( self.loss_dis_a_real + self.loss_dis_b_real) if self.fp16: with amp.scale_loss(self.loss_dis_total, self.dis_opt) as scaled_loss: scaled_loss.backward() self.dis_opt.step() else: self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations - 375000) #fine_tune -370000 self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations - 375000) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] ''' input_dim_a和input_dim_b是输入图像的维度,RGB图就是3 gen和dis是在yaml中定义的与架构相关的配置 ''' # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() ''' 为每幅显示的图像(总共16幅)配置随机的风格(维度为8) ''' # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) ''' 这种简洁的写法值得学习:先将parameter()的list并起来,然后[p for p in params if p.requires_grad] 这里分别为判别器参数、生成器参数各自建立一个优化器 优化器采用Adam,算法参数为0.5和0.999 优化器中可同时配置权重衰减,这里是1e-4 学习率调节器默认配置为每100000步减小为0.5 ''' # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) ''' 注:这个apply函数递归地对每个子模块应用某种函数 ''' # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False '''默认配置中,没有使用这个vgg网络''' def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) # 注,只有在forward内部,是evaluation模式,具体这个方法在哪里用到了,我还不太清楚。 def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # 两个随机的风格码 s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) # prime表示是由真图解码来的风格码 c_b, s_b_prime = self.gen_b.encode( x_b) # c码为(1,256,64,64);s码为(1,8,1,1) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) # (a)用内容码和风格码还原原图 x_b_recon = self.gen_b.decode(c_b, s_b_prime) # (b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode( c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) # (a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) # (b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) # (c) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) # (d) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) # (e) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) # (f) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # (g) self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # (h) # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) # (i) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # (j) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 # (k) self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # (l) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) # 送进去两张图片(batch),交换他们的风格 def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) # 这个是固定的某种风格 s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # 这是即时生成的随机风格 s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode( x_a[i].unsqueeze(0)) # 这是把送入的图片进行编码 c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) # 这是对送入的图像进行重建 x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode( c_b, s_a1[i].unsqueeze(0))) # 这是把固定的风格施加在b的内容上,产生a风格的图片 x_ba2.append(self.gen_a.decode( c_b, s_a2[i].unsqueeze(0))) # 这是把随机风格施加在b的内容上,产生a风格的图片 x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # 生成图片的随机风格码 (1,8,1,1) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) # 将图像利用编码器 c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['new_size'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] self.reg_param = hyperparameters['reg_param'] self.beta_step = hyperparameters['beta_step'] self.target_kl = hyperparameters['target_kl'] self.gan_type = hyperparameters['gan_type'] # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_b.parameters()) gen_params = list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.gen_b.apply(weights_init(hyperparameters['init'])) self.dis_b.apply(weights_init('gaussian')) # SSIM Loss self.ssim_loss = pytorch_ssim.SSIM() def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def recon_criterion_l1(self, input, target, mask): return torch.sum(torch.abs(input - target)) / torch.sum(mask) def forward(self, x_a, x_b): self.eval() s_b = self.gen_b.enc_style(x_b) c_a = self.gen_b.enc_content(x_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab def gen_update(self, x_a, x_b, hyperparameters): toogle_grad(self.dis_b, False) toogle_grad(self.gen_b, True) self.dis_b.train() self.gen_b.train() self.gen_opt.zero_grad() s_b = torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda() # encode c_a = self.gen_b.enc_content(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode x_ab = self.gen_b.decode(c_a, s_b) # encode again c_a_recon, s_b_recon = self.gen_b.encode(x_ab) x_ab.requires_grad_() # reconstruction loss self.loss_gen_recon_x_ab_ssim = -self.ssim_loss.forward(x_a, x_ab) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) # GAN loss _, _, d_fake = self.dis_b(x_ab) # d_fake = d_fake['out'] self.loss_gen_adv_b = self.compute_loss(d_fake, 1) # total loss self.loss_gen_total = self.loss_gen_adv_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_x_ab'] * self.loss_gen_recon_x_ab_ssim self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() x_ab = [] s_b = self.gen_b.enc_style(x_b) for i in range(x_a.size(0)): c_a = self.gen_b.enc_content(x_a[i].unsqueeze(0)) x_ab.append(self.gen_b.decode(c_a, s_b)) x_ab = torch.cat(x_ab) self.train() return x_a, x_ab def dis_update(self, x_a, x_b, hyperparameters): toogle_grad(self.gen_b, False) toogle_grad(self.dis_b, True) self.gen_b.train() self.dis_b.train() self.dis_opt.zero_grad() # On real data x_b.requires_grad_() d_real_dict = self.dis_b(x_b) d_real = d_real_dict[2] dloss_real = self.compute_loss(d_real, 1) reg = 0. # Both grad penal and vgan! dloss_real.backward(retain_graph=True) # hard coded 10 weight for grad penal. reg += 10. * compute_grad2(d_real, x_b).mean() mu = d_real_dict[0] logstd = d_real_dict[1] kl_real = kl_loss(mu, logstd).mean() # On fake data with torch.no_grad(): s_b = torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda() c_a = self.gen_b.enc_content(x_a) x_ab = self.gen_b.decode(c_a, s_b) x_ab.requires_grad_() d_fake_dict = self.dis_b(x_ab) d_fake = d_fake_dict[2] dloss_fake = self.compute_loss(d_fake, 0) dloss_fake.backward(retain_graph=True) mu_fake = d_fake_dict[0] logstd_fake = d_fake_dict[1] kl_fake = kl_loss(mu_fake, logstd_fake).mean() avg_kl = 0.5 * (kl_real + kl_fake) reg += self.reg_param * avg_kl reg.backward() self.update_beta(avg_kl) self.dis_opt.step() self.loss_dis_total = (dloss_real + dloss_fake) return self.loss_dis_total.item() def compute_loss(self, d_out, target): targets = d_out.new_full(size=d_out.size(), fill_value=target) if self.gan_type == 'standard': loss = F.binary_cross_entropy_with_logits(d_out, targets) elif self.gan_type == 'wgan': loss = (2 * target - 1) * d_out.mean() else: raise NotImplementedError return loss def update_beta(self, avg_kl): with torch.no_grad(): new_beta = self.reg_param - self.beta_step * \ (self.target_kl - avg_kl) # self.target_kl is constrain I_c, new_beta = max(new_beta, 0) # print('setting beta from %.2f to %.2f' % (self.reg_param, new_beta)) self.reg_param = new_beta def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % iterations) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % iterations) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'b': self.gen_b.state_dict()}, gen_name) torch.save({'b': self.dis_b.state_dict()}, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class IPMNet_Trainer(nn.Module): def __init__(self, hyperparameters): super(IPMNet_Trainer, self).__init__() lr = hyperparameters['lr'] vgg_weight_file = hyperparameters['vgg_weight_file'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = self.gen_a # AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init(hyperparameters['init'])) self.dis_b.apply(weights_init(hyperparameters['init'])) # Load VGGFace model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_resnet50(vgg_weight_file) self.vgg.eval() self.vgg.fc.reset_parameters() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def gen_update(self, x_a, x_b, mask_a, mask_b, texture_a, texture_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime, x_a_gray_facial = self.gen_a.encode( x_a, mask_a, texture_a) c_b, s_b_prime, x_b_gray_facial = self.gen_b.encode( x_b, mask_b, texture_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime, x_a_gray_facial) x_b_recon = self.gen_b.decode(c_b, s_b_prime, x_b_gray_facial) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a, x_b_gray_facial) x_ab = self.gen_b.decode(c_a, s_b, x_a_gray_facial) # encode again c_a_recon, s_b_recon, x_a_recon_gray_facial = self.gen_b.encode( x_ab, mask_a, texture_a) c_b_recon, s_a_recon, x_b_recon_gray_facial = self.gen_a.encode( x_ba, mask_b, texture_b) # decode again (if needed) x_aba = self.gen_a.decode( c_a_recon, s_a_prime, x_a_recon_gray_facial ) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_b_recon, s_b_prime, x_b_recon_gray_facial ) if hyperparameters['recon_x_cyc_w'] > 0 else None # background x_a_back = x_a * mask_a.repeat(1, 3, 1, 1) x_b_back = x_b * mask_b.repeat(1, 3, 1, 1) x_ab_back = x_ab * mask_a.repeat(1, 3, 1, 1) x_ba_back = x_ba * mask_b.repeat(1, 3, 1, 1) # foreground x_a_fore = x_a * (1 - mask_a).repeat(1, 3, 1, 1) x_b_fore = x_b * (1 - mask_b).repeat(1, 3, 1, 1) x_a_recon_fore = x_a_recon * (1 - mask_a).repeat(1, 3, 1, 1) x_b_recon_fore = x_b_recon * (1 - mask_b).repeat(1, 3, 1, 1) # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # backgrouned loss self.loss_back_x_a = self.recon_criterion( x_ab_back, x_a_back) if hyperparameters['back_w'] > 0 else 0 self.loss_back_x_b = self.recon_criterion( x_ba_back, x_b_back) if hyperparameters['back_w'] > 0 else 0 # foreground loss self.loss_fore_x_a = self.recon_criterion( x_a_recon_fore, x_a_fore) if hyperparameters['fore_w'] > 0 else 0 self.loss_fore_x_b = self.recon_criterion( x_b_recon_fore, x_b_fore) if hyperparameters['fore_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b +\ hyperparameters['back_w'] * self.loss_back_x_a +\ hyperparameters['back_w'] * self.loss_back_x_b +\ hyperparameters['fore_w'] * self.loss_fore_x_a +\ hyperparameters['fore_w'] * self.loss_fore_x_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean(torch.abs(img_fea - target_fea)) def sample(self, x_a, x_b, mask_a, mask_b, texture_a, texture_b, hyperparameters, train=True): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_a_facial_mask, x_b_facial_mask, x_ba, x_ab, x_aba, x_bab = [], [], [], [], [], [], [], [] x_ab1, x_ab2, x_ba1, x_ba2 = [], [], [], [] for i in range(x_a.size(0)): c_a, s_a, x_a_gray_facial = self.gen_a.encode( x_a[i].unsqueeze(0), mask_a[i].unsqueeze(0), texture_a[i].unsqueeze(0)) c_b, s_b, x_b_gray_facial = self.gen_b.encode( x_b[i].unsqueeze(0), mask_b[i].unsqueeze(0), texture_b[i].unsqueeze(0)) if train: if i == 0: print(s_a.squeeze()) print(s_b.squeeze()) x_a_recon.append(self.gen_a.decode(c_a, s_a, x_a_gray_facial)) x_b_recon.append(self.gen_b.decode(c_b, s_b, x_b_gray_facial)) x_a_facial_mask.append(x_a_gray_facial) x_b_facial_mask.append(x_b_gray_facial) x_ba.append(self.gen_a.decode(c_b, s_a, x_b_gray_facial)) x_ab.append(self.gen_b.decode(c_a, s_b, x_a_gray_facial)) # randn style x_ba1.append( self.gen_a.decode(c_b, s_a1[i].unsqueeze(0), x_b_gray_facial)) x_ab1.append( self.gen_b.decode(c_a, s_b1[i].unsqueeze(0), x_a_gray_facial)) x_ba2.append( self.gen_a.decode(c_b, s_a2[i].unsqueeze(0), x_b_gray_facial)) x_ab2.append( self.gen_b.decode(c_a, s_b2[i].unsqueeze(0), x_a_gray_facial)) # encode again c_a_recon, _, x_a_recon_gray_facial = self.gen_a.encode( x_ab[i], mask_a[i].unsqueeze(0), texture_a[i].unsqueeze(0)) c_b_recon, _, x_b_recon_gray_facial = self.gen_b.encode( x_ba[i], mask_b[i].unsqueeze(0), texture_b[i].unsqueeze(0)) # decode again (if needed) x_aba_recon = self.gen_a.decode( c_a_recon, s_a, x_a_recon_gray_facial ) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab_recon = self.gen_b.decode( c_b_recon, s_b, x_b_recon_gray_facial ) if hyperparameters['recon_x_cyc_w'] > 0 else None x_aba.append(x_aba_recon) x_bab.append(x_bab_recon) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_a_facial_mask, x_b_facial_mask = torch.cat( x_a_facial_mask), torch.cat(x_b_facial_mask) x_ab, x_ba = torch.cat(x_ab), torch.cat(x_ba) x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) x_ab1, x_ab2, x_ba1, x_ba2 = torch.cat(x_ab1), torch.cat( x_ab2), torch.cat(x_ba1), torch.cat(x_ba2) self.train() return x_a, x_b, x_a_recon, x_a_facial_mask, x_ab, x_aba, \ x_b, x_a, x_b_recon, x_b_facial_mask, x_ba, x_bab def dis_update(self, x_a, x_b, mask_a, mask_b, texture_a, texture_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _, x_a_gray_facial = self.gen_a.encode(x_a, mask_a, texture_a) c_b, _, x_b_gray_facial = self.gen_b.encode(x_b, mask_b, texture_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a, x_b_gray_facial) x_ab = self.gen_b.decode(c_a, s_b, x_a_gray_facial) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters, opts): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] self.loss = {} # fix the noise used in sampling self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(opts.output_base + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def print_network(self, model, name): num_params = 0 for p in model.parameters(): num_params += p.numel() logger.info( '{} - {} - Number of parameters: {}'.format(name, model, num_params)) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss['G/rec_x_A'] = self.recon_criterion(x_a_recon, x_a) self.loss['G/rec_x_B'] = self.recon_criterion(x_b_recon, x_b) self.loss['G/rec_s_A'] = self.recon_criterion(s_a_recon, s_a) self.loss['G/rec_s_B'] = self.recon_criterion(s_b_recon, s_b) self.loss['G/rec_c_A'] = self.recon_criterion(c_a_recon, c_a) self.loss['G/rec_c_B'] = self.recon_criterion(c_b_recon, c_b) self.loss['G/cycrec_x_A'] = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss['G/cycrec_x_B'] = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss['G/adv_A'] = self.dis_a.calc_gen_loss(x_ba) self.loss['G/adv_B'] = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss['G/vgg_A'] = self.compute_vgg_loss(self.vgg, x_ba.cuda(), x_b.cuda()) if hyperparameters['vgg_w'] > 0 else 0 self.loss['G/vgg_B'] = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss['G/total'] = hyperparameters['gan_w'] * self.loss['G/adv_A'] + \ hyperparameters['gan_w'] * self.loss['G/adv_B'] + \ hyperparameters['recon_x_w'] * self.loss['G/rec_x_A'] + \ hyperparameters['recon_s_w'] * self.loss['G/rec_s_A'] + \ hyperparameters['recon_c_w'] * self.loss['G/rec_c_A'] + \ hyperparameters['recon_x_w'] * self.loss['G/rec_x_B'] + \ hyperparameters['recon_s_w'] * self.loss['G/rec_s_B'] + \ hyperparameters['recon_c_w'] * self.loss['G/rec_c_B'] + \ hyperparameters['recon_x_cyc_w'] * self.loss['G/cycrec_x_A'] + \ hyperparameters['recon_x_cyc_w'] * self.loss['G/cycrec_x_B'] + \ hyperparameters['vgg_w'] * self.loss['G/vgg_A'] + \ hyperparameters['vgg_w'] * self.loss['G/vgg_B'] self.loss['G/total'].backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a_fake.unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b_fake.unsqueeze(0))) outputs = {} outputs['A/real'] = x_a outputs['B/real'] = x_b outputs['A/rec'] = torch.cat(x_a_recon) outputs['B/rec'] = torch.cat(x_b_recon) outputs['A/B_random_style'] = torch.cat(x_ab1) outputs['A/B'] = torch.cat(x_ab2) outputs['B/A_random_style'] = torch.cat(x_ba1) outputs['B/A'] = torch.cat(x_ba2) self.train() return outputs def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss self.loss['D/A'] = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss['D/B'] = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss['D/total'] = hyperparameters['gan_w'] * self.loss['D/A'] + hyperparameters['gan_w'] * self.loss['D/B'] self.loss['D/total'].backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: old_lr = self.dis_opt.param_groups[0]['lr'] self.dis_scheduler.step() new_lr = self.dis_opt.param_groups[0]['lr'] if old_lr != new_lr: logger.info('Updated D learning rate: {}'.format(new_lr)) if self.gen_scheduler is not None: old_lr = self.gen_opt.param_groups[0]['lr'] self.gen_scheduler.step() new_lr = self.gen_opt.param_groups[0]['lr'] if old_lr != new_lr: logger.info('Updated G learning rate: {}'.format(new_lr)) def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) logger.info('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name) logger.info('Saving snapshots to: {}'.format(snapshot_dir))
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.content_classifier = ContentClassifier( hyperparameters['gen']['dim'], hyperparameters) self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] self.num_con_c = hyperparameters['dis']['num_con_c'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] # dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) dis_named_params = list(self.dis_a.named_parameters()) + list( self.dis_b.named_parameters()) # gen_named_params = list(self.gen_a.named_parameters()) + list(self.gen_b.named_parameters()) ### modifying list params dis_params = list() # gen_params = list() for name, param in dis_named_params: if "_Q" in name: # print('%s --> gen_params' % name) gen_params.append(param) else: dis_params.append(param) content_classifier_params = list(self.content_classifier.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.cla_opt = torch.optim.Adam( [p for p in content_classifier_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) self.cla_scheduler = get_scheduler(self.cla_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) self.content_classifier.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False self.gan_type = hyperparameters['dis']['gan_type'] self.criterionQ_con = NormalNLLLoss() self.criterion_content_classifier = nn.CrossEntropyLoss() # self.batch_size = hyperparameters['batch_size'] self.batch_size_val = hyperparameters['batch_size_val'] # self.accu_content_classifier_c_a = 0 # self.accu_content_classifier_c_a_recon = 0 # self.accu_content_classifier_c_b = 0 # self.accu_content_classifier_c_b_recon = 0 # self.accu_CC_all = 0 def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) self.train() return x_ab, x_ba def gen_update(self, x_a, sample_b, hyperparameters, sample_a_limited): x_b, label_b = sample_b x_a_limited, label_a_limited = sample_a_limited self.gen_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) # x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) # GAN loss # self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) x_ba_dis_out = self.dis_a(x_ba) self.loss_gen_adv_a = self.compute_gen_adv_loss(x_ba_dis_out) # self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) x_ab_dis_out = self.dis_b(x_ab) self.loss_gen_adv_b = self.compute_gen_adv_loss(x_ab_dis_out) # loss info continuous self.info_cont_loss_a = self.compute_info_cont_loss(s_a, x_ba_dis_out) self.info_cont_loss_b = self.compute_info_cont_loss(s_b, x_ab_dis_out) # label_predict_c_a = self.content_classifier(c_a) # label_predict_c_a_recon = self.content_classifier(c_a_recon) label_predict_c_b = self.content_classifier(c_b) label_predict_c_b_recon = self.content_classifier(c_b_recon) # loss_content_classifier_c_a = self.compute_content_classifier_loss(label_predict_c_a, label_a) # loss_content_classifier_c_a_recon = self.compute_content_classifier_loss(label_predict_c_a_recon, label_a) ### compute loss of classifier c_a based on limited samples c_a_limited, _ = self.gen_a.encode(x_a_limited) label_predict_c_a_limited = self.content_classifier(c_a_limited) x_ab_limited = self.gen_b.decode(c_a_limited, s_b) c_a_recon_limited, _ = self.gen_b.encode(x_ab_limited) label_predict_c_a_recon_limited = self.content_classifier( c_a_recon_limited) loss_content_classifier_c_a = self.compute_content_classifier_loss( label_predict_c_a_limited, label_a_limited) loss_content_classifier_c_a_recon = self.compute_content_classifier_loss( label_predict_c_a_recon_limited, label_a_limited) loss_content_classifier_b = self.compute_content_classifier_loss( label_predict_c_b, label_b) loss_content_classifier_c_b_recon = self.compute_content_classifier_loss( label_predict_c_b_recon, label_b) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ self.info_cont_loss_a + \ self.info_cont_loss_b +\ loss_content_classifier_c_a + \ loss_content_classifier_c_a_recon + \ loss_content_classifier_b + \ loss_content_classifier_c_b_recon self.loss_gen_total.backward() self.gen_opt.step() def compute_info_cont_loss(self, style_code, outs_fake): loss = 0 num_cont_code = self.num_con_c for it, (out_fake) in enumerate(outs_fake): q_mu = out_fake['mu'] q_var = out_fake['var'] info_noise = style_code[:, -num_cont_code:].view( -1, num_cont_code).squeeze().squeeze() # print(q_mu.size()) # print(q_var.size()) # print(info_noise.size()) # print(num_cont_code) # exit() loss += self.criterionQ_con(info_noise, q_mu, q_var) * 0.1 return loss def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def cla_update(self, sample_a, sample_b): x_a, label_a = sample_a x_b, label_b = sample_b # print('cla_update') # print(x_a.device()) # exit() self.cla_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # print("c_a") # print(c_a.size()) # exit() # decode (within domain) # x_a_recon = self.gen_a.decode(c_a, s_a_prime) # x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) label_predict_c_a = self.content_classifier(c_a) label_predict_c_a_recon = self.content_classifier(c_a_recon) label_predict_c_b = self.content_classifier(c_b) label_predict_c_b_recon = self.content_classifier(c_b_recon) self.loss_content_classifier_c_a = self.compute_content_classifier_loss( label_predict_c_a, label_a) self.loss_content_classifier_c_a_recon = self.compute_content_classifier_loss( label_predict_c_a_recon, label_a) # self.loss_content_classifier_c_a_and_c_a_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_a_recon, label_predict_c_a) self.loss_content_classifier_b = self.compute_content_classifier_loss( label_predict_c_b, label_b) self.loss_content_classifier_c_b_recon = self.compute_content_classifier_loss( label_predict_c_b_recon, label_b) # self.loss_content_classifier_c_b_and_c_b_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_b_recon, label_predict_c_b) # self.accu_content_classifier_c_a = self.compute_content_classifier_accuracy(label_predict_c_a, label_a) # self.accu_content_classifier_c_a_recon = self.compute_content_classifier_accuracy(label_predict_c_a_recon, # label_a) # self.accu_content_classifier_c_b = self.compute_content_classifier_accuracy(label_predict_c_b, label_b) # self.accu_content_classifier_c_b_recon = self.compute_content_classifier_accuracy(label_predict_c_b_recon, # label_b) # self.accu_CC_all = self.mean_list([ # self.accu_content_classifier_c_a, # self.accu_content_classifier_c_a_recon, # self.accu_content_classifier_c_b, # self.accu_content_classifier_c_b_recon # ]) self.loss_cla_total = self.loss_content_classifier_c_a + self.loss_content_classifier_c_a_recon + \ self.loss_content_classifier_b + self.loss_content_classifier_c_b_recon # self.loss_content_classifier_c_a_and_c_a_recon + \ # self.loss_content_classifier_c_b_and_c_b_recon self.loss_cla_total.backward() self.cla_opt.step() def cla_inference(self, test_loader_a, test_loader_b): accu_content_classifier_c_a = [] accu_content_classifier_c_a_recon = [] accu_content_classifier_c_b = [] accu_content_classifier_c_b_recon = [] for it_inf, (samples_a_test, samples_b_test) in enumerate( zip(test_loader_a, test_loader_b)): x_a, label_a = samples_a_test[0].cuda().detach( ), samples_a_test[1].cuda().detach() x_b, label_b = samples_b_test[0].cuda().detach( ), samples_b_test[1].cuda().detach() s_a = Variable( torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable( torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # print("c_a") # print(c_a.size()) # exit() # decode (within domain) # x_a_recon = self.gen_a.decode(c_a, s_a_prime) # x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) label_predict_c_a = self.content_classifier(c_a) label_predict_c_a_recon = self.content_classifier(c_a_recon) label_predict_c_b = self.content_classifier(c_b) label_predict_c_b_recon = self.content_classifier(c_b_recon) # self.loss_content_classifier_c_a = self.compute_content_classifier_loss(label_predict_c_a, label_a) # self.loss_content_classifier_c_a_recon = self.compute_content_classifier_loss(label_predict_c_a_recon, label_a) # self.loss_content_classifier_c_a_and_c_a_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_a_recon, label_predict_c_a) # # self.loss_content_classifier_b = self.compute_content_classifier_loss(label_predict_c_b, label_b) # self.loss_content_classifier_c_b_recon = self.compute_content_classifier_loss(label_predict_c_b_recon, label_b) # self.loss_content_classifier_c_b_and_c_b_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_b_recon, # label_predict_c_b) accu_content_classifier_c_a.append( self.compute_content_classifier_accuracy( label_predict_c_a, label_a)) accu_content_classifier_c_a_recon.append( self.compute_content_classifier_accuracy( label_predict_c_a_recon, label_a)) accu_content_classifier_c_b.append( self.compute_content_classifier_accuracy( label_predict_c_b, label_b)) accu_content_classifier_c_b_recon.append( self.compute_content_classifier_accuracy( label_predict_c_b_recon, label_b)) self.accu_content_classifier_c_a = self.mean_list( accu_content_classifier_c_a) self.accu_content_classifier_c_a_recon = self.mean_list( accu_content_classifier_c_a_recon) self.accu_content_classifier_c_b = self.mean_list( accu_content_classifier_c_b) self.accu_content_classifier_c_b_recon = self.mean_list( accu_content_classifier_c_b_recon) self.accu_CC_all = self.mean_list([ self.accu_content_classifier_c_a, self.accu_content_classifier_c_a_recon, self.accu_content_classifier_c_b, self.accu_content_classifier_c_b_recon ]) # self.loss_cla_total = self.loss_content_classifier_c_a + self.loss_content_classifier_c_a_recon + \ # self.loss_content_classifier_b + self.loss_content_classifier_c_b_recon + \ # self.loss_content_classifier_c_a_and_c_a_recon + \ # self.loss_content_classifier_c_b_and_c_b_recon # self.loss_cla_total.backward() # self.cla_opt.step() @staticmethod def mean_list(lst): return sum(lst) / len(lst) def dis_update(self, x_a, x_b, hyperparameters): # print('dis_update') # print(x_a.is_cuda()) # exit() self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a) x_ab = self.gen_b.decode(c_a, s_b) # D loss # self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) # print(x_ba.detach().size()) # print(x_a.size()) # exit() x_ba_dis_out = self.dis_a(x_ba.detach()) x_a_dis_out = self.dis_a(x_a) self.loss_dis_a = self.compute_dis_loss(x_ba_dis_out, x_a_dis_out) # self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) x_ab_dis_out = self.dis_b(x_ab.detach()) x_b_dis_out = self.dis_b(x_b) self.loss_dis_b = self.compute_dis_loss(x_ab_dis_out, x_b_dis_out) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def compute_content_classifier_loss(self, label_predict, label_true): loss = self.criterion_content_classifier(label_predict, label_true) return loss def compute_content_classifier_two_predictions_loss( self, label_predict_1, label_predict_2): # loss = self.criterion_content_classifier(label_predict, label_true) # print(label_predict_1.size()) # print(label_predict_2.size()) loss = torch.mean(torch.abs(label_predict_1 - label_predict_2)) # print(loss.size()) # exit() return loss def compute_content_classifier_accuracy(self, label_predict, label_true): # print("label_true") # print(label_true) # # print("label_predict") # print(label_predict[0]) # print("max") values, indices = label_predict.max(1) # print(indices) results = (label_true == indices) # print(results) total_correct = results.sum().cpu().numpy() # print("total_correct") # print(total_correct) # total_samples = results.size() # print("total_samples") # print(total_samples) accuracy = float(total_correct) / float(self.batch_size_val) # print("accuracy") # print(accuracy) # # exit() return accuracy def compute_dis_loss(self, outs_fake, outs_real): # calculate the loss to train D # outs0 = self.forward(input_fake) # outs1 = self.forward(input_real) loss = 0 for it, (out_fake, out_real) in enumerate(zip(outs_fake, outs_real)): out_fake = out_fake['output_d'] out_real = out_real['output_d'] if self.gan_type == 'lsgan': loss += torch.mean((out_fake - 0)**2) + torch.mean( (out_real - 1)**2) elif self.gan_type == 'nsgan': all0 = Variable(torch.zeros_like(out_fake.data).cuda(), requires_grad=False) all1 = Variable(torch.ones_like(out_real.data).cuda(), requires_grad=False) loss += torch.mean( F.binary_cross_entropy(F.sigmoid(out_fake), all0) + F.binary_cross_entropy(F.sigmoid(out_real), all1)) else: assert 0, "Unsupported GAN type: {}".format(self.gan_type) return loss def compute_gen_adv_loss(self, outs_fake): # calculate the loss to train G # out_fake = self.forward(input_fake) loss = 0 for it, (out_fake) in enumerate(outs_fake): out_fake = out_fake['output_d'] if self.gan_type == 'lsgan': loss += torch.mean((out_fake - 1)**2) # LSGAN elif self.gan_type == 'nsgan': all1 = Variable(torch.ones_like(out_fake.data).cuda(), requires_grad=False) loss += torch.mean( F.binary_cross_entropy(F.sigmoid(out_fake), all1)) else: assert 0, "Unsupported GAN type: {}".format(self.gan_type) return loss def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load content classifier last_model_name = get_model_list(checkpoint_dir, "con_cla") state_dict = torch.load(last_model_name) self.content_classifier.load_state_dict(state_dict['con_cla']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) self.cla_opt.load_state_dict(state_dict['con_cla']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) self.cla_scheduler = get_scheduler(self.cla_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) con_cla_name = os.path.join(snapshot_dir, 'con_cla_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save({'con_cla': self.content_classifier.state_dict()}, con_cla_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict(), 'con_cla': self.cla_opt.state_dict() }, opt_name)
class aclgan_Trainer(nn.Module): def __init__(self, hyperparameters): super(aclmaskpermgidtno_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_AB = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain A self.gen_BA = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain B self.dis_A = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain A self.dis_B = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain B self.dis_2 = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator 2 # self.dis_2B = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator 2 for domain B self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.z_1 = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.z_2 = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.z_3 = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_A.parameters()) + list( self.dis_B.parameters()) + list(self.dis_2.parameters()) gen_params = list(self.gen_AB.parameters()) + list( self.gen_BA.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) self.alpha = hyperparameters['alpha'] self.focus_lam = hyperparameters['focus_loss'] # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_A.apply(weights_init('gaussian')) self.dis_B.apply(weights_init('gaussian')) self.dis_2.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): z_1 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) z_2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) z_3 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # encode c_1, _ = self.gen_AB.encode(x_a) c_2, s_2 = self.gen_BA.encode(x_a) c_4, s_4 = self.gen_AB.encode(x_b) # decode self.x_B_fake = self.gen_AB.decode(c_1, z_1) self.x_A_fake = self.gen_BA.decode(c_2, z_2) # recon self.x_A_recon = self.gen_BA.decode(c_2, s_2) self.x_B_recon = self.gen_AB.decode(c_4, s_4) #encode 2 c_3, _ = self.gen_BA.encode(self.x_B_fake) self.x_A2_fake = self.gen_BA.decode(c_3, z_3) self.X_A_A1_pair = torch.cat((x_a, self.x_A_fake), -3) self.X_A_A2_pair = torch.cat((x_a, self.x_A2_fake), -3) def focus_translation(self, x_fg, x_bg, x_focus): x_map = (x_focus + 1) / 2 x_map = x_map.repeat(1, 3, 1, 1) return torch.mul(x_fg, x_map) + torch.mul(x_bg, 1 - x_map) def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() focus_delta = hyperparameters['focus_delta'] focus_lambda = hyperparameters['focus_loss'] focus_lower = hyperparameters['focus_lower'] focus_upper = hyperparameters['focus_upper'] focus_epsilon = hyperparameters['focus_epsilon'] #forward z_1 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) z_2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) z_3 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # encode c_1, _ = self.gen_AB.encode(x_a) c_2, s_2 = self.gen_BA.encode(x_a) c_4, s_4 = self.gen_AB.encode(x_b) # decode if focus_lambda > 0: x_B_fake, x_B_focus = self.gen_AB.decode(c_1, z_1).split(3, 1) x_A_fake, x_A_focus = self.gen_BA.decode(c_2, self.alpha * z_2).split( 3, 1) x_B_fake = self.focus_translation(x_B_fake, x_a, x_B_focus) x_A_fake = self.focus_translation(x_A_fake, x_a, x_A_focus) # recon x_A_recon, x_A_recon_focus = self.gen_BA.decode(c_2, s_2).split(3, 1) x_B_recon, x_B_recon_focus = self.gen_AB.decode(c_4, s_4).split(3, 1) # x_A_recon = self.focus_translation(x_A_recon, x_a, x_A_recon_focus) # x_B_recon = self.focus_translation(x_B_recon, x_b, x_B_recon_focus) else: x_B_fake = self.gen_AB.decode(c_1, z_1) x_A_fake = self.gen_BA.decode(c_2, self.alpha * z_2) # recon x_A_recon = self.gen_BA.decode(c_2, s_2) x_B_recon = self.gen_AB.decode(c_4, s_4) #encode 2 c_3, _ = self.gen_BA.encode(x_B_fake) if focus_lambda > 0: x_A2_fake, x_A2_focus = self.gen_BA.decode(c_3, z_3).split(3, 1) x_A2_fake = self.focus_translation(x_A2_fake, x_B_fake, x_A2_focus) else: x_A2_fake = self.gen_BA.decode(c_3, z_3) x_A_A1_pair = torch.cat((x_a, x_A_fake), -3) x_A_A2_pair = torch.cat((x_a, x_A2_fake), -3) # GAN loss self.loss_gen_adv_A = (self.dis_A.calc_gen_loss(x_A_fake) + \ self.dis_A.calc_gen_loss(x_A2_fake)) * 0.5 self.loss_gen_adv_B = self.dis_B.calc_gen_loss(x_B_fake) self.loss_gen_adv_2 = self.dis_2.calc_gen_d2_loss( x_A_A1_pair, x_A_A2_pair) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_A + \ hyperparameters['gan_w'] * self.loss_gen_adv_B + \ hyperparameters['gan_cw'] * self.loss_gen_adv_2 if focus_lambda > 0: x_B_focus = (x_B_focus + 1) / 2 x_A_focus = (x_A_focus + 1) / 2 x_A2_focus = (x_A2_focus + 1) / 2 self.loss_gen_focus_B_size = (F.relu(torch.sum(x_B_focus - focus_upper), inplace=True) ** 2) * focus_delta + \ (F.relu(torch.sum(focus_lower - x_B_focus), inplace=True) ** 2) * focus_delta self.loss_gen_focus_B_digit = torch.sum( 1 / (torch.abs(x_B_focus - 0.5) + focus_epsilon)) self.loss_gen_focus_A_size = (F.relu(torch.sum(x_A_focus - focus_upper), inplace=True) ** 2) * focus_delta + \ (F.relu(torch.sum(focus_lower - x_A_focus), inplace=True) ** 2) * focus_delta self.loss_gen_focus_A_digit = torch.sum( 1 / (torch.abs(x_A_focus - 0.5) + focus_epsilon)) # self.loss_gen_focus_A = torch.sum(1 / (torch.abs(x_A_focus - 0.5) + focus_epsilon)) self.loss_gen_focus_A2_size = (F.relu(torch.sum(x_A2_focus - focus_upper), inplace=True) ** 2) * focus_delta + \ (F.relu(torch.sum(focus_lower - x_A2_focus), inplace=True) ** 2) * focus_delta self.loss_gen_focus_A2_digit = torch.sum( 1 / (torch.abs(x_A2_focus - 0.5) + focus_epsilon)) self.loss_gen_total += focus_lambda * (self.loss_gen_focus_B_size + self.loss_gen_focus_B_digit + \ self.loss_gen_focus_A_size + self.loss_gen_focus_A_digit +\ self.loss_gen_focus_A2_size + self.loss_gen_focus_A2_digit)/ x_a.size(2) / x_a.size(3) / x_a.size(0) / 3 self.loss_idt_A = self.recon_criterion(x_A_recon, x_a) self.loss_idt_B = self.recon_criterion(x_B_recon, x_b) self.loss_gen_total += hyperparameters['recon_x_w'] * self.loss_idt_A + \ hyperparameters['recon_x_w'] * self.loss_idt_B # print(self.loss_gen_focus_B, self.loss_gen_total) # print(self.loss_idt_A) self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() z_1 = Variable(self.z_1) z_2 = Variable(self.z_2) z_3 = Variable(self.z_3) x_A, x_B, x_A_fake, x_B_fake, x_A2_fake = [], [], [], [], [] if self.focus_lam > 0: mask_A, mask_B, mask_A2, mask_recon = [], [], [], [] x_A_recon = [] else: x_A_recon, x_B_recon = [], [] for i in range(x_a.size(0)): x_A.append(x_a[i].unsqueeze(0)) x_B.append(x_b[i].unsqueeze(0)) if self.focus_lam > 0: c_1, s_1 = self.gen_BA.encode(x_a[i].unsqueeze(0)) img, mask = self.gen_BA.decode(c_1, z_1[i].unsqueeze(0)).split( 3, 1) x_A_fake.append( self.focus_translation(img, x_a[i].unsqueeze(0), mask)) mask_A.append(mask) img, mask = self.gen_BA.decode(c_1, s_1).split(3, 1) # x_A_recon.append(self.focus_translation(img, x_a[i].unsqueeze(0), mask)) x_A_recon.append(img) mask_recon.append(mask) c_2, _ = self.gen_AB.encode(x_a[i].unsqueeze(0)) x_b_img, mask = self.gen_AB.decode(c_2, z_2[i].unsqueeze(0)).split( 3, 1) x_b_img = self.focus_translation(x_b_img, x_a[i].unsqueeze(0), mask) x_B_fake.append(x_b_img) mask_B.append(mask) c_3, _ = self.gen_BA.encode(x_b_img) img, mask = self.gen_BA.decode(c_3, z_3[i].unsqueeze(0)).split( 3, 1) x_A2_fake.append(self.focus_translation(img, x_b_img, mask)) mask_A2.append(mask) else: c_1, s_1 = self.gen_BA.encode(x_a[i].unsqueeze(0)) x_A_fake.append(self.gen_BA.decode(c_1, z_1[i].unsqueeze(0))) x_A_recon.append(self.gen_BA.decode(c_1, s_1)) c_2, _ = self.gen_AB.encode(x_a[i].unsqueeze(0)) x_B1 = self.gen_AB.decode(c_2, z_2[i].unsqueeze(0)) x_B_fake.append(x_B1) c_3, _ = self.gen_BA.encode(x_B1) x_A2_fake.append(self.gen_BA.decode(c_3, z_3[i].unsqueeze(0))) c_4, s_4 = self.gen_AB.encode(x_b) x_B_recon.append(self.gen_AB.decode(c_4, s_4)) if self.focus_lam > 0: x_A, x_B = torch.cat(x_A), torch.cat(x_B) x_A_fake, x_B_fake = torch.cat(x_A_fake), torch.cat(x_B_fake) mask_A, x_A2_fake = torch.cat(mask_A), torch.cat(x_A2_fake) mask_B, mask_recon = torch.cat(mask_B), torch.cat(mask_recon) mask_A2, x_A_recon = torch.cat(mask_A2), torch.cat(x_A_recon) self.train() return x_A, x_A_fake, mask_A, x_B_fake, mask_B, x_A2_fake, mask_A2, x_A_recon, mask_recon else: x_A, x_B = torch.cat(x_A), torch.cat(x_B) x_A_fake, x_B_fake = torch.cat(x_A_fake), torch.cat(x_B_fake) x_A_recon, x_A2_fake = torch.cat(x_A_recon), torch.cat(x_A2_fake) x_B_recon = torch.cat(x_B_recon) self.train() return x_A, x_A_fake, x_B_fake, x_A2_fake, x_A_recon, x_B, x_B_recon def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() focus_delta = hyperparameters['focus_delta'] focus_lambda = hyperparameters['focus_loss'] focus_epsilon = hyperparameters['focus_epsilon'] #forward z_1 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) z_2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) z_3 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # encode c_1, _ = self.gen_AB.encode(x_a) c_2, s_2 = self.gen_BA.encode(x_a) c_4, s_4 = self.gen_AB.encode(x_b) # decode if focus_lambda > 0: x_B_fake, x_B_focus = self.gen_AB.decode(c_1, z_1).split(3, 1) x_A_fake, x_A_focus = self.gen_BA.decode(c_2, self.alpha * z_2).split( 3, 1) x_B_fake = self.focus_translation(x_B_fake, x_a, x_B_focus) x_A_fake = self.focus_translation(x_A_fake, x_a, x_A_focus) else: x_B_fake = self.gen_AB.decode(c_1, z_1) x_A_fake = self.gen_BA.decode(c_2, self.alpha * z_2) #encode 2 c_3, _ = self.gen_BA.encode(x_B_fake) if focus_lambda > 0: x_A2_fake, x_A2_focus = self.gen_BA.decode(c_3, z_3).split(3, 1) x_A2_fake = self.focus_translation(x_A2_fake, x_B_fake, x_A2_focus) else: x_A2_fake = self.gen_BA.decode(c_3, z_3) x_A_A1_pair = torch.cat((x_a, x_A_fake), -3) x_A_A2_pair = torch.cat((x_a, x_A2_fake), -3) # D loss self.loss_dis_A = (self.dis_A.calc_dis_loss(x_A_fake, x_a) + \ self.dis_A.calc_dis_loss(x_A2_fake, x_a)) * 0.5 self.loss_dis_B = self.dis_B.calc_dis_loss(x_B_fake, x_b) self.loss_dis_2 = self.dis_2.calc_dis_loss(x_A_A1_pair, x_A_A2_pair) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_A + \ hyperparameters['gan_w'] * self.loss_dis_B + \ hyperparameters['gan_cw'] * self.loss_dis_2 self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_AB.load_state_dict(state_dict['AB']) self.gen_BA.load_state_dict(state_dict['BA']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_A.load_state_dict(state_dict['A']) self.dis_B.load_state_dict(state_dict['B']) self.dis_2.load_state_dict(state_dict['2']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save( { 'AB': self.gen_AB.state_dict(), 'BA': self.gen_BA.state_dict() }, gen_name) torch.save( { 'A': self.dis_A.state_dict(), 'B': self.dis_B.state_dict(), '2': self.dis_2.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class UNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() h_a, _ = self.gen_a.encode(x_a) h_b, _ = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(h_b) x_ab = self.gen_b.decode(h_a) self.train() return x_ab, x_ba def __compute_kl(self, mu): # def _compute_kl(self, mu, sd): # mu_2 = torch.pow(mu, 2) # sd_2 = torch.pow(sd, 2) # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0) # return encoding_loss mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() # encode h_a, n_a = self.gen_a.encode(x_a) h_b, n_b = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(h_a + n_a) x_b_recon = self.gen_b.decode(h_b + n_b) # decode (cross domain) x_ba = self.gen_a.decode(h_b + n_b) x_ab = self.gen_b.decode(h_a + n_a) # encode again h_b_recon, n_b_recon = self.gen_a.encode(x_ba) h_a_recon, n_a_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_kl_a = self.__compute_kl(h_a) self.loss_gen_recon_kl_b = self.__compute_kl(h_b) self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] for i in range(x_a.size(0)): h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(h_a)) x_b_recon.append(self.gen_b.decode(h_b)) x_ba.append(self.gen_a.decode(h_b)) x_ab.append(self.gen_b.decode(h_a)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba = torch.cat(x_ba) x_ab = torch.cat(x_ab) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() # encode h_a, n_a = self.gen_a.encode(x_a) h_b, n_b = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(h_b + n_b) x_ab = self.gen_b.decode(h_a + n_a) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class UNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True h_a, _ = self.gen_a.encode(x_a) h_b, _ = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(h_b) x_ab = self.gen_b.decode(h_a) self.train() return x_ab, x_ba def __compute_kl(self, mu): # def _compute_kl(self, mu, sd): # mu_2 = torch.pow(mu, 2) # sd_2 = torch.pow(sd, 2) # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0) # return encoding_loss mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() # encode h_a, n_a = self.gen_a.encode(x_a) h_b, n_b = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(h_a + n_a) x_b_recon = self.gen_b.decode(h_b + n_b) # decode (cross domain) x_ba = self.gen_a.decode(h_b + n_b) x_ab = self.gen_b.decode(h_a + n_a) # encode again h_b_recon, n_b_recon = self.gen_a.encode(x_ba) h_a_recon, n_a_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_kl_a = self.__compute_kl(h_a) self.loss_gen_recon_kl_b = self.__compute_kl(h_b) self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def sample(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] for i in range(x_a.size(0)): h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(h_a)) x_b_recon.append(self.gen_b.decode(h_b)) x_ba.append(self.gen_a.decode(h_b)) x_ab.append(self.gen_b.decode(h_a)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba = torch.cat(x_ba) x_ab = torch.cat(x_ab) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() # encode h_a, n_a = self.gen_a.encode(x_a) h_b, n_b = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(h_b + n_b) x_ab = self.gen_b.decode(h_a + n_a) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class MUNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a # self.gen_aT = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda(cun) self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda(cun) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_aT, s_a_fake_T = self.gen_a.encodeT(x_a) # self.gen_a.ptl() c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a_fake)#change + _fake x_ab = self.gen_b.decode(c_a, s_b_fake) x_aT = self.gen_a.decodeT(c_a,s_a_fake_T) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() # s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(cun))#torch.randn(*sizes) # s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(cun)) # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) c_aT,s_aT_prime = self.gen_a.encodeT(x_a) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime) x_aT_recon = self.gen_a.decodeT(c_a,s_aT_prime) # print("style code size:",s_a_prime.size()) # print("recon img size:",x_a_recon.size()) x_b_recon = self.gen_b.decode(c_b, s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(c_b, s_a_prime) x_ab = self.gen_b.decode(c_a, s_b_prime) # encode again c_b_recon, s_a_recon = self.gen_a.encode(x_ba) c_a_recon, s_b_recon = self.gen_b.encode(x_ab) # print("content code size:",c_a_recon.size()) # decode again (if needed) x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_styleT = self.recon_criterion(x_a,x_aT_recon) self.loss_gen_content = self.recon_criterion(c_a,c_aT) self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a_prime) self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b_prime) # self.loss_gen_geo = self.recon_criterion(s_a_prime,s_b_prime) self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total += self.recon_criterion(c_a,c_aT) * 5 self.loss_gen_total += self.loss_gen_styleT * 5 self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def sample(self, x_a, x_b): self.eval() # s_a1 = Variable(self.s_a) # s_b1 = Variable(self.s_b) # s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(cun)) # s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(cun)) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2, x_ba, x_ab = [], [], [], [], [], [], [], [] # for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) x_ba.append(self.gen_a.decode(c_b, s_a_fake)) # x_ba2.append(self.gen_a.decode(c_b, s_a_fake)) x_ab.append(self.gen_b.decode(c_a, s_b_fake)) # x_ab2.append(self.gen_b.decode(c_a, s_b_fake)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba = torch.cat(x_ba) x_ab = torch.cat(x_ab) # x_ab1, x_ab2 = torch.cat(x_ab1), torch # .cat(x_ab2) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(cun)) s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(cun)) # encode c_a, _ = self.gen_a.encode(x_a) c_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, _) x_ab = self.gen_b.decode(c_a, _) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name) # class UNIT_Trainer(nn.Module): # def __init__(self, hyperparameters): # super(UNIT_Trainer, self).__init__() # lr = hyperparameters['lr'] # # Initiate the networks # self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a # self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b # self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a # self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b # self.instancenorm = nn.InstanceNorm2d(512, affine=False) # # Setup the optimizers # beta1 = hyperparameters['beta1'] # beta2 = hyperparameters['beta2'] # dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) # gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) # self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], # lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) # self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], # lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) # self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) # self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # # Network weight initialization # self.apply(weights_init(hyperparameters['init'])) # self.dis_a.apply(weights_init('gaussian')) # self.dis_b.apply(weights_init('gaussian')) # # Load VGG model if needed # if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: # self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') # self.vgg.eval() # for param in self.vgg.parameters(): # param.requires_grad = False # def recon_criterion(self, input, target): # return torch.mean(torch.abs(input - target)) # def forward(self, x_a, x_b): # self.eval() # h_a, _ = self.gen_a.encode(x_a) # h_b, _ = self.gen_b.encode(x_b) # x_ba = self.gen_a.decode(h_b) # x_ab = self.gen_b.decode(h_a) # self.train() # return x_ab, x_ba # def __compute_kl(self, mu): # # def _compute_kl(self, mu, sd): # # mu_2 = torch.pow(mu, 2) # # sd_2 = torch.pow(sd, 2) # # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0) # # return encoding_loss # mu_2 = torch.pow(mu, 2) # encoding_loss = torch.mean(mu_2) # return encoding_loss # def gen_update(self, x_a, x_b, hyperparameters): # self.gen_opt.zero_grad() # # encode # h_a, n_a = self.gen_a.encode(x_a) # h_b, n_b = self.gen_b.encode(x_b) # # decode (within domain) # x_a_recon = self.gen_a.decode(h_a + n_a) # x_b_recon = self.gen_b.decode(h_b + n_b) # # decode (cross domain) # x_ba = self.gen_a.decode(h_b + n_b) # x_ab = self.gen_b.decode(h_a + n_a) # # encode again # h_b_recon, n_b_recon = self.gen_a.encode(x_ba) # h_a_recon, n_a_recon = self.gen_b.encode(x_ab) # # decode again (if needed) # x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # # reconstruction loss # self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) # self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) # self.loss_gen_recon_kl_a = self.__compute_kl(h_a) # self.loss_gen_recon_kl_b = self.__compute_kl(h_b) # self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) # self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) # self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) # self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) # # GAN loss # self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) # self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # # domain-invariant perceptual loss # self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 # self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # # total loss # # self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ # # hyperparameters['gan_w'] * self.loss_gen_adv_b + \ # # hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ # # hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ # # hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ # # hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ # # hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ # # hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ # # hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ # # hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ # # hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ # # hyperparameters['vgg_w'] * self.loss_gen_vgg_b # # self.loss_gen_total.backward() # # self.gen_opt.step() # self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ # hyperparameters['gan_w'] * self.loss_gen_adv_b + \ # hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ # hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ # hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ # hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ # hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ # hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ # hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ # hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ # hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ # hyperparameters['vgg_w'] * self.loss_gen_vgg_b # self.loss_gen_total.backward() # self.gen_opt.step() # def compute_vgg_loss(self, vgg, img, target): # img_vgg = vgg_preprocess(img) # target_vgg = vgg_preprocess(target) # img_fea = vgg(img_vgg) # target_fea = vgg(target_vgg) # return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) # def sample(self, x_a, x_b): # self.eval() # x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] # for i in range(x_a.size(0)): # h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) # h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0)) # x_a_recon.append(self.gen_a.decode(h_a)) # x_b_recon.append(self.gen_b.decode(h_b)) # x_ba.append(self.gen_a.decode(h_b)) # x_ab.append(self.gen_b.decode(h_a)) # x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) # x_ba = torch.cat(x_ba) # x_ab = torch.cat(x_ab) # self.train() # return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba # def dis_update(self, x_a, x_b, hyperparameters): # self.dis_opt.zero_grad() # # encode # h_a, n_a = self.gen_a.encode(x_a) # h_b, n_b = self.gen_b.encode(x_b) # # decode (cross domain) # x_ba = self.gen_a.decode(h_b + n_b) # x_ab = self.gen_b.decode(h_a + n_a) # # D loss # self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) # self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) # self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b # self.loss_dis_total.backward() # self.dis_opt.step() # def update_learning_rate(self): # if self.dis_scheduler is not None: # self.dis_scheduler.step() # if self.gen_scheduler is not None: # self.gen_scheduler.step() # def resume(self, checkpoint_dir, hyperparameters): # # Load generators # last_model_name = get_model_list(checkpoint_dir, "gen") # state_dict = torch.load(last_model_name) # self.gen_a.load_state_dict(state_dict['a']) # self.gen_b.load_state_dict(state_dict['b']) # iterations = int(last_model_name[-11:-3]) # # Load discriminators # last_model_name = get_model_list(checkpoint_dir, "dis") # state_dict = torch.load(last_model_name) # self.dis_a.load_state_dict(state_dict['a']) # self.dis_b.load_state_dict(state_dict['b']) # # Load optimizers # state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) # self.dis_opt.load_state_dict(state_dict['dis']) # self.gen_opt.load_state_dict(state_dict['gen']) # # Reinitilize schedulers # self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) # self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) # print('Resume from iteration %d' % iterations) # return iterations # def save(self, snapshot_dir, iterations): # # Save generators, discriminators, and optimizers # gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) # dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % ( # terations + 1)) # opt_name = os.path.join(snapshot_dir, 'optimizer.pt') # torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class DSMAP_Trainer(nn.Module): def __init__(self, hyperparameters): super(DSMAP_Trainer, self).__init__() # Initiate the networks mid_downsample = hyperparameters['gen'].get('mid_downsample', 1) self.content_enc = ContentEncoder_share( hyperparameters['gen']['n_downsample'], mid_downsample, hyperparameters['gen']['n_res'], hyperparameters['input_dim_a'], hyperparameters['gen']['dim'], 'in', hyperparameters['gen']['activ'], pad_type=hyperparameters['gen']['pad_type']) self.style_dim = hyperparameters['gen']['style_dim'] self.gen_a = AdaINGen( hyperparameters['input_dim_a'], self.content_enc, 'a', hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], self.content_enc, 'b', hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], self.content_enc.output_dim, hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], self.content_enc.output_dim, hyperparameters['dis']) # discriminator for domain b def build_optimizer(self, hyperparameters): # Setup the optimizers lr = hyperparameters['lr'] beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: import torchvision.models as models self.vgg = models.vgg16(pretrained=True) # If you cannot download pretrained model automatically, you can download it from # https://download.pytorch.org/models/vgg16-397923af.pth and load it manually # state_dict = torch.load('vgg16-397923af.pth') # self.vgg.load_state_dict(state_dict) self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def __compute_kl(self, mu, logvar): encoding_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return encoding_loss def gen_update(self, x_a, x_b, hyperparameters, iterations): self.gen_opt.zero_grad() self.gen_backward_cc(x_a, x_b, hyperparameters) #self.gen_opt.step() #self.gen_opt.zero_grad() self.gen_backward_latent(x_a, x_b, hyperparameters) self.gen_opt.step() def gen_backward_latent(self, x_a, x_b, hyperparameters): # random sample style vector and multimodal training s_a_random = Variable( torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b_random = Variable( torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # decode x_ba_random = self.gen_a.decode(self.c_b, self.da_b, s_a_random) x_ab_random = self.gen_b.decode(self.c_a, self.db_a, s_b_random) c_b_random_recon, _, _, s_a_random_recon, _, _ = self.gen_a.encode( x_ba_random) c_a_random_recon, _, _, s_b_random_recon, _, _ = self.gen_b.encode( x_ab_random) # style reconstruction loss self.loss_gen_recon_s_a = self.recon_criterion(s_a_random, s_a_random_recon) self.loss_gen_recon_s_b = self.recon_criterion(s_b_random, s_b_random_recon) loss_gen_recon_c_a = self.recon_criterion(self.c_a, c_a_random_recon) loss_gen_recon_c_b = self.recon_criterion(self.c_b, c_b_random_recon) loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba_random, x_a) loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab_random, x_b) loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba_random, x_a) if hyperparameters['vgg_w'] > 0 else 0 loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab_random, x_b) if hyperparameters['vgg_w'] > 0 else 0 loss_gen_total = hyperparameters['gan_w'] * loss_gen_adv_a + \ hyperparameters['gan_w'] * loss_gen_adv_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_c_w'] * loss_gen_recon_c_a + \ hyperparameters['recon_c_w'] * loss_gen_recon_c_b + \ hyperparameters['vgg_w'] * loss_gen_vgg_a + \ hyperparameters['vgg_w'] * loss_gen_vgg_b self.loss_gen_total += loss_gen_total self.loss_gen_total.backward() self.loss_gen_total += loss_gen_total self.loss_gen_adv_a += loss_gen_adv_a self.loss_gen_adv_b += loss_gen_adv_b self.loss_gen_recon_c_a += loss_gen_recon_c_a self.loss_gen_recon_c_b += loss_gen_recon_c_b self.loss_gen_vgg_a += loss_gen_vgg_a self.loss_gen_vgg_b += loss_gen_vgg_b def gen_backward_cc(self, x_a, x_b, hyperparameters): pre_c_a, self.c_a, c_domain_a, self.db_a, self.s_a_prime, mu_a, logvar_a = self.gen_a.encode( x_a, training=True, flag=True) pre_c_b, self.c_b, c_domain_b, self.da_b, self.s_b_prime, mu_b, logvar_b = self.gen_b.encode( x_b, training=True, flag=True) self.da_a = self.gen_b.domain_mapping(self.c_a, pre_c_a) self.db_b = self.gen_a.domain_mapping(self.c_b, pre_c_b) # decode (within domain) x_a_recon = self.gen_a.decode(self.c_a, self.da_a, self.s_a_prime) x_b_recon = self.gen_b.decode(self.c_b, self.db_b, self.s_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(self.c_b, self.da_b, self.s_a_prime) x_ab = self.gen_b.decode(self.c_a, self.db_a, self.s_b_prime) c_b_recon, _, self.db_b_recon, s_a_recon, _, _ = self.gen_a.encode( x_ba, training=True) c_a_recon, _, self.da_a_recon, s_b_recon, _, _ = self.gen_b.encode( x_ab, training=True) # decode again (cycle consistance loss) x_aba = self.gen_a.decode(c_a_recon, self.da_a_recon, s_a_recon) x_bab = self.gen_b.decode(c_b_recon, self.db_b_recon, s_b_recon) # domain-specific content reconstruction loss self.loss_gen_recon_d_a = self.recon_criterion( c_domain_a, self.da_a) if hyperparameters['recon_d_w'] > 0 else 0 self.loss_gen_recon_d_b = self.recon_criterion( c_domain_b, self.db_b) if hyperparameters['recon_d_w'] > 0 else 0 # image reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) # domain-invariant content reconstruction loss self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, self.c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, self.c_b) # cyc loss self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) # kl loss (if needed) self.loss_gen_recon_kl_a = self.__compute_kl( mu_a, logvar_a) if hyperparameters['recon_kl_w'] > 0 else 0 self.loss_gen_recon_kl_b = self.__compute_kl( mu_b, logvar_b) if hyperparameters['recon_kl_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba, x_a) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab, x_b) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_a) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_b) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_d_w'] * self.loss_gen_recon_d_a + \ hyperparameters['recon_d_w'] * self.loss_gen_recon_d_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b # self.loss_gen_total.backward(retain_graph=True) def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg.features(img_vgg) target_fea = vgg.features(target_vgg) return contextual_loss(img_fea, target_fea) def dis_update(self, x_a, x_b, hyperparameters, iterations): self.dis_opt.zero_grad() s_a_random = Variable( torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) s_b_random = Variable( torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode pre_c_a, c_a, c_domain_a, db_a, s_a, _, _ = self.gen_a.encode( x_a, training=True, flag=True) pre_c_b, c_b, c_domain_b, da_b, s_b, _, _ = self.gen_b.encode( x_b, training=True, flag=True) da_a = self.gen_b.domain_mapping(c_a, pre_c_a) db_b = self.gen_a.domain_mapping(c_b, pre_c_b) # decode (cross domain) x_ba = self.gen_a.decode(c_b, da_b, s_a) x_ab = self.gen_b.decode(c_a, db_a, s_b) # decode (cross domain) x_ba_random = self.gen_a.decode(c_b, da_b, s_a_random) x_ab_random = self.gen_b.decode(c_a, db_a, s_b_random) c_b_recon, _, db_b_recon, s_a_recon, _, _ = self.gen_a.encode(x_ba) c_a_recon, _, da_a_recon, s_b_recon, _, _ = self.gen_b.encode(x_ab) _, _, db_b_random_recon, _, _, _ = self.gen_a.encode(x_ba_random) _, _, da_a_random_recon, _, _, _ = self.gen_b.encode(x_ab_random) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss( x_ba.detach(), x_a) + self.dis_a.calc_dis_loss( x_ba_random.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss( x_ab.detach(), x_b) + self.dis_b.calc_dis_loss( x_ab_random.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \ hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ab, x_ba = [], [], [], [] for i in range(x_a.size(0)): pre_c_a, c_a, _, db_a, s_a_fake, _, _ = self.gen_a.encode( x_a[i].unsqueeze(0), flag=True) pre_c_b, c_b, _, da_b, s_b_fake, _, _ = self.gen_b.encode( x_b[i].unsqueeze(0), flag=True) da_a = self.gen_b.domain_mapping(c_a, pre_c_a) db_b = self.gen_a.domain_mapping(c_b, pre_c_b) x_a_recon.append(self.gen_a.decode(c_a, da_a, s_a_fake)) x_b_recon.append(self.gen_b.decode(c_b, db_b, s_b_fake)) x_ba.append(self.gen_a.decode(c_b, da_b, s_a_fake)) x_ab.append(self.gen_b.decode(c_a, db_a, s_b_fake)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ab, x_ba = torch.cat(x_ab), torch.cat(x_ba) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)