Beispiel #1
0
    _, argmax = torch.max(outputs, 1)
    accuracy = (labels == argmax.squeeze()).float().mean()

    if (step + 1) % 100 == 0:
        print('Step [{}/{}], Loss: {:.4f}, Acc: {:.2f}'.format(
            step + 1, total_step, loss.item(), accuracy.item()))

        # ================================================================== #
        #                        Tensorboard Logging                         #
        # ================================================================== #

        # 1. Log scalar values (scalar summary)
        info = {'loss': loss.item(), 'accuracy': accuracy.item()}

        for tag, value in info.items():
            logger.scalar_summary(tag, value, step + 1)

        # 2. Log values and gradients of the parameters (histogram summary)
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            logger.histo_summary(tag, value.data.cpu().numpy(), step + 1)
            logger.histo_summary(tag + '/grad',
                                 value.grad.data.cpu().numpy(), step + 1)

        # 3. Log training images (image summary)
        info = {'images': images.view(-1, 28, 28)[:10].cpu().numpy()}
        # [:10]:取前 9 张?

        for tag, images in info.items():
            logger.image_summary(tag, images, step + 1)
Beispiel #2
0
class Solver(object):
    def __init__(self, data_loaders, config):
        # Data loader
        self.data_loaders = data_loaders
        self.attrs = config.attrs

        # Model hyper-parameters
        self.c_dim = len(data_loaders['train'].dataset.class_names)
        self.c2_dim = config.c2_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.d_train_repeat = config.d_train_repeat

        # Hyper-parameteres
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        # Training settings
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.batch_size = config.batch_size
        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model
        self.pretrained_model_path = config.pretrained_model_path

        # Test settings
        self.test_model = config.test_model

        # Path
        self.log_path = os.path.join(config.output_path, 'logs',
                                     config.output_name)
        self.sample_path = os.path.join(config.output_path, 'samples',
                                        config.output_name)
        self.model_save_path = os.path.join(config.output_path, 'models',
                                            config.output_name)
        self.result_path = os.path.join(config.output_path, 'results',
                                        config.output_name)

        # Step size
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step

        self.num_val_imgs = config.num_val_imgs

        # Build tensorboard if use
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def build_model(self):

        # self.G = UnetGenerator(3+self.c_dim)
        self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num,
                           self.image_size)
        self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim,
                               self.d_repeat_num)

        # Optimizers
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])

        # Print networks
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print_logger.info('{} - {} - Number of parameters: {}'.format(
            name, model, num_params))

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.pretrained_model_path,
                             '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(
            torch.load(
                os.path.join(self.pretrained_model_path,
                             '{}_D.pth'.format(self.pretrained_model))))
        print_logger.info('loaded trained models (step: {})..!'.format(
            self.pretrained_model))

    def build_tensorboard(self):
        from tensorboard_logger import Logger
        self.logger = Logger(self.log_path)

    def update_lr(self, g_lr, d_lr):
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def to_var(self, x, grad=True):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, requires_grad=grad)

    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def threshold(self, x):
        x = x.clone()
        x = (x >= 0.5).float()
        return x

    def compute_accuracy(self, x, y):
        x = F.sigmoid(x)
        predicted = self.threshold(x)
        correct = (predicted == y).float()
        accuracy = torch.mean(correct, dim=0) * 100.0
        return accuracy

    def one_hot(self, labels, dim):
        """Convert label indices to one-hot vector"""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def make_data_labels(self, real_c):
        """Generate domain labels for dataset for debugging/testing.
        """

        y = []
        for dim in range(self.c_dim):
            t = [0] * self.c_dim
            t[dim] = 1
            y.append(torch.FloatTensor(t))

        fixed_c_list = []

        for i in range(self.c_dim):
            fixed_c = real_c.clone()
            for c in fixed_c:
                c[:self.c_dim] = y[i]

            fixed_c_list.append(self.to_var(fixed_c, grad=False))

        return fixed_c_list

    def train(self):
        """Train StarGAN within a single dataset."""

        # The number of iterations per epoch
        data_loader = self.data_loaders['train']

        iters_per_epoch = len(data_loader)

        fixed_x = []
        real_c = []

        num_fixed_imgs = self.num_val_imgs
        for i in range(num_fixed_imgs):
            images, labels = self.data_loaders['val'].dataset.__getitem__(i)
            fixed_x.append(images.unsqueeze(0))
            real_c.append(labels.unsqueeze(0))

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, grad=False)
        real_c = torch.cat(real_c, dim=0)

        fixed_c_list = self.make_data_labels(real_c)

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (real_x, real_label) in enumerate(data_loader):

                # Generat fake labels randomly (target domain labels)
                rand_idx = torch.randperm(real_label.size(0))
                fake_label = real_label[rand_idx]

                real_c = real_label.clone()
                fake_c = fake_label.clone()

                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_c = self.to_var(real_c)  # input for the generator
                fake_c = self.to_var(fake_c)
                real_label = self.to_var(
                    real_label
                )  # this is same as real_c if dataset == 'CelebA'
                fake_label = self.to_var(fake_label)

                # ================== Train D ================== #

                # Compute loss with real images
                out_src, out_cls = self.D(real_x)
                d_loss_real = -torch.mean(out_src)

                d_loss_cls = F.binary_cross_entropy_with_logits(
                    out_cls, real_label, size_average=False) / real_x.size(0)

                # Compute classification accuracy of the discriminator
                if (i + 1) % (self.log_step * 10) == 0:
                    accuracies = self.compute_accuracy(out_cls, real_label)
                    log = [
                        "{}: {:.2f}".format(attr, acc)
                        for (attr, acc) in zip(data_loader.dataset.class_names,
                                               accuracies.data.cpu().numpy())
                    ]
                    print_logger.info('Discriminator Accuracy: {}'.format(log))

                # Compute loss with fake images
                fake_x = self.G(real_x, fake_c)
                fake_x = Variable(fake_x.data)
                out_src, out_cls = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                loss = {}
                loss['D/loss_real'] = d_loss_real.data.item()
                loss['D/loss_fake'] = d_loss_fake.data.item()
                loss['D/loss_cls'] = d_loss_cls.data.item()
                loss['D/loss'] = d_loss.data.item()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data +
                                        (1 - alpha) * fake_x.data,
                                        requires_grad=True)
                out, out_cls = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                loss['D/loss_gp'] = d_loss_gp.data.item()

                # ================== Train G ================== #
                if (i + 1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    fake_x = self.G(real_x, fake_c)
                    rec_x = self.G(fake_x, real_c)

                    # Compute losses
                    out_src, out_cls = self.D(fake_x)
                    g_loss_fake = -torch.mean(out_src)
                    g_loss_rec = torch.mean(torch.abs(real_x - rec_x))

                    g_loss_l1 = torch.mean(torch.abs(real_x - fake_x))

                    g_loss_cls = F.binary_cross_entropy_with_logits(
                        out_cls, fake_label,
                        size_average=False) / fake_x.size(0)

                    # Backward + Optimize
                    g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls  #+ g_loss_l1 * self.lambda_rec
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    loss['G/loss_fake'] = g_loss_fake.data.item()
                    loss['G/loss_rec'] = g_loss_rec.data.item()
                    loss['G/loss_cls'] = g_loss_cls.data.item()
                    # loss['G/loss_l1'] = g_loss_l1.data.item()
                    loss['G/loss'] = g_loss.data.item()

                # Print out log info
                if (i + 1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e + 1, self.num_epochs, i + 1,
                        iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print_logger.info(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(
                                tag, value, e * iters_per_epoch + i + 1)

                # Translate fixed images for debugging
                if (i + 1) % self.sample_step == 0:
                    fake_image_list = [fixed_x]

                    for fixed_c in fixed_c_list:
                        gen_imgs = self.G(fixed_x, fixed_c)
                        fake_image_list.append(gen_imgs)

                    # fake_images = torch.cat(fake_image_list, dim=3)
                    # save_image(self.denorm(fake_images.data.cpu()),
                    #     os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0)
                    # print_logger.info('Translated images and saved into {}..!'.format(self.sample_path))

                    if self.use_tensorboard:
                        tb_imgs = [t.unsqueeze(0) for t in fake_image_list]
                        tb_imgs = torch.cat(tb_imgs)
                        tb_imgs = tb_imgs.permute(1, 0, 2, 3, 4)
                        tb_imgs_list = torch.unbind(tb_imgs, dim=0)
                        tb_imgs_list = [
                            torch.cat(torch.unbind(t, dim=0), dim=2)
                            for t in tb_imgs_list
                        ]

                        self.logger.image_summary('fixed_imgs', tb_imgs_list,
                                                  e * iters_per_epoch + i + 1)

                # Save model checkpoints
                if (i + 1) % self.model_save_step == 0:
                    torch.save(
                        self.G.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_G.pth'.format(e + 1, i + 1)))
                    torch.save(
                        self.D.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_D.pth'.format(e + 1, i + 1)))

            # Decay learning rate
            if (e + 1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print_logger.info(
                    'Decay learning rate to g_lr: {}, d_lr: {}.'.format(
                        g_lr, d_lr))

    def test(self):
        """Facial attribute transfer on CelebA or facial expression synthesis on RaFD."""
        # Load trained parameters
        G_path = os.path.join(self.model_save_path,
                              '{}_G.pth'.format(self.test_model))
        self.G.load_state_dict(torch.load(G_path))
        self.G.eval()

        data_loader = self.data_loaders['test']

        for i, (real_x, org_c) in enumerate(data_loader):
            real_x = self.to_var(real_x, grad=False)
            target_c_list = self.make_data_labels(org_c)

            # Start translations
            fake_image_list = [real_x]
            for target_c in target_c_list:
                fake_image_list.append(self.G(real_x, target_c))
            fake_images = torch.cat(fake_image_list, dim=3)
            save_path = os.path.join(self.result_path,
                                     '{}_fake.png'.format(i + 1))
            save_image(self.denorm(fake_images.data),
                       save_path,
                       nrow=1,
                       padding=0)
            print_logger.info(
                'Translated test images and saved into "{}"..!'.format(
                    save_path))