Exemple #1
0
class Solver(object):
    """Solver for training and testing StarGAN."""
    def __init__(self, celeba_loader, rafd_loader, affectnet_loader, config):
        """Initialize configurations."""

        # Data loader.
        self.celeba_loader = celeba_loader
        self.rafd_loader = rafd_loader
        self.affectnet_loader = affectnet_loader
        # Model configurations.
        self.c_dim = config.c_dim
        self.c2_dim = config.c2_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp
        self.lambda_cls2 = config.lambda_cls2

        # Training configurations.
        self.dataset = config.dataset
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters
        self.selected_attrs = config.selected_attrs
        self.affectnet_emo_descr = config.affectnet_emo_descr
        self.use_ccc = config.use_ccc
        self.depth_concat = config.depth_concat
        self.d_loss_cls_type = config.d_loss_cls_type
        # Test configurations.
        self.test_iters = config.test_iters

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        if self.affectnet_emo_descr in ["64d_reg", "64d_cls"]:
            self.label_transform = LabelTransform(config)

        # Directories.
        self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

    def build_model(self):
        """Create a generator and a discriminator."""
        if self.dataset in ["CelebA", "RaFD", "AffectNet"]:
            self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num,
                               self.depth_concat)
            self.D = Discriminator(
                self.image_size,
                self.d_conv_dim,
                self.c_dim,
                self.d_repeat_num,
                self.affectnet_emo_descr,
                self.d_loss_cls_type,
            )

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])
        self.print_network(self.G, "G")
        self.print_network(self.D, "D")

        self.G.to(self.device)
        self.D.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print(
            "Loading the trained models from step {}...".format(resume_iters))
        G_path = os.path.join(self.model_save_dir,
                              "{}-G.ckpt".format(resume_iters))
        D_path = os.path.join(self.model_save_dir,
                              "{}-D.ckpt".format(resume_iters))
        self.G.load_state_dict(
            torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(
            torch.load(D_path, map_location=lambda storage, loc: storage))

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger

        self.logger = Logger(self.log_dir)

    def update_lr(self, g_lr, d_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group["lr"] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group["lr"] = d_lr

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(
            outputs=y,
            inputs=x,
            grad_outputs=weight,
            retain_graph=True,
            create_graph=True,
            only_inputs=True,
        )[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm - 1)**2)

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def create_labels(self,
                      c_org,
                      c_dim=5,
                      dataset="CelebA",
                      selected_attrs=None):
        """Generate target domain labels for debugging and testing."""
        if dataset in ["CelebA", "RaFD"]:
            # Get hair color indices.
            if dataset == "CelebA":
                hair_color_indices = []
                for i, attr_name in enumerate(selected_attrs):
                    if attr_name in [
                            "Black_Hair",
                            "Blond_Hair",
                            "Brown_Hair",
                            "Gray_Hair",
                    ]:
                        hair_color_indices.append(i)

            c_trg_list = []
            for i in range(c_dim):
                if dataset == "CelebA":
                    c_trg = c_org.clone()
                    if (i in hair_color_indices
                        ):  # Set one hair color to 1 and the rest to 0.
                        c_trg[:, i] = 1
                        for j in hair_color_indices:
                            if j != i:
                                c_trg[:, j] = 0
                    else:
                        c_trg[:,
                              i] = c_trg[:, i] == 0  # Reverse attribute value.
                elif dataset == "RaFD":
                    c_trg = self.label2onehot(
                        torch.ones(c_org.size(0)) * i, c_dim)

                c_trg_list.append(c_trg.to(self.device))

        elif dataset == "AffectNet":
            c_trg_list = []
            if self.affectnet_emo_descr == "emotiw":
                for i in range(c_dim):
                    c_trg = self.label2onehot(
                        torch.ones(c_org.size(0)) * i, c_dim)
                    c_trg_list.append(c_trg.to(self.device))
            elif self.affectnet_emo_descr == "va":
                for v in [-1.0, -0.5, 0.0, 0.5, 1.0]:
                    for a in [-1.0, -0.5, 0.0, 0.5, 1.0]:
                        c_trg = torch.Tensor([v, a]).repeat(c_org.size(0), 1)
                        c_trg_list.append(c_trg.to(self.device))
            elif self.affectnet_emo_descr in ["64d_cls"]:
                c_trg_list = self.label_transform.create_labels(c_org[:, 1:])
            elif self.affectnet_emo_descr in ["64d_reg"]:
                c_trg_list = self.label_transform.create_labels(c_org[:, 2:])
        return c_trg_list

    def ccc_loss(self, x, y):
        mu_x = x.mean(0)
        var_x = x.var(0)
        mu_y = y.mean(0)
        var_y = y.var(0)
        cov_xy = torch.mean((x - mu_x.unsqueeze(0)) * (y - mu_y.unsqueeze(0)),
                            0)
        ccc = 2 * cov_xy / (var_x + var_y + (mu_x - mu_y)**2)
        loss = 1 - ccc
        return loss.sum()

    def classification_loss(self, logit, target, dataset="CelebA"):
        """Compute binary or softmax cross entropy loss."""
        if dataset == "CelebA":
            return F.binary_cross_entropy_with_logits(
                logit, target, reduction="sum") / logit.size(0)
        elif dataset == "RaFD":
            return F.cross_entropy(logit, target)
        elif dataset == "AffectNet":
            if self.affectnet_emo_descr in ["emotiw"]:
                return F.cross_entropy(logit, target)
            elif self.affectnet_emo_descr in ["va"]:
                if self.use_ccc:
                    return self.ccc_loss(logit, target)
                else:
                    return F.mse_loss(logit, target)
            elif self.affectnet_emo_descr in ["64d_reg"]:
                logit_pred, logit_actv = logit
                target_pred, target_actv = target
                out = 0
                if logit_pred is not None:
                    if self.use_ccc:
                        out += (self.lambda_cls2 / self.lambda_cls *
                                self.ccc_loss(logit_pred, target_pred))
                    else:
                        out += (self.lambda_cls2 / self.lambda_cls *
                                F.mse_loss(logit_pred, target_pred))
                if logit_actv is not None:
                    if self.use_ccc:
                        out += self.ccc_loss(logit_actv, target_actv)
                    else:
                        out += F.mse_loss(logit_actv, target_actv)
                return out
            elif self.affectnet_emo_descr in ["64d_cls"]:
                logit_pred, logit_actv = logit
                target_pred, target_actv = target
                out = 0
                if logit_pred is not None:
                    out += (self.lambda_cls2 / self.lambda_cls *
                            F.cross_entropy(logit_pred, target_pred))
                if logit_actv is not None:
                    if self.use_ccc:
                        out += self.ccc_loss(logit_actv, target_actv)
                    else:
                        out += F.mse_loss(logit_actv, target_actv)
                return out

    def train(self):
        """Train StarGAN within a single dataset."""
        # Set data loader.
        if self.dataset == "CelebA":
            data_loader = self.celeba_loader
        elif self.dataset == "RaFD":
            data_loader = self.rafd_loader
        elif self.dataset == "AffectNet":
            data_loader = self.affectnet_loader

        # Fetch fixed inputs for debugging.
        data_iter = iter(data_loader)
        x_fixed, c_org = next(data_iter)
        x_fixed = x_fixed.to(self.device)
        c_fixed_list = self.create_labels(c_org, self.c_dim, self.dataset,
                                          self.selected_attrs)

        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)

        # Start training.
        print("Start training...")
        start_time = time.time()
        for i in range(start_iters, self.num_iters):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch real images and labels.
            try:
                x_real, label_org = next(data_iter)
            except:
                data_iter = iter(data_loader)
                x_real, label_org = next(data_iter)

            # Generate target domain labels randomly.
            rand_idx = torch.randperm(label_org.size(0))
            label_trg = label_org[rand_idx]

            if self.dataset == "CelebA":
                c_org = label_org.clone().to(self.device)
                c_trg = label_trg.clone().to(self.device)
                label_org = label_org.to(
                    self.device)  # Labels for computing classification loss.
                label_trg = label_trg.to(
                    self.device)  # Labels for computing classification loss.
            elif self.dataset == "RaFD":
                c_org = self.label2onehot(label_org,
                                          self.c_dim).to(self.device)
                c_trg = self.label2onehot(label_trg,
                                          self.c_dim).to(self.device)
                label_org = label_org.to(
                    self.device)  # Labels for computing classification loss.
                label_trg = label_trg.to(
                    self.device)  # Labels for computing classification loss.
            elif self.dataset == "AffectNet":
                if self.affectnet_emo_descr == "emotiw":
                    c_org = self.label2onehot(label_org,
                                              self.c_dim).to(self.device)
                    c_trg = self.label2onehot(label_trg,
                                              self.c_dim).to(self.device)
                    label_org = label_org.to(
                        self.device
                    )  # Labels for computing classification loss.
                    label_trg = label_trg.to(
                        self.device
                    )  # Labels for computing classification loss.
                elif self.affectnet_emo_descr == "va":
                    c_org = label_org.clone().to(self.device)
                    c_trg = label_trg.clone().to(self.device)
                    label_org = label_org.to(
                        self.device
                    )  # Labels for computing classification loss.
                    label_trg = label_trg.to(
                        self.device
                    )  # Labels for computing classification loss.
                elif self.affectnet_emo_descr == "64d_cls":
                    pred_org = label_org[:, 0].clone().long().to(self.device)
                    pred_trg = label_trg[:, 0].clone().long().to(self.device)
                    actv_org = label_org[:, 1:].to(self.device)
                    actv_trg = label_trg[:, 1:].to(self.device)
                    c_org = actv_org.clone().to(self.device)
                    c_trg = actv_trg.clone().to(self.device)
                elif self.affectnet_emo_descr == "64d_reg":
                    pred_org = label_org[:, :2].clone().float().to(self.device)
                    pred_trg = label_trg[:, :2].clone().float().to(self.device)
                    actv_org = label_org[:, 2:].to(self.device)
                    actv_trg = label_trg[:, 2:].to(self.device)
                    c_org = actv_org.clone().to(self.device)
                    c_trg = actv_trg.clone().to(self.device)

            x_real = x_real.to(self.device)  # Input images.

            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #

            # Compute loss with real images.
            out_src, out_cls = self.D(x_real)
            d_loss_real = -torch.mean(out_src)

            if self.affectnet_emo_descr in ["64d_reg", "64d_cls"]:
                if self.d_loss_cls_type == "actv":
                    d_loss_cls = self.classification_loss(
                        (None, out_cls), (None, actv_org), self.dataset)
                elif self.d_loss_cls_type == "pred":
                    d_loss_cls = self.classification_loss(
                        (self.D.fc(out_cls), None), (pred_org, None),
                        self.dataset)
                else:
                    d_loss_cls = self.classification_loss(
                        (self.D.fc(out_cls), out_cls),
                        (pred_org, actv_org),
                        self.dataset,
                    )
            else:
                d_loss_cls = self.classification_loss(out_cls, label_org,
                                                      self.dataset)

            # Compute loss with fake images.
            x_fake = self.G(x_real, c_trg)
            out_src, out_cls = self.D(x_fake.detach())
            d_loss_fake = torch.mean(out_src)

            # Compute loss for gradient penalty.
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            x_hat = (alpha * x_real.data +
                     (1 - alpha) * x_fake.data).requires_grad_(True)
            out_src, _ = self.D(x_hat)
            d_loss_gp = self.gradient_penalty(out_src, x_hat)

            # Backward and optimize.
            d_loss = (d_loss_real + d_loss_fake +
                      self.lambda_cls * d_loss_cls +
                      self.lambda_gp * d_loss_gp)
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Logging.
            loss = {}
            loss["D/loss_real"] = d_loss_real.item()
            loss["D/loss_fake"] = d_loss_fake.item()
            loss["D/loss_cls"] = d_loss_cls.item()
            loss["D/loss_gp"] = d_loss_gp.item()

            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #

            if (i + 1) % self.n_critic == 0:
                # Original-to-target domain.
                x_fake = self.G(x_real, c_trg)
                out_src, out_cls = self.D(x_fake)
                g_loss_fake = -torch.mean(out_src)

                if self.affectnet_emo_descr in ["64d_reg", "64d_cls"]:
                    if self.d_loss_cls_type == "actv":
                        g_loss_cls = self.classification_loss(
                            (None, out_cls), (None, actv_trg), self.dataset)
                    elif self.d_loss_cls_type == "pred":
                        g_loss_cls = self.classification_loss(
                            (self.D.fc(out_cls), None), (pred_trg, None),
                            self.dataset)
                    else:
                        g_loss_cls = self.classification_loss(
                            (self.D.fc(out_cls), out_cls),
                            (pred_trg, actv_trg),
                            self.dataset,
                        )
                else:
                    g_loss_cls = self.classification_loss(
                        out_cls, label_trg, self.dataset)

                # Target-to-original domain.
                x_reconst = self.G(x_fake, c_org)
                g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

                # Backward and optimize.
                g_loss = (g_loss_fake + self.lambda_rec * g_loss_rec +
                          self.lambda_cls * g_loss_cls)
                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging.
                loss["G/loss_fake"] = g_loss_fake.item()
                loss["G/loss_rec"] = g_loss_rec.item()
                loss["G/loss_cls"] = g_loss_cls.item()

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i + 1)

            # Translate fixed images for debugging.
            if (i + 1) % self.sample_step == 0:
                with torch.no_grad():
                    x_fake_list = [x_fixed]
                    for c_fixed in c_fixed_list:
                        x_fake_list.append(self.G(x_fixed, c_fixed))
                    x_concat = torch.cat(x_fake_list, dim=3)
                    sample_path = os.path.join(self.sample_dir,
                                               "{}-images.jpg".format(i + 1))
                    save_image(self.denorm(x_concat.data.cpu()),
                               sample_path,
                               nrow=1,
                               padding=0)
                    print("Saved real and fake images into {}...".format(
                        sample_path))

            # Save model checkpoints.
            if (i + 1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir,
                                      "{}-G.ckpt".format(i + 1))
                D_path = os.path.join(self.model_save_dir,
                                      "{}-D.ckpt".format(i + 1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                print("Saved model checkpoints into {}...".format(
                    self.model_save_dir))

            # Decay learning rates.
            if (i + 1) % self.lr_update_step == 0 and (i + 1) > (
                    self.num_iters - self.num_iters_decay):
                g_lr -= self.g_lr / float(self.num_iters_decay)
                d_lr -= self.d_lr / float(self.num_iters_decay)
                self.update_lr(g_lr, d_lr)
                print("Decayed learning rates, g_lr: {}, d_lr: {}.".format(
                    g_lr, d_lr))

    def test(self):
        """Translate images using StarGAN trained on a single dataset."""
        # Load the trained generator.
        self.restore_model(self.test_iters)

        # Set data loader.
        if self.dataset == "CelebA":
            data_loader = self.celeba_loader
        elif self.dataset == "RaFD":
            data_loader = self.rafd_loader
        elif self.dataset == "AffectNet":
            data_loader = self.affectnet_loader

        with torch.no_grad():
            for i, (x_real, c_org) in enumerate(data_loader):

                # Prepare input images and target domain labels.
                x_real = x_real.to(self.device)
                c_trg_list = self.create_labels(c_org, self.c_dim,
                                                self.dataset,
                                                self.selected_attrs)

                # Translate images.
                x_fake_list = [x_real]
                for c_trg in c_trg_list:
                    x_fake_list.append(self.G(x_real, c_trg))

                # Save the translated images.
                x_concat = torch.cat(x_fake_list, dim=3)
                result_path = os.path.join(self.result_dir,
                                           "{}-images.jpg".format(i + 1))
                # result_path = os.path.join(self.result_dir, '{}.jpg'.format(self.model_save_dir.replace("/vol/bitbucket/apg416/","").replace("/models","")))
                save_image(self.denorm(x_concat.data.cpu()),
                           result_path,
                           nrow=1,
                           padding=0)
                print("Saved real and fake images into {}...".format(
                    result_path))