def __init__(self, hyperparameters): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False
def __init__(self, hyperparameters, resume_epoch=-1, snapshot_dir=None): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks. self.gen = VAEGen( hyperparameters['input_dim'] + hyperparameters['n_datasets'], hyperparameters['gen'], hyperparameters['n_datasets']) # Auto-encoder for domain a. self.dis = MsImageDis( hyperparameters['input_dim'] + hyperparameters['n_datasets'], hyperparameters['dis']) # Discriminator for domain a. self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.sup = UNet(input_channels=hyperparameters['input_dim'], num_classes=2).cuda() # Setup the optimizers. beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis.parameters()) gen_params = list(self.gen.parameters()) + list(self.sup.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization. self.apply(weights_init(hyperparameters['init'])) self.dis.apply(weights_init('gaussian')) # Presetting one hot encoding vectors. self.one_hot_img = torch.zeros(hyperparameters['n_datasets'], hyperparameters['batch_size'], hyperparameters['n_datasets'], 256, 256).cuda() self.one_hot_h = torch.zeros(hyperparameters['n_datasets'], hyperparameters['batch_size'], hyperparameters['n_datasets'], 64, 64).cuda() for i in range(hyperparameters['n_datasets']): self.one_hot_img[i, :, i, :, :].fill_(1) self.one_hot_h[i, :, i, :, :].fill_(1) if resume_epoch != -1: self.resume(snapshot_dir, hyperparameters)
def __init__(self, hyperparameters): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False)
def __init__(self, hyperparameters): super(UNIT_Trainer, self).__init__() lr = hyperparameters["lr"] # Initiate the networks self.gen_a = VAEGen( hyperparameters["input_dim_a"], hyperparameters["gen"]) # auto-encoder for domain a self.gen_b = VAEGen( hyperparameters["input_dim_b"], hyperparameters["gen"]) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters["input_dim_b"], hyperparameters["dis"]) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) # Setup the optimizers beta1 = hyperparameters["beta1"] beta2 = hyperparameters["beta2"] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters["init"])) self.dis_a.apply(weights_init("gaussian")) self.dis_b.apply(weights_init("gaussian")) # Load VGG model if needed if "vgg_w" in hyperparameters.keys() and hyperparameters["vgg_w"] > 0: self.vgg = load_vgg16(hyperparameters["vgg_model_path"] + "/models") self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False
def __init__(self, hyperparameters): super(Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks # auto-encoder for domain a self.trait_dim = hyperparameters['gen']['trait_dim'] self.gen_a = VAEGen(hyperparameters['input_dim'], hyperparameters['basis_encoder_dims'], hyperparameters['trait_encoder_dims'], hyperparameters['decoder_dims'], self.trait_dim) # auto-encoder for domain b self.gen_b = VAEGen(hyperparameters['input_dim'], hyperparameters['basis_encoder_dims'], hyperparameters['trait_encoder_dims'], hyperparameters['decoder_dims'], self.trait_dim) # discriminator for domain a self.dis_a = Discriminator(hyperparameters['input_dim'], hyperparameters['dis_dims'], 1) # discriminator for domain b self.dis_b = Discriminator(hyperparameters['input_dim'], hyperparameters['dis_dims'], 1) # fix the noise used in sampling self.trait_a = torch.randn(8, self.trait_dim, 1, 1) self.trait_b = torch.randn(8, self.trait_dim, 1, 1) # Setup the optimizers dis_params = list(self.dis_a.parameters()) + \ list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + \ list(self.gen_b.parameters()) for _p in gen_params: print(_p.data.shape) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.gen_a.apply(weights_init('gaussian')) self.gen_b.apply(weights_init('gaussian')) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian'))
def __init__(self, hyperparameters): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b if not hyperparameters['origin']: self.dis_a = MultiscaleDiscriminator(hyperparameters['input_dim_a'], # discriminator for a ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d, use_sigmoid=False, num_D=2, getIntermFeat=True ) self.dis_b = MultiscaleDiscriminator(hyperparameters['input_dim_b'], # discriminator for b ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d, use_sigmoid=False, num_D=2, getIntermFeat=True ) self.criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.FloatTensor) else: self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) self.instancenorm = nn.InstanceNorm2d(512, affine=False) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False
class UNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() h_a, _ = self.gen_a.encode(x_a) h_b, _ = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(h_b) x_ab = self.gen_b.decode(h_a) self.train() return x_ab, x_ba def __compute_kl(self, mu): # def _compute_kl(self, mu, sd): # mu_2 = torch.pow(mu, 2) # sd_2 = torch.pow(sd, 2) # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0) # return encoding_loss mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() # encode h_a, n_a = self.gen_a.encode(x_a) h_b, n_b = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(h_a + n_a) x_b_recon = self.gen_b.decode(h_b + n_b) # decode (cross domain) x_ba = self.gen_a.decode(h_b + n_b) x_ab = self.gen_b.decode(h_a + n_a) # encode again h_b_recon, n_b_recon = self.gen_a.encode(x_ba) h_a_recon, n_a_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_kl_a = self.__compute_kl(h_a) self.loss_gen_recon_kl_b = self.__compute_kl(h_b) self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] for i in range(x_a.size(0)): h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(h_a)) x_b_recon.append(self.gen_b.decode(h_b)) x_ba.append(self.gen_a.decode(h_b)) x_ab.append(self.gen_b.decode(h_a)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba = torch.cat(x_ba) x_ab = torch.cat(x_ab) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() # encode h_a, n_a = self.gen_a.encode(x_a) h_b, n_b = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(h_b + n_b) x_ab = self.gen_b.decode(h_a + n_a) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class Trainer(nn.Module): def __init__(self, hyperparameters): super(Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks # auto-encoder for domain a self.trait_dim = hyperparameters['gen']['trait_dim'] self.gen_a = VAEGen(hyperparameters['input_dim'], hyperparameters['basis_encoder_dims'], hyperparameters['trait_encoder_dims'], hyperparameters['decoder_dims'], self.trait_dim) # auto-encoder for domain b self.gen_b = VAEGen(hyperparameters['input_dim'], hyperparameters['basis_encoder_dims'], hyperparameters['trait_encoder_dims'], hyperparameters['decoder_dims'], self.trait_dim) # discriminator for domain a self.dis_a = Discriminator(hyperparameters['input_dim'], hyperparameters['dis_dims'], 1) # discriminator for domain b self.dis_b = Discriminator(hyperparameters['input_dim'], hyperparameters['dis_dims'], 1) # fix the noise used in sampling self.trait_a = torch.randn(8, self.trait_dim, 1, 1) self.trait_b = torch.randn(8, self.trait_dim, 1, 1) # Setup the optimizers dis_params = list(self.dis_a.parameters()) + \ list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + \ list(self.gen_b.parameters()) for _p in gen_params: print(_p.data.shape) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.gen_a.apply(weights_init('gaussian')) self.gen_b.apply(weights_init('gaussian')) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() trait_a = Variable(self.trait_a) trait_b = Variable(self.trait_b) basis_a, trait_a_fake = self.gen_a.encode(x_a) basis_b, trait_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(basis_b, trait_a) x_ab = self.gen_b.decode(basis_a, trait_b) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() trait_a = Variable(torch.randn(x_a.size(0), self.trait_dim)) trait_b = Variable(torch.randn(x_b.size(0), self.trait_dim)) # encode basis_a, trait_a_prime = self.gen_a.encode(x_a) basis_b, trait_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(basis_a, trait_a_prime) x_b_recon = self.gen_b.decode(basis_b, trait_b_prime) # decode (cross domain) x_ba = self.gen_a.decode(basis_b, trait_a) x_ab = self.gen_b.decode(basis_a, trait_b) # encode again basis_b_recon, trait_a_recon = self.gen_a.encode(x_ba) basis_a_recon, trait_b_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode( basis_a_recon, trait_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( basis_b_recon, trait_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_trait_a = self.recon_criterion( trait_a_recon, trait_a) self.loss_gen_recon_trait_b = self.recon_criterion( trait_b_recon, trait_b) self.loss_gen_recon_basis_a = self.recon_criterion( basis_a_recon, basis_a) self.loss_gen_recon_basis_b = self.recon_criterion( basis_b_recon, basis_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # total loss self.loss_gen_total = hyperparameters[ 'gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_trait_w'] * self.loss_gen_recon_trait_a + \ hyperparameters['recon_basis_w'] * self.loss_gen_recon_basis_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_trait_w'] * self.loss_gen_recon_trait_b + \ hyperparameters['recon_basis_w'] * self.loss_gen_recon_basis_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b self.loss_gen_total.backward() self.gen_opt.step() # def sample(self, x_a, x_b): # self.eval() # s_a1 = Variable(self.s_a) # s_b1 = Variable(self.s_b) # s_a2 = Variable(torch.randn(x_a.size(0), self.trait_dim, 1, 1)) # s_b2 = Variable(torch.randn(x_b.size(0), self.trait_dim, 1, 1)) # x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] # for i in range(x_a.size(0)): # c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) # c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) # x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) # x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) # x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) # x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) # x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) # x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) # x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) # x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) # x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) # self.train() # return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() trait_a = Variable(torch.randn(x_a.size(0), self.trait_dim)) trait_b = Variable(torch.randn(x_b.size(0), self.trait_dim)) # encode basis_a, _ = self.gen_a.encode(x_a) basis_b, _ = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(basis_b, trait_a) x_ab = self.gen_b.decode(basis_a, trait_b) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba, x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab, x_b) self.loss_dis_total = hyperparameters['gan_w'] * \ self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
def __init__(self, hyperparameters): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = VAEGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = VAEGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.dis_content = Dis_content() self.gpuid = hyperparameters['gpuID'] # @ add backgound discriminator for each domain self.instancenorm = nn.InstanceNorm2d(512, affine=False) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.content_opt = torch.optim.Adam( self.dis_content.parameters(), lr=lr / 2., betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) self.content_scheduler = get_scheduler(self.content_opt, hyperparameters) # Network weight initialization self.gen_a.apply(weights_init(hyperparameters['init'])) self.gen_b.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) self.dis_content.apply(weights_init('gaussian')) # initialize the blur network self.BGBlur_kernel = [5, 9, 15] self.BlurNet = [ GaussionSmoothLayer(3, k_size, 25).cuda(self.gpuid) for k_size in self.BGBlur_kernel ] self.BlurWeight = [0.25, 0.5, 1.] self.Gradient = GradientLoss(3, 3) # # Load VGG model if needed for test if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg19() if torch.cuda.is_available(): self.vgg.cuda(self.gpuid) self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False
class UNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = VAEGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = VAEGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.dis_content = Dis_content() self.gpuid = hyperparameters['gpuID'] # @ add backgound discriminator for each domain self.instancenorm = nn.InstanceNorm2d(512, affine=False) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.content_opt = torch.optim.Adam( self.dis_content.parameters(), lr=lr / 2., betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) self.content_scheduler = get_scheduler(self.content_opt, hyperparameters) # Network weight initialization self.gen_a.apply(weights_init(hyperparameters['init'])) self.gen_b.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) self.dis_content.apply(weights_init('gaussian')) # initialize the blur network self.BGBlur_kernel = [5, 9, 15] self.BlurNet = [ GaussionSmoothLayer(3, k_size, 25).cuda(self.gpuid) for k_size in self.BGBlur_kernel ] self.BlurWeight = [0.25, 0.5, 1.] self.Gradient = GradientLoss(3, 3) # # Load VGG model if needed for test if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg19() if torch.cuda.is_available(): self.vgg.cuda(self.gpuid) self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() h_a = self.gen_a.encode_cont(x_a) # h_a_sty = self.gen_a.encode_sty(x_a) # h_b = self.gen_b.encode_cont(x_b) x_ab = self.gen_b.decode_cont(h_a) # h_c = torch.cat((h_b, h_a_sty), 1) # x_ba = self.gen_a.decode_recs(h_c) # self.train() return x_ab #, x_ba def __compute_kl(self, mu): # def _compute_kl(self, mu, sd): mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def content_update(self, x_a, x_b, hyperparameters): # # encode self.content_opt.zero_grad() enc_a = self.gen_a.encode_cont(x_a) enc_b = self.gen_b.encode_cont(x_b) pred_fake = self.dis_content.forward(enc_a) pred_real = self.dis_content.forward(enc_b) loss_D = 0 if hyperparameters['gan_type'] == 'lsgan': loss_D += torch.mean((pred_fake - 0)**2) + torch.mean( (pred_real - 1)**2) elif hyperparameters['gan_type'] == 'nsgan': all0 = Variable(torch.zeros_like(pred_fake.data).cuda(self.gpuid), requires_grad=False) all1 = Variable(torch.ones_like(pred_real.data).cuda(self.gpuid), requires_grad=False) loss_D += torch.mean( F.binary_cross_entropy(F.sigmoid(pred_fake), all0) + F.binary_cross_entropy(F.sigmoid(pred_real), all1)) else: assert 0, "Unsupported GAN type: {}".format( hyperparameters['gan_type']) loss_D.backward() nn.utils.clip_grad_norm_(self.dis_content.parameters(), 5) self.content_opt.step() def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() self.content_opt.zero_grad() # encode h_a = self.gen_a.encode_cont(x_a) h_b = self.gen_b.encode_cont(x_b) h_a_sty = self.gen_a.encode_sty(x_a) # add domain adverisal loss for generator out_a = self.dis_content(h_a) out_b = self.dis_content(h_b) self.loss_ContentD = 0 if hyperparameters['gan_type'] == 'lsgan': self.loss_ContentD += torch.mean((out_a - 0.5)**2) + torch.mean( (out_b - 0.5)**2) elif hyperparameters['gan_type'] == 'nsgan': all1 = Variable(0.5 * torch.ones_like(out_b.data).cuda(self.gpuid), requires_grad=False) self.loss_ContentD += torch.mean( F.binary_cross_entropy(F.sigmoid(out_a), all1) + F.binary_cross_entropy(F.sigmoid(out_b), all1)) else: assert 0, "Unsupported GAN type: {}".format( hyperparameters['gan_type']) # decode (within domain) h_a_cont = torch.cat((h_a, h_a_sty), 1) noise_a = torch.randn(h_a_cont.size()).cuda(h_a_cont.data.get_device()) x_a_recon = self.gen_a.decode_recs(h_a_cont + noise_a) noise_b = torch.randn(h_b.size()).cuda(h_b.data.get_device()) x_b_recon = self.gen_b.decode_cont(h_b + noise_b) # decode (cross domain) h_ba_cont = torch.cat((h_b, h_a_sty), 1) x_ba = self.gen_a.decode_recs(h_ba_cont + noise_a) x_ab = self.gen_b.decode_cont(h_a + noise_b) # encode again h_b_recon = self.gen_a.encode_cont(x_ba) h_b_sty_recon = self.gen_a.encode_sty(x_ba) h_a_recon = self.gen_b.encode_cont(x_ab) # decode again (if needed) h_a_cat_recs = torch.cat((h_a_recon, h_b_sty_recon), 1) x_aba = self.gen_a.decode_recs( h_a_cat_recs + noise_a) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode_cont( h_b_recon + noise_b) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_kl_a = self.__compute_kl(h_a) self.loss_gen_recon_kl_b = self.__compute_kl(h_b) self.loss_gen_recon_kl_sty = self.__compute_kl(h_a_sty) self.loss_gen_cyc_x_a = self.recon_criterion( x_aba, x_a) if x_aba is not None else 0 self.loss_gen_cyc_x_b = self.recon_criterion( x_bab, x_b) if x_aba is not None else 0 self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) self.loss_gen_recon_kl_cyc_sty = self.__compute_kl(h_b_sty_recon) # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # add background guide loss self.loss_bgm = 0 if hyperparameters['BGM'] != 0: for index, weight in enumerate(self.BlurWeight): out_b = self.BlurNet[index](x_ba) out_real_b = self.BlurNet[index](x_b) out_a = self.BlurNet[index](x_ab) out_real_a = self.BlurNet[index](x_a) grad_loss_b = self.recon_criterion(out_b, out_real_b) grad_loss_a = self.recon_criterion(out_a, out_real_a) self.loss_bgm += weight * (grad_loss_a + grad_loss_b) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_sty + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_sty + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ hyperparameters['BGM'] * self.loss_bgm + \ hyperparameters['gan_w'] * self.loss_ContentD self.loss_gen_total.backward() self.gen_opt.step() self.content_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): if x_a is None or x_b is None: return None self.eval() x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] for i in range(x_a.size(0)): h_a = self.gen_a.encode_cont(x_a[i].unsqueeze(0)) h_a_sty = self.gen_a.encode_sty(x_a[i].unsqueeze(0)) h_b = self.gen_b.encode_cont(x_b[i].unsqueeze(0)) h_ba_cont = torch.cat((h_b, h_a_sty), 1) h_aa_cont = torch.cat((h_a, h_a_sty), 1) x_a_recon.append(self.gen_a.decode_recs(h_aa_cont)) x_b_recon.append(self.gen_b.decode_cont(h_b)) x_ba.append(self.gen_a.decode_recs(h_ba_cont)) x_ab.append(self.gen_b.decode_cont(h_a)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba = torch.cat(x_ba) x_ab = torch.cat(x_ab) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() self.content_opt.zero_grad() # encode h_a = self.gen_a.encode_cont(x_a) h_a_sty = self.gen_a.encode_sty(x_a) h_b = self.gen_b.encode_cont(x_b) # # @ add content adversial out_a = self.dis_content(h_a) out_b = self.dis_content(h_b) self.loss_ContentD = 0 if hyperparameters['gan_type'] == 'lsgan': self.loss_ContentD += torch.mean((out_a - 0)**2) + torch.mean( (out_b - 1)**2) elif hyperparameters['gan_type'] == 'nsgan': all0 = Variable(torch.zeros_like(out_a.data).cuda(self.gpuid), requires_grad=False) all1 = Variable(torch.ones_like(out_b.data).cuda(self.gpuid), requires_grad=False) self.loss_ContentD += torch.mean( F.binary_cross_entropy(F.sigmoid(out_a), all0) + F.binary_cross_entropy(F.sigmoid(out_b), all1)) else: assert 0, "Unsupported GAN type: {}".format( hyperparameters['gan_type']) # decode (cross domain) h_cat = torch.cat((h_b, h_a_sty), 1) noise_b = torch.randn(h_cat.size()).cuda(h_cat.data.get_device()) x_ba = self.gen_a.decode_recs(h_cat + noise_b) noise_a = torch.randn(h_a.size()).cuda(h_a.data.get_device()) x_ab = self.gen_b.decode_cont(h_a + noise_a) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * ( self.loss_dis_a + self.loss_dis_b + self.loss_ContentD) self.loss_dis_total.backward() nn.utils.clip_grad_norm_(self.dis_content.parameters(), 5) # dis_content update self.dis_opt.step() self.content_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.content_scheduler is not None: self.content_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis_00188000") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # load discontent discriminator last_model_name = get_model_list(checkpoint_dir, "dis_Content") state_dict = torch.load(last_model_name) self.dis_content.load_state_dict(state_dict['dis_c']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) self.content_opt.load_state_dict(state_dict['dis_content']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) self.content_scheduler = get_scheduler(self.content_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) dis_Con_name = os.path.join(snapshot_dir, 'dis_Content_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save({'dis_c': self.dis_content.state_dict()}, dis_Con_name) # opt state torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict(), \ 'dis_content':self.content_opt.state_dict()}, opt_name)
class UNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b if not hyperparameters['origin']: self.dis_a = MultiscaleDiscriminator(hyperparameters['input_dim_a'], # discriminator for a ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d, use_sigmoid=False, num_D=2, getIntermFeat=True ) self.dis_b = MultiscaleDiscriminator(hyperparameters['input_dim_b'], # discriminator for b ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d, use_sigmoid=False, num_D=2, getIntermFeat=True ) self.criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.FloatTensor) else: self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) self.instancenorm = nn.InstanceNorm2d(512, affine=False) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def compute_digits_differnce(self, digits1, digits2, weight=1.0): feat_diff = 0 feat_weights = 4.0 / (3 + 1) # 3 layers's discrminator D_weights = 1.0 / 2.0 # number of discrminator for i in range(2): for j in range(len(digits2[i])-1): feat_diff += D_weights * feat_weights * \ F.l1_loss(digits2[i][j], digits1[i][j].detach()) * weight return feat_diff def compute_gan_loss(self, real_digits, fake_digits, gan_cri, loss_at='None'): errD = None errG = None errG_feat = None if gan_cri is not None: if loss_at == 'D': errD = (gan_cri(real_digits, True) \ + gan_cri(fake_digits, False)) * 0.5 elif loss_at == 'G': errG = gan_cri(fake_digits, True) errG_feat = self.compute_digits_differnce(real_digits, fake_digits, weight=10.0) return errD, errG, errG_feat def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() h_a, _ = self.gen_a.encode(x_a) h_b, _ = self.gen_b.encode(x_b) x_ba, _ = self.gen_a.decode(h_b) x_ab, _ = self.gen_b.decode(h_a) self.train() return x_ab, x_ba def __compute_kl(self, mu): # def _compute_kl(self, mu, sd): # mu_2 = torch.pow(mu, 2) # sd_2 = torch.pow(sd, 2) # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0) # return encoding_loss mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() ############ Encode #########################################$ h_a, n_a = self.gen_a.encode(x_a) h_b, n_b = self.gen_b.encode(x_b) # decode (within domain) if hyperparameters['zero_z']: pre_Z = Variable(torch.zeros(hyperparameters['batch_size'], hyperparameters['gen']['z_num']).cuda()) else: pre_Z = None ########### Reconstruction ################################### x_a_recon, _ = self.gen_a.decode(h_a + n_a, z_var=pre_Z) x_b_recon, _ = self.gen_b.decode(h_b + n_b, z_var=pre_Z) ############################################################## ########### Decode (Cross Domain) ############################ ############################################################## ########## with random vector ################ x_ba, z_var_ba_1 = self.gen_a.decode(h_b + n_b) x_ab, z_var_ab_1 = self.gen_b.decode(h_a + n_a) ########## with zero latent vector ########### x_ba_zero, _ = self.gen_a.decode(h_b + n_b, z_var=pre_Z) x_ab_zero, _ = self.gen_b.decode(h_a + n_a, z_var=pre_Z) ######## decode (cross domain the second time) ################ if hyperparameters['loss_eg_weight'] != 0: x_ba_eg, z_var_ba_2 = self.gen_a.decode(h_b + n_b) x_ab_eg, z_var_ab_2 = self.gen_b.decode(h_a + n_a) x_ba_eg = x_ba_eg.detach() x_ab_eg = x_ab_eg.detach() if not hyperparameters['origin']: x_ba_eg_digits = self.dis_a(x_ba_eg) x_ab_eg_digits = self.dis_b(x_ab_eg) # encode again h_b_recon, n_b_recon = self.gen_a.encode(x_ba_zero) h_a_recon, n_a_recon = self.gen_b.encode(x_ab_zero) # decode again (if needed) x_aba, _ = self.gen_a.decode(h_a_recon + n_a_recon, z_var=pre_Z ) if hyperparameters['recon_x_cyc_w'] > 0 else (None, 0) x_bab, _ = self.gen_b.decode(h_b_recon + n_b_recon, z_var=pre_Z ) if hyperparameters['recon_x_cyc_w'] > 0 else (None, 0) # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_kl_a = self.__compute_kl(h_a) self.loss_gen_recon_kl_b = self.__compute_kl(h_b) self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) if x_aba is not None else 0 self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) if x_bab is not None else 0 self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) if hyperparameters['loss_eg_weight'] == 0: self.loss_gen_adv_a, self.loss_gen_adv_b, self.loss_gan_feat_a, \ self.loss_gan_feat_b = 0, 0, 0, 0 elif not hyperparameters['origin']: x_ba_digits = self.dis_a(x_ba) x_a_digits = self.dis_a(x_a) _, self.loss_gen_adv_a, self.loss_gan_feat_a = \ self.compute_gan_loss(x_a_digits, x_ba_digits, self.criterionGAN, loss_at='G') x_ab_digits = self.dis_a(x_ab) x_b_digits = self.dis_a(x_b) _, self.loss_gen_adv_b, self.loss_gan_feat_b = \ self.compute_gan_loss(x_b_digits, x_ab_digits, self.criterionGAN, loss_at='G') else: self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) self.loss_gan_feat_a, self.loss_gan_feat_b = 0, 0 # GAN loss # self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) # self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 if hyperparameters['loss_eg_weight'] == 0: self.loss_eg = 0.0 elif not hyperparameters['origin']: self.loss_eg = compute_eg_loss(x_ba_digits, x_ba_eg_digits, x_ab_digits, x_ab_eg_digits, z_var_ba_1, z_var_ba_2, z_var_ab_1, z_var_ab_2, hyperparameters) else: self.loss_eg = compute_eg_loss(x_ba, x_ba_eg, x_ab, x_ab_eg, z_var_ba_1, z_var_ba_2, z_var_ab_1, z_var_ab_2, hyperparameters) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ self.loss_gan_feat_b + self.loss_gan_feat_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ hyperparameters['loss_eg_weight'] * self.loss_eg self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] for i in range(x_a.size(0)): h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(h_a)[0]) x_b_recon.append(self.gen_b.decode(h_b)[0]) x_ba.append(self.gen_a.decode(h_b)[0]) x_ab.append(self.gen_b.decode(h_a)[0]) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba = torch.cat(x_ba) x_ab = torch.cat(x_ab) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() # encode h_a, n_a = self.gen_a.encode(x_a) h_b, n_b = self.gen_b.encode(x_b) # decode (cross domain) x_ba, _ = self.gen_a.decode(h_b + n_b) x_ab, _ = self.gen_b.decode(h_a + n_a) # D loss if not hyperparameters['origin']: real_digits_a = self.dis_a(x_a) fake_digits_a = self.dis_a(x_ba.detach()) real_digits_b = self.dis_b(x_b) fake_digits_b = self.dis_b(x_ab.detach()) self.loss_dis_a, _, _ = self.compute_gan_loss(real_digits_a, fake_digits_a, self.criterionGAN, loss_at='D') self.loss_dis_b, _, _ = self.compute_gan_loss(real_digits_b, fake_digits_b, self.criterionGAN, loss_at='D') else: self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class UNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True h_a, _ = self.gen_a.encode(x_a) h_b, _ = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(h_b) x_ab = self.gen_b.decode(h_a) self.train() return x_ab, x_ba def __compute_kl(self, mu): # def _compute_kl(self, mu, sd): # mu_2 = torch.pow(mu, 2) # sd_2 = torch.pow(sd, 2) # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0) # return encoding_loss mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() # encode h_a, n_a = self.gen_a.encode(x_a) h_b, n_b = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(h_a + n_a) x_b_recon = self.gen_b.decode(h_b + n_b) # decode (cross domain) x_ba = self.gen_a.decode(h_b + n_b) x_ab = self.gen_b.decode(h_a + n_a) # encode again h_b_recon, n_b_recon = self.gen_a.encode(x_ba) h_a_recon, n_a_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_kl_a = self.__compute_kl(h_a) self.loss_gen_recon_kl_b = self.__compute_kl(h_b) self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b self.loss_gen_total.backward() self.gen_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def sample(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] for i in range(x_a.size(0)): h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(h_a)) x_b_recon.append(self.gen_b.decode(h_b)) x_ba.append(self.gen_a.decode(h_b)) x_ab.append(self.gen_b.decode(h_a)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba = torch.cat(x_ba) x_ab = torch.cat(x_ab) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() # encode h_a, n_a = self.gen_a.encode(x_a) h_b, n_b = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(h_b + n_b) x_ab = self.gen_b.decode(h_a + n_a) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
class UNIT_Trainer(nn.Module): def __init__(self, hyperparameters, resume_epoch=-1, snapshot_dir=None): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] # Initiate the networks. self.gen = VAEGen( hyperparameters['input_dim'] + hyperparameters['n_datasets'], hyperparameters['gen'], hyperparameters['n_datasets']) # Auto-encoder for domain a. self.dis = MsImageDis( hyperparameters['input_dim'] + hyperparameters['n_datasets'], hyperparameters['dis']) # Discriminator for domain a. self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.sup = UNet(input_channels=hyperparameters['input_dim'], num_classes=2).cuda() # Setup the optimizers. beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis.parameters()) gen_params = list(self.gen.parameters()) + list(self.sup.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization. self.apply(weights_init(hyperparameters['init'])) self.dis.apply(weights_init('gaussian')) # Presetting one hot encoding vectors. self.one_hot_img = torch.zeros(hyperparameters['n_datasets'], hyperparameters['batch_size'], hyperparameters['n_datasets'], 256, 256).cuda() self.one_hot_h = torch.zeros(hyperparameters['n_datasets'], hyperparameters['batch_size'], hyperparameters['n_datasets'], 64, 64).cuda() for i in range(hyperparameters['n_datasets']): self.one_hot_img[i, :, i, :, :].fill_(1) self.one_hot_h[i, :, i, :, :].fill_(1) if resume_epoch != -1: self.resume(snapshot_dir, hyperparameters) def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def semi_criterion(self, input, target): loss = CrossEntropyLoss2d(size_average=False).cuda() return loss(input, target) def forward(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True h_a, _ = self.gen_a.encode(x_a) h_b, _ = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(h_b) x_ab = self.gen_b.decode(h_a) self.train() return x_ab, x_ba def __compute_kl(self, mu): # def _compute_kl(self, mu, sd): # mu_2 = torch.pow(mu, 2) # sd_2 = torch.pow(sd, 2) # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0) # return encoding_loss mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def set_gen_trainable(self, train_bool): if train_bool: self.gen.train() for param in self.gen.parameters(): param.requires_grad = True else: self.gen.eval() for param in self.gen.parameters(): param.requires_grad = True def set_sup_trainable(self, train_bool): if train_bool: self.sup.train() for param in self.sup.parameters(): param.requires_grad = True else: self.sup.eval() for param in self.sup.parameters(): param.requires_grad = True def sup_update(self, x_a, x_b, y_a, y_b, d_index_a, d_index_b, use_a, use_b, hyperparameters): self.gen_opt.zero_grad() # Encode. one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) h_a, n_a = self.gen.encode(one_hot_x_a) h_b, n_b = self.gen.encode(one_hot_x_b) # Decode (within domain). one_hot_h_a = torch.cat([h_a + n_a, self.one_hot_h[d_index_a]], 1) one_hot_h_b = torch.cat([h_b + n_b, self.one_hot_h[d_index_b]], 1) x_a_recon = self.gen.decode(one_hot_h_a) x_b_recon = self.gen.decode(one_hot_h_b) # Decode (cross domain). one_hot_h_ab = torch.cat([h_a + n_a, self.one_hot_h[d_index_b]], 1) one_hot_h_ba = torch.cat([h_b + n_b, self.one_hot_h[d_index_a]], 1) x_ba = self.gen.decode(one_hot_h_ba) x_ab = self.gen.decode(one_hot_h_ab) # Encode again. one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1) one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1) h_b_recon, n_b_recon = self.gen.encode(one_hot_x_ba) h_a_recon, n_a_recon = self.gen.encode(one_hot_x_ab) # Decode again (if needed). one_hot_h_a_recon = torch.cat( [h_a_recon + n_a_recon, self.one_hot_h[d_index_a]], 1) one_hot_h_b_recon = torch.cat( [h_b_recon + n_b_recon, self.one_hot_h[d_index_b]], 1) x_aba = self.gen.decode( one_hot_h_a_recon ) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen.decode( one_hot_h_b_recon ) if hyperparameters['recon_x_cyc_w'] > 0 else None # Forwarding through supervised model. p_a = None p_b = None loss_semi_a = None loss_semi_b = None has_a_label = (h_a[use_a, :, :, :].size(0) != 0) if has_a_label: p_a = self.sup(h_a, use_a, True) p_a_recon = self.sup(h_a_recon, use_a, True) loss_semi_a = self.semi_criterion(p_a, y_a[use_a, :, :]) + \ self.semi_criterion(p_a_recon, y_a[use_a, :, :]) has_b_label = (h_b[use_b, :, :, :].size(0) != 0) if has_b_label: p_b = self.sup(h_b, use_b, True) p_b_recon = self.sup(h_b, use_b, True) loss_semi_b = self.semi_criterion(p_b, y_b[use_b, :, :]) + \ self.semi_criterion(p_b_recon, y_b[use_b, :, :]) self.loss_gen_total = None if loss_semi_a is not None and loss_semi_b is not None: self.loss_gen_total = loss_semi_a + loss_semi_b elif loss_semi_a is not None: self.loss_gen_total = loss_semi_a elif loss_semi_b is not None: self.loss_gen_total = loss_semi_b if self.loss_gen_total is not None: self.loss_gen_total.backward() self.gen_opt.step() def sup_forward(self, x, y, d_index, hyperparameters): self.sup.eval() # Encoding content image. one_hot_x = torch.cat([x, self.one_hot_img[d_index, 0].unsqueeze(0)], 1) hidden, _ = self.gen.encode(one_hot_x) # Forwarding on supervised model. y_pred = self.sup(hidden, only_prediction=True) # Computing metrics. pred = y_pred.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy() jacc = jaccard(pred, y.cpu().squeeze(0).numpy()) return jacc, pred, hidden def gen_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters): self.gen_opt.zero_grad() # Encode. one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) h_a, n_a = self.gen.encode(one_hot_x_a) h_b, n_b = self.gen.encode(one_hot_x_b) # Decode (within domain). one_hot_h_a = torch.cat([h_a + n_a, self.one_hot_h[d_index_a]], 1) one_hot_h_b = torch.cat([h_b + n_b, self.one_hot_h[d_index_b]], 1) x_a_recon = self.gen.decode(one_hot_h_a) x_b_recon = self.gen.decode(one_hot_h_b) # Decode (cross domain). one_hot_h_ab = torch.cat([h_a + n_a, self.one_hot_h[d_index_b]], 1) one_hot_h_ba = torch.cat([h_b + n_b, self.one_hot_h[d_index_a]], 1) x_ba = self.gen.decode(one_hot_h_ba) x_ab = self.gen.decode(one_hot_h_ab) # Encode again. one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1) one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1) h_b_recon, n_b_recon = self.gen.encode(one_hot_x_ba) h_a_recon, n_a_recon = self.gen.encode(one_hot_x_ab) # Decode again (if needed). one_hot_h_a_recon = torch.cat( [h_a_recon + n_a_recon, self.one_hot_h[d_index_a]], 1) one_hot_h_b_recon = torch.cat( [h_b_recon + n_b_recon, self.one_hot_h[d_index_b]], 1) x_aba = self.gen.decode( one_hot_h_a_recon ) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen.decode( one_hot_h_b_recon ) if hyperparameters['recon_x_cyc_w'] > 0 else None # Reconstruction loss. self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_kl_a = self.__compute_kl(h_a) self.loss_gen_recon_kl_b = self.__compute_kl(h_b) self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) # GAN loss. self.loss_gen_adv_a = self.dis.calc_gen_loss(one_hot_x_ba) self.loss_gen_adv_b = self.dis.calc_gen_loss(one_hot_x_ab) # Total loss. self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab self.loss_gen_total.backward() self.gen_opt.step() def sample(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] for i in range(x_a.size(0)): h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(h_a)) x_b_recon.append(self.gen_b.decode(h_b)) x_ba.append(self.gen_a.decode(h_b)) x_ab.append(self.gen_b.decode(h_a)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba = torch.cat(x_ba) x_ab = torch.cat(x_ab) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def dis_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters): self.dis_opt.zero_grad() # Encode. one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1) one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1) h_a, n_a = self.gen.encode(one_hot_x_a) h_b, n_b = self.gen.encode(one_hot_x_b) # Decode (cross domain). one_hot_h_ab = torch.cat([h_a + n_a, self.one_hot_h[d_index_b]], 1) one_hot_h_ba = torch.cat([h_b + n_b, self.one_hot_h[d_index_a]], 1) x_ba = self.gen.decode(one_hot_h_ba) x_ab = self.gen.decode(one_hot_h_ab) # D loss. one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1) one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1) self.loss_dis_a = self.dis.calc_dis_loss(one_hot_x_ba.detach(), one_hot_x_a) self.loss_dis_b = self.dis.calc_dis_loss(one_hot_x_ab.detach(), one_hot_x_b) self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \ hyperparameters['gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators. last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen.load_state_dict(state_dict) epochs = int(last_model_name[-11:-3]) # Load discriminators. last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis.load_state_dict(state_dict) # Load supervised model. last_model_name = get_model_list(checkpoint_dir, "sup") state_dict = torch.load(last_model_name) self.sup.load_state_dict(state_dict) # Load optimizers. last_model_name = get_model_list(checkpoint_dir, "opt") state_dict = torch.load(last_model_name) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) for state in self.dis_opt.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() for state in self.gen_opt.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() # Reinitilize schedulers. self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, epochs) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, epochs) print('Resume from iteration %d' % epochs) return epochs def save(self, snapshot_dir, epoch): # Save generators, discriminators, and optimizers. gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % epoch) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % epoch) sup_name = os.path.join(snapshot_dir, 'sup_%08d.pt' % epoch) opt_name = os.path.join(snapshot_dir, 'opt_%08d.pt' % epoch) torch.save(self.gen.state_dict(), gen_name) torch.save(self.dis.state_dict(), dis_name) torch.save(self.sup.state_dict(), sup_name) torch.save( { 'dis': self.dis_opt.state_dict(), 'gen': self.gen_opt.state_dict() }, opt_name)