Ejemplo n.º 1
0
class UNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(UNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        h_a, _ = self.gen_a.encode(x_a)
        h_b, _ = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(h_b)
        x_ab = self.gen_b.decode(h_a)
        self.train()
        return x_ab, x_ba

    def __compute_kl(self, mu):
        # def _compute_kl(self, mu, sd):
        # mu_2 = torch.pow(mu, 2)
        # sd_2 = torch.pow(sd, 2)
        # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0)
        # return encoding_loss
        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        # encode
        h_a, n_a = self.gen_a.encode(x_a)
        h_b, n_b = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(h_a + n_a)
        x_b_recon = self.gen_b.decode(h_b + n_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(h_b + n_b)
        x_ab = self.gen_b.decode(h_a + n_a)
        # encode again
        h_b_recon, n_b_recon = self.gen_a.encode(x_ba)
        h_a_recon, n_a_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
        self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
        self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a)
        self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b)
        self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
        self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], []
        for i in range(x_a.size(0)):
            h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0))
            h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(h_a))
            x_b_recon.append(self.gen_b.decode(h_b))
            x_ba.append(self.gen_a.decode(h_b))
            x_ab.append(self.gen_b.decode(h_a))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba = torch.cat(x_ba)
        x_ab = torch.cat(x_ab)
        self.train()
        return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        # encode
        h_a, n_a = self.gen_a.encode(x_a)
        h_b, n_b = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(h_b + n_b)
        x_ab = self.gen_b.decode(h_a + n_a)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Ejemplo n.º 2
0
class Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        # auto-encoder for domain a
        self.trait_dim = hyperparameters['gen']['trait_dim']

        self.gen_a = VAEGen(hyperparameters['input_dim'],
                            hyperparameters['basis_encoder_dims'],
                            hyperparameters['trait_encoder_dims'],
                            hyperparameters['decoder_dims'], self.trait_dim)
        # auto-encoder for domain b
        self.gen_b = VAEGen(hyperparameters['input_dim'],
                            hyperparameters['basis_encoder_dims'],
                            hyperparameters['trait_encoder_dims'],
                            hyperparameters['decoder_dims'], self.trait_dim)
        # discriminator for domain a
        self.dis_a = Discriminator(hyperparameters['input_dim'],
                                   hyperparameters['dis_dims'], 1)
        # discriminator for domain b
        self.dis_b = Discriminator(hyperparameters['input_dim'],
                                   hyperparameters['dis_dims'], 1)

        # fix the noise used in sampling
        self.trait_a = torch.randn(8, self.trait_dim, 1, 1)
        self.trait_b = torch.randn(8, self.trait_dim, 1, 1)

        # Setup the optimizers
        dis_params = list(self.dis_a.parameters()) + \
            list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + \
            list(self.gen_b.parameters())
        for _p in gen_params:
            print(_p.data.shape)
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.gen_a.apply(weights_init('gaussian'))
        self.gen_b.apply(weights_init('gaussian'))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        trait_a = Variable(self.trait_a)
        trait_b = Variable(self.trait_b)
        basis_a, trait_a_fake = self.gen_a.encode(x_a)
        basis_b, trait_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(basis_b, trait_a)
        x_ab = self.gen_b.decode(basis_a, trait_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        trait_a = Variable(torch.randn(x_a.size(0), self.trait_dim))
        trait_b = Variable(torch.randn(x_b.size(0), self.trait_dim))
        # encode
        basis_a, trait_a_prime = self.gen_a.encode(x_a)
        basis_b, trait_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(basis_a, trait_a_prime)
        x_b_recon = self.gen_b.decode(basis_b, trait_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(basis_b, trait_a)
        x_ab = self.gen_b.decode(basis_a, trait_b)
        # encode again
        basis_b_recon, trait_a_recon = self.gen_a.encode(x_ba)
        basis_a_recon, trait_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(
            basis_a_recon,
            trait_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            basis_b_recon,
            trait_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_trait_a = self.recon_criterion(
            trait_a_recon, trait_a)
        self.loss_gen_recon_trait_b = self.recon_criterion(
            trait_b_recon, trait_b)
        self.loss_gen_recon_basis_a = self.recon_criterion(
            basis_a_recon, basis_a)
        self.loss_gen_recon_basis_b = self.recon_criterion(
            basis_b_recon, basis_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)

        # total loss
        self.loss_gen_total = hyperparameters[
            'gan_w'] * self.loss_gen_adv_a + \
            hyperparameters['gan_w'] * self.loss_gen_adv_b + \
            hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
            hyperparameters['recon_trait_w'] * self.loss_gen_recon_trait_a + \
            hyperparameters['recon_basis_w'] * self.loss_gen_recon_basis_a + \
            hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
            hyperparameters['recon_trait_w'] * self.loss_gen_recon_trait_b + \
            hyperparameters['recon_basis_w'] * self.loss_gen_recon_basis_b + \
            hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
            hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b

        self.loss_gen_total.backward()
        self.gen_opt.step()

    # def sample(self, x_a, x_b):
    #     self.eval()
    #     s_a1 = Variable(self.s_a)
    #     s_b1 = Variable(self.s_b)
    #     s_a2 = Variable(torch.randn(x_a.size(0), self.trait_dim, 1, 1))
    #     s_b2 = Variable(torch.randn(x_b.size(0), self.trait_dim, 1, 1))
    #     x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
    #     for i in range(x_a.size(0)):
    #         c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
    #         c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
    #         x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
    #         x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
    #         x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
    #         x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
    #         x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
    #         x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
    #     x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
    #     x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
    #     x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
    #     self.train()
    #     return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        trait_a = Variable(torch.randn(x_a.size(0), self.trait_dim))
        trait_b = Variable(torch.randn(x_b.size(0), self.trait_dim))
        # encode
        basis_a, _ = self.gen_a.encode(x_a)
        basis_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(basis_b, trait_a)
        x_ab = self.gen_b.decode(basis_a, trait_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba, x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab, x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * \
            self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Ejemplo n.º 3
0
class UNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(UNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        h_a, _ = self.gen_a.encode(x_a)
        h_b, _ = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(h_b)
        x_ab = self.gen_b.decode(h_a)
        self.train()
        return x_ab, x_ba

    def __compute_kl(self, mu):
        # def _compute_kl(self, mu, sd):
        # mu_2 = torch.pow(mu, 2)
        # sd_2 = torch.pow(sd, 2)
        # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0)
        # return encoding_loss
        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        # encode
        h_a, n_a = self.gen_a.encode(x_a)
        h_b, n_b = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(h_a + n_a)
        x_b_recon = self.gen_b.decode(h_b + n_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(h_b + n_b)
        x_ab = self.gen_b.decode(h_a + n_a)
        # encode again
        h_b_recon, n_b_recon = self.gen_a.encode(x_ba)
        h_a_recon, n_a_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
        self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
        self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a)
        self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b)
        self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
        self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def sample(self, x_a, x_b):
        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], []
        for i in range(x_a.size(0)):
            h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0))
            h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(h_a))
            x_b_recon.append(self.gen_b.decode(h_b))
            x_ba.append(self.gen_a.decode(h_b))
            x_ab.append(self.gen_b.decode(h_a))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba = torch.cat(x_ba)
        x_ab = torch.cat(x_ab)
        self.train()
        return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        # encode
        h_a, n_a = self.gen_a.encode(x_a)
        h_b, n_b = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(h_b + n_b)
        x_ab = self.gen_b.decode(h_a + n_a)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Ejemplo n.º 4
0
class UNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(UNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        if not hyperparameters['origin']:
            self.dis_a = MultiscaleDiscriminator(hyperparameters['input_dim_a'],        # discriminator for a
                    ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d, use_sigmoid=False,
                    num_D=2, getIntermFeat=True
                    )
            self.dis_b = MultiscaleDiscriminator(hyperparameters['input_dim_b'],        # discriminator for b
                    ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d, use_sigmoid=False,
                    num_D=2, getIntermFeat=True
                    )
            self.criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.FloatTensor)

        else:
            self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])
            self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])
            
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)


        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def compute_digits_differnce(self, digits1, digits2, weight=1.0):
        feat_diff = 0
        feat_weights = 4.0 / (3 + 1) # 3 layers's discrminator
        D_weights = 1.0 / 2.0  # number of discrminator
        for i in range(2):
            for j in range(len(digits2[i])-1):
                feat_diff += D_weights * feat_weights * \
                    F.l1_loss(digits2[i][j],
                            digits1[i][j].detach()) * weight
        return feat_diff



    def compute_gan_loss(self, real_digits, fake_digits, gan_cri,
            loss_at='None'):
        errD = None
        errG = None
        errG_feat = None
        if gan_cri is not None:
            if loss_at == 'D':
                errD = (gan_cri(real_digits, True) \
                        + gan_cri(fake_digits, False)) * 0.5
            elif loss_at == 'G':
                errG = gan_cri(fake_digits, True)
                errG_feat = self.compute_digits_differnce(real_digits, fake_digits,
                      weight=10.0)

        return errD, errG, errG_feat

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        h_a, _ = self.gen_a.encode(x_a)
        h_b, _ = self.gen_b.encode(x_b)
        x_ba, _ = self.gen_a.decode(h_b)
        x_ab, _ = self.gen_b.decode(h_a)
        self.train()
        return x_ab, x_ba

    def __compute_kl(self, mu):
        # def _compute_kl(self, mu, sd):
        # mu_2 = torch.pow(mu, 2)
        # sd_2 = torch.pow(sd, 2)
        # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0)
        # return encoding_loss
        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        ############ Encode #########################################$
        h_a, n_a = self.gen_a.encode(x_a)
        h_b, n_b = self.gen_b.encode(x_b)
        # decode (within domain)
        if hyperparameters['zero_z']:
            pre_Z = Variable(torch.zeros(hyperparameters['batch_size'], 
                    hyperparameters['gen']['z_num']).cuda())
        else:
            pre_Z = None

        ########### Reconstruction ###################################
        x_a_recon, _ = self.gen_a.decode(h_a + n_a, z_var=pre_Z)
        x_b_recon, _ = self.gen_b.decode(h_b + n_b, z_var=pre_Z)

        ##############################################################
        ########### Decode (Cross Domain) ############################
        ##############################################################

        ########## with random vector ################
        x_ba, z_var_ba_1 = self.gen_a.decode(h_b + n_b)
        x_ab, z_var_ab_1 = self.gen_b.decode(h_a + n_a)

        ########## with zero latent vector ###########
        x_ba_zero, _ = self.gen_a.decode(h_b + n_b, z_var=pre_Z)
        x_ab_zero, _ = self.gen_b.decode(h_a + n_a, z_var=pre_Z)

        ######## decode (cross domain the second time) ################
        if hyperparameters['loss_eg_weight'] != 0:
            x_ba_eg, z_var_ba_2 = self.gen_a.decode(h_b + n_b)
            x_ab_eg, z_var_ab_2 = self.gen_b.decode(h_a + n_a)
            x_ba_eg = x_ba_eg.detach()
            x_ab_eg = x_ab_eg.detach()
            if not hyperparameters['origin']:
                x_ba_eg_digits = self.dis_a(x_ba_eg)
                x_ab_eg_digits = self.dis_b(x_ab_eg)

        # encode again
        h_b_recon, n_b_recon = self.gen_a.encode(x_ba_zero)
        h_a_recon, n_a_recon = self.gen_b.encode(x_ab_zero)
        # decode again (if needed)
        x_aba, _ = self.gen_a.decode(h_a_recon + n_a_recon, 
                z_var=pre_Z
                ) if hyperparameters['recon_x_cyc_w'] > 0 else (None, 0)
        x_bab, _ = self.gen_b.decode(h_b_recon + n_b_recon, 
                z_var=pre_Z
                ) if hyperparameters['recon_x_cyc_w'] > 0 else (None, 0)

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
        self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
        self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) if x_aba is not None else 0
        self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) if x_bab is not None else 0
        self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
        self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)

        if hyperparameters['loss_eg_weight'] == 0:
            self.loss_gen_adv_a, self.loss_gen_adv_b, self.loss_gan_feat_a, \
                    self.loss_gan_feat_b = 0, 0, 0, 0
        elif not hyperparameters['origin']:
            x_ba_digits = self.dis_a(x_ba)
            x_a_digits = self.dis_a(x_a)
            _, self.loss_gen_adv_a, self.loss_gan_feat_a = \
                    self.compute_gan_loss(x_a_digits, x_ba_digits, 
                            self.criterionGAN, loss_at='G')

            x_ab_digits = self.dis_a(x_ab)
            x_b_digits = self.dis_a(x_b)
            _, self.loss_gen_adv_b, self.loss_gan_feat_b = \
                    self.compute_gan_loss(x_b_digits, x_ab_digits, 
                            self.criterionGAN, loss_at='G')
        else:
            self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
            self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
            self.loss_gan_feat_a, self.loss_gan_feat_b = 0, 0

        # GAN loss
        # self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        # self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        if hyperparameters['loss_eg_weight'] == 0:
            self.loss_eg = 0.0
        elif not hyperparameters['origin']:
            self.loss_eg = compute_eg_loss(x_ba_digits, x_ba_eg_digits, 
                    x_ab_digits, x_ab_eg_digits, z_var_ba_1, z_var_ba_2,
                     z_var_ab_1, z_var_ab_2, hyperparameters)
        else:
            self.loss_eg = compute_eg_loss(x_ba, x_ba_eg, 
                    x_ab, x_ab_eg, z_var_ba_1, z_var_ba_2, z_var_ab_1, z_var_ab_2, 
                    hyperparameters)
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              self.loss_gan_feat_b + self.loss_gan_feat_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
                              hyperparameters['loss_eg_weight'] * self.loss_eg
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], []
        for i in range(x_a.size(0)):
            h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0))
            h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(h_a)[0])
            x_b_recon.append(self.gen_b.decode(h_b)[0])
            x_ba.append(self.gen_a.decode(h_b)[0])
            x_ab.append(self.gen_b.decode(h_a)[0])
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba = torch.cat(x_ba)
        x_ab = torch.cat(x_ab)
        self.train()
        return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        # encode
        h_a, n_a = self.gen_a.encode(x_a)
        h_b, n_b = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba, _ = self.gen_a.decode(h_b + n_b)
        x_ab, _ = self.gen_b.decode(h_a + n_a)
        # D loss
        if not hyperparameters['origin']:
            real_digits_a = self.dis_a(x_a)
            fake_digits_a = self.dis_a(x_ba.detach())
            real_digits_b = self.dis_b(x_b)
            fake_digits_b = self.dis_b(x_ab.detach())

            self.loss_dis_a, _, _ = self.compute_gan_loss(real_digits_a, fake_digits_a, 
                    self.criterionGAN, loss_at='D')
            self.loss_dis_b, _, _ = self.compute_gan_loss(real_digits_b, fake_digits_b, 
                    self.criterionGAN, loss_at='D')
        else:
            self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
            self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)


        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Ejemplo n.º 5
0
class UNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters, resume_epoch=-1, snapshot_dir=None):

        super(UNIT_Trainer, self).__init__()

        lr = hyperparameters['lr']

        # Initiate the networks.
        self.gen = VAEGen(
            hyperparameters['input_dim'] + hyperparameters['n_datasets'],
            hyperparameters['gen'],
            hyperparameters['n_datasets'])  # Auto-encoder for domain a.
        self.dis = MsImageDis(
            hyperparameters['input_dim'] + hyperparameters['n_datasets'],
            hyperparameters['dis'])  # Discriminator for domain a.

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        self.sup = UNet(input_channels=hyperparameters['input_dim'],
                        num_classes=2).cuda()

        # Setup the optimizers.
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis.parameters())
        gen_params = list(self.gen.parameters()) + list(self.sup.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization.
        self.apply(weights_init(hyperparameters['init']))
        self.dis.apply(weights_init('gaussian'))

        # Presetting one hot encoding vectors.
        self.one_hot_img = torch.zeros(hyperparameters['n_datasets'],
                                       hyperparameters['batch_size'],
                                       hyperparameters['n_datasets'], 256,
                                       256).cuda()
        self.one_hot_h = torch.zeros(hyperparameters['n_datasets'],
                                     hyperparameters['batch_size'],
                                     hyperparameters['n_datasets'], 64,
                                     64).cuda()

        for i in range(hyperparameters['n_datasets']):
            self.one_hot_img[i, :, i, :, :].fill_(1)
            self.one_hot_h[i, :, i, :, :].fill_(1)

        if resume_epoch != -1:

            self.resume(snapshot_dir, hyperparameters)

    def recon_criterion(self, input, target):

        return torch.mean(torch.abs(input - target))

    def semi_criterion(self, input, target):

        loss = CrossEntropyLoss2d(size_average=False).cuda()
        return loss(input, target)

    def forward(self, x_a, x_b):

        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        h_a, _ = self.gen_a.encode(x_a)
        h_b, _ = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(h_b)
        x_ab = self.gen_b.decode(h_a)
        self.train()
        return x_ab, x_ba

    def __compute_kl(self, mu):

        # def _compute_kl(self, mu, sd):
        # mu_2 = torch.pow(mu, 2)
        # sd_2 = torch.pow(sd, 2)
        # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0)
        # return encoding_loss

        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def set_gen_trainable(self, train_bool):

        if train_bool:
            self.gen.train()
            for param in self.gen.parameters():
                param.requires_grad = True

        else:
            self.gen.eval()
            for param in self.gen.parameters():
                param.requires_grad = True

    def set_sup_trainable(self, train_bool):

        if train_bool:
            self.sup.train()
            for param in self.sup.parameters():
                param.requires_grad = True
        else:
            self.sup.eval()
            for param in self.sup.parameters():
                param.requires_grad = True

    def sup_update(self, x_a, x_b, y_a, y_b, d_index_a, d_index_b, use_a,
                   use_b, hyperparameters):

        self.gen_opt.zero_grad()

        # Encode.
        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)
        h_a, n_a = self.gen.encode(one_hot_x_a)
        h_b, n_b = self.gen.encode(one_hot_x_b)

        # Decode (within domain).
        one_hot_h_a = torch.cat([h_a + n_a, self.one_hot_h[d_index_a]], 1)
        one_hot_h_b = torch.cat([h_b + n_b, self.one_hot_h[d_index_b]], 1)
        x_a_recon = self.gen.decode(one_hot_h_a)
        x_b_recon = self.gen.decode(one_hot_h_b)

        # Decode (cross domain).
        one_hot_h_ab = torch.cat([h_a + n_a, self.one_hot_h[d_index_b]], 1)
        one_hot_h_ba = torch.cat([h_b + n_b, self.one_hot_h[d_index_a]], 1)
        x_ba = self.gen.decode(one_hot_h_ba)
        x_ab = self.gen.decode(one_hot_h_ab)

        # Encode again.
        one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1)
        one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1)
        h_b_recon, n_b_recon = self.gen.encode(one_hot_x_ba)
        h_a_recon, n_a_recon = self.gen.encode(one_hot_x_ab)

        # Decode again (if needed).
        one_hot_h_a_recon = torch.cat(
            [h_a_recon + n_a_recon, self.one_hot_h[d_index_a]], 1)
        one_hot_h_b_recon = torch.cat(
            [h_b_recon + n_b_recon, self.one_hot_h[d_index_b]], 1)
        x_aba = self.gen.decode(
            one_hot_h_a_recon
        ) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen.decode(
            one_hot_h_b_recon
        ) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # Forwarding through supervised model.
        p_a = None
        p_b = None
        loss_semi_a = None
        loss_semi_b = None

        has_a_label = (h_a[use_a, :, :, :].size(0) != 0)
        if has_a_label:
            p_a = self.sup(h_a, use_a, True)
            p_a_recon = self.sup(h_a_recon, use_a, True)
            loss_semi_a = self.semi_criterion(p_a, y_a[use_a, :, :]) + \
                          self.semi_criterion(p_a_recon, y_a[use_a, :, :])

        has_b_label = (h_b[use_b, :, :, :].size(0) != 0)
        if has_b_label:
            p_b = self.sup(h_b, use_b, True)
            p_b_recon = self.sup(h_b, use_b, True)
            loss_semi_b = self.semi_criterion(p_b, y_b[use_b, :, :]) + \
                          self.semi_criterion(p_b_recon, y_b[use_b, :, :])

        self.loss_gen_total = None
        if loss_semi_a is not None and loss_semi_b is not None:
            self.loss_gen_total = loss_semi_a + loss_semi_b
        elif loss_semi_a is not None:
            self.loss_gen_total = loss_semi_a
        elif loss_semi_b is not None:
            self.loss_gen_total = loss_semi_b

        if self.loss_gen_total is not None:
            self.loss_gen_total.backward()
            self.gen_opt.step()

    def sup_forward(self, x, y, d_index, hyperparameters):

        self.sup.eval()

        # Encoding content image.
        one_hot_x = torch.cat([x, self.one_hot_img[d_index, 0].unsqueeze(0)],
                              1)
        hidden, _ = self.gen.encode(one_hot_x)

        # Forwarding on supervised model.
        y_pred = self.sup(hidden, only_prediction=True)

        # Computing metrics.
        pred = y_pred.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()

        jacc = jaccard(pred, y.cpu().squeeze(0).numpy())

        return jacc, pred, hidden

    def gen_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters):

        self.gen_opt.zero_grad()

        # Encode.
        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)
        h_a, n_a = self.gen.encode(one_hot_x_a)
        h_b, n_b = self.gen.encode(one_hot_x_b)

        # Decode (within domain).
        one_hot_h_a = torch.cat([h_a + n_a, self.one_hot_h[d_index_a]], 1)
        one_hot_h_b = torch.cat([h_b + n_b, self.one_hot_h[d_index_b]], 1)
        x_a_recon = self.gen.decode(one_hot_h_a)
        x_b_recon = self.gen.decode(one_hot_h_b)

        # Decode (cross domain).
        one_hot_h_ab = torch.cat([h_a + n_a, self.one_hot_h[d_index_b]], 1)
        one_hot_h_ba = torch.cat([h_b + n_b, self.one_hot_h[d_index_a]], 1)
        x_ba = self.gen.decode(one_hot_h_ba)
        x_ab = self.gen.decode(one_hot_h_ab)

        # Encode again.
        one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1)
        one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1)
        h_b_recon, n_b_recon = self.gen.encode(one_hot_x_ba)
        h_a_recon, n_a_recon = self.gen.encode(one_hot_x_ab)

        # Decode again (if needed).
        one_hot_h_a_recon = torch.cat(
            [h_a_recon + n_a_recon, self.one_hot_h[d_index_a]], 1)
        one_hot_h_b_recon = torch.cat(
            [h_b_recon + n_b_recon, self.one_hot_h[d_index_b]], 1)
        x_aba = self.gen.decode(
            one_hot_h_a_recon
        ) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen.decode(
            one_hot_h_b_recon
        ) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # Reconstruction loss.
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
        self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
        self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a)
        self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b)
        self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
        self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)

        # GAN loss.
        self.loss_gen_adv_a = self.dis.calc_gen_loss(one_hot_x_ba)
        self.loss_gen_adv_b = self.dis.calc_gen_loss(one_hot_x_ab)

        # Total loss.
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def sample(self, x_a, x_b):

        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], []
        for i in range(x_a.size(0)):
            h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0))
            h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(h_a))
            x_b_recon.append(self.gen_b.decode(h_b))
            x_ba.append(self.gen_a.decode(h_b))
            x_ab.append(self.gen_b.decode(h_a))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba = torch.cat(x_ba)
        x_ab = torch.cat(x_ab)
        self.train()
        return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

    def dis_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters):

        self.dis_opt.zero_grad()

        # Encode.
        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)
        h_a, n_a = self.gen.encode(one_hot_x_a)
        h_b, n_b = self.gen.encode(one_hot_x_b)

        # Decode (cross domain).
        one_hot_h_ab = torch.cat([h_a + n_a, self.one_hot_h[d_index_b]], 1)
        one_hot_h_ba = torch.cat([h_b + n_b, self.one_hot_h[d_index_a]], 1)
        x_ba = self.gen.decode(one_hot_h_ba)
        x_ab = self.gen.decode(one_hot_h_ab)

        # D loss.
        one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1)
        one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1)
        self.loss_dis_a = self.dis.calc_dis_loss(one_hot_x_ba.detach(),
                                                 one_hot_x_a)
        self.loss_dis_b = self.dis.calc_dis_loss(one_hot_x_ab.detach(),
                                                 one_hot_x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \
                              hyperparameters['gan_w'] * self.loss_dis_b

        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):

        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):

        # Load generators.
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen.load_state_dict(state_dict)
        epochs = int(last_model_name[-11:-3])

        # Load discriminators.
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis.load_state_dict(state_dict)

        # Load supervised model.
        last_model_name = get_model_list(checkpoint_dir, "sup")
        state_dict = torch.load(last_model_name)
        self.sup.load_state_dict(state_dict)

        # Load optimizers.
        last_model_name = get_model_list(checkpoint_dir, "opt")
        state_dict = torch.load(last_model_name)
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])

        for state in self.dis_opt.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        for state in self.gen_opt.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        # Reinitilize schedulers.
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           epochs)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           epochs)

        print('Resume from iteration %d' % epochs)
        return epochs

    def save(self, snapshot_dir, epoch):

        # Save generators, discriminators, and optimizers.
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % epoch)
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % epoch)
        sup_name = os.path.join(snapshot_dir, 'sup_%08d.pt' % epoch)
        opt_name = os.path.join(snapshot_dir, 'opt_%08d.pt' % epoch)

        torch.save(self.gen.state_dict(), gen_name)
        torch.save(self.dis.state_dict(), dis_name)
        torch.save(self.sup.state_dict(), sup_name)
        torch.save(
            {
                'dis': self.dis_opt.state_dict(),
                'gen': self.gen_opt.state_dict()
            }, opt_name)