Exemple #1
0
def train(epochs, batchsize, interval, c_path, s_path, modeldir):
    # Dataset definition
    dataset = CRDataset(c_path, s_path)
    collator = CollateFn()

    # Model definition
    generator = CartoonRenderer()
    generator.cuda()
    generator.train()
    gen_opt = torch.optim.Adam(generator.parameters(), lr=0.0001)

    discriminator = Discriminator()
    discriminator.cuda()
    discriminator.train()
    dis_opt = torch.optim.Adam(discriminator.parameters(), lr=0.0001)

    iterations = 0

    for epoch in range(epochs):
        dataloader = DataLoader(dataset,
                                batch_size=batchsize,
                                shuffle=True,
                                drop_last=True,
                                collate_fn=collator)
        dataloader = tqdm(dataloader)

        for i, data in enumerate(dataloader):
            iterations += 1
            c, s = data

            y, _, _, _ = generator(c, s)
            dis_loss = adversarial_loss_dis(discriminator, y, s)

            dis_opt.zero_grad()
            dis_loss.backward()
            dis_opt.step()

            y, c_feat, sa_list, y_feat = generator(c, s)
            y_c, _, _, _ = generator(c, c)
            y_s, _, _, _ = generator(s, s)

            gen_loss = adversarial_loss_gen(discriminator, y)
            gen_loss += reconstruction_loss(y_c, c)
            gen_loss += reconstruction_loss(y_s, s)
            gen_loss += content_loss(sa_list, y_feat)
            gen_loss += style_loss(c_feat, y_feat)

            gen_opt.zero_grad()
            gen_loss.backward()
            gen_opt.step()

            if iterations % interval == 1:
                torch.save(generator.state_dict(),
                           f"{modeldir}/model_{iterations}.pt")

            print(
                f"iter: {iterations} dis loss: {dis_loss.data} gen loss: {gen_loss.data}"
            )
Exemple #2
0
def display_network(opt):
    cuda = True if torch.cuda.is_available() else False
    c_dim = len(opt.selected_attrs)

    generator = Generator(opt.channels, opt.residual_blocks, c_dim)
    discriminator = Discriminator(opt.channels, opt.img_height, c_dim)

    if cuda:
        generator.cuda()
        discriminator.cuda()

    summary(generator, [(opt.channels, opt.img_height, opt.img_width), (c_dim)])
    summary(discriminator, (opt.channels, opt.img_height, opt.img_width))
Exemple #3
0
def display_network(opt):
    cuda = True if torch.cuda.is_available() else False

    generator = Generator(opt.channels)
    # generator.load_state_dict(torch.load(opt.load_model))
    discriminator = Discriminator(opt.channels)

    if cuda:
        generator.cuda()
        discriminator.cuda()

    # summary(generator, (opt.channels, opt.img_size, opt.img_size))
    summary(discriminator, (opt.channels, opt.mask_size, opt.mask_size))
def main():
    args = parse_arguments()
    data_loader = download_dataset(args.download_path, args.batch_size)
    img_shape = (args.img_channels, args.img_size, args.img_size)
    base_loss = nn.BCELoss()
    generator = Generator(args.latent_dim, img_shape)
    discriminator = Discriminator(img_shape)
    if torch.cuda.is_available():
        generator.cuda()
        discriminator.cuda()
        base_loss.cuda()
    # Define optimizer
    train(data_loader, discriminator, generator, base_loss, args)
Exemple #5
0
def display_network(opt):
    cuda = True if torch.cuda.is_available() else False

    generator = Generator(opt)
    # generator.load_state_dict(torch.load(opt.load_model))
    discriminator = Discriminator(opt)

    if cuda:
        generator.cuda()
        discriminator.cuda()

    # print(*discriminator.output_shape)
    summary(generator, (opt.channels, opt.img_height, opt.img_width))
    summary(discriminator, (opt.channels, opt.img_height, opt.img_width))
Exemple #6
0
def main(args):
    train_loader, test_loader = load_data(args)

    GeneratorA2B = CycleGAN()
    GeneratorB2A = CycleGAN()

    DiscriminatorA = Discriminator()
    DiscriminatorB = Discriminator()

    if args.cuda:
        GeneratorA2B = GeneratorA2B.cuda()
        GeneratorB2A = GeneratorB2A.cuda()

        DiscriminatorA = DiscriminatorA.cuda()
        DiscriminatorB = DiscriminatorB.cuda()

    optimizerG = optim.Adam(itertools.chain(GeneratorA2B.parameters(), GeneratorB2A.parameters()), lr=args.lr, betas=(0.5, 0.999))
    optimizerD = optim.Adam(itertools.chain(DiscriminatorA.parameters(), DiscriminatorB.parameters()), lr=args.lr, betas=(0.5, 0.999))

    if args.training:
        path = 'E:/cyclegan/checkpoints/model_{}_{}.pth'.format(285, 200)

        checkpoint = torch.load(path)
        GeneratorA2B.load_state_dict(checkpoint['generatorA'])
        GeneratorB2A.load_state_dict(checkpoint['generatorB'])
        DiscriminatorA.load_state_dict(checkpoint['discriminatorA'])
        DiscriminatorB.load_state_dict(checkpoint['discriminatorB'])
        optimizerG.load_state_dict(checkpoint['optimizerG'])
        optimizerD.load_state_dict(checkpoint['optimizerD'])

        start_epoch = 285
    else:
        init_net(GeneratorA2B, init_type='normal', init_gain=0.02, gpu_ids=[0])
        init_net(GeneratorB2A, init_type='normal', init_gain=0.02, gpu_ids=[0])

        init_net(DiscriminatorA, init_type='normal', init_gain=0.02, gpu_ids=[0])
        init_net(DiscriminatorB, init_type='normal', init_gain=0.02, gpu_ids=[0])
        start_epoch = 1

    if args.evaluation:
        evaluation(test_loader, GeneratorA2B, GeneratorB2A, args)
    else:
        cycle = nn.L1Loss()
        gan = nn.BCEWithLogitsLoss()
        identity = nn.L1Loss()

        for epoch in range(start_epoch, args.epochs):
            train(train_loader, GeneratorA2B, GeneratorB2A, DiscriminatorA, DiscriminatorB, optimizerG, optimizerD, cycle, gan, identity, args, epoch)
        evaluation(test_loader, GeneratorA2B, GeneratorB2A, args)
Exemple #7
0
def main(args):
    train_loader, test_loader = load_data(args)

    generator = StyleGAN()
    discriminator = Discriminator(
        from_rgb_activate=not args.no_from_rgb_activate)
    g_running = StyleGAN()
    if args.cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        g_running = g_running.cuda()
    g_running.train(False)

    g_optimizer = optim.Adam(generator.generator.parameters(),
                             lr=args.lr,
                             betas=(0., 0.99))
    g_optimizer.add_param_group({
        'params': generator.style.parameters(),
        'lr': args.lr * 0.01,
        'mult': 0.01
    })

    d_optimizer = optim.Adam(discriminator.parameters(),
                             lr=args.lr,
                             betas=(0., 0.99))

    accumulate(g_running, generator, 0)

    train(args, train_loader, generator, g_running, discriminator, g_optimizer,
          d_optimizer)
def load_networks(isTraining=False):
    depth_net = DepthNetModel()
    color_net = ColorNetModel()
    d_net = Discriminator()
    if param.useGPU:
        depth_net.cuda()
        color_net.cuda()
        d_net.cuda()

    depth_optimizer = optim.Adam(depth_net.parameters(),
                                 lr=param.alpha,
                                 betas=(param.beta1, param.beta2),
                                 eps=param.eps)
    color_optimizer = optim.Adam(color_net.parameters(),
                                 lr=param.alpha,
                                 betas=(param.beta1, param.beta2),
                                 eps=param.eps)
    d_optimizer = optim.Adam(d_net.parameters())

    if isTraining:
        netFolder = param.trainNet
        netName, _, _ = get_folder_content(netFolder, '.tar')

        if param.isContinue and netName:
            tokens = netName[0].split('-')[1].split('.')[0]
            param.startIter = int(tokens)
            checkpoint = torch.load(netFolder + '/' + netName[0])
            depth_net.load_state_dict(checkpoint['depth_net'])
            color_net.load_state_dict(checkpoint['color_net'])
            d_net.load_state_dict(checkpoint['d_net'])
            depth_optimizer.load_state_dict(checkpoint['depth_optimizer'])
            color_optimizer.load_state_dict(checkpoint['color_optimizer'])
            d_optimizer.load_state_dict(checkpoint['d_optimizer'])
        else:
            param.isContinue = False

    else:
        netFolder = param.testNet
        checkpoint = torch.load(netFolder + '/Net_GAN.tar')
        depth_net.load_state_dict(checkpoint['depth_net'])
        color_net.load_state_dict(checkpoint['color_net'])
        d_net.load_state_dict(checkpoint['d_net'])
        depth_optimizer.load_state_dict(checkpoint['depth_optimizer'])
        color_optimizer.load_state_dict(checkpoint['color_optimizer'])
        d_optimizer.load_state_dict(checkpoint['d_optimizer'])

    return depth_net, color_net, d_net, depth_optimizer, color_optimizer, d_optimizer
Exemple #9
0
def main():
    if opt.cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus
        if not torch.cuda.is_available():
            raise Exception(
                'No GPU found or Wrong gpu id, please run without --cuda')

    logger.info('[INFO] Loading datasets')
    train_set = TagImageDataset(tag_path=opt.tag, img_path=opt.image)
    train_loader = DataLoader(train_set,
                              num_workers=opt.threads,
                              batch_size=opt.batch,
                              shuffle=True,
                              drop_last=True)

    logger.info('[INFO] Building model')
    G = Generator(opt.features)
    D = Discriminator(opt.features)
    criterion = nn.BCEWithLogitsLoss()

    logger.info('[INFO] Setting Optimizer')
    G_optim = optim.Adam(G.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    D_optim = optim.Adam(D.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

    logger.info('[INFO] Setting GPU')
    if opt.cuda:
        G = G.cuda()
        D = D.cuda()
        criterion = criterion.cuda()

    if opt.resume:
        if os.path.isfile(opt.resume):
            logger.info('[LOAD] Loading checkpoint {}'.format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint['epoch'] + 1
            G.load_state_dict(checkpoint['g'])
            D.load_state_dict(checkpoint['d'])
            G_optim.load_state_dict(checkpoint['g_optim'])
            D_optim.load_state_dict(checkpoint['d_optim'])
        else:
            logger.warning('[ERROR] No checkpoint found at {}'.format(
                opt.resume))

    if opt.pre_trained:
        if os.path.isfile(opt.pre_trained):
            logger.info('[LOAD] Loading model {}'.format(opt.pre_trained))
            weights = torch.load(opt.pre_trained)
            G.load_state_dict(weights['g'].state_dict())
            D.load_state_dict(weights['d'].state_dict())
            G_optim.load_state_dict(weights['g_optim'].state_dict())
            D_optim.load_state_dict(weights['d_optim'].state_dict())
        else:
            logger.warning('[ERROR] No model found at {}'.format(
                opt.pre_trained))

    logger.info('[INFO] Start Training')
    train(train_loader, G, D, G_optim, D_optim, criterion)
Exemple #10
0
def load_network(gpus):
    # Generator
    netG = Generator()
    netG.apply(weights_init)
    netG = torch.nn.DataParallel(netG, device_ids=gpus)
    print(netG)

    # Discriminator
    netD = Discriminator()
    netD.apply(weights_init)
    netD = torch.nn.DataParallel(netD, device_ids=gpus)
    print(netD)

    # Loading pretrained weights, if exists.
    training_iter = 0
    if cfg.TRAIN.NET_G != '':
        state_dict = torch.load(cfg.TRAIN.NET_G)
        netG.load_state_dict(state_dict)
        print('Loaded Generator from saved model.', cfg.TRAIN.NET_G)

        istart = cfg.TRAIN.NET_G.rfind('_') + 1
        iend = cfg.TRAIN.NET_G.rfind('.')
        training_iter = cfg.TRAIN.NET_G[istart:iend]
        training_iter = int(training_iter) + 1

    if cfg.TRAIN.NET_D != '':
        print('Loading Discriminator from %s.pth' % (cfg.TRAIN.NET_D))
        state_dict = torch.load('%s.pth' % (cfg.TRAIN.NET_D))
        netD.load_state_dict(state_dict)

    inception_model = INCEPTION_V3()

    # Moving to GPU
    if cfg.CUDA:
        netG.cuda()
        netD.cuda()
        inception_model = inception_model.cuda()

    inception_model.eval()

    return netG, netD, inception_model, training_iter
Exemple #11
0
def display_network(opt):
    cuda = True if torch.cuda.is_available() else False

    # Dimensionality
    input_shape = (opt.channels, opt.img_height, opt.img_width)
    shared_dim = opt.dim * (2**opt.n_downsample)

    # Initialize generator and discriminator
    shared_E = ResidualBlock(in_channels=shared_dim)
    E1 = Encoder(dim=opt.dim,
                 n_downsample=opt.n_downsample,
                 shared_block=shared_E)
    E2 = Encoder(dim=opt.dim,
                 n_downsample=opt.n_downsample,
                 shared_block=shared_E)

    shared_G = ResidualBlock(in_channels=shared_dim)
    G1 = Generator(dim=opt.dim,
                   n_upsample=opt.n_upsample,
                   shared_block=shared_G)
    G2 = Generator(dim=opt.dim,
                   n_upsample=opt.n_upsample,
                   shared_block=shared_G)

    D1 = Discriminator(input_shape)
    D2 = Discriminator(input_shape)

    if cuda:
        E1 = E1.cuda()
        E2 = E2.cuda()
        G1 = G1.cuda()
        G2 = G2.cuda()
        D1 = D1.cuda()
        D2 = D2.cuda()

    summary(E1, (opt.channels, opt.img_height, opt.img_width))
    summary(E2, (opt.channels, opt.img_height, opt.img_width))
    summary(G1, (opt.img_height, opt.dim, opt.dim))
    summary(G2, (opt.img_height, opt.dim, opt.dim))
    summary(D1, (opt.channels, opt.img_height, opt.img_width))
    summary(D2, (opt.channels, opt.img_height, opt.img_width))
Exemple #12
0
def build_model(model_type):
    generator = Generator(model_name=model_type, batch_size=128)
    discriminator = Discriminator()
    if cuda_available:
        generator = generator.cuda()
        discriminator = discriminator.cuda()

    loss = nn.BCELoss()
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=2e-4)
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=2e-4)

    return generator, discriminator, loss, optimizer_g, optimizer_d
Exemple #13
0
def choose_model(model_options):
    generator = Generator(model_options)
    discriminator = Discriminator(model_options)

    if torch.cuda.is_available():
        print("CUDA is available")
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        print("Moved models to GPU")

    # Initialize weights
    generator.apply(weights_init)
    discriminator.apply(weights_init)

    return generator, discriminator
Exemple #14
0
def test_discriminator(use_cuda):
    net = Discriminator(20, 2, 2000, 64, [1, 2, 4, 6, 8, 10, 20],
                        [100, 100, 100, 100, 100, 160, 160], 0, 820, 4, 0.75,
                        0.2)
    print(net)
    sentence = np.random.randint(2000, size=(64, 20))
    sentence = Variable(torch.from_numpy(sentence).long())
    target = np.random.randint(2, size=(64))
    target = Variable(torch.from_numpy(target).long())
    if use_cuda:
        net = net.cuda()
        target = target.cuda()
        sentence = sentence.cuda()
    out_dict = net(sentence)
    print("Disciminator forward test passed.")
    loss_function = nn.CrossEntropyLoss().cuda()
    loss = loss_function(out_dict["score"], target) + net.l2_loss()
    loss.backward()
    print("Disciminator backward test passed.")
Exemple #15
0
def prepare_model_dict(use_cuda=False):
    f = open("./params/leak_gan_params.json")
    params = json.load(f)
    f.close()
    discriminator_params = params["discriminator_params"]
    generator_params = params["generator_params"]
    worker_params = generator_params["worker_params"]
    manager_params = generator_params["manager_params"]
    discriminator_params["goal_out_size"] = sum(
        discriminator_params["num_filters"])
    worker_params["goal_out_size"] = discriminator_params["goal_out_size"]
    manager_params["goal_out_size"] = discriminator_params["goal_out_size"]
    discriminator = Discriminator(**discriminator_params)
    generator = Generator(worker_params, manager_params,
                          generator_params["step_size"])
    if use_cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
    model_dict = {"generator": generator, "discriminator": discriminator}
    return model_dict
Exemple #16
0
def main():
    # load data
    annotationfile = image_dir + 'edited_annotations.csv'
    animefacedata = AnimeFaceDataset(annotationfile, image_dir)
    dataloader = DataLoader(animefacedata,
                            batch_size=batch_size,
                            shuffle=True,
                            collate_fn=my_collate,
                            drop_last=True)
    print("Data loaded : %d" % (len(animefacedata)))

    G = Generator()
    D = Discriminator()
    G.apply(init_weight)
    D.apply(init_weight)

    if args.cuda:
        G = G.cuda()
        D = D.cuda()

    criterion = nn.BCELoss()
    print("Start Training")
    train(G, D, dataloader, criterion)
    print("Finished training!")
Exemple #17
0
class Solver(object):
    def __init__(self, data_loader, config):
        # Data loader
        self.data_loader = data_loader

        # Model hyper-parameters
        self.c_dim = config.c_dim
        self.c2_dim = config.c2_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.d_train_repeat = config.d_train_repeat

        # Hyper-parameteres
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        # Training settings
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.batch_size = config.batch_size
        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model
        self.pretrained_model_path = config.pretrained_model_path

        # Test settings
        self.test_model = config.test_model

        # Path
        self.log_path = config.log_path
        self.sample_path = config.sample_path
        self.model_save_path = config.model_save_path
        self.result_path = config.result_path

        # Step size
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step

        # Build tensorboard if use
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def build_model(self):

        self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num,
                           self.image_size)
        self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim,
                               self.d_repeat_num)

        # Optimizers
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])

        # Print networks
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.pretrained_model_path,
                             '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(
            torch.load(
                os.path.join(self.pretrained_model_path,
                             '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(
            self.pretrained_model))

    def build_tensorboard(self):
        from logger import Logger
        self.logger = Logger(self.log_path)

    def update_lr(self, g_lr, d_lr):
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def to_var(self, x, volatile=False):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def threshold(self, x):
        x = x.clone()
        x = (x >= 0.5).float()
        return x

    def compute_accuracy(self, x, y):
        x = F.sigmoid(x)
        predicted = self.threshold(x)
        correct = (predicted == y).float()
        accuracy = torch.mean(correct, dim=0) * 100.0
        return accuracy

    def one_hot(self, labels, dim):
        """Convert label indices to one-hot vector"""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def make_data_labels(self, real_c):
        """Generate domain labels for dataset for debugging/testing.
        """

        y = [
            torch.FloatTensor([1, 0, 0]),
            torch.FloatTensor([0, 1, 0]),
            torch.FloatTensor([0, 0, 1]),
            torch.FloatTensor([1, 0, 0]),
            torch.FloatTensor([0, 1, 0]),
            torch.FloatTensor([0, 0, 1])
        ]

        fixed_c_list = []

        for i in range(self.c_dim):
            fixed_c = real_c.clone()
            for c in fixed_c:
                if i < 3:
                    c[:3] = y[i]
                else:
                    c[3:] = y[i]

            fixed_c_list.append(self.to_var(fixed_c, volatile=True))

        return fixed_c_list

    def train(self):
        """Train StarGAN within a single dataset."""

        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        fixed_x = []
        real_c = []
        for i, (images, labels) in enumerate(self.data_loader):
            fixed_x.append(images)
            real_c.append(labels)
            if i == 3:
                break

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)
        real_c = torch.cat(real_c, dim=0)

        fixed_c_list = self.make_data_labels(real_c)

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (real_x, real_label) in enumerate(self.data_loader):

                # Generat fake labels randomly (target domain labels)
                rand_idx = torch.randperm(real_label.size(0))
                fake_label = real_label[rand_idx]

                real_c = real_label.clone()
                fake_c = fake_label.clone()

                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_c = self.to_var(real_c)  # input for the generator
                fake_c = self.to_var(fake_c)
                real_label = self.to_var(
                    real_label
                )  # this is same as real_c if dataset == 'CelebA'
                fake_label = self.to_var(fake_label)

                # ================== Train D ================== #

                # Compute loss with real images
                out_src, out_cls = self.D(real_x)
                d_loss_real = -torch.mean(out_src)

                d_loss_cls = F.binary_cross_entropy_with_logits(
                    out_cls, real_label, size_average=False) / real_x.size(0)

                # Compute classification accuracy of the discriminator
                if (i + 1) % self.log_step == 0:
                    accuracies = self.compute_accuracy(out_cls, real_label)
                    log = [
                        "{:.2f}".format(acc)
                        for acc in accuracies.data.cpu().numpy()
                    ]
                    print('Classification Acc: ')
                    print(log)

                # Compute loss with fake images
                fake_x = self.G(real_x, fake_c)
                fake_x = Variable(fake_x.data)
                out_src, out_cls = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data +
                                        (1 - alpha) * fake_x.data,
                                        requires_grad=True)
                out, out_cls = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                loss = {}
                loss['D/loss_real'] = d_loss_real.data[0]
                loss['D/loss_fake'] = d_loss_fake.data[0]
                loss['D/loss_cls'] = d_loss_cls.data[0]
                loss['D/loss_gp'] = d_loss_gp.data[0]

                # ================== Train G ================== #
                if (i + 1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    fake_x = self.G(real_x, fake_c)
                    rec_x = self.G(fake_x, real_c)

                    # Compute losses
                    out_src, out_cls = self.D(fake_x)
                    g_loss_fake = -torch.mean(out_src)
                    g_loss_rec = torch.mean(torch.abs(real_x - rec_x))

                    g_loss_cls = F.binary_cross_entropy_with_logits(
                        out_cls, fake_label,
                        size_average=False) / fake_x.size(0)

                    # Backward + Optimize
                    g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    loss['G/loss_fake'] = g_loss_fake.data[0]
                    loss['G/loss_rec'] = g_loss_rec.data[0]
                    loss['G/loss_cls'] = g_loss_cls.data[0]

                # Print out log info
                if (i + 1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e + 1, self.num_epochs, i + 1,
                        iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(
                                tag, value, e * iters_per_epoch + i + 1)

                # Translate fixed images for debugging
                if (i + 1) % self.sample_step == 0:
                    fake_image_list = [fixed_x]
                    for fixed_c in fixed_c_list:
                        fake_image_list.append(self.G(fixed_x, fixed_c))
                    fake_images = torch.cat(fake_image_list, dim=3)
                    save_image(self.denorm(fake_images.data.cpu()),
                               os.path.join(
                                   self.sample_path,
                                   '{}_{}_fake.png'.format(e + 1, i + 1)),
                               nrow=1,
                               padding=0)
                    print('Translated images and saved into {}..!'.format(
                        self.sample_path))

                # Save model checkpoints
                if (i + 1) % self.model_save_step == 0:
                    torch.save(
                        self.G.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_G.pth'.format(e + 1, i + 1)))
                    torch.save(
                        self.D.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_D.pth'.format(e + 1, i + 1)))

            # Decay learning rate
            if (e + 1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print('Decay learning rate to g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))

    def test(self):
        """Facial attribute transfer on CelebA or facial expression synthesis on RaFD."""
        # Load trained parameters
        G_path = os.path.join(self.model_save_path,
                              '{}_G.pth'.format(self.test_model))
        self.G.load_state_dict(torch.load(G_path))
        self.G.eval()

        data_loader = self.data_loader

        for i, (real_x, org_c) in enumerate(data_loader):
            real_x = self.to_var(real_x, volatile=True)

            if self.dataset == 'CelebA':
                target_c_list = self.make_data_labels(org_c)
            else:
                target_c_list = []
                for j in range(self.c_dim):
                    target_c = self.one_hot(
                        torch.ones(real_x.size(0)) * j, self.c_dim)
                    target_c_list.append(self.to_var(target_c, volatile=True))

            # Start translations
            fake_image_list = [real_x]
            for target_c in target_c_list:
                fake_image_list.append(self.G(real_x, target_c))
            fake_images = torch.cat(fake_image_list, dim=3)
            save_path = os.path.join(self.result_path,
                                     '{}_fake.png'.format(i + 1))
            save_image(self.denorm(fake_images.data),
                       save_path,
                       nrow=1,
                       padding=0)
            print('Translated test images and saved into "{}"..!'.format(
                save_path))
Exemple #18
0
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)


# In[3]:

netG = Generator()
netD = Discriminator()

netG.apply(weights_init)
netD.apply(weights_init)
if use_cuda:
    netG = netG.cuda()
    netD = netD.cuda()
optimizerD = optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.999))

fixed_noise = Variable(torch.rand(1, 100), volatile=True).cuda()

# In[4]:

from torch.utils.data import Dataset, DataLoader


class FaceImgDataset(Dataset):
    def __init__(self, img_file, tag_file, fake_tag_file):
        self.training_imgs = torch.load(img_file)
        self.training_tags = torch.load(tag_file)
        self.fake_tags = torch.load(fake_tag_file)
Exemple #19
0
def train(batchsize, epochs):
    dataset = dset.ImageFolder(root="./data/",
                               transform=transforms.Compose([
                                   transforms.Resize(64),
                                   transforms.CenterCrop(64),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                              ]))
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchsize, shuffle=True, num_workers=1)
    

    nz = 150
    netG = Generator(nz, (64,64,3))
    netG = netG.cuda()
    netG.apply(weights_init)
    netD = Discriminator((64,64,3))
    netD = netD.cuda()
    netD.apply(weights_init)

    optimizerD = optim.RMSprop(netD.parameters(), lr=0.00005, alpha=0.9)
    optimizerG = optim.RMSprop(netG.parameters(), lr=0.00005, alpha=0.9)

    img_list = []
    G_losses = []
    D_losses = []

    netG.train()
    netD.train()
    for epoch in range(epochs):
        d_loss = 0
        g_loss = 0
        count = 0
        fixed_noise = torch.randn(25, nz, 1, 1, device="cuda")
        print("Epoch: "+str(epoch)+"/"+str(epochs))
        is_d = 0
        for data in tqdm(dataloader):
            optimizerD.zero_grad()
            # Format batch
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)

            #discriminate real image
            D_real = netD(real_cpu).view(-1)
            D_real_loss = torch.mean(D_real)
            #generate fake image from noise vector
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            fake = netG(noise).detach()
            #discriminate fake image
            D_fake = netD(fake).view(-1)
            D_fake_loss = torch.mean(D_fake)

            gradient_penalty = calc_gradient_penalty(netD, real_cpu, fake , b_size)
            # discriminator loss
            D_loss =  D_fake_loss - D_real_loss + gradient_penalty
            D_loss.backward()
            # Update D
            optimizerD.step()
            d_loss += D_loss.item()
            D_losses.append(D_loss.item())
            is_d+=1
                # weight clipping
            for p in netD.parameters():
                p.data.clamp_(-0.01, 0.01)

            # update generator every 5 batch
            if is_d%5 == 0:
                is_d = 1
                # freeze discriminator
                for p in netD.parameters():
                    p.requires_grad = False
                optimizerG.zero_grad()
                #generate fake image
                noise = torch.randn(b_size, nz, 1, 1, device=device)
                fake = netG(noise)
                #to confuse discriminator
                G_fake = netD(fake).view(-1)
                #generator loss
                G_loss = -torch.mean(G_fake)
                G_loss.backward()
            # Update G
                optimizerG.step()
                g_loss += G_loss.item()
                G_losses.append(G_loss.item())
                for p in netD.parameters():
                    p.requires_grad = True
        print("D_real_loss:%.6f, D_fake_loss:%.6f"%(D_real_loss,D_fake_loss))

        # output image every 3 epoch
        if epoch%3 == 0:
            with torch.no_grad():
                test_img = netG(fixed_noise).detach().cpu()
            test_img = test_img.numpy()
            test_img = np.transpose(test_img,(0,2,3,1))
            fig, axs = plt.subplots(5, 5)
            cnt = 0
            for i in range(5):
                for j in range(5):
                    axs[i,j].imshow(test_img[cnt, :,:,:])
                    axs[i,j].axis('off')
                    cnt += 1
            fig.savefig("./output_grad/"+str(epoch)+".png")
            plt.close()
        print("d loss: "+str(d_loss)+", g loss: "+str(g_loss))
    torch.save({'g': netG.state_dict(), 'd': netD.state_dict()},"model_best")
Exemple #20
0
                        num_workers=4,
                        batch_size=1,
                        shuffle=False)

netG = Generator(UPSCALE_FACTOR)
print('# generator parameters:',
      sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print('# discriminator parameters:',
      sum(param.numel() for param in netD.parameters()))

generator_criterion = GeneratorLoss()

if torch.cuda.is_available():
    netG.cuda()
    netD.cuda()
    generator_criterion.cuda()

optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

results = {
    'd_loss': [],
    'g_loss': [],
    'd_score': [],
    'g_score': [],
    'psnr': [],
    'ssim': []
}

for epoch in range(1, NUM_EPOCHS + 1):
Exemple #21
0
def train(epochs, batchsize, interval, c_path, s_path):
    # Dataset definition
    dataset = HairDataset(c_path, s_path)
    collator = CollateFn()

    # Model & Optimizer Definition
    munit = MUNIT()
    munit.cuda()
    munit.train()
    m_opt = torch.optim.Adam(munit.parameters(),
                             lr=0.0001,
                             betas=(0.5, 0.999),
                             weight_decay=0.0001)

    discriminator_a = Discriminator()
    discriminator_a.cuda()
    discriminator_a.train()
    da_opt = torch.optim.Adam(discriminator_a.parameters(),
                              lr=0.0001,
                              betas=(0.5, 0.999),
                              weight_decay=0.0001)

    discriminator_b = Discriminator()
    discriminator_b.cuda()
    discriminator_b.train()
    db_opt = torch.optim.Adam(discriminator_b.parameters(),
                              lr=0.0001,
                              betas=(0.5, 0.999),
                              weight_decay=0.0001)

    vgg = Vgg19Norm()
    vgg.cuda()
    vgg.train()

    iterations = 0

    for epoch in range(epochs):
        dataloader = DataLoader(dataset,
                                batch_size=batchsize,
                                shuffle=True,
                                drop_last=True,
                                collate_fn=collator)
        dataloader = tqdm(dataloader)

        for i, data in enumerate(dataloader):
            iterations += 1
            a, b = data
            _, _, _, _, _, _, ba, ab, _, _, _, _, _, _ = munit(a, b)

            loss = adversarial_dis_loss(discriminator_a, ba, a)
            loss += adversarial_dis_loss(discriminator_b, ab, b)

            da_opt.zero_grad()
            db_opt.zero_grad()
            loss.backward()
            da_opt.step()
            db_opt.step()

            c_a, s_a, c_b, s_b, a_recon, \
                b_recon, ba, ab, c_b_recon, s_a_recon, c_a_recon, s_b_recon, aba, bab = munit(a, b)

            loss = adversarial_gen_loss(discriminator_a, ba)
            loss += adversarial_gen_loss(discriminator_b, ab)
            loss += 10 * reconstruction_loss(a_recon, a)
            loss += 10 * reconstruction_loss(b_recon, b)
            loss += reconstruction_loss(c_a, c_a_recon)
            loss += reconstruction_loss(c_b, c_b_recon)
            loss += reconstruction_loss(s_a, s_a_recon)
            loss += reconstruction_loss(s_b, s_b_recon)
            loss += 10 * reconstruction_loss(aba, a)
            loss += 10 * reconstruction_loss(bab, b)
            loss += perceptual_loss(vgg, ba, b)
            loss += perceptual_loss(vgg, ab, a)

            m_opt.zero_grad()
            loss.backward()
            m_opt.step()

            if iterations % interval == 1:
                torch.save(munit.load_state_dict,
                           f"./modeldir/model_{iterations}.pt")

                pylab.rcParams['figure.figsize'] = (16.0, 16.0)
                pylab.clf()

                munit.eval()

                with torch.no_grad():
                    _, _, _, _, _, _, _, ab, _, _, _, _, _, _ = munit(a, b)
                    fake = ab.detach().cpu().numpy()
                    real = a.detach().cpu().numpy()

                    for i in range(batchsize):
                        tmp = (np.clip(real[i] * 127.5 + 127.5, 0,
                                       255)).transpose(1, 2,
                                                       0).astype(np.uint8)
                        pylab.subplot(4, 4, 2 * i + 1)
                        pylab.imshow(tmp)
                        pylab.axis("off")
                        pylab.savefig(
                            "outdir/visualize_{}.png".format(iterations))
                        tmp = (np.clip(fake[i] * 127.5 + 127.5, 0,
                                       255)).transpose(1, 2,
                                                       0).astype(np.uint8)
                        pylab.subplot(4, 4, 2 * i + 2)
                        pylab.imshow(tmp)
                        pylab.axis("off")
                        pylab.savefig(
                            "outdir/visualize_{}.png".format(iterations))

                munit.train()

            print(f"iter: {iterations} loss: {loss.data}")
Exemple #22
0
class Solver(object):

    def __init__(self, celebA_loader, rafd_loader, config):
        # Data loader
        self.celebA_loader = celebA_loader
        self.rafd_loader = rafd_loader

        # Model hyper-parameters
        self.c_dim = config.c_dim
        self.c2_dim = config.c2_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.d_train_repeat = config.d_train_repeat

        # Hyper-parameteres
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        # Training settings
        self.dataset = config.dataset
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.batch_size = config.batch_size
        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model

        # Test settings
        self.test_model = config.test_model

        # Path
        self.log_path = config.log_path
        self.sample_path = config.sample_path
        self.model_save_path = config.model_save_path
        self.result_path = config.result_path

        # Step size
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step

        # Build tensorboard if use
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def build_model(self):
        # Define a generator and a discriminator
        if self.dataset == 'Both':
            self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num)   # 2 for mask vector
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)
        else:
            self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) 

        # Optimizers
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])

        # Print networks
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(self.pretrained_model))

    def build_tensorboard(self):
        from logger import Logger
        self.logger = Logger(self.log_path)

    def update_lr(self, g_lr, d_lr):
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def to_var(self, x, volatile=False):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def threshold(self, x):
        x = x.clone()
        x[x >= 0.5] = 1
        x[x < 0.5] = 0
        return x

    def compute_accuracy(self, x, y, dataset):
        if dataset == 'CelebA':
            x = F.sigmoid(x)
            predicted = self.threshold(x)
            correct = (predicted == y).float()
            accuracy = torch.mean(correct, dim=0) * 100.0
        else:
            _, predicted = torch.max(x, dim=1)
            correct = (predicted == y).float()
            accuracy = torch.mean(correct) * 100.0
        return accuracy

    def one_hot(self, labels, dim):
        """Convert label indices to one-hot vector"""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def make_celeb_labels(self, real_c):
        """Generate domain labels for CelebA for debugging/testing.

        if dataset == 'CelebA':
            return single and multiple attribute changes
        elif dataset == 'Both':
            return single attribute changes
        """
        y = [torch.FloatTensor([1, 0, 0]),  # black hair
             torch.FloatTensor([0, 1, 0]),  # blond hair
             torch.FloatTensor([0, 0, 1])]  # brown hair

        fixed_c_list = []

        # single attribute transfer
        for i in range(self.c_dim):
            fixed_c = real_c.clone()
            for c in fixed_c:
                if i < 3:
                    c[:3] = y[i]
                else:
                    c[i] = 0 if c[i] == 1 else 1   # opposite value
            fixed_c_list.append(self.to_var(fixed_c, volatile=True))

        # multi-attribute transfer (H+G, H+A, G+A, H+G+A)
        if self.dataset == 'CelebA':
            for i in range(4):
                fixed_c = real_c.clone()
                for c in fixed_c:
                    if i in [0, 1, 3]:   # Hair color to brown
                        c[:3] = y[2] 
                    if i in [0, 2, 3]:   # Gender
                        c[3] = 0 if c[3] == 1 else 1
                    if i in [1, 2, 3]:   # Aged
                        c[4] = 0 if c[4] == 1 else 1
                fixed_c_list.append(self.to_var(fixed_c, volatile=True))
        return fixed_c_list

    def train(self):
        """Train StarGAN within a single dataset."""

        # Set dataloader
        if self.dataset == 'CelebA':
            self.data_loader = self.celebA_loader
        else:
            self.data_loader = self.rafd_loader

        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        fixed_x = []
        real_c = []
        for i, (images, labels) in enumerate(self.data_loader):
            fixed_x.append(images)
            real_c.append(labels)
            if i == 3:
                break

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)
        real_c = torch.cat(real_c, dim=0)

        if self.dataset == 'CelebA':
            fixed_c_list = self.make_celeb_labels(real_c)
        elif self.dataset == 'RaFD':
            fixed_c_list = []
            for i in range(self.c_dim):
                fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c_dim)
                fixed_c_list.append(self.to_var(fixed_c, volatile=True))

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (real_x, real_label) in enumerate(self.data_loader):
                
                # Generat fake labels randomly (target domain labels)
                rand_idx = torch.randperm(real_label.size(0))
                fake_label = real_label[rand_idx]

                if self.dataset == 'CelebA':
                    real_c = real_label.clone()
                    fake_c = fake_label.clone()
                else:
                    real_c = self.one_hot(real_label, self.c_dim)
                    fake_c = self.one_hot(fake_label, self.c_dim)

                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_c = self.to_var(real_c)           # input for the generator
                fake_c = self.to_var(fake_c)
                real_label = self.to_var(real_label)   # this is same as real_c if dataset == 'CelebA'
                fake_label = self.to_var(fake_label)
                
                # ================== Train D ================== #

                # Compute loss with real images
                out_src, out_cls = self.D(real_x)
                d_loss_real = - torch.mean(out_src)

                if self.dataset == 'CelebA':
                    d_loss_cls = F.binary_cross_entropy_with_logits(
                        out_cls, real_label, size_average=False) / real_x.size(0)
                else:
                    d_loss_cls = F.cross_entropy(out_cls, real_label)

                # Compute classification accuracy of the discriminator
                if (i+1) % self.log_step == 0:
                    accuracies = self.compute_accuracy(out_cls, real_label, self.dataset)
                    log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                    if self.dataset == 'CelebA':
                        print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='')
                    else:
                        print('Classification Acc (8 emotional expressions): ', end='')
                    print(log)

                # Compute loss with fake images
                fake_x = self.G(real_x, fake_c)
                fake_x = Variable(fake_x.data)
                out_src, out_cls = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
                out, out_cls = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                loss = {}
                loss['D/loss_real'] = d_loss_real.data[0]
                loss['D/loss_fake'] = d_loss_fake.data[0]
                loss['D/loss_cls'] = d_loss_cls.data[0]
                loss['D/loss_gp'] = d_loss_gp.data[0]

                # ================== Train G ================== #
                if (i+1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    fake_x = self.G(real_x, fake_c)
                    rec_x = self.G(fake_x, real_c)

                    # Compute losses
                    out_src, out_cls = self.D(fake_x)
                    g_loss_fake = - torch.mean(out_src)
                    g_loss_rec = torch.mean(torch.abs(real_x - rec_x))

                    if self.dataset == 'CelebA':
                        g_loss_cls = F.binary_cross_entropy_with_logits(
                            out_cls, fake_label, size_average=False) / fake_x.size(0)
                    else:
                        g_loss_cls = F.cross_entropy(out_cls, fake_label)

                    # Backward + Optimize
                    g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    loss['G/loss_fake'] = g_loss_fake.data[0]
                    loss['G/loss_rec'] = g_loss_rec.data[0]
                    loss['G/loss_cls'] = g_loss_cls.data[0]

                # Print out log info
                if (i+1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e+1, self.num_epochs, i+1, iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1)

                # Translate fixed images for debugging
                if (i+1) % self.sample_step == 0:
                    fake_image_list = [fixed_x]
                    for fixed_c in fixed_c_list:
                        fake_image_list.append(self.G(fixed_x, fixed_c))
                    fake_images = torch.cat(fake_image_list, dim=3)
                    save_image(self.denorm(fake_images.data),
                        os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0)
                    print('Translated images and saved into {}..!'.format(self.sample_path))

                # Save model checkpoints
                if (i+1) % self.model_save_step == 0:
                    torch.save(self.G.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e+1, i+1)))
                    torch.save(self.D.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e+1, i+1)))

            # Decay learning rate
            if (e+1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

    def train_multi(self):
        """Train StarGAN with multiple datasets.
        In the code below, 1 is related to CelebA and 2 is releated to RaFD.
        """
        # Fixed imagse and labels for debugging
        fixed_x = []
        real_c = []

        for i, (images, labels) in enumerate(self.celebA_loader):
            fixed_x.append(images)
            real_c.append(labels)
            if i == 2:
                break

        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)
        real_c = torch.cat(real_c, dim=0)
        fixed_c1_list = self.make_celeb_labels(real_c)

        fixed_c2_list = []
        for i in range(self.c2_dim):
            fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c2_dim)
            fixed_c2_list.append(self.to_var(fixed_c, volatile=True))

        fixed_zero1 = self.to_var(torch.zeros(fixed_x.size(0), self.c2_dim))     # zero vector when training with CelebA
        fixed_mask1 = self.to_var(self.one_hot(torch.zeros(fixed_x.size(0)), 2)) # mask vector: [1, 0]
        fixed_zero2 = self.to_var(torch.zeros(fixed_x.size(0), self.c_dim))      # zero vector when training with RaFD
        fixed_mask2 = self.to_var(self.one_hot(torch.ones(fixed_x.size(0)), 2))  # mask vector: [0, 1]

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # data iterator
        data_iter1 = iter(self.celebA_loader)
        data_iter2 = iter(self.rafd_loader)

        # Start with trained model
        if self.pretrained_model:
            start = int(self.pretrained_model) + 1
        else:
            start = 0

        # # Start training
        start_time = time.time()
        for i in range(start, self.num_iters):

            # Fetch mini-batch images and labels
            try:
                real_x1, real_label1 = next(data_iter1)
            except:
                data_iter1 = iter(self.celebA_loader)
                real_x1, real_label1 = next(data_iter1)

            try:
                real_x2, real_label2 = next(data_iter2)
            except:
                data_iter2 = iter(self.rafd_loader)
                real_x2, real_label2 = next(data_iter2)

            # Generate fake labels randomly (target domain labels)
            rand_idx = torch.randperm(real_label1.size(0))
            fake_label1 = real_label1[rand_idx]
            rand_idx = torch.randperm(real_label2.size(0))
            fake_label2 = real_label2[rand_idx]

            real_c1 = real_label1.clone()
            fake_c1 = fake_label1.clone()
            zero1 = torch.zeros(real_x1.size(0), self.c2_dim)
            mask1 = self.one_hot(torch.zeros(real_x1.size(0)), 2)

            real_c2 = self.one_hot(real_label2, self.c2_dim)
            fake_c2 = self.one_hot(fake_label2, self.c2_dim)
            zero2 = torch.zeros(real_x2.size(0), self.c_dim)
            mask2 = self.one_hot(torch.ones(real_x2.size(0)), 2)

            # Convert tensor to variable
            real_x1 = self.to_var(real_x1)
            real_c1 = self.to_var(real_c1)
            fake_c1 = self.to_var(fake_c1)
            mask1 = self.to_var(mask1)
            zero1 = self.to_var(zero1)

            real_x2 = self.to_var(real_x2)
            real_c2 = self.to_var(real_c2)
            fake_c2 = self.to_var(fake_c2)
            mask2 = self.to_var(mask2)
            zero2 = self.to_var(zero2)

            real_label1 = self.to_var(real_label1)
            fake_label1 = self.to_var(fake_label1)
            real_label2 = self.to_var(real_label2)
            fake_label2 = self.to_var(fake_label2)

            # ================== Train D ================== #

            # Real images (CelebA)
            out_real, out_cls = self.D(real_x1)
            out_cls1 = out_cls[:, :self.c_dim]      # celebA part
            d_loss_real = - torch.mean(out_real)
            d_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, real_label1, size_average=False) / real_x1.size(0)

            # Real images (RaFD)
            out_real, out_cls = self.D(real_x2)
            out_cls2 = out_cls[:, self.c_dim:]      # rafd part
            d_loss_real += - torch.mean(out_real)
            d_loss_cls += F.cross_entropy(out_cls2, real_label2)

            # Compute classification accuracy of the discriminator
            if (i+1) % self.log_step == 0:
                accuracies = self.compute_accuracy(out_cls1, real_label1, 'CelebA')
                log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='')
                print(log)
                accuracies = self.compute_accuracy(out_cls2, real_label2, 'RaFD')
                log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                print('Classification Acc (8 emotional expressions): ', end='')
                print(log)

            # Fake images (CelebA)
            fake_c = torch.cat([fake_c1, zero1, mask1], dim=1)
            fake_x1 = self.G(real_x1, fake_c)
            fake_x1 = Variable(fake_x1.data)
            out_fake, _ = self.D(fake_x1)
            d_loss_fake = torch.mean(out_fake)

            # Fake images (RaFD)
            fake_c = torch.cat([zero2, fake_c2, mask2], dim=1)
            fake_x2 = self.G(real_x2, fake_c)
            out_fake, _ = self.D(fake_x2)
            d_loss_fake += torch.mean(out_fake)

            # Backward + Optimize
            d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Compute gradient penalty
            if (i+1) % 2 == 0:
                real_x = real_x1
                fake_x = fake_x1
            else:
                real_x = real_x2
                fake_x = fake_x2

            alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
            interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
            out, out_cls = self.D(interpolated)

            if (i+1) % 2 == 0:
                out_cls = out_cls[:, :self.c_dim]  # CelebA
            else:
                out_cls = out_cls[:, self.c_dim:]  # RaFD

            grad = torch.autograd.grad(outputs=out,
                                       inputs=interpolated,
                                       grad_outputs=torch.ones(out.size()).cuda(),
                                       retain_graph=True,
                                       create_graph=True,
                                       only_inputs=True)[0]

            grad = grad.view(grad.size(0), -1)
            grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
            d_loss_gp = torch.mean((grad_l2norm - 1)**2)

            # Backward + Optimize
            d_loss = self.lambda_gp * d_loss_gp
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Logging
            loss = {}
            loss['D/loss_real'] = d_loss_real.data[0]
            loss['D/loss_fake'] = d_loss_fake.data[0]
            loss['D/loss_cls'] = d_loss_cls.data[0]
            loss['D/loss_gp'] = d_loss_gp.data[0]

            # ================== Train G ================== #
            if (i+1) % self.d_train_repeat == 0:
                # Original-to-target and target-to-original domain (CelebA)
                fake_c = torch.cat([fake_c1, zero1, mask1], dim=1)
                real_c = torch.cat([real_c1, zero1, mask1], dim=1)
                fake_x1 = self.G(real_x1, fake_c)
                rec_x1 = self.G(fake_x1, real_c)

                # Compute losses
                out, out_cls = self.D(fake_x1)
                out_cls1 = out_cls[:, :self.c_dim]
                g_loss_fake = - torch.mean(out)
                g_loss_rec = torch.mean(torch.abs(real_x1 - rec_x1))
                g_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, fake_label1, size_average=False) / fake_x1.size(0)

                # Original-to-target and target-to-original domain (RaFD)
                fake_c = torch.cat([zero2, fake_c2, mask2], dim=1)
                real_c = torch.cat([zero2, real_c2, mask2], dim=1)
                fake_x2 = self.G(real_x2, fake_c)
                rec_x2 = self.G(fake_x2, real_c)

                # Compute losses
                out, out_cls = self.D(fake_x2)
                out_cls2 = out_cls[:, self.c_dim:]
                g_loss_fake += - torch.mean(out)
                g_loss_rec += torch.mean(torch.abs(real_x2 - rec_x2))
                g_loss_cls += F.cross_entropy(out_cls2, fake_label2)

                # Backward + Optimize
                g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + self.lambda_rec * g_loss_rec
                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging
                loss['G/loss_fake'] = g_loss_fake.data[0]
                loss['G/loss_cls'] = g_loss_cls.data[0]
                loss['G/loss_rec'] = g_loss_rec.data[0]

            # Print out log info
            if (i+1) % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))

                log = "Elapsed [{}], Iter [{}/{}]".format(
                    elapsed, i+1, self.num_iters)

                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i+1)

            # Translate the images (debugging)
            if (i+1) % self.sample_step == 0:
                fake_image_list = [fixed_x]

                # Changing hair color, gender, and age
                for j in range(self.c_dim):
                    fake_c = torch.cat([fixed_c1_list[j], fixed_zero1, fixed_mask1], dim=1)
                    fake_image_list.append(self.G(fixed_x, fake_c))
                # Changing emotional expressions
                for j in range(self.c2_dim):
                    fake_c = torch.cat([fixed_zero2, fixed_c2_list[j], fixed_mask2], dim=1)
                    fake_image_list.append(self.G(fixed_x, fake_c))
                fake = torch.cat(fake_image_list, dim=3)

                # Save the translated images
                save_image(self.denorm(fake.data),
                    os.path.join(self.sample_path, '{}_fake.png'.format(i+1)), nrow=1, padding=0)

            # Save model checkpoints
            if (i+1) % self.model_save_step == 0:
                torch.save(self.G.state_dict(),
                    os.path.join(self.model_save_path, '{}_G.pth'.format(i+1)))
                torch.save(self.D.state_dict(),
                    os.path.join(self.model_save_path, '{}_D.pth'.format(i+1)))

            # Decay learning rate
            decay_step = 1000
            if (i+1) > (self.num_iters - self.num_iters_decay) and (i+1) % decay_step==0:
                g_lr -= (self.g_lr / float(self.num_iters_decay) * decay_step)
                d_lr -= (self.d_lr / float(self.num_iters_decay) * decay_step)
                self.update_lr(g_lr, d_lr)
                print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

    def test(self):
        """Facial attribute transfer on CelebA or facial expression synthesis on RaFD."""
        # Load trained parameters
        G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model))
        self.G.load_state_dict(torch.load(G_path))
        self.G.eval()

        if self.dataset == 'CelebA':
            data_loader = self.celebA_loader
        else:
            data_loader = self.rafd_loader

        for i, (real_x, org_c) in enumerate(data_loader):
            real_x = self.to_var(real_x, volatile=True)

            if self.dataset == 'CelebA':
                target_c_list = self.make_celeb_labels(org_c)
            else:
                target_c_list = []
                for j in range(self.c_dim):
                    target_c = self.one_hot(torch.ones(real_x.size(0)) * j, self.c_dim)
                    target_c_list.append(self.to_var(target_c, volatile=True))

            # Start translations
            fake_image_list = [real_x]
            for target_c in target_c_list:
                fake_image_list.append(self.G(real_x, target_c))
            fake_images = torch.cat(fake_image_list, dim=3)
            save_path = os.path.join(self.result_path, '{}_fake.png'.format(i+1))
            save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0)
            print('Translated test images and saved into "{}"..!'.format(save_path))

    def test_multi(self):
        """Facial attribute transfer and expression synthesis on CelebA."""
        # Load trained parameters
        G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model))
        self.G.load_state_dict(torch.load(G_path))
        self.G.eval()

        for i, (real_x, org_c) in enumerate(self.celebA_loader):

            # Prepare input images and target domain labels
            real_x = self.to_var(real_x, volatile=True)
            target_c1_list = self.make_celeb_labels(org_c)
            target_c2_list = []
            for j in range(self.c2_dim):
                target_c = self.one_hot(torch.ones(real_x.size(0)) * j, self.c2_dim)
                target_c2_list.append(self.to_var(target_c, volatile=True))

            # Zero vectors and mask vectors
            zero1 = self.to_var(torch.zeros(real_x.size(0), self.c2_dim))     # zero vector for rafd expressions
            mask1 = self.to_var(self.one_hot(torch.zeros(real_x.size(0)), 2)) # mask vector: [1, 0]
            zero2 = self.to_var(torch.zeros(real_x.size(0), self.c_dim))      # zero vector for celebA attributes
            mask2 = self.to_var(self.one_hot(torch.ones(real_x.size(0)), 2))  # mask vector: [0, 1]

            # Changing hair color, gender, and age
            fake_image_list = [real_x]
            for j in range(self.c_dim):
                target_c = torch.cat([target_c1_list[j], zero1, mask1], dim=1)
                fake_image_list.append(self.G(real_x, target_c))

            # Changing emotional expressions
            for j in range(self.c2_dim):
                target_c = torch.cat([zero2, target_c2_list[j], mask2], dim=1)
                fake_image_list.append(self.G(real_x, target_c))
            fake_images = torch.cat(fake_image_list, dim=3)

            # Save the translated images
            save_path = os.path.join(self.result_path, '{}_fake.png'.format(i+1))
            save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0)
            print('Translated test images and saved into "{}"..!'.format(save_path))
Exemple #23
0

num_epoch = 5
batchSize = 64
lr = 0.0002
l1_lambda = 10

text_logger = setup_logger('Train')
logger = Logger('./logs')

discriminator = Discriminator()
generator = Generator()
discriminator.apply(weights_init)
generator.apply(weights_init)
if torch.cuda.is_available():
    discriminator.cuda()
    generator.cuda()

loss_function = nn.CrossEntropyLoss()
d_optim = torch.optim.Adam(discriminator.parameters(), lr, [0.5, 0.999])
g_optim = torch.optim.Adam(generator.parameters(), lr, [0.5, 0.999])

dataloader = DataLoader(batchSize)
data_size = len(dataloader.train_index)
num_batch = data_size // batchSize
#text_logger.info('Total number of videos for train = ' + str(data_size))
#text_logger.info('Total number of batches per echo = ' + str(num_batch))

start_time = time.time()
counter = 0
DIR_TO_SAVE = "./gen_videos/"
Exemple #24
0
def train():
    os.makedirs("images", exist_ok=True)
    os.makedirs("checkpoints", exist_ok=True)

    cuda = True if torch.cuda.is_available() else False
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # get configs and dataloader
    opt = parse_args()
    data_loader = mnist_loader(opt)

    # Initialize generator and discriminator
    generator = Generator(opt)
    discriminator = Discriminator(opt)

    if cuda:
        generator.cuda()
        discriminator.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))

    for epoch in range(opt.epochs):
        for i, (imgs, _) in enumerate(data_loader):

            # Configure input
            z = Variable(
                Tensor(np.random.normal(0, 1,
                                        (imgs.shape[0], opt.latent_dim))))
            gen_imgs = generator(z)
            real_imgs = Variable(imgs.type(Tensor))

            # ------------------
            # Train Discriminator
            # ------------------

            optimizer_D.zero_grad()
            gradient_penalty = compute_gradient_penalty(
                discriminator, real_imgs.data, gen_imgs.data, Tensor)
            d_loss = torch.mean(discriminator(gen_imgs)) - torch.mean(
                discriminator(real_imgs)) + opt.lambda_gp * gradient_penalty

            d_loss.backward()
            optimizer_D.step()

            # ------------------
            # Train Generator
            # ------------------

            if i % opt.n_critic == 0:
                optimizer_G.zero_grad()
                g_loss = -torch.mean(discriminator(generator(z)))

                g_loss.backward()
                optimizer_G.step()

            # ------------------
            # Log Information
            # ------------------

            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                  (epoch, opt.epochs, i, len(data_loader), d_loss.item(),
                   g_loss.item()))

            batches_done = epoch * len(data_loader) + i
            if batches_done % opt.sample_interval == 0:
                save_image(gen_imgs.data[:25],
                           "images/%d.png" % batches_done,
                           nrow=5,
                           normalize=True)

            if batches_done % opt.checkpoint_interval == 0:
                torch.save(generator.state_dict(),
                           "checkpoints/generator_%d.pth" % epoch)
                # torch.save(discriminator.state_dict(), "checkpoints/discriminator_%d.pth" % epoch)

    torch.save(generator.state_dict(), "checkpoints/generator_done.pth")
    print("Training Process has been Done!")
Exemple #25
0
class Solver(object):
    def __init__(self, config, data_loader):
        self.generator = None
        self.discriminator = None
        self.g_optimizer = None
        self.d_optimizer = None
        self.pc_name = config.pc_name
        self.base_path = config.base_path
        self.time_now = config.time_now
        self.inject_z = config.inject_z
        self.data_loader = data_loader
        self.num_epochs = config.num_epochs
        self.sample_size = config.sample_size
        self.logs_path = config.logs_path
        self.save_every = config.save_every
        self.activation_fn = config.activation_fn
        self.max_score = config.max_score
        self.lr = config.lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.validation_step = config.validation_step
        self.sample_path = config.sample_path
        self.model_path = config.model_path
        self.g_layers = config.g_layers
        self.d_layers = config.d_layers
        self.z_dim = self.g_layers[0]
        self.num_imgs_val = config.num_imgs_val
        self.criterion = nn.BCEWithLogitsLoss()
        self.ckpt_gen_path = config.ckpt_gen_path
        self.gp_weight = config.gp_weight
        self.loss = config.loss
        self.seed = config.seed
        self.validation_path = config.validation_path
        self.FID_images = config.FID_images
        self.transform_rep = config.transform_rep
        self.transform_z = config.transform_z
        self.spectral_norm = config.spectral_norm
        self.cifar10_path = config.cifar10_path
        self.fid_score = 100000
        self.concat_injection = config.concat_injection
        self.norm = config.norm
        self.build_model()

    def build_model(self):
        torch.manual_seed(self.seed)
        self.generator = Generator(g_layers=self.g_layers,
                                   activation_fn=self.activation_fn,
                                   inject_z=self.inject_z,
                                   transform_rep=self.transform_rep,
                                   transform_z=self.transform_z,
                                   concat_injection=self.concat_injection,
                                   norm=self.norm)
        self.discriminator = Discriminator(d_layers=self.d_layers,
                                           activation_fn=self.activation_fn,
                                           spectral_norm=self.spectral_norm)
        self.generator.apply(self.weights_init)
        self.discriminator.apply(self.weights_init)
        self.g_optimizer = optim.Adam(self.generator.parameters(),
                                      self.lr,
                                      betas=(self.beta1, self.beta2))
        self.d_optimizer = optim.Adam(self.discriminator.parameters(),
                                      self.lr,
                                      betas=(self.beta1, self.beta2))
        self.logger = Logger(self.logs_path)

        self.gen_params = sum(p.numel() for p in self.generator.parameters()
                              if p.requires_grad)
        self.disc_params = sum(p.numel()
                               for p in self.discriminator.parameters()
                               if p.requires_grad)
        self.total_params = self.gen_params + self.disc_params

        print("Generator params: {}".format(self.gen_params))
        print("Discrimintor params: {}".format(self.disc_params))
        print("Total params: {}".format(self.total_params))

        if torch.cuda.is_available():
            self.generator.cuda()
            self.discriminator.cuda()

    def reset_grad(self):
        self.discriminator.zero_grad()
        self.generator.zero_grad()

    # custom weights initialization called on netG and netD
    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    def gradient_penalty(self, real_data, generated_data):
        batch_size = real_data.size()[0]

        # Calculate interpolation
        alpha = torch.rand(batch_size, 1)
        alpha = alpha.expand_as(real_data)
        alpha = alpha.cuda()
        interpolated = alpha * real_data.data + (1 -
                                                 alpha) * generated_data.data
        interpolated = Variable(interpolated, requires_grad=True)
        interpolated = interpolated.cuda()

        # Calculate probability of interpolated examples
        prob_interpolated = self.discriminator(interpolated)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(outputs=prob_interpolated,
                               inputs=interpolated,
                               grad_outputs=torch.ones(
                                   prob_interpolated.size()).cuda(),
                               create_graph=True,
                               retain_graph=True)[0]

        # Gradients have shape (batch_size, num_channels, height, width),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(batch_size, -1)

        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon
        gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12)

        # Return gradient penalty
        return self.gp_weight * ((gradients_norm - 1)**2).mean()

    def train(self):
        total_step = len(self.data_loader)
        for epoch in range(self.num_epochs):
            for i, (data, _) in enumerate(self.data_loader):

                batch_size = data.size(0)
                # train Discriminator
                data = data.type(torch.FloatTensor)
                data = to_cuda(data)

                real_labels = to_cuda(torch.ones(batch_size,
                                                 self.d_layers[-1]))
                fake_labels = to_cuda(
                    torch.zeros(batch_size, self.d_layers[-1]))

                outputs_real = self.discriminator(data)
                z = to_cuda(torch.randn(batch_size, self.z_dim, 1, 1))
                fake_data = self.generator(z)
                outputs_fake = self.discriminator(fake_data)

                if self.loss == 'original':
                    d_loss_real = self.criterion(outputs_real.squeeze(),
                                                 real_labels.squeeze())
                    d_loss_fake = self.criterion(outputs_fake.squeeze(),
                                                 fake_labels.squeeze())
                    d_loss = d_loss_real + d_loss_fake

                elif self.loss == 'wgan-gp':
                    gradient_penalty = self.gradient_penalty(data, fake_data)
                    d_loss = -outputs_real.mean() + outputs_fake.mean(
                    ) + gradient_penalty

                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # train Generator
                z = to_cuda(torch.randn(batch_size, self.z_dim, 1, 1))
                fake_data = self.generator(z)
                outputs_fake = self.discriminator(fake_data)

                if self.loss == 'original':
                    g_loss = self.criterion(outputs_fake.squeeze(),
                                            real_labels.squeeze())
                elif self.loss == 'wgan-gp':
                    g_loss = -outputs_fake.mean()

                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                if (i + 1) % self.log_step == 0:
                    print(
                        'Epoch [{0:d}/{1:d}], Step [{2:d}/{3:d}], d_real_loss: {4:.4f}, '
                        ' g_loss: {5:.4f}'.format(epoch + 1, self.num_epochs,
                                                  i + 1, total_step,
                                                  d_loss.item(),
                                                  g_loss.item()))

                    # log scalars in tensorboard
                    info = {
                        'd_real_loss': d_loss.item(),
                        'g_loss': g_loss.item(),
                        'inception_score': self.max_score
                    }

                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value,
                                                   epoch * total_step + i + 1)

                if (i + 1) % self.sample_step == 0:
                    save_image(denorm(fake_data).cpu(),
                               self.sample_path +
                               "/epoch_{}_{}.png".format(i + 1, epoch + 1),
                               nrow=8)

                if (i + 1) % self.validation_step == 0:
                    fake_data_all = np.zeros(
                        (self.num_imgs_val, fake_data.size(1),
                         fake_data.size(2), fake_data.size(3)))
                    for j in range(self.num_imgs_val // batch_size):
                        fake_data_all[j * batch_size:(j + 1) *
                                      batch_size] = to_numpy(fake_data)
                    npy_path = os.path.join(
                        self.model_path,
                        '{}_{}_val_data.pkl'.format(epoch + 1, i + 1))
                    np.save(npy_path, fake_data_all)
                    score, _ = IS(fake_data_all,
                                  cuda=True,
                                  batch_size=batch_size)
                    if score > self.max_score:
                        print("Found new best IS score: {}".format(score))
                        self.max_score = score
                        data = "IS " + str(self.seed) + " " + str(
                            epoch + 1) + " " + str(i + 1) + " " + str(
                                self.max_score)
                        save_is(self.base_path, data)
                        g_path = os.path.join(self.model_path,
                                              'generator-best.pkl')
                        d_path = os.path.join(self.model_path,
                                              'discriminator-best.pkl')
                        torch.save(self.generator.state_dict(), g_path)
                        torch.save(self.discriminator.state_dict(), d_path)
                    for j in range(self.FID_images):
                        z = to_cuda(torch.randn(1, self.z_dim, 1, 1))
                        fake_datum = self.generator(z)
                        save_image(
                            denorm(fake_datum.squeeze()).cpu(),
                            self.validation_path + "/" + str(j) + ".png")
                    fid_value = FID([self.validation_path, self.cifar10_path],
                                    64, True, 2048)
                    if fid_value < self.fid_score:
                        self.fid_score = fid_value
                        print("Found new best FID score: {}".format(
                            self.fid_score))
                        data = "FID " + str(self.seed) + " " + str(
                            epoch + 1) + " " + str(i + 1) + " " + str(
                                self.fid_score)
                        save_is(self.base_path, data)
                        g_path = os.path.join(self.model_path,
                                              'generator-best-fid.pkl')
                        d_path = os.path.join(self.model_path,
                                              'discriminator-best-fid.pkl')
                        torch.save(self.generator.state_dict(), g_path)
                        torch.save(self.discriminator.state_dict(), d_path)

            if (epoch + 1) % self.save_every == 0:
                g_path = os.path.join(self.model_path,
                                      'generator-{}.pkl'.format(epoch + 1))
                d_path = os.path.join(self.model_path,
                                      'discriminator-{}.pkl'.format(epoch + 1))
                torch.save(self.generator.state_dict(), g_path)
                torch.save(self.discriminator.state_dict(), d_path)

    def sample(self, n_samples):
        self.n_samples = n_samples
        self.generator = Generator(g_layers=self.g_layers,
                                   inject_z=self.inject_z)
        self.generator.load_state_dict(torch.load(self.ckpt_gen_path))
        if torch.cuda.is_available():
            self.generator.cuda()
        self.generator.eval()

        z_samples = to_cuda(torch.randn(n_samples, self.z_dim, 1, 1))
        generated_samples = self.generator(z_samples)
        generated_samples = to_numpy(generated_samples)
        np.save('./saved/generated_samples.npy', generated_samples)
        z_samples = to_numpy(z_samples)
        np.save('./saved/z_samples.npy', z_samples)
    #PredRNN(64,32, 4,4, num_layers=1),
    PredRNN(32, 16, 4, 4, num_layers=1),
    PredRNN(16, 1, 4, 4, num_layers=1),
)
att = SelfAttention(4)
disc = Discriminator()

if os.path.isfile(weight_path_lstm):
    w = torch.load(weight_path_lstm)
    lstm.load_state_dict(w['lstm'])
    att.load_state_dict(w['att'])
    disc.load_state_dict(w['disc'])
    del w
lstm = lstm.cuda()
att = att.cuda()
disc = disc.cuda()

### OPTIM ###
loss_func = AdaptiveWingLoss()
optimizer = optim.Adam(
    list(lstm.parameters()) + list(att.parameters()),
    lr=lr,
    #weight_decay=1e-5
)
optimizerDisc = optim.Adam(
    list(disc.parameters()),
    lr=lr,
    #weight_decay=1e-5
)
if os.path.isfile(weight_path_lstm):
    w = torch.load(weight_path_lstm)
Exemple #27
0
class Solver(object):
    def __init__(self, config, data_loader):
        self.generator = None
        self.discriminator = None
        self.g_optimizer = None
        self.d_optimizer = None
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.z_dim = config.z_dim
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.image_size = config.image_size
        self.data_loader = data_loader
        self.num_epochs = config.num_epochs
        self.batch_size = config.batch_size
        self.sample_size = config.sample_size
        self.lr = config.lr
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.sample_path = config.sample_path
        self.model_path = config.model_path
        self.build_model()

    def build_model(self):
        """Build generator and discriminator."""
        self.generator = Generator(z_dim=self.z_dim,
                                   image_size=self.image_size,
                                   conv_dim=self.g_conv_dim)
        self.discriminator = Discriminator(image_size=self.image_size,
                                           conv_dim=self.d_conv_dim)
        self.g_optimizer = optim.Adam(self.generator.parameters(), self.lr,
                                      [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(self.discriminator.parameters(), self.lr,
                                      [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.generator.cuda()
            self.discriminator.cuda()

    def to_variable(self, x):
        """Convert tensor to variable."""
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x)

    def to_data(self, x):
        """Convert variable to tensor."""
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data

    def reset_grad(self):
        """Zero the gradient buffers."""
        self.discriminator.zero_grad()
        self.generator.zero_grad()

    def denorm(self, x):
        """Convert range (-1, 1) to (0, 1)"""
        out = (x + 1) / 2
        return out.clamp(0, 1)

    def train(self):
        """Train generator and discriminator."""
        fixed_noise = self.to_variable(torch.randn(self.batch_size,
                                                   self.z_dim))
        total_step = len(self.data_loader)
        for epoch in range(self.num_epochs):
            for i, images in enumerate(self.data_loader):

                # ===================== Train D =====================#
                images = self.to_variable(images)
                batch_size = images.size(0)
                noise = self.to_variable(torch.randn(batch_size, self.z_dim))

                # Train D to recognize real images as real.
                outputs = self.discriminator(images)
                real_loss = torch.mean(
                    (outputs - 1)**2
                )  # L2 loss instead of Binary cross entropy loss (this is optional for stable training)

                # Train D to recognize fake images as fake.
                fake_images = self.generator(noise)
                outputs = self.discriminator(fake_images)
                fake_loss = torch.mean(outputs**2)

                # Backprop + optimize
                d_loss = real_loss + fake_loss
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # ===================== Train G =====================#
                noise = self.to_variable(torch.randn(batch_size, self.z_dim))

                # Train G so that D recognizes G(z) as real.
                fake_images = self.generator(noise)
                outputs = self.discriminator(fake_images)
                g_loss = torch.mean((outputs - 1)**2)

                # Backprop + optimize
                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # print the log info
                if (i + 1) % self.log_step == 0:
                    print(
                        'Epoch [%d/%d], Step[%d/%d], d_real_loss: %.4f, '
                        'd_fake_loss: %.4f, g_loss: %.4f' %
                        (epoch + 1, self.num_epochs, i + 1, total_step,
                         real_loss.data[0], fake_loss.data[0], g_loss.data[0]))

                # save the sampled images
                if (i + 1) % self.sample_step == 0:
                    fake_images = self.generator(fixed_noise)
                    torchvision.utils.save_image(
                        self.denorm(fake_images.data),
                        os.path.join(
                            self.sample_path,
                            'fake_samples-%d-%d.png' % (epoch + 1, i + 1)))

            # save the model parameters for each epoch
            g_path = os.path.join(self.model_path,
                                  'generator-%d.pkl' % (epoch + 1))
            d_path = os.path.join(self.model_path,
                                  'discriminator-%d.pkl' % (epoch + 1))
            torch.save(self.generator.state_dict(), g_path)
            torch.save(self.discriminator.state_dict(), d_path)

    def sample(self):

        # Load trained parameters
        g_path = os.path.join(self.model_path,
                              'generator-%d.pkl' % (self.num_epochs))
        d_path = os.path.join(self.model_path,
                              'discriminator-%d.pkl' % (self.num_epochs))
        self.generator.load_state_dict(torch.load(g_path))
        self.discriminator.load_state_dict(torch.load(d_path))
        self.generator.eval()
        self.discriminator.eval()

        # Sample the images
        noise = self.to_variable(torch.randn(self.sample_size, self.z_dim))
        fake_images = self.generator(noise)
        sample_path = os.path.join(self.sample_path, 'fake_samples-final.png')
        torchvision.utils.save_image(self.denorm(fake_images.data),
                                     sample_path,
                                     nrow=12)

        print("Saved sampled images to '%s'" % sample_path)
Exemple #28
0
def train(epochs, s_path, t_path, batchsize, interval):
    dataset = UGATITDataset(s_path, t_path)
    print(dataset)
    collator = ImageCollate()

    generator_st = Generator()
    generator_st.cuda()
    generator_st.train()
    optim_gen_st = torch.optim.Adam(generator_st.parameters(),
                                    lr=0.0001,
                                    betas=(0.5, 0.999))

    generator_ts = Generator()
    generator_ts.cuda()
    generator_ts.train()
    optim_gen_ts = torch.optim.Adam(generator_ts.parameters(),
                                    lr=0.0001,
                                    betas=(0.5, 0.999))

    discriminator_gt = Discriminator()
    discriminator_gt.cuda()
    discriminator_gt.train()
    optim_dis_gt = torch.optim.Adam(discriminator_gt.parameters(),
                                    lr=0.0001,
                                    betas=(0.5, 0.999))

    discriminator_gs = Discriminator()
    discriminator_gs.cuda()
    discriminator_gs.train()
    optim_dis_gs = torch.optim.Adam(discriminator_gs.parameters(),
                                    lr=0.0001,
                                    betas=(0.5, 0.999))

    #discriminator_rt = Discriminator()
    #discriminator_rt.cuda()
    #discriminator_rt.train()
    #optim_dis_rt = torch.optim.Adam(discriminator_rt.parameters(), lr=0.0001, betas=(0.5, 0.999))

    #discriminator_rs = Discriminator()
    #discriminator_rs.cuda()
    #discriminator_rs.train()
    #optim_dis_rs = torch.optim.Adam(discriminator_rs.parameters(), lr=0.0001, betas=(0.5, 0.999))

    clipper = RhoClipper(0, 1)

    iteration = 0

    for epoch in range(epochs):
        dataloader = DataLoader(dataset,
                                batch_size=batchsize,
                                shuffle=True,
                                collate_fn=collator,
                                drop_last=True)
        progress_bar = tqdm(dataloader)

        for i, data in enumerate(progress_bar):
            iteration += 1
            s, t = data

            fake_t, _, _ = generator_st(s)
            fake_s, _, _ = generator_ts(t)

            real_gs, real_gs_logit, _ = discriminator_gs(s)
            real_gt, real_gt_logit, _ = discriminator_gt(t)
            fake_gs, fake_gs_logit, _ = discriminator_gs(fake_s)
            fake_gt, fake_gt_logit, _ = discriminator_gt(fake_t)

            loss = discriminator_loss(fake_gt, real_gt, fake_gt_logit,
                                      real_gt_logit)
            loss += discriminator_loss(fake_gs, real_gs, fake_gs_logit,
                                       real_gs_logit)

            optim_dis_gs.zero_grad()
            optim_dis_gt.zero_grad()
            loss.backward()
            optim_dis_gs.step()
            optim_dis_gt.step()

            fake_t, fake_gen_t_logit, _ = generator_st(s)
            fake_s, fake_gen_s_logit, _ = generator_ts(t)

            fake_sts, _, _ = generator_ts(fake_t)
            fake_tst, _, _ = generator_st(fake_s)

            fake_t_id, fake_t_id_logit, _ = generator_st(t)
            fake_s_id, fake_s_id_logit, _ = generator_ts(s)

            fake_gs, fake_gs_logit, _ = discriminator_gs(fake_s)
            fake_gt, fake_gt_logit, _ = discriminator_gt(fake_t)

            loss = generator_loss(fake_gs, s, fake_gs_logit, fake_gen_s_logit,
                                  fake_sts, fake_s_id, fake_s_id_logit)
            loss += generator_loss(fake_gt, t, fake_gt_logit, fake_gen_t_logit,
                                   fake_tst, fake_t_id, fake_t_id_logit)

            optim_gen_st.zero_grad()
            optim_gen_ts.zero_grad()
            loss.backward()
            optim_gen_st.step()
            optim_gen_ts.step()

            generator_st.apply(clipper)
            generator_ts.apply(clipper)

            if iteration % interval == 0:
                torch.save(generator_st.state_dict(),
                           f"./model/model_st_{iteration}.pt")
                torch.save(generator_ts.state_dict(),
                           f"./model/model_ts_{iteration}.pt")

            print(f"iteration: {iteration} Loss: {loss.data}")
Exemple #29
0
class Solver(object):
    """Solver for training and testing StarGAN."""

    def __init__(self, celeba_loader, rafd_loader, config):
        """Initialize configurations."""

        # Data loader.
        self.celeba_loader = celeba_loader
        self.rafd_loader = rafd_loader

        # Model configurations.
        self.c_dim = config.c_dim
        self.c2_dim = config.c2_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp

        # Training configurations.
        self.dataset = config.dataset
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters
        self.selected_attrs = config.selected_attrs

        # Test configurations.
        self.test_iters = config.test_iters

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

        # Directories.
        self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

    def build_model(self):
        """Create a generator and a discriminator."""
        if self.dataset in ['CelebA', 'RaFD']:
            self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) 
        elif self.dataset in ['Both']:
            self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num)   # 2 for mask vector.
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print('Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
        self.G.load_state_dict(torch.load(G_path))
        self.D.load_state_dict(torch.load(D_path))

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger
        self.logger = Logger(self.log_dir)

    def update_lr(self, g_lr, d_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def tensor2var(self, x, volatile=False):
        """Convert torch tensor to variable."""
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def gradient_penalty(self, y, x, dtype):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).type(dtype)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm-1)**2)

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None):
        """Generate target domain labels for debugging and testing."""
        # Get hair color indices.
        if dataset == 'CelebA':
            hair_color_indices = []
            for i, attr_name in enumerate(selected_attrs):
                if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
                    hair_color_indices.append(i)

        c_trg_list = []
        for i in range(c_dim):
            if dataset == 'CelebA':
                c_trg = c_org.clone()
                if i in hair_color_indices:  # Set one hair color to 1 and the rest to 0.
                    c_trg[:, i] = 1
                    for j in hair_color_indices:
                        if j != i:
                            c_trg[:, j] = 0
                else:
                    c_trg[:, i] = (c_trg[:, i] == 0)  # Reverse attribute value.
            elif dataset == 'RaFD':
                c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)

            c_trg_list.append(self.tensor2var(c_trg, volatile=True))
        return c_trg_list

    def classification_loss(self, logit, target, dataset='CelebA'):
        """Compute binary or softmax cross entropy loss."""
        if dataset == 'CelebA':
            return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)
        elif dataset == 'RaFD':
            return F.cross_entropy(logit, target)

    def train(self):
        """Train StarGAN within a single dataset."""
        # Set data loader.
        if self.dataset == 'CelebA':
            data_loader = self.celeba_loader
        elif self.dataset == 'RaFD':
            data_loader = self.rafd_loader

        # Fetch fixed inputs for debugging.
        data_iter = iter(data_loader)
        x_fixed, c_org = next(data_iter)
        x_fixed = self.tensor2var(x_fixed, volatile=True)
        c_fixed_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)

        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch real images and labels.
            try:
                x_real, label_org = next(data_iter)
            except:
                data_iter = iter(data_loader)
                x_real, label_org = next(data_iter)

            # Generate target domain labels randomly.
            rand_idx = torch.randperm(label_org.size(0))
            label_trg = label_org[rand_idx]

            if self.dataset == 'CelebA':
                c_org = label_org.clone()
                c_trg = label_trg.clone()
            elif self.dataset == 'RaFD':
                c_org = self.label2onehot(label_org, self.c_dim)
                c_trg = self.label2onehot(label_trg, self.c_dim)

            x_real = self.tensor2var(x_real)           # Input images.
            c_org = self.tensor2var(c_org)             # Original domain labels.
            c_trg = self.tensor2var(c_trg)             # Target domain labels.
            label_org = self.tensor2var(label_org)     # Labels for computing classification loss.
            label_trg = self.tensor2var(label_trg)     # Labels for computing classification loss.

            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #

            # Compute loss with real images.
            out_src, out_cls = self.D(x_real)
            d_loss_real = - torch.mean(out_src)
            d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset)

            # Compute loss with fake images.
            x_fake = self.G(x_real, c_trg)
            out_src, out_cls = self.D(x_fake.detach())
            d_loss_fake = torch.mean(out_src)

            # Compute loss for gradient penalty.
            alpha = torch.rand(x_real.size(0), 1, 1, 1).type(self.dtype)
            x_hat = Variable(alpha * x_real.data + (1 - alpha) * x_fake.data, requires_grad=True)
            out_src, _ = self.D(x_hat)
            d_loss_gp = self.gradient_penalty(out_src, x_hat, self.dtype)

            # Backward and optimize.
            d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Logging.
            loss = {}
            loss['D/loss_real'] = d_loss_real.data[0]
            loss['D/loss_fake'] = d_loss_fake.data[0]
            loss['D/loss_cls'] = d_loss_cls.data[0]
            loss['D/loss_gp'] = d_loss_gp.data[0]

            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #
            
            if (i+1) % self.n_critic == 0:
                # Original-to-target domain.
                x_fake = self.G(x_real, c_trg)
                out_src, out_cls = self.D(x_fake)
                g_loss_fake = - torch.mean(out_src)
                g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)

                # Target-to-original domain.
                x_reconst = self.G(x_fake, c_org)
                g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

                # Backward and optimize.
                g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging.
                loss['G/loss_fake'] = g_loss_fake.data[0]
                loss['G/loss_rec'] = g_loss_rec.data[0]
                loss['G/loss_cls'] = g_loss_cls.data[0]

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            # Print out training information.
            if (i+1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i+1)

            # Translate fixed images for debugging.
            if (i+1) % self.sample_step == 0:
                x_fake_list = [x_fixed]
                for c_fixed in c_fixed_list:
                    x_fake_list.append(self.G(x_fixed, c_fixed))
                x_concat = torch.cat(x_fake_list, dim=3)
                sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
                save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                print('Saved real and fake images into {}...'.format(sample_path))

            # Save model checkpoints.
            if (i+1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
                D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                print('Saved model checkpoints into {}...'.format(self.model_save_dir))

            # Decay learning rates.
            if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

    def train_multi(self):
        """Train StarGAN with multiple datasets."""        
        # Data iterators.
        celeba_iter = iter(self.celeba_loader)
        rafd_iter = iter(self.rafd_loader)

        # Fetch fixed inputs for debugging.
        x_fixed, c_org = next(celeba_iter)
        x_fixed = self.tensor2var(x_fixed, volatile=True)
        c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
        c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
        zero_celeba = self.tensor2var(torch.zeros(x_fixed.size(0), self.c_dim))            # Zero vector for CelebA.
        zero_rafd = self.tensor2var(torch.zeros(x_fixed.size(0), self.c2_dim))             # Zero vector for RaFD.
        mask_celeba = self.tensor2var(self.label2onehot(torch.zeros(x_fixed.size(0)), 2))  # Mask vector: [1, 0].
        mask_rafd = self.tensor2var(self.label2onehot(torch.ones(x_fixed.size(0)), 2))     # Mask vector: [0, 1].

        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):
            for dataset in ['CelebA', 'RaFD']:

                # =================================================================================== #
                #                             1. Preprocess input data                                #
                # =================================================================================== #
                
                # Fetch real images and labels.
                data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter
                
                try:
                    x_real, label_org = next(data_iter)
                except:
                    if dataset == 'CelebA':
                        celeba_iter = iter(self.celeba_loader)
                        x_real, label_org = next(celeba_iter)
                    elif dataset == 'RaFD':
                        rafd_iter = iter(self.rafd_loader)
                        x_real, label_org = next(rafd_iter)

                # Generate target domain labels randomly.
                rand_idx = torch.randperm(label_org.size(0))
                label_trg = label_org[rand_idx]

                if dataset == 'CelebA':
                    c_org = label_org.clone()
                    c_trg = label_trg.clone()
                    zero = torch.zeros(x_real.size(0), self.c2_dim)
                    mask = self.label2onehot(torch.zeros(x_real.size(0)), 2)
                    c_org = torch.cat([c_org, zero, mask], dim=1)
                    c_trg = torch.cat([c_trg, zero, mask], dim=1)
                elif dataset == 'RaFD':
                    c_org = self.label2onehot(label_org, self.c2_dim)
                    c_trg = self.label2onehot(label_trg, self.c2_dim)
                    zero = torch.zeros(x_real.size(0), self.c_dim)
                    mask = self.label2onehot(torch.ones(x_real.size(0)), 2)
                    c_org = torch.cat([zero, c_org, mask], dim=1)
                    c_trg = torch.cat([zero, c_trg, mask], dim=1)

                x_real = self.tensor2var(x_real)             # Input images.
                c_org = self.tensor2var(c_org)               # Original domain labels.
                c_trg = self.tensor2var(c_trg)               # Target domain labels.
                label_org = self.tensor2var(label_org)       # Labels for computing classification loss.
                label_trg = self.tensor2var(label_trg)       # Labels for computing classification loss.

                # =================================================================================== #
                #                             2. Train the discriminator                              #
                # =================================================================================== #

                # Compute loss with real images.
                out_src, out_cls = self.D(x_real)
                out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:]
                d_loss_real = - torch.mean(out_src)
                d_loss_cls = self.classification_loss(out_cls, label_org, dataset)

                # Compute loss with fake images.
                x_fake = self.G(x_real, c_trg)
                out_src, _ = self.D(x_fake.detach())
                d_loss_fake = torch.mean(out_src)

                # Compute loss for gradient penalty.
                alpha = torch.rand(x_real.size(0), 1, 1, 1).type(self.dtype)
                x_hat = Variable(alpha * x_real.data + (1 - alpha) * x_fake.data, requires_grad=True)
                out_src, _ = self.D(x_hat)
                d_loss_gp = self.gradient_penalty(out_src, x_hat, self.dtype)

                # Backward and optimize.
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging.
                loss = {}
                loss['D/loss_real'] = d_loss_real.data[0]
                loss['D/loss_fake'] = d_loss_fake.data[0]
                loss['D/loss_cls'] = d_loss_cls.data[0]
                loss['D/loss_gp'] = d_loss_gp.data[0]

                # =================================================================================== #
                #                               3. Train the generator                                #
                # =================================================================================== #

                if (i+1) % self.n_critic == 0:
                    # Original-to-target domain.
                    x_fake = self.G(x_real, c_trg)
                    out_src, out_cls = self.D(x_fake)
                    out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:]
                    g_loss_fake = - torch.mean(out_src)
                    g_loss_cls = self.classification_loss(out_cls, label_trg, dataset)

                    # Target-to-original domain.
                    x_reconst = self.G(x_fake, c_org)
                    g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

                    # Backward and optimize.
                    g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging.
                    loss['G/loss_fake'] = g_loss_fake.data[0]
                    loss['G/loss_rec'] = g_loss_rec.data[0]
                    loss['G/loss_cls'] = g_loss_cls.data[0]

                # =================================================================================== #
                #                                 4. Miscellaneous                                    #
                # =================================================================================== #

                # Print out training info.
                if (i+1) % self.log_step == 0:
                    et = time.time() - start_time
                    et = str(datetime.timedelta(seconds=et))[:-7]
                    log = "Elapsed [{}], Iteration [{}/{}], Dataset [{}]".format(et, i+1, self.num_iters, dataset)
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(tag, value, i+1)

            # Translate fixed images for debugging.
            if (i+1) % self.sample_step == 0:
                x_fake_list = [x_fixed]
                for c_fixed in c_celeba_list:
                    c_trg = torch.cat([c_fixed, zero_rafd, mask_celeba], dim=1)
                    x_fake_list.append(self.G(x_fixed, c_trg))
                for c_fixed in c_rafd_list:
                    c_trg = torch.cat([zero_celeba, c_fixed, mask_rafd], dim=1)
                    x_fake_list.append(self.G(x_fixed, c_trg))
                x_concat = torch.cat(x_fake_list, dim=3)
                sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
                save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                print('Saved real and fake images into {}...'.format(sample_path))

            # Save model checkpoints.
            if (i+1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
                D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                print('Saved model checkpoints into {}...'.format(self.model_save_dir))

            # Decay learning rates.
            if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

    def test(self):
        """Translate images using StarGAN trained on a single dataset."""
        # Load the trained generator.
        self.restore_model(self.test_iters)
        
        # Set data loader.
        if self.dataset == 'CelebA':
            data_loader = self.celeba_loader
        elif self.dataset == 'RaFD':
            data_loader = self.rafd_loader

        for i, (x_real, c_org) in enumerate(data_loader):
            
            # Prepare input images and target domain labels.
            x_real = self.tensor2var(x_real, volatile=True)
            c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
            
            # Translate images.
            x_fake_list = [x_real]
            for c_trg in c_trg_list:
                x_fake_list.append(self.G(x_real, c_trg))
            
            # Save the translated images.
            x_concat = torch.cat(x_fake_list, dim=3)
            result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
            save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
            print('Saved real and fake images into {}...'.format(result_path))

    def test_multi(self):
        """Translate images using StarGAN trained on multiple datasets."""
        # Load the trained generator.
        self.restore_model(self.test_iters)

        for i, (x_real, c_org) in enumerate(self.celeba_loader):

            # Prepare input images and target domain labels.
            x_real = self.tensor2var(x_real, volatile=True)
            c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
            c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
            zero_celeba = self.tensor2var(torch.zeros(x_real.size(0), self.c_dim))            # Zero vector for CelebA.
            zero_rafd = self.tensor2var(torch.zeros(x_real.size(0), self.c2_dim))             # Zero vector for RaFD.
            mask_celeba = self.tensor2var(self.label2onehot(torch.zeros(x_real.size(0)), 2))  # Mask vector: [1, 0].
            mask_rafd = self.tensor2var(self.label2onehot(torch.ones(x_real.size(0)), 2))     # Mask vector: [0, 1].

            # Translate images.
            x_fake_list = [x_real]
            for c_celeba in c_celeba_list:
                c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1)
                x_fake_list.append(self.G(x_real, c_trg))
            for c_rafd in c_rafd_list:
                c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1)
                x_fake_list.append(self.G(x_real, c_trg))

            # Save the translated images.
            x_concat = torch.cat(x_fake_list, dim=3)
            result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
            save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
            print('Saved real and fake images into {}...'.format(result_path))
Exemple #30
0
def train(args):
    # set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('Configuration:')
    print('\n'.join('\t{:15} {}'.format(k + ':', str(v)) for k, v in sorted(dict(vars(args)).items())))
    print()

    model_path = os.path.join(args.export, 'model.pt')
    config_path = os.path.join(args.export, 'config.json')
    export_config(args, config_path)
    check_path(model_path)

    ###############################################################################
    # Load data
    ###############################################################################

    n_trg = len(args.trg)
    lang2id = {lang: i for i, lang in enumerate(args.lang)}
    dom2id = {dom: i for i, dom in enumerate(args.dom)}
    src_id, dom_id = lang2id[args.src], dom2id[args.sup_dom]
    trg_ids = [lang2id[t] for t in args.trg]

    unlabeled_set = torch.load(args.unlabeled)
    train_set = torch.load(args.train)
    val_set = torch.load(args.val)
    test_set = torch.load(args.test)
    vocabs = [train_set[lang]['vocab'] for lang in args.lang]
    unlabeled = to_device([[batchify(unlabeled_set[lang][dom], args.batch_size) for dom in args.dom] for lang in args.lang], args.cuda)
    train_x, train_y, train_l = to_device(train_set[args.src][args.sup_dom], args.cuda)
    val_ds = [to_device(val_set[t][args.sup_dom], args.cuda) for t in args.trg]
    test_ds = [to_device(test_set[t][args.sup_dom], args.cuda) for t in args.trg]

    if args.sample_unlabeled > 0:
        print('Downsampling unlabeled set...')
        print()
        unlabeled = [[x[:(args.sample_unlabeled // args.batch_size)] for x in t] for t in unlabeled]
    if args.sample_train > 0:
        print('Downsampling training set...')
        print()
        train_x, train_y, train_l = sample([train_x, train_y, train_l], args.sample_train, True)

    senti_train = DataLoader(SentiDataset(train_x, train_y, train_l), batch_size=args.clf_batch_size)
    train_iter = iter(senti_train)
    train_ds = DataLoader(SentiDataset(train_x, train_y, train_l), batch_size=args.test_batch_size)
    val_ds = [DataLoader(SentiDataset(*ds), batch_size=args.test_batch_size) for ds in val_ds]
    test_ds = [DataLoader(SentiDataset(*ds), batch_size=args.test_batch_size) for ds in test_ds]

    lexicons = []
    for tid, tlang in zip(trg_ids, args.trg):
        sv, tv = vocabs[src_id], vocabs[tid]
        lex, lexsz = load_lexicon('data/muse/{}-{}.0-5000.txt'.format(args.src, tlang), sv, tv)
        lexicons.append((lex, lexsz, tid))

    ###############################################################################
    # Build the model
    ###############################################################################
    if args.resume:
        model, dis, lm_opt, dis_opt = model_load(args.resume)

    else:
        model = XLXDClassifier(n_classes=2, clf_p=args.dropoutc, n_langs=len(args.lang), n_doms=len(args.dom),
                               vocab_sizes=list(map(len, vocabs)), emb_size=args.emb_dim, hidden_size=args.hid_dim,
                               num_layers=args.nlayers, num_share=args.nshare, tie_weights=args.tie_softmax,
                               output_p=args.dropouto, hidden_p=args.dropouth, input_p=args.dropouti, embed_p=args.dropoute,
                               weight_p=args.dropoutw, alpha=2, beta=1)
        dis = Discriminator(args.emb_dim, args.dis_hid_dim, len(args.lang), args.dis_nlayers, args.dropoutd)

        if args.mwe:
            mwe = []
            for lid, (v, lang) in enumerate(zip(vocabs, args.lang)):
                x, count = load_vectors_with_vocab(args.mwe_path.format(lang), v, -1)
                model.encoders[lid].weight.data.copy_(torch.from_numpy(x))
                freeze_net(model.encoders[lid])

        params = [{'params': model.models.parameters(),  'lr': args.lr},
                  {'params': model.clfs.parameters(), 'lr': args.lr}]
        if args.optimizer == 'sgd':
            lm_opt = torch.optim.SGD(params, lr=args.lr, weight_decay=args.wdecay)
            dis_opt = torch.optim.SGD(dis.parameters(), lr=args.dis_lr, weight_decay=args.wdecay)
        if args.optimizer == 'adam':
            lm_opt = torch.optim.Adam(params, lr=args.lr, weight_decay=args.wdecay, betas=(args.beta1, 0.999))
            dis_opt = torch.optim.Adam(dis.parameters(), lr=args.dis_lr, weight_decay=args.wdecay, betas=(args.beta1, 0.999))

    crit = nn.CrossEntropyLoss()

    bs = args.batch_size
    n_doms = len(args.dom)
    n_langs = len(args.lang)
    dis_y = to_device(torch.arange(n_langs).unsqueeze(-1).expand(n_langs, bs).contiguous().view(-1), args.cuda)

    if args.cuda:
        model.cuda(), dis.cuda(), crit.cuda()
    else:
        model.cpu(), dis.cpu(), crit.cpu()

    print('Parameters:')
    total_params = sum([np.prod(x.size()) for x in model.parameters()])
    print('\ttotal params:   {}'.format(total_params))
    print('\tparam list:     {}'.format(len(list(model.parameters()))))
    for name, x in model.named_parameters():
        print('\t' + name + '\t', tuple(x.size()))
    for name, x in dis.named_parameters():
        print('\t' + name + '\t', tuple(x.size()))
    print()

    ###############################################################################
    # Training code
    ###############################################################################

    bptt = args.bptt
    best_accs = {tlang: 0. for tlang in args.trg}
    final_test_accs = {tlang: 0. for tlang in args.trg}
    print('Traning:')
    print_line()
    ptrs = np.zeros((len(args.lang), len(args.dom)), dtype=np.int64)  # pointers for reading unlabeled data, of shape (n_lang, n_dom)
    total_loss = np.zeros((len(args.lang), len(args.dom)))  # shape (n_lang, n_dom)
    total_clf_loss = 0
    total_dis_loss = 0
    start_time = time.time()
    model.train()
    model.reset()
    for step in range(args.max_steps):
        loss = 0
        lm_opt.zero_grad()
        dis_opt.zero_grad()

        if not args.mwe:
            seq_len = max(5, int(np.random.normal(bptt if np.random.random() < 0.95 else bptt / 2., 5)))
            lr0 = lm_opt.param_groups[0]['lr']
            lm_opt.param_groups[0]['lr'] = lr0 * seq_len / args.bptt

            # language modeling loss
            dis_x = []
            for lid, t in enumerate(unlabeled):
                for did, lm_x in enumerate(t):
                    if ptrs[lid, did] + bptt + 1 > lm_x.size(0):
                        ptrs[lid, did] = 0
                        model.reset(lid=lid, did=did)
                    p = ptrs[lid, did]
                    xs = lm_x[p: p + bptt].t().contiguous()
                    ys = lm_x[p + 1: p + 1 + bptt].t().contiguous()
                    lm_raw_loss, lm_loss, hid = model.lm_loss(xs, ys, lid=lid, did=did, return_h=True)
                    loss = loss + lm_loss * args.lambd_lm
                    total_loss[lid, did] += lm_raw_loss.item()
                    ptrs[lid, did] += bptt
                    if did == dom_id:
                        dis_x.append(hid[-1].mean(1))

            # language adversarial loss
            dis_x_rev = GradReverse.apply(torch.cat(dis_x, 0))
            dis_loss = crit(dis(dis_x_rev), dis_y)
            loss = loss + args.lambd_dis * dis_loss
            total_dis_loss += dis_loss.item()
            loss.backward()

        # sentiment classification loss
        try:
            xs, ys, ls = next(train_iter)
        except StopIteration:
            train_iter = iter(senti_train)
            xs, ys, ls = next(train_iter)
        clf_loss = crit(model(xs, ls, src_id, dom_id), ys)
        total_clf_loss += clf_loss.item()
        (args.lambd_clf * clf_loss).backward()

        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        if args.dis_clip > 0:
            for x in dis.parameters():
                x.data.clamp_(-args.dis_clip, args.dis_clip)
        lm_opt.step()
        dis_opt.step()
        if not args.mwe:
            lm_opt.param_groups[0]['lr'] = lr0

        if (step + 1) % args.log_interval == 0:
            total_loss /= args.log_interval
            total_clf_loss /= args.log_interval
            total_dis_loss /= args.log_interval
            elapsed = time.time() - start_time
            print('| step {:5d} | lr {:05.5f} | ms/batch {:7.2f} | lm_loss {:7.4f} | avg_ppl {:7.2f} | clf {:7.4f} | dis {:7.4f} |'.format(
                step, lm_opt.param_groups[0]['lr'], elapsed * 1000 / args.log_interval,
                total_loss.mean(), np.exp(total_loss).mean(), total_clf_loss, total_dis_loss))
            total_loss[:, :], total_clf_loss, total_dis_loss = 0, 0, 0
            start_time = time.time()

        if (step + 1) % args.val_interval == 0:
            model.eval()
            with torch.no_grad():
                train_acc = evaluate(model, train_ds, src_id, dom_id)
                val_accs = [evaluate(model, ds, tid, dom_id) for tid, ds in zip(trg_ids, val_ds)]
                test_accs = [evaluate(model, ds, tid, dom_id) for tid, ds in zip(trg_ids, test_ds)]
                bdi_accs = [compute_nn_accuracy(model.encoder_weight(src_id),
                                                model.encoder_weight(tid),
                                                lexicon, 10000, lexicon_size=lexsz) for lexicon, lexsz, tid in lexicons]
                print_line()
                print(('| step {:5d} | train {:.4f} |' +
                       ' val' + ' {} {:.4f}' * n_trg + ' |' +
                       ' test' + ' {} {:.4f}' * n_trg + ' |' +
                       ' bdi' + ' {} {:.4f}' * n_trg + ' |').format(step, train_acc,
                                                                    *sum([[tlang, acc] for tlang, acc in zip(args.trg, val_accs)], []),
                                                                    *sum([[tlang, acc] for tlang, acc in zip(args.trg, test_accs)], []),
                                                                    *sum([[tlang, acc] for tlang, acc in zip(args.trg, bdi_accs)], [])))
                print_line()
                print('saving model to {}'.format(model_path.replace('.pt', '_final.pt')))
                model_save(model, dis, lm_opt, dis_opt, model_path.replace('.pt', '_final.pt'))
                for tlang, val_acc, test_acc in zip(args.trg, val_accs, test_accs):
                    if val_acc > best_accs[tlang]:
                        save_path = model_path.replace('.pt', '_{}.pt'.format(tlang))
                        print('saving {} model to {}'.format(tlang, save_path))
                        model_save(model, dis, lm_opt, dis_opt, save_path)
                        best_accs[tlang] = val_acc
                        final_test_accs[tlang] = test_acc
                print_line()
            model.train()
            start_time = time.time()

    print_line()
    print('Training ended with {} steps'.format(step + 1))
    print(('Best val acc:             ' + ' {} {:.4f}' * n_trg).format(*sum([[tlang, best_accs[tlang]] for tlang in args.trg], [])))
    print(('Test acc (w/ early stop): ' + ' {} {:.4f}' * n_trg).format(*sum([[tlang, final_test_accs[tlang]] for tlang in args.trg], [])))
    print(('Test acc (w/o early stop):' + ' {} {:.4f}' * n_trg).format(*sum([[tlang, acc] for tlang, acc in zip(args.trg, test_accs)], [])))
Exemple #31
0
    #test_dataset = AudioDataset(data_type='test')#获取路径名文件夹内每个音频文件
    train_data_loader = DataLoader(dataset=train_dataset,
                                   batch_size=BATCH_SIZE,
                                   shuffle=True,
                                   num_workers=4)
    #test_data_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)#加载文件
    '''
    # generate reference batch
    ref_batch = train_dataset.reference_batch(BATCH_SIZE) #获得随机选取的参考 批次大小的 张量
    '''

    # create D and G instances
    discriminator = Discriminator()
    generator = Generator()  #模型初始化
    if torch.cuda.is_available():  #是否有GPU
        discriminator.cuda()  #.cuda()转GPU
        generator.cuda()
        # ref_batch = ref_batch.cuda()
    # ref_batch = Variable(ref_batch)
    # Variable就是 变量 的意思。实质上也就是可以变化的量,区别于int变量,它是一种可以变化的变量,这正好就符合了反向传播,参数更新的属性。
#具体来说,在pytorch中的Variable就是一个存放会变化值的地理位置,里面的值会不停发生片花,就像一个装鸡蛋的篮子,鸡蛋数会不断发生变化。那谁是里面的鸡蛋呢,自然就是pytorch中的tensor了

    print("# generator parameters:",
          sum(param.numel() for param in
              generator.parameters()))  #通过Module.parameters()获取网络的参数
    print("# discriminator parameters:",
          sum(param.numel()
              for param in discriminator.parameters()))  #numel()函数:返回数组中元素的个数
    # optimizers
    g_optimizer = optim.RMSprop(
        generator.parameters(),