Beispiel #1
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        s_a = Variable(self.s_a, volatile=True)
        s_b = Variable(self.s_b, volatile=True)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def sample(self, x_a, x_b):
        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        s_a1 = Variable(self.s_a, volatile=True)
        s_b1 = Variable(self.s_b, volatile=True)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(), volatile=True)
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(), volatile=True)
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Beispiel #2
0
class UnsupIntrinsicTrainer(nn.Module):
    def __init__(self, param):
        super(UnsupIntrinsicTrainer, self).__init__()
        lr = param['lr']
        # Initiate the networks
        self.gen_i = AdaINGen(param['input_dim_a'], param['input_dim_a'],
                              param['gen'])  # auto-encoder for domain I
        self.gen_r = AdaINGen(param['input_dim_b'], param['input_dim_b'],
                              param['gen'])  # auto-encoder for domain R
        self.gen_s = AdaINGen(param['input_dim_c'], param['input_dim_c'],
                              param['gen'])  # auto-encoder for domain S
        self.dis_r = MsImageDis(param['input_dim_b'],
                                param['dis'])  # discriminator for domain R
        self.dis_s = MsImageDis(param['input_dim_c'],
                                param['dis'])  # discriminator for domain S
        gp = param['gen']
        self.with_mapping = True
        self.use_phy_loss = True
        self.use_content_loss = True
        if 'ablation_study' in param:
            if 'with_mapping' in param['ablation_study']:
                wm = param['ablation_study']['with_mapping']
                self.with_mapping = True if wm != 0 else False
            if 'wo_phy_loss' in param['ablation_study']:
                wpl = param['ablation_study']['wo_phy_loss']
                self.use_phy_loss = True if wpl == 0 else False
            if 'wo_content_loss' in param['ablation_study']:
                wcl = param['ablation_study']['wo_content_loss']
                self.use_content_loss = True if wcl == 0 else False

        if self.with_mapping:
            self.fea_s = IntrinsicSplitor(gp['style_dim'], gp['mlp_dim'],
                                          gp['n_layer'],
                                          gp['activ'])  # split style for I
            self.fea_m = IntrinsicMerger(gp['style_dim'], gp['mlp_dim'],
                                         gp['n_layer'],
                                         gp['activ'])  # merge style for R, S
        self.bias_shift = param['bias_shift']
        self.instance_norm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = param['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(param['display_size'])
        self.s_r = torch.randn(display_size, self.style_dim, 1, 1).cuda() + 1.
        self.s_s = torch.randn(display_size, self.style_dim, 1, 1).cuda() - 1.

        # Setup the optimizers
        beta1 = param['beta1']
        beta2 = param['beta2']
        dis_params = list(self.dis_r.parameters()) + list(
            self.dis_s.parameters())
        if self.with_mapping:
            gen_params = list(self.gen_i.parameters()) + list(self.gen_r.parameters()) + \
                         list(self.gen_s.parameters()) + \
                         list(self.fea_s.parameters()) + list(self.fea_m.parameters())
        else:
            gen_params = list(self.gen_i.parameters()) + list(
                self.gen_r.parameters()) + list(self.gen_s.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=param['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=param['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, param)
        self.gen_scheduler = get_scheduler(self.gen_opt, param)

        # Network weight initialization
        self.apply(weights_init(param['init']))
        self.dis_r.apply(weights_init('gaussian'))
        self.dis_s.apply(weights_init('gaussian'))
        self.best_result = float('inf')
        self.reflectance_loss = LocalAlbedoSmoothnessLoss(param)

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def physical_criterion(self, x_i, x_r, x_s):
        return torch.mean(torch.abs(x_i - x_r * x_s))

    def forward(self, x_i):
        c_i, s_i_fake = self.gen_i.encode(x_i)
        if self.with_mapping:
            s_r, s_s = self.fea_s(s_i_fake)
        else:
            s_r = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) + self.bias_shift
            s_s = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) - self.bias_shift
        x_ri = self.gen_r.decode(c_i, s_r)
        x_si = self.gen_s.decode(c_i, s_s)
        return x_ri, x_si

    def inference(self, x_i, use_rand_fea=False):
        with torch.no_grad():
            c_i, s_i_fake = self.gen_i.encode(x_i)
            if self.with_mapping:
                s_r, s_s = self.fea_s(s_i_fake)
            else:
                s_r = Variable(
                    torch.randn(x_i.size(0), self.style_dim, 1,
                                1).cuda()) + self.bias_shift
                s_s = Variable(
                    torch.randn(x_i.size(0), self.style_dim, 1,
                                1).cuda()) - self.bias_shift
            if use_rand_fea:
                s_r = Variable(
                    torch.randn(x_i.size(0), self.style_dim, 1,
                                1).cuda()) + self.bias_shift
                s_s = Variable(
                    torch.randn(x_i.size(0), self.style_dim, 1,
                                1).cuda()) - self.bias_shift
            x_ri = self.gen_r.decode(c_i, s_r)
            x_si = self.gen_s.decode(c_i, s_s)
        return x_ri, x_si

    # noinspection PyAttributeOutsideInit
    def gen_update(self, x_i, x_r, x_s, targets=None, param=None):
        self.gen_opt.zero_grad()
        # ============= Domain Translations =============
        # encode
        c_i, s_i_prime = self.gen_i.encode(x_i)
        c_r, s_r_prime = self.gen_r.encode(x_r)
        c_s, s_s_prime = self.gen_s.encode(x_s)

        if self.with_mapping:
            s_ri, s_si = self.fea_s(s_i_prime)
        else:
            s_ri = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) + self.bias_shift
            s_si = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) - self.bias_shift
        s_r_rand = Variable(
            torch.randn(x_r.size(0), self.style_dim, 1,
                        1).cuda()) + self.bias_shift
        s_s_rand = Variable(
            torch.randn(x_s.size(0), self.style_dim, 1,
                        1).cuda()) - self.bias_shift
        if self.with_mapping:
            s_i_recon = self.fea_m(s_ri, s_si)
        else:
            s_i_recon = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda())
        # decode (within domain)
        x_i_recon = self.gen_i.decode(c_i, s_i_prime)
        x_r_recon = self.gen_s.decode(c_r, s_r_prime)
        x_s_recon = self.gen_r.decode(c_s, s_s_prime)
        # decode (cross domain)
        x_rs = self.gen_r.decode(c_s, s_r_rand)
        x_ri = self.gen_r.decode(c_i, s_ri)
        x_ri_rand = self.gen_r.decode(c_i, s_r_rand)
        x_sr = self.gen_s.decode(c_r, s_s_rand)
        x_si = self.gen_s.decode(c_i, s_si)
        x_si_rand = self.gen_s.decode(c_i, s_r_rand)
        # encode again, for feature domain consistency constraints
        c_rs_recon, s_rs_recon = self.gen_r.encode(x_rs)
        c_ri_recon, s_ri_recon = self.gen_r.encode(x_ri)
        c_ri_rand_recon, s_ri_rand_recon = self.gen_r.encode(x_ri_rand)
        c_sr_recon, s_sr_recon = self.gen_s.encode(x_sr)
        c_si_recon, s_si_recon = self.gen_s.encode(x_si)
        c_si_rand_recon, s_si_rand_recon = self.gen_s.encode(x_si_rand)
        # decode again, for image domain cycle consistency
        x_rsr = self.gen_r.decode(c_sr_recon, s_r_prime)
        x_iri = self.gen_i.decode(c_ri_recon, s_i_prime)
        x_iri_rand = self.gen_i.decode(c_ri_rand_recon, s_i_prime)
        x_srs = self.gen_s.decode(c_rs_recon, s_s_prime)
        x_isi = self.gen_i.decode(c_si_recon, s_i_prime)
        x_isi_rand = self.gen_i.decode(c_si_rand_recon, s_i_prime)

        # ============= Loss Functions =============
        # Encoder decoder reconstruction loss for three domain
        self.loss_gen_recon_x_i = self.recon_criterion(x_i_recon, x_i)
        self.loss_gen_recon_x_r = self.recon_criterion(x_r_recon, x_r)
        self.loss_gen_recon_x_s = self.recon_criterion(x_s_recon, x_s)
        # Style-level reconstruction loss for cross domain
        if self.with_mapping:
            self.loss_gen_recon_s_ii = self.recon_criterion(
                s_i_recon, s_i_prime)
        else:
            self.loss_gen_recon_s_ii = 0
        self.loss_gen_recon_s_ri = self.recon_criterion(s_ri_recon, s_ri)
        self.loss_gen_recon_s_ri_rand = self.recon_criterion(
            s_ri_rand_recon, s_ri)
        self.loss_gen_recon_s_rs = self.recon_criterion(s_rs_recon, s_r_rand)
        self.loss_gen_recon_s_sr = self.recon_criterion(s_sr_recon, s_s_rand)
        self.loss_gen_recon_s_si = self.recon_criterion(s_si_recon, s_si)
        self.loss_gen_recon_s_si_rand = self.recon_criterion(
            s_si_rand_recon, s_si)
        # Content-level reconstruction loss for cross domain
        self.loss_gen_recon_c_rs = self.recon_criterion(c_rs_recon, c_s)
        self.loss_gen_recon_c_ri = self.recon_criterion(
            c_ri_recon, c_i) if self.use_content_loss is True else 0
        self.loss_gen_recon_c_ri_rand = self.recon_criterion(
            c_ri_rand_recon, c_i) if self.use_content_loss is True else 0
        self.loss_gen_recon_c_sr = self.recon_criterion(c_sr_recon, c_r)
        self.loss_gen_recon_c_si = self.recon_criterion(
            c_si_recon, c_i) if self.use_content_loss is True else 0
        self.loss_gen_recon_c_si_rand = self.recon_criterion(
            c_si_rand_recon, c_i) if self.use_content_loss is True else 0
        # Cycle consistency loss for three image domain
        self.loss_gen_cyc_recon_x_rs = self.recon_criterion(x_rsr, x_r)
        self.loss_gen_cyc_recon_x_ir = self.recon_criterion(x_iri, x_i)
        self.loss_gen_cyc_recon_x_ir_rand = self.recon_criterion(
            x_iri_rand, x_i)
        self.loss_gen_cyc_recon_x_sr = self.recon_criterion(x_srs, x_s)
        self.loss_gen_cyc_recon_x_is = self.recon_criterion(x_isi, x_i)
        self.loss_gen_cyc_recon_x_is_rand = self.recon_criterion(
            x_isi_rand, x_i)
        # GAN loss
        self.loss_gen_adv_rs = self.dis_r.calc_gen_loss(x_rs)
        self.loss_gen_adv_ri = self.dis_r.calc_gen_loss(x_ri)
        self.loss_gen_adv_ri_rand = self.dis_r.calc_gen_loss(x_ri_rand)
        self.loss_gen_adv_sr = self.dis_s.calc_gen_loss(x_sr)
        self.loss_gen_adv_si = self.dis_s.calc_gen_loss(x_si)
        self.loss_gen_adv_si_rand = self.dis_s.calc_gen_loss(x_si_rand)
        # Physical loss
        self.loss_gen_phy_i = self.physical_criterion(
            x_i, x_ri, x_si) if self.use_phy_loss is True else 0
        self.loss_gen_phy_i_rand = self.physical_criterion(
            x_i, x_ri_rand, x_si_rand) if self.use_phy_loss is True else 0

        # Reflectance smoothness loss
        self.loss_refl_ri = self.reflectance_loss(
            x_ri, targets) if targets is not None else 0
        self.loss_refl_ri_rand = self.reflectance_loss(
            x_ri_rand, targets) if targets is not None else 0

        # total loss
        self.loss_gen_total = param['gan_w'] * self.loss_gen_adv_rs + \
                              param['gan_w'] * self.loss_gen_adv_ri + \
                              param['gan_w'] * self.loss_gen_adv_ri_rand + \
                              param['gan_w'] * self.loss_gen_adv_sr + \
                              param['gan_w'] * self.loss_gen_adv_si + \
                              param['gan_w'] * self.loss_gen_adv_si_rand + \
                              param['recon_x_w'] * self.loss_gen_recon_x_i + \
                              param['recon_x_w'] * self.loss_gen_recon_x_r + \
                              param['recon_x_w'] * self.loss_gen_recon_x_s + \
                              param['recon_s_w'] * self.loss_gen_recon_s_ii + \
                              param['recon_s_w'] * self.loss_gen_recon_s_ri + \
                              param['recon_s_w'] * self.loss_gen_recon_s_ri_rand + \
                              param['recon_s_w'] * self.loss_gen_recon_s_rs + \
                              param['recon_s_w'] * self.loss_gen_recon_s_si + \
                              param['recon_s_w'] * self.loss_gen_recon_s_si_rand + \
                              param['recon_s_w'] * self.loss_gen_recon_s_sr + \
                              param['recon_c_w'] * self.loss_gen_recon_c_ri + \
                              param['recon_c_w'] * self.loss_gen_recon_c_rs + \
                              param['recon_c_w'] * self.loss_gen_recon_c_ri_rand + \
                              param['recon_c_w'] * self.loss_gen_recon_c_si + \
                              param['recon_c_w'] * self.loss_gen_recon_c_sr + \
                              param['recon_c_w'] * self.loss_gen_recon_c_si_rand + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_ir + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_ir_rand + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_is + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_is_rand + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_rs + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_sr + \
                              param['phy_x_w'] * self.loss_gen_phy_i + \
                              param['phy_x_w'] * self.loss_gen_phy_i_rand + \
                              param['refl_smooth_w'] * self.loss_refl_ri + \
                              param['refl_smooth_w'] * self.loss_refl_ri_rand

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def sample(self, x_i, x_r, x_s):
        self.eval()
        s_r = Variable(self.s_r)
        s_s = Variable(self.s_s)
        x_i_recon, x_r_recon, x_s_recon, x_rs, x_ri, x_sr, x_si = [], [], [], [], [], [], []
        for i in range(x_i.size(0)):
            c_i, s_i_fake = self.gen_i.encode(x_i[i].unsqueeze(0))
            c_r, s_r_fake = self.gen_r.encode(x_r[i].unsqueeze(0))
            c_s, s_s_fake = self.gen_s.encode(x_s[i].unsqueeze(0))
            if self.with_mapping:
                s_ri, s_si = self.fea_s(s_i_fake)
            else:
                s_ri = Variable(torch.randn(1, self.style_dim, 1,
                                            1).cuda()) + self.bias_shift
                s_si = Variable(torch.randn(1, self.style_dim, 1,
                                            1).cuda()) - self.bias_shift
            x_i_recon.append(self.gen_i.decode(c_i, s_i_fake))
            x_r_recon.append(self.gen_r.decode(c_r, s_r_fake))
            x_s_recon.append(self.gen_s.decode(c_s, s_s_fake))
            x_rs.append(self.gen_r.decode(c_s, s_r[i].unsqueeze(0)))
            x_ri.append(self.gen_r.decode(c_i, s_ri.unsqueeze(0)))
            x_sr.append(self.gen_s.decode(c_s, s_s[i].unsqueeze(0)))
            x_si.append(self.gen_s.decode(c_i, s_si.unsqueeze(0)))
        x_i_recon, x_r_recon, x_s_recon = torch.cat(x_i_recon), torch.cat(
            x_r_recon), torch.cat(x_s_recon)
        x_rs, x_ri = torch.cat(x_rs), torch.cat(x_ri)
        x_sr, x_si = torch.cat(x_sr), torch.cat(x_si)
        self.train()
        return x_i, x_i_recon, x_r, x_r_recon, x_rs, x_ri, x_s, x_s_recon, x_sr, x_si

    # noinspection PyAttributeOutsideInit
    def dis_update(self, x_i, x_r, x_s, params):
        self.dis_opt.zero_grad()
        s_r = Variable(torch.randn(x_r.size(0), self.style_dim, 1,
                                   1).cuda()) - self.bias_shift
        s_s = Variable(torch.randn(x_s.size(0), self.style_dim, 1,
                                   1).cuda()) + self.bias_shift
        # encode
        c_r, _ = self.gen_r.encode(x_r)
        c_s, _ = self.gen_s.encode(x_s)
        c_i, s_i = self.gen_i.encode(x_i)
        if self.with_mapping:
            s_ri, s_si = self.fea_s(s_i)
        else:
            s_ri = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) + self.bias_shift
            s_si = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) - self.bias_shift
        # decode (cross domain)
        x_rs = self.gen_r.decode(c_s, s_r)
        x_ri = self.gen_r.decode(c_i, s_ri)
        x_sr = self.gen_s.decode(c_r, s_s)
        x_si = self.gen_s.decode(c_i, s_si)
        # D loss
        self.loss_dis_rs = self.dis_r.calc_dis_loss(x_rs.detach(), x_r)
        self.loss_dis_ri = self.dis_r.calc_dis_loss(x_ri.detach(), x_r)
        self.loss_dis_sr = self.dis_s.calc_dis_loss(x_sr.detach(), x_s)
        self.loss_dis_si = self.dis_s.calc_dis_loss(x_si.detach(), x_s)

        self.loss_dis_total = params['gan_w'] * self.loss_dis_rs +\
                              params['gan_w'] * self.loss_dis_ri +\
                              params['gan_w'] * self.loss_dis_sr +\
                              params['gan_w'] * self.loss_dis_si

        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, param):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_i.load_state_dict(state_dict['i'])
        self.gen_r.load_state_dict(state_dict['r'])
        self.gen_s.load_state_dict(state_dict['s'])
        if self.with_mapping:
            self.fea_m.load_state_dict(state_dict['fm'])
            self.fea_s.load_state_dict(state_dict['fs'])
        self.best_result = state_dict['best_result']
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_r.load_state_dict(state_dict['r'])
        self.dis_s.load_state_dict(state_dict['s'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, param, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, param, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        if self.with_mapping:
            torch.save(
                {
                    'i': self.gen_i.state_dict(),
                    'r': self.gen_r.state_dict(),
                    's': self.gen_s.state_dict(),
                    'fs': self.fea_s.state_dict(),
                    'fm': self.fea_m.state_dict(),
                    'best_result': self.best_result
                }, gen_name)
        else:
            torch.save(
                {
                    'i': self.gen_i.state_dict(),
                    'r': self.gen_r.state_dict(),
                    's': self.gen_s.state_dict(),
                    'best_result': self.best_result
                }, gen_name)
        torch.save({
            'r': self.dis_r.state_dict(),
            's': self.dis_s.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Beispiel #3
0
class DGNet_Trainer(nn.Module):
    def __init__(self, hyperparameters, gpu_ids=[0]):
        super(DGNet_Trainer, self).__init__()
        lr_g = hyperparameters['lr_g']
        lr_d = hyperparameters['lr_d']
        ID_class = hyperparameters['ID_class']
        if not 'apex' in hyperparameters.keys():
            hyperparameters['apex'] = False
        self.fp16 = hyperparameters['apex']
        # Initiate the networks
        # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False.
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'],
                              hyperparameters['gen'],
                              fp16=False)  # auto-encoder for domain a
        self.gen_b = self.gen_a  # auto-encoder for domain b

        if not 'ID_stride' in hyperparameters.keys():
            hyperparameters['ID_stride'] = 2

        if hyperparameters['ID_style'] == 'PCB':
            self.id_a = PCB(ID_class)
        elif hyperparameters['ID_style'] == 'AB':
            self.id_a = ft_netAB(ID_class,
                                 stride=hyperparameters['ID_stride'],
                                 norm=hyperparameters['norm_id'],
                                 pool=hyperparameters['pool'])
        else:
            self.id_a = ft_net(ID_class,
                               norm=hyperparameters['norm_id'],
                               pool=hyperparameters['pool'])  # return 2048 now

        self.id_b = self.id_a
        self.dis_a = MsImageDis(3, hyperparameters['dis'],
                                fp16=False)  # discriminator for domain a
        self.dis_b = self.dis_a  # discriminator for domain b

        # load teachers
        if hyperparameters['teacher'] != "":
            teacher_name = hyperparameters['teacher']
            print(teacher_name)
            teacher_names = teacher_name.split(',')
            teacher_model = nn.ModuleList()
            teacher_count = 0
            for teacher_name in teacher_names:
                config_tmp = load_config(teacher_name)
                if 'stride' in config_tmp:
                    stride = config_tmp['stride']
                else:
                    stride = 2
                model_tmp = ft_net(ID_class, stride=stride)
                teacher_model_tmp = load_network(model_tmp, teacher_name)
                teacher_model_tmp.model.fc = nn.Sequential(
                )  # remove the original fc layer in ImageNet
                teacher_model_tmp = teacher_model_tmp.cuda()
                if self.fp16:
                    teacher_model_tmp = amp.initialize(teacher_model_tmp,
                                                       opt_level="O1")
                teacher_model.append(teacher_model_tmp.cuda().eval())
                teacher_count += 1
            self.teacher_model = teacher_model
            if hyperparameters['train_bn']:
                self.teacher_model = self.teacher_model.apply(train_bn)

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # RGB to one channel
        if hyperparameters['single'] == 'edge':
            self.single = to_edge
        else:
            self.single = to_gray(False)

        # Random Erasing when training
        if not 'erasing_p' in hyperparameters.keys():
            hyperparameters['erasing_p'] = 0
        self.single_re = RandomErasing(
            probability=hyperparameters['erasing_p'], mean=[0.0, 0.0, 0.0])

        if not 'T_w' in hyperparameters.keys():
            hyperparameters['T_w'] = 1
        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(
            self.dis_a.parameters())  #+ list(self.dis_b.parameters())
        gen_params = list(
            self.gen_a.parameters())  #+ list(self.gen_b.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr_d,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr_g,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        # id params
        if hyperparameters['ID_style'] == 'PCB':
            ignored_params = (
                list(map(id, self.id_a.classifier0.parameters())) +
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())) +
                list(map(id, self.id_a.classifier3.parameters())))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier0.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier3.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        elif hyperparameters['ID_style'] == 'AB':
            ignored_params = (
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        else:
            ignored_params = list(map(id, self.id_a.classifier.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)

        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.id_scheduler = get_scheduler(self.id_opt, hyperparameters)
        self.id_scheduler.gamma = hyperparameters['gamma2']

        #ID Loss
        self.id_criterion = nn.CrossEntropyLoss()
        self.criterion_teacher = nn.KLDivLoss(size_average=False)
        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # save memory
        if self.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.id_a = self.id_a.cuda()

            self.gen_b = self.gen_a
            self.dis_b = self.dis_a
            self.id_b = self.id_a

            self.gen_a, self.gen_opt = amp.initialize(self.gen_a,
                                                      self.gen_opt,
                                                      opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a,
                                                      self.dis_opt,
                                                      opt_level="O1")
            self.id_a, self.id_opt = amp.initialize(self.id_a,
                                                    self.id_opt,
                                                    opt_level="O1")

    def to_re(self, x):
        out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3))
        out = out.cuda()
        for i in range(x.size(0)):
            out[i, :, :, :] = self.single_re(x[i, :, :, :])
        return out

    def recon_criterion(self, input, target):
        diff = input - target.detach()
        return torch.mean(torch.abs(diff[:]))

    def recon_criterion_sqrt(self, input, target):
        diff = input - target
        return torch.mean(torch.sqrt(torch.abs(diff[:]) + 1e-8))

    def recon_criterion2(self, input, target):
        diff = input - target
        return torch.mean(diff[:]**2)

    def recon_cos(self, input, target):
        cos = torch.nn.CosineSimilarity()
        cos_dis = 1 - cos(input, target)
        return torch.mean(cos_dis[:])

    def forward(self, x_a, x_b):
        self.eval()
        s_a = self.gen_a.encode(self.single(x_a))
        s_b = self.gen_b.encode(self.single(x_b))
        f_a, _ = self.id_a(scale2(x_a))
        f_b, _ = self.id_b(scale2(x_b))
        x_ba = self.gen_a.decode(s_b, f_a)
        x_ab = self.gen_b.decode(s_a, f_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, l_a, xp_a, x_b, l_b, xp_b, hyperparameters,
                   iteration):
        # ppa, ppb is the same person
        self.gen_opt.zero_grad()
        self.id_opt.zero_grad()
        # encode
        s_a = self.gen_a.encode(self.single(x_a))
        s_b = self.gen_b.encode(self.single(x_b))
        f_a, p_a = self.id_a(scale2(x_a))
        f_b, p_b = self.id_b(scale2(x_b))
        # autodecode
        x_a_recon = self.gen_a.decode(s_a, f_a)
        x_b_recon = self.gen_b.decode(s_b, f_b)

        # encode the same ID different photo
        fp_a, pp_a = self.id_a(scale2(xp_a))
        fp_b, pp_b = self.id_b(scale2(xp_b))

        # decode the same person
        x_a_recon_p = self.gen_a.decode(s_a, fp_a)
        x_b_recon_p = self.gen_b.decode(s_b, fp_b)

        # has gradient
        x_ba = self.gen_a.decode(s_b, f_a)
        x_ab = self.gen_b.decode(s_a, f_b)
        # no gradient
        x_ba_copy = Variable(x_ba.data, requires_grad=False)
        x_ab_copy = Variable(x_ab.data, requires_grad=False)

        rand_num = random.uniform(0, 1)
        #################################
        # encode structure
        if hyperparameters['use_encoder_again'] >= rand_num:
            # encode again (encoder is tuned, input is fixed)
            s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy))
            s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy))
        else:
            # copy the encoder
            self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content)
            self.enc_content_copy = self.enc_content_copy.eval()
            # encode again (encoder is fixed, input is tuned)
            s_a_recon = self.enc_content_copy(self.single(x_ab))
            s_b_recon = self.enc_content_copy(self.single(x_ba))

        #################################
        # encode appearance
        self.id_a_copy = copy.deepcopy(self.id_a)
        self.id_a_copy = self.id_a_copy.eval()
        if hyperparameters['train_bn']:
            self.id_a_copy = self.id_a_copy.apply(train_bn)
        self.id_b_copy = self.id_a_copy
        # encode again (encoder is fixed, input is tuned)
        f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba))
        f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab))

        # teacher Loss
        #  Tune the ID model
        log_sm = nn.LogSoftmax(dim=1)
        if hyperparameters['teacher_w'] > 0 and hyperparameters[
                'teacher'] != "":
            if hyperparameters['ID_style'] == 'normal':
                _, p_a_student = self.id_a(scale2(x_ba_copy))
                p_a_student = log_sm(p_a_student)
                p_a_teacher = predict_label(
                    self.teacher_model,
                    scale2(x_ba_copy),
                    num_class=hyperparameters['ID_class'],
                    alabel=l_a,
                    slabel=l_b,
                    teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher = self.criterion_teacher(
                    p_a_student, p_a_teacher) / p_a_student.size(0)

                _, p_b_student = self.id_b(scale2(x_ab_copy))
                p_b_student = log_sm(p_b_student)
                p_b_teacher = predict_label(
                    self.teacher_model,
                    scale2(x_ab_copy),
                    num_class=hyperparameters['ID_class'],
                    alabel=l_b,
                    slabel=l_a,
                    teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(
                    p_b_student, p_b_teacher) / p_b_student.size(0)
            elif hyperparameters['ID_style'] == 'AB':
                # normal teacher-student loss
                # BA -> LabelA(smooth) + LabelB(batchB)
                _, p_ba_student = self.id_a(scale2(x_ba_copy))  # f_a, s_b
                p_a_student = log_sm(p_ba_student[0])
                with torch.no_grad():
                    p_a_teacher = predict_label(
                        self.teacher_model,
                        scale2(x_ba_copy),
                        num_class=hyperparameters['ID_class'],
                        alabel=l_a,
                        slabel=l_b,
                        teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher = self.criterion_teacher(
                    p_a_student, p_a_teacher) / p_a_student.size(0)

                _, p_ab_student = self.id_b(scale2(x_ab_copy))  # f_b, s_a
                p_b_student = log_sm(p_ab_student[0])
                with torch.no_grad():
                    p_b_teacher = predict_label(
                        self.teacher_model,
                        scale2(x_ab_copy),
                        num_class=hyperparameters['ID_class'],
                        alabel=l_b,
                        slabel=l_a,
                        teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(
                    p_b_student, p_b_teacher) / p_b_student.size(0)

                # branch b loss
                # here we give different label
                loss_B = self.id_criterion(p_ba_student[1],
                                           l_b) + self.id_criterion(
                                               p_ab_student[1], l_a)
                self.loss_teacher = hyperparameters[
                    'T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B
        else:
            self.loss_teacher = 0.0

        # decode again (if needed)
        if hyperparameters['use_decoder_again']:
            x_aba = self.gen_a.decode(
                s_a_recon,
                f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
            x_bab = self.gen_b.decode(
                s_b_recon,
                f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
        else:
            self.mlp_w_copy = copy.deepcopy(self.gen_a.mlp_w)
            self.mlp_b_copy = copy.deepcopy(self.gen_a.mlp_b)
            self.dec_copy = copy.deepcopy(self.gen_a.dec)  # Error
            ID = f_a_recon
            ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1)
            adain_params_w = self.mlp_w_copy(ID_Style)
            adain_params_b = self.mlp_b_copy(ID_Style)
            self.gen_a.assign_adain_params(adain_params_w, adain_params_b,
                                           self.dec_copy)
            x_aba = self.dec_copy(
                s_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

            ID = f_b_recon
            ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1)
            adain_params_w = self.mlp_w_copy(ID_Style)
            adain_params_b = self.mlp_b_copy(ID_Style)
            self.gen_a.assign_adain_params(adain_params_w, adain_params_b,
                                           self.dec_copy)
            x_bab = self.dec_copy(
                s_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # auto-encoder image reconstruction
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a)
        self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b)

        # feature reconstruction
        self.loss_gen_recon_s_a = self.recon_criterion(
            s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_s_b = self.recon_criterion(
            s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_f_a = self.recon_criterion(
            f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0
        self.loss_gen_recon_f_b = self.recon_criterion(
            f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0

        # Random Erasing only effect the ID and PID loss.
        if hyperparameters['erasing_p'] > 0:
            x_a_re = self.to_re(scale2(x_a.clone()))
            x_b_re = self.to_re(scale2(x_b.clone()))
            xp_a_re = self.to_re(scale2(xp_a.clone()))
            xp_b_re = self.to_re(scale2(xp_b.clone()))
            _, p_a = self.id_a(x_a_re)
            _, p_b = self.id_b(x_b_re)
            # encode the same ID different photo
            _, pp_a = self.id_a(xp_a_re)
            _, pp_b = self.id_b(xp_b_re)

        # ID loss AND Tune the Generated image
        if hyperparameters['ID_style'] == 'PCB':
            self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b)
            self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b)
            self.loss_gen_recon_id = self.PCB_loss(
                p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b)
        elif hyperparameters['ID_style'] == 'AB':
            weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w']
            self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \
                         + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) )
            self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion(
                pp_b[0], l_b
            )  #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) )
            self.loss_gen_recon_id = self.id_criterion(
                p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b)
        else:
            self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion(
                p_b, l_b)
            self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion(
                pp_b, l_b)
            self.loss_gen_recon_id = self.id_criterion(
                p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b)

        #print(f_a_recon, f_a)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0

        if iteration > hyperparameters['warm_iter']:
            hyperparameters['recon_f_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'],
                                               hyperparameters['max_w'])
            hyperparameters['recon_s_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'],
                                               hyperparameters['max_w'])
            hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_x_cyc_w'] = min(
                hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w'])

        if iteration > hyperparameters['warm_teacher_iter']:
            hyperparameters['teacher_w'] += hyperparameters['warm_scale']
            hyperparameters['teacher_w'] = min(
                hyperparameters['teacher_w'], hyperparameters['max_teacher_w'])
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['id_w'] * self.loss_id + \
                              hyperparameters['pid_w'] * self.loss_pid + \
                              hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
                              hyperparameters['teacher_w'] * self.loss_teacher
        if self.fp16:
            with amp.scale_loss(self.loss_gen_total,
                                [self.gen_opt, self.id_opt]) as scaled_loss:
                scaled_loss.backward()
            self.gen_opt.step()
            self.id_opt.step()
        else:
            self.loss_gen_total.backward()
            self.gen_opt.step()
            self.id_opt.step()
        print("L_total: %.4f, L_gan: %.4f,  Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \
                                                        hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \
                                                        hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \
                                                        hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \
                                                        hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \
                                                        hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \
                                                        hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \
                                                        hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \
                                                        hyperparameters['id_w'] * self.loss_id,\
                                                        hyperparameters['pid_w'] * self.loss_pid,\
hyperparameters['teacher_w'] * self.loss_teacher )  )

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def PCB_loss(self, inputs, labels):
        loss = 0.0
        for part in inputs:
            loss += self.id_criterion(part, labels)
        return loss / len(inputs)

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0)))
            s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0)))
            f_a, _ = self.id_a(scale2(x_a[i].unsqueeze(0)))
            f_b, _ = self.id_b(scale2(x_b[i].unsqueeze(0)))
            x_a_recon.append(self.gen_a.decode(s_a, f_a))
            x_b_recon.append(self.gen_b.decode(s_b, f_b))
            x_ba = self.gen_a.decode(s_b, f_a)
            x_ab = self.gen_b.decode(s_a, f_b)
            x_ba1.append(x_ba)
            x_ab1.append(x_ab)
            #cycle
            s_b_recon = self.gen_a.enc_content(self.single(x_ba))
            s_a_recon = self.gen_b.enc_content(self.single(x_ab))
            f_a_recon, _ = self.id_a(scale2(x_ba))
            f_b_recon, _ = self.id_b(scale2(x_ab))
            x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon))
            x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab)
        x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1)
        self.train()

        return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        # encode
        s_a = self.gen_a.encode(self.single(x_a))
        s_b = self.gen_b.encode(self.single(x_b))
        f_a, _ = self.id_a(scale2(x_a))
        f_b, _ = self.id_b(scale2(x_b))
        # decode (cross domain)
        x_ba = self.gen_a.decode(s_b, f_a)
        x_ab = self.gen_b.decode(s_a, f_b)
        # D loss
        self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        print("DLoss: %.4f" % self.loss_dis_total,
              "Reg: %.4f" % (reg_a + reg_b))
        if self.fp16:
            with amp.scale_loss(self.loss_dis_total,
                                self.dis_opt) as scaled_loss:
                scaled_loss.backward()
        else:
            self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.id_scheduler is not None:
            self.id_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b = self.gen_a
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b = self.dis_a
        # Load ID dis
        last_model_name = get_model_list(checkpoint_dir, "id")
        state_dict = torch.load(last_model_name)
        self.id_a.load_state_dict(state_dict['a'])
        self.id_b = self.id_a
        # Load optimizers
        try:
            state_dict = torch.load(
                os.path.join(checkpoint_dir, 'optimizer.pt'))
            self.dis_opt.load_state_dict(state_dict['dis'])
            self.gen_opt.load_state_dict(state_dict['gen'])
            self.id_opt.load_state_dict(state_dict['id'])
        except:
            pass
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict()}, dis_name)
        torch.save({'a': self.id_a.state_dict()}, id_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'id': self.id_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Beispiel #4
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(
            c_a_recon,
            s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            c_b_recon,
            s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def snap_clean(self, snap_dir, iterations, save_last=10000, period=20000):
        # Cleaning snapshot directory from old files
        if not os.path.exists(snap_dir):
            return None

        gen_models = [
            os.path.join(snap_dir, f) for f in os.listdir(snap_dir)
            if "gen" in f and ".pt" in f
        ]
        dis_models = [
            os.path.join(snap_dir, f) for f in os.listdir(snap_dir)
            if "dis" in f and ".pt" in f
        ]

        gen_models.sort()
        dis_models.sort()
        marked_clean = []
        for i, model in enumerate(gen_models):
            m_iter = int(model[-11:-3])
            if i == 0:
                m_prev = 0
                continue
            if m_iter > iterations - save_last:
                break
            if m_iter - m_prev < period:
                marked_clean.append(model)
            while m_iter - m_prev >= period:
                m_prev += period

        for i, model in enumerate(dis_models):
            m_iter = int(model[-11:-3])
            if i == 0:
                m_prev = 0
                continue
            if m_iter > iterations - save_last:
                break
            if m_iter - m_prev < period:
                marked_clean.append(model)
            while m_iter - m_prev >= period:
                m_prev += period

        print(f'Cleaning snapshots: {marked_clean}')
        for f in marked_clean:
            os.remove(f)

    def save(self, snapshot_dir, iterations, smart_override):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
        if smart_override:
            self.snap_clean(snapshot_dir, iterations + 1)
Beispiel #5
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(
            c_a_recon,
            s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            c_b_recon,
            s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def sample(self, x_a, x_b):
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Beispiel #6
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, m_A, x_b, m_B, hyperparameters):
        self.gen_opt.zero_grad()

        im_A = 1 - m_A
        im_B = 1 - m_B
        # encode

        c_a, s_bA = self.gen_a.encode(x_a, im_A)
        c_b, s_fB = self.gen_b.encode(x_b, m_B)

        _, s_fA = self.gen_a.encode(x_a, m_A)
        _, s_bB = self.gen_b.encode(x_b, im_B)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_fA, m_B, s_bB)
        x_ab = self.gen_b.decode(c_a, s_fB, m_A, s_bA)

        # decode (within domain)
        x_aa = self.gen_a.decode(c_a, s_fA, m_A, s_bA)
        x_bb = self.gen_b.decode(c_b, s_fB, m_B, s_bB)

        # encode again
        c_ba, s_fBA = self.gen_a.encode(x_ba, m_B)
        c_ab, s_fAB = self.gen_a.encode(x_ab, m_A)

        _, s_bBA = self.gen_a.encode(x_ba, im_B)
        _, s_bAB = self.gen_a.encode(x_ab, im_A)

        # decode again (if needed)
        x_aba = self.gen_a.decode(
            c_ab, s_fBA, m_A,
            s_bAB) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            c_ba, s_fAB, m_B,
            s_bBA) if hyperparameters['recon_x_cyc_w'] > 0 else None

        self.loss_gen_recon_c_a = self.recon_criterion(c_ab, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_ba, c_b)

        self.loss_gen_recon_s_a = self.recon_criterion(s_bAB, s_bA)
        self.loss_gen_recon_s_b = self.recon_criterion(s_bBA, s_bB)
        self.loss_gen_recon_s_af = self.recon_criterion(s_fAB, s_fB)
        self.loss_gen_recon_s_bf = self.recon_criterion(s_fBA, s_fA)

        self.loss_gen_recon_x_a = self.recon_criterion(x_aa, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_bb, x_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            im_A * x_aba, im_A *
            x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            m_B * x_bab, m_B *
            x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_af + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_bf + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, loader_a, loader_b, size):
        self.eval()
        im_a = torch.stack([loader_a.dataset[i][0]
                            for i in range(size)]).cuda()
        seg_a = torch.stack([loader_a.dataset[i][1]
                             for i in range(size)]).cuda()
        im_b = torch.stack([loader_b.dataset[i][0]
                            for i in range(size)]).cuda()
        seg_b = torch.stack([loader_b.dataset[i][1]
                             for i in range(size)]).cuda()
        x_a_recon, x_b_recon, x_ba1, x_bm, x_ab1, x_am = [], [], [], [], [], []
        for i in range(im_a.size(0)):
            mask_a = seg_a[i].unsqueeze(0)
            mask_b = seg_b[i].unsqueeze(0)
            x_a = im_a[i].unsqueeze(0)
            x_b = im_b[i].unsqueeze(0)

            masked_a = mask_a * x_a
            masked_b = mask_b * x_b

            c_a, s_bA = self.gen_a.encode(x_a, 1 - mask_a)
            c_b, s_fB = self.gen_b.encode(x_b, mask_b)

            c_a, s_fA = self.gen_a.encode(x_a, mask_a)
            c_b, s_bB = self.gen_b.encode(x_b, 1 - mask_b)
            # decode (cross domain)
            x_BA = self.gen_a.decode(c_b, s_fA, mask_b, s_bB)
            x_AB = self.gen_b.decode(c_a, s_fB, mask_a, s_bA)

            if 0 == i % 2:
                x_AB = (1 * (1 - mask_a) * x_a +
                        (0 * (1 - mask_a) * x_AB)) + mask_a * x_AB
                x_BA = (1 * (1 - mask_b) * x_b +
                        (0 * (1 - mask_b) * x_BA)) + mask_b * x_BA

            x_ba1.append(x_BA)
            x_ab1.append(x_AB)
            x_am.append(masked_a)
            x_bm.append(masked_b)

            # decode (within domain)
            x_A_recon = self.gen_a.decode(c_a, s_fA, mask_a, s_bA)
            x_B_recon = self.gen_b.decode(c_b, s_fB, mask_b, s_bB)
            x_a_recon.append(x_A_recon)
            x_b_recon.append(x_B_recon)

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1 = torch.cat(x_ba1)
        x_ab1 = torch.cat(x_ab1)
        x_bm = torch.cat(x_bm)
        x_am = torch.cat(x_am)

        self.train()
        return im_a, x_a_recon, x_ab1, x_am, im_b, x_b_recon, x_ba1, x_bm

    def dis_update(self, x_a, m_a, x_b, m_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        up_im_A = 1 - m_a  #F.interpolate(1-m_a, None,1, 'bilinear', align_corners=False)
        up_m_B = m_b  #F.interpolate(m_b, None, 1, 'bilinear', align_corners=False)
        up_m_A = m_a  #F.interpolate(m_a, None, 1, 'bilinear', align_corners=False)
        up_im_B = 1 - m_b  #.interpolate(1-m_b, None, 1, 'bilinear', align_corners=False)

        c_a, s_bA = self.gen_a.encode(x_a, up_im_A)
        c_b, s_fB = self.gen_b.encode(x_b, up_m_B)

        _, s_fA = self.gen_a.encode(x_a, up_m_A)
        _, s_bB = self.gen_b.encode(x_b, up_im_B)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_fA, m_b, s_bB)
        x_ab = self.gen_b.decode(c_a, s_fB, m_a, s_bA)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Beispiel #7
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler, self.lr_policy = get_scheduler(
            self.dis_opt, hyperparameters)
        self.gen_scheduler, self.lr_policy = get_scheduler(
            self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))
        self.metric = 0
        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    # 进来两张图 a b
    def gen_update(self, x_a, x_b, hyperparameters, mask):
        self.guided = 0
        print(type(mask))
        self.gen_opt.zero_grad()
        # 随机出style a b
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        # 通过gen 得到content  style'
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain) 把encoder decoder应该要能recon
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain) 如果结合content 和style 得到的应该是translation结束的结果
        if self.guided == 0:
            x_ba = self.gen_a.decode(c_b, s_a)
            x_ab = self.gen_b.decode(c_a, s_b)
        elif self.guided == 1:
            x_ba = self.gen_a.decode(c_b, s_a_prime)
            x_ab = self.gen_b.decode(c_a, s_b_prime)
        # encode again 再区分conten style
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(
            c_a_recon,
            s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            c_b_recon,
            s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        print(x_a_recon.size(), x_b_recon.size())

        # mask loss
        mask = torch.cat([mask, mask, mask], 1)
        self.loss_attentive = self.recon_criterion(x_a[mask == 1],
                                                   x_a_recon[mask == 1])

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
                              hyperparameters['att_w'] * self.loss_attentive
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    # 训练discriminator 输入两张图片,各自转成不同的domain。
    # 使用各自的content code,但是使用随机的style code,去encode出一张图片
    # 希望能骗过discriminator
    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            if self.lr_policy == 'plateau':
                self.dis_scheduler.step(self.metric)
            else:
                self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            if self.lr_policy == 'plateau':
                self.dis_scheduler.step(self.metric)
            else:
                self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler, self.lr_policy = get_scheduler(
            self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler, self.lr_policy = get_scheduler(
            self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Beispiel #8
0
class DSMAP_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(DSMAP_Trainer, self).__init__()
        # Initiate the networks
        mid_downsample = hyperparameters['gen'].get('mid_downsample', 1)
        self.content_enc = ContentEncoder_share(
            hyperparameters['gen']['n_downsample'],
            mid_downsample,
            hyperparameters['gen']['n_res'],
            hyperparameters['input_dim_a'],
            hyperparameters['gen']['dim'],
            'in',
            hyperparameters['gen']['activ'],
            pad_type=hyperparameters['gen']['pad_type'])

        self.style_dim = hyperparameters['gen']['style_dim']
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'], self.content_enc, 'a',
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'], self.content_enc, 'b',
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'], self.content_enc.output_dim,
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'], self.content_enc.output_dim,
            hyperparameters['dis'])  # discriminator for domain b

    def build_optimizer(self, hyperparameters):
        # Setup the optimizers
        lr = hyperparameters['lr']
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            import torchvision.models as models
            self.vgg = models.vgg16(pretrained=True)
            # If you cannot download pretrained model automatically, you can download it from
            # https://download.pytorch.org/models/vgg16-397923af.pth and load it manually
            # state_dict = torch.load('vgg16-397923af.pth')
            # self.vgg.load_state_dict(state_dict)
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def __compute_kl(self, mu, logvar):
        encoding_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return encoding_loss

    def gen_update(self, x_a, x_b, hyperparameters, iterations):
        self.gen_opt.zero_grad()
        self.gen_backward_cc(x_a, x_b, hyperparameters)
        #self.gen_opt.step()

        #self.gen_opt.zero_grad()
        self.gen_backward_latent(x_a, x_b, hyperparameters)
        self.gen_opt.step()

    def gen_backward_latent(self, x_a, x_b, hyperparameters):
        # random sample style vector and multimodal training
        s_a_random = Variable(
            torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b_random = Variable(
            torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # decode
        x_ba_random = self.gen_a.decode(self.c_b, self.da_b, s_a_random)
        x_ab_random = self.gen_b.decode(self.c_a, self.db_a, s_b_random)

        c_b_random_recon, _, _, s_a_random_recon, _, _ = self.gen_a.encode(
            x_ba_random)
        c_a_random_recon, _, _, s_b_random_recon, _, _ = self.gen_b.encode(
            x_ab_random)

        # style reconstruction loss
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_random,
                                                       s_a_random_recon)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_random,
                                                       s_b_random_recon)
        loss_gen_recon_c_a = self.recon_criterion(self.c_a, c_a_random_recon)
        loss_gen_recon_c_b = self.recon_criterion(self.c_b, c_b_random_recon)
        loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba_random, x_a)
        loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab_random, x_b)
        loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba_random, x_a) if hyperparameters['vgg_w'] > 0 else 0
        loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab_random, x_b) if hyperparameters['vgg_w'] > 0 else 0

        loss_gen_total = hyperparameters['gan_w'] * loss_gen_adv_a + \
                          hyperparameters['gan_w'] * loss_gen_adv_b + \
                          hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                          hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                          hyperparameters['recon_c_w'] * loss_gen_recon_c_a + \
                          hyperparameters['recon_c_w'] * loss_gen_recon_c_b + \
                          hyperparameters['vgg_w'] * loss_gen_vgg_a + \
                          hyperparameters['vgg_w'] * loss_gen_vgg_b

        self.loss_gen_total += loss_gen_total
        self.loss_gen_total.backward()

        self.loss_gen_total += loss_gen_total
        self.loss_gen_adv_a += loss_gen_adv_a
        self.loss_gen_adv_b += loss_gen_adv_b
        self.loss_gen_recon_c_a += loss_gen_recon_c_a
        self.loss_gen_recon_c_b += loss_gen_recon_c_b
        self.loss_gen_vgg_a += loss_gen_vgg_a
        self.loss_gen_vgg_b += loss_gen_vgg_b

    def gen_backward_cc(self, x_a, x_b, hyperparameters):
        pre_c_a, self.c_a, c_domain_a, self.db_a, self.s_a_prime, mu_a, logvar_a = self.gen_a.encode(
            x_a, training=True, flag=True)
        pre_c_b, self.c_b, c_domain_b, self.da_b, self.s_b_prime, mu_b, logvar_b = self.gen_b.encode(
            x_b, training=True, flag=True)

        self.da_a = self.gen_b.domain_mapping(self.c_a, pre_c_a)
        self.db_b = self.gen_a.domain_mapping(self.c_b, pre_c_b)

        # decode (within domain)
        x_a_recon = self.gen_a.decode(self.c_a, self.da_a, self.s_a_prime)
        x_b_recon = self.gen_b.decode(self.c_b, self.db_b, self.s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(self.c_b, self.da_b, self.s_a_prime)
        x_ab = self.gen_b.decode(self.c_a, self.db_a, self.s_b_prime)

        c_b_recon, _, self.db_b_recon, s_a_recon, _, _ = self.gen_a.encode(
            x_ba, training=True)
        c_a_recon, _, self.da_a_recon, s_b_recon, _, _ = self.gen_b.encode(
            x_ab, training=True)
        # decode again (cycle consistance loss)
        x_aba = self.gen_a.decode(c_a_recon, self.da_a_recon, s_a_recon)
        x_bab = self.gen_b.decode(c_b_recon, self.db_b_recon, s_b_recon)

        # domain-specific content reconstruction loss
        self.loss_gen_recon_d_a = self.recon_criterion(
            c_domain_a, self.da_a) if hyperparameters['recon_d_w'] > 0 else 0
        self.loss_gen_recon_d_b = self.recon_criterion(
            c_domain_b, self.db_b) if hyperparameters['recon_d_w'] > 0 else 0

        # image reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        # domain-invariant content reconstruction loss
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, self.c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, self.c_b)
        # cyc loss
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a)
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b)
        # kl loss (if needed)
        self.loss_gen_recon_kl_a = self.__compute_kl(
            mu_a, logvar_a) if hyperparameters['recon_kl_w'] > 0 else 0
        self.loss_gen_recon_kl_b = self.__compute_kl(
            mu_b, logvar_b) if hyperparameters['recon_kl_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba, x_a)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab, x_b)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_a) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_b) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_d_w'] * self.loss_gen_recon_d_a + \
                              hyperparameters['recon_d_w'] * self.loss_gen_recon_d_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b

        # self.loss_gen_total.backward(retain_graph=True)

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg.features(img_vgg)
        target_fea = vgg.features(target_vgg)
        return contextual_loss(img_fea, target_fea)

    def dis_update(self, x_a, x_b, hyperparameters, iterations):
        self.dis_opt.zero_grad()
        s_a_random = Variable(
            torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b_random = Variable(
            torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        # encode
        pre_c_a, c_a, c_domain_a, db_a, s_a, _, _ = self.gen_a.encode(
            x_a, training=True, flag=True)
        pre_c_b, c_b, c_domain_b, da_b, s_b, _, _ = self.gen_b.encode(
            x_b, training=True, flag=True)
        da_a = self.gen_b.domain_mapping(c_a, pre_c_a)
        db_b = self.gen_a.domain_mapping(c_b, pre_c_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, da_b, s_a)
        x_ab = self.gen_b.decode(c_a, db_a, s_b)
        # decode (cross domain)
        x_ba_random = self.gen_a.decode(c_b, da_b, s_a_random)
        x_ab_random = self.gen_b.decode(c_a, db_a, s_b_random)

        c_b_recon, _, db_b_recon, s_a_recon, _, _ = self.gen_a.encode(x_ba)
        c_a_recon, _, da_a_recon, s_b_recon, _, _ = self.gen_b.encode(x_ab)
        _, _, db_b_random_recon, _, _, _ = self.gen_a.encode(x_ba_random)
        _, _, da_a_random_recon, _, _, _ = self.gen_b.encode(x_ab_random)

        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(
            x_ba.detach(), x_a) + self.dis_a.calc_dis_loss(
                x_ba_random.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(
            x_ab.detach(), x_b) + self.dis_b.calc_dis_loss(
                x_ab_random.detach(), x_b)

        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \
                              hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ab, x_ba = [], [], [], []
        for i in range(x_a.size(0)):
            pre_c_a, c_a, _, db_a, s_a_fake, _, _ = self.gen_a.encode(
                x_a[i].unsqueeze(0), flag=True)
            pre_c_b, c_b, _, da_b, s_b_fake, _, _ = self.gen_b.encode(
                x_b[i].unsqueeze(0), flag=True)
            da_a = self.gen_b.domain_mapping(c_a, pre_c_a)
            db_b = self.gen_a.domain_mapping(c_b, pre_c_b)
            x_a_recon.append(self.gen_a.decode(c_a, da_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, db_b, s_b_fake))
            x_ba.append(self.gen_a.decode(c_b, da_b, s_a_fake))
            x_ab.append(self.gen_b.decode(c_a, db_a, s_b_fake))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ab, x_ba = torch.cat(x_ab), torch.cat(x_ba)
        self.train()

        return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)