예제 #1
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(opt.image_size),
        torchvision.transforms.CenterCrop(opt.image_size),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                         std=(0.5, 0.5, 0.5))
    ])

    dataset = torchvision.datasets.ImageFolder(opt.data_path,
                                               transform=transforms)

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=opt.num_workers,
                                             drop_last=True)

    # 1、定义神经网络
    D = NetD(opt)
    G = NetG(opt)

    map_location = lambda storage, loc: storage
    if opt.netd_path:
        D.load_state_dict(torch.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        G.load_state_dict(torch.load(opt.netg_path, map_location=map_location))

    # 2、定义优化器和损失
    d_optim = torch.optim.Adam(D.parameters(),
                               opt.d_learning_rate,
                               betas=(opt.optim_beta1, 0.999))
    g_optim = torch.optim.Adam(G.parameters(),
                               opt.g_learning_rate,
                               betas=(opt.optim_beta1, 0.999))
    criterion = torch.nn.BCELoss()

    # 真图片label为1,假图片label为0
    real_labels = Variable(torch.ones(opt.batch_size))
    fake_labels = Variable(torch.zeros(opt.batch_size))

    if torch.cuda.is_available():
        D.cuda()
        G.cuda()
        criterion.cuda()
        real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

    # 3、可视化训练过程
    for epoch in range(opt.num_epochs):
        for step, (images, _) in tqdm.tqdm(enumerate(dataloader)):

            if step % opt.d_every == 0:
                # 1、训练判别器
                d_optim.zero_grad()

                ## 尽可能的把真图片判别为正确
                d_real_data = Variable(images)
                d_real_data = d_real_data.cuda() if torch.cuda.is_available(
                ) else d_real_data
                d_real_decision = D(d_real_data)
                d_real_error = criterion(d_real_decision, real_labels)
                d_real_error.backward()

                ## 尽可能把假图片判别为错误
                d_gen_input = Variable(
                    torch.randn(opt.batch_size, opt.nz, 1, 1))
                d_gen_input = d_gen_input.cuda() if torch.cuda.is_available(
                ) else d_gen_input
                d_fake_data = G(d_gen_input).detach()
                d_fake_decision = D(d_fake_data)
                d_fake_error = criterion(d_fake_decision, fake_labels)
                d_fake_error.backward()
                d_optim.step(
                )  # Only optimizes D's parameters; changes based on stored gradients from backward()

            if step % opt.g_every == 0:
                # 2、训练生成器
                g_optim.zero_grad()

                ## 尽可能让判别器把假图片判别为正确
                g_gen_input = Variable(
                    torch.randn(opt.batch_size, opt.nz, 1, 1))
                g_gen_input = g_gen_input.cuda() if torch.cuda.is_available(
                ) else g_gen_input
                g_fake_data = G(g_gen_input)
                g_fake_decision = D(g_fake_data)
                g_fake_error = criterion(g_fake_decision, real_labels)
                g_fake_error.backward()

                g_optim.step()

        if step % opt.epoch_every == 0:
            print("%s, %s, D: %s/%s G: %s" %
                  (step, g_fake_decision.cpu().data.numpy().mean(),
                   d_real_error.cpu().data[0], d_fake_error.cpu().data[0],
                   g_fake_error.cpu().data[0]))

            # 保存模型、图片
            torchvision.utils.save_image(g_fake_data.data[:36],
                                         '%s/%s.png' %
                                         (opt.save_img_path, epoch),
                                         normalize=True,
                                         range=(-1, 1))
            torch.save(D.state_dict(),
                       '%s/netd_%s.pth' % (opt.checkpoints_path, epoch))
            torch.save(G.state_dict(),
                       '%s/netg_%s.pth' % (opt.checkpoints_path, epoch))
예제 #2
0
            noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
            fake_img = netg(noises)
            fake_output = netd(fake_img)
            error_g = criterion(fake_output, true_labels)

            print('error_g:,', error_g.data[0])
            writer.add_scalar('data/error_g', error_g.data[0], ii)

            error_g.backward()
            optimizer_g.step()

        if (ii + 1) % opt.plot_every == 0:
            fix_fake_imgs = netg(fix_noises)

            fake = fix_fake_imgs[:64] * 0.5 + 0.5
            real = real_img[:64] * 0.5 + 0.5

            writer.add_image('image/fake_Image', fake, ii)
            writer.add_image('image/real_Image', real, ii)

            print('epoch[{}:{}],ii[{}:{}]'.format(epoch, opt.max_epoch, ii, len(dataloader)))

        if (epoch + 1) % opt.decay_every == 0:
            utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
                             range=(-1, 1))
            torch.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            torch.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            optimizer_g = torch.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
            optimizer_d = torch.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
예제 #3
0
            writer.add_scalar('data/error_g', error_g.data[0], ii)

            error_g.backward()
            optimizer_g.step()

        if (ii + 1) % opt.plot_every == 0:
            fix_fake_imgs = netg(fix_noises)

            fake = fix_fake_imgs[:64] * 0.5 + 0.5
            real = real_img[:64] * 0.5 + 0.5

            writer.add_image('image/fake_Image', fake, ii)
            writer.add_image('image/real_Image', real, ii)

            print('epoch[{}:{}],ii[{}:{}]'.format(epoch, opt.max_epoch, ii,
                                                  len(dataloader)))

        if (epoch + 1) % opt.decay_every == 0:
            utils.save_image(fix_fake_imgs.data[:64],
                             '%s/%s.png' % (opt.save_path, epoch),
                             normalize=True,
                             range=(-1, 1))
            torch.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            torch.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            optimizer_g = torch.optim.Adam(netg.parameters(),
                                           opt.lr1,
                                           betas=(opt.beta1, 0.999))
            optimizer_d = torch.optim.Adam(netd.parameters(),
                                           opt.lr2,
                                           betas=(opt.beta1, 0.999))