Example #1
0
])

dataset = ImageFolder(opt.data_path, transform=transform)

dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, drop_last=True)

netd = NetD(opt)
netg = NetG(opt)

if opt.netd_path:
    netd.load_state_dict(torch.load(opt.netd_path, map_location=lambda storage, loc: storage))
if opt.netg_path:
    netg.load_state_dict(torch.load(opt.netg_path, map_location=lambda storage, loc: storage))

optimizer_g = Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
optimizer_d = Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))

criterion = nn.BCELoss()

true_labels = Variable(torch.ones(opt.batch_size))
fake_labels = Variable(torch.zeros(opt.batch_size))
fix_noises = Variable(torch.randn(opt.batch_size, opt.nz, 1, 1))
noises = Variable(torch.randn(opt.batch_size, opt.nz, 1, 1))

if opt.use_gpu:
    netd.cuda()
    netg.cuda()
    criterion.cuda()
    true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
    fix_noises, noises = fix_noises.cuda(), noises.cuda()
Example #2
0
def train(**kwargs):
    opt._parse(kwargs)

    demoer = Evaluator(opt)

    anime_data = AnimeData(opt.data_path)
    anime_dataloader = DataLoader(anime_data,
                                  batch_size=opt.batch_size,
                                  shuffle=True)

    noise_data = NoiseData(opt.noise_size, len(anime_data))
    noise_dataloader = DataLoader(noise_data,
                                  batch_size=opt.batch_size,
                                  shuffle=True)

    net_G = NetG(opt)
    net_D = NetD(opt)

    if opt.use_gpu:
        net_G = net_G.cuda()
        net_D = net_D.cuda()

    criterion = torch.nn.BCELoss()
    optimizer_G = torch.optim.Adam(net_G.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.beta1, opt.beta2))
    optimizer_D = torch.optim.Adam(net_D.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.beta1, opt.beta2))

    loss_D_meteor = meter.AverageValueMeter()
    loss_G_meteor = meter.AverageValueMeter()

    if opt.netd_path is not None:
        net_D.load(opt.netd_path)
    if opt.netg_path is not None:
        net_G.load(opt.netg_path)

    for epoch in range(opt.max_epochs):
        loss_D_meteor.reset()
        loss_G_meteor.reset()

        num_batch = len(anime_dataloader)
        generator = enumerate(zip(anime_dataloader, noise_dataloader))
        for ii, (true_image, feature_map) in tqdm(generator,
                                                  total=num_batch,
                                                  ascii=True):
            num_data = true_image.shape[0]
            true_targets = torch.ones(num_data)
            fake_targets = torch.zeros(num_data)

            if opt.use_gpu:
                feature_map = feature_map.cuda()
                true_image = true_image.cuda()
                true_targets = true_targets.cuda()
                fake_targets = fake_targets.cuda()

            # Train discriminator
            if ii % opt.every_d == 0:
                optimizer_D.zero_grad()
                net_G.set_requires_grad(False)
                net_D.set_requires_grad(True)

                fake_image = net_G(feature_map)
                fake_score = net_D(fake_image)
                true_score = net_D(true_image)
                loss_D = criterion(fake_score, fake_targets) + \
                    criterion(true_score, true_targets)
                loss_D.backward()
                optimizer_D.step()

                loss_D_meteor.add(loss_D.detach().item())

                if os.path.exists(opt.debug_file):
                    import ipdb
                    ipdb.set_trace()

            # Train generator
            if ii % opt.every_g == 0:
                optimizer_G.zero_grad()
                net_G.set_requires_grad(True)
                net_D.set_requires_grad(False)

                fake_image = net_G(feature_map)
                fake_score = net_D(fake_image)
                loss_G = criterion(fake_score, true_targets)
                loss_G.backward()
                optimizer_G.step()

                loss_G_meteor.add(loss_G.detach().item())

                if os.path.exists(opt.debug_file):
                    import ipdb
                    ipdb.set_trace()

        gan_log = "Epoch {epoch:0>2d}: loss_D - {loss_D}, loss_G - {loss_G}".format(
            epoch=epoch + 1,
            loss_D=loss_D_meteor.value()[0],
            loss_G=loss_G_meteor.value()[0],
        )
        print(gan_log)

        if epoch % opt.save_freq == opt.save_freq - 1:
            demoer.evaluate(net_G)
            net_D.save(opt.save_model_path)
            net_G.save(opt.save_model_path)
        time.sleep(0.5)
Example #3
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(opt.image_size),
        torchvision.transforms.CenterCrop(opt.image_size),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                         std=(0.5, 0.5, 0.5))
    ])

    dataset = torchvision.datasets.ImageFolder(opt.data_path,
                                               transform=transforms)

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=opt.num_workers,
                                             drop_last=True)

    # 1、定义神经网络
    D = NetD(opt)
    G = NetG(opt)

    map_location = lambda storage, loc: storage
    if opt.netd_path:
        D.load_state_dict(torch.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        G.load_state_dict(torch.load(opt.netg_path, map_location=map_location))

    # 2、定义优化器和损失
    d_optim = torch.optim.Adam(D.parameters(),
                               opt.d_learning_rate,
                               betas=(opt.optim_beta1, 0.999))
    g_optim = torch.optim.Adam(G.parameters(),
                               opt.g_learning_rate,
                               betas=(opt.optim_beta1, 0.999))
    criterion = torch.nn.BCELoss()

    # 真图片label为1,假图片label为0
    real_labels = Variable(torch.ones(opt.batch_size))
    fake_labels = Variable(torch.zeros(opt.batch_size))

    if torch.cuda.is_available():
        D.cuda()
        G.cuda()
        criterion.cuda()
        real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

    # 3、可视化训练过程
    for epoch in range(opt.num_epochs):
        for step, (images, _) in tqdm.tqdm(enumerate(dataloader)):

            if step % opt.d_every == 0:
                # 1、训练判别器
                d_optim.zero_grad()

                ## 尽可能的把真图片判别为正确
                d_real_data = Variable(images)
                d_real_data = d_real_data.cuda() if torch.cuda.is_available(
                ) else d_real_data
                d_real_decision = D(d_real_data)
                d_real_error = criterion(d_real_decision, real_labels)
                d_real_error.backward()

                ## 尽可能把假图片判别为错误
                d_gen_input = Variable(
                    torch.randn(opt.batch_size, opt.nz, 1, 1))
                d_gen_input = d_gen_input.cuda() if torch.cuda.is_available(
                ) else d_gen_input
                d_fake_data = G(d_gen_input).detach()
                d_fake_decision = D(d_fake_data)
                d_fake_error = criterion(d_fake_decision, fake_labels)
                d_fake_error.backward()
                d_optim.step(
                )  # Only optimizes D's parameters; changes based on stored gradients from backward()

            if step % opt.g_every == 0:
                # 2、训练生成器
                g_optim.zero_grad()

                ## 尽可能让判别器把假图片判别为正确
                g_gen_input = Variable(
                    torch.randn(opt.batch_size, opt.nz, 1, 1))
                g_gen_input = g_gen_input.cuda() if torch.cuda.is_available(
                ) else g_gen_input
                g_fake_data = G(g_gen_input)
                g_fake_decision = D(g_fake_data)
                g_fake_error = criterion(g_fake_decision, real_labels)
                g_fake_error.backward()

                g_optim.step()

        if step % opt.epoch_every == 0:
            print("%s, %s, D: %s/%s G: %s" %
                  (step, g_fake_decision.cpu().data.numpy().mean(),
                   d_real_error.cpu().data[0], d_fake_error.cpu().data[0],
                   g_fake_error.cpu().data[0]))

            # 保存模型、图片
            torchvision.utils.save_image(g_fake_data.data[:36],
                                         '%s/%s.png' %
                                         (opt.save_img_path, epoch),
                                         normalize=True,
                                         range=(-1, 1))
            torch.save(D.state_dict(),
                       '%s/netd_%s.pth' % (opt.checkpoints_path, epoch))
            torch.save(G.state_dict(),
                       '%s/netg_%s.pth' % (opt.checkpoints_path, epoch))
Example #4
0
    ngf = 64


opt = Opt()

noise_data = NoiseData(NOISE_SIZE, BATCH_SIZE)
noise_dataloader = DataLoader(noise_data, batch_size=BATCH_SIZE)
noise_iter = iter(noise_dataloader)

feature_map = next(noise_iter)

net_G = NetG(opt)
net_D = NetD(opt)

criterion = torch.nn.MSELoss()
optimzer = torch.optim.SGD(net_D.parameters(), lr=0.1)


# ============================= #
# Testing
# ============================= #
def test_networks():
    generated = net_G(feature_map)
    assert generated.shape == torch.Tensor(BATCH_SIZE, 3, 96, 96).shape
    assert torch.max(generated) <= 1
    assert torch.min(generated) >= 0

    res_generated = net_D(generated)
    assert res_generated.shape == torch.Tensor(BATCH_SIZE).shape
    assert torch.max(res_generated) <= 1
    assert torch.max(res_generated) >= 0
Example #5
0
                        shuffle=True,
                        num_workers=opt.num_workers,
                        drop_last=True)

netd = NetD(opt)
netg = NetG(opt)

if opt.netd_path:
    netd.load_state_dict(
        torch.load(opt.netd_path, map_location=lambda storage, loc: storage))
if opt.netg_path:
    netg.load_state_dict(
        torch.load(opt.netg_path, map_location=lambda storage, loc: storage))

optimizer_g = Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
optimizer_d = Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))

criterion = nn.BCELoss()

true_labels = Variable(torch.ones(opt.batch_size))
fake_labels = Variable(torch.zeros(opt.batch_size))
fix_noises = Variable(torch.randn(opt.batch_size, opt.nz, 1, 1))
noises = Variable(torch.randn(opt.batch_size, opt.nz, 1, 1))

if opt.use_gpu:
    netd.cuda()
    netg.cuda()
    criterion.cuda()
    true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
    fix_noises, noises = fix_noises.cuda(), noises.cuda()
Example #6
0
mlp = MLP(in_features=args.hidden, nclass=labelsA.max().item() + 1)
Dnet = NetD(nhid=args.hidden)
model = GCN(nfeat=featuresA.shape[1],
            nhid=args.hidden,
            nclass=labelsA.max().item() + 1,
            dropout=args.dropout)
#model.load_state_dict(torch.load('init2.pkl'))
#for item in model.parameters():
#    print(item)
optimizer = optim.Adam(model.parameters(),
                       lr=args.lr,
                       weight_decay=args.weight_decay)
optimizer_mlp = optim.Adam(mlp.parameters(),
                           lr=0.01,
                           weight_decay=args.weight_decay)
dis_optimizer = optim.SGD(Dnet.parameters(),
                          lr=args.lr,
                          weight_decay=args.weight_decay)

if args.cuda:
    model.cuda()
    Dnet.cuda()
    featuresA = featuresA.cuda()
    featuresB = featuresB.cuda()
    adjA = adjA.cuda()
    adjB = adjB.cuda()
    labelsA = labelsA.cuda()
    labelsB = labelsB.cuda()

ones = torch.tensor([1.0 for i in range(8 * rbatch_size)])