Exemple #1
0
class Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        # auto-encoder for domain a
        self.trait_dim = hyperparameters['gen']['trait_dim']

        self.gen_a = VAEGen(hyperparameters['input_dim'],
                            hyperparameters['basis_encoder_dims'],
                            hyperparameters['trait_encoder_dims'],
                            hyperparameters['decoder_dims'], self.trait_dim)
        # auto-encoder for domain b
        self.gen_b = VAEGen(hyperparameters['input_dim'],
                            hyperparameters['basis_encoder_dims'],
                            hyperparameters['trait_encoder_dims'],
                            hyperparameters['decoder_dims'], self.trait_dim)
        # discriminator for domain a
        self.dis_a = Discriminator(hyperparameters['input_dim'],
                                   hyperparameters['dis_dims'], 1)
        # discriminator for domain b
        self.dis_b = Discriminator(hyperparameters['input_dim'],
                                   hyperparameters['dis_dims'], 1)

        # fix the noise used in sampling
        self.trait_a = torch.randn(8, self.trait_dim, 1, 1)
        self.trait_b = torch.randn(8, self.trait_dim, 1, 1)

        # Setup the optimizers
        dis_params = list(self.dis_a.parameters()) + \
            list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + \
            list(self.gen_b.parameters())
        for _p in gen_params:
            print(_p.data.shape)
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.gen_a.apply(weights_init('gaussian'))
        self.gen_b.apply(weights_init('gaussian'))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        trait_a = Variable(self.trait_a)
        trait_b = Variable(self.trait_b)
        basis_a, trait_a_fake = self.gen_a.encode(x_a)
        basis_b, trait_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(basis_b, trait_a)
        x_ab = self.gen_b.decode(basis_a, trait_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        trait_a = Variable(torch.randn(x_a.size(0), self.trait_dim))
        trait_b = Variable(torch.randn(x_b.size(0), self.trait_dim))
        # encode
        basis_a, trait_a_prime = self.gen_a.encode(x_a)
        basis_b, trait_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(basis_a, trait_a_prime)
        x_b_recon = self.gen_b.decode(basis_b, trait_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(basis_b, trait_a)
        x_ab = self.gen_b.decode(basis_a, trait_b)
        # encode again
        basis_b_recon, trait_a_recon = self.gen_a.encode(x_ba)
        basis_a_recon, trait_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(
            basis_a_recon,
            trait_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            basis_b_recon,
            trait_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_trait_a = self.recon_criterion(
            trait_a_recon, trait_a)
        self.loss_gen_recon_trait_b = self.recon_criterion(
            trait_b_recon, trait_b)
        self.loss_gen_recon_basis_a = self.recon_criterion(
            basis_a_recon, basis_a)
        self.loss_gen_recon_basis_b = self.recon_criterion(
            basis_b_recon, basis_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)

        # total loss
        self.loss_gen_total = hyperparameters[
            'gan_w'] * self.loss_gen_adv_a + \
            hyperparameters['gan_w'] * self.loss_gen_adv_b + \
            hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
            hyperparameters['recon_trait_w'] * self.loss_gen_recon_trait_a + \
            hyperparameters['recon_basis_w'] * self.loss_gen_recon_basis_a + \
            hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
            hyperparameters['recon_trait_w'] * self.loss_gen_recon_trait_b + \
            hyperparameters['recon_basis_w'] * self.loss_gen_recon_basis_b + \
            hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
            hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b

        self.loss_gen_total.backward()
        self.gen_opt.step()

    # def sample(self, x_a, x_b):
    #     self.eval()
    #     s_a1 = Variable(self.s_a)
    #     s_b1 = Variable(self.s_b)
    #     s_a2 = Variable(torch.randn(x_a.size(0), self.trait_dim, 1, 1))
    #     s_b2 = Variable(torch.randn(x_b.size(0), self.trait_dim, 1, 1))
    #     x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
    #     for i in range(x_a.size(0)):
    #         c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
    #         c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
    #         x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
    #         x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
    #         x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
    #         x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
    #         x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
    #         x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
    #     x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
    #     x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
    #     x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
    #     self.train()
    #     return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        trait_a = Variable(torch.randn(x_a.size(0), self.trait_dim))
        trait_b = Variable(torch.randn(x_b.size(0), self.trait_dim))
        # encode
        basis_a, _ = self.gen_a.encode(x_a)
        basis_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(basis_b, trait_a)
        x_ab = self.gen_b.decode(basis_a, trait_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba, x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab, x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * \
            self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
class UNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(UNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = VAEGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = VAEGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b

        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.dis_content = Dis_content()
        self.gpuid = hyperparameters['gpuID']
        # @ add backgound discriminator for each domain
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.content_opt = torch.optim.Adam(
            self.dis_content.parameters(),
            lr=lr / 2.,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.content_scheduler = get_scheduler(self.content_opt,
                                               hyperparameters)

        # Network weight initialization
        self.gen_a.apply(weights_init(hyperparameters['init']))
        self.gen_b.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))
        self.dis_content.apply(weights_init('gaussian'))

        # initialize the blur network
        self.BGBlur_kernel = [5, 9, 15]
        self.BlurNet = [
            GaussionSmoothLayer(3, k_size, 25).cuda(self.gpuid)
            for k_size in self.BGBlur_kernel
        ]
        self.BlurWeight = [0.25, 0.5, 1.]
        self.Gradient = GradientLoss(3, 3)

        # # Load VGG model if needed for test
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg19()
            if torch.cuda.is_available():
                self.vgg.cuda(self.gpuid)
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        h_a = self.gen_a.encode_cont(x_a)
        # h_a_sty = self.gen_a.encode_sty(x_a)
        # h_b = self.gen_b.encode_cont(x_b)

        x_ab = self.gen_b.decode_cont(h_a)
        # h_c = torch.cat((h_b, h_a_sty), 1)
        # x_ba = self.gen_a.decode_recs(h_c)
        # self.train()
        return x_ab  #, x_ba

    def __compute_kl(self, mu):
        # def _compute_kl(self, mu, sd):
        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def content_update(self, x_a, x_b, hyperparameters):  #
        # encode
        self.content_opt.zero_grad()
        enc_a = self.gen_a.encode_cont(x_a)
        enc_b = self.gen_b.encode_cont(x_b)
        pred_fake = self.dis_content.forward(enc_a)
        pred_real = self.dis_content.forward(enc_b)
        loss_D = 0
        if hyperparameters['gan_type'] == 'lsgan':
            loss_D += torch.mean((pred_fake - 0)**2) + torch.mean(
                (pred_real - 1)**2)
        elif hyperparameters['gan_type'] == 'nsgan':
            all0 = Variable(torch.zeros_like(pred_fake.data).cuda(self.gpuid),
                            requires_grad=False)
            all1 = Variable(torch.ones_like(pred_real.data).cuda(self.gpuid),
                            requires_grad=False)
            loss_D += torch.mean(
                F.binary_cross_entropy(F.sigmoid(pred_fake), all0) +
                F.binary_cross_entropy(F.sigmoid(pred_real), all1))
        else:
            assert 0, "Unsupported GAN type: {}".format(
                hyperparameters['gan_type'])
        loss_D.backward()
        nn.utils.clip_grad_norm_(self.dis_content.parameters(), 5)
        self.content_opt.step()

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        self.content_opt.zero_grad()
        # encode
        h_a = self.gen_a.encode_cont(x_a)
        h_b = self.gen_b.encode_cont(x_b)
        h_a_sty = self.gen_a.encode_sty(x_a)

        # add domain adverisal loss for generator
        out_a = self.dis_content(h_a)
        out_b = self.dis_content(h_b)
        self.loss_ContentD = 0
        if hyperparameters['gan_type'] == 'lsgan':
            self.loss_ContentD += torch.mean((out_a - 0.5)**2) + torch.mean(
                (out_b - 0.5)**2)
        elif hyperparameters['gan_type'] == 'nsgan':
            all1 = Variable(0.5 * torch.ones_like(out_b.data).cuda(self.gpuid),
                            requires_grad=False)
            self.loss_ContentD += torch.mean(
                F.binary_cross_entropy(F.sigmoid(out_a), all1) +
                F.binary_cross_entropy(F.sigmoid(out_b), all1))
        else:
            assert 0, "Unsupported GAN type: {}".format(
                hyperparameters['gan_type'])

        # decode (within domain)
        h_a_cont = torch.cat((h_a, h_a_sty), 1)
        noise_a = torch.randn(h_a_cont.size()).cuda(h_a_cont.data.get_device())
        x_a_recon = self.gen_a.decode_recs(h_a_cont + noise_a)
        noise_b = torch.randn(h_b.size()).cuda(h_b.data.get_device())
        x_b_recon = self.gen_b.decode_cont(h_b + noise_b)

        # decode (cross domain)
        h_ba_cont = torch.cat((h_b, h_a_sty), 1)
        x_ba = self.gen_a.decode_recs(h_ba_cont + noise_a)
        x_ab = self.gen_b.decode_cont(h_a + noise_b)

        # encode again
        h_b_recon = self.gen_a.encode_cont(x_ba)
        h_b_sty_recon = self.gen_a.encode_sty(x_ba)

        h_a_recon = self.gen_b.encode_cont(x_ab)

        # decode again (if needed)
        h_a_cat_recs = torch.cat((h_a_recon, h_b_sty_recon), 1)

        x_aba = self.gen_a.decode_recs(
            h_a_cat_recs +
            noise_a) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode_cont(
            h_b_recon +
            noise_b) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
        self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
        self.loss_gen_recon_kl_sty = self.__compute_kl(h_a_sty)

        self.loss_gen_cyc_x_a = self.recon_criterion(
            x_aba, x_a) if x_aba is not None else 0
        self.loss_gen_cyc_x_b = self.recon_criterion(
            x_bab, x_b) if x_aba is not None else 0
        self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
        self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)
        self.loss_gen_recon_kl_cyc_sty = self.__compute_kl(h_b_sty_recon)

        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)

        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0

        # add background guide loss
        self.loss_bgm = 0
        if hyperparameters['BGM'] != 0:
            for index, weight in enumerate(self.BlurWeight):
                out_b = self.BlurNet[index](x_ba)
                out_real_b = self.BlurNet[index](x_b)
                out_a = self.BlurNet[index](x_ab)
                out_real_a = self.BlurNet[index](x_a)
                grad_loss_b = self.recon_criterion(out_b, out_real_b)
                grad_loss_a = self.recon_criterion(out_a, out_real_a)
                self.loss_bgm += weight * (grad_loss_a + grad_loss_b)
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_sty + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_sty + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
                              hyperparameters['BGM'] * self.loss_bgm + \
                              hyperparameters['gan_w'] * self.loss_ContentD
        self.loss_gen_total.backward()
        self.gen_opt.step()
        self.content_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):
        if x_a is None or x_b is None:
            return None
        self.eval()
        x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], []
        for i in range(x_a.size(0)):
            h_a = self.gen_a.encode_cont(x_a[i].unsqueeze(0))
            h_a_sty = self.gen_a.encode_sty(x_a[i].unsqueeze(0))
            h_b = self.gen_b.encode_cont(x_b[i].unsqueeze(0))

            h_ba_cont = torch.cat((h_b, h_a_sty), 1)

            h_aa_cont = torch.cat((h_a, h_a_sty), 1)

            x_a_recon.append(self.gen_a.decode_recs(h_aa_cont))
            x_b_recon.append(self.gen_b.decode_cont(h_b))

            x_ba.append(self.gen_a.decode_recs(h_ba_cont))
            x_ab.append(self.gen_b.decode_cont(h_a))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba = torch.cat(x_ba)
        x_ab = torch.cat(x_ab)
        self.train()
        return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        self.content_opt.zero_grad()
        # encode
        h_a = self.gen_a.encode_cont(x_a)
        h_a_sty = self.gen_a.encode_sty(x_a)
        h_b = self.gen_b.encode_cont(x_b)

        # # @ add content adversial
        out_a = self.dis_content(h_a)
        out_b = self.dis_content(h_b)
        self.loss_ContentD = 0
        if hyperparameters['gan_type'] == 'lsgan':
            self.loss_ContentD += torch.mean((out_a - 0)**2) + torch.mean(
                (out_b - 1)**2)
        elif hyperparameters['gan_type'] == 'nsgan':
            all0 = Variable(torch.zeros_like(out_a.data).cuda(self.gpuid),
                            requires_grad=False)
            all1 = Variable(torch.ones_like(out_b.data).cuda(self.gpuid),
                            requires_grad=False)
            self.loss_ContentD += torch.mean(
                F.binary_cross_entropy(F.sigmoid(out_a), all0) +
                F.binary_cross_entropy(F.sigmoid(out_b), all1))
        else:
            assert 0, "Unsupported GAN type: {}".format(
                hyperparameters['gan_type'])

        # decode (cross domain)
        h_cat = torch.cat((h_b, h_a_sty), 1)
        noise_b = torch.randn(h_cat.size()).cuda(h_cat.data.get_device())
        x_ba = self.gen_a.decode_recs(h_cat + noise_b)
        noise_a = torch.randn(h_a.size()).cuda(h_a.data.get_device())
        x_ab = self.gen_b.decode_cont(h_a + noise_a)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)

        self.loss_dis_total = hyperparameters['gan_w'] * (
            self.loss_dis_a + self.loss_dis_b + self.loss_ContentD)
        self.loss_dis_total.backward()
        nn.utils.clip_grad_norm_(self.dis_content.parameters(),
                                 5)  # dis_content update
        self.dis_opt.step()
        self.content_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.content_scheduler is not None:
            self.content_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis_00188000")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])

        # load discontent discriminator
        last_model_name = get_model_list(checkpoint_dir, "dis_Content")
        state_dict = torch.load(last_model_name)
        self.dis_content.load_state_dict(state_dict['dis_c'])

        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        self.content_opt.load_state_dict(state_dict['dis_content'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        self.content_scheduler = get_scheduler(self.content_opt,
                                               hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        dis_Con_name = os.path.join(snapshot_dir,
                                    'dis_Content_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save({'dis_c': self.dis_content.state_dict()}, dis_Con_name)

        #  opt state
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict(), \
                                                    'dis_content':self.content_opt.state_dict()}, opt_name)