Example #1
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b

        self.content_classifier = ContentClassifier(
            hyperparameters['gen']['dim'], hyperparameters)

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.num_con_c = hyperparameters['dis']['num_con_c']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        # dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())

        dis_named_params = list(self.dis_a.named_parameters()) + list(
            self.dis_b.named_parameters())
        # gen_named_params = list(self.gen_a.named_parameters()) + list(self.gen_b.named_parameters())

        ### modifying list params
        dis_params = list()
        # gen_params = list()
        for name, param in dis_named_params:
            if "_Q" in name:
                # print('%s --> gen_params' % name)
                gen_params.append(param)
            else:
                dis_params.append(param)

        content_classifier_params = list(self.content_classifier.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.cla_opt = torch.optim.Adam(
            [p for p in content_classifier_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.cla_scheduler = get_scheduler(self.cla_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        self.content_classifier.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        self.gan_type = hyperparameters['dis']['gan_type']
        self.criterionQ_con = NormalNLLLoss()

        self.criterion_content_classifier = nn.CrossEntropyLoss()

        # self.batch_size = hyperparameters['batch_size']
        self.batch_size_val = hyperparameters['batch_size_val']

        # self.accu_content_classifier_c_a = 0
        # self.accu_content_classifier_c_a_recon = 0
        # self.accu_content_classifier_c_b = 0
        # self.accu_content_classifier_c_b_recon = 0
        # self.accu_CC_all = 0

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, sample_b, hyperparameters, sample_a_limited):
        x_b, label_b = sample_b
        x_a_limited, label_a_limited = sample_a_limited

        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        # x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        # x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)

        # GAN loss
        # self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        x_ba_dis_out = self.dis_a(x_ba)
        self.loss_gen_adv_a = self.compute_gen_adv_loss(x_ba_dis_out)

        # self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        x_ab_dis_out = self.dis_b(x_ab)
        self.loss_gen_adv_b = self.compute_gen_adv_loss(x_ab_dis_out)

        # loss info continuous
        self.info_cont_loss_a = self.compute_info_cont_loss(s_a, x_ba_dis_out)
        self.info_cont_loss_b = self.compute_info_cont_loss(s_b, x_ab_dis_out)

        # label_predict_c_a = self.content_classifier(c_a)
        # label_predict_c_a_recon = self.content_classifier(c_a_recon)
        label_predict_c_b = self.content_classifier(c_b)
        label_predict_c_b_recon = self.content_classifier(c_b_recon)

        # loss_content_classifier_c_a = self.compute_content_classifier_loss(label_predict_c_a, label_a)
        # loss_content_classifier_c_a_recon = self.compute_content_classifier_loss(label_predict_c_a_recon, label_a)

        ### compute loss of classifier c_a based on limited samples
        c_a_limited, _ = self.gen_a.encode(x_a_limited)
        label_predict_c_a_limited = self.content_classifier(c_a_limited)

        x_ab_limited = self.gen_b.decode(c_a_limited, s_b)
        c_a_recon_limited, _ = self.gen_b.encode(x_ab_limited)
        label_predict_c_a_recon_limited = self.content_classifier(
            c_a_recon_limited)

        loss_content_classifier_c_a = self.compute_content_classifier_loss(
            label_predict_c_a_limited, label_a_limited)
        loss_content_classifier_c_a_recon = self.compute_content_classifier_loss(
            label_predict_c_a_recon_limited, label_a_limited)

        loss_content_classifier_b = self.compute_content_classifier_loss(
            label_predict_c_b, label_b)
        loss_content_classifier_c_b_recon = self.compute_content_classifier_loss(
            label_predict_c_b_recon, label_b)

        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              self.info_cont_loss_a + \
                              self.info_cont_loss_b +\
                              loss_content_classifier_c_a + \
                              loss_content_classifier_c_a_recon + \
                              loss_content_classifier_b + \
                              loss_content_classifier_c_b_recon

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_info_cont_loss(self, style_code, outs_fake):
        loss = 0
        num_cont_code = self.num_con_c
        for it, (out_fake) in enumerate(outs_fake):
            q_mu = out_fake['mu']
            q_var = out_fake['var']
            info_noise = style_code[:, -num_cont_code:].view(
                -1, num_cont_code).squeeze().squeeze()
            # print(q_mu.size())
            # print(q_var.size())
            # print(info_noise.size())
            # print(num_cont_code)
            # exit()
            loss += self.criterionQ_con(info_noise, q_mu, q_var) * 0.1
        return loss

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def cla_update(self, sample_a, sample_b):
        x_a, label_a = sample_a
        x_b, label_b = sample_b
        # print('cla_update')
        # print(x_a.device())
        # exit()
        self.cla_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # print("c_a")
        # print(c_a.size())
        # exit()
        # decode (within domain)
        # x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        # x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)

        label_predict_c_a = self.content_classifier(c_a)
        label_predict_c_a_recon = self.content_classifier(c_a_recon)
        label_predict_c_b = self.content_classifier(c_b)
        label_predict_c_b_recon = self.content_classifier(c_b_recon)

        self.loss_content_classifier_c_a = self.compute_content_classifier_loss(
            label_predict_c_a, label_a)
        self.loss_content_classifier_c_a_recon = self.compute_content_classifier_loss(
            label_predict_c_a_recon, label_a)
        # self.loss_content_classifier_c_a_and_c_a_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_a_recon, label_predict_c_a)

        self.loss_content_classifier_b = self.compute_content_classifier_loss(
            label_predict_c_b, label_b)
        self.loss_content_classifier_c_b_recon = self.compute_content_classifier_loss(
            label_predict_c_b_recon, label_b)
        # self.loss_content_classifier_c_b_and_c_b_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_b_recon, label_predict_c_b)

        # self.accu_content_classifier_c_a = self.compute_content_classifier_accuracy(label_predict_c_a, label_a)
        # self.accu_content_classifier_c_a_recon = self.compute_content_classifier_accuracy(label_predict_c_a_recon,
        #                                                                                   label_a)
        # self.accu_content_classifier_c_b = self.compute_content_classifier_accuracy(label_predict_c_b, label_b)
        # self.accu_content_classifier_c_b_recon = self.compute_content_classifier_accuracy(label_predict_c_b_recon,
        #                                                                                   label_b)
        # self.accu_CC_all = self.mean_list([
        #     self.accu_content_classifier_c_a,
        #     self.accu_content_classifier_c_a_recon,
        #     self.accu_content_classifier_c_b,
        #     self.accu_content_classifier_c_b_recon
        # ])

        self.loss_cla_total = self.loss_content_classifier_c_a + self.loss_content_classifier_c_a_recon + \
                              self.loss_content_classifier_b + self.loss_content_classifier_c_b_recon
        # self.loss_content_classifier_c_a_and_c_a_recon + \
        # self.loss_content_classifier_c_b_and_c_b_recon
        self.loss_cla_total.backward()
        self.cla_opt.step()

    def cla_inference(self, test_loader_a, test_loader_b):
        accu_content_classifier_c_a = []
        accu_content_classifier_c_a_recon = []
        accu_content_classifier_c_b = []
        accu_content_classifier_c_b_recon = []
        for it_inf, (samples_a_test, samples_b_test) in enumerate(
                zip(test_loader_a, test_loader_b)):
            x_a, label_a = samples_a_test[0].cuda().detach(
            ), samples_a_test[1].cuda().detach()
            x_b, label_b = samples_b_test[0].cuda().detach(
            ), samples_b_test[1].cuda().detach()

            s_a = Variable(
                torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
            s_b = Variable(
                torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

            # encode
            c_a, s_a_prime = self.gen_a.encode(x_a)
            c_b, s_b_prime = self.gen_b.encode(x_b)
            # print("c_a")
            # print(c_a.size())
            # exit()
            # decode (within domain)
            # x_a_recon = self.gen_a.decode(c_a, s_a_prime)
            # x_b_recon = self.gen_b.decode(c_b, s_b_prime)
            # decode (cross domain)
            x_ba = self.gen_a.decode(c_b, s_a)
            x_ab = self.gen_b.decode(c_a, s_b)
            # encode again
            c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
            c_a_recon, s_b_recon = self.gen_b.encode(x_ab)

            label_predict_c_a = self.content_classifier(c_a)
            label_predict_c_a_recon = self.content_classifier(c_a_recon)
            label_predict_c_b = self.content_classifier(c_b)
            label_predict_c_b_recon = self.content_classifier(c_b_recon)

            # self.loss_content_classifier_c_a = self.compute_content_classifier_loss(label_predict_c_a, label_a)
            # self.loss_content_classifier_c_a_recon = self.compute_content_classifier_loss(label_predict_c_a_recon, label_a)
            # self.loss_content_classifier_c_a_and_c_a_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_a_recon, label_predict_c_a)
            #
            # self.loss_content_classifier_b = self.compute_content_classifier_loss(label_predict_c_b, label_b)
            # self.loss_content_classifier_c_b_recon = self.compute_content_classifier_loss(label_predict_c_b_recon, label_b)
            # self.loss_content_classifier_c_b_and_c_b_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_b_recon,
            #                                                                                  label_predict_c_b)

            accu_content_classifier_c_a.append(
                self.compute_content_classifier_accuracy(
                    label_predict_c_a, label_a))
            accu_content_classifier_c_a_recon.append(
                self.compute_content_classifier_accuracy(
                    label_predict_c_a_recon, label_a))
            accu_content_classifier_c_b.append(
                self.compute_content_classifier_accuracy(
                    label_predict_c_b, label_b))
            accu_content_classifier_c_b_recon.append(
                self.compute_content_classifier_accuracy(
                    label_predict_c_b_recon, label_b))

        self.accu_content_classifier_c_a = self.mean_list(
            accu_content_classifier_c_a)
        self.accu_content_classifier_c_a_recon = self.mean_list(
            accu_content_classifier_c_a_recon)
        self.accu_content_classifier_c_b = self.mean_list(
            accu_content_classifier_c_b)
        self.accu_content_classifier_c_b_recon = self.mean_list(
            accu_content_classifier_c_b_recon)

        self.accu_CC_all = self.mean_list([
            self.accu_content_classifier_c_a,
            self.accu_content_classifier_c_a_recon,
            self.accu_content_classifier_c_b,
            self.accu_content_classifier_c_b_recon
        ])

        # self.loss_cla_total = self.loss_content_classifier_c_a + self.loss_content_classifier_c_a_recon + \
        #                       self.loss_content_classifier_b + self.loss_content_classifier_c_b_recon + \
        #                       self.loss_content_classifier_c_a_and_c_a_recon + \
        #                       self.loss_content_classifier_c_b_and_c_b_recon
        # self.loss_cla_total.backward()
        # self.cla_opt.step()

    @staticmethod
    def mean_list(lst):
        return sum(lst) / len(lst)

    def dis_update(self, x_a, x_b, hyperparameters):
        # print('dis_update')
        # print(x_a.is_cuda())
        # exit()
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)

        # D loss
        # self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        # print(x_ba.detach().size())
        # print(x_a.size())
        # exit()
        x_ba_dis_out = self.dis_a(x_ba.detach())
        x_a_dis_out = self.dis_a(x_a)
        self.loss_dis_a = self.compute_dis_loss(x_ba_dis_out, x_a_dis_out)
        # self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        x_ab_dis_out = self.dis_b(x_ab.detach())
        x_b_dis_out = self.dis_b(x_b)
        self.loss_dis_b = self.compute_dis_loss(x_ab_dis_out, x_b_dis_out)

        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def compute_content_classifier_loss(self, label_predict, label_true):
        loss = self.criterion_content_classifier(label_predict, label_true)
        return loss

    def compute_content_classifier_two_predictions_loss(
            self, label_predict_1, label_predict_2):
        # loss = self.criterion_content_classifier(label_predict, label_true)
        # print(label_predict_1.size())
        # print(label_predict_2.size())
        loss = torch.mean(torch.abs(label_predict_1 - label_predict_2))
        # print(loss.size())
        # exit()
        return loss

    def compute_content_classifier_accuracy(self, label_predict, label_true):
        # print("label_true")
        # print(label_true)
        #
        # print("label_predict")
        # print(label_predict[0])
        # print("max")
        values, indices = label_predict.max(1)
        # print(indices)

        results = (label_true == indices)
        # print(results)

        total_correct = results.sum().cpu().numpy()
        # print("total_correct")
        # print(total_correct)

        # total_samples = results.size()
        # print("total_samples")
        # print(total_samples)

        accuracy = float(total_correct) / float(self.batch_size_val)
        # print("accuracy")
        # print(accuracy)
        #
        # exit()
        return accuracy

    def compute_dis_loss(self, outs_fake, outs_real):
        # calculate the loss to train D
        # outs0 = self.forward(input_fake)
        # outs1 = self.forward(input_real)
        loss = 0
        for it, (out_fake, out_real) in enumerate(zip(outs_fake, outs_real)):
            out_fake = out_fake['output_d']
            out_real = out_real['output_d']
            if self.gan_type == 'lsgan':
                loss += torch.mean((out_fake - 0)**2) + torch.mean(
                    (out_real - 1)**2)
            elif self.gan_type == 'nsgan':
                all0 = Variable(torch.zeros_like(out_fake.data).cuda(),
                                requires_grad=False)
                all1 = Variable(torch.ones_like(out_real.data).cuda(),
                                requires_grad=False)
                loss += torch.mean(
                    F.binary_cross_entropy(F.sigmoid(out_fake), all0) +
                    F.binary_cross_entropy(F.sigmoid(out_real), all1))
            else:
                assert 0, "Unsupported GAN type: {}".format(self.gan_type)
        return loss

    def compute_gen_adv_loss(self, outs_fake):
        # calculate the loss to train G
        # out_fake = self.forward(input_fake)
        loss = 0
        for it, (out_fake) in enumerate(outs_fake):
            out_fake = out_fake['output_d']
            if self.gan_type == 'lsgan':
                loss += torch.mean((out_fake - 1)**2)  # LSGAN
            elif self.gan_type == 'nsgan':
                all1 = Variable(torch.ones_like(out_fake.data).cuda(),
                                requires_grad=False)
                loss += torch.mean(
                    F.binary_cross_entropy(F.sigmoid(out_fake), all1))
            else:
                assert 0, "Unsupported GAN type: {}".format(self.gan_type)
        return loss

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load content classifier
        last_model_name = get_model_list(checkpoint_dir, "con_cla")
        state_dict = torch.load(last_model_name)
        self.content_classifier.load_state_dict(state_dict['con_cla'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        self.cla_opt.load_state_dict(state_dict['con_cla'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        self.cla_scheduler = get_scheduler(self.cla_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        con_cla_name = os.path.join(snapshot_dir,
                                    'con_cla_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')

        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save({'con_cla': self.content_classifier.state_dict()},
                   con_cla_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict(),
                'con_cla': self.cla_opt.state_dict()
            }, opt_name)
Example #2
0
class UNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(UNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        h_a, _ = self.gen_a.encode(x_a)
        h_b, _ = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(h_b)
        x_ab = self.gen_b.decode(h_a)
        self.train()
        return x_ab, x_ba

    def __compute_kl(self, mu):
        # def _compute_kl(self, mu, sd):
        # mu_2 = torch.pow(mu, 2)
        # sd_2 = torch.pow(sd, 2)
        # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0)
        # return encoding_loss
        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        # encode
        h_a, n_a = self.gen_a.encode(x_a)
        h_b, n_b = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(h_a + n_a)
        x_b_recon = self.gen_b.decode(h_b + n_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(h_b + n_b)
        x_ab = self.gen_b.decode(h_a + n_a)
        # encode again
        h_b_recon, n_b_recon = self.gen_a.encode(x_ba)
        h_a_recon, n_a_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
        self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
        self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a)
        self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b)
        self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
        self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], []
        for i in range(x_a.size(0)):
            h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0))
            h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(h_a))
            x_b_recon.append(self.gen_b.decode(h_b))
            x_ba.append(self.gen_a.decode(h_b))
            x_ab.append(self.gen_b.decode(h_a))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba = torch.cat(x_ba)
        x_ab = torch.cat(x_ab)
        self.train()
        return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        # encode
        h_a, n_a = self.gen_a.encode(x_a)
        h_b, n_b = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(h_b + n_b)
        x_ab = self.gen_b.decode(h_a + n_a)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Example #3
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        '''
            input_dim_a和input_dim_b是输入图像的维度,RGB图就是3
            gen和dis是在yaml中定义的与架构相关的配置
        '''

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        '''
            为每幅显示的图像(总共16幅)配置随机的风格(维度为8)
        '''

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        '''
            这种简洁的写法值得学习:先将parameter()的list并起来,然后[p for p in params if p.requires_grad]
            这里分别为判别器参数、生成器参数各自建立一个优化器
            优化器采用Adam,算法参数为0.5和0.999
            优化器中可同时配置权重衰减,这里是1e-4
            学习率调节器默认配置为每100000步减小为0.5
        '''

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))
        '''
            注:这个apply函数递归地对每个子模块应用某种函数
        '''

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
        '''默认配置中,没有使用这个vgg网络'''

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    #   注,只有在forward内部,是evaluation模式,具体这个方法在哪里用到了,我还不太清楚。
    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1,
                                   1).cuda())  #   两个随机的风格码
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)  #   prime表示是由真图解码来的风格码
        c_b, s_b_prime = self.gen_b.encode(
            x_b)  #   c码为(1,256,64,64);s码为(1,8,1,1)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)  #   (a)用内容码和风格码还原原图
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)  #   (b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(
            c_a_recon,
            s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            c_b_recon,
            s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)  #   (a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)  #   (b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)  #   (c)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)  #   (d)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)  #   (e)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)  #   (f)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0  #   (g)
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0  #   (h)
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)  #   (i)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)  #   (j)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba,
            x_b) if hyperparameters['vgg_w'] > 0 else 0  #   (k)
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab,
            x_a) if hyperparameters['vgg_w'] > 0 else 0  #   (l)
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    #   送进去两张图片(batch),交换他们的风格
    def sample(self, x_a, x_b):
        self.eval()
        s_a1 = Variable(self.s_a)  #   这个是固定的某种风格
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1,
                                    1).cuda())  #   这是即时生成的随机风格
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(
                x_a[i].unsqueeze(0))  #   这是把送入的图片进行编码
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a,
                                               s_a_fake))  #   这是对送入的图像进行重建
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(
                c_b, s_a1[i].unsqueeze(0)))  #   这是把固定的风格施加在b的内容上,产生a风格的图片
            x_ba2.append(self.gen_a.decode(
                c_b, s_a2[i].unsqueeze(0)))  #   这是把随机风格施加在b的内容上,产生a风格的图片
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1,
                                   1).cuda())  #   生成图片的随机风格码 (1,8,1,1)
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)  #   将图像利用编码器
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Example #4
0
class IPMNet_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(IPMNet_Trainer, self).__init__()
        lr = hyperparameters['lr']
        vgg_weight_file = hyperparameters['vgg_weight_file']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = self.gen_a  # AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init(hyperparameters['init']))
        self.dis_b.apply(weights_init(hyperparameters['init']))

        # Load VGGFace model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_resnet50(vgg_weight_file)
            self.vgg.eval()
            self.vgg.fc.reset_parameters()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def gen_update(self, x_a, x_b, mask_a, mask_b, texture_a, texture_b,
                   hyperparameters):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime, x_a_gray_facial = self.gen_a.encode(
            x_a, mask_a, texture_a)
        c_b, s_b_prime, x_b_gray_facial = self.gen_b.encode(
            x_b, mask_b, texture_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime, x_a_gray_facial)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime, x_b_gray_facial)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a, x_b_gray_facial)
        x_ab = self.gen_b.decode(c_a, s_b, x_a_gray_facial)
        # encode again
        c_a_recon, s_b_recon, x_a_recon_gray_facial = self.gen_b.encode(
            x_ab, mask_a, texture_a)
        c_b_recon, s_a_recon, x_b_recon_gray_facial = self.gen_a.encode(
            x_ba, mask_b, texture_b)
        # decode again (if needed)
        x_aba = self.gen_a.decode(
            c_a_recon, s_a_prime, x_a_recon_gray_facial
        ) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            c_b_recon, s_b_prime, x_b_recon_gray_facial
        ) if hyperparameters['recon_x_cyc_w'] > 0 else None
        # background
        x_a_back = x_a * mask_a.repeat(1, 3, 1, 1)
        x_b_back = x_b * mask_b.repeat(1, 3, 1, 1)
        x_ab_back = x_ab * mask_a.repeat(1, 3, 1, 1)
        x_ba_back = x_ba * mask_b.repeat(1, 3, 1, 1)
        # foreground
        x_a_fore = x_a * (1 - mask_a).repeat(1, 3, 1, 1)
        x_b_fore = x_b * (1 - mask_b).repeat(1, 3, 1, 1)
        x_a_recon_fore = x_a_recon * (1 - mask_a).repeat(1, 3, 1, 1)
        x_b_recon_fore = x_b_recon * (1 - mask_b).repeat(1, 3, 1, 1)

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # backgrouned loss
        self.loss_back_x_a = self.recon_criterion(
            x_ab_back, x_a_back) if hyperparameters['back_w'] > 0 else 0
        self.loss_back_x_b = self.recon_criterion(
            x_ba_back, x_b_back) if hyperparameters['back_w'] > 0 else 0
        # foreground loss
        self.loss_fore_x_a = self.recon_criterion(
            x_a_recon_fore, x_a_fore) if hyperparameters['fore_w'] > 0 else 0
        self.loss_fore_x_b = self.recon_criterion(
            x_b_recon_fore, x_b_fore) if hyperparameters['fore_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b +\
                              hyperparameters['back_w'] * self.loss_back_x_a +\
                              hyperparameters['back_w'] * self.loss_back_x_b +\
                              hyperparameters['fore_w'] * self.loss_fore_x_a +\
                              hyperparameters['fore_w'] * self.loss_fore_x_b

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(torch.abs(img_fea - target_fea))

    def sample(self,
               x_a,
               x_b,
               mask_a,
               mask_b,
               texture_a,
               texture_b,
               hyperparameters,
               train=True):
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_a_facial_mask, x_b_facial_mask, x_ba, x_ab, x_aba, x_bab = [], [], [], [], [], [], [], []
        x_ab1, x_ab2, x_ba1, x_ba2 = [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a, x_a_gray_facial = self.gen_a.encode(
                x_a[i].unsqueeze(0), mask_a[i].unsqueeze(0),
                texture_a[i].unsqueeze(0))
            c_b, s_b, x_b_gray_facial = self.gen_b.encode(
                x_b[i].unsqueeze(0), mask_b[i].unsqueeze(0),
                texture_b[i].unsqueeze(0))
            if train:
                if i == 0:
                    print(s_a.squeeze())
                    print(s_b.squeeze())
            x_a_recon.append(self.gen_a.decode(c_a, s_a, x_a_gray_facial))
            x_b_recon.append(self.gen_b.decode(c_b, s_b, x_b_gray_facial))
            x_a_facial_mask.append(x_a_gray_facial)
            x_b_facial_mask.append(x_b_gray_facial)
            x_ba.append(self.gen_a.decode(c_b, s_a, x_b_gray_facial))
            x_ab.append(self.gen_b.decode(c_a, s_b, x_a_gray_facial))
            # randn style
            x_ba1.append(
                self.gen_a.decode(c_b, s_a1[i].unsqueeze(0), x_b_gray_facial))
            x_ab1.append(
                self.gen_b.decode(c_a, s_b1[i].unsqueeze(0), x_a_gray_facial))
            x_ba2.append(
                self.gen_a.decode(c_b, s_a2[i].unsqueeze(0), x_b_gray_facial))
            x_ab2.append(
                self.gen_b.decode(c_a, s_b2[i].unsqueeze(0), x_a_gray_facial))
            # encode again
            c_a_recon, _, x_a_recon_gray_facial = self.gen_a.encode(
                x_ab[i], mask_a[i].unsqueeze(0), texture_a[i].unsqueeze(0))
            c_b_recon, _, x_b_recon_gray_facial = self.gen_b.encode(
                x_ba[i], mask_b[i].unsqueeze(0), texture_b[i].unsqueeze(0))
            # decode again (if needed)
            x_aba_recon = self.gen_a.decode(
                c_a_recon, s_a, x_a_recon_gray_facial
            ) if hyperparameters['recon_x_cyc_w'] > 0 else None
            x_bab_recon = self.gen_b.decode(
                c_b_recon, s_b, x_b_recon_gray_facial
            ) if hyperparameters['recon_x_cyc_w'] > 0 else None
            x_aba.append(x_aba_recon)
            x_bab.append(x_bab_recon)

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_a_facial_mask, x_b_facial_mask = torch.cat(
            x_a_facial_mask), torch.cat(x_b_facial_mask)
        x_ab, x_ba = torch.cat(x_ab), torch.cat(x_ba)
        x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab)
        x_ab1, x_ab2, x_ba1, x_ba2 = torch.cat(x_ab1), torch.cat(
            x_ab2), torch.cat(x_ba1), torch.cat(x_ba2)
        self.train()
        return x_a, x_b, x_a_recon, x_a_facial_mask, x_ab, x_aba, \
               x_b, x_a, x_b_recon, x_b_facial_mask, x_ba, x_bab

    def dis_update(self, x_a, x_b, mask_a, mask_b, texture_a, texture_b,
                   hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _, x_a_gray_facial = self.gen_a.encode(x_a, mask_a, texture_a)
        c_b, _, x_b_gray_facial = self.gen_b.encode(x_b, mask_b, texture_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a, x_b_gray_facial)
        x_ab = self.gen_b.decode(c_a, s_b, x_a_gray_facial)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Example #5
0
class UnsupIntrinsicTrainer(nn.Module):
    def __init__(self, param):
        super(UnsupIntrinsicTrainer, self).__init__()
        lr = param['lr']
        # Initiate the networks
        self.gen_i = AdaINGen(param['input_dim_a'], param['input_dim_a'],
                              param['gen'])  # auto-encoder for domain I
        self.gen_r = AdaINGen(param['input_dim_b'], param['input_dim_b'],
                              param['gen'])  # auto-encoder for domain R
        self.gen_s = AdaINGen(param['input_dim_c'], param['input_dim_c'],
                              param['gen'])  # auto-encoder for domain S
        self.dis_r = MsImageDis(param['input_dim_b'],
                                param['dis'])  # discriminator for domain R
        self.dis_s = MsImageDis(param['input_dim_c'],
                                param['dis'])  # discriminator for domain S
        gp = param['gen']
        self.with_mapping = True
        self.use_phy_loss = True
        self.use_content_loss = True
        if 'ablation_study' in param:
            if 'with_mapping' in param['ablation_study']:
                wm = param['ablation_study']['with_mapping']
                self.with_mapping = True if wm != 0 else False
            if 'wo_phy_loss' in param['ablation_study']:
                wpl = param['ablation_study']['wo_phy_loss']
                self.use_phy_loss = True if wpl == 0 else False
            if 'wo_content_loss' in param['ablation_study']:
                wcl = param['ablation_study']['wo_content_loss']
                self.use_content_loss = True if wcl == 0 else False

        if self.with_mapping:
            self.fea_s = IntrinsicSplitor(gp['style_dim'], gp['mlp_dim'],
                                          gp['n_layer'],
                                          gp['activ'])  # split style for I
            self.fea_m = IntrinsicMerger(gp['style_dim'], gp['mlp_dim'],
                                         gp['n_layer'],
                                         gp['activ'])  # merge style for R, S
        self.bias_shift = param['bias_shift']
        self.instance_norm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = param['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(param['display_size'])
        self.s_r = torch.randn(display_size, self.style_dim, 1, 1).cuda() + 1.
        self.s_s = torch.randn(display_size, self.style_dim, 1, 1).cuda() - 1.

        # Setup the optimizers
        beta1 = param['beta1']
        beta2 = param['beta2']
        dis_params = list(self.dis_r.parameters()) + list(
            self.dis_s.parameters())
        if self.with_mapping:
            gen_params = list(self.gen_i.parameters()) + list(self.gen_r.parameters()) + \
                         list(self.gen_s.parameters()) + \
                         list(self.fea_s.parameters()) + list(self.fea_m.parameters())
        else:
            gen_params = list(self.gen_i.parameters()) + list(
                self.gen_r.parameters()) + list(self.gen_s.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=param['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=param['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, param)
        self.gen_scheduler = get_scheduler(self.gen_opt, param)

        # Network weight initialization
        self.apply(weights_init(param['init']))
        self.dis_r.apply(weights_init('gaussian'))
        self.dis_s.apply(weights_init('gaussian'))
        self.best_result = float('inf')
        self.reflectance_loss = LocalAlbedoSmoothnessLoss(param)

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def physical_criterion(self, x_i, x_r, x_s):
        return torch.mean(torch.abs(x_i - x_r * x_s))

    def forward(self, x_i):
        c_i, s_i_fake = self.gen_i.encode(x_i)
        if self.with_mapping:
            s_r, s_s = self.fea_s(s_i_fake)
        else:
            s_r = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) + self.bias_shift
            s_s = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) - self.bias_shift
        x_ri = self.gen_r.decode(c_i, s_r)
        x_si = self.gen_s.decode(c_i, s_s)
        return x_ri, x_si

    def inference(self, x_i, use_rand_fea=False):
        with torch.no_grad():
            c_i, s_i_fake = self.gen_i.encode(x_i)
            if self.with_mapping:
                s_r, s_s = self.fea_s(s_i_fake)
            else:
                s_r = Variable(
                    torch.randn(x_i.size(0), self.style_dim, 1,
                                1).cuda()) + self.bias_shift
                s_s = Variable(
                    torch.randn(x_i.size(0), self.style_dim, 1,
                                1).cuda()) - self.bias_shift
            if use_rand_fea:
                s_r = Variable(
                    torch.randn(x_i.size(0), self.style_dim, 1,
                                1).cuda()) + self.bias_shift
                s_s = Variable(
                    torch.randn(x_i.size(0), self.style_dim, 1,
                                1).cuda()) - self.bias_shift
            x_ri = self.gen_r.decode(c_i, s_r)
            x_si = self.gen_s.decode(c_i, s_s)
        return x_ri, x_si

    # noinspection PyAttributeOutsideInit
    def gen_update(self, x_i, x_r, x_s, targets=None, param=None):
        self.gen_opt.zero_grad()
        # ============= Domain Translations =============
        # encode
        c_i, s_i_prime = self.gen_i.encode(x_i)
        c_r, s_r_prime = self.gen_r.encode(x_r)
        c_s, s_s_prime = self.gen_s.encode(x_s)

        if self.with_mapping:
            s_ri, s_si = self.fea_s(s_i_prime)
        else:
            s_ri = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) + self.bias_shift
            s_si = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) - self.bias_shift
        s_r_rand = Variable(
            torch.randn(x_r.size(0), self.style_dim, 1,
                        1).cuda()) + self.bias_shift
        s_s_rand = Variable(
            torch.randn(x_s.size(0), self.style_dim, 1,
                        1).cuda()) - self.bias_shift
        if self.with_mapping:
            s_i_recon = self.fea_m(s_ri, s_si)
        else:
            s_i_recon = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda())
        # decode (within domain)
        x_i_recon = self.gen_i.decode(c_i, s_i_prime)
        x_r_recon = self.gen_s.decode(c_r, s_r_prime)
        x_s_recon = self.gen_r.decode(c_s, s_s_prime)
        # decode (cross domain)
        x_rs = self.gen_r.decode(c_s, s_r_rand)
        x_ri = self.gen_r.decode(c_i, s_ri)
        x_ri_rand = self.gen_r.decode(c_i, s_r_rand)
        x_sr = self.gen_s.decode(c_r, s_s_rand)
        x_si = self.gen_s.decode(c_i, s_si)
        x_si_rand = self.gen_s.decode(c_i, s_r_rand)
        # encode again, for feature domain consistency constraints
        c_rs_recon, s_rs_recon = self.gen_r.encode(x_rs)
        c_ri_recon, s_ri_recon = self.gen_r.encode(x_ri)
        c_ri_rand_recon, s_ri_rand_recon = self.gen_r.encode(x_ri_rand)
        c_sr_recon, s_sr_recon = self.gen_s.encode(x_sr)
        c_si_recon, s_si_recon = self.gen_s.encode(x_si)
        c_si_rand_recon, s_si_rand_recon = self.gen_s.encode(x_si_rand)
        # decode again, for image domain cycle consistency
        x_rsr = self.gen_r.decode(c_sr_recon, s_r_prime)
        x_iri = self.gen_i.decode(c_ri_recon, s_i_prime)
        x_iri_rand = self.gen_i.decode(c_ri_rand_recon, s_i_prime)
        x_srs = self.gen_s.decode(c_rs_recon, s_s_prime)
        x_isi = self.gen_i.decode(c_si_recon, s_i_prime)
        x_isi_rand = self.gen_i.decode(c_si_rand_recon, s_i_prime)

        # ============= Loss Functions =============
        # Encoder decoder reconstruction loss for three domain
        self.loss_gen_recon_x_i = self.recon_criterion(x_i_recon, x_i)
        self.loss_gen_recon_x_r = self.recon_criterion(x_r_recon, x_r)
        self.loss_gen_recon_x_s = self.recon_criterion(x_s_recon, x_s)
        # Style-level reconstruction loss for cross domain
        if self.with_mapping:
            self.loss_gen_recon_s_ii = self.recon_criterion(
                s_i_recon, s_i_prime)
        else:
            self.loss_gen_recon_s_ii = 0
        self.loss_gen_recon_s_ri = self.recon_criterion(s_ri_recon, s_ri)
        self.loss_gen_recon_s_ri_rand = self.recon_criterion(
            s_ri_rand_recon, s_ri)
        self.loss_gen_recon_s_rs = self.recon_criterion(s_rs_recon, s_r_rand)
        self.loss_gen_recon_s_sr = self.recon_criterion(s_sr_recon, s_s_rand)
        self.loss_gen_recon_s_si = self.recon_criterion(s_si_recon, s_si)
        self.loss_gen_recon_s_si_rand = self.recon_criterion(
            s_si_rand_recon, s_si)
        # Content-level reconstruction loss for cross domain
        self.loss_gen_recon_c_rs = self.recon_criterion(c_rs_recon, c_s)
        self.loss_gen_recon_c_ri = self.recon_criterion(
            c_ri_recon, c_i) if self.use_content_loss is True else 0
        self.loss_gen_recon_c_ri_rand = self.recon_criterion(
            c_ri_rand_recon, c_i) if self.use_content_loss is True else 0
        self.loss_gen_recon_c_sr = self.recon_criterion(c_sr_recon, c_r)
        self.loss_gen_recon_c_si = self.recon_criterion(
            c_si_recon, c_i) if self.use_content_loss is True else 0
        self.loss_gen_recon_c_si_rand = self.recon_criterion(
            c_si_rand_recon, c_i) if self.use_content_loss is True else 0
        # Cycle consistency loss for three image domain
        self.loss_gen_cyc_recon_x_rs = self.recon_criterion(x_rsr, x_r)
        self.loss_gen_cyc_recon_x_ir = self.recon_criterion(x_iri, x_i)
        self.loss_gen_cyc_recon_x_ir_rand = self.recon_criterion(
            x_iri_rand, x_i)
        self.loss_gen_cyc_recon_x_sr = self.recon_criterion(x_srs, x_s)
        self.loss_gen_cyc_recon_x_is = self.recon_criterion(x_isi, x_i)
        self.loss_gen_cyc_recon_x_is_rand = self.recon_criterion(
            x_isi_rand, x_i)
        # GAN loss
        self.loss_gen_adv_rs = self.dis_r.calc_gen_loss(x_rs)
        self.loss_gen_adv_ri = self.dis_r.calc_gen_loss(x_ri)
        self.loss_gen_adv_ri_rand = self.dis_r.calc_gen_loss(x_ri_rand)
        self.loss_gen_adv_sr = self.dis_s.calc_gen_loss(x_sr)
        self.loss_gen_adv_si = self.dis_s.calc_gen_loss(x_si)
        self.loss_gen_adv_si_rand = self.dis_s.calc_gen_loss(x_si_rand)
        # Physical loss
        self.loss_gen_phy_i = self.physical_criterion(
            x_i, x_ri, x_si) if self.use_phy_loss is True else 0
        self.loss_gen_phy_i_rand = self.physical_criterion(
            x_i, x_ri_rand, x_si_rand) if self.use_phy_loss is True else 0

        # Reflectance smoothness loss
        self.loss_refl_ri = self.reflectance_loss(
            x_ri, targets) if targets is not None else 0
        self.loss_refl_ri_rand = self.reflectance_loss(
            x_ri_rand, targets) if targets is not None else 0

        # total loss
        self.loss_gen_total = param['gan_w'] * self.loss_gen_adv_rs + \
                              param['gan_w'] * self.loss_gen_adv_ri + \
                              param['gan_w'] * self.loss_gen_adv_ri_rand + \
                              param['gan_w'] * self.loss_gen_adv_sr + \
                              param['gan_w'] * self.loss_gen_adv_si + \
                              param['gan_w'] * self.loss_gen_adv_si_rand + \
                              param['recon_x_w'] * self.loss_gen_recon_x_i + \
                              param['recon_x_w'] * self.loss_gen_recon_x_r + \
                              param['recon_x_w'] * self.loss_gen_recon_x_s + \
                              param['recon_s_w'] * self.loss_gen_recon_s_ii + \
                              param['recon_s_w'] * self.loss_gen_recon_s_ri + \
                              param['recon_s_w'] * self.loss_gen_recon_s_ri_rand + \
                              param['recon_s_w'] * self.loss_gen_recon_s_rs + \
                              param['recon_s_w'] * self.loss_gen_recon_s_si + \
                              param['recon_s_w'] * self.loss_gen_recon_s_si_rand + \
                              param['recon_s_w'] * self.loss_gen_recon_s_sr + \
                              param['recon_c_w'] * self.loss_gen_recon_c_ri + \
                              param['recon_c_w'] * self.loss_gen_recon_c_rs + \
                              param['recon_c_w'] * self.loss_gen_recon_c_ri_rand + \
                              param['recon_c_w'] * self.loss_gen_recon_c_si + \
                              param['recon_c_w'] * self.loss_gen_recon_c_sr + \
                              param['recon_c_w'] * self.loss_gen_recon_c_si_rand + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_ir + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_ir_rand + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_is + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_is_rand + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_rs + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_sr + \
                              param['phy_x_w'] * self.loss_gen_phy_i + \
                              param['phy_x_w'] * self.loss_gen_phy_i_rand + \
                              param['refl_smooth_w'] * self.loss_refl_ri + \
                              param['refl_smooth_w'] * self.loss_refl_ri_rand

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def sample(self, x_i, x_r, x_s):
        self.eval()
        s_r = Variable(self.s_r)
        s_s = Variable(self.s_s)
        x_i_recon, x_r_recon, x_s_recon, x_rs, x_ri, x_sr, x_si = [], [], [], [], [], [], []
        for i in range(x_i.size(0)):
            c_i, s_i_fake = self.gen_i.encode(x_i[i].unsqueeze(0))
            c_r, s_r_fake = self.gen_r.encode(x_r[i].unsqueeze(0))
            c_s, s_s_fake = self.gen_s.encode(x_s[i].unsqueeze(0))
            if self.with_mapping:
                s_ri, s_si = self.fea_s(s_i_fake)
            else:
                s_ri = Variable(torch.randn(1, self.style_dim, 1,
                                            1).cuda()) + self.bias_shift
                s_si = Variable(torch.randn(1, self.style_dim, 1,
                                            1).cuda()) - self.bias_shift
            x_i_recon.append(self.gen_i.decode(c_i, s_i_fake))
            x_r_recon.append(self.gen_r.decode(c_r, s_r_fake))
            x_s_recon.append(self.gen_s.decode(c_s, s_s_fake))
            x_rs.append(self.gen_r.decode(c_s, s_r[i].unsqueeze(0)))
            x_ri.append(self.gen_r.decode(c_i, s_ri.unsqueeze(0)))
            x_sr.append(self.gen_s.decode(c_s, s_s[i].unsqueeze(0)))
            x_si.append(self.gen_s.decode(c_i, s_si.unsqueeze(0)))
        x_i_recon, x_r_recon, x_s_recon = torch.cat(x_i_recon), torch.cat(
            x_r_recon), torch.cat(x_s_recon)
        x_rs, x_ri = torch.cat(x_rs), torch.cat(x_ri)
        x_sr, x_si = torch.cat(x_sr), torch.cat(x_si)
        self.train()
        return x_i, x_i_recon, x_r, x_r_recon, x_rs, x_ri, x_s, x_s_recon, x_sr, x_si

    # noinspection PyAttributeOutsideInit
    def dis_update(self, x_i, x_r, x_s, params):
        self.dis_opt.zero_grad()
        s_r = Variable(torch.randn(x_r.size(0), self.style_dim, 1,
                                   1).cuda()) - self.bias_shift
        s_s = Variable(torch.randn(x_s.size(0), self.style_dim, 1,
                                   1).cuda()) + self.bias_shift
        # encode
        c_r, _ = self.gen_r.encode(x_r)
        c_s, _ = self.gen_s.encode(x_s)
        c_i, s_i = self.gen_i.encode(x_i)
        if self.with_mapping:
            s_ri, s_si = self.fea_s(s_i)
        else:
            s_ri = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) + self.bias_shift
            s_si = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) - self.bias_shift
        # decode (cross domain)
        x_rs = self.gen_r.decode(c_s, s_r)
        x_ri = self.gen_r.decode(c_i, s_ri)
        x_sr = self.gen_s.decode(c_r, s_s)
        x_si = self.gen_s.decode(c_i, s_si)
        # D loss
        self.loss_dis_rs = self.dis_r.calc_dis_loss(x_rs.detach(), x_r)
        self.loss_dis_ri = self.dis_r.calc_dis_loss(x_ri.detach(), x_r)
        self.loss_dis_sr = self.dis_s.calc_dis_loss(x_sr.detach(), x_s)
        self.loss_dis_si = self.dis_s.calc_dis_loss(x_si.detach(), x_s)

        self.loss_dis_total = params['gan_w'] * self.loss_dis_rs +\
                              params['gan_w'] * self.loss_dis_ri +\
                              params['gan_w'] * self.loss_dis_sr +\
                              params['gan_w'] * self.loss_dis_si

        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, param):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_i.load_state_dict(state_dict['i'])
        self.gen_r.load_state_dict(state_dict['r'])
        self.gen_s.load_state_dict(state_dict['s'])
        if self.with_mapping:
            self.fea_m.load_state_dict(state_dict['fm'])
            self.fea_s.load_state_dict(state_dict['fs'])
        self.best_result = state_dict['best_result']
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_r.load_state_dict(state_dict['r'])
        self.dis_s.load_state_dict(state_dict['s'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, param, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, param, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        if self.with_mapping:
            torch.save(
                {
                    'i': self.gen_i.state_dict(),
                    'r': self.gen_r.state_dict(),
                    's': self.gen_s.state_dict(),
                    'fs': self.fea_s.state_dict(),
                    'fm': self.fea_m.state_dict(),
                    'best_result': self.best_result
                }, gen_name)
        else:
            torch.save(
                {
                    'i': self.gen_i.state_dict(),
                    'r': self.gen_r.state_dict(),
                    's': self.gen_s.state_dict(),
                    'best_result': self.best_result
                }, gen_name)
        torch.save({
            'r': self.dis_r.state_dict(),
            's': self.dis_s.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Example #6
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.a_attibute = hyperparameters['label_a']
        self.b_attibute = hyperparameters['label_b']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        if self.a_attibute == 0:
            self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        else:
            self.s_a = torch.randn(display_size,
                                   self.style_dim - self.a_attibute, 1,
                                   1).cuda()
            s_attribute = [i % self.a_attibute for i in range(display_size)]
            s_attribute = torch.tensor(s_attribute, dtype=torch.long).reshape(
                (display_size, 1))
            label_a = torch.zeros(display_size,
                                  self.a_attibute,
                                  dtype=torch.float32).scatter_(
                                      1, s_attribute, 1)
            label_a = label_a.reshape(display_size, self.a_attibute, 1,
                                      1).cuda()
            self.s_a = torch.cat([self.s_a, label_a], 1)
        if self.b_attibute == 0:
            self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        else:
            self.s_b = torch.randn(display_size,
                                   self.style_dim - self.b_attibute, 1,
                                   1).cuda()
            s_attribute = [i % self.b_attibute for i in range(display_size)]
            s_attribute = torch.tensor(s_attribute, dtype=torch.long).reshape(
                (display_size, 1))
            label_b = torch.zeros(display_size,
                                  self.b_attibute,
                                  dtype=torch.float32).scatter_(
                                      1, s_attribute, 1)
            label_b = label_b.reshape(display_size, self.b_attibute, 1,
                                      1).cuda()
            self.s_b = torch.cat([self.s_b, label_b], 1)

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self,
                   x_a,
                   x_b,
                   hyperparameters,
                   label_a=None,
                   label_b=None):
        self.gen_opt.zero_grad()
        if label_a is None:
            s_a = Variable(
                torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        else:
            style_num = label_a.size(1)
            s_a = Variable(
                torch.randn(x_a.size(0), self.style_dim - style_num, 1,
                            1).cuda())
            label_a = label_a.repeat(x_a.size(0), 1)
            label_a = label_a.reshape(x_a.size(0), style_num, 1, 1)
            s_a = torch.cat([s_a, label_a], 1)
        if label_b is None:
            s_b = Variable(
                torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        else:
            style_num = label_b.size(1)
            s_b = Variable(
                torch.randn(x_b.size(0), self.style_dim - style_num, 1,
                            1).cuda())
            label_b = label_b.repeat(x_b.size(0), 1)
            label_b = label_b.reshape(x_b.size(0), style_num, 1, 1)
            s_b = torch.cat([s_b, label_b], 1)
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(
            c_a_recon,
            s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            c_b_recon,
            s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a, self.loss_gen_class_a = self.dis_a.calc_gen_loss(
            x_ba, label_a)
        self.loss_gen_adv_b, self.loss_gen_class_b = self.dis_b.calc_gen_loss(
            x_ab, label_b)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['gan_w'] * self.loss_gen_class_a + \
                              hyperparameters['gan_w'] * self.loss_gen_class_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(
        self,
        x_a,
        x_b,
        hyperparameters,
        label_a=None,
        label_b=None,
    ):
        self.dis_opt.zero_grad()
        if label_a is None:
            s_a = Variable(
                torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        else:  # utilize label in the style code
            style_num = label_a.size(1)
            s_a = Variable(
                torch.randn(x_a.size(0), self.style_dim - style_num, 1,
                            1).cuda())
            label_a = label_a.repeat(x_a.size(0), 1)
            label_a = label_a.reshape(x_a.size(0), style_num, 1, 1)
            s_a = torch.cat([s_a, label_a], 1)
        if label_b is None:
            s_b = Variable(
                torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        else:  # utilize label in the style code
            style_num = label_b.size(1)
            s_b = Variable(
                torch.randn(x_b.size(0), self.style_dim - style_num, 1,
                            1).cuda())
            label_b = label_b.repeat(x_b.size(0), 1)
            label_b = label_b.reshape(x_b.size(0), style_num, 1, 1)
            s_b = torch.cat([s_b, label_b], 1)

        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ab = self.gen_b.decode(c_a, s_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        # D loss
        self.loss_dis_a, self.loss_class_a = self.dis_a.calc_dis_loss(
            x_ba.detach(), x_a, label_a)
        self.loss_dis_b, self.loss_class_b = self.dis_b.calc_dis_loss(
            x_ab.detach(), x_b, label_b)

        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b + \
                              hyperparameters['gan_w'] * self.loss_class_a + hyperparameters['gan_w'] * self.loss_class_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Example #7
0
class DGNet_Trainer(nn.Module):
    #初始化函数
    def __init__(self, hyperparameters, gpu_ids=[0]):
        super(DGNet_Trainer, self).__init__()
        # 从配置文件获取生成模型的和鉴别模型的学习率
        lr_g = hyperparameters['lr_g']
        lr_d = hyperparameters['lr_d']

        # # ID的类别,这里要注意,不同的数据集都是不一样的,应该是训练数据集的ID数目,非测试集
        ID_class = hyperparameters['ID_class']

        # 看是否设置使用float16,估计float16可以增加精确度
        if not 'apex' in hyperparameters.keys():
            hyperparameters['apex'] = False
        self.fp16 = hyperparameters['apex']

        # Initiate the networks
        # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False.
        ################################################################################################################
        ##这里是定义Es和G
        # 注意这里包含了两个步骤,Es编码+解码过程,既然解码(论文Figure 2的黄色梯形G)包含到这里了,下面Ea应该不会包含解码过程了
        # 因为这里是一个类,如后续gen_a.encode()可以进行编码,gen_b.encode()可以进行解码
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'],
                              hyperparameters['gen'],
                              fp16=False)  # auto-encoder for domain a
        self.gen_b = self.gen_a  # auto-encoder for domain b
        ############################################################################################################################################

        ############################################################################################################################################
        ##这里是定义Ea
        # ID_stride,外观编码器池化层的stride
        if not 'ID_stride' in hyperparameters.keys():
            hyperparameters['ID_stride'] = 2

        # hyperparameters['ID_style']默认为'AB',论文中的Ea编码器
        #这里是设置Ea,有三种模型可以选择
        #PCB模型,ft_netAB为改造后的resnet50,ft_net为resnet50
        if hyperparameters['ID_style'] == 'PCB':
            self.id_a = PCB(ID_class)
        elif hyperparameters['ID_style'] == 'AB':
            # 这是我们执行的模型,注意的是,id_a返回两个x(表示身份),获得f,具体介绍看函数内部
            # 我们使用的是ft_netAB,是代码中Ea编码的过程,也就得到 ap code的过程,除了ap code还会得到两个分类结果
            # 现在怀疑,该分类结果,可能就是行人重识别的结果
            #ID_class表示有ID_class个不同ID的行人
            self.id_a = ft_netAB(ID_class,
                                 stride=hyperparameters['ID_stride'],
                                 norm=hyperparameters['norm_id'],
                                 pool=hyperparameters['pool'])
        else:
            self.id_a = ft_net(ID_class,
                               norm=hyperparameters['norm_id'],
                               pool=hyperparameters['pool'])  # return 2048 now

        # 这里进行的是浅拷贝,所以我认为他们的权重是一起的,可以理解为一个
        self.id_b = self.id_a
        ############################################################################################################################################################

        ############################################################################################################################################################
        ##这里是定义D
        # 鉴别器,行人重识别,这里使用的是一个多尺寸的鉴别器,大概就是说,对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失
        # 经过网络3个元素,分别大小为[batch_size,1,64,32], [batch_size,1,32,16], [batch_size,1,16,8]
        self.dis_a = MsImageDis(3, hyperparameters['dis'],
                                fp16=False)  # discriminator for domain a
        self.dis_b = self.dis_a  # discriminator for domain b
        ############################################################################################################################################################

        ############################################################################################################################################################
        # load teachers
        # 加载老师模型
        # teacher:老师模型名称。对于DukeMTMC,您可以设置“best - duke”
        if hyperparameters['teacher'] != "":
            #teacher_name=best
            teacher_name = hyperparameters['teacher']
            print(teacher_name)
            #有这个操作,我怀疑是可以加载多个教师模型
            teacher_names = teacher_name.split(',')
            #构建老师模型
            teacher_model = nn.ModuleList()
            teacher_count = 0

            # 默认只有一个teacher_name='teacher_name',所以其加载的模型配置文件为项目根目录models/best/opts.yaml模型
            for teacher_name in teacher_names:
                # 加载配置文件models/best/opts.yaml
                config_tmp = load_config(teacher_name)
                if 'stride' in config_tmp:
                    #stride=1
                    stride = config_tmp['stride']
                else:
                    stride = 2

                #  老师模型加载,老师模型为ft_net为resnet50
                model_tmp = ft_net(ID_class, stride=stride)
                teacher_model_tmp = load_network(model_tmp, teacher_name)
                # 移除原本的全连接层
                teacher_model_tmp.model.fc = nn.Sequential(
                )  # remove the original fc layer in ImageNet
                teacher_model_tmp = teacher_model_tmp.cuda()
                # summary(teacher_model_tmp, (3, 224, 224))
                #使用浮点型
                if self.fp16:
                    teacher_model_tmp = amp.initialize(teacher_model_tmp,
                                                       opt_level="O1")
                teacher_model.append(teacher_model_tmp.cuda().eval())
                teacher_count += 1
            self.teacher_model = teacher_model
            # 选择是否使用bn
            if hyperparameters['train_bn']:
                self.teacher_model = self.teacher_model.apply(train_bn)
############################################################################################################################################################

# 实例正则化
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # RGB to one channel
        # 默认设置signal=gray,Es的输入为灰度图
        if hyperparameters['single'] == 'edge':
            self.single = to_edge
        else:
            self.single = to_gray(False)

        # Random Erasing when training
        #earsing_p表示随机擦除的概率
        if not 'erasing_p' in hyperparameters.keys():
            self.erasing_p = 0
        else:
            self.erasing_p = hyperparameters['erasing_p']
        #随机擦除矩形区域的一些像素,应该类似于数据增强
        self.single_re = RandomErasing(probability=self.erasing_p,
                                       mean=[0.0, 0.0, 0.0])
        # 设置T_w为1,T_w为primary feature learning loss的权重系数
        if not 'T_w' in hyperparameters.keys():
            hyperparameters['T_w'] = 1

        ################################################################################################
        # Setup the optimizers
        # 设置优化器参数
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(
            self.dis_a.parameters())  #+ list(self.dis_b.parameters())
        gen_params = list(
            self.gen_a.parameters())  #+ list(self.gen_b.parameters())
        #使用Adams优化器,用Adams训练Es,G,D
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr_d,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr_g,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        # id params
        # 因为ID_style默认为AB,所以这里不执行
        if hyperparameters['ID_style'] == 'PCB':
            ignored_params = (
                list(map(id, self.id_a.classifier0.parameters())) +
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())) +
                list(map(id, self.id_a.classifier3.parameters())))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            #Ea 的优化器
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier0.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier3.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)

        #     这里是我们执行的代码
        elif hyperparameters['ID_style'] == 'AB':
            # 忽略的参数,应该是适用于'PCB'或者其他的,但是不适用于'AB'的
            ignored_params = (
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())))
            # 获得基本的配置参数,如学习率
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']

            #对Ea使用SGD
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        else:
            ignored_params = list(map(id, self.id_a.classifier.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)

        # 选择各个网络的优化
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.id_scheduler = get_scheduler(self.id_opt, hyperparameters)
        self.id_scheduler.gamma = hyperparameters['gamma2']

        #ID Loss
        #交叉熵损失函数
        self.id_criterion = nn.CrossEntropyLoss()
        # KL散度
        self.criterion_teacher = nn.KLDivLoss(size_average=False)

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # save memory
        if self.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.id_a = self.id_a.cuda()

            self.gen_b = self.gen_a
            self.dis_b = self.dis_a
            self.id_b = self.id_a

            self.gen_a, self.gen_opt = amp.initialize(self.gen_a,
                                                      self.gen_opt,
                                                      opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a,
                                                      self.dis_opt,
                                                      opt_level="O1")
            self.id_a, self.id_opt = amp.initialize(self.id_a,
                                                    self.id_opt,
                                                    opt_level="O1")

    def to_re(self, x):
        out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3))
        out = out.cuda()
        for i in range(x.size(0)):
            out[i, :, :, :] = self.single_re(x[i, :, :, :])
        return out

    # L1 loss,(差的绝对值)
    def recon_criterion(self, input, target):
        diff = input - target.detach()
        return torch.mean(torch.abs(diff[:]))

    #L1 loss 开根号((差的绝对值后开根号))
    def recon_criterion_sqrt(self, input, target):
        diff = input - target
        return torch.mean(torch.sqrt(torch.abs(diff[:]) + 1e-8))

    # L2 loss
    def recon_criterion2(self, input, target):
        diff = input - target
        return torch.mean(diff[:]**2)

    # cos loss
    def recon_cos(self, input, target):
        cos = torch.nn.CosineSimilarity()
        cos_dis = 1 - cos(input, target)
        return torch.mean(cos_dis[:])

    # x_a,x_b, xp_a, xp_b[4, 3, 256, 128],
    # 第一个参数表示bitch size,第二个参数表示输入通道数,第三个参数表示输入图片的高度,第四个参数表示输入图片的宽度
    def forward(self, x_a, x_b, xp_a, xp_b):
        #送入x_a,x_b两张图片(来自训练集不同ID)
        #通过st编码器,编码成两个stcode,structure code
        # s_a[batch,128,64,32]
        # s_b[batch,128,64,32]
        # single会根据参数设定判断是否转化为灰度图
        s_a = self.gen_a.encode(self.single(x_a))
        s_b = self.gen_b.encode(self.single(x_b))

        # 先把图片进行下采样,图示我们可以看到ap code的体积比st code是要小的,这样会出现一个情况,那么他们是没有办法直接融合的,所以后面有个全链接成把他们统一
        # f_a[batch_size,2024*4=8192],     p_a[0]=[batch_size, class_num=751], p_a[1]=[batch_size, class_num=751]
        # f_b[batch_size,2024*4=8192],     p_b[0]=[batch_size, class_num=751], p_b[1]=[batch_size, class_num=751]
        # f代表的是经过ap编码器得到的ap code,
        # p表示对身份的预测(有两个身份预测,也就是p_a了两个元素,这里不好解释),
        # 前面提到过,ap编码器,不仅负责编码,还要负责身份的预测(行人重识别),也是我们落实项目的关键所在
        # 这里是第一个重难点,在论文的翻译中提到过,后续详细讲解
        f_a, p_a = self.id_a(scale2(x_a))
        f_b, p_b = self.id_b(scale2(x_b))

        # 进行解码操作,就是Figure 2中的黄色梯形G操作,这里的x_a,与x_b进行衣服互换,不同ID
        # s_b[batch,128,64,32] f_a[batch_size,2028,4,1] -->  x_ba[batch_size,3,256,128]
        x_ba = self.gen_a.decode(s_b, f_a)
        x_ab = self.gen_b.decode(s_a, f_b)

        #同一张图片进行重构,相当于autoencoder
        x_a_recon = self.gen_a.decode(s_a, f_a)
        x_b_recon = self.gen_b.decode(s_b, f_b)

        fp_a, pp_a = self.id_a(scale2(xp_a))
        fp_b, pp_b = self.id_b(scale2(xp_b))

        # decode the same person
        #x_a,xp_a表示同ID的不同图片,以下即表示同ID不同图片的重构
        x_a_recon_p = self.gen_a.decode(s_a, fp_a)
        x_b_recon_p = self.gen_b.decode(s_b, fp_b)

        # Random Erasing only effect the ID and PID loss.
        #把图片擦除一些像素,然后进行ap code编码
        if self.erasing_p > 0:
            #先把每一张图片都擦除一些像素
            x_a_re = self.to_re(scale2(x_a.clone()))
            x_b_re = self.to_re(scale2(x_b.clone()))
            xp_a_re = self.to_re(scale2(xp_a.clone()))
            xp_b_re = self.to_re(scale2(xp_b.clone()))

            # 然后经过编码成ap code,暂时不知道作用,感觉应该是数据增强
            # 类似于,擦除了图片的一些像素,但是已经能够识别出来这些图片是谁
            _, p_a = self.id_a(x_a_re)
            _, p_b = self.id_b(x_b_re)
            # encode the same ID different photo
            _, pp_a = self.id_a(xp_a_re)
            _, pp_b = self.id_b(xp_b_re)

        # 混合合成图片:x_ab[images_a的st,images_b的ap]    混合合成图片x_ba[images_b的st,images_a的ap]
        # s_a[输入图片images_a经过Es编码得到的 st code]     s_b[输入图片images_b经过Es编码得到的 st code]
        # f_a[输入图片images_a经过Ea编码得到的 ap code]     f_b[输入图片images_b经过Ea编码得到的 ap code]
        # p_a[输入图片images_a经过Ea编码进行身份ID的预测]    p_b[输入图片images_b经过Ea编码进行身份ID的预测]
        # pp_a[输入图片pos_a经过Ea编码进行身份ID的预测]      pp_b[输入图片pos_b经过Ea编码进行身份ID的预测]
        # x_a_recon[输入图片images_a(s_a)与自己(f_a)合成的图片,当然和images_a长得一样]
        # x_b_recon[输入图片images_b(s_b)与自己(f_b)合成的图片,当然和images_b长得一样]
        # x_a_recon_p[输入图片images_a(s_a)与图片pos_a(fp_a)合成的图片,当然和images_a长得一样]
        # x_b_recon_p[输入图片images_a(s_a)与图片pos_b(fp_b)合成的图片,当然和images_b长得一样]

        return x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p

    def gen_update(self, x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b,
                   x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, x_a, x_b,
                   xp_a, xp_b, l_a, l_b, hyperparameters, iteration, num_gpu):
        """

        :param x_ab:[images_a的st,images_b的ap]
        :param x_ba:[images_b的st,images_a的ap]
        :param s_a:[输入图片images_a经过Es编码得到的 st code]
        :param s_b:[输入图片images_b经过Es编码得到的 st code]
        :param f_a:[输入图片images_a经过Ea编码得到的 ap code]
        :param f_b:[输入图片images_b经过Ea编码得到的 ap code]
        :param p_a:[输入图片images_a经过Ea编码进行身份ID的预测]
        :param p_b:[输入图片images_b经过Ea编码进行身份ID的预测]
        :param pp_a:[输入图片pos_a经过Ea编码进行身份ID的预测]
        :param pp_b:[输入图片pos_b经过Ea编码进行身份ID的预测]
        :param x_a_recon:[输入图片images_a(s_a)与自己(f_a)合成的图片,当然和images_a长得一样]
        :param x_b_recon:[输入图片images_b(s_b)与自己(f_b)合成的图片,当然和images_b长得一样]
        :param x_a_recon_p:[输入图片images_a(s_a)与图片pos_a(fp_a)合成的图片,当然和images_a长得一样]
        :param x_b_recon_p:[输入图片images_b(s_b)与图片pos_b(fp_b)合成的图片,当然和images_b长得一样]
        :param x_a:images_a
        :param x_b:images_b
        :param xp_a:pos_a
        :param xp_b:pos_b
        :param l_a:labels_a
        :param l_b:labels_b
        :param hyperparameters:
        :param iteration:
        :param num_gpu:
        :return:
        """
        # ppa, ppb is the same person?
        self.gen_opt.zero_grad()  #梯度清零
        self.id_opt.zero_grad()

        # no gradient
        # 对合成x_ba和x_ab分别进行一份拷贝
        x_ba_copy = Variable(x_ba.data, requires_grad=False)
        x_ab_copy = Variable(x_ab.data, requires_grad=False)

        rand_num = random.uniform(0, 1)
        #################################
        # encode structure
        # enc_content是类ContentEncoder对象
        if hyperparameters['use_encoder_again'] >= rand_num:
            # encode again (encoder is tuned, input is fixed)
            # Es编码得到s_a_recon与s_b_recon即st code
            # 如果是理想模型,s_a_recon=s_a, s_b_recon=s_b
            s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy))
            s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy))
        else:
            # copy the encoder
            # 这里的是深拷贝
            #enc_content_copy=gen_a.enc_content
            self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content)
            self.enc_content_copy = self.enc_content_copy.eval()
            # encode again (encoder is fixed, input is tuned)
            s_a_recon = self.enc_content_copy(self.single(x_ab))
            s_b_recon = self.enc_content_copy(self.single(x_ba))

        #################################
        # encode appearance
        #id_a_copy=id_a=Ea
        self.id_a_copy = copy.deepcopy(self.id_a)
        self.id_a_copy = self.id_a_copy.eval()
        if hyperparameters['train_bn']:
            self.id_a_copy = self.id_a_copy.apply(train_bn)
        self.id_b_copy = self.id_a_copy
        # encode again (encoder is fixed, input is tuned)
        # 对混合生成的图片x_ba,x_ab进行Es编码操作,同时对身份进行鉴别#
        # f_a_recon,f_b_recon表示的ap code,p_a_recon,p_b_recon表示对身份的鉴别
        f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba))
        f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab))

        # teacher Loss
        #  Tune the ID model
        log_sm = nn.LogSoftmax(dim=1)
        #如果使用了教师网络
        #默认ID_style为AB
        if hyperparameters['teacher_w'] > 0 and hyperparameters[
                'teacher'] != "":
            if hyperparameters['ID_style'] == 'normal':
                #p_a_student表示x_ba_copy的身份编码,使用的是Ea进行身份编码,也就是使用学生模型进行身份编码
                _, p_a_student = self.id_a(scale2(x_ba_copy))
                #对p_a_student使用logsoftmax,输出结果为x_ba_copy像某张图片的概率(就是一个分布)
                p_a_student = log_sm(p_a_student)
                #使用教师模型对生成图像x_ba_copy进行分类,输出结果为x_ba_copy像某张图片的概率(就是一个分布)
                p_a_teacher = predict_label(
                    self.teacher_model,
                    scale2(x_ba_copy),
                    num_class=hyperparameters['ID_class'],
                    alabel=l_a,
                    slabel=l_b,
                    teacher_style=hyperparameters['teacher_style'])
                #通过最小化KL散度损失函数,目的是让分布p_a_student与p_a_teacher尽可能的一致
                self.loss_teacher = self.criterion_teacher(
                    p_a_student, p_a_teacher) / p_a_student.size(0)

                #对x_ab_copy进行同样的操作
                _, p_b_student = self.id_b(scale2(x_ab_copy))
                p_b_student = log_sm(p_b_student)
                p_b_teacher = predict_label(
                    self.teacher_model,
                    scale2(x_ab_copy),
                    num_class=hyperparameters['ID_class'],
                    alabel=l_b,
                    slabel=l_a,
                    teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(
                    p_b_student, p_b_teacher) / p_b_student.size(0)

            #######################################################################################################################################################################################################
            # primary feature learning loss
            #######################################################################################################################################################################################################
            #  ID_style为AB
            elif hyperparameters['ID_style'] == 'AB':
                # normal teacher-student loss
                # BA -> LabelA(smooth) + LabelB(batchB)
                # 合成的图片经过身份鉴别器,得到每个ID可能性的概率,注意这里去的是p_ba_student[0],我们知有两个身份预测结果,这里只取了一个
                # 并且赋值给了p_a_student,用于和教师模型结合的,共同计算损失
                #p_a_student分为两个部分,p_a_student[0]表示L_prim,p_a_student[1]表示L_fine。
                _, p_ba_student = self.id_a(scale2(x_ba_copy))  # f_a, s_b
                p_a_student = log_sm(p_ba_student[0])

                with torch.no_grad():
                    ##使用教师模型对生成图像x_ba_copy进行分类,输出结果为x_ba_copy像某张图片(x_a/x_b)的概率(就是一个分布)
                    p_a_teacher = predict_label(
                        self.teacher_model,
                        scale2(x_ba_copy),
                        num_class=hyperparameters['ID_class'],
                        alabel=l_a,
                        slabel=l_b,
                        teacher_style=hyperparameters['teacher_style'])

                # criterion_teacher = nn.KLDivLoss(size_average=False)
                # 计算离散距离,可以理解为p_a_student与p_a_teacher每个元素的距离之和,然后除以p_a_student.size(0)取平均值
                # 就是说学生网络(Ea)的预测越与教师网络结果相同,则是最好的
                self.loss_teacher = self.criterion_teacher(
                    p_a_student, p_a_teacher) / p_a_student.size(0)

                # 对另一张合成图片进行同样的操作
                _, p_ab_student = self.id_b(scale2(x_ab_copy))  # f_b, s_a
                p_b_student = log_sm(p_ab_student[0])
                with torch.no_grad():
                    p_b_teacher = predict_label(
                        self.teacher_model,
                        scale2(x_ab_copy),
                        num_class=hyperparameters['ID_class'],
                        alabel=l_b,
                        slabel=l_a,
                        teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(
                    p_b_student, p_b_teacher) / p_b_student.size(0)
                ########################################################################################################################################################################################################

                ########################################################################################################################################################################################################
                #fine—grained feature mining loss
                ########################################################################################################################################################################################################
                # branch b loss
                # here we give different label
                # p_ba_student[1]表示的是f_fine特征,l_b表示的是images_b,即为生成图像提供st code 的图片
                loss_B = self.id_criterion(p_ba_student[1],
                                           l_b) + self.id_criterion(
                                               p_ab_student[1], l_a)
                #######################################################################################################################################################################################################

                # 对两部分损失进行权重调整
                self.loss_teacher = hyperparameters[
                    'T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B
        else:
            self.loss_teacher = 0.0

        ## 剩下的就是重构图像之间的损失了
        # 前面提到,重构和合成是不一样的,重构是构建出来和原来图片一样的图片
        # 所以也就是可以把重构的图片和原来的图像直接计算像素直接的插值
        # 但是合成的图片是没有办法的,因为训练数据集是没有合成图片的,所以,没有办法计算像素之间的损失
        # #######################################################################################################################################################################################################
        # auto-encoder image reconstruction
        # 同ID图像进行重构时的损失函数
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a)
        self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b)
        # #######################################################################################################################################################################################################

        #######################################################################################################################################################################################################
        # feature reconstruction
        # 不同ID图像进行图像合成时,为了保证合成图像的st code和ap code与为合成图像提供st code 和 ap code保持一致所使用的损失函数
        self.loss_gen_recon_s_a = self.recon_criterion(
            s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_s_b = self.recon_criterion(
            s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_f_a = self.recon_criterion(
            f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0
        self.loss_gen_recon_f_b = self.recon_criterion(
            f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0
        # #######################################################################################################################################################################################################

        # 又一次进行图像合成
        x_aba = self.gen_a.decode(
            s_a_recon,
            f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            s_b_recon,
            f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # ID loss AND Tune the Generated image
        if hyperparameters['ID_style'] == 'PCB':
            self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b)
            self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b)
            self.loss_gen_recon_id = self.PCB_loss(
                p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b)

        ########################################################################################################################################################################################################
        #   使用的是  ['ID_style']=='AB'
        elif hyperparameters['ID_style'] == 'AB':
            weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w']
            #计算的是L^s_id
            self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \
                         + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) )

            #对同ID不同图片计算L^s_id
            self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion(
                pp_b[0], l_b
            )  #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) )

            # 对生成图像计算L^C_id
            self.loss_gen_recon_id = self.id_criterion(
                p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b)
        ########################################################################################################################################################################################################

        else:
            self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion(
                p_b, l_b)
            self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion(
                pp_b, l_b)
            self.loss_gen_recon_id = self.id_criterion(
                p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b)

        #print(f_a_recon, f_a)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0

        ########################################################################################################################################################################################################
        # GAN loss
        #计算生成器G的对抗损失函数
        ########################################################################################################################################################################################################
        if num_gpu > 1:
            self.loss_gen_adv_a = self.dis_a.module.calc_gen_loss(
                self.dis_a, x_ba)
            self.loss_gen_adv_b = self.dis_b.module.calc_gen_loss(
                self.dis_b, x_ab)
        else:
            self.loss_gen_adv_a = self.dis_a.calc_gen_loss(self.dis_a, x_ba)
            self.loss_gen_adv_b = self.dis_b.calc_gen_loss(self.dis_b, x_ab)
        ########################################################################################################################################################################################################

        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0

        if iteration > hyperparameters['warm_iter']:
            hyperparameters['recon_f_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'],
                                               hyperparameters['max_w'])
            hyperparameters['recon_s_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'],
                                               hyperparameters['max_w'])
            hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_x_cyc_w'] = min(
                hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w'])

        if iteration > hyperparameters['warm_teacher_iter']:
            hyperparameters['teacher_w'] += hyperparameters['warm_scale']
            hyperparameters['teacher_w'] = min(
                hyperparameters['teacher_w'], hyperparameters['max_teacher_w'])
        # total loss,计算总的loss
        #1个teacher loss+4个同ID图片重构loss+4个不同ID图片合成loss++3个ID loss+2个生成器loss、
        #teacher loss包括了primary feature learning loss和fine_grain mining loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['id_w'] * self.loss_id + \
                              hyperparameters['pid_w'] * self.loss_pid + \
                              hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
                              hyperparameters['teacher_w'] * self.loss_teacher

        if self.fp16:
            with amp.scale_loss(self.loss_gen_total,
                                [self.gen_opt, self.id_opt]) as scaled_loss:
                scaled_loss.backward()
            self.gen_opt.step()
            self.id_opt.step()
        else:
            self.loss_gen_total.backward()  #计算梯度
            self.gen_opt.step()  #梯度更新
            self.id_opt.step()  #梯度更新
        print("L_total: %.4f, L_gan: %.4f,  Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \
                                                        hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \
                                                        hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \
                                                        hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \
                                                        hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \
                                                        hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \
                                                        hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \
                                                        hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \
                                                        hyperparameters['id_w'] * self.loss_id,\
                                                        hyperparameters['pid_w'] * self.loss_pid,\
hyperparameters['teacher_w'] * self.loss_teacher )  )

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def PCB_loss(self, inputs, labels):
        loss = 0.0
        for part in inputs:
            loss += self.id_criterion(part, labels)
        return loss / len(inputs)

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0)))
            s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0)))
            f_a, _ = self.id_a(scale2(x_a[i].unsqueeze(0)))
            f_b, _ = self.id_b(scale2(x_b[i].unsqueeze(0)))
            x_a_recon.append(self.gen_a.decode(s_a, f_a))
            x_b_recon.append(self.gen_b.decode(s_b, f_b))
            x_ba = self.gen_a.decode(s_b, f_a)
            x_ab = self.gen_b.decode(s_a, f_b)
            x_ba1.append(x_ba)
            x_ab1.append(x_ab)
            #cycle
            s_b_recon = self.gen_a.enc_content(self.single(x_ba))
            s_a_recon = self.gen_b.enc_content(self.single(x_ab))
            f_a_recon, _ = self.id_a(scale2(x_ba))
            f_b_recon, _ = self.id_b(scale2(x_ab))
            x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon))
            x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab)
        x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1)
        self.train()

        return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1

    def dis_update(self, x_ab, x_ba, x_a, x_b, hyperparameters, num_gpu):
        self.dis_opt.zero_grad()  #梯度清零
        # D loss
        #计算判别器的损失函数,然后计算梯度,进行梯度更新
        #输入为(x_ba,x_a),(x_ab,x_b)两对图片,损失为两对图片的总和
        if num_gpu > 1:
            self.loss_dis_a, reg_a = self.dis_a.module.calc_dis_loss(
                self.dis_a, x_ba.detach(), x_a)
            self.loss_dis_b, reg_b = self.dis_b.module.calc_dis_loss(
                self.dis_b, x_ab.detach(), x_b)
        else:
            # 计算判别器的损失函数
            self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(
                self.dis_a, x_ba.detach(), x_a)
            self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss(
                self.dis_b, x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        print("DLoss: %.4f" % self.loss_dis_total,
              "Reg: %.4f" % (reg_a + reg_b))
        if self.fp16:
            with amp.scale_loss(self.loss_dis_total,
                                self.dis_opt) as scaled_loss:
                scaled_loss.backward()
        else:
            self.loss_dis_total.backward()  #计算梯度
        self.dis_opt.step()  #梯度更新

    def update_learning_rate(self):
        #对学习率的更新
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.id_scheduler is not None:
            self.id_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b = self.gen_a
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b = self.dis_a
        # Load ID dis
        last_model_name = get_model_list(checkpoint_dir, "id")
        state_dict = torch.load(last_model_name)
        self.id_a.load_state_dict(state_dict['a'])
        self.id_b = self.id_a
        # Load optimizers
        try:
            state_dict = torch.load(
                os.path.join(checkpoint_dir, 'optimizer.pt'))
            self.dis_opt.load_state_dict(state_dict['dis'])
            self.gen_opt.load_state_dict(state_dict['gen'])
            self.id_opt.load_state_dict(state_dict['id'])
        except:
            pass
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations, num_gpu=1):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict()}, gen_name)
        if num_gpu > 1:
            torch.save({'a': self.dis_a.module.state_dict()}, dis_name)
        else:
            torch.save({'a': self.dis_a.state_dict()}, dis_name)
        torch.save({'a': self.id_a.state_dict()}, id_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'id': self.id_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Example #8
0
class MUSIC_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUSIC_Trainer, self).__init__()
        lr = hyperparameters['lr']
        old_flag = hyperparameters['old_flag']
        # Initiate the networks
        if old_flag == 1:
            self.gen = MaskGenOld(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
            self.style_dim = hyperparameters['gen']['style_dim']
            self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda()
            self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda()
        else:
            self.gen = MaskGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        try:
            enhance = hyperparameters['enhance']
        except KeyError:
            enhance = None

        self.enhance = enhance

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen.parameters())

        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def masking(self, mask, img):
        if self.enhance:
            mask = mask ** self.enhance
        img = img * 0.5 + 0.5
        masked_image = mask * img
        masked_image = (masked_image - 0.5) * 2
        return masked_image

    def scaled_sum(self, input_1, input_2):
        input_1 = input_1 * 0.5 + 0.5
        input_2 = input_2 * 0.5 + 0.5
        sum_output = input_1 + input_2
        sum_output = torch.clamp(sum_output, 0, 1) # added at 3.yaml
        sum_output = (sum_output - 0.5) * 2
        return sum_output

    def scaled_sub(self, input_1, input_2):
        input_1 = input_1 * 0.5 + 0.5
        input_2 = input_2 * 0.5 + 0.5
        sub_output = input_1 - input_2
        sub_output = torch.clamp(sub_output, 0, 1) # added at 3.yaml
        sub_output = (sub_output - 0.5) * 2
        return sub_output

    def forward(self, x_b):
        self.eval()
        x_ba_mask = self.gen.decode(self.gen.encode(x_b))
        x_ba = self.masking(x_ba_mask, x_b)
        self.train()
        return x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()

        # encode-decode
        x_ba_mask = self.gen.decode(self.gen.encode(x_b))
        x_aa_mask = self.gen.decode(self.gen.encode(x_a))
        x_ba = self.masking(x_ba_mask, x_b)
        x_aa = self.masking(x_aa_mask, x_a)
        # encode again
        x_t = self.scaled_sub(x_b, x_ba)
        x_b_new = self.scaled_sum(x_t, x_a)
        x_b_new_mask = self.gen.decode(self.gen.encode(x_b_new))
        x_ba_new = self.masking(x_b_new_mask, x_b_new)
        x_t_new = self.scaled_sub(x_b_new, x_ba_new)
        # decode twice
        x_baa_mask = self.gen.decode(self.gen.encode(x_ba))
        x_baa = self.masking(x_baa_mask, x_ba)

        # reconstruction loss
        self.loss_gen_recon_x_aa = self.recon_criterion(x_aa, x_a)
        self.loss_gen_recon_x_t = self.recon_criterion(x_t_new, x_t)
        self.loss_gen_recon_x_ba_new = self.recon_criterion(x_ba_new, x_a)
        self.loss_gen_recon_x_baa = self.recon_criterion(x_baa, x_ba)

        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_b_new)

        # total loss
        self.loss_gen_total = hyperparameters['gan_w_a'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w_b'] * self.loss_gen_adv_b + \
                              hyperparameters['a2a_w'] * self.loss_gen_recon_x_aa + \
                              hyperparameters['x_t_w'] * self.loss_gen_recon_x_t + \
                              hyperparameters['recon_w'] * self.loss_gen_recon_x_ba_new + \
                              hyperparameters['DTN_w'] * self.loss_gen_recon_x_baa
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def sample(self, x_a, x_b):
        self.eval()

        x_ba, x_aa, x_ba_new = [], [], []
        x_t, x_t_new = [], []
        x_baa, x_b_new = [], []
        for i in range(x_a.size(0)):
            x_ba.append(self.masking(self.gen.decode(self.gen.encode(x_b[i].unsqueeze(0))), x_b[i].unsqueeze(0)))
            x_aa.append(self.masking(self.gen.decode(self.gen.encode(x_a[i].unsqueeze(0))), x_a[i].unsqueeze(0)))
            x_t.append(self.scaled_sub(x_b[i], x_ba[i]))
            x_b_new.append(self.scaled_sum(x_t[i], x_a[i].unsqueeze(0)))
            x_ba_new.append(self.masking(self.gen.decode(self.gen.encode(x_b_new[i])), x_b_new[i]))
            x_t_new.append(self.scaled_sub(x_b_new[i], x_ba_new[i]))
            x_baa.append(self.masking(self.gen.decode(self.gen.encode(x_ba[i])), x_ba[i]))

        x_ba, x_aa, x_ba_new = torch.cat(x_ba), torch.cat(x_aa), torch.cat(x_ba_new)
        x_t, x_t_new = torch.cat(x_t), torch.cat(x_t_new)
        x_baa, x_b_new = torch.cat(x_baa), torch.cat(x_b_new)

        self.train()
        return x_b, x_ba, x_baa, x_a, x_ba_new, x_aa, x_t, x_t_new, x_b_new

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()

        x_ba = self.masking(self.gen.decode(self.gen.encode(x_b)), x_b)
        x_t = self.scaled_sub(x_b, x_ba)
        x_b_new = self.scaled_sum(x_t, x_a)

        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_b_new.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w_a'] * self.loss_dis_a + hyperparameters['gan_w_b'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen.load_state_dict(state_dict['gen'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'gen': self.gen.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Example #9
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters, resume_epoch=-1, snapshot_dir=None):

        super(MUNIT_Trainer, self).__init__()

        lr = hyperparameters['lr']

        # Initiate the networks.
        self.gen = AdaINGen(
            hyperparameters['input_dim'] + hyperparameters['n_datasets'],
            hyperparameters['gen'],
            hyperparameters['n_datasets'])  # Auto-encoder for domain a.
        self.dis = MsImageDis(
            hyperparameters['input_dim'] + hyperparameters['n_datasets'],
            hyperparameters['dis'])  # Discriminator for domain a.

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.beta1 = hyperparameters['beta1']
        self.beta2 = hyperparameters['beta2']
        self.weight_decay = hyperparameters['weight_decay']

        # Initiating and loader pretrained UNet.
        self.sup = UNet(input_channels=hyperparameters['input_dim'],
                        num_classes=2).cuda()

        # Fix the noise used in sampling.
        self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda()

        # Setup the optimizers.
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']

        dis_params = list(self.dis.parameters())
        gen_params = list(self.gen.parameters()) + list(self.sup.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(self.beta1, self.beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(self.beta1, self.beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization.
        self.apply(weights_init(hyperparameters['init']))
        self.dis.apply(weights_init('gaussian'))

        # Presetting one hot encoding vectors.
        self.one_hot_img = torch.zeros(hyperparameters['n_datasets'],
                                       hyperparameters['batch_size'],
                                       hyperparameters['n_datasets'], 256,
                                       256).cuda()
        self.one_hot_c = torch.zeros(hyperparameters['n_datasets'],
                                     hyperparameters['batch_size'],
                                     hyperparameters['n_datasets'], 64,
                                     64).cuda()

        for i in range(hyperparameters['n_datasets']):
            self.one_hot_img[i, :, i, :, :].fill_(1)
            self.one_hot_c[i, :, i, :, :].fill_(1)

        if resume_epoch != -1:

            self.resume(snapshot_dir, hyperparameters)

    def recon_criterion(self, input, target):

        return torch.mean(torch.abs(input - target))

    def semi_criterion(self, input, target):

        loss = CrossEntropyLoss2d(size_average=False).cuda()
        return loss(input, target)

    def forward(self, x_a, x_b):

        self.eval()

        x_a.volatile = True
        x_b.volatile = True

        s_a = Variable(self.s_a, volatile=True)
        s_b = Variable(self.s_b, volatile=True)

        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)
        c_a, s_a_fake = self.gen.encode(one_hot_x_a)
        c_b, s_b_fake = self.gen.encode(one_hot_x_b)

        one_hot_c_b = torch.cat([c_b, self.one_hot_c[d_index_a]], 1)
        one_hot_c_a = torch.cat([c_a, self.one_hot_c[d_index_b]], 1)
        x_ba = self.gen.decode(one_hot_c_b, s_a)
        x_ab = self.gen.decode(one_hot_c_a, s_b)

        self.train()

        return x_ab, x_ba

    def set_gen_trainable(self, train_bool):

        if train_bool:
            self.gen.train()
            for param in self.gen.parameters():
                param.requires_grad = True

        else:
            self.gen.eval()
            for param in self.gen.parameters():
                param.requires_grad = True

    def set_sup_trainable(self, train_bool):

        if train_bool:
            self.sup.train()
            for param in self.sup.parameters():
                param.requires_grad = True
        else:
            self.sup.eval()
            for param in self.sup.parameters():
                param.requires_grad = True

    def sup_update(self, x_a, x_b, y_a, y_b, d_index_a, d_index_b, use_a,
                   use_b, hyperparameters):

        self.gen_opt.zero_grad()

        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)

        # Encode.
        c_a, s_a_prime = self.gen.encode(one_hot_x_a)
        c_b, s_b_prime = self.gen.encode(one_hot_x_b)

        # Decode (within domain).
        one_hot_c_a = torch.cat([c_a, self.one_hot_c[d_index_a]], 1)
        one_hot_c_b = torch.cat([c_b, self.one_hot_c[d_index_b]], 1)
        x_a_recon = self.gen.decode(one_hot_c_a, s_a_prime)
        x_b_recon = self.gen.decode(one_hot_c_b, s_b_prime)

        # Decode (cross domain).
        one_hot_c_ab = torch.cat([c_a, self.one_hot_c[d_index_b]], 1)
        one_hot_c_ba = torch.cat([c_b, self.one_hot_c[d_index_a]], 1)
        x_ba = self.gen.decode(one_hot_c_ba, s_a)
        x_ab = self.gen.decode(one_hot_c_ab, s_b)

        # Encode again.
        one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1)
        one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1)
        c_b_recon, s_a_recon = self.gen.encode(one_hot_x_ba)
        c_a_recon, s_b_recon = self.gen.encode(one_hot_x_ab)

        # Forwarding through supervised model.
        p_a = None
        p_b = None
        loss_semi_a = None
        loss_semi_b = None

        has_a_label = (c_a[use_a, :, :, :].size(0) != 0)
        if has_a_label:
            p_a = self.sup(c_a, use_a, True)
            p_a_recon = self.sup(c_a_recon, use_a, True)
            loss_semi_a = self.semi_criterion(p_a, y_a[use_a, :, :]) + \
                          self.semi_criterion(p_a_recon, y_a[use_a, :, :])

        has_b_label = (c_b[use_b, :, :, :].size(0) != 0)
        if has_b_label:
            p_b = self.sup(c_b, use_b, True)
            p_b_recon = self.sup(c_b, use_b, True)
            loss_semi_b = self.semi_criterion(p_b, y_b[use_b, :, :]) + \
                          self.semi_criterion(p_b_recon, y_b[use_b, :, :])

        self.loss_gen_total = None
        if loss_semi_a is not None and loss_semi_b is not None:
            self.loss_gen_total = loss_semi_a + loss_semi_b
        elif loss_semi_a is not None:
            self.loss_gen_total = loss_semi_a
        elif loss_semi_b is not None:
            self.loss_gen_total = loss_semi_b

        if self.loss_gen_total is not None:
            self.loss_gen_total.backward()
            self.gen_opt.step()

    def sup_forward(self, x, y, d_index, hyperparameters):

        self.sup.eval()

        # Encoding content image.
        one_hot_x = torch.cat([x, self.one_hot_img[d_index, 0].unsqueeze(0)],
                              1)
        content, _ = self.gen.encode(one_hot_x)

        # Forwarding on supervised model.
        y_pred = self.sup(content, only_prediction=True)

        # Computing metrics.
        pred = y_pred.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()

        jacc = jaccard(pred, y.cpu().squeeze(0).numpy())

        return jacc, pred, content

    def gen_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters):

        self.gen_opt.zero_grad()

        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        # Encode.
        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)

        c_a, s_a_prime = self.gen.encode(one_hot_x_a)
        c_b, s_b_prime = self.gen.encode(one_hot_x_b)

        # Decode (within domain).
        one_hot_c_a = torch.cat([c_a, self.one_hot_c[d_index_a]], 1)
        one_hot_c_b = torch.cat([c_b, self.one_hot_c[d_index_b]], 1)
        x_a_recon = self.gen.decode(one_hot_c_a, s_a_prime)
        x_b_recon = self.gen.decode(one_hot_c_b, s_b_prime)

        # Decode (cross domain).
        one_hot_c_ab = torch.cat([c_a, self.one_hot_c[d_index_b]], 1)
        one_hot_c_ba = torch.cat([c_b, self.one_hot_c[d_index_a]], 1)
        x_ba = self.gen.decode(one_hot_c_ba, s_a)
        x_ab = self.gen.decode(one_hot_c_ab, s_b)

        # Encode again.
        one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1)
        one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1)
        c_b_recon, s_a_recon = self.gen.encode(one_hot_x_ba)
        c_a_recon, s_b_recon = self.gen.encode(one_hot_x_ab)

        # Decode again (if needed).
        one_hot_c_aba_recon = torch.cat([c_a_recon, self.one_hot_c[d_index_a]],
                                        1)
        one_hot_c_bab_recon = torch.cat([c_b_recon, self.one_hot_c[d_index_b]],
                                        1)
        x_aba = self.gen.decode(one_hot_c_aba_recon, s_a_prime)
        x_bab = self.gen.decode(one_hot_c_bab_recon, s_b_prime)

        # Reconstruction loss.
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)

        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a)
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b)

        # GAN loss.
        self.loss_gen_adv_a = self.dis.calc_gen_loss(one_hot_x_ba)
        self.loss_gen_adv_b = self.dis.calc_gen_loss(one_hot_x_ab)

        # Total loss.
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):

        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):

        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        s_a1 = Variable(self.s_a, volatile=True)
        s_b1 = Variable(self.s_b, volatile=True)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(),
                        volatile=True)
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(),
                        volatile=True)
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):

            one_hot_x_a = torch.cat(
                [x_a[i].unsqueeze(0), self.one_hot_img_a[i].unsqueeze(0)], 1)
            one_hot_x_b = torch.cat(
                [x_b[i].unsqueeze(0), self.one_hot_img_b[i].unsqueeze(0)], 1)

            c_a, s_a_fake = self.gen.encode(one_hot_x_a)
            c_b, s_b_fake = self.gen.encode(one_hot_x_b)
            x_a_recon.append(self.gen.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen.decode(c_b, s_b_fake))
            x_ba1.append(self.gen.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen.decode(c_a, s_b2[i].unsqueeze(0)))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters):

        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        # Encode.
        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)

        c_a, _ = self.gen.encode(one_hot_x_a)
        c_b, _ = self.gen.encode(one_hot_x_b)

        one_hot_c_ba = torch.cat([c_b, self.one_hot_c[d_index_a]], 1)
        one_hot_c_ab = torch.cat([c_a, self.one_hot_c[d_index_b]], 1)

        # Decode (cross domain).
        x_ba = self.gen.decode(one_hot_c_ba, s_a)
        x_ab = self.gen.decode(one_hot_c_ab, s_b)

        # D loss.
        one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1)
        one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1)

        self.loss_dis_a = self.dis.calc_dis_loss(one_hot_x_ba, one_hot_x_a)
        self.loss_dis_b = self.dis.calc_dis_loss(one_hot_x_ab, one_hot_x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \
                              hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):

        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):

        print("--> " + checkpoint_dir)

        # Load generator.
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen.load_state_dict(state_dict)
        epochs = int(last_model_name[-11:-3])

        # Load supervised model.
        last_model_name = get_model_list(checkpoint_dir, "sup")
        state_dict = torch.load(last_model_name)
        self.sup.load_state_dict(state_dict)

        # Load discriminator.
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis.load_state_dict(state_dict)

        # Load optimizers.
        last_model_name = get_model_list(checkpoint_dir, "opt")
        state_dict = torch.load(last_model_name)
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])

        for state in self.dis_opt.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        for state in self.gen_opt.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        # Reinitilize schedulers.
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           epochs)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           epochs)

        print('Resume from epoch %d' % epochs)
        return epochs

    def save(self, snapshot_dir, epoch):

        # Save generators, discriminators, and optimizers.
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % epoch)
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % epoch)
        sup_name = os.path.join(snapshot_dir, 'sup_%08d.pt' % epoch)
        opt_name = os.path.join(snapshot_dir, 'opt_%08d.pt' % epoch)

        torch.save(self.gen.state_dict(), gen_name)
        torch.save(self.dis.state_dict(), dis_name)
        torch.save(self.sup.state_dict(), sup_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Example #10
0
class UNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters, resume_epoch=-1, snapshot_dir=None):

        super(UNIT_Trainer, self).__init__()

        lr = hyperparameters['lr']

        # Initiate the networks.
        self.gen = VAEGen(
            hyperparameters['input_dim'] + hyperparameters['n_datasets'],
            hyperparameters['gen'],
            hyperparameters['n_datasets'])  # Auto-encoder for domain a.
        self.dis = MsImageDis(
            hyperparameters['input_dim'] + hyperparameters['n_datasets'],
            hyperparameters['dis'])  # Discriminator for domain a.

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        self.sup = UNet(input_channels=hyperparameters['input_dim'],
                        num_classes=2).cuda()

        # Setup the optimizers.
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis.parameters())
        gen_params = list(self.gen.parameters()) + list(self.sup.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization.
        self.apply(weights_init(hyperparameters['init']))
        self.dis.apply(weights_init('gaussian'))

        # Presetting one hot encoding vectors.
        self.one_hot_img = torch.zeros(hyperparameters['n_datasets'],
                                       hyperparameters['batch_size'],
                                       hyperparameters['n_datasets'], 256,
                                       256).cuda()
        self.one_hot_h = torch.zeros(hyperparameters['n_datasets'],
                                     hyperparameters['batch_size'],
                                     hyperparameters['n_datasets'], 64,
                                     64).cuda()

        for i in range(hyperparameters['n_datasets']):
            self.one_hot_img[i, :, i, :, :].fill_(1)
            self.one_hot_h[i, :, i, :, :].fill_(1)

        if resume_epoch != -1:

            self.resume(snapshot_dir, hyperparameters)

    def recon_criterion(self, input, target):

        return torch.mean(torch.abs(input - target))

    def semi_criterion(self, input, target):

        loss = CrossEntropyLoss2d(size_average=False).cuda()
        return loss(input, target)

    def forward(self, x_a, x_b):

        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        h_a, _ = self.gen_a.encode(x_a)
        h_b, _ = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(h_b)
        x_ab = self.gen_b.decode(h_a)
        self.train()
        return x_ab, x_ba

    def __compute_kl(self, mu):

        # def _compute_kl(self, mu, sd):
        # mu_2 = torch.pow(mu, 2)
        # sd_2 = torch.pow(sd, 2)
        # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0)
        # return encoding_loss

        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def set_gen_trainable(self, train_bool):

        if train_bool:
            self.gen.train()
            for param in self.gen.parameters():
                param.requires_grad = True

        else:
            self.gen.eval()
            for param in self.gen.parameters():
                param.requires_grad = True

    def set_sup_trainable(self, train_bool):

        if train_bool:
            self.sup.train()
            for param in self.sup.parameters():
                param.requires_grad = True
        else:
            self.sup.eval()
            for param in self.sup.parameters():
                param.requires_grad = True

    def sup_update(self, x_a, x_b, y_a, y_b, d_index_a, d_index_b, use_a,
                   use_b, hyperparameters):

        self.gen_opt.zero_grad()

        # Encode.
        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)
        h_a, n_a = self.gen.encode(one_hot_x_a)
        h_b, n_b = self.gen.encode(one_hot_x_b)

        # Decode (within domain).
        one_hot_h_a = torch.cat([h_a + n_a, self.one_hot_h[d_index_a]], 1)
        one_hot_h_b = torch.cat([h_b + n_b, self.one_hot_h[d_index_b]], 1)
        x_a_recon = self.gen.decode(one_hot_h_a)
        x_b_recon = self.gen.decode(one_hot_h_b)

        # Decode (cross domain).
        one_hot_h_ab = torch.cat([h_a + n_a, self.one_hot_h[d_index_b]], 1)
        one_hot_h_ba = torch.cat([h_b + n_b, self.one_hot_h[d_index_a]], 1)
        x_ba = self.gen.decode(one_hot_h_ba)
        x_ab = self.gen.decode(one_hot_h_ab)

        # Encode again.
        one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1)
        one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1)
        h_b_recon, n_b_recon = self.gen.encode(one_hot_x_ba)
        h_a_recon, n_a_recon = self.gen.encode(one_hot_x_ab)

        # Decode again (if needed).
        one_hot_h_a_recon = torch.cat(
            [h_a_recon + n_a_recon, self.one_hot_h[d_index_a]], 1)
        one_hot_h_b_recon = torch.cat(
            [h_b_recon + n_b_recon, self.one_hot_h[d_index_b]], 1)
        x_aba = self.gen.decode(
            one_hot_h_a_recon
        ) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen.decode(
            one_hot_h_b_recon
        ) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # Forwarding through supervised model.
        p_a = None
        p_b = None
        loss_semi_a = None
        loss_semi_b = None

        has_a_label = (h_a[use_a, :, :, :].size(0) != 0)
        if has_a_label:
            p_a = self.sup(h_a, use_a, True)
            p_a_recon = self.sup(h_a_recon, use_a, True)
            loss_semi_a = self.semi_criterion(p_a, y_a[use_a, :, :]) + \
                          self.semi_criterion(p_a_recon, y_a[use_a, :, :])

        has_b_label = (h_b[use_b, :, :, :].size(0) != 0)
        if has_b_label:
            p_b = self.sup(h_b, use_b, True)
            p_b_recon = self.sup(h_b, use_b, True)
            loss_semi_b = self.semi_criterion(p_b, y_b[use_b, :, :]) + \
                          self.semi_criterion(p_b_recon, y_b[use_b, :, :])

        self.loss_gen_total = None
        if loss_semi_a is not None and loss_semi_b is not None:
            self.loss_gen_total = loss_semi_a + loss_semi_b
        elif loss_semi_a is not None:
            self.loss_gen_total = loss_semi_a
        elif loss_semi_b is not None:
            self.loss_gen_total = loss_semi_b

        if self.loss_gen_total is not None:
            self.loss_gen_total.backward()
            self.gen_opt.step()

    def sup_forward(self, x, y, d_index, hyperparameters):

        self.sup.eval()

        # Encoding content image.
        one_hot_x = torch.cat([x, self.one_hot_img[d_index, 0].unsqueeze(0)],
                              1)
        hidden, _ = self.gen.encode(one_hot_x)

        # Forwarding on supervised model.
        y_pred = self.sup(hidden, only_prediction=True)

        # Computing metrics.
        pred = y_pred.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()

        jacc = jaccard(pred, y.cpu().squeeze(0).numpy())

        return jacc, pred, hidden

    def gen_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters):

        self.gen_opt.zero_grad()

        # Encode.
        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)
        h_a, n_a = self.gen.encode(one_hot_x_a)
        h_b, n_b = self.gen.encode(one_hot_x_b)

        # Decode (within domain).
        one_hot_h_a = torch.cat([h_a + n_a, self.one_hot_h[d_index_a]], 1)
        one_hot_h_b = torch.cat([h_b + n_b, self.one_hot_h[d_index_b]], 1)
        x_a_recon = self.gen.decode(one_hot_h_a)
        x_b_recon = self.gen.decode(one_hot_h_b)

        # Decode (cross domain).
        one_hot_h_ab = torch.cat([h_a + n_a, self.one_hot_h[d_index_b]], 1)
        one_hot_h_ba = torch.cat([h_b + n_b, self.one_hot_h[d_index_a]], 1)
        x_ba = self.gen.decode(one_hot_h_ba)
        x_ab = self.gen.decode(one_hot_h_ab)

        # Encode again.
        one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1)
        one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1)
        h_b_recon, n_b_recon = self.gen.encode(one_hot_x_ba)
        h_a_recon, n_a_recon = self.gen.encode(one_hot_x_ab)

        # Decode again (if needed).
        one_hot_h_a_recon = torch.cat(
            [h_a_recon + n_a_recon, self.one_hot_h[d_index_a]], 1)
        one_hot_h_b_recon = torch.cat(
            [h_b_recon + n_b_recon, self.one_hot_h[d_index_b]], 1)
        x_aba = self.gen.decode(
            one_hot_h_a_recon
        ) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen.decode(
            one_hot_h_b_recon
        ) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # Reconstruction loss.
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
        self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
        self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a)
        self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b)
        self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
        self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)

        # GAN loss.
        self.loss_gen_adv_a = self.dis.calc_gen_loss(one_hot_x_ba)
        self.loss_gen_adv_b = self.dis.calc_gen_loss(one_hot_x_ab)

        # Total loss.
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def sample(self, x_a, x_b):

        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], []
        for i in range(x_a.size(0)):
            h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0))
            h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(h_a))
            x_b_recon.append(self.gen_b.decode(h_b))
            x_ba.append(self.gen_a.decode(h_b))
            x_ab.append(self.gen_b.decode(h_a))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba = torch.cat(x_ba)
        x_ab = torch.cat(x_ab)
        self.train()
        return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

    def dis_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters):

        self.dis_opt.zero_grad()

        # Encode.
        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)
        h_a, n_a = self.gen.encode(one_hot_x_a)
        h_b, n_b = self.gen.encode(one_hot_x_b)

        # Decode (cross domain).
        one_hot_h_ab = torch.cat([h_a + n_a, self.one_hot_h[d_index_b]], 1)
        one_hot_h_ba = torch.cat([h_b + n_b, self.one_hot_h[d_index_a]], 1)
        x_ba = self.gen.decode(one_hot_h_ba)
        x_ab = self.gen.decode(one_hot_h_ab)

        # D loss.
        one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1)
        one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1)
        self.loss_dis_a = self.dis.calc_dis_loss(one_hot_x_ba.detach(),
                                                 one_hot_x_a)
        self.loss_dis_b = self.dis.calc_dis_loss(one_hot_x_ab.detach(),
                                                 one_hot_x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \
                              hyperparameters['gan_w'] * self.loss_dis_b

        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):

        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):

        # Load generators.
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen.load_state_dict(state_dict)
        epochs = int(last_model_name[-11:-3])

        # Load discriminators.
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis.load_state_dict(state_dict)

        # Load supervised model.
        last_model_name = get_model_list(checkpoint_dir, "sup")
        state_dict = torch.load(last_model_name)
        self.sup.load_state_dict(state_dict)

        # Load optimizers.
        last_model_name = get_model_list(checkpoint_dir, "opt")
        state_dict = torch.load(last_model_name)
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])

        for state in self.dis_opt.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        for state in self.gen_opt.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        # Reinitilize schedulers.
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           epochs)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           epochs)

        print('Resume from iteration %d' % epochs)
        return epochs

    def save(self, snapshot_dir, epoch):

        # Save generators, discriminators, and optimizers.
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % epoch)
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % epoch)
        sup_name = os.path.join(snapshot_dir, 'sup_%08d.pt' % epoch)
        opt_name = os.path.join(snapshot_dir, 'opt_%08d.pt' % epoch)

        torch.save(self.gen.state_dict(), gen_name)
        torch.save(self.dis.state_dict(), dis_name)
        torch.save(self.sup.state_dict(), sup_name)
        torch.save(
            {
                'dis': self.dis_opt.state_dict(),
                'gen': self.gen_opt.state_dict()
            }, opt_name)
Example #11
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters["lr"]
        self.gen_state = hyperparameters["gen_state"]
        self.guided = hyperparameters["guided"]
        self.newsize = hyperparameters["crop_image_height"]
        self.semantic_w = hyperparameters["semantic_w"] > 0

        self.recon_mask = hyperparameters["recon_mask"] == 1
        self.check_alignment = hyperparameters["check_alignment"] == 1

        self.full_adaptation = hyperparameters["full_adaptation"] == 1
        self.dann_scheduler = None
        self.full_adaptation = hyperparameters["full_adaptation"] == 1

        if "domain_adv_w" in hyperparameters.keys():
            self.domain_classif = hyperparameters["domain_adv_w"] > 0
        else:
            self.domain_classif = False
        if self.gen_state == 0:
            # Initiate the networks
            self.gen_a = AdaINGen(
                hyperparameters["input_dim_a"],
                hyperparameters["gen"])  # auto-encoder for domain a
            self.gen_b = AdaINGen(
                hyperparameters["input_dim_b"],
                hyperparameters["gen"])  # auto-encoder for domain b

        elif self.gen_state == 1:
            self.gen = AdaINGen_double(hyperparameters["input_dim_a"],
                                       hyperparameters["gen"])
        else:
            print("self.gen_state unknown value:", self.gen_state)

        self.dis_a = MsImageDis(
            hyperparameters["input_dim_a"],
            hyperparameters["dis"])  # discriminator for domain a

        self.dis_b = MsImageDis(
            hyperparameters["input_dim_b"],
            hyperparameters["dis"])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters["gen"]["style_dim"]

        # fix the noise used in sampling
        display_size = int(hyperparameters["display_size"])
        print(self.style_dim)
        print(display_size)
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        # Setup the optimizers
        beta1 = hyperparameters["beta1"]
        beta2 = hyperparameters["beta2"]

        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())

        if self.gen_state == 0:
            gen_params = list(self.gen_a.parameters()) + list(
                self.gen_b.parameters())
        elif self.gen_state == 1:
            gen_params = list(self.gen.parameters())
        else:
            print("self.gen_state unknown value:", self.gen_state)

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters["weight_decay"],
        )
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters["weight_decay"],
        )
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters["init"]))
        self.dis_a.apply(weights_init("gaussian"))
        self.dis_b.apply(weights_init("gaussian"))

        # Load VGG model if needed
        if "vgg_w" in hyperparameters.keys() and hyperparameters["vgg_w"] > 0:
            self.vgg = load_vgg16(hyperparameters["vgg_model_path"] +
                                  "/models")
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # Load semantic segmentation model if needed
        if "semantic_w" in hyperparameters.keys(
        ) and hyperparameters["semantic_w"] > 0:
            self.segmentation_model = load_segmentation_model(
                hyperparameters["semantic_ckpt_path"])
            self.segmentation_model.eval()
            for param in self.segmentation_model.parameters():
                param.requires_grad = False

        # Load domain classifier if needed
        if ("domain_adv_w" in hyperparameters.keys()
                and hyperparameters["domain_adv_w"] > 0):
            self.domain_classifier = domainClassifier(256)
            dann_params = list(self.domain_classifier.parameters())
            self.dann_opt = torch.optim.Adam(
                [p for p in dann_params if p.requires_grad],
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=hyperparameters["weight_decay"],
            )
            self.domain_classifier.apply(weights_init("gaussian"))
            self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters)

    def recon_criterion(self, input, target):
        """
        Compute pixelwise L1 loss between two images input and target 
        
        Arguments:
            input {torch.Tensor} -- Image tensor (original image such as x_a)
            target {torch.Tensor} -- Image tensor (after cycle-translation image x_aba)
        
        Returns:
            torch.Float -- pixelwise L1 loss
        """
        return torch.mean(torch.abs(input - target))

    def recon_criterion_mask(self, input, target, mask):
        """
        Compute a weaker version of the recon_criterion between two images input and target 
        where the L1 is only computed on the unmasked region
        
        Arguments:
            input {torch.Tensor} -- Image (original image such as x_a)
            target {torch.Tensor} -- Image (after cycle-translation image x_aba)
            mask {} -- binary Mask of size HxW (input.shape ~ CxHxW)
        
        Returns:
            torch.Float -- L1 loss over input.(1-mask) and target.(1-mask)
        """
        return torch.mean(torch.abs(torch.mul((input - target), 1 - mask)))

    def forward(self, x_a, x_b):
        """
        Perform the translation from domain A (resp B) to domain B (resp A): x_a to x_ab (resp: x_b to x_ba).
        
        Arguments:
            x_a {torch.Tensor} -- Image from domain A after transform in tensor format
            x_b {torch.Tensor} -- Image from domain B after transform in tensor format
        
        Returns:
            torch.Tensor, torch.Tensor -- Translated version of x_a in domain B, Translated version of x_b in domain A
        """
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        if self.gen_state == 0:
            c_a, s_a_fake = self.gen_a.encode(x_a)
            c_b, s_b_fake = self.gen_b.encode(x_b)
            x_ba = self.gen_a.decode(c_b, s_a)
            x_ab = self.gen_b.decode(c_a, s_b)
        elif self.gen_state == 1:
            c_a, s_a_fake = self.gen.encode(x_a, 1)
            c_b, s_b_fake = self.gen.encode(x_b, 2)
            x_ba = self.gen.decode(c_b, s_a, 1)
            x_ab = self.gen.decode(c_a, s_b, 2)
        else:
            print("self.gen_state unknown value:", self.gen_state)
        self.train()
        return x_ab, x_ba

    def gen_update(self,
                   x_a,
                   x_b,
                   hyperparameters,
                   mask_a=None,
                   mask_b=None,
                   comet_exp=None,
                   synth=0):
        """
        Update the generator parameters

        Arguments:
            x_a {torch.Tensor} -- Image from domain A after transform in tensor format
            x_b {torch.Tensor} -- Image from domain B after transform in tensor format
            hyperparameters {dictionnary} -- dictionnary with all hyperparameters 

        Keyword Arguments:
            mask_a {torch.Tensor} -- binary mask (0,1) corresponding to the ground in x_a (default: {None})
            mask_b {torch.Tensor} -- binary mask (0,1) corresponding to the water in x_b (default: {None})
            comet_exp {cometExperience} -- CometML object use to log all the loss and images (default: {None})
            synth {boolean}  -- binary True or False stating if we have a synthetic pair or not 

        Returns:
            [type] -- [description]
        """
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        if self.gen_state == 0:
            # encode
            c_a, s_a_prime = self.gen_a.encode(x_a)
            c_b, s_b_prime = self.gen_b.encode(x_b)
            # decode (within domain)
            x_a_recon = self.gen_a.decode(c_a, s_a_prime)
            x_b_recon = self.gen_b.decode(c_b, s_b_prime)
            # decode (cross domain)
            if self.guided == 0:
                x_ba = self.gen_a.decode(c_b, s_a)
                x_ab = self.gen_b.decode(c_a, s_b)
            elif self.guided == 1:
                x_ba = self.gen_a.decode(c_b, s_a_prime)
                x_ab = self.gen_b.decode(c_a, s_b_prime)
            else:
                print("self.guided unknown value:", self.guided)
            # encode again
            c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
            c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
            # decode again (if needed)
            x_aba = (self.gen_a.decode(c_a_recon, s_a_prime)
                     if hyperparameters["recon_x_cyc_w"] > 0 else None)
            x_bab = (self.gen_b.decode(c_b_recon, s_b_prime)
                     if hyperparameters["recon_x_cyc_w"] > 0 else None)
        elif self.gen_state == 1:
            # encode
            c_a, s_a_prime = self.gen.encode(x_a, 1)
            print(c_a.shape)
            c_b, s_b_prime = self.gen.encode(x_b, 2)
            # decode (within domain)
            x_a_recon = self.gen.decode(c_a, s_a_prime, 1)
            x_b_recon = self.gen.decode(c_b, s_b_prime, 2)
            # decode (cross domain)
            if self.guided == 0:
                x_ba = self.gen.decode(c_b, s_a, 1)
                x_ab = self.gen.decode(c_a, s_b, 2)
            elif self.guided == 1:
                x_ba = self.gen.decode(c_b, s_a_prime, 1)
                x_ab = self.gen.decode(c_a, s_b_prime, 2)
            else:
                print("self.guided unknown value:", self.guided)

            # encode again
            c_b_recon, s_a_recon = self.gen.encode(x_ba, 1)
            c_a_recon, s_b_recon = self.gen.encode(x_ab, 2)
            # decode again (if needed)
            x_aba = (self.gen.decode(c_a_recon, s_a_prime, 1)
                     if hyperparameters["recon_x_cyc_w"] > 0 else None)
            x_bab = (self.gen.decode(c_b_recon, s_b_prime, 2)
                     if hyperparameters["recon_x_cyc_w"] > 0 else None)
        else:
            print("self.gen_state unknown value:", self.gen_state)

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)

        if self.guided == 0:
            self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
            self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)

        elif self.guided == 1:
            self.loss_gen_recon_s_a = self.recon_criterion(
                s_a_recon, s_a_prime)
            self.loss_gen_recon_s_b = self.recon_criterion(
                s_b_recon, s_b_prime)
        else:
            print("self.guided unknown value:", self.guided)

        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)

        # Synthetic reconstruction loss

        if self.check_alignment:
            print('mask_b.shape', mask_b.shape)

            # Define the mask of exact same pixel among a pair
            mask_alignment = (torch.sum(torch.abs(x_a - x_b),
                                        1) == 0).unsqueeze(1)
            mask_alignment = mask_alignment.type(torch.cuda.FloatTensor)
            #print('mask_alignment.shape', mask_alignment.shape)


        self.loss_gen_recon_synth = self.recon_criterion_mask(x_ab, x_b, 1-mask_alignment) + \
                                    self.recon_criterion_mask(x_ba, x_a, 1-mask_alignment)  if self.check_alignment else 0

        if self.recon_mask:
            self.loss_gen_cycrecon_x_a = (self.recon_criterion_mask(
                x_aba, x_a, mask_a) if hyperparameters["recon_x_cyc_w"] > 0
                                          else 0)
            self.loss_gen_cycrecon_x_b = (self.recon_criterion_mask(
                x_bab, x_b, mask_b) if hyperparameters["recon_x_cyc_w"] > 0
                                          else 0)
        else:
            self.loss_gen_cycrecon_x_a = (self.recon_criterion(
                x_aba, x_a) if hyperparameters["recon_x_cyc_w"] > 0 else 0)
            self.loss_gen_cycrecon_x_b = (self.recon_criterion(
                x_bab, x_b) if hyperparameters["recon_x_cyc_w"] > 0 else 0)

        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = (self.compute_vgg_loss(self.vgg, x_ba, x_b)
                               if hyperparameters["vgg_w"] > 0 else 0)
        self.loss_gen_vgg_b = (self.compute_vgg_loss(self.vgg, x_ab, x_a)
                               if hyperparameters["vgg_w"] > 0 else 0)

        # semantic-segmentation loss
        self.loss_sem_seg = (self.compute_semantic_seg_loss(
            x_a.squeeze(), x_ab.squeeze(), mask_a) +
                             self.compute_semantic_seg_loss(
                                 x_b.squeeze(), x_ba.squeeze(), mask_b)
                             if hyperparameters["semantic_w"] > 0 else 0)
        # Domain adversarial loss (c_a and c_b are swapped because we want the feature to be less informative
        # minmax (accuracy but max min loss)
        self.domain_adv_loss = (self.compute_domain_adv_loss(
            c_a, c_b, compute_accuracy=False, minimize=False)
                                if hyperparameters["domain_adv_w"] > 0 else 0)

        # total loss
        self.loss_gen_total = (
            hyperparameters["gan_w"] * self.loss_gen_adv_a +
            hyperparameters["gan_w"] * self.loss_gen_adv_b +
            hyperparameters["recon_x_w"] * self.loss_gen_recon_x_a +
            hyperparameters["recon_s_w"] * self.loss_gen_recon_s_a +
            hyperparameters["recon_c_w"] * self.loss_gen_recon_c_a +
            hyperparameters["recon_x_w"] * self.loss_gen_recon_x_b +
            hyperparameters["recon_s_w"] * self.loss_gen_recon_s_b +
            hyperparameters["recon_c_w"] * self.loss_gen_recon_c_b +
            hyperparameters["recon_x_cyc_w"] * self.loss_gen_cycrecon_x_a +
            hyperparameters["recon_x_cyc_w"] * self.loss_gen_cycrecon_x_b +
            hyperparameters["vgg_w"] * self.loss_gen_vgg_a +
            hyperparameters["vgg_w"] * self.loss_gen_vgg_b +
            hyperparameters["semantic_w"] * self.loss_sem_seg +
            hyperparameters["domain_adv_w"] * self.domain_adv_loss +
            hyperparameters["recon_synth_w"] * self.loss_gen_recon_synth)

        self.loss_gen_total.backward()
        self.gen_opt.step()

        if comet_exp is not None:
            comet_exp.log_metric("loss_gen_adv_a",
                                 self.loss_gen_adv_a.cpu().detach())
            comet_exp.log_metric("loss_gen_adv_b",
                                 self.loss_gen_adv_b.cpu().detach())
            comet_exp.log_metric("loss_gen_recon_x_a",
                                 self.loss_gen_recon_x_a.cpu().detach())
            comet_exp.log_metric("loss_gen_recon_s_a",
                                 self.loss_gen_recon_s_a.cpu().detach())
            comet_exp.log_metric("loss_gen_recon_c_a",
                                 self.loss_gen_recon_c_a.cpu().detach())
            comet_exp.log_metric("loss_gen_recon_x_b",
                                 self.loss_gen_recon_x_b.cpu().detach())
            comet_exp.log_metric("loss_gen_recon_s_b",
                                 self.loss_gen_recon_s_b.cpu().detach())
            comet_exp.log_metric("loss_gen_recon_c_b",
                                 self.loss_gen_recon_c_b.cpu().detach())
            comet_exp.log_metric("loss_gen_cycrecon_x_a",
                                 self.loss_gen_cycrecon_x_a.cpu().detach())
            comet_exp.log_metric("loss_gen_cycrecon_x_b",
                                 self.loss_gen_cycrecon_x_b.cpu().detach())
            comet_exp.log_metric("loss_gen_total",
                                 self.loss_gen_total.cpu().detach())
            if hyperparameters["vgg_w"] > 0:
                comet_exp.log_metric("loss_gen_vgg_a",
                                     self.loss_gen_vgg_a.cpu().detach())
                comet_exp.log_metric("loss_gen_vgg_b",
                                     self.loss_gen_vgg_b.cpu().detach())
            if hyperparameters["semantic_w"] > 0:
                comet_exp.log_metric("loss_sem_seg",
                                     self.loss_sem_seg.cpu().detach())
            if hyperparameters["domain_adv_w"] > 0:
                comet_exp.log_metric("domain_adv_loss_gen",
                                     self.domain_adv_loss.cpu().detach())
            if synth == 0:
                comet_exp.log_metric("loss_gen_recon_synth",
                                     self.loss_gen_recon_synth.cpu().detach())

    def compute_vgg_loss(self, vgg, img, target):
        """ 
        Compute the domain-invariant perceptual loss
        
        Arguments:
            vgg {model} -- popular Convolutional Network for Classification and Detection
            img {torch.Tensor} -- image before translation
            target {torch.Tensor} -- image after translation
        
        Returns:
            torch.Float -- domain invariant perceptual loss
        """
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def compute_domain_adv_loss(self,
                                c_a,
                                c_b,
                                compute_accuracy=False,
                                minimize=True):
        """ 
        Compute a domain adversarial loss on the embedding of the classifier:
        we are trying to learn an anonymized representation of the content. 
        
        Arguments:
            c_a {torch.tensor} -- content extracted from an image of domain A with encoder A
            c_b {torch.tensor} -- content extracted from an image of domain B with encoder B
        
        Keyword Arguments:
            compute_accuracy {bool} -- either return only the loss or loss and softmax probs
            (default: {False})
            minimize {bool} -- optimize classification accuracy(True) or anonymized the representation(False)
        
        Returns:
            torch.Float -- loss (optionnal softmax P(classifier(c_a)=a) and P(classifier(c_b)=b)) 
        """
        # Infer domain classifier on content extracted from an image of domainA
        output_a = self.domain_classifier(c_a)

        # Infer domain classifier on content extracted from an image of domainB
        output_b = self.domain_classifier(c_b)

        # Concatenate the output in a single vector
        output = torch.cat((output_a, output_b))

        if minimize:
            target = torch.tensor([1., 0., 0., 1.], device='cuda')
        else:
            target = torch.tensor([0.5, 0.5, 0.5, 0.5], device='cuda')
        # mean square error loss
        loss = torch.nn.MSELoss()(output, target)
        if compute_accuracy:
            return loss, output_a[0], output_b[1]
        else:
            return loss

    def compute_semantic_seg_loss(self, img1, img2, mask=None):
        """ 
        Compute semantic segmentation loss between two images on the unmasked region or in the entire image

        Arguments:
            img1 {torch.Tensor} -- Image from domain A after transform in tensor format
            img2 {torch.Tensor} -- Image transformed
            mask {torch.Tensor} -- Binary mask where we force the loss to be zero
        Returns:
            torch.float -- Cross entropy loss on the unmasked region
        """
        # denorm
        img1_denorm = (img1 + 1) / 2.0
        img2_denorm = (img2 + 1) / 2.0

        # norm for semantic seg network
        input_transformed1 = seg_batch_transform(img1_denorm)
        input_transformed2 = seg_batch_transform(img2_denorm)

        # compute labels from original image and logits from translated version
        target = (self.segmentation_model(input_transformed1).max(1)[1])
        output = self.segmentation_model(input_transformed2)

        if not self.full_adaptation and mask is not None:
            # Resize mask to the size of the image
            mask1 = torch.nn.functional.interpolate(mask,
                                                    size=(self.newsize,
                                                          self.newsize))
            mask1_tensor = torch.tensor(mask1, dtype=torch.long).cuda()
            mask1_tensor = mask1_tensor.squeeze(1)
            # we want the masked region to be labeled as unknown (19 is not an existing label)
            target_with_mask = torch.mul(1 - mask1_tensor,
                                         target) + mask1_tensor * 19

            mask2 = torch.nn.functional.interpolate(mask,
                                                    size=(self.newsize,
                                                          self.newsize))
            mask_tensor = torch.tensor(mask2, dtype=torch.float).cuda()
            output_with_mask = torch.mul(1 - mask_tensor, output)

            # cat the mask as to the logits (loss=0 over the masked region)
            output_with_mask_cat = torch.cat((output_with_mask, mask_tensor),
                                             dim=1)
            loss = nn.CrossEntropyLoss()(output_with_mask_cat,
                                         target_with_mask)
        else:
            loss = nn.CrossEntropyLoss()(output, target)
        return loss

    def sample(self, x_a, x_b):
        """ 
        Infer the model on a batch of image
        
        Arguments:
            x_a {torch.Tensor} -- batch of image from domain A
            x_b {[type]} -- batch of image from domain B
        
        Returns:
            A list of torch images -- columnwise :x_a, autoencode(x_a), x_ab_1, x_ab_2
            Or if self.semantic_w is true: x_a, autoencode(x_a), Semantic segmentation x_a, 
            x_ab_1,semantic segmentation x_ab_1, x_ab_2
        """
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []

        if self.gen_state == 0:
            for i in range(x_a.size(0)):
                c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
                c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
                x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
                x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
                if self.guided == 0:
                    x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
                    x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
                    x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
                    x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
                elif self.guided == 1:
                    x_ba1.append(self.gen_a.decode(
                        c_b, s_a_fake))  # s_a1[i].unsqueeze(0)))
                    x_ba2.append(self.gen_a.decode(
                        c_b, s_a_fake))  # s_a2[i].unsqueeze(0)))
                    x_ab1.append(self.gen_b.decode(
                        c_a, s_b_fake))  # s_b1[i].unsqueeze(0)))
                    x_ab2.append(self.gen_b.decode(
                        c_a, s_b_fake))  # s_b2[i].unsqueeze(0)))
                else:
                    print("self.guided unknown value:", self.guided)

        elif self.gen_state == 1:
            for i in range(x_a.size(0)):
                c_a, s_a_fake = self.gen.encode(x_a[i].unsqueeze(0), 1)
                c_b, s_b_fake = self.gen.encode(x_b[i].unsqueeze(0), 2)
                x_a_recon.append(self.gen.decode(c_a, s_a_fake, 1))
                x_b_recon.append(self.gen.decode(c_b, s_b_fake, 2))
                if self.guided == 0:
                    x_ba1.append(self.gen.decode(c_b, s_a1[i].unsqueeze(0), 1))
                    x_ba2.append(self.gen.decode(c_b, s_a2[i].unsqueeze(0), 1))
                    x_ab1.append(self.gen.decode(c_a, s_b1[i].unsqueeze(0), 2))
                    x_ab2.append(self.gen.decode(c_a, s_b2[i].unsqueeze(0), 2))
                elif self.guided == 1:
                    x_ba1.append(self.gen.decode(c_b, s_a_fake,
                                                 1))  # s_a1[i].unsqueeze(0)))
                    x_ba2.append(self.gen.decode(c_b, s_a_fake,
                                                 1))  # s_a2[i].unsqueeze(0)))
                    x_ab1.append(self.gen.decode(c_a, s_b_fake,
                                                 2))  # s_b1[i].unsqueeze(0)))
                    x_ab2.append(self.gen.decode(c_a, s_b_fake,
                                                 2))  # s_b2[i].unsqueeze(0)))
                else:
                    print("self.guided unknown value:", self.guided)

        else:
            print("self.gen_state unknown value:", self.gen_state)

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)

        if self.semantic_w:
            rgb_a_list, rgb_b_list, rgb_ab_list, rgb_ba_list = [], [], [], []

            for i in range(x_a.size(0)):

                # Inference semantic segmentation on original images
                im_a = (x_a[i].squeeze() + 1) / 2.0
                im_b = (x_b[i].squeeze() + 1) / 2.0

                input_transformed_a = seg_transform()(im_a).unsqueeze(0)
                input_transformed_b = seg_transform()(im_b).unsqueeze(0)
                output_a = (self.segmentation_model(
                    input_transformed_a).squeeze().max(0)[1])
                output_b = (self.segmentation_model(
                    input_transformed_b).squeeze().max(0)[1])

                rgb_a = decode_segmap(output_a.cpu().numpy())
                rgb_b = decode_segmap(output_b.cpu().numpy())
                rgb_a = Image.fromarray(rgb_a).resize(
                    (x_a.size(3), x_a.size(3)))
                rgb_b = Image.fromarray(rgb_b).resize(
                    (x_a.size(3), x_a.size(3)))

                rgb_a_list.append(transforms.ToTensor()(rgb_a).unsqueeze(0))
                rgb_b_list.append(transforms.ToTensor()(rgb_b).unsqueeze(0))

                # Inference semantic segmentation on fake images
                image_ab = (x_ab1[i].squeeze() + 1) / 2.0
                image_ba = (x_ba1[i].squeeze() + 1) / 2.0

                input_transformed_ab = seg_transform()(image_ab).unsqueeze(
                    0).to("cuda")
                input_transformed_ba = seg_transform()(image_ba).unsqueeze(
                    0).to("cuda")

                output_ab = (self.segmentation_model(
                    input_transformed_ab).squeeze().max(0)[1])
                output_ba = (self.segmentation_model(
                    input_transformed_ba).squeeze().max(0)[1])

                rgb_ab = decode_segmap(output_ab.cpu().numpy())
                rgb_ba = decode_segmap(output_ba.cpu().numpy())

                rgb_ab = Image.fromarray(rgb_ab).resize(
                    (x_a.size(3), x_a.size(3)))
                rgb_ba = Image.fromarray(rgb_ba).resize(
                    (x_a.size(3), x_a.size(3)))

                rgb_ab_list.append(transforms.ToTensor()(rgb_ab).unsqueeze(0))
                rgb_ba_list.append(transforms.ToTensor()(rgb_ba).unsqueeze(0))

            rgb1_a, rgb1_b, rgb1_ab, rgb1_ba = (
                torch.cat(rgb_a_list).cuda(),
                torch.cat(rgb_b_list).cuda(),
                torch.cat(rgb_ab_list).cuda(),
                torch.cat(rgb_ba_list).cuda(),
            )

        self.train()
        if self.semantic_w:
            self.segmentation_model.eval()
            return (
                x_a,
                x_a_recon,
                rgb1_a,
                x_ab1,
                rgb1_ab,
                x_ab2,
                x_b,
                x_b_recon,
                rgb1_b,
                x_ba1,
                rgb1_ba,
                x_ba2,
            )
        else:
            return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def sample_fid(self, x_a, x_b):
        """ 
        Infer the model on a batch of image
        
        Arguments:
            x_a {torch.Tensor} -- batch of image from domain A
            x_b {[type]} -- batch of image from domain B
        
        Returns:
            A list of torch images -- columnwise :x_a, autoencode(x_a), x_ab_1, x_ab_2
            Or if self.semantic_w is true: x_a, autoencode(x_a), Semantic segmentation x_a, 
            x_ab_1,semantic segmentation x_ab_1, x_ab_2
        """
        self.eval()
        x_ab1 = []

        if self.gen_state == 0:
            for i in range(x_a.size(0)):
                c_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0))
                _, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))

                if self.guided == 1:
                    x_ab1.append(self.gen_b.decode(c_a, s_b_fake))

                else:
                    print("self.guided unknown value:", self.guided)

        elif self.gen_state == 1:
            for i in range(x_a.size(0)):
                c_a, _ = self.gen.encode(x_a[i].unsqueeze(0), 1)
                _, s_b_fake = self.gen.encode(x_b[i].unsqueeze(0), 2)
                if self.guided == 1:
                    x_ab1.append(self.gen.decode(c_a, s_b_fake, 2))
                else:
                    print("self.guided unknown value:", self.guided)

        else:
            print("self.gen_state unknown value:", self.gen_state)

        x_ab1 = torch.cat(x_ab1)
        self.train()
        if self.semantic_w:
            self.segmentation_model.eval()

        return x_ab1

    def dis_update(self, x_a, x_b, hyperparameters, comet_exp=None):
        """
        Update the weights of the discriminator
        
        Arguments:
            x_a {torch.Tensor} -- Image from domain A after transform in tensor format
            x_b {torch.Tensor} -- Image from domain B after transform in tensor format
            hyperparameters {dictionnary} -- dictionnary with all hyperparameters 
        
        Keyword Arguments:
            comet_exp {cometExperience} -- CometML object use to log all the loss and images (default: {None})        
        """
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        if self.gen_state == 0:
            # encode
            c_a, s_a_prime = self.gen_a.encode(x_a)
            c_b, s_b_prime = self.gen_b.encode(x_b)
            # decode (cross domain)
            if self.guided == 0:
                x_ba = self.gen_a.decode(c_b, s_a)
                x_ab = self.gen_b.decode(c_a, s_b)
            elif self.guided == 1:
                x_ba = self.gen_a.decode(c_b, s_a_prime)
                x_ab = self.gen_b.decode(c_a, s_b_prime)
            else:
                print("self.guided unknown value:", self.guided)
        elif self.gen_state == 1:
            # encode
            c_a, s_a_prime = self.gen.encode(x_a, 1)
            c_b, s_b_prime = self.gen.encode(x_b, 2)
            # decode (cross domain)
            if self.guided == 0:
                x_ba = self.gen.decode(c_b, s_a, 1)
                x_ab = self.gen.decode(c_a, s_b, 2)
            elif self.guided == 1:
                x_ba = self.gen.decode(c_b, s_a_prime, 1)
                x_ab = self.gen.decode(c_a, s_b_prime, 2)
            else:
                print("self.guided unknown value:", self.guided)
        else:
            print("self.gen_state unknown value:", self.gen_state)

        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)

        self.loss_dis_total = (hyperparameters["gan_w"] * self.loss_dis_a +
                               hyperparameters["gan_w"] * self.loss_dis_b)
        self.loss_dis_total.backward()
        self.dis_opt.step()

        if comet_exp is not None:
            comet_exp.log_metric("loss_dis_b", self.loss_dis_b.cpu().detach())
            comet_exp.log_metric("loss_dis_a", self.loss_dis_a.cpu().detach())

    def domain_classifier_update(self,
                                 x_a,
                                 x_b,
                                 hyperparameters,
                                 comet_exp=None):
        """
        Update the weights of the domain classifier
        
        Arguments:
            x_a {torch.Tensor} -- Image from domain A after transform in tensor format
            x_b {torch.Tensor} -- Image from domain B after transform in tensor format
            hyperparameters {dictionnary} -- dictionnary with all hyperparameters 
        
        Keyword Arguments:
            comet_exp {cometExperience} -- CometML object use to log all the loss and images (default: {None})        
        """
        self.dann_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        if self.gen_state == 0:
            # encode
            c_a, _ = self.gen_a.encode(x_a)
            c_b, _ = self.gen_b.encode(x_b)
        elif self.gen_state == 1:
            # encode
            c_a, _ = self.gen.encode(x_a, 1)
            c_b, _ = self.gen.encode(x_b, 2)
        else:
            print("self.gen_state unknown value:", self.gen_state)

        # domain classifier loss
        self.domain_class_loss, out_a, out_b = self.compute_domain_adv_loss(
            c_a, c_b, compute_accuracy=True, minimize=True)

        self.domain_class_loss.backward()
        self.dann_opt.step()

        if comet_exp is not None:
            comet_exp.log_metric("domain_class_loss",
                                 self.domain_class_loss.cpu().detach())
            comet_exp.log_metric("probability A being identified as A",
                                 out_a.cpu().detach())
            comet_exp.log_metric("probability B being identified as B",
                                 out_b.cpu().detach())

    def update_learning_rate(self):
        """ 
        Update the learning rate
        """
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.dann_scheduler is not None:
            self.dann_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        """
        Resume the training loading the network parameters
        
        Arguments:
            checkpoint_dir {string} -- path to the directory where the checkpoints are saved
            hyperparameters {dictionnary} -- dictionnary with all hyperparameters 
        
        Returns:
            int -- number of iterations (used by the optimizer)
        """
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        if self.gen_state == 0:
            self.gen_a.load_state_dict(state_dict["a"])
            self.gen_b.load_state_dict(state_dict["b"])
        elif self.gen_state == 1:
            self.gen.load_state_dict(state_dict["2"])
        else:
            print("self.gen_state unknown value:", self.gen_state)

        # Load domain classifier
        if self.domain_classif == 1:
            last_model_name = get_model_list(checkpoint_dir, "domain_classif")
            state_dict = torch.load(last_model_name)
            self.domain_classifier.load_state_dict(state_dict["d"])

        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict["a"])
        self.dis_b.load_state_dict(state_dict["b"])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, "optimizer.pt"))
        self.dis_opt.load_state_dict(state_dict["dis"])
        self.gen_opt.load_state_dict(state_dict["gen"])

        if self.domain_classif == 1:
            self.dann_opt.load_state_dict(state_dict["dann"])
            self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters,
                                                iterations)
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print("Resume from iteration %d" % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        """
        Save generators, discriminators, and optimizers
        
        Arguments:
            snapshot_dir {string} -- directory path where to save the networks weights
            iterations {int} -- number of training iterations
        """
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, "gen_%08d.pt" % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, "dis_%08d.pt" % (iterations + 1))
        domain_classifier_name = os.path.join(
            snapshot_dir, "domain_classifier_%08d.pt" % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, "optimizer.pt")
        if self.gen_state == 0:
            torch.save(
                {
                    "a": self.gen_a.state_dict(),
                    "b": self.gen_b.state_dict()
                }, gen_name)
        elif self.gen_state == 1:
            torch.save({"2": self.gen.state_dict()}, gen_name)
        else:
            print("self.gen_state unknown value:", self.gen_state)
        torch.save({
            "a": self.dis_a.state_dict(),
            "b": self.dis_b.state_dict()
        }, dis_name)
        if self.domain_classif:
            torch.save({"d": self.domain_classifier.state_dict()},
                       domain_classifier_name)
            torch.save(
                {
                    "gen": self.gen_opt.state_dict(),
                    "dis": self.dis_opt.state_dict(),
                    "dann": self.dann_opt.state_dict(),
                },
                opt_name,
            )
        else:
            torch.save(
                {
                    "gen": self.gen_opt.state_dict(),
                    "dis": self.dis_opt.state_dict()
                },
                opt_name,
            )
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']

        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            eps=1e-8,
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            eps=1e-8,
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        self.iter = 0

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def intrinsic_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def volumeloss_criterion(self, input, target):
        idx_select = torch.tensor([0]).cuda()
        input, target = input.index_select(1, idx_select), target.index_select(
            1, idx_select)
        input, target = torch.mean(input, 3), torch.mean(target, 3)
        input, target = torch.mean(input, 2), torch.mean(target, 2)
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        with torch.no_grad():
            self.eval()
            s_a = Variable(self.s_a)
            s_b = Variable(self.s_b)
            c_a, s_a_fake = self.gen_a.encode(x_a)
            c_b, s_b_fake = self.gen_b.encode(x_b)
            x_ba = self.gen_a.decode(c_b, s_a)
            x_ab = self.gen_b.decode(c_a, s_b)
            self.train()
            return x_ab, x_ba

    def update_iter(self):
        self.iter += 1

    def gen_update(self,
                   x_a,
                   x_b,
                   hyperparameters,
                   x_a_rand=None,
                   x_b_rand=None):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(
            c_a_recon,
            s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            c_b_recon,
            s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba, x_a_rand)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab, x_b_rand)
        # ceps loss
        self.loss_gen_ceps_a = self.calc_cepstrum_loss(
            x_ba) if hyperparameters['ceps_w'] > 0 else 0
        self.loss_gen_ceps_b = self.calc_cepstrum_loss(
            x_ab) if hyperparameters['ceps_w'] > 0 else 0
        # flux loss
        self.loss_gen_flux_a2b = self.calc_spectral_flux_loss(
            x_ab) if hyperparameters['flux_w'] > 0 else 0
        self.loss_gen_flux_b2a = self.calc_spectral_flux_loss(
            x_ba) if hyperparameters['flux_w'] > 0 else 0
        # enve loss
        self.loss_gen_enve_a2b = self.calc_spectral_enve15_loss(
            x_ab) if hyperparameters['enve_w'] > 0 else 0
        self.loss_gen_enve_b2a = self.calc_spectral_enve15_loss(
            x_ba) if hyperparameters['enve_w'] > 0 else 0
        # volume loss
        self.loss_gen_vol_a = self.volumeloss_criterion(
            x_a, x_ab) if hyperparameters['vol_w'] > 0 else 0
        self.loss_gen_vol_b = self.volumeloss_criterion(
            x_b, x_ba) if hyperparameters['vol_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['ceps_w'] * self.loss_gen_ceps_a + \
                              hyperparameters['ceps_w'] * self.loss_gen_ceps_b + \
                              hyperparameters['flux_w'] * self.loss_gen_flux_a2b + \
                              hyperparameters['flux_w'] * self.loss_gen_flux_b2a + \
                              hyperparameters['enve_w'] * self.loss_gen_enve_a2b + \
                              hyperparameters['enve_w'] * self.loss_gen_enve_b2a + \
                              hyperparameters['vol_w'] * self.loss_gen_vol_a + \
                              hyperparameters['vol_w'] * self.loss_gen_vol_b
        self.loss_gen_total.backward()
        if hyperparameters['clip_grad'] == 'value':
            torch.nn.utils.clip_grad_value_(
                list(self.gen_a.parameters()) + list(self.gen_b.parameters()),
                1)
        elif hyperparameters['clip_grad'] == 'norm':
            torch.nn.utils.clip_grad_norm_(
                list(self.gen_a.parameters()) + list(self.gen_b.parameters()),
                0.5)
        self.gen_opt.step()

    def calc_cepstrum_loss(self, x_fake):
        idx_select_spec = torch.tensor([0]).cuda()
        idx_select_ceps = torch.tensor([1]).cuda()

        fake_spec = x_fake.index_select(
            1, idx_select_spec).detach().cpu().numpy()
        ceps = scipy.fftpack.dct(fake_spec, axis=2, type=2, norm='ortho')
        ceps = np.maximum(ceps, 0)
        return self.intrinsic_criterion(
            x_fake.index_select(1, idx_select_ceps),
            torch.from_numpy(ceps).cuda())

    def calc_spectral_flux_loss(self, x_fake):
        idx_select_spec = torch.tensor([0]).cuda()
        idx_select_flux = torch.tensor([2]).cuda()

        fake_spec = x_fake.index_select(
            1, idx_select_spec).detach().cpu().numpy()
        spec_flux = np.zeros_like(fake_spec)
        hei, wid = 256, 256
        for i in range(1, wid - 1):
            spec_flux[:, :, :, i] = np.maximum(
                fake_spec[:, :, :, i + 1] - fake_spec[:, :, :, i - 1], 0.0)
        spec_flux[:, :, :, 0] = spec_flux[:, :, :, 1]
        spec_flux[:, :, :, -1] = spec_flux[:, :, :, -2]
        return self.intrinsic_criterion(
            x_fake.index_select(1, idx_select_flux),
            torch.from_numpy(spec_flux).cuda())

    def calc_spectral_enve15_loss(self, x_fake):
        idx_select_spec = torch.tensor([0]).cuda()
        idx_select_enve = torch.tensor([3]).cuda()

        fake_spec = x_fake.index_select(
            1, idx_select_spec).detach().cpu().numpy()
        MFCC = scipy.fftpack.dct(fake_spec, axis=2, type=2, norm='ortho')
        MFCC[:, :, 15:, :] = 0.0
        spec_enve = scipy.fftpack.idct(MFCC, axis=2, type=2, norm='ortho')
        spec_enve = np.maximum(spec_enve, 0.0)
        return self.intrinsic_criterion(
            x_fake.index_select(1, idx_select_enve),
            torch.from_numpy(spec_enve).cuda())

    def sample(self, x_a, x_b):
        with torch.no_grad():
            self.eval()
            s_a1 = Variable(self.s_a)
            s_b1 = Variable(self.s_b)
            s_a2 = Variable(
                torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
            s_b2 = Variable(
                torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
            x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
            for i in range(x_a.size(0)):
                c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
                c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
                x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
                x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
                x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
                x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
                x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
                x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
            x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
            x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
            x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
            self.train()
            return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        #self.save_grad(list(self.dis_a.named_parameters()) + list(self.dis_b.named_parameters()))
        #torch.nn.utils.clip_grad_norm_(list(self.dis_a.parameters()) + list(self.dis_b.parameters()), 0.5)
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Example #13
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.dis_sa = MsImageDis(
            hyperparameters['input_dim_a'] * 2,
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_sb = MsImageDis(
            hyperparameters['input_dim_b'] * 2,
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        dis_style_params = list(self.dis_sa.parameters()) + list(
            self.dis_sb.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_style_opt = torch.optim.Adam(
            [p for p in dis_style_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.dis_style_scheduler = get_scheduler(self.dis_style_opt,
                                                 hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_sa.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))
        self.dis_sb.apply(weights_init('gaussian'))
        if hyperparameters['gen']['CE_method'] == 'vgg':
            self.gen_a.content_init()
            self.gen_b.content_init()
        self.criterion = nn.L1Loss().cuda()
        self.triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2).cuda()
        self.kld = nn.KLDivLoss()
        self.contextual_loss = ContextualLoss()

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def kl_loss(self, input, target):
        #return torch.mean(torch.abs(self.kld(input, target)))
        return torch.mean(self.kld(input, target))

    def normalize_feat(self, feat):
        bs, c, H, W = feat.shape
        feat = feat.view(bs, c, -1)
        feat_norm = torch.norm(feat, 2, 1,
                               keepdim=True) + sys.float_info.epsilon
        feat = torch.div(feat, feat_norm)
        #print(max(feat))
        return feat

    def norm_two_domain(self, feat_c, feat_s):
        feat = torch.cat((feat_c, feat_s), 1)
        bs, c, H, W = feat.shape
        feat_norm = torch.norm(feat, 2, 1, keepdim=True)
        feat = torch.div(feat, feat_norm)
        feat_c = feat[:, 0:256, :, :].view(bs, 256, -1)
        feat_s = feat[:, 256:512, :, :].view(bs, 256, -1)
        return feat_c, feat_s

    def generate_map(self, corr_index, h, w):
        coor = []
        corr_map = []
        for i in range(len(corr_index)):
            x = corr_index[i] // h
            y = corr_index[i] % w
            coor.append(x)
            coor.append(y)
            corr_map.append(list(np.asarray(coor)))
            coor.clear()
        corr_map_final = np.reshape(np.asarray(corr_map), (h, w, 2))
        return corr_map_final

    def warp_img(self, corr_map, ref_img):
        bs, c, h_img, w_img = ref_img.shape
        h, w, _ = corr_map.shape
        scale = h_img // h
        warped_img = torch.zeros(ref_img.shape)
        for i in range(h):
            for j in range(w):
                nnx = corr_map[i][j][0]
                nny = corr_map[i][j][1]
                warped_img[:, :, i * scale:(i + 1) * scale, j * scale:(j + 1) *
                           scale] = ref_img[:, :,
                                            nnx * scale:(nnx + 1) * scale,
                                            nny * scale:(nny + 1) * scale]
        return warped_img.cuda()

    def warp_style(self, cur_content, ref_content, ref_style):
        # normalize feature
        cur_content = self.normalize_feat(cur_content)
        ref_content = self.normalize_feat(ref_content)
        #cur_content, ref_content = self.norm_two_domain(cur_content, ref_content)
        cur_content = cur_content.permute(0, 2, 1)

        # calculate similarity
        f = torch.matmul(cur_content, ref_content)  # 1 x (H x W) x (H x W)
        f_corr = F.softmax(f / 0.005, dim=-1)  # 1 x (H x W) x (H x W)
        #f_corr = F.softmax(f, dim=-1) # 1 x (H x W) x (H x W)

        # get corr index replace softmax
        bs, HW, WH = f_corr.shape
        corr_index = torch.argmax(f_corr, dim=-1).squeeze(0)

        # collect ref style
        bs, c, H, W = ref_style.shape
        ref_style = ref_style.view(bs, c, -1)
        ref_style = ref_style.permute(0, 2, 1)  # 1 x (H x W) x c

        # warp ref style
        warped_style = torch.matmul(f_corr, ref_style)  # 1 x (H x W) x c
        warped_style = warped_style.permute(0, 2, 1).contiguous()
        warped_style = warped_style.view(bs, c, H, W)

        return corr_index, warped_style

    def forward(self, x_a, x_b):
        self.eval()
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)

        # warp the ref_style to the content_style
        _, s_ab_warp = warp_style(c_a, c_b, s_b_fake)
        _, s_ba_warp = warp_style(c_b, c_a, s_a_fake)

        x_ba = self.gen_a.decode(s_ba_warp, c_b)
        x_ab = self.gen_b.decode(s_ab_warp, c_a)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, x_adf, x_bdf, hyperparameters):
        self.gen_opt.zero_grad()
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)

        # add style warp here
        _, s_ab = self.warp_style(c_a, c_b, s_b_prime)
        _, s_ba = self.warp_style(c_b, c_a, s_a_prime)

        # decode (within domain)
        x_a_recon = self.gen_a.decode(s_a_prime, c_a)
        x_b_recon = self.gen_b.decode(s_b_prime, c_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(s_ba, c_b)
        x_ab = self.gen_b.decode(s_ab, c_a)
        # encode again
        c_b_recon, s_ba_recon = self.gen_a.encode(
            x_ba)  # now the s_a_recon matches the structure of B
        c_a_recon, s_ab_recon = self.gen_b.encode(x_ab)

        # decode again (if needed)
        # to warp style first
        _, s_aba = self.warp_style(c_a, c_b, s_ba_recon)
        _, s_bab = self.warp_style(c_b, c_a, s_ab_recon)
        # to reconstruct then
        x_aba = self.gen_a.decode(
            s_a_prime,
            c_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            s_b_prime,
            c_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # prepare paired data for adv generator
        #pair_a_ffake = torch.cat((x_ba, x_a), 1)
        #pair_b_ffake = torch.cat((x_ab, x_b), 1)

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_aba, s_a_prime)
        self.loss_gen_recon_s_b = self.recon_criterion(
            s_bab, s_b_prime)  # default is s_bab, need to test s_b_recon
        #self.loss_gen_recon_s_a += self.triplet_loss(s_a_prime, s_aba, s_b_prime)
        #self.loss_gen_recon_s_b += self.triplet_loss(s_b_prime, s_bab, s_a_prime)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        #self.loss_gen_kl_ab = self.kl_loss(x_ab, x_b)
        #self.loss_gen_kl_ba = self.kl_loss(x_ba, x_a)
        self.loss_gen_cx_a = self.contextual_loss(s_ba, s_a_prime)
        self.loss_gen_cx_b = self.contextual_loss(s_ab, s_b_prime)
        self.loss_gen_cycrecon_x_a = self.criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0

        # GAN loss
        self.loss_gen_adv_xa = self.gen_a.calc_gen_loss(
            self.dis_a.forward(x_ba))
        self.loss_gen_adv_xb = self.gen_b.calc_gen_loss(
            self.dis_b.forward(x_ab))
        #self.loss_gen_adv_sxa = self.gen_a.calc_gen_loss(self.dis_sa.forward(pair_a_ffake))
        #self.loss_gen_adv_sxb = self.gen_b.calc_gen_loss(self.dis_sb.forward(pair_b_ffake))

        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss_new(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss_new(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_xa + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_xb + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        #hyperparameters['recon_kl_w'] * self.loss_gen_kl_ab + \
        #hyperparameters['recon_kl_w'] * self.loss_gen_kl_ba + \
        #hyperparameters['recon_cx_w'] * self.loss_gen_cx_a + \
        #hyperparameters['recon_cx_w'] * self.loss_gen_cx_b + \
        #hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_s_a + \
        #hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_s_b + \
        #hyperparameters['gan_wp'] * self.loss_gen_adv_sxa + \
        #hyperparameters['gan_wp'] * self.loss_gen_adv_sxb + \
        self.loss_gen_total.backward()

        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def compute_vgg_loss_new(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_feat = vgg(img_vgg)
        target_feat = vgg(target_vgg)
        return self.recon_criterion(img_feat, target_feat)

    def sample(self, x_a, x_b, x_adf, x_bdf):
        self.eval()
        #s_a1 = Variable(self.s_a)
        #s_b1 = Variable(self.s_b)
        #s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        #s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        h = w = 64
        for i in range(x_a.size(0)):
            img_a = x_a[i].unsqueeze(0)
            img_b = x_b[i].unsqueeze(0)
            c_a, s_a_fake = self.gen_a.encode(img_a)
            c_b, s_b_fake = self.gen_b.encode(img_b)
            # reconstruction
            x_a_recon.append(self.gen_a.decode(s_a_fake, c_a))
            x_b_recon.append(self.gen_b.decode(s_b_fake, c_b))
            print(x_a_recon[0].shape)
            # warp style
            corr_index_ab, s_ab = self.warp_style(c_a, c_b, s_b_fake)
            corr_index_ba, s_ba = self.warp_style(c_b, c_a, s_a_fake)
            # cross domain construction
            x_ba1.append(self.gen_a.decode(s_ba, c_b))
            ## output warped results x_ba2
            corr_map_ba = self.generate_map(corr_index_ba, h, w)
            x_ba2.append(self.warp_img(corr_map_ba, img_a))

            x_ab1.append(self.gen_b.decode(s_ab, c_a))
            ## output warped results x_ab2
            corr_map_ab = self.generate_map(corr_index_ab, h, w)
            x_ab2.append(self.warp_img(corr_map_ab, img_b))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_adf, x_a_recon, x_ab1, x_ab2, x_b, x_bdf, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, x_adf, x_bdf, hyperparameters):

        self.dis_opt.zero_grad()
        self.dis_style_opt.zero_grad()
        #s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        #s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a = self.gen_a.encode(x_a)
        c_b, s_b = self.gen_b.encode(x_b)

        # warp the style here
        _, s_ab_warp = self.warp_style(c_a, c_b, s_b)
        _, s_ba_warp = self.warp_style(c_b, c_a, s_a)

        # decode (cross domain)
        x_ba = self.gen_a.decode(s_ba_warp, c_b)
        x_ab = self.gen_b.decode(s_ab_warp, c_a)

        # prepare data for the paired discriminator
        # real fake data -> 0
        if (len(self.dis_sa.pool_) == 0):
            print(len(self.dis_sa.pool_))
            pair_a_rfake = torch.cat((x_b, x_a), 1)
        else:
            pair_a_rfake = torch.cat((self.dis_sa.pool('fetch'), x_a), 1)
        self.dis_sa.pool('push', x_a)

        if (len(self.dis_sb.pool_) == 0):
            print(len(self.dis_sb.pool_))
            pair_b_rfake = torch.cat((x_a, x_b), 1)
        else:
            pair_b_rfake = torch.cat((self.dis_sb.pool('fetch'), x_b), 1)
        self.dis_sb.pool('push', x_b)

        # real real data -> 1
        pair_a_rreal = torch.cat((x_a, x_adf), 1)
        pair_b_rreal = torch.cat((x_b, x_bdf), 1)
        # fake fake data -> 0
        #pair_a_ffake = torch.cat((x_ba.detach(), x_a), 1)
        #pair_b_ffake = torch.cat((x_ab.detach(), x_b), 1)

        # D loss
        self.loss_dis_xa = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_xb = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        #self.loss_dis_xa = self.dis_a.calc_dis_loss(x_ba.detach(), self.dis_sa.pool('fetch'))
        #self.loss_dis_xb = self.dis_b.calc_dis_loss(x_ab.detach(), self.dis_sb.pool('fetch'))
        #self.loss_dis_sxa = (self.dis_sa.calc_dis_loss(pair_a_rfake, pair_a_rreal) + self.dis_sa.calc_dis_loss(pair_a_ffake, pair_a_rreal)) / 2
        #self.loss_dis_sxb = (self.dis_sb.calc_dis_loss(pair_b_rfake, pair_b_rreal) + self.dis_sb.calc_dis_loss(pair_b_ffake, pair_b_rreal)) / 2
        #self.loss_dis_sxa = self.dis_sa.calc_dis_loss(pair_a_ffake, pair_a_rreal)
        #self.loss_dis_sxb = self.dis_sb.calc_dis_loss(pair_b_ffake, pair_b_rreal)

        #self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_xa + hyperparameters['gan_w'] * self.loss_dis_xb + hyperparameters['gan_wp'] * self.loss_dis_sxa + hyperparameters['gan_wp'] * self.loss_dis_sxb
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_xa + hyperparameters[
                'gan_w'] * self.loss_dis_xb
        self.loss_dis_total.backward()
        self.dis_opt.step()
        self.dis_style_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.dis_style_scheduler is not None:
            self.dis_style_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.dis_style_scheduler = get_scheduler(self.dis_style_opt,
                                                 hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Example #14
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(
            c_a_recon,
            s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            c_b_recon,
            s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def snap_clean(self, snap_dir, iterations, save_last=10000, period=20000):
        # Cleaning snapshot directory from old files
        if not os.path.exists(snap_dir):
            return None

        gen_models = [
            os.path.join(snap_dir, f) for f in os.listdir(snap_dir)
            if "gen" in f and ".pt" in f
        ]
        dis_models = [
            os.path.join(snap_dir, f) for f in os.listdir(snap_dir)
            if "dis" in f and ".pt" in f
        ]

        gen_models.sort()
        dis_models.sort()
        marked_clean = []
        for i, model in enumerate(gen_models):
            m_iter = int(model[-11:-3])
            if i == 0:
                m_prev = 0
                continue
            if m_iter > iterations - save_last:
                break
            if m_iter - m_prev < period:
                marked_clean.append(model)
            while m_iter - m_prev >= period:
                m_prev += period

        for i, model in enumerate(dis_models):
            m_iter = int(model[-11:-3])
            if i == 0:
                m_prev = 0
                continue
            if m_iter > iterations - save_last:
                break
            if m_iter - m_prev < period:
                marked_clean.append(model)
            while m_iter - m_prev >= period:
                m_prev += period

        print(f'Cleaning snapshots: {marked_clean}')
        for f in marked_clean:
            os.remove(f)

    def save(self, snapshot_dir, iterations, smart_override):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
        if smart_override:
            self.snap_clean(snapshot_dir, iterations + 1)
Example #15
0
class AGUIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(AGUIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.noise_dim = hyperparameters['gen']['noise_dim']
        self.attr_dim = len(hyperparameters['gen']['selected_attrs'])
        self.gen = AdaINGen(hyperparameters['input_dim'],
                            hyperparameters['gen'])
        self.dis = MsImageDis(hyperparameters['input_dim'], self.attr_dim,
                              hyperparameters['dis'])
        self.dis_content = ContentDis(
            hyperparameters['gen']['dim'] *
            (2**hyperparameters['gen']['n_downsample']), self.attr_dim)

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis.parameters()) + list(
            self.dis_content.parameters())
        gen_params = list(self.gen.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis.apply(weights_init('gaussian'))
        self.dis_content.apply(weights_init('gaussian'))

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def gen_update(self, x_l, x_u, l, hyperparameters):
        self.gen_opt.zero_grad()
        # l_s_rand = torch.randn_like(l_s)
        # l_s = torch.where(l_s == 0, l_s_rand, l_s)
        s_r = torch.cat([torch.randn(x_u.size(0), self.noise_dim).cuda(), l],
                        1)

        # encode
        c_l, s_l = self.gen.encode(x_l)
        c_u, s_u = self.gen.encode(x_u)

        # decode (within domain)
        x_u_recon = self.gen.decode(c_u, s_u)

        # decode (cross domain)
        x_ur = self.gen.decode(c_u, s_r)

        # encode again
        c_u_recon, s_r_recon = self.gen.encode(x_ur)

        x_u_cycle = self.gen.decode(c_u_recon, s_u)

        # additional KL-loss (optional)
        s_mean = s_l[:, 0:self.noise_dim].mean()
        s_std = s_l[:, 0:self.noise_dim].std()

        self.loss_gen_kld = (s_mean**2 + s_std.pow(2) - s_std.pow(2).log() -
                             1).mean() / 2

        self.loss_gen_adv_content = self.dis_content.calc_gen_loss(c_l, c_u, l)
        # reconstruction loss
        self.loss_gen_rec = self.recon_criterion(x_u_recon, x_u)
        self.loss_gen_rec_s = self.recon_criterion(s_r_recon, s_r)
        self.loss_gen_rec_c = self.recon_criterion(c_u_recon, c_u)

        self.loss_gen_cyc = self.recon_criterion(x_u_cycle, x_u)

        # GAN loss
        self.loss_gen_adv = self.dis.calc_gen_loss(x_ur, l)

        # label part loss
        self.loss_gen_cla = (
            s_l[:, self.noise_dim:self.noise_dim + self.attr_dim] -
            l).pow(2).mean()

        self.loss_gen_total = hyperparameters['adv_w'] * self.loss_gen_adv + \
                              hyperparameters['adv_c_w'] * self.loss_gen_adv_content + \
                              hyperparameters['rec_w'] * self.loss_gen_rec + \
                              hyperparameters['rec_s_w'] * self.loss_gen_rec_s + \
                              hyperparameters['rec_c_w'] * self.loss_gen_rec_c + \
                              hyperparameters['cla_w'] * self.loss_gen_cla + \
                              hyperparameters['kld_w'] * self.loss_gen_kld + \
                              hyperparameters['cyc_w'] * self.loss_gen_cyc

        self.loss_gen_total.backward()

        self.gen_opt.step()

        return self.loss_gen_total.detach()

    def sample(self, x_l, l):

        c_l, s_l = self.gen.encode(x_l)

        # decode (within domain)
        x_l_recon = self.gen.decode(c_l, s_l)

        out = [x_l, x_l_recon]
        for i in range(self.attr_dim):
            s_changed = s_l.clone()
            s_changed[:, self.noise_dim + i] = -l[:, i]
            out += [self.gen.decode(c_l, s_changed)]

        return out

    def dis_update(self, x_l, x_u, l, hyperparameters):
        self.dis_opt.zero_grad()

        s_r = torch.cat([torch.randn(x_u.size(0), self.noise_dim).cuda(), l],
                        1)

        # encode
        c_l, s_l = self.gen.encode(x_l)
        c_u, s_u = self.gen.encode(x_u)

        # decode (cross domain)
        x_ur = self.gen.decode(c_u, s_r)

        # D loss
        self.loss_dis_adv = self.dis.calc_dis_loss(x_ur.detach(), x_l, x_u, l)
        self.loss_dis_adv_content = self.dis_content.calc_dis_loss(c_l, c_u, l)
        self.loss_dis_total = hyperparameters['adv_w'] * self.loss_dis_adv + \
                              hyperparameters['adv_c_w'] * self.loss_dis_adv_content
        self.loss_dis_total.backward()
        self.dis_opt.step()

        return self.loss_dis_total.detach()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen.load_state_dict(state_dict['gen'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis.load_state_dict(state_dict['dis'])
        self.dis_content.load_state_dict(state_dict['dis_content'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'gen': self.gen.state_dict()}, gen_name)
        torch.save(
            {
                'dis': self.dis_a.state_dict(),
                'dis_content': self.dis_content.state_dict()
            }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Example #16
0
class UNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(UNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        h_a, _ = self.gen_a.encode(x_a)
        h_b, _ = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(h_b)
        x_ab = self.gen_b.decode(h_a)
        self.train()
        return x_ab, x_ba

    def __compute_kl(self, mu):
        # def _compute_kl(self, mu, sd):
        # mu_2 = torch.pow(mu, 2)
        # sd_2 = torch.pow(sd, 2)
        # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0)
        # return encoding_loss
        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        # encode
        h_a, n_a = self.gen_a.encode(x_a)
        h_b, n_b = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(h_a + n_a)
        x_b_recon = self.gen_b.decode(h_b + n_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(h_b + n_b)
        x_ab = self.gen_b.decode(h_a + n_a)
        # encode again
        h_b_recon, n_b_recon = self.gen_a.encode(x_ba)
        h_a_recon, n_a_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
        self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
        self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a)
        self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b)
        self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
        self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def sample(self, x_a, x_b):
        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], []
        for i in range(x_a.size(0)):
            h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0))
            h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(h_a))
            x_b_recon.append(self.gen_b.decode(h_b))
            x_ba.append(self.gen_a.decode(h_b))
            x_ab.append(self.gen_b.decode(h_a))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba = torch.cat(x_ba)
        x_ab = torch.cat(x_ab)
        self.train()
        return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        # encode
        h_a, n_a = self.gen_a.encode(x_a)
        h_b, n_b = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(h_b + n_b)
        x_ab = self.gen_b.decode(h_a + n_a)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Example #17
0
class DGNet_Trainer(nn.Module):
    def __init__(self, hyperparameters, gpu_ids=[0]):
        super(DGNet_Trainer, self).__init__()
        # 从配置文件获取生成模型和鉴别模型的学习率
        lr_g = hyperparameters['lr_g']
        lr_d = hyperparameters['lr_d']

        # ID 类别
        ID_class = hyperparameters['ID_class']

        # 是否设置使用fp16,
        if not 'apex' in hyperparameters.keys():
            hyperparameters['apex'] = False
        self.fp16 = hyperparameters['apex']
        # Initiate the networks
        # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False.
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'],
                              hyperparameters['gen'],
                              fp16=False)  # auto-encoder for domain a
        self.gen_b = self.gen_a  # auto-encoder for domain b
        '''
        ft_netAB :   Ea
        '''
        # ID_stride: 外观编码器池化层的stride
        if not 'ID_stride' in hyperparameters.keys():
            hyperparameters['ID_stride'] = 2
        # id_a : 外观编码器  ->  Ea
        if hyperparameters['ID_style'] == 'PCB':
            self.id_a = PCB(ID_class)
        elif hyperparameters['ID_style'] == 'AB':
            self.id_a = ft_netAB(ID_class,
                                 stride=hyperparameters['ID_stride'],
                                 norm=hyperparameters['norm_id'],
                                 pool=hyperparameters['pool'])
        else:
            self.id_a = ft_net(ID_class,
                               norm=hyperparameters['norm_id'],
                               pool=hyperparameters['pool'])  # return 2048 now

        self.id_b = self.id_a  # 对图片b的操作与图片a的操作一致

        # 判别器,使用的是一个多尺寸的判别器,就是对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失
        # 经过网络3个缩放,,分别为:[batch_size, 1, 64, 32],[batch_size, 1, 32, 16],[batch_size, 1, 16, 8]
        self.dis_a = MsImageDis(3, hyperparameters['dis'],
                                fp16=False)  # discriminator for domain a
        self.dis_b = self.dis_a  # discriminator for domain b

        # load teachers
        if hyperparameters['teacher'] != "":
            teacher_name = hyperparameters['teacher']
            print(teacher_name)

            # 加载多个老师模型
            teacher_names = teacher_name.split(',')
            # 构建老师模型
            teacher_model = nn.ModuleList()  # 初始化为空,接下来开始填充
            teacher_count = 0
            for teacher_name in teacher_names:
                config_tmp = load_config(teacher_name)

                # 池化层的stride
                if 'stride' in config_tmp:
                    stride = config_tmp['stride']
                else:
                    stride = 2

                # 开始搭建网络
                model_tmp = ft_net(ID_class, stride=stride)
                teacher_model_tmp = load_network(model_tmp, teacher_name)
                teacher_model_tmp.model.fc = nn.Sequential(
                )  # remove the original fc layer in ImageNet
                teacher_model_tmp = teacher_model_tmp.cuda()
                # teacher_model_tmp,[3, 224, 224]

                # 使用fp16
                if self.fp16:
                    teacher_model_tmp = amp.initialize(teacher_model_tmp,
                                                       opt_level="O1")
                teacher_model.append(teacher_model_tmp.cuda().eval(
                ))  # 第一个填充为 teacher_model_tmp.cuda().eval()
                teacher_count += 1
            self.teacher_model = teacher_model

            # 是否使用batchnorm
            if hyperparameters['train_bn']:
                self.teacher_model = self.teacher_model.apply(train_bn)
        # 实例正则化
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # RGB to one channel
        # 因为Es 需要使用灰度图, 所以single 用来将图片转化为灰度图
        if hyperparameters['single'] == 'edge':
            self.single = to_edge
        else:
            self.single = to_gray(False)

        # Random Erasing when training
        # arasing_p 随机擦除的概率
        if not 'erasing_p' in hyperparameters.keys():
            self.erasing_p = 0
        else:
            self.erasing_p = hyperparameters['erasing_p']
        # 对图片中的某一随机区域进行擦除,具体:将该区域的像素值设置为均值
        self.single_re = RandomErasing(probability=self.erasing_p,
                                       mean=[0.0, 0.0, 0.0])

        if not 'T_w' in hyperparameters.keys():
            hyperparameters['T_w'] = 1
        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(
            self.dis_a.parameters())  #+ list(self.dis_b.parameters())
        gen_params = list(
            self.gen_a.parameters())  #+ list(self.gen_b.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr_d,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr_g,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        # id params
        # 修改 id_a模型中分类器的学习率
        if hyperparameters['ID_style'] == 'PCB':
            ignored_params = (
                list(map(id, self.id_a.classifier0.parameters())) +
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())) +
                list(map(id, self.id_a.classifier3.parameters())))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier0.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier3.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        elif hyperparameters['ID_style'] == 'AB':
            ignored_params = (
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        else:
            ignored_params = list(map(id, self.id_a.classifier.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        # 生成器和判别器中的优化策略(学习率的更新策略)
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.id_scheduler = get_scheduler(self.id_opt, hyperparameters)
        self.id_scheduler.gamma = hyperparameters['gamma2']

        #ID Loss
        self.id_criterion = nn.CrossEntropyLoss()
        self.criterion_teacher = nn.KLDivLoss(
            size_average=False)  # 生成主要特征: Lprim
        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # save memory
        # 保存当前的模型,是为了提高计算效率
        if self.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.id_a = self.id_a.cuda()

            self.gen_b = self.gen_a
            self.dis_b = self.dis_a
            self.id_b = self.id_a

            self.gen_a, self.gen_opt = amp.initialize(self.gen_a,
                                                      self.gen_opt,
                                                      opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a,
                                                      self.dis_opt,
                                                      opt_level="O1")
            self.id_a, self.id_opt = amp.initialize(self.id_a,
                                                    self.id_opt,
                                                    opt_level="O1")

    def to_re(self, x):
        out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3))
        out = out.cuda()
        for i in range(x.size(0)):
            out[i, :, :, :] = self.single_re(x[i, :, :, :])  # 修改对应像素值
        return out

    def recon_criterion(self, input, target):  # 重构损失函数
        diff = input - target.detach()  # 对应像素之间相减
        return torch.mean(torch.abs(diff[:]))

    def recon_criterion_sqrt(self, input, target):  # 重构损失平方函数
        diff = input - target
        return torch.mean(torch.sqrt(torch.abs(diff[:]) + 1e-8))

    def recon_criterion2(self, input, target):  # 重构损失平方求均值
        diff = input - target
        return torch.mean(diff[:]**2)

    def recon_cos(self, input, target):  # 重构均值余弦相似度损失
        cos = torch.nn.CosineSimilarity()
        cos_dis = 1 - cos(input, target)
        return torch.mean(cos_dis[:])

    def forward(self, x_a, x_b, xp_a, xp_b):
        '''
        一共输入4张图片
        :param x_a: :param xp_a:    id 相同
        :param x_b: :param xp_b:    id 相同

        为什么要输入四张图片:
        因为一个完整的DG_Net输入需要三张图片:id1, id2, id1正例
        如果一次输入3张图片,那么训练两组数据就需要6张图片
        而如果一次输入四张图片如:id1,id1正例, id2,id2正例
        那么就可以组成两组数据:id1,id2,id1正例  和  id2,id1,d2正例
        这样就节省了两张图片。
        '''

        # self.gen_a.encode :-> Es
        # single : 转化为灰度图
        s_a = self.gen_a.encode(self.single(
            x_a))  # shape: [batch_size, 128, 64, 32]    -> a st code
        s_b = self.gen_b.encode(self.single(
            x_b))  # shape:  [batch_size, 128, 64, 32]    -> b st code
        # self.id_a : -> Ea
        f_a, p_a = self.id_a(
            scale2(x_a))  #                                      -> a ap code
        f_b, p_b = self.id_b(scale2(x_b))
        # f shape:[batch_size, 2024*4=8192]           #                                      -> b ap code
        # p[0] shape:[batch_size, class_num=751], p[1] shape:[batch_size, class_num=751]     -> probability distribution

        # self.gen_a.decode -> D
        x_ba = self.gen_a.decode(
            s_b, f_a)  # shape: [batch_size, 3, 256, 128]     -> a-ap + b-st
        x_ab = self.gen_b.decode(
            s_a, f_b)  # shape: [batch_size, 3, 256, 128]     -> a-st + b-ap

        x_a_recon = self.gen_a.decode(
            s_a, f_a)  # shape: [batch_size, 3, 256, 128]     -> a-ap + a-st
        x_b_recon = self.gen_b.decode(
            s_b, f_b)  # shape: [batch_size, 3, 256, 128]     -> b-ap + b-st
        fp_a, pp_a = self.id_a(
            scale2(xp_a)
        )  #                                      -> x_a ap code, pro-dis
        fp_b, pp_b = self.id_b(
            scale2(xp_b)
        )  #                                      -> x_b ap code, pro-dis
        # decode the same person
        x_a_recon_p = self.gen_a.decode(
            s_a, fp_a)  # shape: [batch_size, 3, 256, 128]     -> a-st + x_a-ap
        x_b_recon_p = self.gen_b.decode(
            s_b, fp_b)  # shape: [batch_size, 3, 256, 128]     -> b-st + x_b-ap

        # Random Erasing only effect the ID and PID loss.
        if self.erasing_p > 0:
            x_a_re = self.to_re(scale2(x_a.clone()))
            x_b_re = self.to_re(scale2(x_b.clone()))
            xp_a_re = self.to_re(scale2(xp_a.clone()))
            xp_b_re = self.to_re(scale2(xp_b.clone()))
            _, p_a = self.id_a(x_a_re)  # 经过随机擦除之后再预测概率分布
            _, p_b = self.id_b(x_b_re)
            # encode the same ID different photo
            _, pp_a = self.id_a(xp_a_re)
            _, pp_b = self.id_b(xp_b_re)

        return x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p
        '''
        输入3张图片训练一次
        s_a = self.gen_a.encode(self.single(x_a))
        f_a, p_a = self.id_a(scale2(x_a))
        f_b, p_b = self.id_b(scale2(x_b))
        fp_a, pp_a = self.id_a(scale2(xp_a))
        x_a_recon = self.gen_a.decode(s_a, f_a)
        x_ab = self.gen_b.decode(s_a, f_b)
        x_a_recon_p = self.gen_a.decode(s_a, fp_a)
        
        
        输入3张图片训练一次
        s_b = self.gen_b.encode(self.single(x_b))
        f_a, p_a = self.id_a(scale2(x_a))
        f_b, p_b = self.id_b(scale2(x_b))
        fp_b, pp_b = self.id_b(scale2(xp_b))
        x_ba = self.gen_a.decode(s_b, f_a)
        x_b_recon_p = self.gen_b.decode(s_b, fp_b)
        x_b_recon_p = self.gen_b.decode(s_b, fp_b)
        '''

    def gen_update(self, x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b,
                   x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, x_a, x_b,
                   xp_a, xp_b, l_a, l_b, hyperparameters, iteration, num_gpu):
        # ppa, ppb is the same person
        # pp_a: 输入图片a经过Ea编码进行身份预测  pp_b:输入图片b经过Ea编码进行身份预测
        self.gen_opt.zero_grad()
        self.id_opt.zero_grad()

        # no gradient
        x_ba_copy = Variable(x_ba.data, requires_grad=False)
        x_ab_copy = Variable(x_ab.data, requires_grad=False)

        rand_num = random.uniform(0, 1)
        #################################
        # encode structure
        if hyperparameters['use_encoder_again'] >= rand_num:
            # encode again (encoder is tuned, input is fixed)
            s_a_recon = self.gen_b.enc_content(
                self.single(x_ab_copy))  # 对x_ab经过Es进行编码 得到st code
            s_b_recon = self.gen_a.enc_content(
                self.single(x_ba_copy))  # 对x_ba经过Es进行编码 得到st code
        else:
            # copy the encoder
            # 这里是shencopy
            self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content)
            self.enc_content_copy = self.enc_content_copy.eval()
            # encode again (encoder is fixed, input is tuned)
            s_a_recon = self.enc_content_copy(self.single(x_ab))
            s_b_recon = self.enc_content_copy(self.single(x_ba))

        #################################
        # encode appearance
        self.id_a_copy = copy.deepcopy(self.id_a)
        self.id_a_copy = self.id_a_copy.eval()
        if hyperparameters['train_bn']:
            self.id_a_copy = self.id_a_copy.apply(train_bn)
        self.id_b_copy = self.id_a_copy
        # encode again (encoder is fixed, input is tuned)
        f_a_recon, p_a_recon = self.id_a_copy(
            scale2(x_ba))  # 对合成的图片 x_ba进行Ea编码和身份预测
        f_b_recon, p_b_recon = self.id_b_copy(
            scale2(x_ab))  # 对合成的图片 x_ab进行Ea编码和身份预测

        # teacher Loss
        #  Tune the ID model
        log_sm = nn.LogSoftmax(dim=1)
        if hyperparameters['teacher_w'] > 0 and hyperparameters[
                'teacher'] != "":
            if hyperparameters['ID_style'] == 'normal':
                _, p_a_student = self.id_a(scale2(x_ba_copy))
                p_a_student = log_sm(p_a_student)
                p_a_teacher = predict_label(
                    self.teacher_model,
                    scale2(x_ba_copy),
                    num_class=hyperparameters['ID_class'],
                    alabel=l_a,
                    slabel=l_b,
                    teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher = self.criterion_teacher(
                    p_a_student, p_a_teacher) / p_a_student.size(0)

                _, p_b_student = self.id_b(scale2(x_ab_copy))
                p_b_student = log_sm(p_b_student)
                p_b_teacher = predict_label(
                    self.teacher_model,
                    scale2(x_ab_copy),
                    num_class=hyperparameters['ID_class'],
                    alabel=l_b,
                    slabel=l_a,
                    teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(
                    p_b_student, p_b_teacher) / p_b_student.size(0)
            elif hyperparameters['ID_style'] == 'AB':
                # normal teacher-student loss
                # BA -> LabelA(smooth) + LabelB(batchB)
                # 合成的图片经过身份鉴别器,得到每个ID可能的概率
                _, p_ba_student = self.id_a(scale2(x_ba_copy))  # f_a, s_b
                p_a_student = log_sm(p_ba_student[0])  # 两个身份预测的第一个预测值
                with torch.no_grad():
                    p_a_teacher = predict_label(
                        self.teacher_model,
                        scale2(x_ba_copy),
                        num_class=hyperparameters['ID_class'],
                        alabel=l_a,
                        slabel=l_b,
                        teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher = self.criterion_teacher(
                    p_a_student, p_a_teacher) / p_a_student.size(
                        0)  # 在老师模型监督下,x_ba身份预测损失  # 公式(8)

                _, p_ab_student = self.id_b(scale2(x_ab_copy))  # f_b, s_a
                p_b_student = log_sm(p_ab_student[0])
                with torch.no_grad():
                    p_b_teacher = predict_label(
                        self.teacher_model,
                        scale2(x_ab_copy),
                        num_class=hyperparameters['ID_class'],
                        alabel=l_b,
                        slabel=l_a,
                        teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(
                    p_b_student, p_b_teacher) / p_b_student.size(
                        0)  # 在老师模型监督下,x_ab身份预测损失   # 公式 (8)

                # branch b loss
                # here we give different label
                # 用Ea的第二个身份预测值计算身份预测损失,
                # 这就相当于是Ea输出两个向量,一个用来计算与老师模型的身份预测损失,另一个用来计算自身身份预测损失
                loss_B = self.id_criterion(
                    p_ba_student[1], l_b) + self.id_criterion(
                        p_ab_student[1], l_a)  # l_b 是b的label    # 公式(9)
                self.loss_teacher = hyperparameters[
                    'T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B
        else:
            self.loss_teacher = 0.0

        # auto-encoder image reconstruction
        self.loss_gen_recon_x_a = self.recon_criterion(
            x_a_recon, x_a)  # x_a_recon,  a 的 ap 和 a 的 st     # 公式 (1)
        self.loss_gen_recon_x_b = self.recon_criterion(
            x_b_recon, x_b)  # x_b_recon,  b 的 ap 和 b 的 st     # 公式 (1)
        self.loss_gen_recon_xp_a = self.recon_criterion(
            x_a_recon_p, x_a)  # x_a_recon_p, a 的 st 和 pos_a 的 ap   # 公式 (2)
        self.loss_gen_recon_xp_b = self.recon_criterion(
            x_b_recon_p, x_b)  # x_b_recon_p, b 的 st 和 pos_b 的 ap   # 公式 (2)

        # feature reconstruction
        self.loss_gen_recon_s_a = self.recon_criterion(
            s_a_recon, s_a) if hyperparameters[
                'recon_s_w'] > 0 else 0  # s_a_recon, 合成图片x_ab 的st   # 公式 (5)
        self.loss_gen_recon_s_b = self.recon_criterion(
            s_b_recon, s_b) if hyperparameters[
                'recon_s_w'] > 0 else 0  # s_b_recon, 合成图片x_ba 的st   # 公式 (5)
        self.loss_gen_recon_f_a = self.recon_criterion(
            f_a_recon, f_a) if hyperparameters[
                'recon_f_w'] > 0 else 0  # f_a_recon, 合成图片x_ba 的ap   # 公式 (4)
        self.loss_gen_recon_f_b = self.recon_criterion(
            f_b_recon, f_b) if hyperparameters[
                'recon_f_w'] > 0 else 0  # f_b_recon, 合成图片x_ab 的ap   # 公式 (4)

        x_aba = self.gen_a.decode(s_a_recon, f_a_recon) if hyperparameters[
            'recon_x_cyc_w'] > 0 else None  # x_aba,ab 的 st 与 ba 的 ap
        x_bab = self.gen_b.decode(s_b_recon, f_b_recon) if hyperparameters[
            'recon_x_cyc_w'] > 0 else None  # x_bab,ba 的 st 与 ab 的 ap

        # ID loss AND Tune the Generated image
        if hyperparameters['ID_style'] == 'PCB':
            self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b)
            self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b)
            self.loss_gen_recon_id = self.PCB_loss(
                p_a_recon, l_a) + self.PCB_loss(
                    p_b_recon, l_b)  # x_ba 与l_a, x_ab 与l_b 的身份预测损失
        elif hyperparameters['ID_style'] == 'AB':
            weight_B = hyperparameters['teacher_w'] * hyperparameters[
                'B_w']  # teather_w = 1.0, B_w = 0.2
            self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \
                         + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) )  # a和b的身份预测损失             # 公式(3)
            self.loss_pid = self.id_criterion(
                pp_a[0], l_a) + self.id_criterion(
                    pp_b[0], l_b)  # pos_a 和 pos_b 的身份预测损失  # 公式(3)
            self.loss_gen_recon_id = self.id_criterion(
                p_a_recon[0], l_a) + self.id_criterion(
                    p_b_recon[0], l_b)  # 不太懂为什么用了b的st   却要判定为a的label 公式(7)
        else:
            self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion(
                p_b, l_b)
            self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion(
                pp_b, l_b)
            self.loss_gen_recon_id = self.id_criterion(
                p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b)

        #print(f_a_recon, f_a)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters[
                'recon_x_cyc_w'] > 0 else 0  # x_aba,ab 的 st 与 ba 的 ap
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters[
                'recon_x_cyc_w'] > 0 else 0  # x_bab,ba 的 st 与 ab 的 ap
        # GAN loss
        if num_gpu > 1:
            self.loss_gen_adv_a = self.dis_a.module.calc_gen_loss(
                self.dis_a, x_ba)  # 公式(6)
            self.loss_gen_adv_b = self.dis_b.module.calc_gen_loss(
                self.dis_b, x_ab)  # 公式(6)
        else:
            self.loss_gen_adv_a = self.dis_a.calc_gen_loss(self.dis_a, x_ba)
            self.loss_gen_adv_b = self.dis_b.calc_gen_loss(self.dis_b, x_ab)

        # domain-invariant perceptual loss
        # 使用vgg,对合成图片和真实图片进行特征提取,然后计算两个特征loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0

        # 每个loss所占的权重
        if iteration > hyperparameters['warm_iter']:
            hyperparameters['recon_f_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'],
                                               hyperparameters['max_w'])
            hyperparameters['recon_s_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'],
                                               hyperparameters['max_w'])
            hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_x_cyc_w'] = min(
                hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w'])

        if iteration > hyperparameters['warm_teacher_iter']:
            hyperparameters['teacher_w'] += hyperparameters['warm_scale']
            hyperparameters['teacher_w'] = min(
                hyperparameters['teacher_w'], hyperparameters['max_teacher_w'])
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['id_w'] * self.loss_id + \
                              hyperparameters['pid_w'] * self.loss_pid + \
                              hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
                              hyperparameters['teacher_w'] * self.loss_teacher
        # 增大计算效率
        if self.fp16:
            with amp.scale_loss(self.loss_gen_total,
                                [self.gen_opt, self.id_opt]) as scaled_loss:
                scaled_loss.backward()
            self.gen_opt.step()
            self.id_opt.step()
        else:
            self.loss_gen_total.backward()  # 后向传播
            self.gen_opt.step()
            self.id_opt.step()
        print("L_total: %.4f, L_gan: %.4f,  Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \
                                                        hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \
                                                        hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \
                                                        hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \
                                                        hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \
                                                        hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \
                                                        hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \
                                                        hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \
                                                        hyperparameters['id_w'] * self.loss_id,\
                                                        hyperparameters['pid_w'] * self.loss_pid,\
hyperparameters['teacher_w'] * self.loss_teacher )  )

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def PCB_loss(self, inputs, labels):
        loss = 0.0
        for part in inputs:
            loss += self.id_criterion(part, labels)
        return loss / len(inputs)

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0)))
            s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0)))
            f_a, _ = self.id_a(scale2(x_a[i].unsqueeze(0)))
            f_b, _ = self.id_b(scale2(x_b[i].unsqueeze(0)))
            x_a_recon.append(self.gen_a.decode(s_a, f_a))
            x_b_recon.append(self.gen_b.decode(s_b, f_b))
            x_ba = self.gen_a.decode(s_b, f_a)
            x_ab = self.gen_b.decode(s_a, f_b)
            x_ba1.append(x_ba)
            x_ab1.append(x_ab)
            #cycle
            s_b_recon = self.gen_a.enc_content(self.single(x_ba))
            s_a_recon = self.gen_b.enc_content(self.single(x_ab))
            f_a_recon, _ = self.id_a(scale2(x_ba))
            f_b_recon, _ = self.id_b(scale2(x_ab))
            x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon))
            x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab)
        x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1)
        self.train()

        return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1

    def dis_update(self, x_ab, x_ba, x_a, x_b, hyperparameters,
                   num_gpu):  # 对判别器进行更新
        self.dis_opt.zero_grad()
        # D loss
        if num_gpu > 1:
            self.loss_dis_a, reg_a = self.dis_a.module.calc_dis_loss(
                self.dis_a, x_ba.detach(), x_a)  # lsgan 损失
            self.loss_dis_b, reg_b = self.dis_b.module.calc_dis_loss(
                self.dis_b, x_ab.detach(), x_b)
        else:
            self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(
                self.dis_a, x_ba.detach(), x_a)
            self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss(
                self.dis_b, x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        print("DLoss: %.4f" % self.loss_dis_total,
              "Reg: %.4f" % (reg_a + reg_b))
        if self.fp16:
            with amp.scale_loss(self.loss_dis_total,
                                self.dis_opt) as scaled_loss:
                scaled_loss.backward()
        else:
            self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):  # 调整学习率
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.id_scheduler is not None:
            self.id_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):  # load 网络
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b = self.gen_a
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b = self.dis_a
        # Load ID dis
        last_model_name = get_model_list(checkpoint_dir, "id")
        state_dict = torch.load(last_model_name)
        self.id_a.load_state_dict(state_dict['a'])
        self.id_b = self.id_a
        # Load optimizers
        try:
            state_dict = torch.load(
                os.path.join(checkpoint_dir, 'optimizer.pt'))
            self.dis_opt.load_state_dict(state_dict['dis'])
            self.gen_opt.load_state_dict(state_dict['gen'])
            self.id_opt.load_state_dict(state_dict['id'])
        except:
            pass
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations, num_gpu=1):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict()}, gen_name)
        if num_gpu > 1:
            torch.save({'a': self.dis_a.module.state_dict()}, dis_name)
        else:
            torch.save({'a': self.dis_a.state_dict()}, dis_name)
        torch.save({'a': self.id_a.state_dict()}, id_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'id': self.id_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Example #18
0
class UNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(UNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        if not hyperparameters['origin']:
            self.dis_a = MultiscaleDiscriminator(hyperparameters['input_dim_a'],        # discriminator for a
                    ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d, use_sigmoid=False,
                    num_D=2, getIntermFeat=True
                    )
            self.dis_b = MultiscaleDiscriminator(hyperparameters['input_dim_b'],        # discriminator for b
                    ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d, use_sigmoid=False,
                    num_D=2, getIntermFeat=True
                    )
            self.criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.FloatTensor)

        else:
            self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])
            self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])
            
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)


        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def compute_digits_differnce(self, digits1, digits2, weight=1.0):
        feat_diff = 0
        feat_weights = 4.0 / (3 + 1) # 3 layers's discrminator
        D_weights = 1.0 / 2.0  # number of discrminator
        for i in range(2):
            for j in range(len(digits2[i])-1):
                feat_diff += D_weights * feat_weights * \
                    F.l1_loss(digits2[i][j],
                            digits1[i][j].detach()) * weight
        return feat_diff



    def compute_gan_loss(self, real_digits, fake_digits, gan_cri,
            loss_at='None'):
        errD = None
        errG = None
        errG_feat = None
        if gan_cri is not None:
            if loss_at == 'D':
                errD = (gan_cri(real_digits, True) \
                        + gan_cri(fake_digits, False)) * 0.5
            elif loss_at == 'G':
                errG = gan_cri(fake_digits, True)
                errG_feat = self.compute_digits_differnce(real_digits, fake_digits,
                      weight=10.0)

        return errD, errG, errG_feat

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        h_a, _ = self.gen_a.encode(x_a)
        h_b, _ = self.gen_b.encode(x_b)
        x_ba, _ = self.gen_a.decode(h_b)
        x_ab, _ = self.gen_b.decode(h_a)
        self.train()
        return x_ab, x_ba

    def __compute_kl(self, mu):
        # def _compute_kl(self, mu, sd):
        # mu_2 = torch.pow(mu, 2)
        # sd_2 = torch.pow(sd, 2)
        # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0)
        # return encoding_loss
        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        ############ Encode #########################################$
        h_a, n_a = self.gen_a.encode(x_a)
        h_b, n_b = self.gen_b.encode(x_b)
        # decode (within domain)
        if hyperparameters['zero_z']:
            pre_Z = Variable(torch.zeros(hyperparameters['batch_size'], 
                    hyperparameters['gen']['z_num']).cuda())
        else:
            pre_Z = None

        ########### Reconstruction ###################################
        x_a_recon, _ = self.gen_a.decode(h_a + n_a, z_var=pre_Z)
        x_b_recon, _ = self.gen_b.decode(h_b + n_b, z_var=pre_Z)

        ##############################################################
        ########### Decode (Cross Domain) ############################
        ##############################################################

        ########## with random vector ################
        x_ba, z_var_ba_1 = self.gen_a.decode(h_b + n_b)
        x_ab, z_var_ab_1 = self.gen_b.decode(h_a + n_a)

        ########## with zero latent vector ###########
        x_ba_zero, _ = self.gen_a.decode(h_b + n_b, z_var=pre_Z)
        x_ab_zero, _ = self.gen_b.decode(h_a + n_a, z_var=pre_Z)

        ######## decode (cross domain the second time) ################
        if hyperparameters['loss_eg_weight'] != 0:
            x_ba_eg, z_var_ba_2 = self.gen_a.decode(h_b + n_b)
            x_ab_eg, z_var_ab_2 = self.gen_b.decode(h_a + n_a)
            x_ba_eg = x_ba_eg.detach()
            x_ab_eg = x_ab_eg.detach()
            if not hyperparameters['origin']:
                x_ba_eg_digits = self.dis_a(x_ba_eg)
                x_ab_eg_digits = self.dis_b(x_ab_eg)

        # encode again
        h_b_recon, n_b_recon = self.gen_a.encode(x_ba_zero)
        h_a_recon, n_a_recon = self.gen_b.encode(x_ab_zero)
        # decode again (if needed)
        x_aba, _ = self.gen_a.decode(h_a_recon + n_a_recon, 
                z_var=pre_Z
                ) if hyperparameters['recon_x_cyc_w'] > 0 else (None, 0)
        x_bab, _ = self.gen_b.decode(h_b_recon + n_b_recon, 
                z_var=pre_Z
                ) if hyperparameters['recon_x_cyc_w'] > 0 else (None, 0)

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
        self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
        self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) if x_aba is not None else 0
        self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) if x_bab is not None else 0
        self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
        self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)

        if hyperparameters['loss_eg_weight'] == 0:
            self.loss_gen_adv_a, self.loss_gen_adv_b, self.loss_gan_feat_a, \
                    self.loss_gan_feat_b = 0, 0, 0, 0
        elif not hyperparameters['origin']:
            x_ba_digits = self.dis_a(x_ba)
            x_a_digits = self.dis_a(x_a)
            _, self.loss_gen_adv_a, self.loss_gan_feat_a = \
                    self.compute_gan_loss(x_a_digits, x_ba_digits, 
                            self.criterionGAN, loss_at='G')

            x_ab_digits = self.dis_a(x_ab)
            x_b_digits = self.dis_a(x_b)
            _, self.loss_gen_adv_b, self.loss_gan_feat_b = \
                    self.compute_gan_loss(x_b_digits, x_ab_digits, 
                            self.criterionGAN, loss_at='G')
        else:
            self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
            self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
            self.loss_gan_feat_a, self.loss_gan_feat_b = 0, 0

        # GAN loss
        # self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        # self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        if hyperparameters['loss_eg_weight'] == 0:
            self.loss_eg = 0.0
        elif not hyperparameters['origin']:
            self.loss_eg = compute_eg_loss(x_ba_digits, x_ba_eg_digits, 
                    x_ab_digits, x_ab_eg_digits, z_var_ba_1, z_var_ba_2,
                     z_var_ab_1, z_var_ab_2, hyperparameters)
        else:
            self.loss_eg = compute_eg_loss(x_ba, x_ba_eg, 
                    x_ab, x_ab_eg, z_var_ba_1, z_var_ba_2, z_var_ab_1, z_var_ab_2, 
                    hyperparameters)
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              self.loss_gan_feat_b + self.loss_gan_feat_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
                              hyperparameters['loss_eg_weight'] * self.loss_eg
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], []
        for i in range(x_a.size(0)):
            h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0))
            h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(h_a)[0])
            x_b_recon.append(self.gen_b.decode(h_b)[0])
            x_ba.append(self.gen_a.decode(h_b)[0])
            x_ab.append(self.gen_b.decode(h_a)[0])
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba = torch.cat(x_ba)
        x_ab = torch.cat(x_ab)
        self.train()
        return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        # encode
        h_a, n_a = self.gen_a.encode(x_a)
        h_b, n_b = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba, _ = self.gen_a.decode(h_b + n_b)
        x_ab, _ = self.gen_b.decode(h_a + n_a)
        # D loss
        if not hyperparameters['origin']:
            real_digits_a = self.dis_a(x_a)
            fake_digits_a = self.dis_a(x_ba.detach())
            real_digits_b = self.dis_b(x_b)
            fake_digits_b = self.dis_b(x_ab.detach())

            self.loss_dis_a, _, _ = self.compute_gan_loss(real_digits_a, fake_digits_a, 
                    self.criterionGAN, loss_at='D')
            self.loss_dis_b, _, _ = self.compute_gan_loss(real_digits_b, fake_digits_b, 
                    self.criterionGAN, loss_at='D')
        else:
            self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
            self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)


        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Example #19
0
class DGNet_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(DGNet_Trainer, self).__init__()
        lr_g = hyperparameters['lr_g']
        lr_d = hyperparameters['lr_d']
        ID_class = hyperparameters['ID_class']
        if not 'apex' in hyperparameters.keys():
            hyperparameters['apex'] = False
        self.fp16 = hyperparameters['apex']
        # Initiate the networks
        # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False.
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'],
                              hyperparameters['gen'],
                              fp16=False)  # auto-encoder for domain a
        self.gen_b = self.gen_a  # auto-encoder for domain b

        if not 'ID_stride' in hyperparameters.keys():
            hyperparameters['ID_stride'] = 2

        if hyperparameters['ID_style'] == 'PCB':
            self.id_a = PCB(ID_class)
        elif hyperparameters['ID_style'] == 'AB':
            self.id_a = ft_netAB(ID_class,
                                 stride=hyperparameters['ID_stride'],
                                 norm=hyperparameters['norm_id'],
                                 pool=hyperparameters['pool'])
        else:
            self.id_a = ft_net(ID_class,
                               norm=hyperparameters['norm_id'],
                               pool=hyperparameters['pool'])  # return 2048 now

        self.id_b = self.id_a
        self.dis_a = MsImageDis(3, hyperparameters['dis'],
                                fp16=False)  # discriminator for domain a
        self.dis_b = self.dis_a  # discriminator for domain b

        # load teachers
        if hyperparameters['teacher'] != "":
            teacher_name = hyperparameters['teacher']
            print(teacher_name)
            teacher_names = teacher_name.split(',')
            teacher_model = nn.ModuleList()
            teacher_count = 0
            for teacher_name in teacher_names:
                config_tmp = load_config(teacher_name)
                if 'stride' in config_tmp:
                    stride = config_tmp['stride']
                else:
                    stride = 2
                model_tmp = ft_net(ID_class, stride=stride)
                teacher_model_tmp = load_network(model_tmp, teacher_name)
                teacher_model_tmp.model.fc = nn.Sequential(
                )  # remove the original fc layer in ImageNet
                teacher_model_tmp = teacher_model_tmp.cuda()
                if self.fp16:
                    teacher_model_tmp = amp.initialize(teacher_model_tmp,
                                                       opt_level="O1")
                teacher_model.append(teacher_model_tmp.cuda().eval())
                teacher_count += 1
            self.teacher_model = teacher_model
            if hyperparameters['train_bn']:
                self.teacher_model = self.teacher_model.apply(train_bn)

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        display_size = int(hyperparameters['display_size'])

        # RGB to one channel
        if hyperparameters['single'] == 'edge':
            self.single = to_edge
        else:
            self.single = to_gray(False)

        # Random Erasing when training
        if not 'erasing_p' in hyperparameters.keys():
            hyperparameters['erasing_p'] = 0
        self.single_re = RandomErasing(
            probability=hyperparameters['erasing_p'], mean=[0.0, 0.0, 0.0])

        if not 'T_w' in hyperparameters.keys():
            hyperparameters['T_w'] = 1
        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(
            self.dis_a.parameters())  #+ list(self.dis_b.parameters())
        gen_params = list(
            self.gen_a.parameters())  #+ list(self.gen_b.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr_d,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr_g,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        # id params
        if hyperparameters['ID_style'] == 'PCB':
            ignored_params = (
                list(map(id, self.id_a.classifier0.parameters())) +
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())) +
                list(map(id, self.id_a.classifier3.parameters())))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier0.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier3.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        elif hyperparameters['ID_style'] == 'AB':
            ignored_params = (
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        else:
            ignored_params = list(map(id, self.id_a.classifier.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)

        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.id_scheduler = get_scheduler(self.id_opt, hyperparameters)
        self.id_scheduler.gamma = hyperparameters['gamma2']

        #ID Loss
        self.id_criterion = nn.CrossEntropyLoss()
        self.criterion_teacher = nn.KLDivLoss(size_average=False)
        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # save memory
        if self.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.id_a = self.id_a.cuda()

            self.gen_b = self.gen_a
            self.dis_b = self.dis_a
            self.id_b = self.id_a

            self.gen_a, self.gen_opt = amp.initialize(self.gen_a,
                                                      self.gen_opt,
                                                      opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a,
                                                      self.dis_opt,
                                                      opt_level="O1")
            self.id_a, self.id_opt = amp.initialize(self.id_a,
                                                    self.id_opt,
                                                    opt_level="O1")

    def to_re(self, x):
        out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3))
        out = out.cuda()
        for i in range(x.size(0)):
            out[i, :, :, :] = self.single_re(x[i, :, :, :])
        return out

    def recon_criterion(self, input, target):
        diff = input - target.detach()
        return torch.mean(torch.abs(diff[:]))

    def recon_criterion_sqrt(self, input, target):
        diff = input - target
        return torch.mean(torch.sqrt(torch.abs(diff[:]) + 1e-8))

    def recon_criterion2(self, input, target):
        diff = input - target
        return torch.mean(diff[:]**2)

    def recon_cos(self, input, target):
        cos = torch.nn.CosineSimilarity()
        cos_dis = 1 - cos(input, target)
        return torch.mean(cos_dis[:])

    def forward(self, x_a, x_b):
        self.eval()
        s_a = self.gen_a.encode(self.single(x_a))
        s_b = self.gen_b.encode(self.single(x_b))
        f_a, _ = self.id_a(scale2(x_a))
        f_b, _ = self.id_b(scale2(x_b))
        x_ba = self.gen_a.decode(s_b, f_a)
        x_ab = self.gen_b.decode(s_a, f_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, l_a, xp_a, x_b, l_b, xp_b, hyperparameters,
                   iteration):
        # ppa, ppb is the same person
        self.gen_opt.zero_grad()
        self.id_opt.zero_grad()
        # encode
        s_a = self.gen_a.encode(self.single(x_a))
        s_b = self.gen_b.encode(self.single(x_b))
        f_a, p_a = self.id_a(scale2(x_a))
        f_b, p_b = self.id_b(scale2(x_b))
        # autodecode
        x_a_recon = self.gen_a.decode(s_a, f_a)
        x_b_recon = self.gen_b.decode(s_b, f_b)

        # encode the same ID different photo
        fp_a, pp_a = self.id_a(scale2(xp_a))
        fp_b, pp_b = self.id_b(scale2(xp_b))

        # decode the same person
        x_a_recon_p = self.gen_a.decode(s_a, fp_a)
        x_b_recon_p = self.gen_b.decode(s_b, fp_b)

        # has gradient
        x_ba = self.gen_a.decode(s_b, f_a)
        x_ab = self.gen_b.decode(s_a, f_b)
        # no gradient
        x_ba_copy = Variable(x_ba.data, requires_grad=False)
        x_ab_copy = Variable(x_ab.data, requires_grad=False)

        rand_num = random.uniform(0, 1)
        #################################
        # encode structure
        if hyperparameters['use_encoder_again'] >= rand_num:
            # encode again (encoder is tuned, input is fixed)
            s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy))
            s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy))
        else:
            # copy the encoder
            self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content)
            self.enc_content_copy = self.enc_content_copy.eval()
            # encode again (encoder is fixed, input is tuned)
            s_a_recon = self.enc_content_copy(self.single(x_ab))
            s_b_recon = self.enc_content_copy(self.single(x_ba))

        #################################
        # encode appearance
        self.id_a_copy = copy.deepcopy(self.id_a)
        self.id_a_copy = self.id_a_copy.eval()
        if hyperparameters['train_bn']:
            self.id_a_copy = self.id_a_copy.apply(train_bn)
        self.id_b_copy = self.id_a_copy
        # encode again (encoder is fixed, input is tuned)
        f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba))
        f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab))

        # teacher Loss
        #  Tune the ID model
        log_sm = nn.LogSoftmax(dim=1)
        if hyperparameters['teacher_w'] > 0 and hyperparameters[
                'teacher'] != "":
            if hyperparameters['ID_style'] == 'normal':
                _, p_a_student = self.id_a(scale2(x_ba_copy))
                p_a_student = log_sm(p_a_student)
                p_a_teacher = predict_label(self.teacher_model,
                                            scale2(x_ba_copy))
                self.loss_teacher = self.criterion_teacher(
                    p_a_student, p_a_teacher) / p_a_student.size(0)

                _, p_b_student = self.id_b(scale2(x_ab_copy))
                p_b_student = log_sm(p_b_student)
                p_b_teacher = predict_label(self.teacher_model,
                                            scale2(x_ab_copy))
                self.loss_teacher += self.criterion_teacher(
                    p_b_student, p_b_teacher) / p_b_student.size(0)
            elif hyperparameters['ID_style'] == 'AB':
                # normal teacher-student loss
                # BA -> LabelA(smooth) + LabelB(batchB)
                _, p_ba_student = self.id_a(scale2(x_ba_copy))  # f_a, s_b
                p_a_student = log_sm(p_ba_student[0])
                with torch.no_grad():
                    p_a_teacher = predict_label(
                        self.teacher_model,
                        scale2(x_ba_copy),
                        num_class=hyperparameters['ID_class'],
                        alabel=l_a,
                        slabel=l_b,
                        teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher = self.criterion_teacher(
                    p_a_student, p_a_teacher) / p_a_student.size(0)

                _, p_ab_student = self.id_b(scale2(x_ab_copy))  # f_b, s_a
                p_b_student = log_sm(p_ab_student[0])
                with torch.no_grad():
                    p_b_teacher = predict_label(
                        self.teacher_model,
                        scale2(x_ab_copy),
                        num_class=hyperparameters['ID_class'],
                        alabel=l_b,
                        slabel=l_a,
                        teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(
                    p_b_student, p_b_teacher) / p_b_student.size(0)

                # branch b loss
                # here we give different label
                loss_B = self.id_criterion(p_ba_student[1],
                                           l_b) + self.id_criterion(
                                               p_ab_student[1], l_a)
                self.loss_teacher = hyperparameters[
                    'T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B
        else:
            self.loss_teacher = 0.0

        # decode again (if needed)
        if hyperparameters['use_decoder_again']:
            x_aba = self.gen_a.decode(
                s_a_recon,
                f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
            x_bab = self.gen_b.decode(
                s_b_recon,
                f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
        else:
            self.mlp_w_copy = copy.deepcopy(self.gen_a.mlp_w)
            self.mlp_b_copy = copy.deepcopy(self.gen_a.mlp_b)
            self.dec_copy = copy.deepcopy(self.gen_a.dec)  # Error
            ID = f_a_recon
            ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1)
            adain_params_w = self.mlp_w_copy(ID_Style)
            adain_params_b = self.mlp_b_copy(ID_Style)
            self.gen_a.assign_adain_params(adain_params_w, adain_params_b,
                                           self.dec_copy)
            x_aba = self.dec_copy(
                s_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

            ID = f_b_recon
            ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1)
            adain_params_w = self.mlp_w_copy(ID_Style)
            adain_params_b = self.mlp_b_copy(ID_Style)
            self.gen_a.assign_adain_params(adain_params_w, adain_params_b,
                                           self.dec_copy)
            x_bab = self.dec_copy(
                s_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # auto-encoder image reconstruction
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a)
        self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b)

        # feature reconstruction
        self.loss_gen_recon_s_a = self.recon_criterion(
            s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_s_b = self.recon_criterion(
            s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_f_a = self.recon_criterion(
            f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0
        self.loss_gen_recon_f_b = self.recon_criterion(
            f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0

        # Random Erasing only effect the ID and PID loss.
        if hyperparameters['erasing_p'] > 0:
            x_a_re = self.to_re(scale2(x_a.clone()))
            x_b_re = self.to_re(scale2(x_b.clone()))
            xp_a_re = self.to_re(scale2(xp_a.clone()))
            xp_b_re = self.to_re(scale2(xp_b.clone()))
            _, p_a = self.id_a(x_a_re)
            _, p_b = self.id_b(x_b_re)
            # encode the same ID different photo
            _, pp_a = self.id_a(xp_a_re)
            _, pp_b = self.id_b(xp_b_re)

        # ID loss AND Tune the Generated image
        if hyperparameters['ID_style'] == 'PCB':
            self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b)
            self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b)
            self.loss_gen_recon_id = self.PCB_loss(
                p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b)
        elif hyperparameters['ID_style'] == 'AB':
            weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w']
            self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \
                         + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) )
            self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion(
                pp_b[0], l_b
            )  #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) )
            self.loss_gen_recon_id = self.id_criterion(
                p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b)
        else:
            self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion(
                p_b, l_b)
            self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion(
                pp_b, l_b)
            self.loss_gen_recon_id = self.id_criterion(
                p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b)

        #print(f_a_recon, f_a)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0

        if iteration > hyperparameters['warm_iter']:
            hyperparameters['recon_f_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'],
                                               hyperparameters['max_w'])
            hyperparameters['recon_s_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'],
                                               hyperparameters['max_w'])
            hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_x_cyc_w'] = min(
                hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w'])

        if iteration > hyperparameters['warm_teacher_iter']:
            hyperparameters['teacher_w'] += hyperparameters['warm_scale']
            hyperparameters['teacher_w'] = min(
                hyperparameters['teacher_w'], hyperparameters['max_teacher_w'])
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['id_w'] * self.loss_id + \
                              hyperparameters['pid_w'] * self.loss_pid + \
                              hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
                              hyperparameters['teacher_w'] * self.loss_teacher
        if self.fp16:
            with amp.scale_loss(self.loss_gen_total,
                                [self.gen_opt, self.id_opt]) as scaled_loss:
                scaled_loss.backward()
            self.gen_opt.step()
            self.id_opt.step()
        else:
            self.loss_gen_total.backward()
            self.gen_opt.step()
            self.id_opt.step()
        print("L_total: %.4f, L_gan: %.4f,  Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \
                                                        hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \
                                                        hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \
                                                        hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \
                                                        hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \
                                                        hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \
                                                        hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \
                                                        hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \
                                                        hyperparameters['id_w'] * self.loss_id,\
                                                        hyperparameters['pid_w'] * self.loss_pid,\
hyperparameters['teacher_w'] * self.loss_teacher )  )

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def PCB_loss(self, inputs, labels):
        loss = 0.0
        for part in inputs:
            loss += self.id_criterion(part, labels)
        return loss / len(inputs)

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0)))
            s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0)))
            f_a, _ = self.id_a(scale2(x_a[i].unsqueeze(0)))
            f_b, _ = self.id_b(scale2(x_b[i].unsqueeze(0)))
            x_a_recon.append(self.gen_a.decode(s_a, f_a))
            x_b_recon.append(self.gen_b.decode(s_b, f_b))
            x_ba = self.gen_a.decode(s_b, f_a)
            x_ab = self.gen_b.decode(s_a, f_b)
            x_ba1.append(x_ba)
            x_ab1.append(x_ab)
            #cycle
            s_b_recon = self.gen_a.enc_content(self.single(x_ba))
            s_a_recon = self.gen_b.enc_content(self.single(x_ab))
            f_a_recon, _ = self.id_a(scale2(x_ba))
            f_b_recon, _ = self.id_b(scale2(x_ab))
            x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon))
            x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab)
        x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1)
        self.train()

        return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        # encode
        s_a = self.gen_a.encode(self.single(x_a))
        s_b = self.gen_b.encode(self.single(x_b))
        f_a, _ = self.id_a(scale2(x_a))
        f_b, _ = self.id_b(scale2(x_b))
        # decode (cross domain)
        x_ba = self.gen_a.decode(s_b, f_a)
        x_ab = self.gen_b.decode(s_a, f_b)
        # D loss
        self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        print("DLoss: %.4f" % self.loss_dis_total,
              "Reg: %.4f" % (reg_a + reg_b))
        if self.fp16:
            with amp.scale_loss(self.loss_dis_total,
                                self.dis_opt) as scaled_loss:
                scaled_loss.backward()
        else:
            self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.id_scheduler is not None:
            self.id_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b = self.gen_a
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b = self.dis_a
        # Load ID dis
        last_model_name = get_model_list(checkpoint_dir, "id")
        state_dict = torch.load(last_model_name)
        self.id_a.load_state_dict(state_dict['a'])
        self.id_b = self.id_a
        # Load optimizers
        try:
            state_dict = torch.load(
                os.path.join(checkpoint_dir, 'optimizer.pt'))
            self.dis_opt.load_state_dict(state_dict['dis'])
            self.gen_opt.load_state_dict(state_dict['gen'])
            self.id_opt.load_state_dict(state_dict['id'])
        except:
            pass
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict()}, dis_name)
        torch.save({'a': self.id_a.state_dict()}, id_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'id': self.id_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
class UNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(UNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = VAEGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = VAEGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b

        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.dis_content = Dis_content()
        self.gpuid = hyperparameters['gpuID']
        # @ add backgound discriminator for each domain
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.content_opt = torch.optim.Adam(
            self.dis_content.parameters(),
            lr=lr / 2.,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.content_scheduler = get_scheduler(self.content_opt,
                                               hyperparameters)

        # Network weight initialization
        self.gen_a.apply(weights_init(hyperparameters['init']))
        self.gen_b.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))
        self.dis_content.apply(weights_init('gaussian'))

        # initialize the blur network
        self.BGBlur_kernel = [5, 9, 15]
        self.BlurNet = [
            GaussionSmoothLayer(3, k_size, 25).cuda(self.gpuid)
            for k_size in self.BGBlur_kernel
        ]
        self.BlurWeight = [0.25, 0.5, 1.]
        self.Gradient = GradientLoss(3, 3)

        # # Load VGG model if needed for test
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg19()
            if torch.cuda.is_available():
                self.vgg.cuda(self.gpuid)
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        h_a = self.gen_a.encode_cont(x_a)
        # h_a_sty = self.gen_a.encode_sty(x_a)
        # h_b = self.gen_b.encode_cont(x_b)

        x_ab = self.gen_b.decode_cont(h_a)
        # h_c = torch.cat((h_b, h_a_sty), 1)
        # x_ba = self.gen_a.decode_recs(h_c)
        # self.train()
        return x_ab  #, x_ba

    def __compute_kl(self, mu):
        # def _compute_kl(self, mu, sd):
        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def content_update(self, x_a, x_b, hyperparameters):  #
        # encode
        self.content_opt.zero_grad()
        enc_a = self.gen_a.encode_cont(x_a)
        enc_b = self.gen_b.encode_cont(x_b)
        pred_fake = self.dis_content.forward(enc_a)
        pred_real = self.dis_content.forward(enc_b)
        loss_D = 0
        if hyperparameters['gan_type'] == 'lsgan':
            loss_D += torch.mean((pred_fake - 0)**2) + torch.mean(
                (pred_real - 1)**2)
        elif hyperparameters['gan_type'] == 'nsgan':
            all0 = Variable(torch.zeros_like(pred_fake.data).cuda(self.gpuid),
                            requires_grad=False)
            all1 = Variable(torch.ones_like(pred_real.data).cuda(self.gpuid),
                            requires_grad=False)
            loss_D += torch.mean(
                F.binary_cross_entropy(F.sigmoid(pred_fake), all0) +
                F.binary_cross_entropy(F.sigmoid(pred_real), all1))
        else:
            assert 0, "Unsupported GAN type: {}".format(
                hyperparameters['gan_type'])
        loss_D.backward()
        nn.utils.clip_grad_norm_(self.dis_content.parameters(), 5)
        self.content_opt.step()

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        self.content_opt.zero_grad()
        # encode
        h_a = self.gen_a.encode_cont(x_a)
        h_b = self.gen_b.encode_cont(x_b)
        h_a_sty = self.gen_a.encode_sty(x_a)

        # add domain adverisal loss for generator
        out_a = self.dis_content(h_a)
        out_b = self.dis_content(h_b)
        self.loss_ContentD = 0
        if hyperparameters['gan_type'] == 'lsgan':
            self.loss_ContentD += torch.mean((out_a - 0.5)**2) + torch.mean(
                (out_b - 0.5)**2)
        elif hyperparameters['gan_type'] == 'nsgan':
            all1 = Variable(0.5 * torch.ones_like(out_b.data).cuda(self.gpuid),
                            requires_grad=False)
            self.loss_ContentD += torch.mean(
                F.binary_cross_entropy(F.sigmoid(out_a), all1) +
                F.binary_cross_entropy(F.sigmoid(out_b), all1))
        else:
            assert 0, "Unsupported GAN type: {}".format(
                hyperparameters['gan_type'])

        # decode (within domain)
        h_a_cont = torch.cat((h_a, h_a_sty), 1)
        noise_a = torch.randn(h_a_cont.size()).cuda(h_a_cont.data.get_device())
        x_a_recon = self.gen_a.decode_recs(h_a_cont + noise_a)
        noise_b = torch.randn(h_b.size()).cuda(h_b.data.get_device())
        x_b_recon = self.gen_b.decode_cont(h_b + noise_b)

        # decode (cross domain)
        h_ba_cont = torch.cat((h_b, h_a_sty), 1)
        x_ba = self.gen_a.decode_recs(h_ba_cont + noise_a)
        x_ab = self.gen_b.decode_cont(h_a + noise_b)

        # encode again
        h_b_recon = self.gen_a.encode_cont(x_ba)
        h_b_sty_recon = self.gen_a.encode_sty(x_ba)

        h_a_recon = self.gen_b.encode_cont(x_ab)

        # decode again (if needed)
        h_a_cat_recs = torch.cat((h_a_recon, h_b_sty_recon), 1)

        x_aba = self.gen_a.decode_recs(
            h_a_cat_recs +
            noise_a) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode_cont(
            h_b_recon +
            noise_b) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
        self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
        self.loss_gen_recon_kl_sty = self.__compute_kl(h_a_sty)

        self.loss_gen_cyc_x_a = self.recon_criterion(
            x_aba, x_a) if x_aba is not None else 0
        self.loss_gen_cyc_x_b = self.recon_criterion(
            x_bab, x_b) if x_aba is not None else 0
        self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
        self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)
        self.loss_gen_recon_kl_cyc_sty = self.__compute_kl(h_b_sty_recon)

        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)

        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0

        # add background guide loss
        self.loss_bgm = 0
        if hyperparameters['BGM'] != 0:
            for index, weight in enumerate(self.BlurWeight):
                out_b = self.BlurNet[index](x_ba)
                out_real_b = self.BlurNet[index](x_b)
                out_a = self.BlurNet[index](x_ab)
                out_real_a = self.BlurNet[index](x_a)
                grad_loss_b = self.recon_criterion(out_b, out_real_b)
                grad_loss_a = self.recon_criterion(out_a, out_real_a)
                self.loss_bgm += weight * (grad_loss_a + grad_loss_b)
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_sty + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_sty + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
                              hyperparameters['BGM'] * self.loss_bgm + \
                              hyperparameters['gan_w'] * self.loss_ContentD
        self.loss_gen_total.backward()
        self.gen_opt.step()
        self.content_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):
        if x_a is None or x_b is None:
            return None
        self.eval()
        x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], []
        for i in range(x_a.size(0)):
            h_a = self.gen_a.encode_cont(x_a[i].unsqueeze(0))
            h_a_sty = self.gen_a.encode_sty(x_a[i].unsqueeze(0))
            h_b = self.gen_b.encode_cont(x_b[i].unsqueeze(0))

            h_ba_cont = torch.cat((h_b, h_a_sty), 1)

            h_aa_cont = torch.cat((h_a, h_a_sty), 1)

            x_a_recon.append(self.gen_a.decode_recs(h_aa_cont))
            x_b_recon.append(self.gen_b.decode_cont(h_b))

            x_ba.append(self.gen_a.decode_recs(h_ba_cont))
            x_ab.append(self.gen_b.decode_cont(h_a))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba = torch.cat(x_ba)
        x_ab = torch.cat(x_ab)
        self.train()
        return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        self.content_opt.zero_grad()
        # encode
        h_a = self.gen_a.encode_cont(x_a)
        h_a_sty = self.gen_a.encode_sty(x_a)
        h_b = self.gen_b.encode_cont(x_b)

        # # @ add content adversial
        out_a = self.dis_content(h_a)
        out_b = self.dis_content(h_b)
        self.loss_ContentD = 0
        if hyperparameters['gan_type'] == 'lsgan':
            self.loss_ContentD += torch.mean((out_a - 0)**2) + torch.mean(
                (out_b - 1)**2)
        elif hyperparameters['gan_type'] == 'nsgan':
            all0 = Variable(torch.zeros_like(out_a.data).cuda(self.gpuid),
                            requires_grad=False)
            all1 = Variable(torch.ones_like(out_b.data).cuda(self.gpuid),
                            requires_grad=False)
            self.loss_ContentD += torch.mean(
                F.binary_cross_entropy(F.sigmoid(out_a), all0) +
                F.binary_cross_entropy(F.sigmoid(out_b), all1))
        else:
            assert 0, "Unsupported GAN type: {}".format(
                hyperparameters['gan_type'])

        # decode (cross domain)
        h_cat = torch.cat((h_b, h_a_sty), 1)
        noise_b = torch.randn(h_cat.size()).cuda(h_cat.data.get_device())
        x_ba = self.gen_a.decode_recs(h_cat + noise_b)
        noise_a = torch.randn(h_a.size()).cuda(h_a.data.get_device())
        x_ab = self.gen_b.decode_cont(h_a + noise_a)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)

        self.loss_dis_total = hyperparameters['gan_w'] * (
            self.loss_dis_a + self.loss_dis_b + self.loss_ContentD)
        self.loss_dis_total.backward()
        nn.utils.clip_grad_norm_(self.dis_content.parameters(),
                                 5)  # dis_content update
        self.dis_opt.step()
        self.content_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.content_scheduler is not None:
            self.content_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis_00188000")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])

        # load discontent discriminator
        last_model_name = get_model_list(checkpoint_dir, "dis_Content")
        state_dict = torch.load(last_model_name)
        self.dis_content.load_state_dict(state_dict['dis_c'])

        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        self.content_opt.load_state_dict(state_dict['dis_content'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        self.content_scheduler = get_scheduler(self.content_opt,
                                               hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        dis_Con_name = os.path.join(snapshot_dir,
                                    'dis_Content_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save({'dis_c': self.dis_content.state_dict()}, dis_Con_name)

        #  opt state
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict(), \
                                                    'dis_content':self.content_opt.state_dict()}, opt_name)
Example #21
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters["lr"]
        self.newsize = hyperparameters["crop_image_height"]
        self.semantic_w = hyperparameters["semantic_w"] > 0
        self.recon_mask = hyperparameters["recon_mask"] == 1
        self.dann_scheduler = None
        self.full_adaptation = hyperparameters["adaptation"][
            "full_adaptation"] == 1
        dim = hyperparameters["gen"]["dim"]
        n_downsample = hyperparameters["gen"]["n_downsample"]
        latent_dim = dim * (2**n_downsample)

        if "domain_adv_w" in hyperparameters.keys():
            self.domain_classif_ab = hyperparameters["domain_adv_w"] > 0
        else:
            self.domain_classif_ab = False

        if hyperparameters["adaptation"]["dfeat_lambda"] > 0:
            self.use_classifier_sr = True
        else:
            self.use_classifier_sr = False

        if hyperparameters["adaptation"]["sem_seg_lambda"] > 0:
            self.train_seg = True
        else:
            self.train_seg = False

        if hyperparameters["adaptation"]["output_classifier_lambda"] > 0:
            self.use_output_classifier_sr = True
        else:
            self.use_output_classifier_sr = False

        self.gen = SpadeGen(hyperparameters["input_dim_a"],
                            hyperparameters["gen"])

        # Note: the "+1" is for the masks
        if hyperparameters["dis"]["type"] == "patchgan":
            print("Using patchgan discrminator...")
            self.dis_a = MultiscaleDiscriminator(
                hyperparameters["input_dim_a"],
                hyperparameters["dis"])  # discriminator for domain a
            self.dis_b = MultiscaleDiscriminator(
                hyperparameters["input_dim_b"],
                hyperparameters["dis"])  # discriminator for domain b
            self.instancenorm = nn.InstanceNorm2d(512, affine=False)

            self.dis_a_masked = MultiscaleDiscriminator(
                hyperparameters["input_dim_a"],
                hyperparameters["dis"])  # discriminator for domain a
            self.dis_b_masked = MultiscaleDiscriminator(
                hyperparameters["input_dim_b"],
                hyperparameters["dis"])  # discriminator for domain b
            self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        else:
            self.dis_a = MsImageDis(
                hyperparameters["input_dim_a"],
                hyperparameters["dis"])  # discriminator for domain a
            self.dis_b = MsImageDis(
                hyperparameters["input_dim_b"],
                hyperparameters["dis"])  # discriminator for domain b
            self.instancenorm = nn.InstanceNorm2d(512, affine=False)
            self.dis_a_masked = MsImageDis(
                hyperparameters["input_dim_a"],
                hyperparameters["dis"])  # discriminator for domain a
            self.dis_b_masked = MsImageDis(
                hyperparameters["input_dim_b"],
                hyperparameters["dis"])  # discriminator for domain b
            self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # fix the noise usd in sampling
        display_size = int(hyperparameters["display_size"])
        # Setup the optimizers
        beta1 = hyperparameters["beta1"]
        beta2 = hyperparameters["beta2"]
        dis_params = (list(self.dis_a.parameters()) +
                      list(self.dis_b.parameters()) +
                      list(self.dis_a_masked.parameters()) +
                      list(self.dis_b_masked.parameters()))

        gen_params = list(self.gen.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters["weight_decay"],
        )
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters["weight_decay"],
        )
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters["init"]))
        self.dis_a.apply(weights_init("gaussian"))
        self.dis_b.apply(weights_init("gaussian"))
        self.dis_a_masked.apply(weights_init("gaussian"))
        self.dis_b_masked.apply(weights_init("gaussian"))

        # Load VGG model if needed
        if hyperparameters["vgg_w"] > 0:
            self.criterionVGG = VGGLoss()

        # Load semantic segmentation model if needed
        if "semantic_w" in hyperparameters.keys(
        ) and hyperparameters["semantic_w"] > 0:
            self.segmentation_model = load_segmentation_model(
                hyperparameters["semantic_ckpt_path"], 19)
            self.segmentation_model.eval()
            for param in self.segmentation_model.parameters():
                param.requires_grad = False

        # Load domain classifier if needed
        if "domain_adv_w" in hyperparameters.keys(
        ) and hyperparameters["domain_adv_w"] > 0:
            self.domain_classifier_ab = domainClassifier(input_dim=latent_dim,
                                                         dim=256)
            dann_params = list(self.domain_classifier_ab.parameters())
            self.dann_opt = torch.optim.Adam(
                [p for p in dann_params if p.requires_grad],
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=hyperparameters["weight_decay"],
            )
            self.domain_classifier_ab.apply(weights_init("gaussian"))
            self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters)

        # Load classifier on features for syn, real adaptation
        if self.use_classifier_sr:
            #! Hardcoded
            self.domain_classifier_sr_b = domainClassifier(
                input_dim=latent_dim, dim=256)
            self.domain_classifier_sr_a = domainClassifier(
                input_dim=latent_dim, dim=256)

            dann_params = list(
                self.domain_classifier_sr_a.parameters()) + list(
                    self.domain_classifier_sr_b.parameters())
            self.classif_opt_sr = torch.optim.Adam(
                [p for p in dann_params if p.requires_grad],
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=hyperparameters["weight_decay"],
            )
            self.domain_classifier_sr_a.apply(weights_init("gaussian"))
            self.domain_classifier_sr_b.apply(weights_init("gaussian"))
            self.classif_sr_scheduler = get_scheduler(self.classif_opt_sr,
                                                      hyperparameters)

        if self.use_output_classifier_sr:
            if self.hyperparameters["dis"]["type"] == "patchgan":
                self.output_classifier_sr_a = MultiscaleDiscriminator(
                    hyperparameters["input_dim_a"],
                    hyperparameters["dis"])  # discriminator for domain a,sr
                self.output_classifier_sr_b = MultiscaleDiscriminator(
                    hyperparameters["input_dim_a"],
                    hyperparameters["dis"])  # discriminator for domain b,sr

            else:
                self.output_classifier_sr_a = MsImageDis(
                    hyperparameters["input_dim_a"],
                    hyperparameters["dis"])  # discriminator for domain a,sr
                self.output_classifier_sr_b = MsImageDis(
                    hyperparameters["input_dim_a"],
                    hyperparameters["dis"])  # discriminator for domain b,sr

            dann_params = list(
                self.output_classifier_sr_a.parameters()) + list(
                    self.output_classifier_sr_b.parameters())
            self.output_classif_opt_sr = torch.optim.Adam(
                [p for p in dann_params if p.requires_grad],
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=hyperparameters["weight_decay"],
            )
            self.output_classifier_sr_b.apply(weights_init("gaussian"))
            self.output_classifier_sr_a.apply(weights_init("gaussian"))
            self.output_scheduler_sr = get_scheduler(
                self.output_classif_opt_sr, hyperparameters)

        if self.train_seg:
            pretrained = load_segmentation_model(
                hyperparameters["semantic_ckpt_path"], 19)
            last_layer = nn.Conv2d(512, 10, kernel_size=1)
            model = torch.nn.Sequential(
                *list(pretrained.resnet34_8s.children())[7:-1],
                last_layer.cuda())
            self.segmentation_head = model

            for param in self.segmentation_head.parameters():
                param.requires_grad = True

            dann_params = list(self.segmentation_head.parameters())
            self.segmentation_opt = torch.optim.Adam(
                [p for p in dann_params if p.requires_grad],
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=hyperparameters["weight_decay"],
            )
            self.scheduler_seg = get_scheduler(self.segmentation_opt,
                                               hyperparameters)

    def recon_criterion(self, input, target):
        """
        Compute pixelwise L1 loss between two images input and target
        
        Arguments:
            input {torch.Tensor} -- Image tensor
            target {torch.Tensor} -- Image tensor
        
        Returns:
            torch.Float -- pixelwise L1 loss
        """
        return torch.mean(torch.abs(input - target))

    def recon_criterion_mask(self, input, target, mask):
        """
        Compute a weaker version of the recon_criterion between two images input and target 
        where the L1 is only computed on the unmasked region
        
        Arguments:
            input {torch.Tensor} -- Image (original image such as x_a)
            target {torch.Tensor} -- Image (after cycle-translation image x_aba)
            mask {} -- binary Mask of size HxW (input.shape ~ CxHxW)
        
        Returns:
            torch.Float -- L1 loss over input.(1-mask) and target.(1-mask)
        """
        return torch.mean(torch.abs(torch.mul((input - target), 1 - mask)))

    def forward(self, x_a, x_b, m_a, m_b):
        """
        Perform the translation from domain A (resp B) to domain B (resp A): x_a to x_ab (resp: x_b to x_ba).
        
        Arguments:
            x_a {torch.Tensor} -- Image from domain A after transform in tensor format
            x_b {torch.Tensor} -- Image from domain B after transform in tensor format
        
        Returns:
            torch.Tensor, torch.Tensor -- Translated version of x_a in domain B, Translated version of x_b in domain A
        """
        self.eval()
        x_a_augment = torch.cat([x_a, m_a], dim=1)
        x_b_augment = torch.cat([x_b, m_b], dim=1)

        c_a = self.gen.encode(x_a, 1)
        c_b = self.gen.encode(x_b, 2)

        x_ba = self.gen.decode(c_b, 1)
        x_ab = self.gen.decode(c_a, 2)

        self.train()
        return x_ab, x_ba

    def gen_update(
        self,
        x_a,
        x_b,
        hyperparameters,
        mask_a,
        mask_b,
        comet_exp=None,
        synth=False,
        semantic_gt_a=None,
        semantic_gt_b=None,
    ):
        """
        Update the generator parameters

        Arguments:
            x_a {torch.Tensor} -- Image from domain A after transform in tensor format
            x_b {torch.Tensor} -- Image from domain B after transform in tensor format
            hyperparameters {dictionnary} -- dictionnary with all hyperparameters 

        Keyword Arguments:
            mask_a {torch.Tensor} -- binary mask (0,1) corresponding to the ground in x_a (default: {None})
            mask_b {torch.Tensor} -- binary mask (0,1) corresponding to the water in x_b (default: {None})
            comet_exp {cometExperience} -- CometML object use to log all the loss and images (default: {None})
            synth {boolean}  -- binary True or False stating if we have a synthetic pair or not 

        Returns:
            [type] -- [description]
        """
        self.gen_opt.zero_grad()

        # encode
        x_a_augment = torch.cat([x_a, mask_a], dim=1)
        x_b_augment = torch.cat([x_b, mask_b], dim=1)

        c_a = self.gen.encode(x_a, 1)
        c_b = self.gen.encode(x_b, 2)

        # decode (within domain)
        x_a_recon = self.gen.decode(c_a, mask_a, 1)
        x_b_recon = self.gen.decode(c_b, mask_b, 2)

        x_ba = self.gen.decode(c_b, mask_b, 1)
        x_ab = self.gen.decode(c_a, mask_a, 2)

        x_ba_augment = torch.cat([x_ba, mask_b], dim=1)
        x_ab_augment = torch.cat([x_ab, mask_a], dim=1)
        # encode again
        c_b_recon = self.gen.encode(x_ba, 1)
        c_a_recon = self.gen.encode(x_ab, 2)

        # decode again (if needed)
        x_aba = (self.gen.decode(c_a_recon, mask_a, 1)
                 if hyperparameters["recon_x_cyc_w"] > 0 else None)
        x_bab = (self.gen.decode(c_b_recon, mask_b, 2)
                 if hyperparameters["recon_x_cyc_w"] > 0 else None)

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)

        # Contex preserving loss
        self.context_loss = self.recon_criterion_mask(
            x_ab, x_a, mask_a) + self.recon_criterion_mask(x_ba, x_b, mask_b)

        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)

        # Synthetic reconstruction loss
        if synth:
            # print('mask_b.shape', mask_b.shape)
            # Define the mask of exact same pixel among a pair
            mask_alignment = (torch.sum(torch.abs(x_a - x_b),
                                        1) == 0).unsqueeze(1)
            mask_alignment = mask_alignment.type(torch.cuda.FloatTensor)
            # print('mask_alignment.shape', mask_alignment.shape)

        self.loss_gen_recon_synth = (
            self.recon_criterion_mask(x_ab, x_b, 1 - mask_alignment) +
            self.recon_criterion_mask(x_ba, x_a, 1 - mask_alignment)
            if synth else 0)

        if self.recon_mask:
            self.loss_gen_cycrecon_x_a = (self.recon_criterion_mask(
                x_aba, x_a, mask_a) if hyperparameters["recon_x_cyc_w"] > 0
                                          else 0)
            self.loss_gen_cycrecon_x_b = (self.recon_criterion_mask(
                x_bab, x_b, mask_b) if hyperparameters["recon_x_cyc_w"] > 0
                                          else 0)
        else:
            self.loss_gen_cycrecon_x_a = (self.recon_criterion(
                x_aba, x_a) if hyperparameters["recon_x_cyc_w"] > 0 else 0)
            self.loss_gen_cycrecon_x_b = (self.recon_criterion(
                x_bab, x_b) if hyperparameters["recon_x_cyc_w"] > 0 else 0)

        # GAN loss
        # Concat masks before feeding to loss

        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba,
                                                       x_a,
                                                       comet_exp,
                                                       mode="a")

        self.loss_gen_adv_a += self.dis_a_masked.calc_gen_loss(x_ba * mask_b,
                                                               x_a * mask_a,
                                                               comet_exp,
                                                               mode="a")

        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab,
                                                       x_b,
                                                       comet_exp,
                                                       mode="b")

        self.loss_gen_adv_b += self.dis_b_masked.calc_gen_loss(x_ab * mask_a,
                                                               x_b * mask_b,
                                                               comet_exp,
                                                               mode="b")

        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = (self.compute_vgg_loss(x_ba, x_b, mask_b)
                               if hyperparameters["vgg_w"] > 0 else 0)
        self.loss_gen_vgg_b = (self.compute_vgg_loss(x_ab, x_a, mask_a)
                               if hyperparameters["vgg_w"] > 0 else 0)

        # semantic-segmentation loss
        self.loss_sem_seg = (
            self.compute_semantic_seg_loss(x_a, x_ab, mask_a, semantic_gt_a) +
            self.compute_semantic_seg_loss(x_b, x_ba, mask_b, semantic_gt_b)
            if hyperparameters["semantic_w"] > 0 else 0)
        # Domain adversarial loss (c_a and c_b are swapped because we want the feature to be less informative
        # minmax (accuracy but max min loss)
        self.domain_adv_loss = (self.compute_domain_adv_loss(
            c_a, c_b, compute_accuracy=False, minimize=False)
                                if hyperparameters["domain_adv_w"] > 0 else 0)

        self.loss_classifier_sr = (self.compute_classifier_sr_loss(
            c_a, c_b, domain_synth=synth,
            fool=True) if hyperparameters["adaptation"]["adv_lambda"] > 0 else
                                   0)

        if hyperparameters["adaptation"]["output_adv_lambda"] > 0:
            self.loss_output_classifier_sr = self.output_classifier_sr_a.calc_gen_loss_sr(
                x_ba) + self.output_classifier_sr_b.calc_gen_loss_sr(x_ab)

        else:

            self.loss_output_classifier_sr = 0

        # total loss
        self.loss_gen_total = (
            hyperparameters["gan_w"] * self.loss_gen_adv_a +
            hyperparameters["gan_w"] * self.loss_gen_adv_b +
            hyperparameters["recon_x_w"] * self.loss_gen_recon_x_a +
            hyperparameters["recon_c_w"] * self.loss_gen_recon_c_a +
            hyperparameters["recon_x_w"] * self.loss_gen_recon_x_b +
            hyperparameters["recon_c_w"] * self.loss_gen_recon_c_b +
            hyperparameters["recon_x_cyc_w"] * self.loss_gen_cycrecon_x_a +
            hyperparameters["recon_x_cyc_w"] * self.loss_gen_cycrecon_x_b +
            hyperparameters["vgg_w"] * self.loss_gen_vgg_a +
            hyperparameters["vgg_w"] * self.loss_gen_vgg_b +
            hyperparameters["context_w"] * self.context_loss +
            hyperparameters["semantic_w"] * self.loss_sem_seg +
            hyperparameters["domain_adv_w"] * self.domain_adv_loss +
            hyperparameters["recon_synth_w"] * self.loss_gen_recon_synth +
            hyperparameters["adaptation"]["adv_lambda"] *
            self.loss_classifier_sr +
            hyperparameters["adaptation"]["output_adv_lambda"] *
            self.loss_output_classifier_sr)

        self.loss_gen_total.backward()
        self.gen_opt.step()

        if comet_exp is not None:

            comet_exp.log_metric("loss_gen_adv_a",
                                 self.loss_gen_adv_a.cpu().detach())
            comet_exp.log_metric("loss_gen_adv_b",
                                 self.loss_gen_adv_b.cpu().detach())
            comet_exp.log_metric("loss_gen_recon_x_a",
                                 self.loss_gen_recon_x_a.cpu().detach())
            comet_exp.log_metric("loss_gen_recon_x_b",
                                 self.loss_gen_recon_x_b.cpu().detach())

            if hyperparameters["recon_c_w"] > 0:
                comet_exp.log_metric("loss_gen_recon_c_a",
                                     self.loss_gen_recon_c_a.cpu().detach())
                comet_exp.log_metric("loss_gen_recon_c_b",
                                     self.loss_gen_recon_c_b.cpu().detach())

            if hyperparameters["recon_x_cyc_w"] > 0:
                comet_exp.log_metric("loss_gen_cycrecon_x_a",
                                     self.loss_gen_cycrecon_x_a.cpu().detach())
                comet_exp.log_metric("loss_gen_cycrecon_x_b",
                                     self.loss_gen_cycrecon_x_b.cpu().detach())
            comet_exp.log_metric("loss_gen_total",
                                 self.loss_gen_total.cpu().detach())

            if hyperparameters["vgg_w"] > 0:
                comet_exp.log_metric("loss_gen_vgg_a",
                                     self.loss_gen_vgg_a.cpu().detach())
                comet_exp.log_metric("loss_gen_vgg_b",
                                     self.loss_gen_vgg_b.cpu().detach())
            if hyperparameters["semantic_w"] > 0:
                comet_exp.log_metric("loss_sem_seg",
                                     self.loss_sem_seg.cpu().detach())
            if hyperparameters["context_w"] > 0:
                comet_exp.log_metric("context_preserve_loss",
                                     self.context_loss.cpu().detach())
            if hyperparameters["domain_adv_w"] > 0:
                comet_exp.log_metric("domain_adv_loss_gen",
                                     self.domain_adv_loss.cpu().detach())
            if synth:
                comet_exp.log_metric("loss_gen_recon_synth",
                                     self.loss_gen_recon_synth.cpu().detach())
            if self.use_classifier_sr:
                comet_exp.log_metric("loss_classifier_adv_sr",
                                     self.loss_classifier_sr.cpu().detach())
            if self.use_output_classifier_sr:
                comet_exp.log_metric(
                    "loss_output_classifier_adv_sr",
                    self.loss_output_classifier_sr.cpu().detach())

    def compute_vgg_loss(self, img, target, mask):
        """ 
        Compute the domain-invariant perceptual loss
        
        Arguments:
            vgg {model} -- popular Convolutional Network for Classification and Detection
            img {torch.Tensor} -- image before translation
            target {torch.Tensor} -- image after translation
        
        Returns:
            torch.Float -- domain invariant perceptual loss
        """
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)

        # Mask input to VGG:
        img_vgg = img_vgg * (1.0 - mask)
        target_vgg = target_vgg * (1.0 - mask)

        loss_G_VGG = self.criterionVGG(img_vgg, target_vgg)

        return loss_G_VGG

    def compute_classifier_sr_loss(self,
                                   c_a,
                                   c_b,
                                   domain_synth=False,
                                   fool=False):
        """ 
        Compute classifier loss for the adaptation s/r 
        
        Arguments:
            c_a {torch.Tensor} -- content of x_a
            c_b {torch.Tensor} -- content of x_b
            domain_synth {Boolean} -- Whether if the content is from s or r
            fool {Boolean} -- Wheter we want to fool the classifier or not 
        
        Returns:
            torch.Float -- domain invariant perceptual loss
        """
        # Infer domain classifier on content extracted from an image of domainA
        output_a = self.domain_classifier_sr_a(c_a)

        # Infer domain classifier on content extracted from an image of domainB
        output_b = self.domain_classifier_sr_b(c_b)

        if fool:
            loss = torch.mean((output_a - 0.5)**2) + torch.mean(
                (output_b - 0.5)**2)

        else:
            if domain_synth:
                loss = torch.mean((output_a - 0)**2) + torch.mean(
                    (output_b - 0)**2)

            else:
                loss = torch.mean((output_a - 1)**2) + torch.mean(
                    (output_b - 1)**2)

        return loss

    def compute_domain_adv_loss(self,
                                c_a,
                                c_b,
                                compute_accuracy=False,
                                minimize=True):
        """ 
        Compute a domain adversarial loss on the embedding of the classifier:
        we are trying to learn an anonymized representation of the content. 
        
        Arguments:
            c_a {torch.tensor} -- content extracted from an image of domain A with encoder A
            c_b {torch.tensor} -- content extracted from an image of domain B with encoder B
        
        Keyword Arguments:
            compute_accuracy {bool} -- either return only the loss or loss and softmax probs
            (default: {False})
            minimize {bool} -- optimize classification accuracy(True) or anonymized the representation(False)
        
        Returns:
            torch.Float -- loss (optionnal softmax P(classifier(c_a)=a) and P(classifier(c_b)=b)) 
        """
        # Infer domain classifier on content extracted from an image of domainA
        output_a = self.domain_classifier_ab(c_a)

        # Infer domain classifier on content extracted from an image of domainB
        output_b = self.domain_classifier_ab(c_b)

        # Concatenate the output in a single vector
        output = torch.cat((output_a, output_b))

        if minimize:
            target = torch.tensor([1.0, 0.0, 0.0, 1.0], device="cuda")
        else:
            target = torch.tensor([0.5, 0.5, 0.5, 0.5], device="cuda")
        # mean square error loss
        loss = torch.nn.MSELoss()(output, target)
        if compute_accuracy:
            return loss, output_a[0], output_b[1]
        else:
            return loss

    def compute_semantic_seg_loss(self,
                                  img1,
                                  img2,
                                  mask=None,
                                  ground_truth=None):
        """
        Compute semantic segmentation loss between two images on the unmasked region or in the entire image
        Arguments:
            img1 {torch.Tensor} -- Image from domain A after transform in tensor format
            img2 {torch.Tensor} -- Image transformed
            mask {torch.Tensor} -- Binary mask where we force the loss to be zero
            ground_truth {torch.Tensor} -- If available palletized image of size (batch, h, w) 
        Returns:
            torch.float -- Cross entropy loss on the unmasked region
        """
        new_class = 19
        # denorm
        img1_denorm = (img1 + 1) / 2.0
        img2_denorm = (img2 + 1) / 2.0
        # norm for semantic seg network
        input_transformed1 = seg_batch_transform(img1_denorm)
        input_transformed2 = seg_batch_transform(img2_denorm)

        # compute labels from original image and logits from translated version
        # target = (
        #   self.segmentation_model(input_transformed1).max(1)[1]
        # )
        # Infer x_ab or x_ba
        output = self.segmentation_model(input_transformed2)

        # If we have a ground truth (simulated data), merge classes to fit the ground truth of our simulated world (19 to 10)
        if ground_truth is not None:
            target = ground_truth.type(torch.long).cuda()
            target = target.squeeze(1)
            output = merge_classes(output).cuda()
            new_class = 10

        else:
            # Else use the pretrained model
            target = self.segmentation_model(input_transformed1).max(1)[1]

        # If we don't want to compute the loss on the masked region
        if not self.full_adaptation and mask is not None:
            # Resize mask to the size of the image
            # ADRIEN  DANGEROUS TO CHAAAANGE
            mask1 = torch.nn.functional.interpolate(mask,
                                                    size=(self.newsize,
                                                          self.newsize))

            mask1_tensor = torch.tensor(mask1, dtype=torch.long).cuda()
            mask1_tensor = mask1_tensor.squeeze(1)

            # we want the masked region to be labeled as unknown (19 is not an existing label)
            target_with_mask = (torch.mul(1 - mask1_tensor, target) +
                                mask1_tensor * new_class
                                )  # CATEGORICAL TENSOR (B 20 H W) (TARGET)

            mask2 = torch.nn.functional.interpolate(mask,
                                                    size=(self.newsize,
                                                          self.newsize))
            mask_tensor = torch.tensor(mask2, dtype=torch.float).cuda()
            output_with_mask = torch.mul(1 - mask_tensor, output)
            #
            # cat the mask as to the logits (loss=0 over the masked region)
            output_with_mask_cat = torch.cat((output_with_mask, mask_tensor),
                                             dim=1)
            loss = nn.CrossEntropyLoss()(output_with_mask_cat,
                                         target_with_mask)

        else:
            loss = nn.CrossEntropyLoss()(output, target)
        return loss

    def sample(self, x_a, x_b, m_a, m_b):
        """ 
        Infer the model on a batch of image
        
        Arguments:
            x_a {torch.Tensor} -- batch of image from domain A
            x_b {[type]} -- batch of image from domain B
        
        Returns:
            A list of torch images -- columnwise :x_a, autoencode(x_a), x_ab_1, x_ab_2
            Or if self.semantic_w is true: x_a, autoencode(x_a), Semantic segmentation x_a, 
            x_ab_1,semantic segmentation x_ab_1, x_ab_2
        """
        self.eval()

        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []

        x_a_augment = torch.cat([x_a, m_a], dim=1)
        x_b_augment = torch.cat([x_b, m_b], dim=1)

        for i in range(x_a.size(0)):
            c_a = self.gen.encode(x_a[i].unsqueeze(0), 1)
            c_b = self.gen.encode(x_b[i].unsqueeze(0), 2)

            x_a_recon.append(self.gen.decode(c_a, m_a[i].unsqueeze(0), 1))
            x_b_recon.append(self.gen.decode(c_b, m_b[i].unsqueeze(0), 2))

            x_ba1.append(self.gen.decode(c_b, m_b[i].unsqueeze(0),
                                         1))  # s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen.decode(c_b, m_b[i].unsqueeze(0),
                                         1))  # s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen.decode(c_a, m_a[i].unsqueeze(0),
                                         2))  # s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen.decode(c_a, m_a[i].unsqueeze(0),
                                         2))  # s_b2[i].unsqueeze(0)))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)

        if self.semantic_w:
            rgb_a_list, rgb_b_list, rgb_ab_list, rgb_ba_list = [], [], [], []

            for i in range(x_a.size(0)):

                # Inference semantic segmentation on original images
                im_a = (x_a[i].squeeze() + 1) / 2.0
                im_b = (x_b[i].squeeze() + 1) / 2.0

                input_transformed_a = seg_transform()(im_a).unsqueeze(0)
                input_transformed_b = seg_transform()(im_b).unsqueeze(0)
                output_a = self.segmentation_model(
                    input_transformed_a).squeeze().max(0)[1]
                output_b = self.segmentation_model(
                    input_transformed_b).squeeze().max(0)[1]

                rgb_a = decode_segmap(output_a.cpu().numpy())
                rgb_b = decode_segmap(output_b.cpu().numpy())
                rgb_a = Image.fromarray(rgb_a).resize(
                    (x_a.size(3), x_a.size(3)))
                rgb_b = Image.fromarray(rgb_b).resize(
                    (x_a.size(3), x_a.size(3)))

                rgb_a_list.append(transforms.ToTensor()(rgb_a).unsqueeze(0))
                rgb_b_list.append(transforms.ToTensor()(rgb_b).unsqueeze(0))

                # Inference semantic segmentation on fake images
                image_ab = (x_ab1[i].squeeze() + 1) / 2.0
                image_ba = (x_ba1[i].squeeze() + 1) / 2.0

                input_transformed_ab = seg_transform()(image_ab).unsqueeze(
                    0).to("cuda")
                input_transformed_ba = seg_transform()(image_ba).unsqueeze(
                    0).to("cuda")

                output_ab = self.segmentation_model(
                    input_transformed_ab).squeeze().max(0)[1]
                output_ba = self.segmentation_model(
                    input_transformed_ba).squeeze().max(0)[1]

                rgb_ab = decode_segmap(output_ab.cpu().numpy())
                rgb_ba = decode_segmap(output_ba.cpu().numpy())

                rgb_ab = Image.fromarray(rgb_ab).resize(
                    (x_a.size(3), x_a.size(3)))
                rgb_ba = Image.fromarray(rgb_ba).resize(
                    (x_a.size(3), x_a.size(3)))

                rgb_ab_list.append(transforms.ToTensor()(rgb_ab).unsqueeze(0))
                rgb_ba_list.append(transforms.ToTensor()(rgb_ba).unsqueeze(0))

            rgb1_a, rgb1_b, rgb1_ab, rgb1_ba = (
                torch.cat(rgb_a_list).cuda(),
                torch.cat(rgb_b_list).cuda(),
                torch.cat(rgb_ab_list).cuda(),
                torch.cat(rgb_ba_list).cuda(),
            )

        self.train()
        # Overlay mask onto image:
        save_m_a = x_a - (x_a * m_a.repeat(1, 3, 1, 1)) + m_a.repeat(
            1, 3, 1, 1)
        save_m_b = x_b - (x_b * m_b.repeat(1, 3, 1, 1)) + m_b.repeat(
            1, 3, 1, 1)

        if self.semantic_w:
            self.segmentation_model.eval()
            return (
                x_a,
                x_a_recon,
                rgb1_a,
                x_ab1,
                rgb1_ab,
                x_ab1 * m_a,
                save_m_a,
                x_b,
                x_b_recon,
                rgb1_b,
                x_ba1,
                rgb1_ba,
                x_ba2 * m_b,
                save_m_b,
            )
        else:
            return (
                x_a,
                x_a_recon,
                x_ab1,
                x_ab1 * m_a,
                save_m_a,
                x_b,
                x_b_recon,
                x_ba1,
                x_ba2 * m_b,
                save_m_b,
            )

    def sample_syn(self, x_a, x_b, m_a, m_b):
        """ 
        Infer the model on a batch of image
        
        Arguments:
            x_a {torch.Tensor} -- batch of image from domain A
            x_b {[type]} -- batch of image from domain B
        
        Returns:
            A list of torch images -- columnwise :x_a, autoencode(x_a), x_ab_1, x_ab_2
            Or if self.semantic_w is true: x_a, autoencode(x_a), Semantic segmentation x_a, 
            x_ab_1,semantic segmentation x_ab_1, x_ab_2
        """
        self.eval()

        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        x_a_augment = torch.cat([x_a, m_a], dim=1)
        x_b_augment = torch.cat([x_b, m_b], dim=1)

        for i in range(x_a.size(0)):
            c_a = self.gen.encode(x_a[i].unsqueeze(0), 1)
            c_b = self.gen.encode(x_b[i].unsqueeze(0), 2)
            x_a_recon.append(self.gen.decode(c_a, m_a[i].unsqueeze(0), 1))
            x_b_recon.append(self.gen.decode(c_b, m_b[i].unsqueeze(0), 2))

            x_ba1.append(self.gen.decode(c_b, m_b[i].unsqueeze(0), 1))
            x_ba2.append(self.gen.decode(c_b, m_b[i].unsqueeze(0), 1))
            x_ab1.append(self.gen.decode(c_a, m_a[i].unsqueeze(0), 2))
            x_ab2.append(self.gen.decode(c_a, m_a[i].unsqueeze(0), 2))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)

        if self.semantic_w:
            rgb_a_list, rgb_b_list, rgb_ab_list, rgb_ba_list = [], [], [], []

            for i in range(x_a.size(0)):

                # Inference semantic segmentation on original images
                im_a = (x_a[i].squeeze() + 1) / 2.0
                im_b = (x_b[i].squeeze() + 1) / 2.0

                input_transformed_a = seg_transform()(im_a).unsqueeze(0)
                input_transformed_b = seg_transform()(im_b).unsqueeze(0)
                output_a = self.segmentation_model(
                    input_transformed_a).squeeze().max(0)[1]
                output_b = self.segmentation_model(
                    input_transformed_b).squeeze().max(0)[1]

                rgb_a = decode_segmap(output_a.cpu().numpy())
                rgb_b = decode_segmap(output_b.cpu().numpy())
                rgb_a = Image.fromarray(rgb_a).resize(
                    (x_a.size(3), x_a.size(3)))
                rgb_b = Image.fromarray(rgb_b).resize(
                    (x_a.size(3), x_a.size(3)))

                rgb_a_list.append(transforms.ToTensor()(rgb_a).unsqueeze(0))
                rgb_b_list.append(transforms.ToTensor()(rgb_b).unsqueeze(0))

                # Inference semantic segmentation on fake images
                image_ab = (x_ab1[i].squeeze() + 1) / 2.0
                image_ba = (x_ba1[i].squeeze() + 1) / 2.0

                input_transformed_ab = seg_transform()(image_ab).unsqueeze(
                    0).to("cuda")
                input_transformed_ba = seg_transform()(image_ba).unsqueeze(
                    0).to("cuda")

                output_ab = self.segmentation_model(
                    input_transformed_ab).squeeze().max(0)[1]
                output_ba = self.segmentation_model(
                    input_transformed_ba).squeeze().max(0)[1]

                rgb_ab = decode_segmap(output_ab.cpu().numpy())
                rgb_ba = decode_segmap(output_ba.cpu().numpy())

                rgb_ab = Image.fromarray(rgb_ab).resize(
                    (x_a.size(3), x_a.size(3)))
                rgb_ba = Image.fromarray(rgb_ba).resize(
                    (x_a.size(3), x_a.size(3)))

                rgb_ab_list.append(transforms.ToTensor()(rgb_ab).unsqueeze(0))
                rgb_ba_list.append(transforms.ToTensor()(rgb_ba).unsqueeze(0))

            rgb1_a, rgb1_b, rgb1_ab, rgb1_ba = (
                torch.cat(rgb_a_list).cuda(),
                torch.cat(rgb_b_list).cuda(),
                torch.cat(rgb_ab_list).cuda(),
                torch.cat(rgb_ba_list).cuda(),
            )

        self.train()
        if self.semantic_w:
            self.segmentation_model.eval()
            return (
                x_a,
                x_a_recon,
                rgb1_a,
                x_ab1,
                rgb1_ab,
                x_ab2,
                x_b,
                x_b_recon,
                rgb1_b,
                x_ba1,
                rgb1_ba,
                x_ba2,
            )
        else:
            return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, m_a, m_b, hyperparameters, comet_exp=None):
        """
        Update the weights of the discriminator
        
        Arguments:
            x_a {torch.Tensor} -- Image from domain A after transform in tensor format
            x_b {torch.Tensor} -- Image from domain B after transform in tensor format
            hyperparameters {dictionnary} -- dictionnary with all hyperparameters 
        
        Keyword Arguments:
            comet_exp {cometExperience} -- CometML object use to log all the loss and images (default: {None})        
        """
        self.dis_opt.zero_grad()
        x_a_augment = torch.cat([x_a, m_a], dim=1)
        x_b_augment = torch.cat([x_b, m_b], dim=1)

        # encode
        c_a = self.gen.encode(x_a, 1)
        c_b = self.gen.encode(x_b, 2)
        # decode (cross domain)
        x_ba = self.gen.decode(c_b, m_b, 1)
        x_ab = self.gen.decode(c_a, m_a, 2)

        x_ba_augment = torch.cat([x_ba, m_b], dim=1)
        x_ab_augment = torch.cat([x_ab, m_a], dim=1)

        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(),
                                                   x_a,
                                                   comet_exp,
                                                   mode="a")

        self.loss_dis_a += self.dis_a_masked.calc_dis_loss(x_ba * m_b.detach(),
                                                           x_a * m_a,
                                                           comet_exp,
                                                           mode="a")

        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(),
                                                   x_b,
                                                   comet_exp,
                                                   mode="b")

        self.loss_dis_b += self.dis_b_masked.calc_dis_loss(x_ab * m_a.detach(),
                                                           x_b * m_b,
                                                           comet_exp,
                                                           mode="b")

        self.loss_dis_total = (hyperparameters["gan_w"] * self.loss_dis_a +
                               hyperparameters["gan_w"] * self.loss_dis_b)
        self.loss_dis_total.backward()
        self.dis_opt.step()

        if comet_exp is not None:
            comet_exp.log_metric("loss_dis_b", self.loss_dis_b.cpu().detach())
            comet_exp.log_metric("loss_dis_a", self.loss_dis_a.cpu().detach())

    def domain_classifier_update(self,
                                 x_a,
                                 x_b,
                                 hyperparameters,
                                 comet_exp=None):
        """
        Update the weights of the domain classifier
        
        Arguments:
            x_a {torch.Tensor} -- Image from domain A after transform in tensor format
            x_b {torch.Tensor} -- Image from domain B after transform in tensor format
            hyperparameters {dictionnary} -- dictionnary with all hyperparameters 
        
        Keyword Arguments:
            comet_exp {cometExperience} -- CometML object use to log all the loss and images (default: {None})        
        """
        self.dann_opt.zero_grad()

        # encode
        c_a = self.gen.encode(x_a, 1)
        c_b = self.gen.encode(x_b, 2)

        # domain classifier loss
        self.domain_class_loss, out_a, out_b = self.compute_domain_adv_loss(
            c_a, c_b, compute_accuracy=True, minimize=True)

        self.domain_class_loss.backward()
        self.dann_opt.step()

        if comet_exp is not None:
            comet_exp.log_metric("domain_class_loss",
                                 self.domain_class_loss.cpu().detach())
            comet_exp.log_metric("probability A being identified as A",
                                 out_a.cpu().detach())
            comet_exp.log_metric("probability B being identified as B",
                                 out_b.cpu().detach())

    def domain_classifier_sr_update(self,
                                    x_a,
                                    x_b,
                                    m_a,
                                    m_b,
                                    domain_synth,
                                    lambda_classifier,
                                    step,
                                    comet_exp=None):

        self.classif_opt_sr.zero_grad()

        # encode
        c_a = self.gen.encode(x_a, 1)
        c_b = self.gen.encode(x_b, 2)

        # noise = c_a.data.new(c_a.size()).normal_(0, 1)
        loss = self.compute_classifier_sr_loss(c_a.detach(),
                                               c_b.detach(),
                                               domain_synth,
                                               fool=False)
        loss = lambda_classifier * loss
        loss.backward()
        self.classif_opt_sr.step()

        if comet_exp is not None:
            comet_exp.log_metric("loss_classifier_sr",
                                 loss.cpu().detach(),
                                 step=step)

    def output_domain_classifier_sr_update(self,
                                           x_ar,
                                           x_as,
                                           x_br,
                                           x_bs,
                                           hyperparameters,
                                           step,
                                           comet_exp=None):

        self.output_classif_opt_sr.zero_grad()

        loss = self.output_classifier_sr_b.calc_dis_loss_sr(
            x_bs, x_br) + self.output_classifier_sr_a.calc_dis_loss_sr(
                x_as, x_ar)
        loss = hyperparameters["adaptation"]["output_classifier_lambda"] * loss
        loss.backward()

        self.output_classif_opt_sr.step()

        if comet_exp is not None:
            comet_exp.log_metric("loss_output_classifier_sr",
                                 loss.cpu().detach(),
                                 step=step)

    def segmentation_head_update(self,
                                 x_a,
                                 x_b,
                                 target_a,
                                 target_b,
                                 lamb,
                                 comet_exp=None):

        self.segmentation_opt.zero_grad()

        # encode
        c_a = self.gen.encode(x_a, 1)
        c_b = self.gen.encode(x_b, 2)

        output_a = self.segmentation_head(c_a)
        output_b = self.segmentation_head(c_b)
        output_a = nn.functional.interpolate(input=output_a,
                                             size=(self.newsize, self.newsize),
                                             mode="bilinear")
        output_b = nn.functional.interpolate(input=output_b,
                                             size=(self.newsize, self.newsize),
                                             mode="bilinear")

        loss1 = nn.CrossEntropyLoss()(output_a, target_a.type(
            torch.long).squeeze(1).cuda())
        loss2 = nn.CrossEntropyLoss()(output_b, target_b.type(
            torch.long).squeeze(1).cuda())
        loss = (loss1 + loss2) * lamb

        loss.backward()
        self.segmentation_opt.step()

        if comet_exp is not None:
            comet_exp.log_metric("loss_semantic_head", loss.cpu().detach())

    def update_learning_rate(self):
        """ 
        Update the learning rate
        """
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.dann_scheduler is not None:
            self.dann_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        """
        Resume the training loading the network parameters
        
        Arguments:
            checkpoint_dir {string} -- path to the directory where the checkpoints are saved
            hyperparameters {dictionnary} -- dictionnary with all hyperparameters 
        
        Returns:
            int -- number of iterations (used by the optimizer)
        """
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)

        self.gen.load_state_dict(state_dict["2"])

        # Load domain classifier
        if self.domain_classif_ab == 1:
            last_model_name = get_model_list(checkpoint_dir, "domain_classif")
            state_dict = torch.load(last_model_name)
            self.domain_classifier.load_state_dict(state_dict["d"])

        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict["a"])
        self.dis_b.load_state_dict(state_dict["b"])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, "optimizer.pt"))
        self.dis_opt.load_state_dict(state_dict["dis"])
        self.gen_opt.load_state_dict(state_dict["gen"])

        if self.domain_classif_ab == 1:
            self.dann_opt.load_state_dict(state_dict["dann"])
            self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters,
                                                iterations)
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print("Resume from iteration %d" % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        """
        Save generators, discriminators, and optimizers
        
        Arguments:
            snapshot_dir {string} -- directory path where to save the networks weights
            iterations {int} -- number of training iterations
        """
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, "gen_%08d.pt" % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, "dis_%08d.pt" % (iterations + 1))
        domain_classifier_name = os.path.join(
            snapshot_dir, "domain_classifier_%08d.pt" % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, "optimizer.pt")

        torch.save({"2": self.gen.state_dict()}, gen_name)
        torch.save({
            "a": self.dis_a.state_dict(),
            "b": self.dis_b.state_dict()
        }, dis_name)
        if self.domain_classif_ab:
            torch.save({"d": self.domain_classifier.state_dict()},
                       domain_classifier_name)
            torch.save(
                {
                    "gen": self.gen_opt.state_dict(),
                    "dis": self.dis_opt.state_dict(),
                    "dann": self.dann_opt.state_dict(),
                },
                opt_name,
            )
        else:
            torch.save(
                {
                    "gen": self.gen_opt.state_dict(),
                    "dis": self.dis_opt.state_dict()
                },
                opt_name,
            )
Example #22
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters, opts):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        self.opts = opts

        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.seg = segmentor(num_classes=2, channels=hyperparameters['input_dim_b'], hyperpars=hyperparameters['seg'])

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.seg_opt = torch.optim.SGD(self.seg.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters['lr_policy'], hyperparameters=hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters['lr_policy'], hyperparameters=hyperparameters)
        self.seg_scheduler = get_scheduler(self.seg_opt, 'constant', hyperparameters=None)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        self.criterion_seg = DiceLoss(ignore_index=hyperparameters['seg']['ignore_index'])

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters, target_a, iters):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        if iters >= hyperparameters['guide_gen_iters']:
            config.task = 0
            self.seg.eval()
            self.pred_x_ab = self.seg(x_ab)
            self.seg.train()

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0

        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)

        # semantic loss ab
        if iters >= hyperparameters['guide_gen_iters']:
            self.loss_sem_ab, _ = self.criterion_seg(self.pred_x_ab, target_a)
        else:
            self.loss_sem_ab = 0

        # only use semantic loss when segmentor has reasonably low loss
        if not hasattr(self, 'loss_seg_ab') or self.loss_seg_ab.detach().item() > -0.3:
            self.loss_sem_ab = 0

        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['sem_w'] * self.loss_sem_ab

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def seg_update(self, x_a,  x_b, target_a, target_b):
        self.seg.train()
        self.seg_opt.zero_grad()
        s_b = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        with torch.no_grad():
            # encode
            c_a, _ = self.gen_a.encode(x_a)
            # decode (cross domain)
            x_ab = self.gen_b.decode(c_a, s_b)

        config.task = 0
        self.pred_x_ab = self.seg(x_ab.detach())

        config.task = 1
        self.pred_x_b = self.seg(x_b)

        self.loss_seg_ab, _ = self.criterion_seg(self.pred_x_ab, target_a)
        self.loss_seg_b, _ = self.criterion_seg(self.pred_x_b, target_b)

        self.loss_seg_total = self.loss_seg_ab + self.loss_seg_b
        self.loss_seg_total.backward()
        self.seg_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.seg_scheduler is not None:
            self.seg_scheduler.step()

    def sample(self, x_a, x_b):
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_aba, x_bab, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [], [], []
        for i in range(x_b.size(0)):
            # encode
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            # decode (within domain)
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            # decode (cross domain)
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
            # encode again
            c_b_recon, s_a_recon = self.gen_a.encode(x_ba1[-1])
            c_a_recon, s_b_recon = self.gen_b.encode(x_ab1[-1])
            x_aba.append(self.gen_a.decode(c_a_recon, s_a_fake))
            x_bab.append(self.gen_b.decode(c_b_recon, s_b_fake))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab)

        self.train()
        return x_a, x_a_recon, x_aba, x_ab1, x_ab2, x_b, x_b_recon, x_bab, x_ba1, x_ba2

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, 'gen')
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, 'dis')
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load segmentor
        last_model_name = get_model_list(checkpoint_dir, 'seg')
        state_dict = torch.load(last_model_name)
        self.seg.load_state_dict(state_dict)

        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'opt.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])

        state_dict = torch.load(os.path.join(checkpoint_dir, 'opt_seg.pt'))
        self.seg_opt.load_state_dict(state_dict)

        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters['lr_policy'], hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters['lr_policy'], hyperparameters, iterations)
        self.seg_scheduler = get_scheduler(self.seg_opt, 'constant', None, iterations)

        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        seg_name = os.path.join(snapshot_dir, 'seg_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'opt.pt')
        opt_seg_name = os.path.join(snapshot_dir, 'opt_seg.pt')

        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save(self.seg.state_dict(), seg_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
        torch.save(self.seg_opt.state_dict(), opt_seg_name)
Example #23
0
class ERGAN_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(ERGAN_Trainer, self).__init__()
        lr_G = hyperparameters['lr_G']
        lr_D = hyperparameters['lr_D']
        print(lr_D, lr_G)
        self.fp16 = hyperparameters['fp16']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.gen_b.enc_content = self.gen_a.enc_content  # content share weight
        #self.gen_b.enc_style = self.gen_a.enc_style
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.a = hyperparameters['gen']['new_size'] / 224
        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr_D,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr_G,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        if self.fp16:
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.gen_b = self.gen_b.cuda()
            self.dis_b = self.dis_b.cuda()
            self.gen_a, self.gen_opt = amp.initialize(self.gen_a,
                                                      self.gen_opt,
                                                      opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a,
                                                      self.dis_opt,
                                                      opt_level="O1")

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):

        input = input.type_as(target)
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a, x_b)
        x_ab = self.gen_b.decode(c_a, s_b, x_a)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        #mask = torch.ones(x_a.shape).cuda()
        block_a = x_a.clone()
        block_b = x_b.clone()
        block_a[:, :,
                round(self.a * 92):round(self.a * 144),
                round(self.a * 48):round(self.a * 172)] = 0
        block_b[:, :,
                round(self.a * 92):round(self.a * 144),
                round(self.a * 48):round(self.a * 172)] = 0

        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime, x_a)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime, x_b)

        # decode (cross domain)
        # decode random
        # x_ba_randn = self.gen_a.decode(c_b, s_a, x_b)
        # x_ab_randn = self.gen_b.decode(c_a, s_b, x_a)
        # decode real
        x_ba_real = self.gen_a.decode(c_b, s_a_prime, x_b)
        x_ab_real = self.gen_b.decode(c_a, s_b_prime, x_a)

        block_ba_real = x_ba_real.clone()
        block_ab_real = x_ab_real.clone()
        block_ba_real[:, :,
                      round(self.a * 92):round(self.a * 144),
                      round(self.a * 48):round(self.a * 172)] = 0
        block_ab_real[:, :,
                      round(self.a * 92):round(self.a * 144),
                      round(self.a * 48):round(self.a * 172)] = 0
        # encode again
        # c_b_recon, s_a_recon = self.gen_a.encode(x_ba_randn)
        # c_a_recon, s_b_recon = self.gen_b.encode(x_ab_randn)

        c_b_real_recon, s_a_prime_recon = self.gen_a.encode(x_ba_real)
        c_a_real_recon, s_b_prime_recon = self.gen_b.encode(x_ab_real)
        # decode again (if needed)
        x_aba = self.gen_a.decode(
            c_a_real_recon, s_a_prime,
            x_a) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            c_b_real_recon, s_b_prime,
            x_b) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_res_a = self.recon_criterion(
            block_ab_real, block_a)
        self.loss_gen_recon_res_b = self.recon_criterion(
            block_ba_real, block_b)
        self.loss_gen_recon_x_a_re = self.recon_criterion(
            x_a_recon[:, :,
                      round(self.a * 92):round(self.a * 144),
                      round(self.a * 48):round(self.a * 172)],
            x_a[:, :,
                round(self.a * 92):round(self.a * 144),
                round(self.a * 48):round(self.a * 172)])
        self.loss_gen_recon_x_b_re = self.recon_criterion(
            x_b_recon[:, :,
                      round(self.a * 92):round(self.a * 144),
                      round(self.a * 48):round(self.a * 172)],
            x_b[:, :,
                round(self.a * 92):round(self.a * 144),
                round(self.a * 48):round(self.a * 172)]
        )  # both celebA and MeGlass: [92:144, 48:172]
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)

        self.loss_gen_recon_s_a_prime = self.recon_criterion(
            s_a_prime_recon, s_a_prime)
        self.loss_gen_recon_s_b_prime = self.recon_criterion(
            s_b_prime_recon, s_b_prime)

        self.loss_gen_recon_c_a_real = self.recon_criterion(
            c_a_real_recon, c_a)
        self.loss_gen_recon_c_b_real = self.recon_criterion(
            c_b_real_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a_real = self.dis_a.calc_gen_loss(x_ba_real)
        self.loss_gen_adv_b_real = self.dis_b.calc_gen_loss(x_ab_real)
        # domain-invariant perceptual loss
        # self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba_randn, x_b) if hyperparameters['vgg_w'] > 0 else 0
        # self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab_randn, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a_real + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_x_w_re'] * self.loss_gen_recon_x_b_re + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b_prime + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b_real + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b+\
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_x_w_re'] * self.loss_gen_recon_x_a_re + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a_prime + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a_real + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b_real +\
                              hyperparameters['recon_x_w_res'] * self.loss_gen_recon_res_b + \
                              hyperparameters['recon_x_w_res'] * self.loss_gen_recon_res_b

        if self.fp16:
            with amp.scale_loss(self.loss_gen_total,
                                self.gen_opt) as scaled_loss:
                scaled_loss.backward()
            self.gen_opt.step()

        else:
            self.loss_gen_total.backward()
            self.gen_opt.step()

        #loss_gan = hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b)
        loss_gan_real = hyperparameters['gan_w'] * (self.loss_gen_adv_a_real +
                                                    self.loss_gen_adv_b_real)
        loss_x = hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a +
                                                 self.loss_gen_recon_x_b)
        loss_x_re = hyperparameters['recon_x_w_re'] * (
            self.loss_gen_recon_x_a_re + self.loss_gen_recon_x_b_re)
        #loss_x_res = hyperparameters['recon_x_w_res'] * (self.loss_gen_recon_res_a + self.loss_gen_recon_res_b)
        #loss_s = hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b)
        #loss_c = hyperparameters['recon_c_w'] * (self.loss_gen_recon_c_a + self.loss_gen_recon_c_b)
        loss_s_prime = hyperparameters['recon_s_w'] * (
            self.loss_gen_recon_s_a_prime + self.loss_gen_recon_s_b_prime)
        loss_c_real = hyperparameters['recon_c_w'] * (
            self.loss_gen_recon_c_a_real + self.loss_gen_recon_c_b_real)
        loss_x_cyc = hyperparameters['recon_x_cyc_w'] * (
            self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b)

        #loss_vgg = hyperparameters['vgg_w'] * (self.loss_gen_vgg_a + self.loss_gen_vgg_b)
        print(
            '||total:%.2f||gan_real:%.2f||x:%.2f||x_re:%.2f||s_prime:%.4f||c_real:%.2f||x_cyc:%.4f||'
            % (self.loss_gen_total, loss_gan_real, loss_x, loss_x_re,
               loss_s_prime, loss_c_real, loss_x_cyc))

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):
        self.eval()
        # s_a1 = Variable(self.s_a)
        # s_b1 = Variable(self.s_b)
        # s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        # s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        #x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        x_a_recon, x_b_recon, x_bab, x_ab, x_ba, x_aba = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))

            x_a_recon.append(
                self.gen_a.decode(c_a, s_a_fake, x_a[i].unsqueeze(0)))
            x_b_recon.append(
                self.gen_b.decode(c_b, s_b_fake, x_b[i].unsqueeze(0)))

            x_ba_tmp = self.gen_a.decode(c_b, s_a_fake, x_b[i].unsqueeze(0))
            x_ab_tmp = self.gen_b.decode(c_a, s_b_fake, x_a[i].unsqueeze(0))
            x_ba.append(x_ba_tmp)
            x_ab.append(x_ab_tmp)

            c_b_recon, _ = self.gen_a.encode(x_ba_tmp)
            c_a_recon, _ = self.gen_b.encode(x_ab_tmp)

            x_aba.append(
                self.gen_a.decode(c_a_recon, s_a_fake, x_a[i].unsqueeze(0)))
            x_bab.append(
                self.gen_b.decode(c_b_recon, s_b_fake, x_b[i].unsqueeze(0)))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)

        x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab)
        x_ab, x_ba = torch.cat(x_ab), torch.cat(x_ba)
        self.train()

        return x_a, x_a_recon, x_ab, x_ba, x_b, x_aba, x_b, x_b_recon, x_ba, x_ab, x_a, x_bab

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        # s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        # s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (cross domain)

        x_ba_real = self.gen_a.decode(c_b, s_a_prime, x_b)
        x_ab_real = self.gen_b.decode(c_a, s_b_prime, x_a)
        # D loss
        # self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba_randn.detach(), x_a)
        # self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab_randn.detach(), x_b)

        self.loss_dis_a_real = self.dis_a.calc_dis_loss(
            x_ba_real.detach(), x_a)
        self.loss_dis_b_real = self.dis_b.calc_dis_loss(
            x_ab_real.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * (
            self.loss_dis_a_real + self.loss_dis_b_real)

        if self.fp16:
            with amp.scale_loss(self.loss_dis_total,
                                self.dis_opt) as scaled_loss:
                scaled_loss.backward()
            self.dis_opt.step()
        else:
            self.loss_dis_total.backward()
            self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations -
                                           375000)  #fine_tune -370000
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations - 375000)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Example #24
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        self.display_size = int(hyperparameters['display_size'])
        self.s_a = self.random_style()
        self.s_b = self.random_style()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        """
        Args:
            x_a: Image domain A
            x_b: Image domain B
            hyperparameters:

        Returns:

        """

        self.gen_opt.zero_grad()
        s_a = self.random_style(x_a)
        s_b = self.random_style(x_b)

        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a) # c_a - content encoding, s_a_prime - style encoding
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime) # x_a_recon - reconstruction from content and style vectors
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a) # content b, style a
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba) # encode to get content_b and style_a from cross domain image
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def random_style(self, x=None, factor=1):
        dim = self.display_size if x is None else x.size(0)
        return Variable(torch.randn(dim, self.style_dim, 1, 1).cuda()) * factor

    def sample(self, x_a, x_b):
        """

        Args:
            x_a:
            x_b:

        Returns:
            (tuple): domainA: original
                              reconstruction
                              A to B - fixed sample noise
                              A to B - random noise

        """

        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = self.random_style(x_a, factor=5)
        s_b2 = self.random_style(x_b, factor=5)
        #print(s_a1, s_a2)
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def sampleB_toA(self, x_b):
        """
        Args:
            x_b:

        Returns:
            (tuple, length=batch): INPUT IMAGES to domain A
        """

        self.eval()
        s_a2 = self.random_style(x_b)

        x_ba1 = []

        for i in range(x_b.size(0)): # loop through batches
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_ba1.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
        return x_ba1

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = self.random_style(x_a)
        s_b = self.random_style(x_b)
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters, gen_model=None, dis_model=None):
        # Load generators
        if gen_model is None:
            gen_model = get_model_list(checkpoint_dir, "gen") # last gen model
        gen_state_dict = torch.load(gen_model)
        self.gen_a.load_state_dict(gen_state_dict['a'])
        self.gen_b.load_state_dict(gen_state_dict['b'])
        iterations = int(gen_model[-11:-3])

        # Load discriminators
        if dis_model is None:
            dis_model = get_model_list(checkpoint_dir, "dis")
        dis_state_dict = torch.load(dis_model)
        self.dis_a.load_state_dict(dis_state_dict['a'])
        self.dis_b.load_state_dict(dis_state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def load_model(self, checkpoint_dir, hyperparameters, iteration):
        from pathlib import Path
        gen_model = Path(checkpoint_dir) / f"gen_{iteration:08d}.pt"
        dis_model = Path(checkpoint_dir) / f"dis_{iteration:08d}.pt"
        return self.resume(checkpoint_dir, hyperparameters, gen_model.as_posix(), dis_model.as_posix())

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Example #25
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'], hyperparameters['new_size'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        self.reg_param = hyperparameters['reg_param']
        self.beta_step = hyperparameters['beta_step']
        self.target_kl = hyperparameters['target_kl']
        self.gan_type = hyperparameters['gan_type']

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_b.parameters())
        gen_params = list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.gen_b.apply(weights_init(hyperparameters['init']))
        self.dis_b.apply(weights_init('gaussian'))

        # SSIM Loss
        self.ssim_loss = pytorch_ssim.SSIM()

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def recon_criterion_l1(self, input, target, mask):
        return torch.sum(torch.abs(input - target)) / torch.sum(mask)

    def forward(self, x_a, x_b):
        self.eval()
        s_b = self.gen_b.enc_style(x_b)
        c_a = self.gen_b.enc_content(x_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab

    def gen_update(self, x_a, x_b, hyperparameters):
        toogle_grad(self.dis_b, False)
        toogle_grad(self.gen_b, True)
        self.dis_b.train()
        self.gen_b.train()
        self.gen_opt.zero_grad()
        s_b = torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()
        # encode
        c_a = self.gen_b.enc_content(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)

        x_ab.requires_grad_()
        # reconstruction loss
        self.loss_gen_recon_x_ab_ssim = -self.ssim_loss.forward(x_a, x_ab)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        # GAN loss
        _, _, d_fake = self.dis_b(x_ab)
        # d_fake = d_fake['out']
        self.loss_gen_adv_b = self.compute_loss(d_fake, 1)
        # total loss
        self.loss_gen_total = self.loss_gen_adv_b + \
            hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
            hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
            hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
            hyperparameters['recon_x_ab'] * self.loss_gen_recon_x_ab_ssim
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):
        self.eval()
        x_ab = []
        s_b = self.gen_b.enc_style(x_b)
        for i in range(x_a.size(0)):
            c_a = self.gen_b.enc_content(x_a[i].unsqueeze(0))
            x_ab.append(self.gen_b.decode(c_a, s_b))
        x_ab = torch.cat(x_ab)
        self.train()
        return x_a, x_ab

    def dis_update(self, x_a, x_b, hyperparameters):
        toogle_grad(self.gen_b, False)
        toogle_grad(self.dis_b, True)
        self.gen_b.train()
        self.dis_b.train()
        self.dis_opt.zero_grad()

        # On real data
        x_b.requires_grad_()
        d_real_dict = self.dis_b(x_b)
        d_real = d_real_dict[2]
        dloss_real = self.compute_loss(d_real, 1)
        reg = 0.
        # Both grad penal and vgan!
        dloss_real.backward(retain_graph=True)
        # hard coded 10 weight for grad penal.
        reg += 10. * compute_grad2(d_real, x_b).mean()
        mu = d_real_dict[0]
        logstd = d_real_dict[1]
        kl_real = kl_loss(mu, logstd).mean()

        # On fake data
        with torch.no_grad():
            s_b = torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()
            c_a = self.gen_b.enc_content(x_a)
            x_ab = self.gen_b.decode(c_a, s_b)
        x_ab.requires_grad_()
        d_fake_dict = self.dis_b(x_ab)
        d_fake = d_fake_dict[2]
        dloss_fake = self.compute_loss(d_fake, 0)
        dloss_fake.backward(retain_graph=True)
        mu_fake = d_fake_dict[0]
        logstd_fake = d_fake_dict[1]
        kl_fake = kl_loss(mu_fake, logstd_fake).mean()
        avg_kl = 0.5 * (kl_real + kl_fake)
        reg += self.reg_param * avg_kl
        reg.backward()

        self.update_beta(avg_kl)
        self.dis_opt.step()

        self.loss_dis_total = (dloss_real + dloss_fake)
        return self.loss_dis_total.item()

    def compute_loss(self, d_out, target):
        targets = d_out.new_full(size=d_out.size(), fill_value=target)

        if self.gan_type == 'standard':
            loss = F.binary_cross_entropy_with_logits(d_out, targets)
        elif self.gan_type == 'wgan':
            loss = (2 * target - 1) * d_out.mean()
        else:
            raise NotImplementedError

        return loss

    def update_beta(self, avg_kl):
        with torch.no_grad():
            new_beta = self.reg_param - self.beta_step * \
                (self.target_kl - avg_kl)  # self.target_kl is constrain I_c,
            new_beta = max(new_beta, 0)
            # print('setting beta from %.2f to %.2f' % (self.reg_param, new_beta))
            self.reg_param = new_beta

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % iterations)
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % iterations)
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'b': self.dis_b.state_dict()}, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Example #26
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.is_ganilla_gen = hyperparameters['gen']['ganilla_gen']
        if self.is_ganilla_gen == False:
            self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
            self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        else:
            self.gen_a = AdaINGanilla(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a with ganilla architecture
            self.gen_b = AdaINGanilla(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b with ganilla architecture
            print(self.gen_a)
        if hyperparameters['dis']['dis_type'] == 'patch':
            if hyperparameters['dis']['use_patch_gan']:
                self.dis_a = PatchDis(hyperparameters['input_dim_a'], hyperparameters['dis'])
                self.dis_b = PatchDis(hyperparameters['input_dim_b'], hyperparameters['dis'])
            else:
                self.dis_a = MsImageDis(hyperparameters['input_dim_a'],
                                        hyperparameters['dis'])  # discriminator for domain a
                self.dis_b = MsImageDis(hyperparameters['input_dim_b'],
                                        hyperparameters['dis'])  # discriminator for domain b
            print(self.dis_a)
        else:
            self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
            self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            self.VggExtract = VggExtract(self.vgg)
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        if self.is_ganilla_gen:
            c_a = c_a[-1]
            c_b = c_b[-1]
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_updateN(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        _, s_a_prime = self.gen_a.encode(x_a)
        c, s_b_prime = self.gen_b.encode(x_b)

        # decode (within domain)
        x_a_recon = self.gen_a.decode(c, s_a_prime)
        x_b_recon = self.gen_b.decode(c, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c, s_a)
        x_ab = self.gen_b.decode(c, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)

        # decode again (if needed)
        x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        if self.is_ganilla_gen:
            c = c[-1]
            c_b_recon = c_b_recon[-1]
            c_a_recon = c_a_recon[-1]
        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_c_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_c_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_s_a = self.compute_vgg_loss(self.vgg, x_ba, x_a, all=1) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_s_b = self.compute_vgg_loss(self.vgg, x_ab, x_b, all=1) if hyperparameters['vgg_w'] > 0 else 0

        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_c_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_c_b
                              # hyperparameters['vgg_w'] * self.loss_gen_vgg_s_a + \
                              # hyperparameters['vgg_w'] * self.loss_gen_vgg_s_b


        self.loss_gen_total.backward()
        self.gen_opt.step()

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)

        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)

        # decode again (if needed)
        x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        if self.is_ganilla_gen:
            c_a = c_a[-1]
            c_b = c_b[-1]
            c_b_recon = c_b_recon[-1]
            c_a_recon = c_a_recon[-1]
        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_c_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_c_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_s_a = self.compute_vgg_loss(self.vgg, x_ba, x_a, all=1) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_s_b = self.compute_vgg_loss(self.vgg, x_ab, x_b, all=1) if hyperparameters['vgg_w'] > 0 else 0

        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_c_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_c_b
                              # hyperparameters['vgg_w'] * self.loss_gen_vgg_s_a + \
                              # hyperparameters['vgg_w'] * self.loss_gen_vgg_s_b


        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target, all=0):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        # img_fea = vgg(img_vgg)
        # target_fea = vgg(target_vgg)
        img_fea_dict = self.VggExtract(img_vgg)
        target_fea_dict = self.VggExtract(target_vgg)
        loss=0
        if all:
            # for feature in img_fea_dict:
            #     loss+= torch.mean((img_fea_dict[feature] - (target_fea_dict[feature])) ** 2)
            loss += torch.mean((img_fea_dict['relu4_3'] - (target_fea_dict['relu4_3'])) ** 2)
        else:
            loss += torch.mean(
                (self.instancenorm(img_fea_dict['relu4_3']) - self.instancenorm(target_fea_dict['relu4_3'])) ** 2)

        return loss

    def sample(self, x_a, x_b):
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_updateN(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        # c, _ = self.gen_a.encode(x_a)
        c, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c, s_a)
        x_ab = self.gen_b.decode(c, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Example #27
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters, opts):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']

        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        self.loss = {}

        # fix the noise used in sampling
        self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(opts.output_base + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
            logger.info(
                '{} - {} - Number of parameters: {}'.format(name, model,
                                                        num_params))

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss['G/rec_x_A'] = self.recon_criterion(x_a_recon, x_a)
        self.loss['G/rec_x_B'] = self.recon_criterion(x_b_recon, x_b)
        self.loss['G/rec_s_A'] = self.recon_criterion(s_a_recon, s_a)
        self.loss['G/rec_s_B'] = self.recon_criterion(s_b_recon, s_b)
        self.loss['G/rec_c_A'] = self.recon_criterion(c_a_recon, c_a)
        self.loss['G/rec_c_B'] = self.recon_criterion(c_b_recon, c_b)
        self.loss['G/cycrec_x_A'] = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss['G/cycrec_x_B'] = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0

        # GAN loss
        self.loss['G/adv_A'] = self.dis_a.calc_gen_loss(x_ba)
        self.loss['G/adv_B'] = self.dis_b.calc_gen_loss(x_ab)

        # domain-invariant perceptual loss
        self.loss['G/vgg_A'] = self.compute_vgg_loss(self.vgg, x_ba.cuda(), x_b.cuda()) if hyperparameters['vgg_w'] > 0 else 0
        self.loss['G/vgg_B'] = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0

        # total loss
        self.loss['G/total'] = hyperparameters['gan_w'] * self.loss['G/adv_A'] + \
                              hyperparameters['gan_w'] * self.loss['G/adv_B'] + \
                              hyperparameters['recon_x_w'] * self.loss['G/rec_x_A'] + \
                              hyperparameters['recon_s_w'] * self.loss['G/rec_s_A'] + \
                              hyperparameters['recon_c_w'] * self.loss['G/rec_c_A'] + \
                              hyperparameters['recon_x_w'] * self.loss['G/rec_x_B'] + \
                              hyperparameters['recon_s_w'] * self.loss['G/rec_s_B'] + \
                              hyperparameters['recon_c_w'] * self.loss['G/rec_c_B'] + \
                              hyperparameters['recon_x_cyc_w'] * self.loss['G/cycrec_x_A'] + \
                              hyperparameters['recon_x_cyc_w'] * self.loss['G/cycrec_x_B'] + \
                              hyperparameters['vgg_w'] * self.loss['G/vgg_A'] + \
                              hyperparameters['vgg_w'] * self.loss['G/vgg_B']
        self.loss['G/total'].backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def sample(self, x_a, x_b):
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a_fake.unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b_fake.unsqueeze(0)))

        outputs = {}
        outputs['A/real'] = x_a
        outputs['B/real'] = x_b

        outputs['A/rec'] = torch.cat(x_a_recon)
        outputs['B/rec'] = torch.cat(x_b_recon)

        outputs['A/B_random_style'] = torch.cat(x_ab1)
        outputs['A/B'] = torch.cat(x_ab2)
        outputs['B/A_random_style'] = torch.cat(x_ba1)
        outputs['B/A'] = torch.cat(x_ba2)

        self.train()

        return outputs

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss['D/A'] = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss['D/B'] = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss['D/total'] = hyperparameters['gan_w'] * self.loss['D/A'] + hyperparameters['gan_w'] * self.loss['D/B']
        self.loss['D/total'].backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            old_lr = self.dis_opt.param_groups[0]['lr']

            self.dis_scheduler.step()

            new_lr = self.dis_opt.param_groups[0]['lr']
            if old_lr != new_lr:
                logger.info('Updated D learning rate: {}'.format(new_lr))
        if self.gen_scheduler is not None:
            old_lr = self.gen_opt.param_groups[0]['lr']
            self.gen_scheduler.step()
            new_lr = self.gen_opt.param_groups[0]['lr']
            if old_lr != new_lr:
                logger.info('Updated G learning rate: {}'.format(new_lr))

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        logger.info('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)

        logger.info('Saving snapshots to: {}'.format(snapshot_dir))
Example #28
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()   # super() 函数是用于调用父类(超类)的一个方法。
        lr = hyperparameters['lr']
        # Initiate the networks, 需要好好看看生成器和鉴别器到底是如何构造的
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        # https://blog.csdn.net/liuxiao214/article/details/81037416
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        # s_a , s_b 表示的是两个不同的style
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()  # 16*8*1*1
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        # 两个鉴别器
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        # 两个生成器
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        # 优化器
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        # 优化策略
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        # 解释 apply  apply(lambda x,y : x+y, (1),{'y' : 2})   https://zhuanlan.zhihu.com/p/42756654
        self.apply(weights_init(hyperparameters['init']))  # 初始化当前类
        self.dis_a.apply(weights_init('gaussian'))   # 初始化dis_a,是一个类对象
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)     # here the self.s_a is random style
        s_b = Variable(self.s_b)     # here the self.s_b is random style
        # 两个auto-encoder
        c_a, s_a_fake = self.gen_a.encode(x_a)  # c_a, s_a_fake is the content and style of input x_a
        c_b, s_b_fake = self.gen_b.encode(x_b)
        # x_ba 表示的是 imgb->imga
        x_ba = self.gen_a.decode(c_b, s_a)  # combine(c_b, s_a) to generate the x_ba
        x_ab = self.gen_b.decode(c_a, s_b)  # combine(c_a, s_b) to generate the x_ab
        #  训练模式
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        # 这里只负责参数的更新操作,真正的训练操作实际上来源于AdainGan 以及 MsImgDis
        # self.gen_opt 表示的是生成器的优化器
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)

        # print('content shape:', c_b.shape)
        # print('style shape:', s_b_prime.shape)
        # decode (within domain), 这个是必要的部分,因为需要保证解码器的效果
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)   # 本质上这里传过来的还是c_a, 我们希望c_a_recon与c_a越相似越好
        # decode again (if needed)
        x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss  为什么要用vgg呢?也就是内容不变的约束
        # 对feature map 的 L2 loss  好好理解一下这里的域不变loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def sample(self, x_a, x_b):
        # 模型调整为eval模式
        self.eval()
        # # random style
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        # 实际上x_a.size(0) = display_size  16
        for i in range(x_a.size(0)):
            # https://blog.csdn.net/xiexu911/article/details/80820028
            # unsequeeze 在指定的维度上对数据的维度进行扩展
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))

        # 这里用到了cat,那么x_a_recon是4个channel,这样最终输出的就是32个channel,然后我们将这32个channel 分别打印show出来。
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, hyperparameters):
        # self.dis_opt 表示的是鉴别器的优化器
        self.dis_opt.zero_grad()
        # s_a : a 图片的风格
        # s_b : b 图片的风格
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        #  content encode ,encode 只是将不同图片的内容和风格进行编码,不做迁移处理, 生成器有一个编码器和一个解码器
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        # 进行交叉域的内容和风格decode,这里需要好好理解一下decode到底在做什么,为什么可以进行内容和风格的混合
        x_ba = self.gen_a.decode(c_b, s_a)   # xba 是由随机风格和原始图像内容组合的结果
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss  鉴别器loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)  # 这里要把x_ba, x_a 独立出来来看,因为鉴别器的目的是
        # 为了判断真伪, self.loss_dis_a 是分别计算x_ba, x_a的loss的和, 他们两个对象之间不涉及对比。
        # 这里挺有意思的,鉴别器的loss是两个独立输入图片的二分类loss的和, 注意的一点就是,我们实际上是希望送入鉴别器的这两张图拒用相同的风格。
        # 这里就要区分和重建loss的区别。
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) # 为什么是这样的组合呢?
        #  hyperparameters['gan_w'] weight of adversarial loss
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        # https://blog.csdn.net/jacke121/article/details/82995740 深入理解backward(), step()
        self.loss_dis_total.backward()  # calculate grad
        self.dis_opt.step()         # update grad

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step() # update the learning rate
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        # 每次选择的都是最新的模型
        last_model_name = get_model_list(checkpoint_dir, "gen")
        print('resume model: ', last_model_name)
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations, gpuids):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        if len(gpuids) > 1:
            torch.save({'a': self.gen_a.module.state_dict(), 'b': self.gen_b.module.state_dict()}, gen_name)
            torch.save({'a': self.dis_a.module.state_dict(), 'b': self.dis_b.module.state_dict()}, dis_name)
            torch.save({'gen': self.gen_opt.module.state_dict(), 'dis': self.dis_opt.module.state_dict()}, opt_name)
        else:
            torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
            torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
            torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Example #29
0
class aclgan_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(aclmaskpermgidtno_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_AB = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain A
        self.gen_BA = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain B
        self.dis_A = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain A
        self.dis_B = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain B
        self.dis_2 = MsImageDis(hyperparameters['input_dim_b'],
                                hyperparameters['dis'])  # discriminator 2
        #        self.dis_2B = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator 2 for domain B
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.z_1 = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.z_2 = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.z_3 = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_A.parameters()) + list(
            self.dis_B.parameters()) + list(self.dis_2.parameters())
        gen_params = list(self.gen_AB.parameters()) + list(
            self.gen_BA.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.alpha = hyperparameters['alpha']
        self.focus_lam = hyperparameters['focus_loss']

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_A.apply(weights_init('gaussian'))
        self.dis_B.apply(weights_init('gaussian'))
        self.dis_2.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        z_1 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        z_2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        z_3 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_1, _ = self.gen_AB.encode(x_a)
        c_2, s_2 = self.gen_BA.encode(x_a)
        c_4, s_4 = self.gen_AB.encode(x_b)
        # decode
        self.x_B_fake = self.gen_AB.decode(c_1, z_1)
        self.x_A_fake = self.gen_BA.decode(c_2, z_2)
        # recon
        self.x_A_recon = self.gen_BA.decode(c_2, s_2)
        self.x_B_recon = self.gen_AB.decode(c_4, s_4)
        #encode 2
        c_3, _ = self.gen_BA.encode(self.x_B_fake)
        self.x_A2_fake = self.gen_BA.decode(c_3, z_3)

        self.X_A_A1_pair = torch.cat((x_a, self.x_A_fake), -3)
        self.X_A_A2_pair = torch.cat((x_a, self.x_A2_fake), -3)

    def focus_translation(self, x_fg, x_bg, x_focus):
        x_map = (x_focus + 1) / 2
        x_map = x_map.repeat(1, 3, 1, 1)
        return torch.mul(x_fg, x_map) + torch.mul(x_bg, 1 - x_map)

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()

        focus_delta = hyperparameters['focus_delta']
        focus_lambda = hyperparameters['focus_loss']
        focus_lower = hyperparameters['focus_lower']
        focus_upper = hyperparameters['focus_upper']
        focus_epsilon = hyperparameters['focus_epsilon']
        #forward
        z_1 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        z_2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        z_3 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_1, _ = self.gen_AB.encode(x_a)
        c_2, s_2 = self.gen_BA.encode(x_a)
        c_4, s_4 = self.gen_AB.encode(x_b)
        # decode
        if focus_lambda > 0:
            x_B_fake, x_B_focus = self.gen_AB.decode(c_1, z_1).split(3, 1)
            x_A_fake, x_A_focus = self.gen_BA.decode(c_2,
                                                     self.alpha * z_2).split(
                                                         3, 1)
            x_B_fake = self.focus_translation(x_B_fake, x_a, x_B_focus)
            x_A_fake = self.focus_translation(x_A_fake, x_a, x_A_focus)
            # recon
            x_A_recon, x_A_recon_focus = self.gen_BA.decode(c_2,
                                                            s_2).split(3, 1)
            x_B_recon, x_B_recon_focus = self.gen_AB.decode(c_4,
                                                            s_4).split(3, 1)
#            x_A_recon = self.focus_translation(x_A_recon, x_a, x_A_recon_focus)
#            x_B_recon = self.focus_translation(x_B_recon, x_b, x_B_recon_focus)
        else:
            x_B_fake = self.gen_AB.decode(c_1, z_1)
            x_A_fake = self.gen_BA.decode(c_2, self.alpha * z_2)
            # recon
            x_A_recon = self.gen_BA.decode(c_2, s_2)
            x_B_recon = self.gen_AB.decode(c_4, s_4)

        #encode 2
        c_3, _ = self.gen_BA.encode(x_B_fake)
        if focus_lambda > 0:
            x_A2_fake, x_A2_focus = self.gen_BA.decode(c_3, z_3).split(3, 1)
            x_A2_fake = self.focus_translation(x_A2_fake, x_B_fake, x_A2_focus)
        else:
            x_A2_fake = self.gen_BA.decode(c_3, z_3)

        x_A_A1_pair = torch.cat((x_a, x_A_fake), -3)
        x_A_A2_pair = torch.cat((x_a, x_A2_fake), -3)

        # GAN loss
        self.loss_gen_adv_A = (self.dis_A.calc_gen_loss(x_A_fake) + \
                              self.dis_A.calc_gen_loss(x_A2_fake)) * 0.5
        self.loss_gen_adv_B = self.dis_B.calc_gen_loss(x_B_fake)
        self.loss_gen_adv_2 = self.dis_2.calc_gen_d2_loss(
            x_A_A1_pair, x_A_A2_pair)

        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_A + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_B + \
                              hyperparameters['gan_cw'] * self.loss_gen_adv_2
        if focus_lambda > 0:
            x_B_focus = (x_B_focus + 1) / 2
            x_A_focus = (x_A_focus + 1) / 2
            x_A2_focus = (x_A2_focus + 1) / 2
            self.loss_gen_focus_B_size = (F.relu(torch.sum(x_B_focus - focus_upper), inplace=True) ** 2) * focus_delta + \
                (F.relu(torch.sum(focus_lower - x_B_focus), inplace=True) ** 2) * focus_delta
            self.loss_gen_focus_B_digit = torch.sum(
                1 / (torch.abs(x_B_focus - 0.5) + focus_epsilon))
            self.loss_gen_focus_A_size = (F.relu(torch.sum(x_A_focus - focus_upper), inplace=True) ** 2) * focus_delta + \
                (F.relu(torch.sum(focus_lower - x_A_focus), inplace=True) ** 2) * focus_delta
            self.loss_gen_focus_A_digit = torch.sum(
                1 / (torch.abs(x_A_focus - 0.5) + focus_epsilon))
            #            self.loss_gen_focus_A = torch.sum(1 / (torch.abs(x_A_focus - 0.5) + focus_epsilon))
            self.loss_gen_focus_A2_size = (F.relu(torch.sum(x_A2_focus - focus_upper), inplace=True) ** 2) * focus_delta + \
                (F.relu(torch.sum(focus_lower - x_A2_focus), inplace=True) ** 2) * focus_delta
            self.loss_gen_focus_A2_digit = torch.sum(
                1 / (torch.abs(x_A2_focus - 0.5) + focus_epsilon))
            self.loss_gen_total += focus_lambda * (self.loss_gen_focus_B_size + self.loss_gen_focus_B_digit + \
                            self.loss_gen_focus_A_size + self.loss_gen_focus_A_digit +\
                            self.loss_gen_focus_A2_size + self.loss_gen_focus_A2_digit)/ x_a.size(2) / x_a.size(3) / x_a.size(0) / 3
        self.loss_idt_A = self.recon_criterion(x_A_recon, x_a)
        self.loss_idt_B = self.recon_criterion(x_B_recon, x_b)
        self.loss_gen_total += hyperparameters['recon_x_w'] * self.loss_idt_A + \
                              hyperparameters['recon_x_w'] * self.loss_idt_B

        #        print(self.loss_gen_focus_B, self.loss_gen_total)
        #        print(self.loss_idt_A)
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):
        self.eval()
        z_1 = Variable(self.z_1)
        z_2 = Variable(self.z_2)
        z_3 = Variable(self.z_3)
        x_A, x_B, x_A_fake, x_B_fake, x_A2_fake = [], [], [], [], []
        if self.focus_lam > 0:
            mask_A, mask_B, mask_A2, mask_recon = [], [], [], []
            x_A_recon = []
        else:
            x_A_recon, x_B_recon = [], []
        for i in range(x_a.size(0)):
            x_A.append(x_a[i].unsqueeze(0))
            x_B.append(x_b[i].unsqueeze(0))
            if self.focus_lam > 0:
                c_1, s_1 = self.gen_BA.encode(x_a[i].unsqueeze(0))
                img, mask = self.gen_BA.decode(c_1, z_1[i].unsqueeze(0)).split(
                    3, 1)
                x_A_fake.append(
                    self.focus_translation(img, x_a[i].unsqueeze(0), mask))
                mask_A.append(mask)

                img, mask = self.gen_BA.decode(c_1, s_1).split(3, 1)
                #                x_A_recon.append(self.focus_translation(img, x_a[i].unsqueeze(0), mask))
                x_A_recon.append(img)
                mask_recon.append(mask)

                c_2, _ = self.gen_AB.encode(x_a[i].unsqueeze(0))
                x_b_img, mask = self.gen_AB.decode(c_2,
                                                   z_2[i].unsqueeze(0)).split(
                                                       3, 1)
                x_b_img = self.focus_translation(x_b_img, x_a[i].unsqueeze(0),
                                                 mask)
                x_B_fake.append(x_b_img)
                mask_B.append(mask)

                c_3, _ = self.gen_BA.encode(x_b_img)
                img, mask = self.gen_BA.decode(c_3, z_3[i].unsqueeze(0)).split(
                    3, 1)
                x_A2_fake.append(self.focus_translation(img, x_b_img, mask))
                mask_A2.append(mask)

            else:
                c_1, s_1 = self.gen_BA.encode(x_a[i].unsqueeze(0))
                x_A_fake.append(self.gen_BA.decode(c_1, z_1[i].unsqueeze(0)))
                x_A_recon.append(self.gen_BA.decode(c_1, s_1))

                c_2, _ = self.gen_AB.encode(x_a[i].unsqueeze(0))
                x_B1 = self.gen_AB.decode(c_2, z_2[i].unsqueeze(0))
                x_B_fake.append(x_B1)

                c_3, _ = self.gen_BA.encode(x_B1)
                x_A2_fake.append(self.gen_BA.decode(c_3, z_3[i].unsqueeze(0)))

                c_4, s_4 = self.gen_AB.encode(x_b)
                x_B_recon.append(self.gen_AB.decode(c_4, s_4))

        if self.focus_lam > 0:
            x_A, x_B = torch.cat(x_A), torch.cat(x_B)
            x_A_fake, x_B_fake = torch.cat(x_A_fake), torch.cat(x_B_fake)
            mask_A, x_A2_fake = torch.cat(mask_A), torch.cat(x_A2_fake)
            mask_B, mask_recon = torch.cat(mask_B), torch.cat(mask_recon)
            mask_A2, x_A_recon = torch.cat(mask_A2), torch.cat(x_A_recon)
            self.train()
            return x_A, x_A_fake, mask_A, x_B_fake, mask_B, x_A2_fake, mask_A2, x_A_recon, mask_recon

        else:
            x_A, x_B = torch.cat(x_A), torch.cat(x_B)
            x_A_fake, x_B_fake = torch.cat(x_A_fake), torch.cat(x_B_fake)
            x_A_recon, x_A2_fake = torch.cat(x_A_recon), torch.cat(x_A2_fake)
            x_B_recon = torch.cat(x_B_recon)
            self.train()
            return x_A, x_A_fake, x_B_fake, x_A2_fake, x_A_recon, x_B, x_B_recon

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()

        focus_delta = hyperparameters['focus_delta']
        focus_lambda = hyperparameters['focus_loss']
        focus_epsilon = hyperparameters['focus_epsilon']
        #forward
        z_1 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        z_2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        z_3 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_1, _ = self.gen_AB.encode(x_a)
        c_2, s_2 = self.gen_BA.encode(x_a)
        c_4, s_4 = self.gen_AB.encode(x_b)
        # decode
        if focus_lambda > 0:
            x_B_fake, x_B_focus = self.gen_AB.decode(c_1, z_1).split(3, 1)
            x_A_fake, x_A_focus = self.gen_BA.decode(c_2,
                                                     self.alpha * z_2).split(
                                                         3, 1)
            x_B_fake = self.focus_translation(x_B_fake, x_a, x_B_focus)
            x_A_fake = self.focus_translation(x_A_fake, x_a, x_A_focus)
        else:
            x_B_fake = self.gen_AB.decode(c_1, z_1)
            x_A_fake = self.gen_BA.decode(c_2, self.alpha * z_2)

        #encode 2
        c_3, _ = self.gen_BA.encode(x_B_fake)
        if focus_lambda > 0:
            x_A2_fake, x_A2_focus = self.gen_BA.decode(c_3, z_3).split(3, 1)
            x_A2_fake = self.focus_translation(x_A2_fake, x_B_fake, x_A2_focus)
        else:
            x_A2_fake = self.gen_BA.decode(c_3, z_3)

        x_A_A1_pair = torch.cat((x_a, x_A_fake), -3)
        x_A_A2_pair = torch.cat((x_a, x_A2_fake), -3)

        # D loss
        self.loss_dis_A = (self.dis_A.calc_dis_loss(x_A_fake, x_a) + \
                           self.dis_A.calc_dis_loss(x_A2_fake, x_a)) * 0.5
        self.loss_dis_B = self.dis_B.calc_dis_loss(x_B_fake, x_b)
        self.loss_dis_2 = self.dis_2.calc_dis_loss(x_A_A1_pair, x_A_A2_pair)

        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_A + \
                                hyperparameters['gan_w'] * self.loss_dis_B + \
                                hyperparameters['gan_cw'] * self.loss_dis_2

        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_AB.load_state_dict(state_dict['AB'])
        self.gen_BA.load_state_dict(state_dict['BA'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_A.load_state_dict(state_dict['A'])
        self.dis_B.load_state_dict(state_dict['B'])
        self.dis_2.load_state_dict(state_dict['2'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save(
            {
                'AB': self.gen_AB.state_dict(),
                'BA': self.gen_BA.state_dict()
            }, gen_name)
        torch.save(
            {
                'A': self.dis_A.state_dict(),
                'B': self.dis_B.state_dict(),
                '2': self.dis_2.state_dict()
            }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Example #30
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        # self.gen_aT = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda(cun)
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda(cun)

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_aT, s_a_fake_T = self.gen_a.encodeT(x_a)
        
        # self.gen_a.ptl()
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a_fake)#change +  _fake
        x_ab = self.gen_b.decode(c_a, s_b_fake)
        x_aT = self.gen_a.decodeT(c_a,s_a_fake_T)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        # s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(cun))#torch.randn(*sizes)
        # s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(cun))
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        c_aT,s_aT_prime = self.gen_a.encodeT(x_a)

        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_aT_recon = self.gen_a.decodeT(c_a,s_aT_prime)

        # print("style code size:",s_a_prime.size())
        
        # print("recon img size:",x_a_recon.size())
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a_prime)
        x_ab = self.gen_b.decode(c_a, s_b_prime)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # print("content code size:",c_a_recon.size())
        # decode again (if needed)
        x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_styleT = self.recon_criterion(x_a,x_aT_recon)
        self.loss_gen_content = self.recon_criterion(c_a,c_aT)
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a_prime)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b_prime)
        # self.loss_gen_geo = self.recon_criterion(s_a_prime,s_b_prime)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total += self.recon_criterion(c_a,c_aT) * 5
        self.loss_gen_total += self.loss_gen_styleT * 5
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def sample(self, x_a, x_b):
        self.eval()
        # s_a1 = Variable(self.s_a)
        # s_b1 = Variable(self.s_b)
        # s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(cun))
        # s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(cun))
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2, x_ba, x_ab = [], [], [], [], [], [], [], []
        # for i in range(x_a.size(0)):
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
        x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
        x_ba.append(self.gen_a.decode(c_b, s_a_fake))
        # x_ba2.append(self.gen_a.decode(c_b, s_a_fake))
        x_ab.append(self.gen_b.decode(c_a, s_b_fake))
        # x_ab2.append(self.gen_b.decode(c_a, s_b_fake))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba = torch.cat(x_ba)
        x_ab = torch.cat(x_ab)
 
        # x_ab1, x_ab2 = torch.cat(x_ab1), torch
        # .cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(cun))
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(cun))
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, _)
        x_ab = self.gen_b.decode(c_a, _)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)


# class UNIT_Trainer(nn.Module):
#     def __init__(self, hyperparameters):
#         super(UNIT_Trainer, self).__init__()
#         lr = hyperparameters['lr']
#         # Initiate the networks
#         self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
#         self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
#         self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
#         self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
#         self.instancenorm = nn.InstanceNorm2d(512, affine=False)

#         # Setup the optimizers
#         beta1 = hyperparameters['beta1']
#         beta2 = hyperparameters['beta2']
#         dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
#         gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
#         self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
#                                         lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
#         self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
#                                         lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
#         self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
#         self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

#         # Network weight initialization
#         self.apply(weights_init(hyperparameters['init']))
#         self.dis_a.apply(weights_init('gaussian'))
#         self.dis_b.apply(weights_init('gaussian'))

#         # Load VGG model if needed
#         if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
#             self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
#             self.vgg.eval()
#             for param in self.vgg.parameters():
#                 param.requires_grad = False

#     def recon_criterion(self, input, target):
#         return torch.mean(torch.abs(input - target))

#     def forward(self, x_a, x_b):
#         self.eval()
#         h_a, _ = self.gen_a.encode(x_a)
#         h_b, _ = self.gen_b.encode(x_b)
#         x_ba = self.gen_a.decode(h_b)
#         x_ab = self.gen_b.decode(h_a)
#         self.train()
#         return x_ab, x_ba

#     def __compute_kl(self, mu):
#         # def _compute_kl(self, mu, sd):
#         # mu_2 = torch.pow(mu, 2)
#         # sd_2 = torch.pow(sd, 2)
#         # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0)
#         # return encoding_loss
#         mu_2 = torch.pow(mu, 2)
#         encoding_loss = torch.mean(mu_2)
#         return encoding_loss

#     def gen_update(self, x_a, x_b, hyperparameters):
#         self.gen_opt.zero_grad()
#         # encode
#         h_a, n_a = self.gen_a.encode(x_a)
#         h_b, n_b = self.gen_b.encode(x_b)
#         # decode (within domain)
#         x_a_recon = self.gen_a.decode(h_a + n_a)
#         x_b_recon = self.gen_b.decode(h_b + n_b)
#         # decode (cross domain)
#         x_ba = self.gen_a.decode(h_b + n_b)
#         x_ab = self.gen_b.decode(h_a + n_a)
#         # encode again
#         h_b_recon, n_b_recon = self.gen_a.encode(x_ba)
#         h_a_recon, n_a_recon = self.gen_b.encode(x_ab)
#         # decode again (if needed)
#         x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
#         x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

#         # reconstruction loss
#         self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
#         self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
#         self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
#         self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
#         self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a)
#         self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b)
#         self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
#         self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)
#         # GAN loss
#         self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
#         self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
#         # domain-invariant perceptual loss
#         self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
#         self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
#         # total loss
#         # self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \                           
#         #                       hyperparameters['gan_w'] * self.loss_gen_adv_b + \
#         #                       hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
#         #                       hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
#         #                       hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
#         #                       hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
#         #                       hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
#         #                       hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
#         #                       hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
#         #                       hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \
#         #                       hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
#         #                       hyperparameters['vgg_w'] * self.loss_gen_vgg_b
#         # self.loss_gen_total.backward()
#         # self.gen_opt.step()
#         self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \                           
#                               hyperparameters['gan_w'] * self.loss_gen_adv_b + \
#                               hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
#                               hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
#                               hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
#                               hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
#                               hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
#                               hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
#                               hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
#                               hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \
#                               hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
#                               hyperparameters['vgg_w'] * self.loss_gen_vgg_b
#         self.loss_gen_total.backward()
#         self.gen_opt.step()

#     def compute_vgg_loss(self, vgg, img, target):
#         img_vgg = vgg_preprocess(img)
#         target_vgg = vgg_preprocess(target)
#         img_fea = vgg(img_vgg)
#         target_fea = vgg(target_vgg)
#         return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

#     def sample(self, x_a, x_b):
#         self.eval()
#         x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], []
#         for i in range(x_a.size(0)):
#             h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0))
#             h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0))
#             x_a_recon.append(self.gen_a.decode(h_a))
#             x_b_recon.append(self.gen_b.decode(h_b))
#             x_ba.append(self.gen_a.decode(h_b))
#             x_ab.append(self.gen_b.decode(h_a))
#         x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
#         x_ba = torch.cat(x_ba)
#         x_ab = torch.cat(x_ab)
#         self.train()
#         return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

#     def dis_update(self, x_a, x_b, hyperparameters):
#         self.dis_opt.zero_grad()
#         # encode
#         h_a, n_a = self.gen_a.encode(x_a)
#         h_b, n_b = self.gen_b.encode(x_b)
#         # decode (cross domain)
#         x_ba = self.gen_a.decode(h_b + n_b)
#         x_ab = self.gen_b.decode(h_a + n_a)
#         # D loss
#         self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
#         self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
#         self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
#         self.loss_dis_total.backward()
#         self.dis_opt.step()

#     def update_learning_rate(self):
#         if self.dis_scheduler is not None:
#             self.dis_scheduler.step()
#         if self.gen_scheduler is not None:
#             self.gen_scheduler.step()

#     def resume(self, checkpoint_dir, hyperparameters):
#         # Load generators
#         last_model_name = get_model_list(checkpoint_dir, "gen")
#         state_dict = torch.load(last_model_name)
#         self.gen_a.load_state_dict(state_dict['a'])
#         self.gen_b.load_state_dict(state_dict['b'])
#         iterations = int(last_model_name[-11:-3])
#         # Load discriminators
#         last_model_name = get_model_list(checkpoint_dir, "dis")
#         state_dict = torch.load(last_model_name)
#         self.dis_a.load_state_dict(state_dict['a'])
#         self.dis_b.load_state_dict(state_dict['b'])
#         # Load optimizers
#         state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
#         self.dis_opt.load_state_dict(state_dict['dis'])
#         self.gen_opt.load_state_dict(state_dict['gen'])
#         # Reinitilize schedulers
#         self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
#         self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
#         print('Resume from iteration %d' % iterations)
#         return iterations

#     def save(self, snapshot_dir, iterations):
#         # Save generators, discriminators, and optimizers
#         gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
#         dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (
# terations + 1))
#         opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
#         torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Example #31
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, m_A, x_b, m_B, hyperparameters):
        self.gen_opt.zero_grad()

        im_A = 1 - m_A
        im_B = 1 - m_B
        # encode

        c_a, s_bA = self.gen_a.encode(x_a, im_A)
        c_b, s_fB = self.gen_b.encode(x_b, m_B)

        _, s_fA = self.gen_a.encode(x_a, m_A)
        _, s_bB = self.gen_b.encode(x_b, im_B)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_fA, m_B, s_bB)
        x_ab = self.gen_b.decode(c_a, s_fB, m_A, s_bA)

        # decode (within domain)
        x_aa = self.gen_a.decode(c_a, s_fA, m_A, s_bA)
        x_bb = self.gen_b.decode(c_b, s_fB, m_B, s_bB)

        # encode again
        c_ba, s_fBA = self.gen_a.encode(x_ba, m_B)
        c_ab, s_fAB = self.gen_a.encode(x_ab, m_A)

        _, s_bBA = self.gen_a.encode(x_ba, im_B)
        _, s_bAB = self.gen_a.encode(x_ab, im_A)

        # decode again (if needed)
        x_aba = self.gen_a.decode(
            c_ab, s_fBA, m_A,
            s_bAB) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            c_ba, s_fAB, m_B,
            s_bBA) if hyperparameters['recon_x_cyc_w'] > 0 else None

        self.loss_gen_recon_c_a = self.recon_criterion(c_ab, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_ba, c_b)

        self.loss_gen_recon_s_a = self.recon_criterion(s_bAB, s_bA)
        self.loss_gen_recon_s_b = self.recon_criterion(s_bBA, s_bB)
        self.loss_gen_recon_s_af = self.recon_criterion(s_fAB, s_fB)
        self.loss_gen_recon_s_bf = self.recon_criterion(s_fBA, s_fA)

        self.loss_gen_recon_x_a = self.recon_criterion(x_aa, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_bb, x_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            im_A * x_aba, im_A *
            x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            m_B * x_bab, m_B *
            x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_af + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_bf + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, loader_a, loader_b, size):
        self.eval()
        im_a = torch.stack([loader_a.dataset[i][0]
                            for i in range(size)]).cuda()
        seg_a = torch.stack([loader_a.dataset[i][1]
                             for i in range(size)]).cuda()
        im_b = torch.stack([loader_b.dataset[i][0]
                            for i in range(size)]).cuda()
        seg_b = torch.stack([loader_b.dataset[i][1]
                             for i in range(size)]).cuda()
        x_a_recon, x_b_recon, x_ba1, x_bm, x_ab1, x_am = [], [], [], [], [], []
        for i in range(im_a.size(0)):
            mask_a = seg_a[i].unsqueeze(0)
            mask_b = seg_b[i].unsqueeze(0)
            x_a = im_a[i].unsqueeze(0)
            x_b = im_b[i].unsqueeze(0)

            masked_a = mask_a * x_a
            masked_b = mask_b * x_b

            c_a, s_bA = self.gen_a.encode(x_a, 1 - mask_a)
            c_b, s_fB = self.gen_b.encode(x_b, mask_b)

            c_a, s_fA = self.gen_a.encode(x_a, mask_a)
            c_b, s_bB = self.gen_b.encode(x_b, 1 - mask_b)
            # decode (cross domain)
            x_BA = self.gen_a.decode(c_b, s_fA, mask_b, s_bB)
            x_AB = self.gen_b.decode(c_a, s_fB, mask_a, s_bA)

            if 0 == i % 2:
                x_AB = (1 * (1 - mask_a) * x_a +
                        (0 * (1 - mask_a) * x_AB)) + mask_a * x_AB
                x_BA = (1 * (1 - mask_b) * x_b +
                        (0 * (1 - mask_b) * x_BA)) + mask_b * x_BA

            x_ba1.append(x_BA)
            x_ab1.append(x_AB)
            x_am.append(masked_a)
            x_bm.append(masked_b)

            # decode (within domain)
            x_A_recon = self.gen_a.decode(c_a, s_fA, mask_a, s_bA)
            x_B_recon = self.gen_b.decode(c_b, s_fB, mask_b, s_bB)
            x_a_recon.append(x_A_recon)
            x_b_recon.append(x_B_recon)

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1 = torch.cat(x_ba1)
        x_ab1 = torch.cat(x_ab1)
        x_bm = torch.cat(x_bm)
        x_am = torch.cat(x_am)

        self.train()
        return im_a, x_a_recon, x_ab1, x_am, im_b, x_b_recon, x_ba1, x_bm

    def dis_update(self, x_a, m_a, x_b, m_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        up_im_A = 1 - m_a  #F.interpolate(1-m_a, None,1, 'bilinear', align_corners=False)
        up_m_B = m_b  #F.interpolate(m_b, None, 1, 'bilinear', align_corners=False)
        up_m_A = m_a  #F.interpolate(m_a, None, 1, 'bilinear', align_corners=False)
        up_im_B = 1 - m_b  #.interpolate(1-m_b, None, 1, 'bilinear', align_corners=False)

        c_a, s_bA = self.gen_a.encode(x_a, up_im_A)
        c_b, s_fB = self.gen_b.encode(x_b, up_m_B)

        _, s_fA = self.gen_a.encode(x_a, up_m_A)
        _, s_bB = self.gen_b.encode(x_b, up_im_B)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_fA, m_b, s_bB)
        x_ab = self.gen_b.decode(c_a, s_fB, m_a, s_bA)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)