Exemplo n.º 1
0
    def build_model(self):
        """Create a generator and a discriminator."""
        if self.dataset in ['CelebA', 'RaFD']:
            # self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
            self.G = Generator(self.g_conv_dim)
            # self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)
            self.D = Discriminator(repeat_num=5,
                                   channel_multiplier=32,
                                   dimension=1)

            self.F = Mapping(image_size=128, repeat_num=6)

            self.E = StyleEncoder(repeat_num=5,
                                  channel_multiplier=16,
                                  dimension=64)

            # # initialize the weights  of all modules using he init and set all biases to 0
            # self.G.apply(init_weights)
            # self.D.apply(init_weights)
            # self.E.apply(init_weights)
            # self.F.apply(init_weights)

        elif self.dataset in ['Both']:
            self.G = Generator(self.g_conv_dim, self.c_dim + self.c2_dim + 2,
                               self.g_repeat_num)  # 2 for mask vector.
            self.D = Discriminator(self.image_size, self.d_conv_dim,
                                   self.c_dim + self.c2_dim, self.d_repeat_num)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])

        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])
        # self.d_optimizer = torch.optim.SGD(self.D.parameters(), lr=0.0001, momentum=0.9)

        self.e_optimizer = torch.optim.Adam(self.E.parameters(), self.e_lr,
                                            [self.beta1, self.beta2])

        self.f_optimizer = torch.optim.Adam(self.F.parameters(), self.f_lr,
                                            [self.beta1, self.beta2])

        self.l1_loss = nn.L1Loss()
        self.bce_loss = nn.BCEWithLogitsLoss()
        # self.ce_loss = nn.CrossEntropyLoss()

        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')
        self.print_network(self.E, 'E')
        self.print_network(self.F, 'F')

        self.G.to(self.device)
        self.D.to(self.device)
        self.E.to(self.device)
        self.F.to(self.device)
Exemplo n.º 2
0
def test():

    hparams = get_hparams()
    print(hparams.task_name)
    model_path = os.path.join(hparams.model_path, hparams.task_name,
                              hparams.spec_opt)

    # Load Dataset Loader

    root = '../dataset/feat/test'
    list_dir_A = './etc/Test_dt05_real_isolated_1ch_track_list.csv'
    list_dir_B = './etc/Test_dt05_simu_isolated_1ch_track_list.csv'

    output_dir = './output/{}/{}_img'.format(hparams.task_name,
                                             hparams.iteration_num)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    normalizer_clean = Tanhize('clean')
    normalizer_noisy = Tanhize('noisy')
    test_list_A, speaker_A = testset_list_classifier(root, list_dir_A)
    test_list_B, speaker_B = testset_list_classifier(root, list_dir_B)

    generator_A = Generator()
    generator_B = Generator()
    discriminator_A = Discriminator()
    discriminator_B = Discriminator()
    ContEncoder_A = ContentEncoder()
    ContEncoder_B = ContentEncoder()

    StEncoder_A = StyleEncoder()
    StEncoder_B = StyleEncoder()

    generator_A = nn.DataParallel(generator_A).cuda()
    generator_B = nn.DataParallel(generator_B).cuda()
    discriminator_A = nn.DataParallel(discriminator_A).cuda()
    discriminator_B = nn.DataParallel(discriminator_B).cuda()

    ContEncoder_A = nn.DataParallel(ContEncoder_A).cuda()
    ContEncoder_B = nn.DataParallel(ContEncoder_B).cuda()

    StEncoder_A = nn.DataParallel(StEncoder_A).cuda()
    StEncoder_B = nn.DataParallel(StEncoder_B).cuda()

    map_location = lambda storage, loc: storage
    generator_A.load_state_dict(
        torch.load('./models/{}/{}/model_gen_A_{}.pth'.format(
            hparams.task_name, hparams.spec_opt, hparams.iteration_num),
                   map_location=map_location))
    generator_B.load_state_dict(
        torch.load('./models/{}/{}/model_gen_B_{}.pth'.format(
            hparams.task_name, hparams.spec_opt, hparams.iteration_num),
                   map_location=map_location))
    discriminator_A.load_state_dict(
        torch.load('./models/{}/{}/model_dis_A_{}.pth'.format(
            hparams.task_name, hparams.spec_opt, hparams.iteration_num),
                   map_location=map_location))
    discriminator_B.load_state_dict(
        torch.load('./models/{}/{}/model_dis_B_{}.pth'.format(
            hparams.task_name, hparams.spec_opt, hparams.iteration_num),
                   map_location=map_location))
    ContEncoder_A.load_state_dict(
        torch.load('./models/{}/{}/model_ContEnc_A_{}.pth'.format(
            hparams.task_name, hparams.spec_opt, hparams.iteration_num),
                   map_location=map_location))
    ContEncoder_B.load_state_dict(
        torch.load('./models/{}/{}/model_ContEnc_B_{}.pth'.format(
            hparams.task_name, hparams.spec_opt, hparams.iteration_num),
                   map_location=map_location))
    StEncoder_A.load_state_dict(
        torch.load('./models/{}/{}/model_StEnc_A_{}.pth'.format(
            hparams.task_name, hparams.spec_opt, hparams.iteration_num),
                   map_location=map_location))
    StEncoder_B.load_state_dict(
        torch.load('./models/{}/{}/model_StEnc_B_{}.pth'.format(
            hparams.task_name, hparams.spec_opt, hparams.iteration_num),
                   map_location=map_location))

    for i in range(10):

        generator_B.eval()
        ContEncoder_A.eval()
        StEncoder_B.eval()

        feat = testset_loader(root,
                              test_list_A[i],
                              speaker_A,
                              normalizer=normalizer_noisy)

        print(feat['audio_name'])

        A_content = Variable(torch.FloatTensor(feat['sp']).unsqueeze(0)).cuda()

        A_cont = ContEncoder_A(A_content)

        z_st = get_z_random_sparse(1, 8, 1)
        st_0 = torch.ones((1, 8, 1)) * 2

        feature_z = generator_B(A_cont, z_st)
        feature_z = normalizer_noisy.backward_process(feature_z.squeeze().data)
        feature_z = feature_z.squeeze().data.cpu().numpy()

        feature_0 = generator_B(A_cont, st_0)
        feature_0 = normalizer_noisy.backward_process(feature_0.squeeze().data)
        feature_0 = feature_0.squeeze().data.cpu().numpy()

        imsave(os.path.join(
            output_dir, 'z-img-' + feat['audio_name'].split('.')[0] + '.png'),
               feature_z.transpose(),
               origin='lower')
        imsave(os.path.join(
            output_dir, '0-img-' + feat['audio_name'].split('.')[0] + '.png'),
               feature_0.transpose(),
               origin='lower')

    for i in range(10):

        generator_B.eval()
        ContEncoder_A.eval()

        feat = testset_loader(root,
                              test_list_B[i],
                              speaker_B,
                              normalizer=normalizer_noisy)

        print(feat['audio_name'])

        A_content = Variable(torch.FloatTensor(feat['sp']).unsqueeze(0)).cuda()

        A_cont = ContEncoder_A(A_content)

        z_st = get_z_random_sparse(1, 8, 1)

        feature_z = generator_B(A_cont, z_st)
        feature_z = normalizer_noisy.backward_process(feature_z.squeeze().data)
        feature_z = feature_z.squeeze().data.cpu().numpy()

        imsave(os.path.join(
            output_dir, 'z-img-' + feat['audio_name'].split('.')[0] + '.png'),
               feature_z.transpose(),
               origin='lower')
Exemplo n.º 3
0
class Solver(object):
    """
    Solver for training and testing StarGAN.
    """
    def __init__(self, data_loader, config):
        # def __init__(self, celeba_loader, rafd_loader, config):
        """Initialize configurations."""

        # Data loader.
        # self.celeba_loader = celeba_loader
        # self.rafd_loader = rafd_loader
        self.data_loader = data_loader

        # Model configurations.
        # self.c_dim = config.c_dim
        # self.c2_dim = config.c2_dim
        self.num_domains = config.num_domains
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp
        # self.lambda_sty = config.lambda_rec
        self.lambda_sty = 1
        self.lambda_ds = 1
        self.lambda_cyc = 1

        # Training configurations.
        self.dataset = config.dataset
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.e_lr = config.e_lr
        self.f_lr = config.f_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters
        self.selected_attrs = config.selected_attrs
        self.reg_param = 1

        # Test configurations.
        self.test_iters = config.test_iters

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        # Directories.
        self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model and tensorboard.
        self.build_model()

        # if self.use_tensorboard:
        #     self.build_tensorboard()

    def ones_target(self, size, device="cuda"):
        """Tensor containing ones, with shape = size"""
        # data = Variable(torch.ones(size, 1))
        if device == "cuda":
            data = torch.ones((size, 1)).to(device)
        else:
            data = torch.ones((size, 1))
        return data

    def zeros_target(self, size, device="cuda"):
        """
        Tensor containing zeros, with shape = size
        """
        # data = Variable(torch.zeros(size, 1))
        if device == "cuda":
            data = torch.zeros((size, 1)).to(device)
        else:
            data = torch.zeros((size, 1))
        return data

    def build_model(self):
        """Create a generator and a discriminator."""
        if self.dataset in ['CelebA', 'RaFD']:
            # self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
            self.G = Generator(self.g_conv_dim)
            # self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)
            self.D = Discriminator(repeat_num=5,
                                   channel_multiplier=32,
                                   dimension=1)

            self.F = Mapping(image_size=128, repeat_num=6)

            self.E = StyleEncoder(repeat_num=5,
                                  channel_multiplier=16,
                                  dimension=64)

            # # initialize the weights  of all modules using he init and set all biases to 0
            # self.G.apply(init_weights)
            # self.D.apply(init_weights)
            # self.E.apply(init_weights)
            # self.F.apply(init_weights)

        elif self.dataset in ['Both']:
            self.G = Generator(self.g_conv_dim, self.c_dim + self.c2_dim + 2,
                               self.g_repeat_num)  # 2 for mask vector.
            self.D = Discriminator(self.image_size, self.d_conv_dim,
                                   self.c_dim + self.c2_dim, self.d_repeat_num)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])

        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])
        # self.d_optimizer = torch.optim.SGD(self.D.parameters(), lr=0.0001, momentum=0.9)

        self.e_optimizer = torch.optim.Adam(self.E.parameters(), self.e_lr,
                                            [self.beta1, self.beta2])

        self.f_optimizer = torch.optim.Adam(self.F.parameters(), self.f_lr,
                                            [self.beta1, self.beta2])

        self.l1_loss = nn.L1Loss()
        self.bce_loss = nn.BCEWithLogitsLoss()
        # self.ce_loss = nn.CrossEntropyLoss()

        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')
        self.print_network(self.E, 'E')
        self.print_network(self.F, 'F')

        self.G.to(self.device)
        self.D.to(self.device)
        self.E.to(self.device)
        self.F.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def restore_model(self, resume_iters):
        """
        Restore the trained generator and discriminator.
        """
        print(
            'Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir,
                              '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir,
                              '{}-D.ckpt'.format(resume_iters))
        E_path = os.path.join(self.model_save_dir,
                              '{}-E.ckpt'.format(resume_iters))
        F_path = os.path.join(self.model_save_dir,
                              '{}-F.ckpt'.format(resume_iters))

        self.G.load_state_dict(
            torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(
            torch.load(D_path, map_location=lambda storage, loc: storage))
        self.E.load_state_dict(
            torch.load(E_path, map_location=lambda storage, loc: storage))
        self.F.load_state_dict(
            torch.load(F_path, map_location=lambda storage, loc: storage))

    def build_tensorboard(self):
        """
        Build a tensorboard logger.
        """
        from logger import Logger
        self.logger = Logger(self.log_dir)

    def update_lr(self, g_lr, d_lr):
        """
        Decay learning rates of the generator and discriminator.
        """
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def noise(self, size, dimension, device="cuda"):
        """
        Generates a 1-d vector of gaussian sampled random values
        """
        # n = Variable(torch.randn(size, 100))
        # n = torch.randn((size, 100), requires_grad=True).to(device)
        if device == "cuda":
            n = torch.randn((size, dimension), requires_grad=True)
        else:
            n = torch.randn((size, dimension), requires_grad=True).to(device)
        return n

    def compute_grad2(self, d_out, x_in):
        """
        https://github.com/LMescheder/GAN_stability/blob/master/gan_training/train.py
        :param d_out: discriminator output
        :param x_in: x_real or x_fake
        :return:
        """
        batch_size = x_in.size(0)
        grad_dout = autograd.grad(outputs=d_out.sum(),
                                  inputs=x_in,
                                  create_graph=True,
                                  retain_graph=True,
                                  only_inputs=True)[0]
        grad_dout2 = grad_dout.pow(2)
        assert (grad_dout2.size() == x_in.size())
        reg = grad_dout2.view(batch_size, -1).sum(1)
        return reg

    def train_discriminator(self, x_real, x_fake, label_org, label_trg):
        """
        train discriminator
        @:param is_d is d or g
        :return:
        """
        self.d_optimizer.zero_grad()

        x_real.requires_grad = True

        # compute adversarial loss on real images
        # with torch.no_grad():
        out_real = torch.gather(self.D(x_real, num_domains=2), 1,
                                label_org.long())
        loss_real = self.bce_loss(out_real, self.ones_target(self.batch_size))
        # loss_real.backward()

        # R1 regularization only on real data
        # out_real.requires_grad = True
        # loss_real.backward(retain_graph=True)

        reg = self.reg_param * self.compute_grad2(out_real, x_real).mean()
        # reg.backward()

        # target style code s_tilde

        # Compute adversarial loss on fake images.
        # x_fake = self.G(x_real, s_tilde[0])
        # x_fake = self.G(x_real, s_tilde_trg)
        x_fake.detach()
        d = self.D(x_fake)

        out_fake = torch.gather(d, 1, label_trg.long())
        # d_loss_fake = torch.mean(out_fake)
        loss_fake = self.bce_loss(out_fake, self.zeros_target(self.batch_size))
        # loss_fake.backward()

        # self.d_optimizer.step()

        # Backward and optimize.
        # d_loss = loss_real + loss_fake
        d_loss = loss_real + loss_fake + reg
        d_loss.backward(retain_graph=True)
        self.d_optimizer.step()

        # l1_norm = torch.norm(self.D.weight, p=1)
        # d_loss += l1_norm

        return d_loss, loss_real, loss_fake

    def train_generator(self, x_real, g_x_fake, g_s_tilde_trg, label_org,
                        label_trg):
        """
        train generator
        :param x_real:
        :param label_org:
        :param label_trg:
        :return:
        """
        # clear cached gradients for optimizer
        self.g_optimizer.zero_grad()
        self.e_optimizer.zero_grad()
        self.f_optimizer.zero_grad()

        # style reconstruction
        # g_s_tilde_trg = self.generate_style_code(label_trg)
        # g_x_fake = self.G(x_real, g_s_tilde_trg)

        # self.G(x_real, g_s_tilde_trg)
        # s_hat = self.E(d_x_fake)

        # s_hat: estimated style code of source image
        # loss style reconstruction:style reconstruction
        s_hat_sty = self.E(g_x_fake, num_domains=self.num_domains)
        # s_hat_trg = torch.index_select(torch.stack(s_hat_sty, 1), 1, label_trg.squeeze().long())[:, 0, :]
        s_hat_trg = torch.squeeze(
            torch.stack([
                torch.index_select(x, 0, i) for x, i in zip(
                    torch.chunk(torch.stack(s_hat_sty, 1),
                                chunks=self.num_domains,
                                dim=1),
                    label_trg.squeeze().long())
            ]))

        g_loss_sty = self.l1_loss(g_s_tilde_trg, s_hat_trg)

        # loss cycle: preserving source characteristics
        s_hat_cyc = self.E(x_real, num_domains=self.num_domains)
        # s_hat_org = torch.index_select(torch.stack(s_hat_cyc, 1), 1, label_org.squeeze().long())[:, 0, :]
        s_hat_org = torch.squeeze(
            torch.stack([
                torch.index_select(x, 0, i) for x, i in zip(
                    torch.chunk(torch.stack(s_hat_cyc, 1),
                                chunks=self.num_domains,
                                dim=1),
                    label_org.squeeze().long())
            ]))

        x_fake_cyc = self.G(g_x_fake, s_hat_org)
        g_loss_cyc = self.l1_loss(x_real, x_fake_cyc)

        # loss style diversification:style diversification
        # z1 = self.noise(size=self.batch_size, dimension=16)
        # s1_tilde = self.F(z1, num_domains=2)
        # s1_tilde_trg = torch.index_select(torch.stack(s1_tilde, 1), 1, label_trg.squeeze().long())[:, 0, :]
        s1_tilde_trg = self.generate_style_code(label_trg)
        # z2 = self.noise(size=self.batch_size, dimension=16)
        # s2_tilde = self.F(z2, num_domains=2)
        # s2_tilde_trg = torch.index_select(torch.stack(s2_tilde, 1), 1, label_trg.squeeze().long())[:, 0, :]
        s2_tilde_trg = self.generate_style_code(label_trg)
        g_loss_ds = self.l1_loss(self.G(x_real, s1_tilde_trg),
                                 self.G(x_real, s2_tilde_trg))
        #
        # out_real = torch.gather(self.D(x_real, num_domains=2), 1, label_org.long())
        # loss_real = self.bce_loss(self.ones_target(self.batch_size), out_real)
        # target style code s_tilde

        # Compute loss with fake images.
        # x_fake = self.G(x_real, s_tilde[0])

        # compute adversarial loss on fake images
        # x_fake = self.G(x_real, g_s_tilde_trg)
        d = self.D(g_x_fake)
        out_fake = torch.gather(d, 1, label_trg.long())
        # d_loss_fake = torch.mean(out_fake)
        g_adv_loss = self.bce_loss(out_fake, self.ones_target(self.batch_size))

        g_loss = g_adv_loss + self.lambda_sty * g_loss_sty + self.lambda_cyc * g_loss_cyc - self.lambda_ds * g_loss_ds
        # g_loss = g_adv_loss

        # g_loss.backward(retain_graph=True)
        g_loss.backward()

        self.g_optimizer.step()
        self.e_optimizer.step()
        self.f_optimizer.step()

        return g_loss, g_adv_loss, g_loss_sty, g_loss_cyc, g_loss_ds
        # return g_loss

    def compute_adversarial_loss(self, is_d, x_real, label_org, s_tilde_trg,
                                 label_trg):
        """
        compute non-saturating adversarial loss
        @:param is_d is d or g
        :return:
        """
        out_real = torch.gather(self.D(x_real, num_domains=2), 1,
                                label_org.long())
        loss_real = self.bce_loss(self.ones_target(self.batch_size), out_real)
        # target style code s_tilde

        # Compute loss with fake images.
        # x_fake = self.G(x_real, s_tilde[0])
        x_fake = self.G(x_real, s_tilde_trg)
        if is_d:
            # d = self.D(x_fake).detach()
            x_fake.detach()
        d = self.D(x_fake)

        out_fake = torch.gather(d, 1, label_trg.long())
        # d_loss_fake = torch.mean(out_fake)
        loss_fake = self.bce_loss(self.zeros_target(self.batch_size), out_fake)

        # Backward and optimize.
        d_loss = loss_real + loss_fake

        # # R1 regularization
        # l1_norm = torch.norm(self.D.weight, p=1)
        # d_loss += l1_norm

        return d_loss, loss_real, loss_fake

    def reset_grad(self):
        """
        Reset the gradient buffers.
        """
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        self.e_optimizer.zero_grad()
        self.f_optimizer.zero_grad()

    def denorm(self, x):
        """
        Convert the range from [-1, 1] to [0, 1].
        """
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def gradient_penalty(self, y, x):
        """
        Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.
        """
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm - 1)**2)

    def label2onehot(self, labels, dim):
        """
        Convert label indices to one-hot vectors.
        """
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def create_labels(self,
                      c_org,
                      c_dim=5,
                      dataset='CelebA',
                      selected_attrs=None):
        """
        Generate target domain labels for debugging and testing.
        """
        # Get hair color indices.
        if dataset == 'CelebA':
            hair_color_indices = []
            for i, attr_name in enumerate(selected_attrs):
                if attr_name in [
                        'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'
                ]:
                    hair_color_indices.append(i)

        c_trg_list = []
        for i in range(c_dim):
            if dataset == 'CelebA':
                c_trg = c_org.clone()
                if i in hair_color_indices:  # Set one hair color to 1 and the rest to 0.
                    c_trg[:, i] = 1
                    for j in hair_color_indices:
                        if j != i:
                            c_trg[:, j] = 0
                else:
                    c_trg[:, i] = (c_trg[:,
                                         i] == 0)  # Reverse attribute value.
            elif dataset == 'RaFD':
                c_trg = self.label2onehot(torch.ones(c_org.size(0)) * i, c_dim)

            c_trg_list.append(c_trg.to(self.device))
        return c_trg_list

    def generate_style_code(self, label_trg, num_domains=2):
        """

        :param label_trg:
        :return:
        """
        z = self.noise(size=self.batch_size, dimension=16, device=self.device)
        s_tilde = self.F(z, num_domains=num_domains)

        # Compute loss with fake images.
        # s_tilde_trg = torch.index_select(torch.stack(s_tilde, 1), 1, label_trg.squeeze().long())[:, 0, :]
        s_tilde_trg = torch.squeeze(
            torch.stack([
                torch.index_select(x, 1, i) for x, i in zip(
                    torch.chunk(torch.stack(s_tilde, 1),
                                chunks=self.batch_size,
                                dim=0),
                    label_trg.squeeze().long())
            ]))
        return s_tilde_trg

    def get_reference_style(self, x_reference, label_trg):
        """
        get reference guided style code
        :param x_reference: reference image
        :param label_trg: target domain labels
        :return: s_hat_trg: style code in target domain
        """
        s_hat_sty = self.E(x_reference, num_domains=self.num_domains)
        # s_hat_trg = torch.index_select(torch.stack(s_hat_sty, 1), 1, label_trg.squeeze().long())[:, 0, :]
        s_hat_trg = torch.squeeze(
            torch.stack([
                torch.index_select(x, 1, i) for x, i in zip(
                    torch.chunk(torch.stack(s_hat_sty, 1),
                                chunks=self.batch_size,
                                dim=0),
                    label_trg.squeeze().long())
            ]))
        return s_hat_trg

    def classification_loss(self, logit, target, dataset='CelebA'):
        """
        Compute binary or softmax cross entropy loss.
        """
        if dataset == 'CelebA':
            return F.binary_cross_entropy_with_logits(
                logit, target, size_average=False) / logit.size(0)
        elif dataset == 'RaFD':
            return F.cross_entropy(logit, target)

    def train(self):
        """
        Train StarGAN within a single dataset.
        """
        # # Set data loader.
        # if self.dataset == 'CelebA':
        #     data_loader = self.celeba_loader
        # elif self.dataset == 'RaFD':
        #     data_loader = self.rafd_loader

        # Fetch fixed inputs for debugging.
        data_iter = iter(self.data_loader)
        x_fixed, c_org = next(data_iter)
        x_fixed = x_fixed.to(self.device)
        # c_fixed_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)

        # Learning rate cache for decaying.
        # g_lr = self.g_lr
        # d_lr = self.d_lr

        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):

            try:
                # =================================================================================== #
                #                             1. Preprocess input data                                #
                # =================================================================================== #

                # Fetch real images and labels.
                try:
                    x_real, label_org = next(data_iter)
                    label_org = torch.unsqueeze(label_org, dim=1)
                except Exception as e:
                    print(str(e))
                    data_iter = iter(self.data_loader)
                    x_real, label_org = next(data_iter)

                # Generate target domain labels randomly.
                rand_idx = torch.randperm(label_org.size(0))
                label_trg = label_org[rand_idx]

                # if self.dataset == 'CelebA':
                #     c_org = label_org.clone()
                #     c_trg = label_trg.clone()
                # elif self.dataset == 'RaFD':
                #     c_org = self.label2onehot(label_org, self.c_dim)
                #     c_trg = self.label2onehot(label_trg, self.c_dim)

                x_real = x_real.to(self.device)  # Input images.
                # c_org = c_org.to(self.device)  # Original domain labels.
                # c_trg = c_trg.to(self.device)  # Target domain labels.
                label_org = label_org.to(
                    self.device)  # Labels for computing classification loss.
                label_trg = label_trg.to(
                    self.device)  # Labels for computing classification loss.

                # =================================================================================== #
                #                             2. Train the discriminator                              #
                # =================================================================================== #
                # self.d_optimizer.zero_grad()
                # Compute loss with real images.
                # out_src, out_cls = self.D(x_real)
                # out_real = self.D(x_real, num_domains=2)

                # d_out_real = torch.gather(out_real, 1, label_org.long())
                # d_loss_real = torch.mean(torch.log(d_out_src))

                # d_loss_real = self.bce_loss(self.zeros_target(self.batch_size), d_out_real)
                # d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset)
                #
                # # z:latent code s_tilde:target style code
                # z = self.noise(size=self.batch_size, dimension=16, device=self.device)
                # s_tilde = self.F(z, num_domains=2)
                #
                # # target style code s_tilde
                #
                # # Compute loss with fake images.
                # # s_tilde_tensor = torch.stack(s_tilde, 1)
                # s_tilde_trg = torch.index_select(torch.stack(s_tilde, 1), 1, label_trg.squeeze().long())[:, 0, :]
                # # s_tilde_trg = torch.gather(s_tilde_tensor, 1, label_trg.expand(s_tilde_tensor.size()).long())
                # # s_tilde_trg = torch.gather(torch.stack(s_tilde, 1), 1, torch.unsqueeze(label_trg, 2).long())
                #
                # d_x_fake = self.G(x_real, s_tilde_trg)
                # out_fake = self.D(d_x_fake.detach())
                # d_out_fake = torch.gather(out_fake, 1, label_trg.long())
                # # d_loss_fake = torch.mean(out_fake)
                # d_loss_fake = self.bce_loss(self.ones_target(self.batch_size), d_out_fake)

                # # Compute loss for gradient penalty.
                # alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
                # x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
                # out_src, _ = self.D(x_hat)
                # d_loss_gp = self.gradient_penalty(out_src, x_hat)

                # # Backward and optimize.
                # d_loss = -(d_loss_real + d_loss_fake)
                #
                # # R1 regularization
                # l1_norm = torch.norm(self.D.weight, p=1)
                # d_loss += l1_norm

                s_tilde_trg = self.generate_style_code(label_trg)
                x_fake = self.G(x_real, s_tilde_trg)
                # fake_logits = self.D(x_fake)

                # d_loss, d_loss_real, d_loss_fake = self.compute_adversarial_loss(True, x_real, label_org, d_x_fake, label_trg)
                d_loss, d_loss_real, d_loss_fake = self.train_discriminator(
                    x_real, x_fake, label_org, label_trg)

                # d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
                # d_loss = -d_loss
                # self.reset_grad()
                # d_loss.backward()
                # self.d_optimizer.step()

                # Logging.
                loss = {
                    'D/loss': d_loss.item(),
                    'D/loss_real': d_loss_real.item(),
                    'D/loss_fake': d_loss_fake.item()
                }

                writer.add_scalar('D/loss', d_loss.item(), i)
                writer.add_scalar('D/loss_real', d_loss_real.item(), i)
                writer.add_scalar('D/loss_fake', d_loss_fake.item(), i)

                # =================================================================================== #
                #                               3. Train the generator                                #
                # =================================================================================== #

                if (i + 1) % self.n_critic == 0:
                    # # style reconstruction
                    # g_s_tilde_trg = self.generate_style_code(label_trg)
                    # # g_x_fake = self.G(x_real, g_s_tilde_trg)
                    # # s_hat = self.E(d_x_fake)
                    #
                    # # s_hat: estimated style code of source image
                    # # loss style reconstruction:style reconstruction
                    # s_hat = self.E(self.G(x_real, g_s_tilde_trg), num_domains=2)
                    # s_hat_trg = torch.index_select(torch.stack(s_hat, 1), 1, label_trg.squeeze().long())[:, 0, :]
                    # g_loss_sty = self.l1_loss(g_s_tilde_trg, s_hat_trg)
                    #
                    # # loss cycle: preserving source characteristics
                    # s_hat_org = torch.index_select(torch.stack(s_hat, 1), 1, label_org.squeeze().long())[:, 0, :]
                    # x_fake_cyc = self.G(self.G(x_real, g_s_tilde_trg), s_hat_org)
                    # g_loss_cyc = self.l1_loss(x_real, x_fake_cyc)
                    #
                    # # loss style diversification:style diversification
                    # # z1 = self.noise(size=self.batch_size, dimension=16)
                    # # s1_tilde = self.F(z1, num_domains=2)
                    # # s1_tilde_trg = torch.index_select(torch.stack(s1_tilde, 1), 1, label_trg.squeeze().long())[:, 0, :]
                    # s1_tilde_trg = self.generate_style_code(label_trg)
                    # # z2 = self.noise(size=self.batch_size, dimension=16)
                    # # s2_tilde = self.F(z2, num_domains=2)
                    # # s2_tilde_trg = torch.index_select(torch.stack(s2_tilde, 1), 1, label_trg.squeeze().long())[:, 0, :]
                    # s2_tilde_trg = self.generate_style_code(label_trg)
                    # g_loss_ds = self.l1_loss(self.G(x_real, s1_tilde_trg), self.G(x_real, s2_tilde_trg))

                    # # Original-to-target domain.
                    # x_fake = self.G(x_real, c_trg)
                    # out_src, out_cls = self.D(x_fake)
                    # g_loss_fake = - torch.mean(out_src)
                    # g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)
                    #
                    # # Target-to-original domain.
                    # x_reconst = self.G(x_fake, c_org)
                    # g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

                    # out_real = self.D(x_real, num_domains=2)
                    # g_out_real = torch.gather(out_real, 1, label_org.long())
                    # g_loss_real = self.bce_loss(self.zeros_target(self.batch_size), g_out_real)
                    # # target style code s_tilde
                    #
                    # # Compute loss with fake images.
                    # # x_fake = self.G(x_real, s_tilde[0])
                    # out_fake = self.D(g_x_fake)
                    # g_out_fake = torch.gather(out_fake, 1, label_org.long())
                    # # d_loss_fake = torch.mean(out_fake)
                    # g_loss_fake = self.bce_loss(self.ones_target(self.batch_size), g_out_fake)

                    g_loss, g_adv_loss, g_loss_sty, g_loss_cyc, g_loss_ds = self.train_generator(
                        x_real, x_fake, s_tilde_trg, label_org, label_trg)
                    # g_loss = self.train_generator(x_real, label_org, label_trg)

                    # g_adv_loss = self.compute_adversarial_loss(False, x_real, label_org, g_s_tilde_trg, label_trg)[0]

                    # Backward and optimize.
                    # g_loss = g_adv_loss + self.lambda_sty * g_loss_sty + self.lambda_cyc * g_loss_cyc + self.lambda_ds * g_loss_ds
                    # self.reset_grad()
                    # g_loss.backward()
                    # self.g_optimizer.step()
                    # self.e_optimizer.step()
                    # self.f_optimizer.step()

                    # Logging.
                    # loss['G/loss_fake'] = g_loss_fake.item()
                    # loss['G/loss_sty'] = g_loss_sty.item()
                    # loss['G/loss_cyc'] = g_loss_cyc.item()
                    # loss['G/loss_ds'] = g_loss_ds.item()
                    loss['G/loss'] = g_loss.item()
                    loss['G/loss_adv'] = g_adv_loss.item()
                    loss['G/loss_sty'] = g_loss_sty.item()
                    loss['G/loss_cyc'] = g_loss_cyc.item()
                    loss['G/loss_ds'] = g_loss_ds.item()

                    # writer.add_scalar('G/loss_cyc', g_loss_cyc.item(), i)
                    # writer.add_scalar('G/loss_ds', g_loss_ds.item(), i)
                    writer.add_scalar('G/loss', g_loss.item(), i)
                    writer.add_scalar('G/loss_adv', g_adv_loss.item(), i)
                    writer.add_scalar('G/loss_sty', g_loss_sty.item(), i)
                    writer.add_scalar('G/loss_cyc', g_loss_cyc.item(), i)
                    writer.add_scalar('G/loss_ds', g_loss_ds.item(), i)

                # =================================================================================== #
                #                                 4. Miscellaneous                                    #
                # =================================================================================== #

                # Print out training information.
                if (i + 1) % self.log_step == 0:
                    et = time.time() - start_time
                    et = str(datetime.timedelta(seconds=et))[:-7]
                    log = "Elapsed [{}], Iteration [{}/{}]".format(
                        et, i + 1, self.num_iters)
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    # if self.use_tensorboard:
                    #     for tag, value in loss.items():
                    #         self.logger.scalar_summary(tag, value, i + 1)

                # Translate fixed images for debugging.
                if (i + 1) % self.sample_step == 0:
                    with torch.no_grad():
                        # random style: source images + generated images
                        g_s_tilde_trg = self.generate_style_code(label_trg)

                        # reference guided style: x_real as the reference image
                        ref_style = self.get_reference_style(x_real, label_org)

                        x_fake_list = [
                            x_fixed,
                            self.G(x_fixed, g_s_tilde_trg), x_real,
                            self.G(x_fixed, ref_style)
                        ]

                        # for c_fixed in label_org:
                        x_concat = torch.cat(x_fake_list, dim=3)
                        sample_path = os.path.join(
                            self.sample_dir, '{}-images.jpg'.format(i + 1))
                        save_image(self.denorm(x_concat.data.cpu()),
                                   sample_path,
                                   nrow=1,
                                   padding=0)
                        print('Saved real and fake images into {}...'.format(
                            sample_path))

                        grid = torchvision.utils.make_grid(x_concat)
                        writer.add_image('images', grid, 0)
                        # writer.add_graph(model, images)

                # Save model checkpoints.
                if (i + 1) % self.model_save_step == 0:
                    G_path = os.path.join(self.model_save_dir,
                                          '{}-G.ckpt'.format(i + 1))
                    D_path = os.path.join(self.model_save_dir,
                                          '{}-D.ckpt'.format(i + 1))
                    E_path = os.path.join(self.model_save_dir,
                                          '{}-E.ckpt'.format(i + 1))
                    F_path = os.path.join(self.model_save_dir,
                                          '{}-F.ckpt'.format(i + 1))
                    torch.save(self.G.state_dict(), G_path)
                    torch.save(self.D.state_dict(), D_path)
                    torch.save(self.E.state_dict(), E_path)
                    torch.save(self.F.state_dict(), F_path)
                    print('Saved model checkpoints into {}...'.format(
                        self.model_save_dir))
                #
                # # Decay learning rates.
                # if (i + 1) % self.lr_update_step == 0 and (i + 1) > (self.num_iters - self.num_iters_decay):
                #     g_lr -= (self.g_lr / float(self.num_iters_decay))
                #     d_lr -= (self.d_lr / float(self.num_iters_decay))
                #     self.update_lr(g_lr, d_lr)
                #     print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

                # Decay weight lambda ds
                if (i + 1) < self.num_iters_decay:
                    self.lambda_ds = 1 - 0.00002 * (i + 1)
                    # print('Decayed weight lambda ds , lambda_ds: {}'.format(self.lambda_ds))
            except Exception as e:
                print(str(e))

        # close the tensorboard summary writter
        writer.close()

    def train_multi(self):
        """Train StarGAN with multiple datasets."""
        # Data iterators.
        celeba_iter = iter(self.celeba_loader)
        rafd_iter = iter(self.rafd_loader)

        # Fetch fixed inputs for debugging.
        x_fixed, c_org = next(celeba_iter)
        x_fixed = x_fixed.to(self.device)
        c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA',
                                           self.selected_attrs)
        c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
        zero_celeba = torch.zeros(x_fixed.size(0), self.c_dim).to(
            self.device)  # Zero vector for CelebA.
        zero_rafd = torch.zeros(x_fixed.size(0), self.c2_dim).to(
            self.device)  # Zero vector for RaFD.
        mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to(
            self.device)  # Mask vector: [1, 0].
        mask_rafd = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to(
            self.device)  # Mask vector: [0, 1].

        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):
            for dataset in ['CelebA', 'RaFD']:

                # =================================================================================== #
                #                             1. Preprocess input data                                #
                # =================================================================================== #

                # Fetch real images and labels.
                data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter

                try:
                    x_real, label_org = next(data_iter)
                except:
                    if dataset == 'CelebA':
                        celeba_iter = iter(self.celeba_loader)
                        x_real, label_org = next(celeba_iter)
                    elif dataset == 'RaFD':
                        rafd_iter = iter(self.rafd_loader)
                        x_real, label_org = next(rafd_iter)

                # Generate target domain labels randomly.
                rand_idx = torch.randperm(label_org.size(0))
                label_trg = label_org[rand_idx]

                if dataset == 'CelebA':
                    c_org = label_org.clone()
                    c_trg = label_trg.clone()
                    zero = torch.zeros(x_real.size(0), self.c2_dim)
                    mask = self.label2onehot(torch.zeros(x_real.size(0)), 2)
                    c_org = torch.cat([c_org, zero, mask], dim=1)
                    c_trg = torch.cat([c_trg, zero, mask], dim=1)
                elif dataset == 'RaFD':
                    c_org = self.label2onehot(label_org, self.c2_dim)
                    c_trg = self.label2onehot(label_trg, self.c2_dim)
                    zero = torch.zeros(x_real.size(0), self.c_dim)
                    mask = self.label2onehot(torch.ones(x_real.size(0)), 2)
                    c_org = torch.cat([zero, c_org, mask], dim=1)
                    c_trg = torch.cat([zero, c_trg, mask], dim=1)

                x_real = x_real.to(self.device)  # Input images.
                c_org = c_org.to(self.device)  # Original domain labels.
                c_trg = c_trg.to(self.device)  # Target domain labels.
                label_org = label_org.to(
                    self.device)  # Labels for computing classification loss.
                label_trg = label_trg.to(
                    self.device)  # Labels for computing classification loss.

                # =================================================================================== #
                #                             2. Train the discriminator                              #
                # =================================================================================== #

                # Compute loss with real images.
                out_src, out_cls = self.D(x_real)
                out_cls = out_cls[:, :self.
                                  c_dim] if dataset == 'CelebA' else out_cls[:,
                                                                             self
                                                                             .
                                                                             c_dim:]
                d_loss_real = -torch.mean(out_src)
                d_loss_cls = self.classification_loss(out_cls, label_org,
                                                      dataset)

                # Compute loss with fake images.
                x_fake = self.G(x_real, c_trg)
                out_src, _ = self.D(x_fake.detach())
                d_loss_fake = torch.mean(out_src)

                # Compute loss for gradient penalty.
                alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
                x_hat = (alpha * x_real.data +
                         (1 - alpha) * x_fake.data).requires_grad_(True)
                out_src, _ = self.D(x_hat)
                d_loss_gp = self.gradient_penalty(out_src, x_hat)

                # Backward and optimize.
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging.
                loss = {}
                loss['D/loss_real'] = d_loss_real.item()
                loss['D/loss_fake'] = d_loss_fake.item()
                loss['D/loss_cls'] = d_loss_cls.item()
                loss['D/loss_gp'] = d_loss_gp.item()

                # =================================================================================== #
                #                               3. Train the generator                                #
                # =================================================================================== #

                if (i + 1) % self.n_critic == 0:
                    # Original-to-target domain.
                    x_fake = self.G(x_real, c_trg)
                    out_src, out_cls = self.D(x_fake)
                    out_cls = out_cls[:, :self.
                                      c_dim] if dataset == 'CelebA' else out_cls[:,
                                                                                 self
                                                                                 .
                                                                                 c_dim:]
                    g_loss_fake = -torch.mean(out_src)
                    g_loss_cls = self.classification_loss(
                        out_cls, label_trg, dataset)

                    # Target-to-original domain.
                    x_reconst = self.G(x_fake, c_org)
                    g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

                    # Backward and optimize.
                    g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging.
                    loss['G/loss_fake'] = g_loss_fake.item()
                    loss['G/loss_rec'] = g_loss_rec.item()
                    loss['G/loss_cls'] = g_loss_cls.item()

                # =================================================================================== #
                #                                 4. Miscellaneous                                    #
                # =================================================================================== #

                # Print out training info.
                if (i + 1) % self.log_step == 0:
                    et = time.time() - start_time
                    et = str(datetime.timedelta(seconds=et))[:-7]
                    log = "Elapsed [{}], Iteration [{}/{}], Dataset [{}]".format(
                        et, i + 1, self.num_iters, dataset)
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(tag, value, i + 1)

            # Translate fixed images for debugging.
            if (i + 1) % self.sample_step == 0:
                with torch.no_grad():
                    x_fake_list = [x_fixed]
                    for c_fixed in c_celeba_list:
                        c_trg = torch.cat([c_fixed, zero_rafd, mask_celeba],
                                          dim=1)
                        x_fake_list.append(self.G(x_fixed, c_trg))
                    # for c_fixed in c_rafd_list:
                    #     c_trg = torch.cat([zero_celeba, c_fixed, mask_rafd], dim=1)
                    #     x_fake_list.append(self.G(x_fixed, c_trg))
                    x_concat = torch.cat(x_fake_list, dim=3)
                    sample_path = os.path.join(self.sample_dir,
                                               '{}-images.jpg'.format(i + 1))
                    save_image(self.denorm(x_concat.data.cpu()),
                               sample_path,
                               nrow=1,
                               padding=0)
                    print('Saved real and fake images into {}...'.format(
                        sample_path))

            # Save model checkpoints.
            if (i + 1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir,
                                      '{}-G.ckpt'.format(i + 1))
                D_path = os.path.join(self.model_save_dir,
                                      '{}-D.ckpt'.format(i + 1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                print('Saved model checkpoints into {}...'.format(
                    self.model_save_dir))

            # Decay learning rates.
            if (i + 1) % self.lr_update_step == 0 and (i + 1) > (
                    self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr)
                print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))

    def test(self):
        """Translate images using StarGAN v2 trained on a single dataset."""
        # Load the trained generator.
        self.restore_model(self.test_iters)

        # Set data loader.
        if self.dataset == 'CelebA':
            data_loader = self.celeba_loader
        elif self.dataset == 'RaFD':
            data_loader = self.rafd_loader

        with torch.no_grad():
            for i, (x_real, c_org) in enumerate(data_loader):

                # Prepare input images and target domain labels.
                x_real = x_real.to(self.device)
                c_trg_list = self.create_labels(c_org, self.c_dim,
                                                self.dataset,
                                                self.selected_attrs)

                # Translate images.
                x_fake_list = [x_real]
                for c_trg in c_trg_list:
                    x_fake_list.append(self.G(x_real, c_trg))

                # Save the translated images.
                x_concat = torch.cat(x_fake_list, dim=3)
                result_path = os.path.join(self.result_dir,
                                           '{}-images.jpg'.format(i + 1))
                save_image(self.denorm(x_concat.data.cpu()),
                           result_path,
                           nrow=1,
                           padding=0)
                print('Saved real and fake images into {}...'.format(
                    result_path))

    def test_multi(self):
        """Translate images using StarGAN trained on multiple datasets."""
        # Load the trained generator.
        self.restore_model(self.test_iters)

        with torch.no_grad():
            for i, (x_real, c_org) in enumerate(self.celeba_loader):

                # Prepare input images and target domain labels.
                x_real = x_real.to(self.device)
                c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA',
                                                   self.selected_attrs)
                c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
                zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(
                    self.device)  # Zero vector for CelebA.
                zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(
                    self.device)  # Zero vector for RaFD.
                mask_celeba = self.label2onehot(torch.zeros(
                    x_real.size(0)), 2).to(self.device)  # Mask vector: [1, 0].
                mask_rafd = self.label2onehot(torch.ones(
                    x_real.size(0)), 2).to(self.device)  # Mask vector: [0, 1].

                # Translate images.
                x_fake_list = [x_real]
                for c_celeba in c_celeba_list:
                    c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba],
                                      dim=1)
                    x_fake_list.append(self.G(x_real, c_trg))
                for c_rafd in c_rafd_list:
                    c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1)
                    x_fake_list.append(self.G(x_real, c_trg))

                # Save the translated images.
                x_concat = torch.cat(x_fake_list, dim=3)
                result_path = os.path.join(self.result_dir,
                                           '{}-images.jpg'.format(i + 1))
                save_image(self.denorm(x_concat.data.cpu()),
                           result_path,
                           nrow=1,
                           padding=0)
                print('Saved real and fake images into {}...'.format(
                    result_path))
Exemplo n.º 4
0
def train():

    hparams = get_hparams()
    model_path = os.path.join(hparams.model_path, hparams.task_name,
                              hparams.spec_opt)
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    # Load Dataset Loader

    normalizer_clean = Tanhize('clean')
    normalizer_noisy = Tanhize('noisy')

    print('Load dataset2d loader')
    dataset_A_2d = npyDataset2d(hparams.dataset_root,
                                hparams.list_dir_train_A_2d,
                                hparams.frame_len,
                                normalizer=normalizer_noisy)
    dataset_B_2d = npyDataset2d(hparams.dataset_root,
                                hparams.list_dir_train_B_2d,
                                hparams.frame_len,
                                normalizer=normalizer_clean)

    dataloader_A = DataLoader(
        dataset_A_2d,
        batch_size=hparams.batch_size,
        shuffle=True,
        drop_last=True,
    )
    dataloader_B = DataLoader(
        dataset_B_2d,
        batch_size=hparams.batch_size,
        shuffle=True,
        drop_last=True,
    )

    # Load Generator / Disciminator model
    generator_A = Generator()
    generator_B = Generator()

    discriminator_A = Discriminator()
    discriminator_B = Discriminator()

    ContEncoder_A = ContentEncoder()
    ContEncoder_B = ContentEncoder()

    StEncoder_A = StyleEncoder()
    StEncoder_B = StyleEncoder()

    generator_A.apply(weights_init)
    generator_B.apply(weights_init)

    discriminator_A.apply(weights_init)
    discriminator_B.apply(weights_init)

    ContEncoder_A.apply(weights_init)
    ContEncoder_B.apply(weights_init)

    StEncoder_A.apply(weights_init)
    StEncoder_B.apply(weights_init)

    real_label = 1
    fake_label = 0
    real_tensor = Variable(torch.FloatTensor(hparams.batch_size))
    _ = real_tensor.data.fill_(real_label)

    fake_tensor = Variable(torch.FloatTensor(hparams.batch_size))
    _ = fake_tensor.data.fill_(fake_label)

    # Define Loss function
    d = nn.MSELoss()
    bce = nn.BCELoss()

    # Cuda Process
    if hparams.cuda == True:
        print('-- Activate with CUDA --')

        generator_A = nn.DataParallel(generator_A).cuda()
        generator_B = nn.DataParallel(generator_B).cuda()
        discriminator_A = nn.DataParallel(discriminator_A).cuda()
        discriminator_B = nn.DataParallel(discriminator_B).cuda()
        ContEncoder_A = nn.DataParallel(ContEncoder_A).cuda()
        ContEncoder_B = nn.DataParallel(ContEncoder_B).cuda()
        StEncoder_A = nn.DataParallel(StEncoder_A).cuda()
        StEncoder_B = nn.DataParallel(StEncoder_B).cuda()

        d.cuda()
        bce.cuda()
        real_tensor = real_tensor.cuda()
        fake_tensor = fake_tensor.cuda()

    else:
        print('-- Activate without CUDA --')

    gen_params = chain(
        generator_A.parameters(),
        generator_B.parameters(),
        ContEncoder_A.parameters(),
        ContEncoder_B.parameters(),
        StEncoder_A.parameters(),
        StEncoder_B.parameters(),
    )

    dis_params = chain(
        discriminator_A.parameters(),
        discriminator_B.parameters(),
    )

    optimizer_g = optim.Adam(gen_params, lr=hparams.learning_rate)
    optimizer_d = optim.Adam(dis_params, lr=hparams.learning_rate)

    iters = 0
    for e in range(hparams.epoch_size):

        # input Tensor

        A_loader, B_loader = iter(dataloader_A), iter(dataloader_B)

        for i in range(len(A_loader) - 1):

            batch_A = A_loader.next()
            batch_B = B_loader.next()

            A_indx = torch.LongTensor(list(range(hparams.batch_size)))
            B_indx = torch.LongTensor(list(range(hparams.batch_size)))

            A_ = torch.FloatTensor(batch_A)
            B_ = torch.FloatTensor(batch_B)

            if hparams.cuda == True:

                x_A = Variable(A_.cuda())
                x_B = Variable(B_.cuda())

            else:
                x_A = Variable(A_)
                x_B = Variable(B_)

            real_tensor.data.resize_(hparams.batch_size).fill_(real_label)
            fake_tensor.data.resize_(hparams.batch_size).fill_(fake_label)

            ## Discrominator Update Steps

            discriminator_A.zero_grad()
            discriminator_B.zero_grad()

            # x_A, x_B, x_AB, x_BA
            # [#_batch, max_time_len, dim]

            A_c = ContEncoder_A(x_A).detach()
            B_c = ContEncoder_B(x_B).detach()

            # A,B :  N ~ (0,1)
            A_s = Variable(get_z_random(hparams.batch_size, 8))
            B_s = Variable(get_z_random(hparams.batch_size, 8))

            x_AB = generator_B(A_c, B_s).detach()
            x_BA = generator_A(B_c, A_s).detach()

            # We recommend LSGAN-loss for adversarial loss

            l_d_A_real = 0.5 * torch.mean(
                (discriminator_A(x_A) - real_tensor)**2)
            l_d_A_fake = 0.5 * torch.mean(
                (discriminator_A(x_BA) - fake_tensor)**2)

            l_d_B_real = 0.5 * torch.mean(
                (discriminator_B(x_B) - real_tensor)**2)
            l_d_B_fake = 0.5 * torch.mean(
                (discriminator_B(x_AB) - fake_tensor)**2)

            l_d_A = l_d_A_real + l_d_A_fake
            l_d_B = l_d_B_real + l_d_B_fake

            l_d = l_d_A + l_d_B

            l_d.backward()
            optimizer_d.step()

            ## Generator Update Steps

            generator_A.zero_grad()
            generator_B.zero_grad()
            ContEncoder_A.zero_grad()
            ContEncoder_B.zero_grad()
            StEncoder_A.zero_grad()
            StEncoder_B.zero_grad()

            A_c = ContEncoder_A(x_A)
            B_c = ContEncoder_B(x_B)

            A_s_prime = StEncoder_A(x_A)
            B_s_prime = StEncoder_B(x_B)

            # A,B : N ~ (0,1)
            A_s = Variable(get_z_random(hparams.batch_size, 8))
            B_s = Variable(get_z_random(hparams.batch_size, 8))

            x_BA = generator_A(B_c, A_s)
            x_AB = generator_B(A_c, B_s)

            x_A_recon = generator_A(A_c, A_s_prime)
            x_B_recon = generator_B(B_c, B_s_prime)

            B_c_recon = ContEncoder_A(x_BA)
            A_s_recon = StEncoder_A(x_BA)

            A_c_recon = ContEncoder_B(x_AB)
            B_s_recon = StEncoder_B(x_AB)

            x_ABA = generator_A(A_c_recon, A_s_prime)
            x_BAB = generator_B(B_c_recon, B_s_prime)

            l_cy_A = recon_criterion(x_ABA, x_A)
            l_cy_B = recon_criterion(x_BAB, x_B)

            l_f_A = recon_criterion(x_A_recon, x_A)
            l_f_B = recon_criterion(x_B_recon, x_B)

            l_c_A = recon_criterion(A_c_recon, A_c)
            l_c_B = recon_criterion(B_c_recon, B_c)

            l_s_A = recon_criterion(A_s_recon, A_s)
            l_s_B = recon_criterion(B_s_recon, B_s)

            # We recommend LSGAN-loss for adversarial loss

            l_gan_A = 0.5 * torch.mean(
                (discriminator_A(x_BA) - real_tensor)**2)
            l_gan_B = 0.5 * torch.mean(
                (discriminator_B(x_AB) - real_tensor)**2)

            l_g = l_gan_A + l_gan_B + lambda_f * (l_f_A + l_f_B) + lambda_s * (
                l_s_A + l_s_B) + lambda_c * (l_c_A + l_c_B) + lambda_cy * (
                    l_cy_A + l_cy_B)

            l_g.backward()
            optimizer_g.step()

            if iters % hparams.log_interval == 0:
                print("---------------------")

                print("Gen Loss :{} disc loss :{}".format(
                    l_g / hparams.batch_size, l_d / hparams.batch_size))
                print("epoch :", e, " ", "total ", hparams.epoch_size)
                print("iteration :", iters)

            if iters % hparams.model_save_interval == 0:
                torch.save(
                    generator_A.state_dict(),
                    os.path.join(model_path,
                                 'model_gen_A_{}.pth'.format(iters)))
                torch.save(
                    generator_B.state_dict(),
                    os.path.join(model_path,
                                 'model_gen_B_{}.pth'.format(iters)))
                torch.save(
                    discriminator_A.state_dict(),
                    os.path.join(model_path,
                                 'model_dis_A_{}.pth'.format(iters)))
                torch.save(
                    discriminator_B.state_dict(),
                    os.path.join(model_path,
                                 'model_dis_B_{}.pth'.format(iters)))

                torch.save(
                    ContEncoder_A.state_dict(),
                    os.path.join(model_path,
                                 'model_ContEnc_A_{}.pth'.format(iters)))
                torch.save(
                    ContEncoder_B.state_dict(),
                    os.path.join(model_path,
                                 'model_ContEnc_B_{}.pth'.format(iters)))
                torch.save(
                    StEncoder_A.state_dict(),
                    os.path.join(model_path,
                                 'model_StEnc_A_{}.pth'.format(iters)))
                torch.save(
                    StEncoder_B.state_dict(),
                    os.path.join(model_path,
                                 'model_StEnc_B_{}.pth'.format(iters)))

            iters += 1
Exemplo n.º 5
0
    def __init__(self, config, train_loader, test_loader):
        super(StarGAN, self).__init__()

        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.test_source, self.test_domain, _ = next(iter(self.test_loader))
        self.test_source = self.test_source.to(self.device)
        self.test_domain = self.test_domain.view(-1, 1, 1).to(self.device)
        self.test_batch_size, _, self.height, self.width = self.test_source.size(
        )
        self.save_img_cnt = 0
        self.loss = {}
        self.items = {}

        self.iter_size = len(self.train_loader)
        self.epoch_size = config['max_iter'] // self.iter_size + 1

        lr = config['lr']
        lr_F = config['lr_F']
        beta1 = config['beta1']
        beta2 = config['beta2']
        init = config['init']
        # weight_decay = config['weight_decay']

        self.batch_size = config['batch_size']
        self.gan_type = config['gan_type']
        self.max_iter = config['max_iter']
        self.img_size = config['crop_size']

        self.path_sample = os.path.join('./results/', config['save_name'],
                                        "samples")
        self.path_model = os.path.join('./results/', config['save_name'],
                                       "models")

        self.w_style = config['w_style']
        self.w_ds = config['w_ds']
        self.w_cyc = config['w_cyc']
        self.w_regul = config['w_regul']

        self.num_domain = len(train_loader.dataset.domains)
        self.dim_style = config['dim_style']
        self.dim_latent = config['mapping_network']['dim_latent']

        self.generator = Generator(config['gen'])  # 29072960
        # self.generator = DummyModel(config['gen'])  # 29072960
        self.style_encoder = StyleEncoder(config['style_encoder'],
                                          self.num_domain, self.img_size)
        self.mapping_network = MappingNetwork(config['mapping_network'],
                                              self.num_domain, self.dim_style)
        self.discriminator = Discriminator(config['dis'], self.num_domain,
                                           self.img_size)

        self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(),
                                            lr, (beta1, beta2))
        params_g = list(self.generator.parameters()) + list(
            self.style_encoder.parameters())
        self.optimizer_g = torch.optim.Adam(params_g, lr, (beta1, beta2))
        self.optimizer_g.add_param_group({
            'params':
            self.mapping_network.parameters(),
            'lr':
            lr_F,
            'betas': (beta1, beta2),
        })

        # self.scheduler_g = get_scheduler(self.optimizer_g, config)
        # self.scheduler_d = get_scheduler(self.optimizer_d, config)

        self.apply(weights_init(init))

        self.criterion_l1 = nn.L1Loss()
        self.criterion_l2 = nn.MSELoss()
        self.criterion_bce = nn.BCEWithLogitsLoss()

        self.to(self.device)
Exemplo n.º 6
0
class StarGAN(nn.Module):
    def __init__(self, config, train_loader, test_loader):
        super(StarGAN, self).__init__()

        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.test_source, self.test_domain, _ = next(iter(self.test_loader))
        self.test_source = self.test_source.to(self.device)
        self.test_domain = self.test_domain.view(-1, 1, 1).to(self.device)
        self.test_batch_size, _, self.height, self.width = self.test_source.size(
        )
        self.save_img_cnt = 0
        self.loss = {}
        self.items = {}

        self.iter_size = len(self.train_loader)
        self.epoch_size = config['max_iter'] // self.iter_size + 1

        lr = config['lr']
        lr_F = config['lr_F']
        beta1 = config['beta1']
        beta2 = config['beta2']
        init = config['init']
        # weight_decay = config['weight_decay']

        self.batch_size = config['batch_size']
        self.gan_type = config['gan_type']
        self.max_iter = config['max_iter']
        self.img_size = config['crop_size']

        self.path_sample = os.path.join('./results/', config['save_name'],
                                        "samples")
        self.path_model = os.path.join('./results/', config['save_name'],
                                       "models")

        self.w_style = config['w_style']
        self.w_ds = config['w_ds']
        self.w_cyc = config['w_cyc']
        self.w_regul = config['w_regul']

        self.num_domain = len(train_loader.dataset.domains)
        self.dim_style = config['dim_style']
        self.dim_latent = config['mapping_network']['dim_latent']

        self.generator = Generator(config['gen'])  # 29072960
        # self.generator = DummyModel(config['gen'])  # 29072960
        self.style_encoder = StyleEncoder(config['style_encoder'],
                                          self.num_domain, self.img_size)
        self.mapping_network = MappingNetwork(config['mapping_network'],
                                              self.num_domain, self.dim_style)
        self.discriminator = Discriminator(config['dis'], self.num_domain,
                                           self.img_size)

        self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(),
                                            lr, (beta1, beta2))
        params_g = list(self.generator.parameters()) + list(
            self.style_encoder.parameters())
        self.optimizer_g = torch.optim.Adam(params_g, lr, (beta1, beta2))
        self.optimizer_g.add_param_group({
            'params':
            self.mapping_network.parameters(),
            'lr':
            lr_F,
            'betas': (beta1, beta2),
        })

        # self.scheduler_g = get_scheduler(self.optimizer_g, config)
        # self.scheduler_d = get_scheduler(self.optimizer_d, config)

        self.apply(weights_init(init))

        self.criterion_l1 = nn.L1Loss()
        self.criterion_l2 = nn.MSELoss()
        self.criterion_bce = nn.BCEWithLogitsLoss()

        self.to(self.device)

    # def update_scheduler(self):
    #     if self.current_epoch >= 10 and self.scheduler_d and self.scheduler_g:
    #         self.scheduler_d.step()
    #         self.scheduler_g.step()

    def calc_adversarial_loss(self, logit, is_real):
        if self.gan_type == 'bce':
            target_fn = torch.ones_like if is_real else torch.zeros_like
            loss = self.criterion_bce(logit, target_fn(logit))

        elif self.gan_type == 'lsgan':
            target_fn = torch.ones_like if is_real else torch.zeros_like
            loss = self.criterion_l2(logit, target_fn(logit))

        elif self.gan_type == 'wgan':
            if is_real:
                loss = -torch.mean(logit)
            else:
                loss = torch.mean(logit)
        else:
            raise NotImplementedError("Unsupported gan type: {}".format(
                self.gan_type))

        return loss

    def calc_r1(self, real_images, logit_real):
        grad_real = autograd.grad(outputs=logit_real.sum(),
                                  inputs=real_images,
                                  create_graph=True)[0]
        grad_penalty = (grad_real.view(grad_real.size(0),
                                       -1).norm(2, dim=1)**2).mean()
        grad_penalty = 0.5 * grad_penalty
        return grad_penalty

    def calc_gp(self, real_images, fake_images):  # TODO :
        raise NotImplementedError("")
        alpha = torch.rand(real_images.size(0), 1, 1, 1).to(self.device)
        interpolated = (alpha * real_images +
                        ((1 - alpha) * fake_images)).requires_grad_(True)
        prob_interpolated, _ = self.discriminator(interpolated)

        grad_outputs = torch.ones(prob_interpolated.size()).to(self.device)
        gradients = torch.autograd.grad(outputs=prob_interpolated,
                                        inputs=interpolated,
                                        grad_outputs=grad_outputs,
                                        create_graph=True,
                                        retain_graph=True)[0]

        gradients = gradients.reshape(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
        return gradient_penalty

    def generate_random_nosie(self):
        random_noise = torch.randn(1, self.dim_latent).to(self.device)
        random_domain = torch.randint(self.num_domain,
                                      (self.batch_size, 1, 1)).to(self.device)
        return random_noise, random_domain

    def eval_mode_all(self):
        self.discriminator.eval()
        self.generator.eval()

    def update_d(self, real, real_domain, random_noise, random_domain):
        reset_gradients([self.optimizer_g, self.optimizer_d])
        real.requires_grad = True

        style_mapped = self.mapping_network(random_noise, random_domain)
        fake = self.generator(real, style_mapped)

        # Adv
        logit_real = self.discriminator(real, real_domain)
        logit_fake = self.discriminator(fake.detach(), random_domain)

        adv_d_real = self.calc_adversarial_loss(logit_real,
                                                is_real=True)  # .contiguous()
        adv_d_fake = self.calc_adversarial_loss(logit_fake,
                                                is_real=False)  # .contiguous()

        if self.config['gan_type'] == 'bce':
            regul = self.calc_r1(real, logit_real) * self.w_regul
        elif self.config['gan_type'] == 'wgan':
            regul = self.calc_gp(real, fake) * self.w_regul

        self.adv_d_fake = adv_d_fake
        self.adv_d_real = adv_d_real
        loss_d = adv_d_fake + adv_d_real + regul
        loss_d.backward()
        self.optimizer_d.step()

        self.loss['adv_d_fake'] = adv_d_fake.item()
        self.loss['adv_d_real'] = adv_d_real.item()
        self.loss['regul'] = regul.item()

        self.items["logit_real"] = logit_real
        self.items["logit_fake_d"] = logit_fake

    def update_g(self, real, real_domain, random_noise, random_domain):
        reset_gradients([self.optimizer_g, self.optimizer_d])

        style_fake = self.mapping_network(random_noise, random_domain)
        style_real = self.style_encoder(real, real_domain)
        fake = self.generator(real, style_fake)
        style_recon = self.style_encoder(fake, random_domain)
        image_recon = self.generator(fake, style_real)

        # Adversarial
        logit_fake = self.discriminator(fake, random_domain)
        adv_g = self.calc_adversarial_loss(logit_fake, is_real=True)

        # Style recon
        style_recon_loss = self.criterion_l1(style_fake,
                                             style_recon) * self.w_style

        # Style diversification
        random_noise1 = torch.randn(1, self.dim_latent).to(self.device)
        random_noise2 = torch.randn(1, self.dim_latent).to(self.device)
        random_domain1 = torch.randint(self.num_domain,
                                       (self.batch_size, 1, 1)).to(self.device)

        s1 = self.mapping_network(random_noise1, random_domain1)
        s2 = self.mapping_network(random_noise2, random_domain1)
        fake1 = self.generator(real, s1)
        fake2 = self.generator(real, s2)

        ds_loss = -self.criterion_l1(fake1, fake2) * self.w_ds

        # Cycle consistency
        cyc_loss = self.criterion_l1(real, image_recon) * self.w_cyc

        loss_g = adv_g + cyc_loss + style_recon_loss + ds_loss
        loss_g.backward()
        self.optimizer_g.step()

        self.loss['adv_g'] = adv_g.item()
        self.loss['style_recon_loss'] = style_recon_loss.item()
        self.loss['ds_loss'] = ds_loss.item()
        self.loss['cyc_loss'] = cyc_loss.item()

        self.items["real"] = real
        self.items["real_domain"] = real_domain
        self.items["random_noise"] = random_noise
        self.items["random_domain"] = random_domain
        self.items["random_noise1"] = random_noise1
        self.items["random_noise2"] = random_noise2
        self.items["random_domain1"] = random_domain1
        self.items["logit_fake"] = logit_fake
        self.items["style_fake"] = style_fake
        self.items["style_real"] = style_real
        self.items["fake"] = fake
        self.items["recon"] = image_recon
        self.items["style_recon"] = style_recon

    def train_starGAN(self, init_epoch):
        d_step, g_step = self.config['d_step'], self.config['g_step']
        log_iter = self.config['log_iter']
        image_display_iter = self.config['image_display_iter']
        image_save_iter = self.config['image_save_iter']

        for epoch in range(init_epoch, self.epoch_size):
            self.current_epoch = epoch
            self.save_img_cnt = 0
            for iters, (real, real_domain, _) in enumerate(self.train_loader):
                # self.update_scheduler()

                # real, real_domain = real.to(self.device), real_domain.view(-1, 1, 1).to(self.device)
                real, real_domain = real.to(self.device), real_domain.to(
                    self.device)
                random_noise, random_domain = self.generate_random_nosie()

                if not iters & d_step:
                    self.update_d(real, real_domain, random_noise,
                                  random_domain)

                if not iters % g_step:
                    self.update_g(real, real_domain, random_noise,
                                  random_domain)

                if self.device.type == 'cuda':
                    torch.cuda.synchronize()

                if not (iters + 1) % log_iter:
                    self.print_log(epoch, iters)

                    if not (iters + 1) % image_display_iter:
                        show_batch_torch(torch.cat([
                            real, self.items['fake'].clamp(-1, 1),
                            self.items['recon'].clamp(-1, 1)
                        ]),
                                         n_rows=3,
                                         n_cols=-1)

                        if not (iters + 1) % image_save_iter:
                            self.test_sample = self.generate_test_samples(
                                save=True)
                            clear_jupyter_console()

                # TODO : arbitrary
                if epoch >= 10 and not (iters + 1) % 1000:
                    print("w_ds decayed:", self.w_ds, " -> ", self.w_ds * 0.9)
                    self.w_ds *= 0.9  #

            self.save_models(epoch)

    def print_log(self, epoch, iters):
        adv_d_real = self.loss['adv_d_real']
        adv_d_fake = self.loss['adv_d_fake']
        regul = self.loss['regul']
        adv_g = self.loss['adv_g']
        style_recon_loss = self.loss['style_recon_loss']
        ds_loss = self.loss['ds_loss']
        cyc_loss = self.loss['cyc_loss']

        print(
            "[Epoch {}/{}, iters: {}/{}] " \
            "- Adv: {:5.4} {:5.4} / {:5.4}, Style recon: {:5.4}, DS: {:5.4}, Cyc : {:5.4}, Regul : {:5.4}".format(
                epoch, self.epoch_size, iters + 1, self.iter_size,
                adv_d_real, adv_d_fake, adv_g, style_recon_loss, ds_loss, cyc_loss, regul
            )
        )

    def save_models(self, epoch):
        os.makedirs(self.path_model, exist_ok=True)

        state = {
            'generator': self.generator.state_dict(),
            'discriminator': self.discriminator.state_dict(),
            'optimizer_d': self.optimizer_d.state_dict(),
            'optimizer_g': self.optimizer_g.state_dict(),
            # 'scheduler_d': self.scheduler_d.state_dict(),  # TODO
            # 'scheduler_g': self.scheduler_g.state_dict(),
            'w_ds': self.w_ds,
            'current_epoch': epoch,
        }

        save_name = os.path.join(self.path_model, "epoch_{:02}".format(epoch))
        torch.save(state, save_name)

    def load_models(self, epoch=False):
        if not epoch:
            last_model_path = sorted(
                glob.glob(os.path.join(self.path_model, '*')))[-1]
            epoch = int(last_model_path.split('/')[-1].split('_')[1][:2])

        save_name = os.path.join(self.path_model, "epoch_{:02}".format(epoch))
        checkpoint = torch.load(save_name)

        # weight
        self.discriminator.load_state_dict(checkpoint['discriminator'])
        self.generator.load_state_dict(checkpoint['generator'])
        self.optimizer_d.load_state_dict(checkpoint['optimizer_d'])
        self.optimizer_g.load_state_dict(checkpoint['optimizer_g'])
        # self.scheduler_d.load_state_dict(checkpoint['scheduler_d'])
        # self.scheduler_g.load_state_dict(checkpoint['scheduler_g'])
        self.w_ds = checkpoint['w_ds']
        self.current_epoch = checkpoint['current_epoch']
        return epoch

    def resume_train(self, restart_epoch=False):
        restart_epoch = self.load_models(restart_epoch)
        print("Resume Training - Epoch: ", restart_epoch)
        self.train_starGAN(restart_epoch + 1)

    def generate_test_samples(self, save):
        os.makedirs(self.path_sample, exist_ok=True)

        with torch.no_grad():
            reference, reference_domain, _ = next(iter(self.test_loader))
            reference, reference_domain = reference.to(
                self.device), reference_domain.to(self.device)

            style_reference = self.style_encoder(reference, reference_domain)
            style_reference = style_reference.repeat(1, reference.size(0),
                                                     1).view(
                                                         -1, 1, self.dim_style)
            source = self.test_source.repeat(reference.size(0), 1, 1,
                                             1).view(-1, 3, self.height,
                                                     self.width)
            generated = self.generator(source, style_reference).clamp(-1, 1)

            right_concat, _, _ = reshape_batch_torch(
                torch.cat([self.test_source, generated]),
                n_cols=self.test_batch_size,
                n_rows=-1)

            left_concat = torch.cat(
                [torch.zeros_like(reference[:1]), reference])
            left_concat, _, _ = reshape_batch_torch(left_concat,
                                                    n_cols=1,
                                                    n_rows=-1)

            save_image = preprocess(
                np.concatenate([left_concat, right_concat], axis=1))

            if save:
                save_name = os.path.join(
                    self.path_sample,
                    "{:02}_{:02}.jpg".format(self.current_epoch,
                                             self.save_img_cnt))
                self.save_img_cnt += 1
                plt.imsave(save_name, save_image)
                print("Test samples Saved:" + save_name)
        return save_image
Exemplo n.º 7
0
def test():

    hparams = get_hparams()
    print(hparams.task_name)
    model_path = os.path.join( hparams.model_path, hparams.task_name, hparams.spec_opt )

        
    # Load Dataset Loader


    root = '../dataset/feat/test'
    list_dir_A = './etc/Test_dt05_real_isolated_1ch_track_list.csv'
    list_dir_B = './etc/Test_dt05_simu_isolated_1ch_track_list.csv'

    output_dir = './output/{}/{}_AB_dt'.format(hparams.task_name, hparams.iteration_num)
    output_dir_real = os.path.join( output_dir, 'dt_real')
    output_dir_simu = os.path.join( output_dir, 'dt_simu')

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if not os.path.exists(output_dir_real):
        os.makedirs(output_dir_real)

    if not os.path.exists(output_dir_simu):
        os.makedirs(output_dir_simu)


    normalizer_clean = Tanhize('clean')
    normalizer_noisy = Tanhize('noisy')
    test_list_A, speaker_A = testset_list_classifier(root, list_dir_A)
    test_list_B, speaker_B = testset_list_classifier(root, list_dir_B)
#    test_list_C, speaker_C = testset_list_classifier(root, list_dir_C, 'clean')


    generator_A = Generator()
    generator_B = Generator()
    discriminator_A = Discriminator()
    discriminator_B = Discriminator()
    ContEncoder_A = ContentEncoder()
    ContEncoder_B = ContentEncoder()

    StEncoder_A = StyleEncoder()
    StEncoder_B = StyleEncoder()


    generator_A = nn.DataParallel(generator_A).cuda()
    generator_B = nn.DataParallel(generator_B).cuda()
    discriminator_A = nn.DataParallel(discriminator_A).cuda()
    discriminator_B = nn.DataParallel(discriminator_B).cuda()

    ContEncoder_A = nn.DataParallel(ContEncoder_A).cuda()
    ContEncoder_B = nn.DataParallel(ContEncoder_B).cuda()

    StEncoder_A = nn.DataParallel(StEncoder_A).cuda()
    StEncoder_B = nn.DataParallel(StEncoder_B).cuda()


    map_location = lambda storage, loc: storage
    generator_A.load_state_dict(
        torch.load('./models/{}/{}/model_gen_A_{}.pth'.format(hparams.task_name, hparams.spec_opt, hparams.iteration_num), map_location=map_location))
    generator_B.load_state_dict(
        torch.load('./models/{}/{}/model_gen_B_{}.pth'.format(hparams.task_name, hparams.spec_opt,hparams.iteration_num), map_location=map_location))
    discriminator_A.load_state_dict(
        torch.load('./models/{}/{}/model_dis_A_{}.pth'.format(hparams.task_name, hparams.spec_opt, hparams.iteration_num), map_location=map_location))
    discriminator_B.load_state_dict(
        torch.load('./models/{}/{}/model_dis_B_{}.pth'.format(hparams.task_name, hparams.spec_opt, hparams.iteration_num), map_location=map_location))
    ContEncoder_A.load_state_dict(
        torch.load('./models/{}/{}/model_ContEnc_A_{}.pth'.format(hparams.task_name, hparams.spec_opt, hparams.iteration_num), map_location=map_location))
    ContEncoder_B.load_state_dict(
        torch.load('./models/{}/{}/model_ContEnc_B_{}.pth'.format(hparams.task_name, hparams.spec_opt, hparams.iteration_num), map_location=map_location))
    StEncoder_A.load_state_dict(
        torch.load('./models/{}/{}/model_StEnc_A_{}.pth'.format(hparams.task_name, hparams.spec_opt, hparams.iteration_num), map_location=map_location))
    StEncoder_B.load_state_dict(
        torch.load('./models/{}/{}/model_StEnc_B_{}.pth'.format(hparams.task_name, hparams.spec_opt, hparams.iteration_num), map_location=map_location))


    for i in range(len(test_list_A)):

        generator_B.eval()
        ContEncoder_A.eval()
        StEncoder_B.eval()


        feat = testset_loader(root, test_list_A[i],speaker_A, normalizer =normalizer_noisy)

        print(feat['audio_name'])
    
        A_content = Variable(torch.FloatTensor(feat['sp']).unsqueeze(0)).cuda()


        A_cont= ContEncoder_A(A_content)

        z_st = get_z_random(1, 8)
        
        feature_z = generator_B(A_cont, z_st)
        feature_z = normalizer_noisy.backward_process(feature_z.squeeze().data)
        feature_z = feature_z.squeeze().data.cpu().numpy()
        

        np.save( os.path.join( output_dir_real, 'z-' + feat['audio_name']), feature_z)

    for i in range(len(test_list_B)):

        generator_B.eval()
        ContEncoder_A.eval()


        feat = testset_loader(root, test_list_B[i],speaker_B, normalizer =normalizer_noisy)
        
        print(feat['audio_name'])
    
        A_content = Variable(torch.FloatTensor(feat['sp']).unsqueeze(0)).cuda()

        A_cont= ContEncoder_A(A_content)


        z_st = get_z_random(1, 8)
    
        feature_z = generator_B(A_cont, z_st)
        feature_z = normalizer_noisy.backward_process(feature_z.squeeze().data)
        feature_z = feature_z.squeeze().data.cpu().numpy()
        

        np.save( os.path.join( output_dir_simu, 'z-' + feat['audio_name']), feature_z)