예제 #1
0
                def gradient_penalty(fake_x, origin_x, D, d_optimizer):

                    alpha = torch.rand(origin_x.size(0), 1, 1,
                                       1).cuda().expand_as(origin_x)
                    interpolated = Variable(alpha * origin_x.data +
                                            (1 - alpha) * fake_x.data,
                                            requires_grad=True)
                    out = D(interpolated)

                    grad = torch.autograd.grad(outputs=out,
                                               inputs=interpolated,
                                               grad_outputs=torch.ones(
                                                   out.size()).cuda(),
                                               retain_graph=True,
                                               create_graph=True,
                                               only_inputs=True)[0]

                    grad = grad.view(grad.size(0), -1)
                    grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                    d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                    # Backward + Optimize
                    d_loss = self.lambda_gp * d_loss_gp
                    self.reset_grad()
                    d_loss.backward()
                    d_optimizer.step()

                    return d_loss_gp
예제 #2
0
    def fitness_score(self, eval_fake_imgs, eval_real_imgs):

        self.set_requires_grad(self.D, True)

        eval_fake = self.D(eval_fake_imgs)
        eval_real = self.D(eval_real_imgs)

        fake_loss = torch.mean(eval_fake)
        real_loss = -torch.mean(eval_real)

        D_loss_score = fake_loss + real_loss

        # quality fitness score
        Fq = nn.functional.sigmoid(eval_fake).data.mean().cpu().numpy()

        # Diversity fitness score

        gradients = torch.autograd.grad(
            outputs=D_loss_score,
            inputs=self.D.parameters(),
            grad_outputs=torch.ones(D_loss_score.size()).to(self.device),
            create_graph=True,
            retain_graph=True,
            only_inputs=True)

        with torch.no_grad():
            for i, grad in enumerate(gradients):
                grad = grad.view(-1)
                allgrad = grad if i == 0 else torch.cat([allgrad, grad])

        Fd = -torch.log(torch.norm(allgrad)).data.cpu().numpy()

        return Fq, Fd
예제 #3
0
 def make_step(grad, attack, step_size):
     if attack == 'l2':
         grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, 1, 1, 1)
         scaled_grad = grad / (grad_norm + 1e-10)
         step = step_size * scaled_grad
     elif attack == 'inf':
         step = step_size * torch.sign(grad)
     else:
         step = step_size * grad
     return step
예제 #4
0
def compute_grad_penalty(net_D, true_data, fake_data):
    batch_size = true_data.shape[0]
    epsilon = true_data.new(batch_size, 1, 1, 1)
    epsilon = epsilon.uniform_()
    line_data = true_data * (1 - epsilon) + fake_data * (1 - epsilon)
    line_data = Parameter(line_data)
    line_pred = net_D(line_data).sum()
    grad, = torch.autograd.grad(line_pred, line_data, create_graph=True)
    grad = grad.view(batch_size, -1)
    grad_norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
    return ((grad_norm - 1) ** 2).mean()
예제 #5
0
        def Fvp(v):
            kl = self.get_kl(states)
            kl = kl.mean()

            grads = torch.autograd.grad(kl,
                                        self.actor.parameters(),
                                        create_graph=True)
            flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

            kl_v = (flat_grad_kl * Variable(v)).sum()
            grads = torch.autograd.grad(kl_v, self.actor.parameters())
            flat_grad_grad_kl = torch.cat(
                [grad.contiguous().view(-1) for grad in grads]).data

            return flat_grad_grad_kl + v * self.damping
예제 #6
0
    def update_actor(self, states, actions, advantages):
        action_means, action_log_stds, action_stds = self.actor(
            Variable(states))
        fixed_log_prob = normal_log_density(Variable(actions), action_means,
                                            action_log_stds,
                                            action_stds).data.clone()

        loss = self.get_loss(states, actions, advantages, fixed_log_prob)
        grads = torch.autograd.grad(loss, self.actor.parameters())
        loss_grad = torch.cat([grad.view(-1) for grad in grads]).data

        def Fvp(v):
            kl = self.get_kl(states)
            kl = kl.mean()

            grads = torch.autograd.grad(kl,
                                        self.actor.parameters(),
                                        create_graph=True)
            flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

            kl_v = (flat_grad_kl * Variable(v)).sum()
            grads = torch.autograd.grad(kl_v, self.actor.parameters())
            flat_grad_grad_kl = torch.cat(
                [grad.contiguous().view(-1) for grad in grads]).data

            return flat_grad_grad_kl + v * self.damping

        step_dir = conjugate_gradients(Fvp, -loss_grad, self.nsteps)
        shs = 0.5 * (step_dir * Fvp(step_dir)).sum(0, keepdim=True)
        lm = torch.sqrt(shs / self.max_kl)
        fullstep = step_dir / lm[0]

        neggdotstepdir = (-loss_grad * step_dir).sum(0, keepdim=True)
        print("lagrange multiplier:", lm[0], "grad_norm:", loss_grad.norm())

        success, new_params = self.linesearch(states, actions, advantages,
                                              fixed_log_prob, fullstep,
                                              neggdotstepdir / lm[0])
        set_flat_params_to(self.actor, new_params)
예제 #7
0
    def train(self):
        """Train attribute-guided face image synthesis model"""
        self.data_loader = self.face_data_loader
        # The number of iterations for each epoch
        iters_per_epoch = len(self.data_loader)

        sample_x = []
        sample_l = []
        real_y = []
        for i, (images, landmark) in enumerate(self.data_loader):
            labels = images[1]
            sample_x.append(images[0])
            sample_l.append(landmark[0])
            real_y.append(labels)
            if i == 2:
                break

        # Sample inputs and desired domain labels for testing
        sample_x = torch.cat(sample_x, dim=0)
        sample_x = self.to_var(sample_x, volatile=True)
        sample_l = torch.cat(sample_l, dim=0)
        sample_l = self.to_var(sample_l, volatile=True)
        real_y = torch.cat(real_y, dim=0)

        sample_y_list = []
        for i in range(self.y_dim):
            sample_y = self.one_hot(
                torch.ones(sample_x.size(0)) * i, self.y_dim)
            sample_y_list.append(self.to_var(sample_y, volatile=True))

        # Learning rate for decaying
        d_lr = self.d_lr
        enc_lr = self.enc_lr
        dec_lr = self.dec_lr

        # Start with trained model
        if self.trained_model:
            start = int(self.trained_model.split('_')[0])
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (real_image, real_landmark) in enumerate(self.data_loader):
                #real_x: real image and real_l: conditional side image (landmark heatmap)
                real_x = real_image[0]
                real_label = real_image[1]
                real_l = real_landmark[0]

                # Sample fake labels randomly
                rand_idx = torch.randperm(real_label.size(0))
                fake_label = real_label[rand_idx]

                real_y = self.one_hot(real_label, self.y_dim)
                fake_y = self.one_hot(fake_label, self.y_dim)

                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_l = self.to_var(real_l)
                real_y = self.to_var(real_y)
                fake_y = self.to_var(fake_y)
                real_label = self.to_var(real_label)
                fake_label = self.to_var(fake_label)

                #================== Train Discriminator ================== #
                # Input images (original image+side images) are concatenated
                src_output, cls_output = self.D(torch.cat([real_x, real_l], 1))
                d_loss_real = -torch.mean(src_output)
                d_loss_cls = F.cross_entropy(cls_output, real_label)

                # Compute expression recognition accuracy on synthetic images
                if (i + 1) % self.log_step == 0:
                    accuracies = self.calculate_accuracy(
                        cls_output, real_label)
                    log = [
                        "{:.2f}".format(acc)
                        for acc in accuracies.data.cpu().numpy()
                    ]
                    print('Recognition Acc: ')
                    print(log)

                # Generate outputs and compute loss with fake generated images
                enc_feat = self.Enc(torch.cat([real_x, real_l], 1))
                fake_x, fake_l = self.Dec(enc_feat, fake_y)
                fake_x = Variable(fake_x.data)
                fake_l = Variable(fake_l.data)

                src_output, cls_output = self.D(torch.cat([fake_x, fake_l], 1))
                d_loss_fake = torch.mean(src_output)

                # Discriminator losses
                d_loss = self.lambda_cls * d_loss_cls + d_loss_real + d_loss_fake
                self.reset()
                d_loss.backward()
                self.d_optimizer.step()

                # Compute gradient penalty loss
                real = torch.cat([real_x, real_l], 1)
                fake = torch.cat([fake_x, fake_l], 1)
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real)
                interpolated = Variable(alpha * real.data +
                                        (1 - alpha) * fake.data,
                                        requires_grad=True)
                output, cls_output = self.D(interpolated)

                grad = torch.autograd.grad(outputs=output,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               output.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Gradient penalty loss
                d_loss = self.lambda_gp * d_loss_gp
                self.reset()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                loss = {}
                loss['D/loss_real'] = d_loss_real.data[0]
                loss['D/loss_fake'] = d_loss_fake.data[0]
                loss['D/loss_cls'] = d_loss_cls.data[0]
                loss['D/loss_gp'] = d_loss_gp.data[0]

                # ================== Train Encoder-Decoder networks ================== #
                if (i + 1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    enc_feat = self.Enc(torch.cat([real_x, real_l], 1))
                    fake_x, fake_l = self.Dec(enc_feat, fake_y)
                    src_output, cls_output = self.D(
                        torch.cat([fake_x, fake_l], 1))
                    g_loss_fake = -torch.mean(src_output)

                    #rec_feat = self.Enc(fake_x)
                    rec_feat = self.Enc(torch.cat([fake_x, fake_l], 1))
                    rec_x, rec_l = self.Dec(rec_feat, real_y)

                    # bidirectional loss of the images
                    g_loss_rec_x = torch.mean(torch.abs(real_x - rec_x))
                    g_loss_rec_l = torch.mean(torch.abs(real_l - rec_l))

                    #bidirectional loss of the latent feature
                    g_loss_feature = torch.mean(torch.abs(enc_feat - rec_feat))

                    #identity loss of the images
                    g_loss_identity_x = torch.mean(torch.abs(real_x - fake_x))
                    g_loss_identity_l = torch.mean(torch.abs(real_l - fake_l))

                    # attribute classification loss for the fake generated images
                    g_loss_cls = F.cross_entropy(cls_output, fake_label)

                    # Backward + Optimize (generator (encoder-decoder) losses), we update decoder two times for each encoder update
                    g_loss = g_loss_fake + self.lambda_bi * g_loss_rec_x + self.lambda_bi * g_loss_rec_l + self.lambda_bi * g_loss_feature + self.lambda_id * g_loss_identity_x + self.lambda_id * g_loss_identity_l + self.lambda_cls * g_loss_cls
                    self.reset()
                    g_loss.backward()
                    self.enc_optimizer.step()
                    self.dec_optimizer.step()
                    self.dec_optimizer.step()

                    # Logging Generator losses
                    loss['G/loss_feature'] = g_loss_feature.data[0]
                    loss['G/loss_identity_x'] = g_loss_identity_x.data[0]
                    loss['G/loss_identity_l'] = g_loss_identity_l.data[0]
                    loss['G/loss_rec_x'] = g_loss_rec_x.data[0]
                    loss['G/loss_rec_l'] = g_loss_rec_l.data[0]
                    loss['G/loss_fake'] = g_loss_fake.data[0]
                    loss['G/loss_cls'] = g_loss_cls.data[0]

                # Print out log
                if (i + 1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e + 1, self.num_epochs, i + 1,
                        iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value,
                                                   e * iters_per_epoch + i + 1)

                # Synthesize images
                if (i + 1) % self.sample_step == 0:
                    fake_image_list = [sample_x]
                    for sample_y in sample_y_list:
                        enc_feat = self.Enc(torch.cat([sample_x, sample_l], 1))
                        sample_result, sample_landmark = self.Dec(
                            enc_feat, sample_y)
                        fake_image_list.append(sample_result)
                    fake_images = torch.cat(fake_image_list, dim=3)
                    save_image(self.denorm(fake_images.data),
                               os.path.join(
                                   self.sample_path,
                                   '{}_{}_fake.png'.format(e + 1, i + 1)),
                               nrow=1,
                               padding=0)
                    print('Generated images and saved into {}..!'.format(
                        self.sample_path))

                # Save checkpoints
                if (i + 1) % self.model_save_step == 0:
                    torch.save(
                        self.Enc.state_dict(),
                        os.path.join(self.model_path,
                                     '{}_{}_Enc.pth'.format(e + 1, i + 1)))
                    torch.save(
                        self.Dec.state_dict(),
                        os.path.join(self.model_path,
                                     '{}_{}_Dec.pth'.format(e + 1, i + 1)))
                    torch.save(
                        self.D.state_dict(),
                        os.path.join(self.model_path,
                                     '{}_{}_D.pth'.format(e + 1, i + 1)))

            # Decay learning rate
            if (e + 1) > (self.num_epochs - self.num_epochs_decay):
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                enc_lr -= (self.enc_lr / float(self.num_epochs_decay))
                dec_lr -= (self.dec_lr / float(self.num_epochs_decay))
                self.update_lr(enc_lr, dec_lr, d_lr)
                print('Decay learning rate to enc_lr: {}, d_lr: {}.'.format(
                    enc_lr, d_lr))
    def train(self):
        """Train anomaly detection model"""
        self.data_loader = self.img_data_loader
        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader.train)

        fixed_x = []
        for i, (images, labels) in enumerate(self.data_loader.train):
            fixed_x.append(images)
            if i == 0:
                break

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)

        # Learning rate for decaying
        d_lr = self.d_lr
        g_lr = self.g_lr

        # Start with trained model
        if self.trained_model:
            start = int(self.trained_model.split('_')[0])
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (real_x, real_label) in enumerate(self.data_loader.train):
                rand_idx = torch.randperm(real_label.size(0))
                # Convert tensor to variable
                real_x = self.to_var(real_x)

                #================== Train Discriminator ================== #
                # Compute loss with real images
                out_src = self.D(real_x)
                d_loss_real = -torch.mean(out_src)

                fake_x, _, _ = self.G(real_x)
                fake_x = Variable(fake_x.data)

                out_src = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                # Discriminator losses
                d_loss = d_loss_real + d_loss_fake
                self.reset()
                d_loss.backward()
                self.d_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data +
                                        (1 - alpha) * fake_x.data,
                                        requires_grad=True)
                out = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Gradient penalty loss
                d_loss = self.lambda_gp * d_loss_gp
                self.reset()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                loss = {}
                loss['D/loss_real'] = d_loss_real.item()
                loss['D/loss_fake'] = d_loss_fake.item()
                loss['D/loss_gp'] = d_loss_gp.item()

                # ================== Train Encoder-Decoder networks ================== #
                if (i + 1) % self.d_train_repeat == 0:

                    fake_x, enc_feat, rec_feat = self.G(real_x)
                    out_src = self.D(fake_x)
                    g_loss_fake = -torch.mean(out_src)

                    g_loss_rec_x = torch.mean(torch.abs(real_x - fake_x))

                    g_loss_ssim = (0.5 *
                                   (1 - self.ssim_loss(real_x, fake_x))).clamp(
                                       0, 1)

                    g_loss_feature = torch.mean(
                        torch.pow((enc_feat - rec_feat), 2))

                    g_loss = g_loss_fake + self.lambda_f * g_loss_feature + +self.lambda_bi * g_loss_rec_x + self.lambda_ssim * g_loss_ssim
                    self.reset()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging Generator losses
                    loss['G/loss_feature'] = g_loss_feature.item()
                    loss['G/loss_image'] = g_loss_rec_x.item()
                    loss['G/loss_ssim'] = g_loss_ssim.item()
                    loss['G/loss_fake'] = g_loss_fake.item()

                # Print out log
                if (i + 1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e + 1, self.num_epochs, i + 1,
                        iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value,
                                                   e * iters_per_epoch + i + 1)

                # Reconstructed images
                if (i + 1) % self.sample_step == 0:
                    fake_image_list = [fixed_x]
                    #for fixed_c in fixed_c_list:
                    sample_result, _, _ = self.G(fixed_x)
                    fake_image_list.append(sample_result)
                    fake_images = torch.cat(fake_image_list, dim=3)
                    save_image(self.denorm(fake_images.data),
                               os.path.join(
                                   self.sample_path,
                                   '{}_{}_fake.png'.format(e + 1, i + 1)),
                               nrow=1,
                               padding=0)
                    print('Generated images and saved into {}..!'.format(
                        self.sample_path))

                # Save model checkpoints
                if (i + 1) % self.model_save_step == 0:
                    torch.save(
                        self.G.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_G.pth'.format(e + 1, i + 1)))
                    torch.save(
                        self.D.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_D.pth'.format(e + 1, i + 1)))

            # Decay learning rate
            if (e + 1) > (self.num_epochs - self.num_epochs_decay):
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print('Decay learning rate to g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))
예제 #9
0
파일: solver.py 프로젝트: dodler/StarGAN
    def train_multi(self):
        """Train StarGAN with multiple datasets.
        In the code below, 1 is related to CelebA and 2 is releated to RaFD.
        """
        # Fixed imagse and labels for debugging
        fixed_x = []
        real_c = []

        for i, (images, labels) in enumerate(self.celebA_loader):
            fixed_x.append(images)
            real_c.append(labels)
            if i == 2:
                break

        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)
        real_c = torch.cat(real_c, dim=0)
        fixed_c1_list = self.make_celeb_labels(real_c)

        fixed_c2_list = []
        for i in range(self.c2_dim):
            fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c2_dim)
            fixed_c2_list.append(self.to_var(fixed_c, volatile=True))

        fixed_zero1 = self.to_var(torch.zeros(fixed_x.size(0), self.c2_dim))     # zero vector when training with CelebA
        fixed_mask1 = self.to_var(self.one_hot(torch.zeros(fixed_x.size(0)), 2)) # mask vector: [1, 0]
        fixed_zero2 = self.to_var(torch.zeros(fixed_x.size(0), self.c_dim))      # zero vector when training with RaFD
        fixed_mask2 = self.to_var(self.one_hot(torch.ones(fixed_x.size(0)), 2))  # mask vector: [0, 1]

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # data iterator
        data_iter1 = iter(self.celebA_loader)
        data_iter2 = iter(self.rafd_loader)

        # Start with trained model
        if self.pretrained_model:
            start = int(self.pretrained_model) + 1
        else:
            start = 0

        # # Start training
        start_time = time.time()
        for i in range(start, self.num_iters):

            # Fetch mini-batch images and labels
            try:
                real_x1, real_label1 = next(data_iter1)
            except:
                data_iter1 = iter(self.celebA_loader)
                real_x1, real_label1 = next(data_iter1)

            try:
                real_x2, real_label2 = next(data_iter2)
            except:
                data_iter2 = iter(self.rafd_loader)
                real_x2, real_label2 = next(data_iter2)

            # Generate fake labels randomly (target domain labels)
            rand_idx = torch.randperm(real_label1.size(0))
            fake_label1 = real_label1[rand_idx]
            rand_idx = torch.randperm(real_label2.size(0))
            fake_label2 = real_label2[rand_idx]

            real_c1 = real_label1.clone()
            fake_c1 = fake_label1.clone()
            zero1 = torch.zeros(real_x1.size(0), self.c2_dim)
            mask1 = self.one_hot(torch.zeros(real_x1.size(0)), 2)

            real_c2 = self.one_hot(real_label2, self.c2_dim)
            fake_c2 = self.one_hot(fake_label2, self.c2_dim)
            zero2 = torch.zeros(real_x2.size(0), self.c_dim)
            mask2 = self.one_hot(torch.ones(real_x2.size(0)), 2)

            # Convert tensor to variable
            real_x1 = self.to_var(real_x1)
            real_c1 = self.to_var(real_c1)
            fake_c1 = self.to_var(fake_c1)
            mask1 = self.to_var(mask1)
            zero1 = self.to_var(zero1)

            real_x2 = self.to_var(real_x2)
            real_c2 = self.to_var(real_c2)
            fake_c2 = self.to_var(fake_c2)
            mask2 = self.to_var(mask2)
            zero2 = self.to_var(zero2)

            real_label1 = self.to_var(real_label1)
            fake_label1 = self.to_var(fake_label1)
            real_label2 = self.to_var(real_label2)
            fake_label2 = self.to_var(fake_label2)

            # ================== Train D ================== #

            # Real images (CelebA)
            out_real, out_cls = self.D(real_x1)
            out_cls1 = out_cls[:, :self.c_dim]      # celebA part
            d_loss_real = - torch.mean(out_real)
            d_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, real_label1, size_average=False) / real_x1.size(0)

            # Real images (RaFD)
            out_real, out_cls = self.D(real_x2)
            out_cls2 = out_cls[:, self.c_dim:]      # rafd part
            d_loss_real += - torch.mean(out_real)
            d_loss_cls += F.cross_entropy(out_cls2, real_label2)

            # Compute classification accuracy of the discriminator
            if (i+1) % self.log_step == 0:
                accuracies = self.compute_accuracy(out_cls1, real_label1, 'CelebA')
                log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', '')
                print(log)
                accuracies = self.compute_accuracy(out_cls2, real_label2, 'RaFD')
                log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                print('Classification Acc (8 emotional expressions): ', '')
                print(log)

            # Fake images (CelebA)
            fake_c = torch.cat([fake_c1, zero1, mask1], dim=1)
            fake_x1 = self.G(real_x1, fake_c)
            fake_x1 = Variable(fake_x1.data)
            out_fake, _ = self.D(fake_x1)
            d_loss_fake = torch.mean(out_fake)

            # Fake images (RaFD)
            fake_c = torch.cat([zero2, fake_c2, mask2], dim=1)
            fake_x2 = self.G(real_x2, fake_c)
            out_fake, _ = self.D(fake_x2)
            d_loss_fake += torch.mean(out_fake)

            # Backward + Optimize
            d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Compute gradient penalty
            if (i+1) % 2 == 0:
                real_x = real_x1
                fake_x = fake_x1
            else:
                real_x = real_x2
                fake_x = fake_x2

            alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
            interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
            out, out_cls = self.D(interpolated)

            if (i+1) % 2 == 0:
                out_cls = out_cls[:, :self.c_dim]  # CelebA
            else:
                out_cls = out_cls[:, self.c_dim:]  # RaFD

            grad = torch.autograd.grad(outputs=out,
                                       inputs=interpolated,
                                       grad_outputs=torch.ones(out.size()).cuda(),
                                       retain_graph=True,
                                       create_graph=True,
                                       only_inputs=True)[0]

            grad = grad.view(grad.size(0), -1)
            grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
            d_loss_gp = torch.mean((grad_l2norm - 1)**2)

            # Backward + Optimize
            d_loss = self.lambda_gp * d_loss_gp
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Logging
            loss = {}
            loss['D/loss_real'] = d_loss_real.data[0]
            loss['D/loss_fake'] = d_loss_fake.data[0]
            loss['D/loss_cls'] = d_loss_cls.data[0]
            loss['D/loss_gp'] = d_loss_gp.data[0]

            # ================== Train G ================== #
            if (i+1) % self.d_train_repeat == 0:
                # Original-to-target and target-to-original domain (CelebA)
                fake_c = torch.cat([fake_c1, zero1, mask1], dim=1)
                real_c = torch.cat([real_c1, zero1, mask1], dim=1)
                fake_x1 = self.G(real_x1, fake_c)
                rec_x1 = self.G(fake_x1, real_c)

                # Compute losses
                out, out_cls = self.D(fake_x1)
                out_cls1 = out_cls[:, :self.c_dim]
                g_loss_fake = - torch.mean(out)
                g_loss_rec = torch.mean(torch.abs(real_x1 - rec_x1))
                g_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, fake_label1, size_average=False) / fake_x1.size(0)

                # Original-to-target and target-to-original domain (RaFD)
                fake_c = torch.cat([zero2, fake_c2, mask2], dim=1)
                real_c = torch.cat([zero2, real_c2, mask2], dim=1)
                fake_x2 = self.G(real_x2, fake_c)
                rec_x2 = self.G(fake_x2, real_c)

                # Compute losses
                out, out_cls = self.D(fake_x2)
                out_cls2 = out_cls[:, self.c_dim:]
                g_loss_fake += - torch.mean(out)
                g_loss_rec += torch.mean(torch.abs(real_x2 - rec_x2))
                g_loss_cls += F.cross_entropy(out_cls2, fake_label2)

                # Backward + Optimize
                g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + self.lambda_rec * g_loss_rec
                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging
                loss['G/loss_fake'] = g_loss_fake.data[0]
                loss['G/loss_cls'] = g_loss_cls.data[0]
                loss['G/loss_rec'] = g_loss_rec.data[0]

            # Print out log info
            if (i+1) % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))

                log = "Elapsed [{}], Iter [{}/{}]".format(
                    elapsed, i+1, self.num_iters)

                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i+1)

            # Translate the images (debugging)
            if (i+1) % self.sample_step == 0:
                fake_image_list = [fixed_x]

                # Changing hair color, gender, and age
                for j in range(self.c_dim):
                    fake_c = torch.cat([fixed_c1_list[j], fixed_zero1, fixed_mask1], dim=1)
                    fake_image_list.append(self.G(fixed_x, fake_c))
                # Changing emotional expressions
                for j in range(self.c2_dim):
                    fake_c = torch.cat([fixed_zero2, fixed_c2_list[j], fixed_mask2], dim=1)
                    fake_image_list.append(self.G(fixed_x, fake_c))
                fake = torch.cat(fake_image_list, dim=3)

                # Save the translated images
                save_image(self.denorm(fake.data.cpu()),
                    os.path.join(self.sample_path, '{}_fake.png'.format(i+1)), nrow=1, padding=0)

            # Save model checkpoints
            if (i+1) % self.model_save_step == 0:
                torch.save(self.G.state_dict(),
                    os.path.join(self.model_save_path, '{}_G.pth'.format(i+1)))
                torch.save(self.D.state_dict(),
                    os.path.join(self.model_save_path, '{}_D.pth'.format(i+1)))

            # Decay learning rate
            decay_step = 1000
            if (i+1) > (self.num_iters - self.num_iters_decay) and (i+1) % decay_step==0:
                g_lr -= (self.g_lr / float(self.num_iters_decay) * decay_step)
                d_lr -= (self.d_lr / float(self.num_iters_decay) * decay_step)
                self.update_lr(g_lr, d_lr)
                print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
예제 #10
0
    def train(self):
        """Train StarGAN within a single dataset."""

        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        fixed_x = []
        real_c = []
        for i, (images, labels) in enumerate(self.data_loader):
            fixed_x.append(images)
            real_c.append(labels)
            if i == 3:
                break

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)
        real_c = torch.cat(real_c, dim=0)

        fixed_c_list = self.make_data_labels(real_c)

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (real_x, real_label) in enumerate(self.data_loader):

                # Generat fake labels randomly (target domain labels)
                rand_idx = torch.randperm(real_label.size(0))
                fake_label = real_label[rand_idx]

                real_c = real_label.clone()
                fake_c = fake_label.clone()

                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_c = self.to_var(real_c)  # input for the generator
                fake_c = self.to_var(fake_c)
                real_label = self.to_var(
                    real_label
                )  # this is same as real_c if dataset == 'CelebA'
                fake_label = self.to_var(fake_label)

                # ================== Train D ================== #

                # Compute loss with real images
                out_src, out_cls = self.D(real_x)
                d_loss_real = -torch.mean(out_src)

                d_loss_cls = F.binary_cross_entropy_with_logits(
                    out_cls, real_label, size_average=False) / real_x.size(0)

                # Compute classification accuracy of the discriminator
                if (i + 1) % self.log_step == 0:
                    accuracies = self.compute_accuracy(out_cls, real_label)
                    log = [
                        "{:.2f}".format(acc)
                        for acc in accuracies.data.cpu().numpy()
                    ]
                    print('Classification Acc: ')
                    print(log)

                # Compute loss with fake images
                fake_x = self.G(real_x, fake_c)
                fake_x = Variable(fake_x.data)
                out_src, out_cls = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data +
                                        (1 - alpha) * fake_x.data,
                                        requires_grad=True)
                out, out_cls = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                loss = {}
                loss['D/loss_real'] = d_loss_real.data[0]
                loss['D/loss_fake'] = d_loss_fake.data[0]
                loss['D/loss_cls'] = d_loss_cls.data[0]
                loss['D/loss_gp'] = d_loss_gp.data[0]

                # ================== Train G ================== #
                if (i + 1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    fake_x = self.G(real_x, fake_c)
                    rec_x = self.G(fake_x, real_c)

                    # Compute losses
                    out_src, out_cls = self.D(fake_x)
                    g_loss_fake = -torch.mean(out_src)
                    g_loss_rec = torch.mean(torch.abs(real_x - rec_x))

                    g_loss_cls = F.binary_cross_entropy_with_logits(
                        out_cls, fake_label,
                        size_average=False) / fake_x.size(0)

                    # Backward + Optimize
                    g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    loss['G/loss_fake'] = g_loss_fake.data[0]
                    loss['G/loss_rec'] = g_loss_rec.data[0]
                    loss['G/loss_cls'] = g_loss_cls.data[0]

                # Print out log info
                if (i + 1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e + 1, self.num_epochs, i + 1,
                        iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(
                                tag, value, e * iters_per_epoch + i + 1)

                # Translate fixed images for debugging
                if (i + 1) % self.sample_step == 0:
                    fake_image_list = [fixed_x]
                    for fixed_c in fixed_c_list:
                        fake_image_list.append(self.G(fixed_x, fixed_c))
                    fake_images = torch.cat(fake_image_list, dim=3)
                    save_image(self.denorm(fake_images.data.cpu()),
                               os.path.join(
                                   self.sample_path,
                                   '{}_{}_fake.png'.format(e + 1, i + 1)),
                               nrow=1,
                               padding=0)
                    print('Translated images and saved into {}..!'.format(
                        self.sample_path))

                # Save model checkpoints
                if (i + 1) % self.model_save_step == 0:
                    torch.save(
                        self.G.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_G.pth'.format(e + 1, i + 1)))
                    torch.save(
                        self.D.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_D.pth'.format(e + 1, i + 1)))

            # Decay learning rate
            if (e + 1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print('Decay learning rate to g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))
예제 #11
0
    def train(self):
        """Train StarGAN within a single dataset."""

        # Set dataloader
        self.data_loader = self.dataset1_loader

        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        fixed_x = []
        real_c = []
        for i, (images, labels) in enumerate(self.data_loader):
            fixed_x.append(images)
            real_c.append(labels)
            if i == 0:
                break

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, requires_grad=False)
        real_c = torch.cat(real_c, dim=0)

        fixed_c_list = []
        for i in range(self.c_dim):
            fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c_dim)
            fixed_c_list.append(self.to_var(fixed_c, requires_grad=False))

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start with trained model if exists
        if self.resume:
            start = int(self.resume.split('_')[0])
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (real_x, real_label) in enumerate(self.data_loader):
                
                # Generat fake labels randomly (target domain labels)
                rand_idx = torch.randperm(real_label.size(0))
                fake_label = real_label[rand_idx]

                real_c = self.one_hot(real_label, self.c_dim)
                fake_c = self.one_hot(fake_label, self.c_dim)

                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_c = self.to_var(real_c)           # input for the generator
                fake_c = self.to_var(fake_c)
                real_label = self.to_var(real_label, requires_grad=False)   # this is same as real_c if dataset == 'CelebA'
                fake_label = self.to_var(fake_label, requires_grad=False)
                
                # ================== Train D ================== #

                # Compute loss with real images
                out_real, out_cls, out_reg = self.D(real_x, real_label)
                # Compute loss with fake images
                fake_x = self.G(real_x, fake_c).detach()
                out_fake, _, _ = self.D(fake_x, fake_label)

                # d_loss_adv = loss_hinge_dis(out_fake, out_real)
                d_loss_adv = -torch.mean(out_real) + torch.mean(out_fake)
                d_loss_cls = F.cross_entropy(out_cls, real_label)
                # todo:regression
                d_loss_reg = loss_hard_reg(out_reg, real_label, self.c_dim)

                # Compute classification accuracy of the discriminator
                if (i+1) % self.log_step == 0:
                    classification_accuracies = self.compute_accuracy(out_cls, real_label, n_classes=self.c_dim)
                    regression_accuracies = self.compute_accuracy(out_reg.squeeze(), real_label, n_classes=self.c_dim)

                    log = "{:.2f}/{:.2f}".format(classification_accuracies.data.cpu().numpy(), regression_accuracies.data.cpu().numpy())
                    print('Classification/regression Acc: ', end='')
                    print(log)


                # Backward + Optimize
                d_loss = d_loss_adv + self.lambda_cls * d_loss_cls + self.lambda_reg * d_loss_reg
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
                out, _, _ = self.D(interpolated, real_label)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()


                # Logging
                loss = collections.OrderedDict()
                loss['D/loss_adv'] = d_loss_adv.data
                loss['D/loss_reg'] = d_loss_reg.data
                loss['D/loss_gp'] = d_loss_gp.data
                loss['D/loss_cls'] = d_loss_cls.data

                # todo
                if self.dataset2_loader is not None:
                    pExample1_img, pExample2_img, pExample1_lbl, pExample2_lbl, nExample1_img, nExample2_img, nExample1_lbl, nExample2_lbl = iter(
                        self.dataset2_loader).next()

                    # Generat fake labels randomly (target domain labels)

                    pExample1_c = self.one_hot(pExample1_lbl, self.c_dim)
                    pExample2_c = self.one_hot(pExample2_lbl, self.c_dim)
                    nExample1_c = self.one_hot(nExample1_lbl, self.c_dim)
                    nExample2_c = self.one_hot(nExample2_lbl, self.c_dim)

                    # Convert tensor to variable
                    pExample1_img = self.to_var(pExample1_img)
                    pExample2_img = self.to_var(pExample2_img)
                    nExample1_img = self.to_var(nExample1_img)
                    nExample2_img = self.to_var(nExample2_img)

                    pExample1_c = self.to_var(pExample1_c, self.c_dim)
                    pExample2_c = self.to_var(pExample2_c, self.c_dim)
                    nExample1_c = self.to_var(nExample1_c, self.c_dim)
                    nExample2_c = self.to_var(nExample2_c, self.c_dim)

                    pExample1_lbl = self.to_var(pExample1_lbl, requires_grad=False)
                    pExample2_lbl = self.to_var(pExample2_lbl, requires_grad=False)
                    nExample1_lbl = self.to_var(nExample1_lbl, requires_grad=False)
                    nExample2_lbl = self.to_var(nExample2_lbl, requires_grad=False)

                    # ================== Train D2 ================== #

                    # Compute loss with real example
                    out_real, mu1_real, logvar1_real, mu2_real, logvar2_real = self.D2(pExample1_img, pExample2_img, pExample1_lbl,
                                                                   pExample2_lbl)
                    # Compute loss with negtive example
                    out_neg, mu1_neg, logvar1_neg, mu2_neg, logvar2_neg = self.D2(nExample1_img, nExample2_img, nExample1_lbl,
                                                                    nExample2_lbl)
                    # no projection
                    # out_neg, mu1_neg, logvar1_neg, mu2_neg, logvar2_neg = self.D2(nExample1_img, nExample2_img)

                    # # Compute loss with real example
                    # out_real = self.D2(pExample1_img, pExample2_img, pExample1_lbl, pExample2_lbl)
                    # # Compute loss with negtive example
                    # out_neg = self.D2(nExample1_img, nExample2_img, nExample1_lbl, nExample2_lbl)

                    # Compute loss with fake example
                    # fExample2_img = self.G(pExample1_img, pExample2_c).detach()
                    # out_fake, _, _, mu2_fake, logvar2_fake = self.D2(pExample1_img, fExample2_img, pExample1_lbl, pExample2_lbl)

                    fExample2_img = self.G(pExample1_img, pExample2_c).detach()
                    out_fake, _, _, mu2_fake, logvar2_fake = self.D2(pExample1_img, fExample2_img, pExample1_lbl,
                                                                     pExample2_lbl)

                    # unilateral projection
                    # out_fake, _, _, mu2_fake, logvar2_fake = self.D2(pExample1_img, fExample2_img, None,
                    #                                                  pExample2_lbl)

                    kl_real = loss_kl(mu1_real, logvar1_real) + loss_kl(mu2_real, logvar2_real)
                    kl_neg = loss_kl(mu1_neg, logvar1_neg) + loss_kl(mu2_neg, logvar2_neg)
                    kl_fake = loss_kl(mu2_fake, logvar2_fake)

                    kl = (kl_fake + kl_neg + kl_real) * 0.2
                    # d_loss_adv = loss_hinge_dis(out_fake, out_real)
                    d2_loss_adv = -torch.mean(out_real) + (torch.mean(out_fake) + torch.mean(out_neg))*0.5

                    d2_loss = d2_loss_adv + kl*self.lambda_kl

                    # Backward + Optimize
                    self.reset_grad()
                    d2_loss.backward()
                    self.d2_optimizer.step()

                    # Compute gradient penalty
                    alpha = torch.rand(nExample2_img.size(0), 1, 1, 1).cuda().expand_as(nExample2_img)
                    interpolated = Variable(alpha * nExample2_img.data + (1 - alpha) * fExample2_img.data, requires_grad=True)
                    out, _, _, _, _ = self.D2(pExample1_img, interpolated, pExample1_lbl, pExample2_lbl)

                    grad = torch.autograd.grad(outputs=out,
                                               inputs=interpolated,
                                               grad_outputs=torch.ones(out.size()).cuda(),
                                               retain_graph=True,
                                               create_graph=True,
                                               only_inputs=True)[0]

                    grad = grad.view(grad.size(0), -1)
                    grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                    d2_loss_gp = torch.mean((grad_l2norm - 1) ** 2)

                    # Backward + Optimize
                    d2_loss = self.lambda_gp * d2_loss_gp
                    self.reset_grad()
                    d2_loss.backward()
                    self.d2_optimizer.step()

                    # Logging
                    loss['D2/loss_adv'] = d2_loss_adv.data
                    loss['D2/d2_loss_gp'] = d2_loss_gp.data
                    loss['D2/loss_kl'] = kl.data

                # ================== Train G ================== #
                if (i+1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    fake_x = self.G(real_x, fake_c)
                    # todo
                    rec_x  = self.G(real_x, real_c)

                    # Compute losses
                    out_fake, out_cls, out_reg = self.D(fake_x, fake_label)
                    g_loss_adv = -torch.mean(out_fake)
                    g_loss_rec = torch.mean(torch.abs(real_x - rec_x))
                    g_loss_cls = F.cross_entropy(out_cls, fake_label)
                    g_loss_reg = loss_hard_reg(out_reg, fake_label, self.c_dim)

                    # todo
                    if self.dataset2_loader is not None:
                        # fExample2_img = self.G(pExample1_img, pExample2_c)
                        # out_fake, _, _, _, _ = self.D2(pExample1_img, fExample2_img,
                        #                                pExample1_lbl,
                        #                                pExample2_lbl)
                        fExample2_img = self.G(pExample1_img, nExample2_c)
                        out_fake, _, _, _, _ = self.D2(pExample1_img, fExample2_img,
                                                       pExample1_lbl,
                                                       nExample2_lbl)
                        # unilateral projection
                        # out_fake, _, _, _, _ = self.D2(pExample1_img, fExample2_img,
                        #                                None,
                        #                                nExample2_lbl)
                        g_loss_adv2 = -torch.mean(out_fake)
                    else:
                        g_loss_adv2 = 0

                    # Backward + Optimize
                    g_loss = g_loss_adv + self.lambda_cls * g_loss_cls + self.lambda_rec * g_loss_rec + self.lambda_reg * g_loss_reg + g_loss_adv2
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    loss['G/loss_adv'] = g_loss_adv.data
                    loss['G/loss_rec'] = g_loss_rec.data
                    loss['G/loss_reg'] = g_loss_reg.data
                    loss['G/loss_cls'] = g_loss_cls.data
                    loss['G/loss_adv2'] = g_loss_adv2



                # Print out log info
                if (i+1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e+1, self.num_epochs, i+1, iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1)

                # Translate fixed images for debugging
                if (i+1) % self.sample_step == 0:
                    fake_image_list = [fixed_x]
                    for fixed_c in fixed_c_list:
                        fake_image_list.append(self.G(fixed_x, fixed_c).detach())
                    fake_images = torch.cat(fake_image_list, dim=3)
                    save_image(self.denorm(fake_images.data),
                        os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0)
                    if self.dataset2_loader is not None:
                        pair_images = torch.cat([pExample1_img, pExample2_img, fExample2_img], dim=3).detach()
                        if self.debug:
                            pair_images = torch.cat([pair_images, nExample1_img, nExample2_img], dim=3).detach()
                        save_image(self.denorm(pair_images.data),
                                   os.path.join(self.sample_path, '{}_{}_pair_images.png'.format(e + 1, i + 1)), nrow=1, padding=0)

                    print('Translated images and saved into {}..!'.format(self.sample_path))

            # Save model checkpoints
            if (e+1) % self.model_save_step == 0 and (e+1) > self.model_save_star:
                print('Save model checkpoints')
                torch.save(self.G.state_dict(),
                    os.path.join(self.model_save_path, '{}_G.pth'.format(e+1)))
                torch.save(self.D.state_dict(),
                    os.path.join(self.model_save_path, '{}_D.pth'.format(e+1)))
                if self.dataset2_loader is not None:
                    torch.save(self.D2.state_dict(),
                               os.path.join(self.model_save_path, '{}_D2.pth'.format(e + 1)))

                intra_fid = calculate_intra_fid(self.eval_path, self.eval_batchsize, True, self.dims, self.eval_model,
                                                self.G, self.eval_loader)
                log = 'TEST Epoch [{}/{}]'.format(e+1, self.num_epochs)
                for tag, value in intra_fid.items():
                    log += ", {}: {:.4f}".format(tag, value)
                test_log_path = os.path.join(self.log_path, 'test.log')
                with open(test_log_path, 'a') as f:
                    f.write(log)
                    f.write('\n')
                print(log)

            # Decay learning rate
            if (e+1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
예제 #12
0
    def train(self):
        """Train StarGAN within a single dataset."""

        # Set dataloader

        self.data_loader = self.celebA_loader

        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        # Fixed latent vector and label for output samples
        fixed_size = 20
        fixed_z = torch.randn(fixed_size, self.z_dim)
        fixed_z = self.to_var(fixed_z, volatile=True)

        fixed_c_list = self.make_celeb_labels_test()

        fixed_z_repeat = fixed_z.repeat(len(fixed_c_list), 1)
        fixed_c_repeat_list = []
        for fixed_c in fixed_c_list:
            fixed_c_repeat_list.append(
                fixed_c.expand(fixed_size, fixed_c.size(1)))
        fixed_c_list = []
        fixed_c_repeat = torch.cat(fixed_c_repeat_list, dim=0)
        fixed_c_repeat_list = []
        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0]) - 1
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            epoch_iter = 0
            for i, (real_x, real_label) in enumerate(self.data_loader):
                epoch_iter = epoch_iter + 1
                if self.dataset == 'Fashion':
                    real_c_i = real_label_i.clone()
                real_c = real_label.clone()
                # rand_idx = torch.randperm(real_c.size(0))
                # fake_c = real_c[rand_idx]

                z = torch.randn(real_x.size(0), self.z_dim)
                z = self.to_var(z)
                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_c = self.to_var(real_c)  # input for the generator
                if self.dataset == 'Fashion':
                    real_c_i = self.to_var(real_c_i)
                # fake_c = self.to_var(fake_c, volatile=True)
                # ================== Train D ================== #

                # Compute loss with real images
                out_src, out_cls = self.D(real_x)
                d_loss_real = -torch.mean(out_src)
                # print(real_x.size())
                # print(out_src.size())
                # print(out_cls.size())
                # print(real_c.size())
                if self.dataset == 'CelebA':
                    d_loss_cls = F.binary_cross_entropy_with_logits(
                        out_cls, real_c, size_average=False) / real_x.size(0)
                elif self.dataset == 'Fashion':
                    d_loss_cls = F.cross_entropy(out_cls, real_c_i)

                # # Compute classification accuracy of the discriminator
                # if (i+1) % self.log_step == 0:
                #     accuracies = self.compute_accuracy(out_cls, real_c, self.dataset)
                #     log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                #     if self.dataset == 'CelebA':
                #         print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='')
                #     else:
                #         print('Classification Acc (8 emotional expressions): ', end='')
                #     print(log)

                # Compute loss with fake images
                fake_x = self.G(z, real_c)
                fake_x = Variable(fake_x.data)
                out_src, out_cls = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data +
                                        (1 - alpha) * fake_x.data,
                                        requires_grad=True)
                out, out_cls = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                loss = {}
                loss['D/loss_real'] = d_loss_real.data[0]
                loss['D/loss_fake'] = d_loss_fake.data[0]
                loss['D/loss_cls'] = d_loss_cls.data[0]
                loss['D/loss_gp'] = d_loss_gp.data[0]

                # ================== Train G ================== #
                if (i + 1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    fake_x = self.G(z, real_c)
                    # fake_x2 = self.G(z, fake_c)
                    # Compute losses
                    out_src, out_cls = self.D(fake_x)
                    g_loss_fake = -torch.mean(out_src)
                    if self.dataset == 'CelebA':
                        g_loss_cls = F.binary_cross_entropy_with_logits(
                            out_cls, real_c,
                            size_average=False) / fake_x.size(0)
                    elif self.dataset == 'Fashion':
                        g_loss_cls = F.cross_entropy(out_cls, real_c_i)
                    # Backward + Optimize
                    g_loss = g_loss_fake + self.lambda_cls * g_loss_cls
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    loss['G/loss_fake'] = g_loss_fake.data[0]
                    loss['G/loss_cls'] = g_loss_cls.data[0]

                if (i + 1) % self.visual_step == 0:
                    # save visuals
                    self.real_x = real_x
                    self.fake_x = fake_x
                    # self.fake_x2 = fake_x2

                    # save losses
                    self.d_real = -d_loss_real
                    self.d_fake = d_loss_fake
                    self.d_loss = d_loss
                    self.g_loss = g_loss
                    self.g_loss_fake = g_loss_fake
                    self.g_loss_cls = self.lambda_cls * g_loss_cls
                    self.d_loss_cls = self.lambda_cls * d_loss_cls
                    errors_D = self.get_current_errors('D')
                    errors_G = self.get_current_errors('G')
                    self.visualizer.display_current_results(
                        self.get_current_visuals(), e)
                    self.visualizer.plot_current_errors_D(
                        e,
                        float(epoch_iter) / float(iters_per_epoch), errors_D)
                    self.visualizer.plot_current_errors_G(
                        e,
                        float(epoch_iter) / float(iters_per_epoch), errors_G)
                # Print out log info
                if (i + 1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e + 1, self.num_epochs, i + 1,
                        iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(
                                tag, value, e * iters_per_epoch + i + 1)

                # Translate fixed images for debugging
#                 if (i+1) % self.sample_step == 0:
# #                     fake_image_list = []
# #                     for fixed_c in fixed_c_list:
# #                         fixed_c = fixed_c.expand(fixed_z.size(0), fixed_c.size(1))
# #                         fake_image_list.append(self.G(fixed_z, fixed_c))

# #                     fake_images = torch.cat(fake_image_list, dim=3)
# #                     save_image(self.denorm(fake_images.data),
# #                         os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0)
# #                     print('Translated images and saved into {}..!'.format(self.sample_path))

#                     fake_images_repeat = self.G(fixed_z_repeat, fixed_c_repeat)
#                     fake_image_list = []
#                     for idx in range(12):
#                         fake_image_list.append(fake_images_repeat[fixed_size*(idx):fixed_size*(idx+1)])
#                     fake_images = torch.cat(fake_image_list, dim=3)
#                     save_image(self.denorm(fake_images.data),
#                         os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0)
#                     print('Translated images and saved into {}..!'.format(self.sample_path))

# Save model checkpoints
                if (i + 1) % self.model_save_step == 0:
                    torch.save(
                        self.G.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_G.pth'.format(e + 1, i + 1)))
                    torch.save(
                        self.D.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_D.pth'.format(e + 1, i + 1)))

            # Decay learning rate
            if (e + 1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print('Decay learning rate to g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))
예제 #13
0
    def train(self):
        """Train StarGAN within a single dataset."""
        self.criterionL1 = torch.nn.L1Loss()
        # self.criterionL2 = torch.nn.MSELoss()
        self.criterionTV = TVLoss()

        self.data_loader = self.Msceleb_loader
        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        fixed_x = []
        real_c = []
        for i, (aug_images, aug_labels, _, _) in enumerate(self.data_loader):
            fixed_x.append(aug_images)
            real_c.append(aug_labels)
            if i == 3:
                break

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (aug_x, aug_label, origin_x, origin_label) in enumerate(self.data_loader):

                # Generat fake labels randomly (target domain labels)
                # aug_c = self.one_hot(aug_label, self.c_dim)
                # origin_c = self.one_hot(origin_label, self.c_dim)

                aug_c_V = self.to_var(aug_label)
                origin_c_V = self.to_var(origin_label)

                aug_x = self.to_var(aug_x)
                origin_x = self.to_var(origin_x)

                # # ================== Train D ================== #
                # Compute loss with real images
                out_src = self.D(origin_x)
                out_cls = self.C(origin_x)
                d_loss_real = - torch.mean(out_src)

                c_loss_cls = F.cross_entropy(out_cls, origin_c_V)
                # Compute classification accuracy of the discriminator
                if (i+1) % self.log_step == 0:
                    accuracies = self.compute_accuracy(out_cls, origin_c_V)
                    log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                    print('Classification Acc (75268 ids): ')
                    print(log)

                # Compute loss with fake images
                fake_x = self.G(aug_x)
                fake_x = Variable(fake_x.data)
                out_src = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake
                c_loss = self.lambda_cls * c_loss_cls


                self.reset_grad()
                d_loss.backward()
                c_loss.backward()
                self.d_optimizer.step()
                self.c_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(origin_x.size(0), 1, 1, 1).cuda().expand_as(origin_x)
                interpolated = Variable(alpha * origin_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
                out = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                loss = {}
                loss['D/loss_real'] = d_loss_real.data[0]
                loss['D/loss_fake'] = d_loss_fake.data[0]
                loss['D/loss_gp'] = d_loss_gp.data[0]
                loss['C/loss_cls'] = c_loss_cls.data[0]

                # ================== Train G ================== #
                if (i+1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    fake_x = self.G(aug_x)

                    # Compute losses
                    out_src = self.D(fake_x)
                    out_cls = self.C(fake_x)
                    g_loss_fake = - torch.mean(out_src)

                    g_loss_cls = F.cross_entropy(out_cls, aug_c_V)

                    # Backward + Optimize
                    recon_loss = self.criterionL1(fake_x, aug_x)
                    TV_loss = self.criterionTV(fake_x) * 0.001

                    g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + 5* recon_loss + TV_loss

                    # if self.lambda_face > 0.0:
                    #     self.criterionFace = nn.L1Loss()
                    #
                    #     real_input_x = (torch.sum(real_x, 1, keepdim=True) / 3.0 + 1) / 2.0
                    #     fake_input_x = (torch.sum(fake_x, 1, keepdim=True) / 3.0 + 1) / 2.0
                    #     rec_input_x = (torch.sum(rec_x, 1, keepdim=True) / 3.0 + 1) / 2.0
                    #
                    #     _, real_x_feature_fc, real_x_feature_conv = self.Face_recognition_network.forward(
                    #         real_input_x)
                    #     _, fake_x_feature_fc, fake_x_feature_conv = self.Face_recognition_network.forward(
                    #         fake_input_x)
                    #     _, rec_x1_feature_fc, rec_x1_feature_conv = self.Face_recognition_network.forward(rec_input_x)
                    #     # x1_loss = (self.criterionFace(fake_x1_feature_fc, Variable(real_x1_feature_fc.data,requires_grad=False)) +
                    #     #            self.criterionFace(fake_x1_feature_conv,Variable(real_x1_feature_conv.data,requires_grad=False)))\
                    #     #            * self.lambda_face
                    #     x_loss = (self.criterionFace(fake_x_feature_fc,Variable(real_x_feature_fc.data, requires_grad=False))) \
                    #               * self.lambda_face
                    #
                    #     rec_x_loss = (self.criterionFace(rec_x1_feature_fc, Variable(real_x_feature_fc.data, requires_grad=False)))
                    #
                    #     self.id_loss = x_loss + rec_x_loss
                    #     loss['G/id_loss'] = self.id_loss.data[0]
                    #     g_loss += self.id_loss

                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    loss['G/loss_fake'] = g_loss_fake.data[0]
                    loss['G/loss_cls'] = g_loss_cls.data[0]

                # Print out log info
                if (i+1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e+1, self.num_epochs, i+1, iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)


                # Translate fixed images for debugging
                if (i+1) % self.sample_step == 0:
                    fake_image_list = [fixed_x]

                    fake_image_list.append(self.G(fixed_x))

                    fake_images = torch.cat(fake_image_list, dim=3)
                    save_image(self.denorm(fake_images.data),
                        os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0)
                    print('Translated images and saved into {}..!'.format(self.sample_path))

                # Save model checkpoints
                if (i+1) % self.model_save_step == 0:
                    torch.save(self.G.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e+1, i+1)))
                    torch.save(self.D.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e+1, i+1)))
                    torch.save(self.C.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_C.pth'.format(e+1, i+1)))


            # Decay learning rate
            if (e+1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))


            torch.save(self.G.state_dict(),
                           os.path.join(self.model_save_path, '{}_final_G.pth'.format(e + 1)))
            torch.save(self.D.state_dict(),
                           os.path.join(self.model_save_path, '{}_final_D.pth'.format(e + 1)))
            torch.save(self.C.state_dict(),
                           os.path.join(self.model_save_path, '{}_final_C.pth'.format(e + 1)))
예제 #14
0
    def train(self):

        # if self.config.visualize:
        visualizer = Visualizer()
        """Train StarGAN within a single dataset."""

        # Set dataloader
        self.data_loader = self.train_loader

        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        fixed_x = []
        real_c = []
        for i, (imgs, labels, _) in enumerate(self.data_loader):
            fixed_x.append(imgs[0])
            real_c.append(labels)
            if i == 0:
                break

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)
        real_c = torch.cat(real_c, dim=0)

        fixed_c_list = self.make_celeb_labels(self.config.batch_size)

        # lr cache for decaying
        g_lr = self.config.g_lr
        d_lr = self.config.d_lr

        # Start with trained model if exists
        if self.config.pretrained_model:
            start = int(self.config.pretrained_model.split('_')[0]) - 1
        else:
            start = 0

        # Start training
        self.loss = {}
        start_time = time.time()

        for e in range(start, self.config.num_epochs):
            self.test(e)
            for i, (images, real_label,
                    identity) in enumerate(self.data_loader):

                real_x = images[0]

                if self.config.use_si:
                    real_ox = self.to_var(images[1])
                    real_oo = self.to_var(images[2])

                if self.config.id_cls_loss == 'cross':
                    identity = identity.squeeze()

                # Generate fake labels randomly (target domain labels)
                rand_idx = torch.randperm(real_label.size(0))

                fake_label = real_label[rand_idx]

                real_c = real_label.clone()
                fake_c = fake_label.clone()

                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_c = self.to_var(real_c)  # input for the generator
                fake_c = self.to_var(fake_c)
                real_label = self.to_var(
                    real_label
                )  # this is same as real_c if dataset == 'CelebA'
                fake_label = self.to_var(fake_label)
                identity = self.to_var(identity)

                # ================== Train D ================== #

                # Compute loss with real images
                if self.config.loss_id_cls:
                    out_src, out_cls, out_id_real = self.D(real_x)
                else:
                    out_src, out_cls = self.D(real_x)

                d_loss_real = -torch.mean(out_src)

                d_loss_cls = F.binary_cross_entropy_with_logits(
                    out_cls, real_label, size_average=False) / real_x.size(0)

                if self.config.loss_id_cls:
                    d_loss_id_cls = self.id_cls_criterion(
                        out_id_real, identity)
                    self.loss[
                        'D/loss_id_cls'] = self.config.lambda_id_cls * d_loss_id_cls.data[
                            0]
                else:
                    d_loss_id_cls = 0.0

                # Compute classification accuracy of the discriminator
                if (i + 1) % self.config.log_step == 0:
                    accuracies = self.compute_accuracy(out_cls.detach(),
                                                       real_label,
                                                       self.config.dataset)
                    log = [
                        "{:.2f}".format(acc)
                        for acc in accuracies.data.cpu().numpy()
                    ]
                    print('Classification Acc (20 classes): ')
                    print(log)
                    print('\n')

                # Compute loss with fake images
                if self.config.use_gpb:
                    fake_x, _ = self.G(real_x, fake_c)
                else:
                    fake_x = self.G(real_x, fake_c)
                fake_x = Variable(fake_x.data)

                if self.config.loss_id_cls:
                    out_src, out_cls, _ = self.D(fake_x.detach())
                else:
                    out_src, out_cls = self.D(fake_x.detach())

                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake + self.config.lambda_cls * d_loss_cls + d_loss_id_cls * self.config.lambda_id_cls
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data +
                                        (1 - alpha) * fake_x.data,
                                        requires_grad=True)

                if self.config.loss_id_cls:
                    out, out_cls, _ = self.D(interpolated)
                else:
                    out, out_cls = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.config.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                self.loss['D/loss_real'] = d_loss_real.data[0]
                self.loss['D/loss_fake'] = d_loss_fake.data[0]
                self.loss[
                    'D/loss_cls'] = self.config.lambda_cls * d_loss_cls.data[0]
                self.loss[
                    'D/loss_gp'] = self.config.lambda_gp * d_loss_gp.data[0]

                # ================== Train G ================== #
                if (i + 1) % self.config.d_train_repeat == 0:

                    self.img = {}
                    # Original-to-target and target-to-original domain
                    if self.config.use_gpb:
                        fake_x, id_vector_real_in_x = self.G(real_x, fake_c)
                        rec_x, id_vector_fake_in_x = self.G(
                            fake_x.detach(), real_c)
                    else:
                        fake_x = self.G(real_x, fake_c)
                        rec_x = self.G(fake_x.detach(), real_c)

                    # Compute losses
                    if self.config.loss_id_cls:
                        out_src, out_cls, out_id_fake = self.D(fake_x)
                    else:
                        out_src, out_cls = self.D(fake_x)

                    g_loss_fake = -torch.mean(out_src)
                    g_loss_rec = torch.mean(torch.abs(real_x - rec_x))

                    ### siamese loss
                    if self.config.use_si:
                        if self.config.use_gpb:
                            # feedforward
                            fake_ox, id_vector_ox = self.G(real_ox, fake_c)
                            fake_oo, id_vector_oo = self.G(real_oo, fake_c)

                            id_vector_ox = id_vector_ox.detach()
                            id_vector_oo = id_vector_oo.detach()

                            mdist = 1.0 - torch.mean(
                                torch.abs(id_vector_real_in_x - id_vector_oo))
                            mdist = torch.clamp(mdist, min=0.0)
                            g_loss_si = 0.5 * (torch.pow(
                                torch.mean(
                                    torch.abs(id_vector_real_in_x -
                                              id_vector_ox)), 2) +
                                               torch.pow(mdist, 2))

                            # backward
                            _, id_vector_ox = self.G(fake_ox.detach(), real_c)
                            _, id_vector_oo = self.G(fake_oo.detach(), real_c)

                            id_vector_ox = id_vector_ox.detach()
                            id_vector_oo = id_vector_oo.detach()

                            mdist = 1.0 - torch.mean(
                                torch.abs(id_vector_fake_in_x - id_vector_oo))
                            mdist = torch.clamp(mdist, min=0.0)
                            g_loss_si += 0.5 * (torch.pow(
                                torch.mean(
                                    torch.abs(id_vector_fake_in_x -
                                              id_vector_ox)), 2) +
                                                torch.pow(mdist, 2))

                            self.loss['G/g_loss_si'] = g_loss_si.data[0]
                        else:
                            fake_ox = self.G(real_ox, fake_c).detach()

                            fake_ooc = fake_c.data.cpu().numpy().copy()
                            fake_ooc = np.roll(fake_ooc,
                                               np.random.randint(
                                                   self.config.c_dim),
                                               axis=1)
                            fake_ooc = self.to_var(torch.FloatTensor(fake_ooc))

                            fake_oo = self.G(real_oo, fake_ooc).detach()
                            mdist = 1.0 - torch.mean(
                                torch.abs(fake_x - fake_oo))
                            mdist = torch.clamp(mdist, min=0.0)

                            g_loss_si = 0.5 * (torch.pow(
                                torch.mean(torch.abs(fake_x - fake_ox)), 2) +
                                               torch.pow(mdist, 2))
                            self.loss['G/g_loss_si'] = g_loss_si.data[0]
                    else:
                        g_loss_si = 0.0

                    ### id cls loss
                    if self.config.loss_id_cls:
                        g_loss_id_cls = self.id_cls_criterion(
                            out_id_fake, identity)
                        self.loss[
                            'G/g_loss_id_cls'] = self.config.lambda_id_cls * g_loss_id_cls.data[
                                0]
                    else:
                        g_loss_id_cls = 0.0

                    ### sym loss
                    if self.config.loss_symmetry:
                        g_loss_sym_fake = self.find_sym_img_and_cal_loss(
                            fake_x, fake_c,
                            True)  # cal. over samples w/ specific labels
                        g_loss_sym_rec = self.find_sym_img_and_cal_loss(
                            rec_x, real_c, True)

                        lap_fake_x = self.take_laplacian(fake_x)
                        lap_rec_x = self.take_laplacian(rec_x)
                        g_loss_sym_lap_fake = self.find_sym_img_and_cal_loss(
                            lap_fake_x, None, False)  # cal. over all samples
                        g_loss_sym_lap_rec = self.find_sym_img_and_cal_loss(
                            lap_rec_x, None, False)
                        sym_loss = (g_loss_sym_fake + g_loss_sym_rec +
                                    g_loss_sym_lap_fake + g_loss_sym_lap_rec)
                        self.loss[
                            'G/g_loss_sym'] = self.config.lambda_symmetry * sym_loss.data[
                                0]
                    else:
                        sym_loss = 0

                    ###id loss
                    if self.config.loss_id:
                        if self.config.use_gpb:
                            idx, _ = self.G(real_x, real_c)
                        else:
                            idx = self.G(real_x, real_c)
                        self.img['idx'] = idx

                        g_loss_id = torch.mean(torch.abs(real_x - idx))
                        self.loss[
                            'G/g_loss_id'] = self.config.lambda_idx * g_loss_id.data[
                                0]
                    else:
                        g_loss_id = 0

                    ###identity loss
                    if self.config.loss_identity:
                        real_x_f, real_x_p = self.get_feature(real_x)
                        fake_x_f, fake_x_p = self.get_feature(fake_x)
                        g_loss_identity = torch.mean(
                            torch.abs(real_x_f - fake_x_f))
                        g_loss_identity += torch.mean(
                            torch.abs(real_x_p - fake_x_p))

                        self.loss[
                            'G/g_loss_identity'] = self.config.lambda_identity * g_loss_identity.data[
                                0]
                    else:
                        g_loss_identity = 0

                    ###total var loss
                    if self.config.loss_tv:
                        g_tv_loss = (self.total_variation_loss(fake_x) +
                                     self.total_variation_loss(rec_x)) / 2
                        self.loss[
                            'G/tv_loss'] = self.config.lambda_tv * g_tv_loss.data[
                                0]
                    else:
                        g_tv_loss = 0

                    ### D's cls loss
                    g_loss_cls = F.binary_cross_entropy_with_logits(
                        out_cls, fake_label,
                        size_average=False) / fake_x.size(0)

                    # Backward + Optimize
                    g_loss = g_loss_fake +\
                             self.config.lambda_rec * g_loss_rec +\
                             self.config.lambda_cls * g_loss_cls+\
                             self.config.lambda_idx * g_loss_id+\
                             self.config.lambda_identity*g_loss_identity+\
                             self.config.lambda_tv*g_tv_loss+\
                             self.config.lambda_symmetry*sym_loss+\
                             self.config.lambda_id_cls * g_loss_id_cls+\
                             self.config.lambda_si * g_loss_si

                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    self.img['real_x'] = real_x
                    self.img['fake_x'] = fake_x
                    self.img['rec_x'] = rec_x
                    self.loss['G/loss_fake'] = g_loss_fake.data[0]
                    self.loss[
                        'G/loss_rec'] = self.config.lambda_rec * g_loss_rec.data[
                            0]
                    self.loss[
                        'G/loss_cls'] = self.config.lambda_cls * g_loss_cls.data[
                            0]
                    #

                # Print out log info
                if (i + 1) % self.config.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e + 1, self.config.num_epochs, i + 1,
                        iters_per_epoch)

                    for tag, value in self.loss.items():
                        log += ", {}: {}".format(tag, value)
                    print(log)

                    if self.config.use_tensorboard:
                        for tag, value in self.loss.items():
                            self.logger.scalar_summary(
                                tag, value, e * iters_per_epoch + i + 1)

                # Translate fixed images for debugging
                if (i) % self.config.sample_step == 0:
                    fake_image_list = [fixed_x]
                    for fixed_c in fixed_c_list:
                        if self.config.use_gpb:
                            fake_image_list.append(self.G(fixed_x, fixed_c)[0])
                        else:
                            fake_image_list.append(self.G(fixed_x, fixed_c))
                    fake_images = torch.cat(fake_image_list, dim=3)

                    if not self.config.log_space:
                        save_image(self.denorm(fake_images.data),
                                   os.path.join(
                                       self.config.sample_path,
                                       '{}_{}_fake.png'.format(e + 1, i + 1)),
                                   nrow=1,
                                   padding=0)
                    else:
                        fake_images = self.denorm(fake_images.data) * 255.0
                        fake_images = torch.pow(
                            2.71828182846,
                            fake_images / 255.0 * np.log(256.0)) - 1.0
                        fake_images = fake_images / 255.0
                        fake_images = fake_images.clamp(0.0, 1.0)
                        save_image(fake_images,
                                   os.path.join(
                                       self.config.sample_path,
                                       '{}_{}_fake.png'.format(e + 1, i + 1)),
                                   nrow=1,
                                   padding=0)

                    print('Translated images and saved into {}..!'.format(
                        self.config.sample_path))

                # Save model checkpoints
                if (i + 1) % self.config.model_save_step == 0:
                    torch.save(
                        self.G.state_dict(),
                        os.path.join(self.config.model_save_path,
                                     '{}_{}_G.pth'.format(e + 1, i + 1)))
                    torch.save(
                        self.D.state_dict(),
                        os.path.join(self.config.model_save_path,
                                     '{}_{}_D.pth'.format(e + 1, i + 1)))
                if self.config.visualize and (i +
                                              1) % self.config.display_f == 0:
                    visualizer.display_current_results(self.img)
                    visualizer.plot_current_errors(
                        e,
                        float(i + 1) / iters_per_epoch, self.loss)

            # Decay learning rate
            if (e + 1) > (self.config.num_epochs -
                          self.config.num_epochs_decay):
                g_lr -= (self.config.g_lr /
                         float(self.config.num_epochs_decay))
                d_lr -= (self.config.d_lr /
                         float(self.config.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print('Decay learning rate to g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))
예제 #15
0
    def train(self):
        """Train StarGAN within a single dataset."""

        # Set dataloader
        if self.dataset == 'CelebA':
            self.data_loader = self.celebA_loader
        else:
            self.data_loader = self.rafd_loader

        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        fixed_x = []
        real_c = []
        fixed_s = []
        for i, (images, seg_i, seg, labels) in enumerate(self.data_loader):
            fixed_x.append(images)
            fixed_s.append(seg)
            real_c.append(labels)
            if i == 3:
                break

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)
        real_c = torch.cat(real_c, dim=0)

        
        fixed_s = torch.cat(fixed_s, dim=0)
        fixed_s_list = []
        fixed_s_list.append(self.to_var(fixed_s, volatile=True))
        
        rand_idx = torch.randperm(fixed_s.size(0))
        fixed_s_num = 5
        fixed_s_vec = fixed_s[rand_idx][:fixed_s_num]

        for i in range(fixed_s_num):
            fixed_s_temp = fixed_s_vec[i].unsqueeze(0).repeat(fixed_s.size(0),1,1,1)
            fixed_s_temp = self.to_var(fixed_s_temp)
            fixed_s_list.append(fixed_s_temp)

        # for i in range(4):
        #     rand_idx = torch.randperm(fixed_s.size(0))
        #     fixed_s_temp = self.to_var(fixed_s[rand_idx], volatile=True)
        #     fixed_s_list.append(fixed_s_temp)

        if self.dataset == 'CelebA':
            fixed_c_list = self.make_celeb_labels(real_c)
        elif self.dataset == 'RaFD':
            fixed_c_list = []
            for i in range(self.c_dim):
                fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c_dim)
                fixed_c_list.append(self.to_var(fixed_c, volatile=True))

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])-1
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            epoch_iter = 0
            for i, (real_x, real_s_i, real_s, real_label) in enumerate(self.data_loader):
                epoch_iter = epoch_iter + 1
                # Generat fake labels randomly (target domain labels)
                rand_idx = torch.randperm(real_label.size(0))
                fake_label = real_label[rand_idx]
                rand_idx = torch.randperm(real_label.size(0))
                fake_s = real_s[rand_idx]
                fake_s_i = real_s_i[rand_idx]
                if self.dataset == 'CelebA':
                    real_c = real_label.clone()
                    fake_c = fake_label.clone()
                else:
                    real_c = self.one_hot(real_label, self.c_dim)
                    fake_c = self.one_hot(fake_label, self.c_dim)

                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_s = self.to_var(real_s)
                real_s_i = self.to_var(real_s_i)
                fake_s = self.to_var(fake_s)
                fake_s_i = self.to_var(fake_s_i)
                real_c = self.to_var(real_c)           # input for the generator
                fake_c = self.to_var(fake_c)
                real_label = self.to_var(real_label)   # this is same as real_c if dataset == 'CelebA'
                fake_label = self.to_var(fake_label)
                
                # ================== Train D ================== #

                # Compute loss with real images
                out_src, out_cls = self.D(real_x)
                d_loss_real = - torch.mean(out_src)

                if self.dataset == 'CelebA':
                    d_loss_cls = F.binary_cross_entropy_with_logits(
                        out_cls, real_label, size_average=False) / real_x.size(0)
                else:
                    d_loss_cls = F.cross_entropy(out_cls, real_label)

                # Compute classification accuracy of the discriminator
                if (i+1) % self.log_step == 0:
                    accuracies = self.compute_accuracy(out_cls, real_label, self.dataset)
                    log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                    if self.dataset == 'CelebA':
                        print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='')
                    else:
                        print('Classification Acc (8 emotional expressions): ', end='')
                    print(log)

                # Compute loss with fake images
                fake_x = self.G(real_x, fake_c, fake_s)
                fake_x = Variable(fake_x.data)
                out_src, out_cls = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
                out, out_cls = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # ================== Train A ================== #
                self.a_optimizer.zero_grad()
                out_real_s = self.A(real_x)
                # a_loss = self.criterion_s(out_real_s, real_s_i.type(torch.cuda.LongTensor)) * self.lambda_s
                a_loss = self.criterion_s(out_real_s, real_s_i) * self.lambda_s
                # a_loss = torch.mean(torch.abs(real_s - out_real_s))
                a_loss.backward()
                self.a_optimizer.step()

                # Logging
                loss = {}
                loss['D/loss_real'] = d_loss_real.data[0]
                loss['D/loss_fake'] = d_loss_fake.data[0]
                loss['D/loss_cls'] = d_loss_cls.data[0]
                loss['D/loss_gp'] = d_loss_gp.data[0]

                # ================== Train G ================== #
                if (i+1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    fake_x = self.G(real_x, fake_c, fake_s)

                    rec_x = self.G(fake_x, real_c, real_s)

                    # Compute losses
                    out_src, out_cls = self.D(fake_x)
                    g_loss_fake = - torch.mean(out_src)
                    g_loss_rec = self.lambda_rec * torch.mean(torch.abs(real_x - rec_x))

                    if self.dataset == 'CelebA':
                        g_loss_cls = F.binary_cross_entropy_with_logits(
                            out_cls, fake_label, size_average=False) / fake_x.size(0)
                    else:
                        g_loss_cls = F.cross_entropy(out_cls, fake_label)

                    # segmentation loss
                    out_fake_s = self.A(fake_x)
                    g_loss_s = self.lambda_s * self.criterion_s(out_fake_s, fake_s_i)
                    # Backward + Optimize
                    g_loss = g_loss_fake + g_loss_rec + g_loss_s + self.lambda_cls * g_loss_cls
                    # g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    loss['G/loss_fake'] = g_loss_fake.data[0]
                    loss['G/loss_rec'] = g_loss_rec.data[0]
                    loss['G/loss_cls'] = g_loss_cls.data[0]
                
                if (i+1) % self.visual_step == 0:
                    # save visuals
                    self.real_x = real_x
                    self.fake_x = fake_x
                    self.rec_x = rec_x
                    self.real_s = real_s
                    self.fake_s = fake_s
                    self.out_real_s = out_real_s
                    self.out_fake_s = out_fake_s
                    self.a_loss = a_loss
                    # save losses
                    self.d_real = - d_loss_real
                    self.d_fake = d_loss_fake
                    self.d_loss = d_loss
                    self.g_loss = g_loss
                    self.g_loss_fake = g_loss_fake
                    self.g_loss_rec = g_loss_rec
                    self.g_loss_s = g_loss_s
                    errors_D = self.get_current_errors('D')
                    errors_G = self.get_current_errors('G')
                    self.visualizer.display_current_results(self.get_current_visuals(), e)
                    self.visualizer.plot_current_errors_D(e, float(epoch_iter)/float(iters_per_epoch), errors_D)
                    self.visualizer.plot_current_errors_G(e, float(epoch_iter)/float(iters_per_epoch), errors_G)
                # Print out log info
                if (i+1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e+1, self.num_epochs, i+1, iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1)

                # Translate fixed images for debugging
                if (i+1) % self.sample_step == 0:
                    fake_image_list = [fixed_x]
                    fixed_c = fixed_c_list[0]
                    real_seg_list = []
                    for fixed_c in fixed_c_list:
                        for fixed_s in fixed_s_list:
                            fake_image_list.append(self.G(fixed_x, fixed_c, fixed_s))
                            real_seg_list.append(fixed_s)
                    fake_images = torch.cat(fake_image_list, dim=3)
                    real_seg_images = torch.cat(real_seg_list, dim=3)
                    save_image(self.denorm(fake_images.data),
                        os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0)
                    save_image(self.cat2class_tensor(real_seg_images.data),
                        os.path.join(self.sample_path, '{}_{}_seg.png'.format(e+1, i+1)),nrow=1, padding=0)
                    print('Translated images and saved into {}..!'.format(self.sample_path))

                # Save model checkpoints
                if (i+1) % self.model_save_step == 0:
                    torch.save(self.G.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e+1, i+1)))
                    torch.save(self.D.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e+1, i+1)))
                    torch.save(self.A.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_A.pth'.format(e+1, i+1)))

            # Decay learning rate
            if (e+1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
예제 #16
0
파일: solver.py 프로젝트: rafalsc/StarGAN
    def train(self):
        """Train StarGAN within a single dataset."""

        # Set dataloader
        if self.dataset == 'CelebA':
            self.data_loader = self.celebA_loader
        else:
            self.data_loader = self.rafd_loader

        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        fixed_x = []
        real_c = []
        for i, (images, labels) in enumerate(self.data_loader):
            fixed_x.append(images)
            real_c.append(labels)
            if i == 3:
                break

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)
        real_c = torch.cat(real_c, dim=0)

        if self.dataset == 'CelebA':
            fixed_c_list = self.make_celeb_labels(real_c)
        elif self.dataset == 'RaFD':
            fixed_c_list = []
            for i in range(self.c_dim):
                fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c_dim)
                fixed_c_list.append(self.to_var(fixed_c, volatile=True))

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (real_x, real_label) in enumerate(self.data_loader):
                
                # Generat fake labels randomly (target domain labels)
                rand_idx = torch.randperm(real_label.size(0))
                fake_label = real_label[rand_idx]

                if self.dataset == 'CelebA':
                    real_c = real_label.clone()
                    fake_c = fake_label.clone()
                else:
                    real_c = self.one_hot(real_label, self.c_dim)
                    fake_c = self.one_hot(fake_label, self.c_dim)

                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_c = self.to_var(real_c)           # input for the generator
                fake_c = self.to_var(fake_c)
                real_label = self.to_var(real_label)   # this is same as real_c if dataset == 'CelebA'
                fake_label = self.to_var(fake_label)
                
                # ================== Train D ================== #

                # Compute loss with real images
                out_src, out_cls = self.D(real_x)
                d_loss_real = - torch.mean(out_src)

                if self.dataset == 'CelebA':
                    d_loss_cls = F.binary_cross_entropy_with_logits(
                        out_cls, real_label, size_average=False) / real_x.size(0)
                else:
                    d_loss_cls = F.cross_entropy(out_cls, real_label)

                # Compute classification accuracy of the discriminator
                if (i+1) % self.log_step == 0:
                    accuracies = self.compute_accuracy(out_cls, real_label, self.dataset)
                    log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                    if self.dataset == 'CelebA':
                        print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='')
                    else:
                        print('Classification Acc (8 emotional expressions): ', end='')
                    print(log)

                # Compute loss with fake images
                fake_x = self.G(real_x, fake_c)
                fake_x = Variable(fake_x.data)
                out_src, out_cls = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
                out, out_cls = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                loss = {}
                loss['D/loss_real'] = d_loss_real.data[0]
                loss['D/loss_fake'] = d_loss_fake.data[0]
                loss['D/loss_cls'] = d_loss_cls.data[0]
                loss['D/loss_gp'] = d_loss_gp.data[0]

                # ================== Train G ================== #
                if (i+1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    fake_x = self.G(real_x, fake_c)
                    rec_x = self.G(fake_x, real_c)

                    # Compute losses
                    out_src, out_cls = self.D(fake_x)
                    g_loss_fake = - torch.mean(out_src)
                    g_loss_rec = torch.mean(torch.abs(real_x - rec_x))

                    if self.dataset == 'CelebA':
                        g_loss_cls = F.binary_cross_entropy_with_logits(
                            out_cls, fake_label, size_average=False) / fake_x.size(0)
                    else:
                        g_loss_cls = F.cross_entropy(out_cls, fake_label)

                    # Backward + Optimize
                    g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    loss['G/loss_fake'] = g_loss_fake.data[0]
                    loss['G/loss_rec'] = g_loss_rec.data[0]
                    loss['G/loss_cls'] = g_loss_cls.data[0]

                # Print out log info
                if (i+1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e+1, self.num_epochs, i+1, iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1)

                # Translate fixed images for debugging
                if (i+1) % self.sample_step == 0:
                    fake_image_list = [fixed_x]
                    for fixed_c in fixed_c_list:
                        fake_image_list.append(self.G(fixed_x, fixed_c))
                    fake_images = torch.cat(fake_image_list, dim=3)
                    save_image(self.denorm(fake_images.data),
                        os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0)
                    print('Translated images and saved into {}..!'.format(self.sample_path))

                # Save model checkpoints
                if (i+1) % self.model_save_step == 0:
                    torch.save(self.G.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e+1, i+1)))
                    torch.save(self.D.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e+1, i+1)))

            # Decay learning rate
            if (e+1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
예제 #17
0
파일: solver.py 프로젝트: rafalsc/StarGAN
    def train_multi(self):
        """Train StarGAN with multiple datasets.
        In the code below, 1 is related to CelebA and 2 is releated to RaFD.
        """
        # Fixed imagse and labels for debugging
        fixed_x = []
        real_c = []

        for i, (images, labels) in enumerate(self.celebA_loader):
            fixed_x.append(images)
            real_c.append(labels)
            if i == 2:
                break

        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)
        real_c = torch.cat(real_c, dim=0)
        fixed_c1_list = self.make_celeb_labels(real_c)

        fixed_c2_list = []
        for i in range(self.c2_dim):
            fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c2_dim)
            fixed_c2_list.append(self.to_var(fixed_c, volatile=True))

        fixed_zero1 = self.to_var(torch.zeros(fixed_x.size(0), self.c2_dim))     # zero vector when training with CelebA
        fixed_mask1 = self.to_var(self.one_hot(torch.zeros(fixed_x.size(0)), 2)) # mask vector: [1, 0]
        fixed_zero2 = self.to_var(torch.zeros(fixed_x.size(0), self.c_dim))      # zero vector when training with RaFD
        fixed_mask2 = self.to_var(self.one_hot(torch.ones(fixed_x.size(0)), 2))  # mask vector: [0, 1]

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # data iterator
        data_iter1 = iter(self.celebA_loader)
        data_iter2 = iter(self.rafd_loader)

        # Start with trained model
        if self.pretrained_model:
            start = int(self.pretrained_model) + 1
        else:
            start = 0

        # # Start training
        start_time = time.time()
        for i in range(start, self.num_iters):

            # Fetch mini-batch images and labels
            try:
                real_x1, real_label1 = next(data_iter1)
            except:
                data_iter1 = iter(self.celebA_loader)
                real_x1, real_label1 = next(data_iter1)

            try:
                real_x2, real_label2 = next(data_iter2)
            except:
                data_iter2 = iter(self.rafd_loader)
                real_x2, real_label2 = next(data_iter2)

            # Generate fake labels randomly (target domain labels)
            rand_idx = torch.randperm(real_label1.size(0))
            fake_label1 = real_label1[rand_idx]
            rand_idx = torch.randperm(real_label2.size(0))
            fake_label2 = real_label2[rand_idx]

            real_c1 = real_label1.clone()
            fake_c1 = fake_label1.clone()
            zero1 = torch.zeros(real_x1.size(0), self.c2_dim)
            mask1 = self.one_hot(torch.zeros(real_x1.size(0)), 2)

            real_c2 = self.one_hot(real_label2, self.c2_dim)
            fake_c2 = self.one_hot(fake_label2, self.c2_dim)
            zero2 = torch.zeros(real_x2.size(0), self.c_dim)
            mask2 = self.one_hot(torch.ones(real_x2.size(0)), 2)

            # Convert tensor to variable
            real_x1 = self.to_var(real_x1)
            real_c1 = self.to_var(real_c1)
            fake_c1 = self.to_var(fake_c1)
            mask1 = self.to_var(mask1)
            zero1 = self.to_var(zero1)

            real_x2 = self.to_var(real_x2)
            real_c2 = self.to_var(real_c2)
            fake_c2 = self.to_var(fake_c2)
            mask2 = self.to_var(mask2)
            zero2 = self.to_var(zero2)

            real_label1 = self.to_var(real_label1)
            fake_label1 = self.to_var(fake_label1)
            real_label2 = self.to_var(real_label2)
            fake_label2 = self.to_var(fake_label2)

            # ================== Train D ================== #

            # Real images (CelebA)
            out_real, out_cls = self.D(real_x1)
            out_cls1 = out_cls[:, :self.c_dim]      # celebA part
            d_loss_real = - torch.mean(out_real)
            d_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, real_label1, size_average=False) / real_x1.size(0)

            # Real images (RaFD)
            out_real, out_cls = self.D(real_x2)
            out_cls2 = out_cls[:, self.c_dim:]      # rafd part
            d_loss_real += - torch.mean(out_real)
            d_loss_cls += F.cross_entropy(out_cls2, real_label2)

            # Compute classification accuracy of the discriminator
            if (i+1) % self.log_step == 0:
                accuracies = self.compute_accuracy(out_cls1, real_label1, 'CelebA')
                log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='')
                print(log)
                accuracies = self.compute_accuracy(out_cls2, real_label2, 'RaFD')
                log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                print('Classification Acc (8 emotional expressions): ', end='')
                print(log)

            # Fake images (CelebA)
            fake_c = torch.cat([fake_c1, zero1, mask1], dim=1)
            fake_x1 = self.G(real_x1, fake_c)
            fake_x1 = Variable(fake_x1.data)
            out_fake, _ = self.D(fake_x1)
            d_loss_fake = torch.mean(out_fake)

            # Fake images (RaFD)
            fake_c = torch.cat([zero2, fake_c2, mask2], dim=1)
            fake_x2 = self.G(real_x2, fake_c)
            out_fake, _ = self.D(fake_x2)
            d_loss_fake += torch.mean(out_fake)

            # Backward + Optimize
            d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Compute gradient penalty
            if (i+1) % 2 == 0:
                real_x = real_x1
                fake_x = fake_x1
            else:
                real_x = real_x2
                fake_x = fake_x2

            alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
            interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
            out, out_cls = self.D(interpolated)

            if (i+1) % 2 == 0:
                out_cls = out_cls[:, :self.c_dim]  # CelebA
            else:
                out_cls = out_cls[:, self.c_dim:]  # RaFD

            grad = torch.autograd.grad(outputs=out,
                                       inputs=interpolated,
                                       grad_outputs=torch.ones(out.size()).cuda(),
                                       retain_graph=True,
                                       create_graph=True,
                                       only_inputs=True)[0]

            grad = grad.view(grad.size(0), -1)
            grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
            d_loss_gp = torch.mean((grad_l2norm - 1)**2)

            # Backward + Optimize
            d_loss = self.lambda_gp * d_loss_gp
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Logging
            loss = {}
            loss['D/loss_real'] = d_loss_real.data[0]
            loss['D/loss_fake'] = d_loss_fake.data[0]
            loss['D/loss_cls'] = d_loss_cls.data[0]
            loss['D/loss_gp'] = d_loss_gp.data[0]

            # ================== Train G ================== #
            if (i+1) % self.d_train_repeat == 0:
                # Original-to-target and target-to-original domain (CelebA)
                fake_c = torch.cat([fake_c1, zero1, mask1], dim=1)
                real_c = torch.cat([real_c1, zero1, mask1], dim=1)
                fake_x1 = self.G(real_x1, fake_c)
                rec_x1 = self.G(fake_x1, real_c)

                # Compute losses
                out, out_cls = self.D(fake_x1)
                out_cls1 = out_cls[:, :self.c_dim]
                g_loss_fake = - torch.mean(out)
                g_loss_rec = torch.mean(torch.abs(real_x1 - rec_x1))
                g_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, fake_label1, size_average=False) / fake_x1.size(0)

                # Original-to-target and target-to-original domain (RaFD)
                fake_c = torch.cat([zero2, fake_c2, mask2], dim=1)
                real_c = torch.cat([zero2, real_c2, mask2], dim=1)
                fake_x2 = self.G(real_x2, fake_c)
                rec_x2 = self.G(fake_x2, real_c)

                # Compute losses
                out, out_cls = self.D(fake_x2)
                out_cls2 = out_cls[:, self.c_dim:]
                g_loss_fake += - torch.mean(out)
                g_loss_rec += torch.mean(torch.abs(real_x2 - rec_x2))
                g_loss_cls += F.cross_entropy(out_cls2, fake_label2)

                # Backward + Optimize
                g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + self.lambda_rec * g_loss_rec
                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging
                loss['G/loss_fake'] = g_loss_fake.data[0]
                loss['G/loss_cls'] = g_loss_cls.data[0]
                loss['G/loss_rec'] = g_loss_rec.data[0]

            # Print out log info
            if (i+1) % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))

                log = "Elapsed [{}], Iter [{}/{}]".format(
                    elapsed, i+1, self.num_iters)

                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i+1)

            # Translate the images (debugging)
            if (i+1) % self.sample_step == 0:
                fake_image_list = [fixed_x]

                # Changing hair color, gender, and age
                for j in range(self.c_dim):
                    fake_c = torch.cat([fixed_c1_list[j], fixed_zero1, fixed_mask1], dim=1)
                    fake_image_list.append(self.G(fixed_x, fake_c))
                # Changing emotional expressions
                for j in range(self.c2_dim):
                    fake_c = torch.cat([fixed_zero2, fixed_c2_list[j], fixed_mask2], dim=1)
                    fake_image_list.append(self.G(fixed_x, fake_c))
                fake = torch.cat(fake_image_list, dim=3)

                # Save the translated images
                save_image(self.denorm(fake.data),
                    os.path.join(self.sample_path, '{}_fake.png'.format(i+1)), nrow=1, padding=0)

            # Save model checkpoints
            if (i+1) % self.model_save_step == 0:
                torch.save(self.G.state_dict(),
                    os.path.join(self.model_save_path, '{}_G.pth'.format(i+1)))
                torch.save(self.D.state_dict(),
                    os.path.join(self.model_save_path, '{}_D.pth'.format(i+1)))

            # Decay learning rate
            decay_step = 1000
            if (i+1) > (self.num_iters - self.num_iters_decay) and (i+1) % decay_step==0:
                g_lr -= (self.g_lr / float(self.num_iters_decay) * decay_step)
                d_lr -= (self.d_lr / float(self.num_iters_decay) * decay_step)
                self.update_lr(g_lr, d_lr)
                print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
예제 #18
0
파일: wgan.py 프로젝트: XPping/pytorch-GAN
    def train(self):
        print(len(self.data_loader))
        for e in range(self.epoch):
            for i, batch_images in enumerate(self.data_loader):
                batch_size = batch_images.size(0)
                label = torch.FloatTensor(batch_size)
                real_x = self.to_var(batch_images)
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(batch_size, self.noise_n)))
                # train D
                fake_x = self.G(noise_x)
                real_out = self.D(real_x)
                fake_out = self.D(fake_x.detach())

                D_real = -torch.mean(real_out)
                D_fake = torch.mean(fake_out)
                D_loss = D_real + D_fake

                self.reset_grad()
                D_loss.backward()
                self.D_optimizer.step()
                # Log
                loss = {}
                loss['D/loss_real'] = D_real.data[0]
                loss['D/loss_fake'] = D_fake.data[0]
                loss['D/loss'] = D_loss.data[0]

                # choose one in below two
                # Clip weights of D
                # for p in self.D.parameters():
                #     p.data.clamp_(-self.clip_value, clip_value)
                # Gradients penalty, WGAP-GP
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real_x)
                # print(alpha.shape, real_x.shape, fake_x.shape)
                interpolated = Variable(alpha * real_x.data +
                                        (1 - alpha) * fake_x.data,
                                        requires_grad=True)
                gp_out = self.D(interpolated)
                grad = torch.autograd.grad(outputs=gp_out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               gp_out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)
                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.D_optimizer.step()
                # Train G
                if (i + 1) % self.D_train_step == 0:
                    fake_out = self.D(self.G(noise_x))
                    G_loss = -torch.mean(fake_out)
                    self.reset_grad()
                    G_loss.backward()
                    self.G_optimizer.step()
                    loss['G/loss'] = G_loss.data[0]
                # Print log
                if (i + 1) % self.log_step == 0:
                    log = "Epoch: {}/{}, Iter: {}/{}".format(
                        e + 1, self.epoch, i + 1, len(self.data_loader))
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)
                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(
                            tag, value,
                            e * len(self.data_loader) + i + 1)
            # Save images
            if (e + 1) % self.save_image_step == 0:
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(16, self.noise_n)))
                fake_image = self.G(noise_x)
                save_image(
                    self.denorm(fake_image.data),
                    os.path.join(self.image_save_path,
                                 "{}_fake.png".format(e + 1)))
            if (e + 1) % self.model_save_step == 0:
                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_G.pth".format(e + 1)))
                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_D.pth".format(e + 1)))
예제 #19
0
    def train(self):
        """Train StarGAN within a single dataset."""

        # Set dataloader
        if self.dataset == 'CelebA':
            self.data_loader = self.celebA_loader
        elif self.dataset == 'RaFD':
            self.data_loader = self.rafd_loader
        elif self.dataset == 'fer2013':
            self.data_loader = self.fer2013_loader
        elif self.dataset == 'ferg_db':
            self.data_loader = self.ferg_db_loader

        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        fixed_x = []
        real_c = []
        for i, (images, labels) in enumerate(self.data_loader):
            fixed_x.append(images)
            real_c.append(labels)
            if i == 3:
                break

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)
        real_c = torch.cat(real_c, dim=0)

        if self.dataset in ['CelebA']:
            fixed_c_list = self.make_celeb_labels(real_c)
        elif self.dataset in ['RaFD', 'fer2013', 'ferg_db']:
            fixed_c_list = []
            for i in range(self.c_dim):
                fixed_c = self.one_hot(
                    torch.ones(fixed_x.size(0)) * i, self.c_dim)
                fixed_c_list.append(self.to_var(fixed_c, volatile=True))

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (real_x, real_label) in enumerate(self.data_loader):

                # Generat fake labels randomly (target domain labels)
                rand_idx = torch.randperm(real_label.size(0))
                fake_label = real_label[rand_idx]

                if self.dataset == 'CelebA':
                    real_c = real_label.clone()
                    fake_c = fake_label.clone()
                else:
                    real_c = self.one_hot(real_label, self.c_dim)
                    fake_c = self.one_hot(fake_label, self.c_dim)

                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_c = self.to_var(real_c)  # input for the generator
                fake_c = self.to_var(fake_c)
                real_label = self.to_var(
                    real_label
                )  # this is same as real_c if dataset == 'CelebA'
                fake_label = self.to_var(fake_label)

                # ================== Train D ================== #

                # Compute loss with real images
                out_src, out_cls, out_feats_real = self.D(real_x)
                d_loss_real = -torch.mean(out_src)

                if self.dataset == 'CelebA':
                    d_loss_cls = F.binary_cross_entropy_with_logits(
                        out_cls, real_label,
                        size_average=False) / real_x.size(0)
                else:
                    d_loss_cls = F.cross_entropy(out_cls, real_label)

                # Compute classification accuracy of the discriminator
                if (i + 1) % self.log_step == 0:
                    accuracies = self.compute_accuracy(out_cls, real_label,
                                                       self.dataset)
                    log = [
                        "{:.2f}".format(acc)
                        for acc in accuracies.data.cpu().numpy()
                    ]
                    if self.dataset == 'CelebA':
                        print(
                            'Classification Acc (Black/Blond/Brown/Gender/Aged): ',
                            end='')
                    elif self.dataset in ['fer2013', 'ferg_db']:
                        print('Classification Acc (7 emotional expressions): ',
                              end='')
                    else:
                        print('Classification Acc (8 emotional expressions): ',
                              end='')
                    print(log)

                # Compute loss with fake images
                fake_x = self.G(real_x, fake_c)
                fake_x = Variable(fake_x.data)
                out_src, out_cls, out_feats_fake = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
                self.reset_grad()
                d_loss.backward(retain_graph=True)
                self.d_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data +
                                        (1 - alpha) * fake_x.data,
                                        requires_grad=True)
                out, out_cls, out_feats = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward(retain_graph=True)
                self.d_optimizer.step()

                # Logging
                loss = {}
                loss['D/loss_real'] = d_loss_real.data[0]
                loss['D/loss_fake'] = d_loss_fake.data[0]
                loss['D/loss_cls'] = d_loss_cls.data[0]
                loss['D/loss_gp'] = d_loss_gp.data[0]

                # ================== Train G ================== #
                if (i + 1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    fake_x = self.G(real_x, fake_c)
                    rec_x = self.G(fake_x, real_c)

                    # Compute losses
                    out_src, out_cls, out_feats_fake = self.D(fake_x)
                    g_loss_fake = -torch.mean(out_src)

                    if self.dataset == 'CelebA':
                        g_loss_cls = F.binary_cross_entropy_with_logits(
                            out_cls, fake_label,
                            size_average=False) / fake_x.size(0)
                    else:
                        g_loss_cls = F.cross_entropy(out_cls, fake_label)

                    ### Discriminate for rec_x
                    out_src, out_cls, out_feats_rec = self.D(rec_x)
                    g_loss_rec = torch.mean(torch.abs(real_x - rec_x))
                    '''
                    ### Replace pixel-wise reconstruction error between real_x / rec_x
                    ### with feature-wise reconstruction error (multi layers) between real_feat / rec_feat (L1 norm)
                    g_loss_feat_rec = 0
                    for real_feat, rec_feat in zip(out_feats_real, out_feats_rec):
                        g_loss_feat_rec += torch.mean(torch.abs(real_feat - rec_feat))
                    '''
                    '''
                    ### Feature matching (distribution) loss (multi layers) between real_feat / rec_feat (L2 norm, from DiscoGAN)
                    feat_criterion = nn.HingeEmbeddingLoss()
                    g_loss_feat_match = 0
                    for real_feat, rec_feat in zip(out_feats_real, out_feats_rec):
                        l2 = (torch.mean(real_feat, 0) - torch.mean(rec_feat, 0)) ** 2
                        g_loss_feat_match += feat_criterion( l2, Variable( torch.ones( l2.size() ) ).cuda() )
                    '''

                    # Backward + Optimize
                    g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + self.lambda_rec * g_loss_rec
                    '''
                    if e < self.num_epochs // 5: # early phase
                        g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + self.lambda_rec * g_loss_rec
                    else:
                        g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + self.lambda_feat_rec * g_loss_feat_rec
                    '''
                    self.reset_grad()
                    g_loss.backward(retain_graph=True)
                    self.g_optimizer.step()

                    # Logging
                    loss['G/loss_fake'] = g_loss_fake.data[0]
                    loss['G/loss_rec'] = g_loss_rec.data[0]
                    loss['G/loss_cls'] = g_loss_cls.data[0]
                    ###loss['G/loss_feat_rec'] = g_loss_feat_rec.data[0]

                # Print out log info
                if (i + 1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e + 1, self.num_epochs, i + 1,
                        iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(
                                tag, value, e * iters_per_epoch + i + 1)

                # Translate fixed images for debugging
                if (i + 1) % self.sample_step == 0:
                    fake_image_list = [fixed_x]
                    for fixed_c in fixed_c_list:
                        fake_image_list.append(self.G(fixed_x, fixed_c))
                    fake_images = torch.cat(fake_image_list, dim=3)
                    save_image(self.denorm(fake_images.data),
                               os.path.join(
                                   self.sample_path,
                                   '{}_{}_fake.png'.format(e + 1, i + 1)),
                               nrow=1,
                               padding=0)
                    print('Translated images and saved into {}..!'.format(
                        self.sample_path))

                # Save model checkpoints
                if (i + 1) % self.model_save_step == 0:
                    torch.save(
                        self.G.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_G.pth'.format(e + 1, i + 1)))
                    torch.save(
                        self.D.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_D.pth'.format(e + 1, i + 1)))

            # Decay learning rate
            if (e + 1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print('Decay learning rate to g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))