Exemple #1
0
class Solver(object):
    def __init__(self, celebA_loader, config):
        # Data loader
        self.celebA_loader = celebA_loader
        self.visualizer = Visualizer(port=config.port, web_dir=config.web_dir)
        # Model hyper-parameters
        self.z_dim = config.z_dim
        self.c_dim = config.c_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.d_train_repeat = config.d_train_repeat

        # Hyper-parameteres
        self.lambda_cls = config.lambda_cls
        self.lambda_gp = config.lambda_gp
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        # Training settings
        self.dataset = config.dataset
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.batch_size = config.batch_size
        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model

        # Test settings
        self.test_model = config.test_model
        self.config = config

        # Path
        self.log_path = config.log_path
        self.sample_path = config.sample_path
        self.model_save_path = config.model_save_path
        self.result_path = config.result_path

        # Step size
        self.log_step = config.log_step
        self.visual_step = self.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step

        # Build tensorboard if use
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def build_model(self):
        # Define a generator and a discriminator
        self.G = Generator(self.z_dim, self.c_dim)
        self.D = Discriminator_CNN(self.c_dim)
        # Optimizers
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])
        # self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])
        # Print networks
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(
            self.pretrained_model))

    def build_tensorboard(self):
        from logger import Logger
        self.logger = Logger(self.log_path)

    def update_lr(self, g_lr, d_lr):
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def to_var(self, x, volatile=False):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def threshold(self, x):
        x = x.clone()
        x[x >= 0.5] = 1
        x[x < 0.5] = 0
        return x

    def compute_accuracy(self, x, y, dataset):
        if dataset == 'CelebA':
            x = F.sigmoid(x)
            predicted = self.threshold(x)
            correct = (predicted == y).float()
            accuracy = torch.mean(correct, dim=0) * 100.0
        else:
            _, predicted = torch.max(x, dim=1)
            correct = (predicted == y).float()
            accuracy = torch.mean(correct) * 100.0
        return accuracy

    def one_hot(self, labels, dim):
        """Convert label indices to one-hot vector"""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def make_celeb_labels_test(self):
        """Generate domain labels for CelebA for debugging/testing.

        if dataset == 'CelebA':
            return single and multiple attribute changes
        elif dataset == 'Both':
            return single attribute changes
        """
        y = [
            torch.FloatTensor([1, 0, 0]),  # black hair
            torch.FloatTensor([0, 1, 0]),  # blond hair
            torch.FloatTensor([0, 0, 1])
        ]  # brown hair

        fixed_c_list = []
        fixed_c_list.append(
            self.to_var(torch.FloatTensor([1, 0, 0, 1, 1]).unsqueeze(0),
                        volatile=True))
        fixed_c_list.append(
            self.to_var(torch.FloatTensor([0, 1, 0, 1, 1]).unsqueeze(0),
                        volatile=True))
        fixed_c_list.append(
            self.to_var(torch.FloatTensor([0, 0, 1, 1, 1]).unsqueeze(0),
                        volatile=True))
        fixed_c_list.append(
            self.to_var(torch.FloatTensor([1, 0, 0, 1, 0]).unsqueeze(0),
                        volatile=True))
        fixed_c_list.append(
            self.to_var(torch.FloatTensor([0, 1, 0, 1, 0]).unsqueeze(0),
                        volatile=True))
        fixed_c_list.append(
            self.to_var(torch.FloatTensor([0, 0, 1, 1, 0]).unsqueeze(0),
                        volatile=True))
        fixed_c_list.append(
            self.to_var(torch.FloatTensor([1, 0, 0, 0, 1]).unsqueeze(0),
                        volatile=True))
        fixed_c_list.append(
            self.to_var(torch.FloatTensor([0, 1, 0, 0, 1]).unsqueeze(0),
                        volatile=True))
        fixed_c_list.append(
            self.to_var(torch.FloatTensor([0, 0, 1, 0, 1]).unsqueeze(0),
                        volatile=True))
        fixed_c_list.append(
            self.to_var(torch.FloatTensor([1, 0, 0, 0, 0]).unsqueeze(0),
                        volatile=True))
        fixed_c_list.append(
            self.to_var(torch.FloatTensor([0, 1, 0, 0, 0]).unsqueeze(0),
                        volatile=True))
        fixed_c_list.append(
            self.to_var(torch.FloatTensor([0, 0, 1, 0, 0]).unsqueeze(0),
                        volatile=True))

        return fixed_c_list

    def train(self):
        """Train StarGAN within a single dataset."""

        # Set dataloader

        self.data_loader = self.celebA_loader

        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        # Fixed latent vector and label for output samples
        fixed_size = 20
        fixed_z = torch.randn(fixed_size, self.z_dim)
        fixed_z = self.to_var(fixed_z, volatile=True)

        fixed_c_list = self.make_celeb_labels_test()

        fixed_z_repeat = fixed_z.repeat(len(fixed_c_list), 1)
        fixed_c_repeat_list = []
        for fixed_c in fixed_c_list:
            fixed_c_repeat_list.append(
                fixed_c.expand(fixed_size, fixed_c.size(1)))
        fixed_c_list = []
        fixed_c_repeat = torch.cat(fixed_c_repeat_list, dim=0)
        fixed_c_repeat_list = []
        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0]) - 1
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            epoch_iter = 0
            for i, (real_x, real_label) in enumerate(self.data_loader):
                epoch_iter = epoch_iter + 1
                if self.dataset == 'Fashion':
                    real_c_i = real_label_i.clone()
                real_c = real_label.clone()
                # rand_idx = torch.randperm(real_c.size(0))
                # fake_c = real_c[rand_idx]

                z = torch.randn(real_x.size(0), self.z_dim)
                z = self.to_var(z)
                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_c = self.to_var(real_c)  # input for the generator
                if self.dataset == 'Fashion':
                    real_c_i = self.to_var(real_c_i)
                # fake_c = self.to_var(fake_c, volatile=True)
                # ================== Train D ================== #

                # Compute loss with real images
                out_src, out_cls = self.D(real_x)
                d_loss_real = -torch.mean(out_src)
                # print(real_x.size())
                # print(out_src.size())
                # print(out_cls.size())
                # print(real_c.size())
                if self.dataset == 'CelebA':
                    d_loss_cls = F.binary_cross_entropy_with_logits(
                        out_cls, real_c, size_average=False) / real_x.size(0)
                elif self.dataset == 'Fashion':
                    d_loss_cls = F.cross_entropy(out_cls, real_c_i)

                # # Compute classification accuracy of the discriminator
                # if (i+1) % self.log_step == 0:
                #     accuracies = self.compute_accuracy(out_cls, real_c, self.dataset)
                #     log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                #     if self.dataset == 'CelebA':
                #         print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='')
                #     else:
                #         print('Classification Acc (8 emotional expressions): ', end='')
                #     print(log)

                # Compute loss with fake images
                fake_x = self.G(z, real_c)
                fake_x = Variable(fake_x.data)
                out_src, out_cls = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data +
                                        (1 - alpha) * fake_x.data,
                                        requires_grad=True)
                out, out_cls = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                loss = {}
                loss['D/loss_real'] = d_loss_real.data[0]
                loss['D/loss_fake'] = d_loss_fake.data[0]
                loss['D/loss_cls'] = d_loss_cls.data[0]
                loss['D/loss_gp'] = d_loss_gp.data[0]

                # ================== Train G ================== #
                if (i + 1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    fake_x = self.G(z, real_c)
                    # fake_x2 = self.G(z, fake_c)
                    # Compute losses
                    out_src, out_cls = self.D(fake_x)
                    g_loss_fake = -torch.mean(out_src)
                    if self.dataset == 'CelebA':
                        g_loss_cls = F.binary_cross_entropy_with_logits(
                            out_cls, real_c,
                            size_average=False) / fake_x.size(0)
                    elif self.dataset == 'Fashion':
                        g_loss_cls = F.cross_entropy(out_cls, real_c_i)
                    # Backward + Optimize
                    g_loss = g_loss_fake + self.lambda_cls * g_loss_cls
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    loss['G/loss_fake'] = g_loss_fake.data[0]
                    loss['G/loss_cls'] = g_loss_cls.data[0]

                if (i + 1) % self.visual_step == 0:
                    # save visuals
                    self.real_x = real_x
                    self.fake_x = fake_x
                    # self.fake_x2 = fake_x2

                    # save losses
                    self.d_real = -d_loss_real
                    self.d_fake = d_loss_fake
                    self.d_loss = d_loss
                    self.g_loss = g_loss
                    self.g_loss_fake = g_loss_fake
                    self.g_loss_cls = self.lambda_cls * g_loss_cls
                    self.d_loss_cls = self.lambda_cls * d_loss_cls
                    errors_D = self.get_current_errors('D')
                    errors_G = self.get_current_errors('G')
                    self.visualizer.display_current_results(
                        self.get_current_visuals(), e)
                    self.visualizer.plot_current_errors_D(
                        e,
                        float(epoch_iter) / float(iters_per_epoch), errors_D)
                    self.visualizer.plot_current_errors_G(
                        e,
                        float(epoch_iter) / float(iters_per_epoch), errors_G)
                # Print out log info
                if (i + 1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e + 1, self.num_epochs, i + 1,
                        iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(
                                tag, value, e * iters_per_epoch + i + 1)

                # Translate fixed images for debugging
#                 if (i+1) % self.sample_step == 0:
# #                     fake_image_list = []
# #                     for fixed_c in fixed_c_list:
# #                         fixed_c = fixed_c.expand(fixed_z.size(0), fixed_c.size(1))
# #                         fake_image_list.append(self.G(fixed_z, fixed_c))

# #                     fake_images = torch.cat(fake_image_list, dim=3)
# #                     save_image(self.denorm(fake_images.data),
# #                         os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0)
# #                     print('Translated images and saved into {}..!'.format(self.sample_path))

#                     fake_images_repeat = self.G(fixed_z_repeat, fixed_c_repeat)
#                     fake_image_list = []
#                     for idx in range(12):
#                         fake_image_list.append(fake_images_repeat[fixed_size*(idx):fixed_size*(idx+1)])
#                     fake_images = torch.cat(fake_image_list, dim=3)
#                     save_image(self.denorm(fake_images.data),
#                         os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0)
#                     print('Translated images and saved into {}..!'.format(self.sample_path))

# Save model checkpoints
                if (i + 1) % self.model_save_step == 0:
                    torch.save(
                        self.G.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_G.pth'.format(e + 1, i + 1)))
                    torch.save(
                        self.D.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_D.pth'.format(e + 1, i + 1)))

            # Decay learning rate
            if (e + 1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print('Decay learning rate to g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))

    # def test(self):
    #     test_size = 30
    #     c_dim = 17
    #     test_c = self.to_var(torch.FloatTensor(np.eye(c_dim, dtype=float)), volatile=True)
    #     fake_image_list = []
    #     for i in range(test_size):
    #         test_z = self.to_var(torch.randn(c_dim, self.z_dim), volatile=True)
    #         test_z = test_z.expand(c_dim, test_z.size(1))
    #         fake_image_list.append(self.G(test_z, test_c))

    #     fake_images = torch.cat(fake_image_list, dim=3)
    #     save_image(self.denorm(fake_images.data),
    #                     os.path.join(self.result_path, 'fake.png'),nrow=1, padding=0)
    def make_celeb_labels(self):
        """Generate domain labels for CelebA for debugging/testing.

        if dataset == 'CelebA':
            return single and multiple attribute changes
        elif dataset == 'Both':
            return single attribute changes
        """
        y = [
            torch.FloatTensor([1, 0, 0]),  # black hair
            torch.FloatTensor([0, 1, 0]),  # blond hair
            torch.FloatTensor([0, 0, 1])
        ]  # brown hair

        fixed_c_list = []
        fixed_c_list.append(torch.FloatTensor([1, 0, 0, 1, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 1, 0, 1, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 0, 1, 1, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([1, 0, 0, 1, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 1, 0, 1, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 0, 1, 1, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([1, 0, 0, 0, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 1, 0, 0, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 0, 1, 0, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([1, 0, 0, 0, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 1, 0, 0, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 0, 1, 0, 0]).unsqueeze(0))
        fixed_c = torch.cat(fixed_c_list, dim=0)
        return fixed_c

    def make_celeb_labels_all(self):
        """Generate domain labels for CelebA for debugging/testing.

        if dataset == 'CelebA':
            return single and multiple attribute changes
        elif dataset == 'Both':
            return single attribute changes
        """
        y = [
            torch.FloatTensor([1, 0, 0]),  # black hair
            torch.FloatTensor([0, 1, 0]),  # blond hair
            torch.FloatTensor([0, 0, 1])
        ]  # brown hair

        fixed_c_list = []

        fixed_c_list.append(torch.FloatTensor([1, 0, 0, 1, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 1, 0, 1, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 0, 1, 1, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([1, 1, 0, 1, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 1, 1, 1, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([1, 0, 1, 1, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 0, 0, 1, 1]).unsqueeze(0))

        fixed_c_list.append(torch.FloatTensor([1, 0, 0, 1, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 1, 0, 1, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 0, 1, 1, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([1, 1, 0, 1, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 1, 1, 1, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([1, 0, 1, 1, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 0, 0, 1, 0]).unsqueeze(0))

        fixed_c_list.append(torch.FloatTensor([1, 0, 0, 0, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 1, 0, 0, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 0, 1, 0, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([1, 1, 0, 0, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 1, 1, 0, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([1, 0, 1, 0, 1]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 0, 0, 0, 1]).unsqueeze(0))

        fixed_c_list.append(torch.FloatTensor([1, 0, 0, 0, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 1, 0, 0, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 0, 1, 0, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([1, 1, 0, 0, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 1, 1, 0, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([1, 0, 1, 0, 0]).unsqueeze(0))
        fixed_c_list.append(torch.FloatTensor([0, 0, 0, 0, 0]).unsqueeze(0))

        return fixed_c_list

    def test_celeba(self):
        test_size = 16
        test_c = self.make_celeb_labels()
        # print(test_c)
        test_c = test_c.repeat(test_size, 1)
        test_c = self.to_var(test_c, volatile=True)
        test_z_list = []
        for i in range(16):
            test_z_list.append(torch.randn(1, self.z_dim).repeat(12, 1))
        test_z = torch.cat(test_z_list, dim=0)
        test_z = self.to_var(test_z, volatile=True)
        fake_image_mat = self.G(test_z, test_c)
        # fake_image_save = fake_image_mat.view(16, 3, 128*12, 128)
        save_image(self.denorm(fake_image_mat.data),
                   os.path.join(self.result_path, 'fake.png'),
                   nrow=1,
                   padding=0)

    def test_celeba_single(self):
        image_index = 0
        import math
        test_size = math.ceil(50000 / 28)
        c_dim = 28
        test_c = self.make_celeb_labels_all()
        test_c = self.to_var(torch.cat(test_c, dim=0), volatile=True)
        for i in range(test_size):
            test_z = self.to_var(torch.randn(c_dim, self.z_dim), volatile=True)
            fake_image_list = self.G(test_z, test_c)
            for ind in range(fake_image_list.size(0)):
                save_image(self.denorm(fake_image_list[ind].data),
                           os.path.join(
                               self.result_path,
                               'single/fake_{0:05d}.png'.format(image_index)),
                           nrow=1,
                           padding=0)
                image_index = image_index + 1
            if i > test_size - 1:
                break

    def test(self):
        import math
        test_size = math.ceil(50000 / 17)
        c_dim = 17
        test_c = self.to_var(torch.FloatTensor(np.eye(c_dim, dtype=float)),
                             volatile=True)
        fake_image_list = []
        image_index = 0
        for i in range(test_size):
            test_z = self.to_var(torch.randn(c_dim, self.z_dim), volatile=True)
            test_z = test_z.expand(c_dim, test_z.size(1))
            fake_image_list = self.G(test_z, test_c).transpose(2, 3)
            for ind in range(fake_image_list.size(0)):
                save_image(self.denorm(fake_image_list[ind].data),
                           os.path.join(
                               self.result_path,
                               'single/fake_{0:05d}.png'.format(image_index)),
                           nrow=1,
                           padding=0)
                image_index = image_index + 1

        # fake_images = torch.cat(fake_image_list, dim=3)
        # save_image(self.denorm(fake_images.data),
        #                 os.path.join(self.result_path, 'fake.png'),nrow=1, padding=0)

    def get_current_errors(self, label='all'):
        D_fake = self.d_fake.data[0]
        D_real = self.d_real.data[0]
        D_loss_cls = self.d_loss_cls.data[0]
        D_loss = self.d_loss.data[0]
        G_loss = self.g_loss.data[0]
        G_loss_cls = self.g_loss_cls.data[0]
        G_loss_fake = self.g_loss_fake.data[0]
        if label == 'all':
            return OrderedDict([('D_fake', D_fake), ('D_real', D_real),
                                ('D_loss', D_loss), ('G_loss', G_loss),
                                ('G_loss_fake', G_loss_fake)])
        if label == 'D':
            return OrderedDict([('D_fake', D_fake), ('D_loss_cls', D_loss_cls),
                                ('D_real', D_real), ('D_loss', D_loss)])
        if label == 'G':
            return OrderedDict([('G_loss', G_loss), ('G_loss_cls', G_loss_cls),
                                ('G_loss_fake', G_loss_fake)])

    def get_current_visuals(self):
        real_x = util.tensor2im(self.real_x.data)
        fake_x = util.tensor2im(self.fake_x.data)
        # fake_x2 = util.tensor2im(self.fake_x2.data)
        return OrderedDict([
            ('real_x', real_x),
            ('fake_x', fake_x),
            # ('fake_x2', fake_x2),
        ])
Exemple #2
0
class Solver(object):

    def __init__(self, celebA_loader, rafd_loader, config):
        # Data loader
        self.celebA_loader = celebA_loader
        self.rafd_loader = rafd_loader
        self.visualizer = Visualizer()
        # Model hyper-parameters
        self.c_dim = config.c_dim
        self.s_dim = config.s_dim
        self.c2_dim = config.c2_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.d_train_repeat = config.d_train_repeat

        # Hyper-parameteres
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp
        self.lambda_s = config.lambda_s
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.a_lr = config.a_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        # Criterion
        self.criterion_s = CrossEntropyLoss2d(size_average=True).cuda()

        # Training settings
        self.dataset = config.dataset
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.batch_size = config.batch_size
        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model

        # Test settings
        self.test_model = config.test_model
        self.config = config

        # Path
        self.log_path = config.log_path
        self.sample_path = config.sample_path
        self.model_save_path = config.model_save_path
        self.result_path = config.result_path

        # Step size
        self.log_step = config.log_step
        self.visual_step = self.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step

        # Build tensorboard if use
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def build_model(self):
        # Define a generator and a discriminator
        if self.dataset == 'Both':
            self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num)   # 2 for mask vector
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)
        else:
            self.G = Generator(self.g_conv_dim, self.c_dim, self.s_dim, self.g_repeat_num)
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) 
            self.A = Segmentor()

        # Optimizers
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.a_optimizer = torch.optim.Adam(self.A.parameters(), self.a_lr, [self.beta1, self.beta2])
        # Print networks
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')
        self.print_network(self.A, 'A')

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()
            self.A.cuda()

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_D.pth'.format(self.pretrained_model))))
        self.A.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_A.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(self.pretrained_model))

    def build_tensorboard(self):
        from logger import Logger
        self.logger = Logger(self.log_path)

    def update_lr(self, g_lr, d_lr):
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def to_var(self, x, volatile=False):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def threshold(self, x):
        x = x.clone()
        x[x >= 0.5] = 1
        x[x < 0.5] = 0
        return x

    def compute_accuracy(self, x, y, dataset):
        if dataset == 'CelebA':
            x = F.sigmoid(x)
            predicted = self.threshold(x)
            correct = (predicted == y).float()
            accuracy = torch.mean(correct, dim=0) * 100.0
        else:
            _, predicted = torch.max(x, dim=1)
            correct = (predicted == y).float()
            accuracy = torch.mean(correct) * 100.0
        return accuracy

    def one_hot(self, labels, dim):
        """Convert label indices to one-hot vector"""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def make_celeb_labels(self, real_c):
        """Generate domain labels for CelebA for debugging/testing.

        if dataset == 'CelebA':
            return single and multiple attribute changes
        elif dataset == 'Both':
            return single attribute changes
        """
        y = [torch.FloatTensor([1, 0, 0]),  # black hair
             torch.FloatTensor([0, 1, 0]),  # blond hair
             torch.FloatTensor([0, 0, 1])]  # brown hair

        fixed_c_list = []

        # single attribute transfer
        for i in range(self.c_dim):
            fixed_c = real_c.clone()
            for c in fixed_c:
                if i < 3:
                    c[:3] = y[i]
                else:
                    c[i] = 0 if c[i] == 1 else 1   # opposite value
            fixed_c_list.append(self.to_var(fixed_c, volatile=True))

        # multi-attribute transfer (H+G, H+A, G+A, H+G+A)
        if self.dataset == 'CelebA':
            for i in range(4):
                fixed_c = real_c.clone()
                for c in fixed_c:
                    if i in [0, 1, 3]:   # Hair color to brown
                        c[:3] = y[2] 
                    if i in [0, 2, 3]:   # Gender
                        c[3] = 0 if c[3] == 1 else 1
                    if i in [1, 2, 3]:   # Aged
                        c[4] = 0 if c[4] == 1 else 1
                fixed_c_list.append(self.to_var(fixed_c, volatile=True))
        return fixed_c_list

    def train(self):
        """Train StarGAN within a single dataset."""

        # Set dataloader
        if self.dataset == 'CelebA':
            self.data_loader = self.celebA_loader
        else:
            self.data_loader = self.rafd_loader

        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        fixed_x = []
        real_c = []
        fixed_s = []
        for i, (images, seg_i, seg, labels) in enumerate(self.data_loader):
            fixed_x.append(images)
            fixed_s.append(seg)
            real_c.append(labels)
            if i == 3:
                break

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)
        real_c = torch.cat(real_c, dim=0)

        
        fixed_s = torch.cat(fixed_s, dim=0)
        fixed_s_list = []
        fixed_s_list.append(self.to_var(fixed_s, volatile=True))
        
        rand_idx = torch.randperm(fixed_s.size(0))
        fixed_s_num = 5
        fixed_s_vec = fixed_s[rand_idx][:fixed_s_num]

        for i in range(fixed_s_num):
            fixed_s_temp = fixed_s_vec[i].unsqueeze(0).repeat(fixed_s.size(0),1,1,1)
            fixed_s_temp = self.to_var(fixed_s_temp)
            fixed_s_list.append(fixed_s_temp)

        # for i in range(4):
        #     rand_idx = torch.randperm(fixed_s.size(0))
        #     fixed_s_temp = self.to_var(fixed_s[rand_idx], volatile=True)
        #     fixed_s_list.append(fixed_s_temp)

        if self.dataset == 'CelebA':
            fixed_c_list = self.make_celeb_labels(real_c)
        elif self.dataset == 'RaFD':
            fixed_c_list = []
            for i in range(self.c_dim):
                fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c_dim)
                fixed_c_list.append(self.to_var(fixed_c, volatile=True))

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])-1
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            epoch_iter = 0
            for i, (real_x, real_s_i, real_s, real_label) in enumerate(self.data_loader):
                epoch_iter = epoch_iter + 1
                # Generat fake labels randomly (target domain labels)
                rand_idx = torch.randperm(real_label.size(0))
                fake_label = real_label[rand_idx]
                rand_idx = torch.randperm(real_label.size(0))
                fake_s = real_s[rand_idx]
                fake_s_i = real_s_i[rand_idx]
                if self.dataset == 'CelebA':
                    real_c = real_label.clone()
                    fake_c = fake_label.clone()
                else:
                    real_c = self.one_hot(real_label, self.c_dim)
                    fake_c = self.one_hot(fake_label, self.c_dim)

                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_s = self.to_var(real_s)
                real_s_i = self.to_var(real_s_i)
                fake_s = self.to_var(fake_s)
                fake_s_i = self.to_var(fake_s_i)
                real_c = self.to_var(real_c)           # input for the generator
                fake_c = self.to_var(fake_c)
                real_label = self.to_var(real_label)   # this is same as real_c if dataset == 'CelebA'
                fake_label = self.to_var(fake_label)
                
                # ================== Train D ================== #

                # Compute loss with real images
                out_src, out_cls = self.D(real_x)
                d_loss_real = - torch.mean(out_src)

                if self.dataset == 'CelebA':
                    d_loss_cls = F.binary_cross_entropy_with_logits(
                        out_cls, real_label, size_average=False) / real_x.size(0)
                else:
                    d_loss_cls = F.cross_entropy(out_cls, real_label)

                # Compute classification accuracy of the discriminator
                if (i+1) % self.log_step == 0:
                    accuracies = self.compute_accuracy(out_cls, real_label, self.dataset)
                    log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                    if self.dataset == 'CelebA':
                        print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='')
                    else:
                        print('Classification Acc (8 emotional expressions): ', end='')
                    print(log)

                # Compute loss with fake images
                fake_x = self.G(real_x, fake_c, fake_s)
                fake_x = Variable(fake_x.data)
                out_src, out_cls = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
                out, out_cls = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # ================== Train A ================== #
                self.a_optimizer.zero_grad()
                out_real_s = self.A(real_x)
                # a_loss = self.criterion_s(out_real_s, real_s_i.type(torch.cuda.LongTensor)) * self.lambda_s
                a_loss = self.criterion_s(out_real_s, real_s_i) * self.lambda_s
                # a_loss = torch.mean(torch.abs(real_s - out_real_s))
                a_loss.backward()
                self.a_optimizer.step()

                # Logging
                loss = {}
                loss['D/loss_real'] = d_loss_real.data[0]
                loss['D/loss_fake'] = d_loss_fake.data[0]
                loss['D/loss_cls'] = d_loss_cls.data[0]
                loss['D/loss_gp'] = d_loss_gp.data[0]

                # ================== Train G ================== #
                if (i+1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    fake_x = self.G(real_x, fake_c, fake_s)

                    rec_x = self.G(fake_x, real_c, real_s)

                    # Compute losses
                    out_src, out_cls = self.D(fake_x)
                    g_loss_fake = - torch.mean(out_src)
                    g_loss_rec = self.lambda_rec * torch.mean(torch.abs(real_x - rec_x))

                    if self.dataset == 'CelebA':
                        g_loss_cls = F.binary_cross_entropy_with_logits(
                            out_cls, fake_label, size_average=False) / fake_x.size(0)
                    else:
                        g_loss_cls = F.cross_entropy(out_cls, fake_label)

                    # segmentation loss
                    out_fake_s = self.A(fake_x)
                    g_loss_s = self.lambda_s * self.criterion_s(out_fake_s, fake_s_i)
                    # Backward + Optimize
                    g_loss = g_loss_fake + g_loss_rec + g_loss_s + self.lambda_cls * g_loss_cls
                    # g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    loss['G/loss_fake'] = g_loss_fake.data[0]
                    loss['G/loss_rec'] = g_loss_rec.data[0]
                    loss['G/loss_cls'] = g_loss_cls.data[0]
                
                if (i+1) % self.visual_step == 0:
                    # save visuals
                    self.real_x = real_x
                    self.fake_x = fake_x
                    self.rec_x = rec_x
                    self.real_s = real_s
                    self.fake_s = fake_s
                    self.out_real_s = out_real_s
                    self.out_fake_s = out_fake_s
                    self.a_loss = a_loss
                    # save losses
                    self.d_real = - d_loss_real
                    self.d_fake = d_loss_fake
                    self.d_loss = d_loss
                    self.g_loss = g_loss
                    self.g_loss_fake = g_loss_fake
                    self.g_loss_rec = g_loss_rec
                    self.g_loss_s = g_loss_s
                    errors_D = self.get_current_errors('D')
                    errors_G = self.get_current_errors('G')
                    self.visualizer.display_current_results(self.get_current_visuals(), e)
                    self.visualizer.plot_current_errors_D(e, float(epoch_iter)/float(iters_per_epoch), errors_D)
                    self.visualizer.plot_current_errors_G(e, float(epoch_iter)/float(iters_per_epoch), errors_G)
                # Print out log info
                if (i+1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e+1, self.num_epochs, i+1, iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1)

                # Translate fixed images for debugging
                if (i+1) % self.sample_step == 0:
                    fake_image_list = [fixed_x]
                    fixed_c = fixed_c_list[0]
                    real_seg_list = []
                    for fixed_c in fixed_c_list:
                        for fixed_s in fixed_s_list:
                            fake_image_list.append(self.G(fixed_x, fixed_c, fixed_s))
                            real_seg_list.append(fixed_s)
                    fake_images = torch.cat(fake_image_list, dim=3)
                    real_seg_images = torch.cat(real_seg_list, dim=3)
                    save_image(self.denorm(fake_images.data),
                        os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0)
                    save_image(self.cat2class_tensor(real_seg_images.data),
                        os.path.join(self.sample_path, '{}_{}_seg.png'.format(e+1, i+1)),nrow=1, padding=0)
                    print('Translated images and saved into {}..!'.format(self.sample_path))

                # Save model checkpoints
                if (i+1) % self.model_save_step == 0:
                    torch.save(self.G.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e+1, i+1)))
                    torch.save(self.D.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e+1, i+1)))
                    torch.save(self.A.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_A.pth'.format(e+1, i+1)))

            # Decay learning rate
            if (e+1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

    def make_celeb_labels_test(self):
        """Generate domain labels for CelebA for debugging/testing.

        if dataset == 'CelebA':
            return single and multiple attribute changes
        elif dataset == 'Both':
            return single attribute changes
        """
        y = [torch.FloatTensor([1, 0, 0]),  # black hair
             torch.FloatTensor([0, 1, 0]),  # blond hair
             torch.FloatTensor([0, 0, 1])]  # brown hair

        fixed_c_list = []
        fixed_c_list.append(self.to_var(torch.FloatTensor([1,0,0,1,1]).unsqueeze(0), volatile=True))
        fixed_c_list.append(self.to_var(torch.FloatTensor([0,1,0,1,1]).unsqueeze(0), volatile=True))
        fixed_c_list.append(self.to_var(torch.FloatTensor([0,0,1,1,1]).unsqueeze(0), volatile=True))
        fixed_c_list.append(self.to_var(torch.FloatTensor([1,0,0,1,0]).unsqueeze(0), volatile=True))
        fixed_c_list.append(self.to_var(torch.FloatTensor([0,1,0,1,0]).unsqueeze(0), volatile=True))
        fixed_c_list.append(self.to_var(torch.FloatTensor([0,0,1,1,0]).unsqueeze(0), volatile=True))
        fixed_c_list.append(self.to_var(torch.FloatTensor([1,0,0,0,1]).unsqueeze(0), volatile=True))
        fixed_c_list.append(self.to_var(torch.FloatTensor([0,1,0,0,1]).unsqueeze(0), volatile=True))
        fixed_c_list.append(self.to_var(torch.FloatTensor([0,0,1,0,1]).unsqueeze(0), volatile=True))
        fixed_c_list.append(self.to_var(torch.FloatTensor([1,0,0,0,0]).unsqueeze(0), volatile=True))
        fixed_c_list.append(self.to_var(torch.FloatTensor([0,1,0,0,0]).unsqueeze(0), volatile=True))
        fixed_c_list.append(self.to_var(torch.FloatTensor([0,0,1,0,0]).unsqueeze(0), volatile=True))

        return fixed_c_list

    def test(self):
        """Facial attribute transfer on CelebA or facial expression synthesis on RaFD."""
        # Load trained parameters
        G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model))
        self.G.load_state_dict(torch.load(G_path))
        self.G.eval()
        fixed_c_list = self.make_celeb_labels_test()
        transform = transforms.Compose([
            transforms.CenterCrop(self.config.celebA_crop_size),
            transforms.Scale(self.config.image_size),
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        transform_seg1 = transforms.Compose([
            transforms.CenterCrop(self.config.celebA_crop_size),
            transforms.Scale(self.config.image_size)])
        transform_seg2 = transforms.Compose([
            transforms.ToTensor()])

        for root, _, fnames in sorted(os.walk(self.config.test_image_path)):
            for fname in fnames:
                path = os.path.join(root, fname)
                image = Image.open(path)
        

                image = transform(image)
                image = image.unsqueeze(0)
                x = self.to_var(image, volatile=True)
                fake_image_mat = []
                for fixed_c in fixed_c_list:
                    fake_image_list = [x]
                    for i in range(11):
                        seg = Image.open(os.path.join(self.config.test_seg_path, '{}.png'.format(i+1)))
                        seg = transform_seg1(seg)
                        num_s = 7
                        seg_onehot = to_categorical(seg, num_s)
                        seg_onehot = transform_seg2(seg_onehot)*255.0
                        
                        seg_onehot = seg_onehot.unsqueeze(0)

                        
                        s = self.to_var(seg_onehot, volatile=True)

                        fake_x = self.G(x,fixed_c,s)
                        fake_image_list.append(fake_x)
                        # save_path = os.path.join(self.result_path, 'fake_x_{}.png'.format(i+1))
                        # save_image(self.denorm(fake_x.data), save_path, nrow=1, padding=0)
                    fake_images = torch.cat(fake_image_list, dim=2)
                    fake_image_mat.append(fake_images)

                fake_images_save = torch.cat(fake_image_mat, dim=3)
                
                save_path = os.path.join(self.result_path, 'fake_x_sum_{}.png'.format(fname))
                print('Translated test images and saved into "{}"..!'.format(save_path))
                save_image(self.denorm(fake_images_save.data), save_path, nrow=1, padding=0)
            # # Start translations
            # fake_image_list = [real_x]
            # for target_c in target_c_list:
            #     fake_image_list.append(self.G(real_x, target_c))
            # fake_images = torch.cat(fake_image_list, dim=3)
            # save_path = os.path.join(self.result_path, '{}_fake.png'.format(i+1))
            # save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0)
            # print('Translated test images and saved into "{}"..!'.format(save_path))
    def test_with_original_seg(self):
        """Facial attribute transfer on CelebA or facial expression synthesis on RaFD."""
        # Load trained parameters
        G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model))
        self.G.load_state_dict(torch.load(G_path))
        self.G.eval()
        fixed_c_list = self.make_celeb_labels_test()
        transform = transforms.Compose([
            # transforms.CenterCrop(self.config.celebA_crop_size),
            transforms.Scale(self.config.image_size),
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        transform_seg1 = transforms.Compose([
            transforms.CenterCrop(self.config.celebA_crop_size),
            transforms.Scale(self.config.image_size)])
        transform_seg2 = transforms.Compose([
            transforms.ToTensor()])

        for root, _, fnames in sorted(os.walk(self.config.test_image_path)):
            for fname in fnames:
                path = os.path.join(root, fname)
                image = Image.open(path)
        

                image = transform(image)
                image = image.unsqueeze(0)
                x = self.to_var(image, volatile=True)
                fake_image_mat = []
                for fixed_c in fixed_c_list:
                    fake_image_list = [x]
                    seg = Image.open(os.path.join(self.config.test_seg_path, '{}.png'.format(fname[:-4])))
                    seg = transform_seg1(seg)
                    num_s = 7
                    seg_onehot = to_categorical(seg, num_s)
                    seg_onehot = transform_seg2(seg_onehot)*255.0
                        
                    seg_onehot = seg_onehot.unsqueeze(0)

                    s = self.to_var(seg_onehot, volatile=True)

                    fake_x = self.G(x,fixed_c,s)
                    fake_image_list.append(fake_x)
                    # save_path = os.path.join(self.result_path, 'fake_x_{}.png'.format(i+1))
                    # save_image(self.denorm(fake_x.data), save_path, nrow=1, padding=0)
                    fake_images = torch.cat(fake_image_list, dim=3)
                    fake_image_mat.append(fake_images)

                fake_images_save = torch.cat(fake_image_mat, dim=2)
                
                save_path = os.path.join(self.result_path, 'fake_x_sum_{}.png'.format(fname))
                print('Translated test images and saved into "{}"..!'.format(save_path))
                save_image(self.denorm(fake_images_save.data), save_path, nrow=1, padding=0)
            # # Start translations
            # fake_image_list = [real_x]
            # for target_c in target_c_list:
            #     fake_image_list.append(self.G(real_x, target_c))
            # fake_images = torch.cat(fake_image_list, dim=3)
            # save_path = os.path.join(self.result_path, '{}_fake.png'.format(i+1))
            # save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0)
            # print('Translated test images and saved into "{}"..!'.format(save_path))
    def test_seg(self):
        """Facial attribute transfer on CelebA or facial expression synthesis on RaFD."""
        # Load trained parameters
        A_path = os.path.join(self.model_save_path, '{}_A.pth'.format(self.test_model))
        self.A.load_state_dict(torch.load(A_path))
        self.A.eval()

        transform = transforms.Compose([
            # transforms.CenterCrop(self.config.celebA_crop_size),
            transforms.Scale(self.config.image_size),
            # transforms.Scale(178),
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        for root, _, fnames in sorted(os.walk(self.config.test_image_path)):
            for fname in fnames:
                path = os.path.join(root, fname)
                image = Image.open(path)
                print('Read image "{}"..!'.format(fname))

                image = transform(image)
                image = image.unsqueeze(0)
                x = self.to_var(image, volatile=True)
                
                seg = self.A(x)
                seg_numpy = seg.data[0].cpu().float().numpy()
                seg_numpy = np.transpose(seg_numpy, (1, 2, 0)).astype(np.float)
                import scipy.io as sio
                sio.savemat('segnumpy.mat',{'seg':seg_numpy})
                
                print('Translated seg images and saved into "{}"..!'.format('segnumpy.mat'))
                
    def get_current_errors(self, label='all'):
        D_fake = self.d_fake.data[0]
        D_real = self.d_real.data[0]
        # D_fake = self.D_fake.data[0]
        # D_real = self.D_real.data[0]
        A_loss = self.a_loss.data[0]
        D_loss = self.d_loss.data[0]
        G_loss = self.g_loss.data[0]
        G_loss_fake = self.g_loss_fake.data[0]
        G_loss_s = self.g_loss_s.data[0]
        G_loss_rec = self.g_loss_rec.data[0]
        if label == 'all':
            return OrderedDict([('D_fake', D_fake), 
                                ('D_real', D_real), 
                                ('D', D_loss),
                                ('A_loss', A_loss),
                                ('G', G_loss), 
                                ('G_loss_fake', G_loss_fake), 
                                ('G_loss_s', G_loss_s),
                                ('G_loss_rec', G_loss_rec)])
        if label == 'D':
            return OrderedDict([('D_fake', D_fake), 
                                ('D_real', D_real), 
                                ('D', D_loss),
                                ('A_loss', A_loss)])
        if label == 'G':
            return OrderedDict([('A_loss', A_loss),
                                ('G', G_loss), 
                                ('G_loss_fake', G_loss_fake), 
                                ('G_loss_s', G_loss_s),
                                ('G_loss_rec', G_loss_rec)])

    def get_current_visuals(self):
        real_x = util.tensor2im(self.real_x.data)
        fake_x = util.tensor2im(self.fake_x.data)
        rec_x = util.tensor2im(self.rec_x.data)
        real_s = util.tensor2im_seg(self.real_s.data)
        fake_s = util.tensor2im_seg(self.fake_s.data)
        out_real_s = util.tensor2im_seg(self.out_real_s.data)
        out_fake_s = util.tensor2im_seg(self.out_fake_s.data)
        return OrderedDict([('real_x', real_x), 
                            ('fake_x', fake_x), 
                            ('rec_x', rec_x), 
                            ('real_s', self.cat2class(real_s)),
                            ('fake_s', self.cat2class(fake_s)),
                            ('out_real_s', self.cat2class(out_real_s)), 
                            ('out_fake_s', self.cat2class(out_fake_s))
                            ])
    def cat2class(self, m):
        y = np.zeros((np.size(m,0),np.size(m,1)),dtype='float64')
        for i in range(np.size(m,2)):
            y = y + m[:,:,i]*i
        y = y / float(np.max(y)) * 255.0 
        y = y.astype(np.uint8)
        y = np.reshape(y,(np.size(m,0),np.size(m,1),1))
        # print(np.shape(y))
        return np.repeat(y, 3, 2)

    def cat2class_tensor(self, m):
        y = []
        for i in range(m.size(0)):
            x = torch.cuda.FloatTensor(m.size(2),m.size(3)).zero_()
            for j in range(m.size(1)):
                x = x + m[i,j,:,:]*j
            x = x.unsqueeze(0).unsqueeze(1).expand(1,3,m.size(2),m.size(3))
            y.append(x)
        y = torch.cat(y, dim=0)
        y = y / float(torch.max(y))
        return y