Esempio n. 1
0
def generate(**kwargs):
    """
    随机生成图像,并根据netd的分数选择较好的
    """
    with t.no_grad():
        opt._parse(kwargs)
        netg, netd = NetG(opt).eval(), NetD(opt).eval()
        noises = t.randn(opt.gen_search_num, opt.nz, 1,
                         1).normal_(opt.gen_mean, opt.gen_std)
        noises = noises.to(opt._parse(kwargs))

        map_location = lambda storage, loc: storage
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
        netd.to(opt._parse(kwargs))
        netg.to(opt._parse(kwargs))

        # 生成图片,并计算图片在判别器的分数
        fake_img = netg(noises)
        scores = netd(fake_img).detach()

        # 挑选最好的某几张
        indexs = scores.topk(opt.gen_num)[1]
        result = []
        for ii in indexs:
            result.append(fake_img.data[ii])
        # 保存图片
        tv.utils.save_image(t.stack(result),
                            opt.gen_img,
                            normalize=True,
                            range=(-1, 1))
Esempio n. 2
0
def generate(**kwargs):
    """用训练好的数据进行生成图片"""

    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device("cuda" if opt.gpu else "cpu")

    #  1.加载训练好权重数据
    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    map_location = lambda storage, loc: storage

    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location),
                         False)
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location),
                         False)
    netd.to(device)
    netg.to(device)

    #  2.生成训练好的图片
    noise = t.randn(opt.gen_search_num, opt.nz, 1,
                    1).normal_(opt.gen_mean, opt.gen_std)
    noise.to(device)

    fake_image = netg(noise)
    score = netd(fake_image).detach()  # TODO 查阅topk()函数

    # 挑选出合适的图片
    indexs = score.topk(opt.gen_num)[1]

    result = []

    for ii in indexs:
        result.append(fake_image.data[ii])

    tv.utils.save_image(t.stack(result),
                        opt.gen_img,
                        normalize=True,
                        range=(-1, 1))
Esempio n. 3
0
def generate(**kwargs):
    opt.parse(**kwargs)
    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
    
    if opt.netd_path:
        netg.load(opt.netg_path)
    if opt.netg_path:
        netd.load(opt.netd_path)
    
    if opt.use_gpu:
        netd.cuda()
        netg.cuda()
        noises = noises.cuda()
    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).data
    # 选好的
    index = scores.topk(opt.gen_num)[1]
    result = []
    for ii in index:
        # tensor的截取与合并  cat, stack,cat+view=stack,stack 新增维度进行合并
        result.append(fake_img.data[ii])
    tv.utils.save_image(t.stack(result), opt.gen_img, normalize=True, range=(-1,1))
Esempio n. 4
0
def generate(**kwargs):
    '''
    随机生成动漫头像,并根据netd的分数选择较好的
    '''
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    D = NetD(opt)
    G = NetG(opt)

    noises = torch.randn(opt.gen_search_num, opt.nz, 1,
                         1).normal_(opt.gen_mean, opt.gen_std)
    noises = Variable(noises, volatile=True)

    map_location = lambda storage, loc: storage
    D.load_state_dict(torch.load(opt.netd_path, map_location=map_location))
    G.load_state_dict(torch.load(opt.netg_path, map_location=map_location))

    if torch.cuda.is_available():
        D.cuda()
        G.cuda()
        noises = noises.cuda()

    # 生成图片,并计算图片在判别器的分数
    fake_img = G(noises)
    scores = D(fake_img).data

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for idx in indexs:
        result.append(fake_img.data[idx])
    # 保存图片
    torchvision.utils.save_image(torch.stack(result),
                                 opt.gen_img,
                                 normalize=True,
                                 range=(-1, 1))
Esempio n. 5
0
writer = SummaryWriter()

opt = Config()

transform = transforms.Compose([
    transforms.Scale(opt.image_size),
    transforms.CenterCrop(opt.image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = ImageFolder(opt.data_path, transform=transform)

dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, drop_last=True)

netd = NetD(opt)
netg = NetG(opt)

if opt.netd_path:
    netd.load_state_dict(torch.load(opt.netd_path, map_location=lambda storage, loc: storage))
if opt.netg_path:
    netg.load_state_dict(torch.load(opt.netg_path, map_location=lambda storage, loc: storage))

optimizer_g = Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
optimizer_d = Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))

criterion = nn.BCELoss()

true_labels = Variable(torch.ones(opt.batch_size))
fake_labels = Variable(torch.zeros(opt.batch_size))
fix_noises = Variable(torch.randn(opt.batch_size, opt.nz, 1, 1))
Esempio n. 6
0
def train(**kwargs):
    opt._parse(kwargs)

    demoer = Evaluator(opt)

    anime_data = AnimeData(opt.data_path)
    anime_dataloader = DataLoader(anime_data,
                                  batch_size=opt.batch_size,
                                  shuffle=True)

    noise_data = NoiseData(opt.noise_size, len(anime_data))
    noise_dataloader = DataLoader(noise_data,
                                  batch_size=opt.batch_size,
                                  shuffle=True)

    net_G = NetG(opt)
    net_D = NetD(opt)

    if opt.use_gpu:
        net_G = net_G.cuda()
        net_D = net_D.cuda()

    criterion = torch.nn.BCELoss()
    optimizer_G = torch.optim.Adam(net_G.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.beta1, opt.beta2))
    optimizer_D = torch.optim.Adam(net_D.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.beta1, opt.beta2))

    loss_D_meteor = meter.AverageValueMeter()
    loss_G_meteor = meter.AverageValueMeter()

    if opt.netd_path is not None:
        net_D.load(opt.netd_path)
    if opt.netg_path is not None:
        net_G.load(opt.netg_path)

    for epoch in range(opt.max_epochs):
        loss_D_meteor.reset()
        loss_G_meteor.reset()

        num_batch = len(anime_dataloader)
        generator = enumerate(zip(anime_dataloader, noise_dataloader))
        for ii, (true_image, feature_map) in tqdm(generator,
                                                  total=num_batch,
                                                  ascii=True):
            num_data = true_image.shape[0]
            true_targets = torch.ones(num_data)
            fake_targets = torch.zeros(num_data)

            if opt.use_gpu:
                feature_map = feature_map.cuda()
                true_image = true_image.cuda()
                true_targets = true_targets.cuda()
                fake_targets = fake_targets.cuda()

            # Train discriminator
            if ii % opt.every_d == 0:
                optimizer_D.zero_grad()
                net_G.set_requires_grad(False)
                net_D.set_requires_grad(True)

                fake_image = net_G(feature_map)
                fake_score = net_D(fake_image)
                true_score = net_D(true_image)
                loss_D = criterion(fake_score, fake_targets) + \
                    criterion(true_score, true_targets)
                loss_D.backward()
                optimizer_D.step()

                loss_D_meteor.add(loss_D.detach().item())

                if os.path.exists(opt.debug_file):
                    import ipdb
                    ipdb.set_trace()

            # Train generator
            if ii % opt.every_g == 0:
                optimizer_G.zero_grad()
                net_G.set_requires_grad(True)
                net_D.set_requires_grad(False)

                fake_image = net_G(feature_map)
                fake_score = net_D(fake_image)
                loss_G = criterion(fake_score, true_targets)
                loss_G.backward()
                optimizer_G.step()

                loss_G_meteor.add(loss_G.detach().item())

                if os.path.exists(opt.debug_file):
                    import ipdb
                    ipdb.set_trace()

        gan_log = "Epoch {epoch:0>2d}: loss_D - {loss_D}, loss_G - {loss_G}".format(
            epoch=epoch + 1,
            loss_D=loss_D_meteor.value()[0],
            loss_G=loss_G_meteor.value()[0],
        )
        print(gan_log)

        if epoch % opt.save_freq == opt.save_freq - 1:
            demoer.evaluate(net_G)
            net_D.save(opt.save_model_path)
            net_G.save(opt.save_model_path)
        time.sleep(0.5)
Esempio n. 7
0
def train(**kwargs):
    opt._parse(kwargs)
    if opt.vis:
        from utils.visualize import Visualizer
        vis = Visualizer(opt.env)

    # 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(
        dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.num_workers,
        drop_last=True  # 最后一个数据集不满batch_size  将被遗弃
    )

    # 网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(opt.device)
    netg.to(opt.device)

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr2,
                               betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss().to(opt.device)

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = t.ones(opt.batch_size).to(opt.device)  # 真
    fake_labels = t.zeros(opt.batch_size).to(opt.device)  # 假
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(opt.device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(opt.device)

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(opt.device)

            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 +
                           0.5,
                           win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch + 1) % opt.save_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64],
                                '%s/%s.png' % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            tag = [
                i for i in os.listdir('./data') if os.path.isdir('./data/' + i)
            ][0]
            t.save(netd.state_dict(), 'checkpoints/%s_d_%s.pth' % (tag, epoch))
            t.save(netg.state_dict(), 'checkpoints/%s_g_%s.pth' % (tag, epoch))
            errord_meter.reset()
            errorg_meter.reset()
Esempio n. 8
0
def train(**kwargs):
    """training NetWork"""

    #  0.配置属性
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device("cuda") if opt.gpu else t.device("cpu")

    # 1.预处理数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.img_size),  # 3*96*96
        tv.transforms.CenterCrop(opt.img_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    #  1.1 加载数据
    dataset = tv.datasets.ImageFolder(opt.data_path,
                                      transform=transforms)  # TODO 复习这个封装方法
    dataloader = DataLoader(dataset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            drop_last=True)  # TODO 查看drop_last操作

    # 2.初始化网络
    netg, netd = NetG(opt), NetD(opt)
    # 2.1判断网络是否已有权重数值
    map_location = lambda storage, loc: storage  # TODO 复习map_location操作

    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    # 2.2 搬移模型到指定设备
    netd.to(device)
    netg.to(device)

    # 3. 定义优化策略
    #  TODO 复习Adam算法
    optimize_g = t.optim.Adam(netg.parameters(),
                              lr=opt.lr1,
                              betas=(opt.beta1, 0.999))
    optimize_d = t.optim.Adam(netd.parameters(),
                              lr=opt.lr2,
                              betas=(opt.beta1, 0.999))
    criterions = nn.BCELoss().to(device)  # TODO 重新复习BCELoss方法

    # 4. 定义标签, 并且开始注入生成器的输入noise
    true_labels = t.ones(opt.batch_size).to(device)
    fake_labels = t.ones(opt.batch_size).to(device)
    noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)
    fix_noises = t.randn(opt.batch_size, opt.nz, 1, 1).to(device)

    errord_meter = AverageValueMeter()  # TODO 重新阅读torchnet
    errorg_meter = AverageValueMeter()

    #  6.训练网络
    epochs = range(opt.max_epoch)
    write = SummaryWriter(log_dir=opt.virs, comment='loss')

    # 6.1 设置迭代
    for epoch in iter(epochs):
        #  6.2 读取每一个batch 数据
        for ii_, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = img.to(device)

            #  6.3开始训练生成器和判别器
            #  注意要使得生成的训练次数小于一些
            if ii_ % opt.d_every == 0:
                optimize_d.zero_grad()
                # 训练判别器
                # 真图
                output = netd(real_img)
                error_d_real = criterions(output, true_labels)
                error_d_real.backward()

                # 随机生成的假图
                noises = noises.detach()
                fake_image = netg(noises).detach()
                output = netd(fake_image)
                error_d_fake = criterions(output, fake_labels)
                error_d_fake.backward()
                optimize_d.step()

                # 计算loss
                error_d = error_d_fake + error_d_real
                errord_meter.add(error_d.item())

            # 训练判别器
            if ii_ % opt.g_every == 0:
                optimize_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterions(output, true_labels)
                error_g.backward()
                optimize_g.step()

                errorg_meter.add(error_g.item())
            # 绘制数据
            if ii_ % 5 == 0:
                write.add_scalar("Discriminator_loss", errord_meter.value()[0])
                write.add_scalar("Generator_loss", errorg_meter.value()[0])

        #  7.保存模型
        if (epoch + 1) % opt.save_every == 0:
            fix_fake_image = netg(fix_noises)
            tv.utils.save_image(fix_fake_image.data[:64],
                                "%s/%s.png" % (opt.save_path, epoch),
                                normalize=True)

            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()

    write.close()
Esempio n. 9
0
def train(**kwargs):
    # step1: configure
    opt.parse(**kwargs)
    if opt.vis:
        vis = Visualizer(opt.env)
    # step2: data
    normalize = T.Normalize(mean = [0.5,0.5,0.5], std = [0.5,0.5,0.5] )
    transforms = T.Compose(
    [
        T.Resize(opt.image_size),
        T.CenterCrop(opt.image_size),
        T.ToTensor(),
        normalize
    ])
    # 对于这个模型 transform对于train和test没有区别
    dataset = tv.datasets.ImageFolder(opt.data_path,transform=transforms)
    dataloader = DataLoader(dataset,
                            batch_size = opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            drop_last=True)                      # 加载图片,用于训练NetD模型
    
    true_labels = Variable(t.ones(opt.batch_size))
    fake_labels = Variable(t.zeros(opt.batch_size))
    fix_noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))  # 固定噪声,用于验证NetG模型
    noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))      # 随机噪声,用于训练和测试NetG模型
    
    # step3: model
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc:storage
    if opt.netg_path:
        netg.load(opt.netg_path)
    if opt.netd_path:
    	netd.load(opt.netd_path)
        
    
    # step4: criterion and optimizer
    optimizer_g = t.optim.Adam(params=netg.parameters(), lr = opt.lrg, betas=(opt.beta1,0.999))
    optimizer_d = t.optim.Adam(params=netd.parameters(), lr = opt.lrd, betas=(opt.beta1,0.999))
    criterion = t.nn.BCELoss()
    
    # step: meters
    errord_meter = meter.AverageValueMeter()
    errorg_meter = meter.AverageValueMeter()
    
    if opt.use_gpu:
        netd.cuda()
        netg.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()
    
    # step5: train
    for epoch in range(opt.max_epoch):
        ## step5.1 train
        for ii,(data, _) in tqdm(enumerate(dataloader)):
            real_img = Variable(data)
            if opt.use_gpu:
                real_img = real_img.cuda()
            if (ii+1) % opt.d_every ==0:
                # 判别器
                optimizer_d.zero_grad()
                # 真图片
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()
                # 假图片
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                # 在第八章的时候想通了这里为什么要加detach,这个detach不是为了防止反向传播传到netg中,因为parameter已经保证了,是为了fake_output的requires_grad设置为False,不对fake_img求导,因为不需要,当然,在后来的实验室中我设置成没有也没有报错,但是这是为了节约内存,已经确定,可以停止反向传播,节约内存,与requires_grad=False的意义一样,但是requires_grad只能用于leaf节点,对于非leaf节点,使其不进行求导的方式是detach()
                fake_img =  netg(noises).detach() 
                fake_output = netd(fake_img)
                error_d_fake = criterion(fake_output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()
                
                error_d = error_d_real+error_d_fake
                errord_meter.add(error_d.data)
                
            if (ii+1) % opt.g_every == 0:
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                
                errorg_meter.add(error_g.data)
                
            ## step5.2 validate and visualize on batch_size  
            # 我们可以看到,损失函数并不是一个epoch画一次,而是几个batch画一次
            if (ii+1) % opt.print_freq == 0 and opt.vis:
                if os.path.exists(opt.debug_file):
                    # import ipdb
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises) # batch_size*nz*1*1 --> batch_size(256)*3*96*96 # 可以认为是在验证模型
                vis.img('fix_fake',fix_fake_imgs.data[:64]*0.5+0.5)
                vis.img('real', real_img.data[:64]*0.5+0.5)
                vis.plot(win = 'errord',y= errord_meter.value()[0])
                vis.plot(win = 'errorg',y= errorg_meter.value()[0])
                
            
        ## step5.3 validate and save model on epoch 
        # 模型保存是每几个epoch保存一次,
        # 按理来说模型验证也应该是每次或这每几次验证一次,这一点和这一章的模型验证有所不一样,不过不用太在意,因为这一章的模型验证没有指标。
        if (epoch+1)%opt.save_freq == 0:
            netg.save(opt.model_save_path,'netg_%s' %epoch)
            netd.save(opt.model_save_path,'netd_%s' %epoch)
            fix_fake_imgs = val(netg,fix_noises)
            tv.utils.save_image(fix_fake_imgs,'%s/%s.png' % (opt.img_save_path, epoch),normalize=True, range=(-1,1))
            # 和作者沟通后,因为数据集少,所以为了避免每次重置的噪声,多几个epoch再重置,等下试试每次重置的话这个误差的变化情况
            errord_meter.reset()
            errorg_meter.reset()
            """
Esempio n. 10
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(opt.image_size),
        torchvision.transforms.CenterCrop(opt.image_size),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                         std=(0.5, 0.5, 0.5))
    ])

    dataset = torchvision.datasets.ImageFolder(opt.data_path,
                                               transform=transforms)

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=opt.num_workers,
                                             drop_last=True)

    # 1、定义神经网络
    D = NetD(opt)
    G = NetG(opt)

    map_location = lambda storage, loc: storage
    if opt.netd_path:
        D.load_state_dict(torch.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        G.load_state_dict(torch.load(opt.netg_path, map_location=map_location))

    # 2、定义优化器和损失
    d_optim = torch.optim.Adam(D.parameters(),
                               opt.d_learning_rate,
                               betas=(opt.optim_beta1, 0.999))
    g_optim = torch.optim.Adam(G.parameters(),
                               opt.g_learning_rate,
                               betas=(opt.optim_beta1, 0.999))
    criterion = torch.nn.BCELoss()

    # 真图片label为1,假图片label为0
    real_labels = Variable(torch.ones(opt.batch_size))
    fake_labels = Variable(torch.zeros(opt.batch_size))

    if torch.cuda.is_available():
        D.cuda()
        G.cuda()
        criterion.cuda()
        real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

    # 3、可视化训练过程
    for epoch in range(opt.num_epochs):
        for step, (images, _) in tqdm.tqdm(enumerate(dataloader)):

            if step % opt.d_every == 0:
                # 1、训练判别器
                d_optim.zero_grad()

                ## 尽可能的把真图片判别为正确
                d_real_data = Variable(images)
                d_real_data = d_real_data.cuda() if torch.cuda.is_available(
                ) else d_real_data
                d_real_decision = D(d_real_data)
                d_real_error = criterion(d_real_decision, real_labels)
                d_real_error.backward()

                ## 尽可能把假图片判别为错误
                d_gen_input = Variable(
                    torch.randn(opt.batch_size, opt.nz, 1, 1))
                d_gen_input = d_gen_input.cuda() if torch.cuda.is_available(
                ) else d_gen_input
                d_fake_data = G(d_gen_input).detach()
                d_fake_decision = D(d_fake_data)
                d_fake_error = criterion(d_fake_decision, fake_labels)
                d_fake_error.backward()
                d_optim.step(
                )  # Only optimizes D's parameters; changes based on stored gradients from backward()

            if step % opt.g_every == 0:
                # 2、训练生成器
                g_optim.zero_grad()

                ## 尽可能让判别器把假图片判别为正确
                g_gen_input = Variable(
                    torch.randn(opt.batch_size, opt.nz, 1, 1))
                g_gen_input = g_gen_input.cuda() if torch.cuda.is_available(
                ) else g_gen_input
                g_fake_data = G(g_gen_input)
                g_fake_decision = D(g_fake_data)
                g_fake_error = criterion(g_fake_decision, real_labels)
                g_fake_error.backward()

                g_optim.step()

        if step % opt.epoch_every == 0:
            print("%s, %s, D: %s/%s G: %s" %
                  (step, g_fake_decision.cpu().data.numpy().mean(),
                   d_real_error.cpu().data[0], d_fake_error.cpu().data[0],
                   g_fake_error.cpu().data[0]))

            # 保存模型、图片
            torchvision.utils.save_image(g_fake_data.data[:36],
                                         '%s/%s.png' %
                                         (opt.save_img_path, epoch),
                                         normalize=True,
                                         range=(-1, 1))
            torch.save(D.state_dict(),
                       '%s/netd_%s.pth' % (opt.checkpoints_path, epoch))
            torch.save(G.state_dict(),
                       '%s/netg_%s.pth' % (opt.checkpoints_path, epoch))
Esempio n. 11
0
class Opt():
    noise_size = NOISE_SIZE
    ndf = 64
    ngf = 64


opt = Opt()

noise_data = NoiseData(NOISE_SIZE, BATCH_SIZE)
noise_dataloader = DataLoader(noise_data, batch_size=BATCH_SIZE)
noise_iter = iter(noise_dataloader)

feature_map = next(noise_iter)

net_G = NetG(opt)
net_D = NetD(opt)

criterion = torch.nn.MSELoss()
optimzer = torch.optim.SGD(net_D.parameters(), lr=0.1)


# ============================= #
# Testing
# ============================= #
def test_networks():
    generated = net_G(feature_map)
    assert generated.shape == torch.Tensor(BATCH_SIZE, 3, 96, 96).shape
    assert torch.max(generated) <= 1
    assert torch.min(generated) >= 0

    res_generated = net_D(generated)
Esempio n. 12
0
transform = transforms.Compose([
    transforms.Scale(opt.image_size),
    transforms.CenterCrop(opt.image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = ImageFolder(opt.data_path, transform=transform)

dataloader = DataLoader(dataset,
                        batch_size=opt.batch_size,
                        shuffle=True,
                        num_workers=opt.num_workers,
                        drop_last=True)

netd = NetD(opt)
netg = NetG(opt)

if opt.netd_path:
    netd.load_state_dict(
        torch.load(opt.netd_path, map_location=lambda storage, loc: storage))
if opt.netg_path:
    netg.load_state_dict(
        torch.load(opt.netg_path, map_location=lambda storage, loc: storage))

optimizer_g = Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
optimizer_d = Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))

criterion = nn.BCELoss()

true_labels = Variable(torch.ones(opt.batch_size))
Esempio n. 13
0
                               replacement=False)
    ulnodeB = np.setdiff1d(np.array([i for i in range(len(labelsB))]),
                           lnodeB.numpy())
    np.random.shuffle(ulnodeB)
    valnodeB = ulnodeB[0:int(0.2 * nB)]
    ulnodeB = ulnodeB[int(0.2 * nB) + 1:-1]

np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False

# Model and optimizer
mlp = MLP(in_features=args.hidden, nclass=labelsA.max().item() + 1)
Dnet = NetD(nhid=args.hidden)
model = GCN(nfeat=featuresA.shape[1],
            nhid=args.hidden,
            nclass=labelsA.max().item() + 1,
            dropout=args.dropout)
#model.load_state_dict(torch.load('init2.pkl'))
#for item in model.parameters():
#    print(item)
optimizer = optim.Adam(model.parameters(),
                       lr=args.lr,
                       weight_decay=args.weight_decay)
optimizer_mlp = optim.Adam(mlp.parameters(),
                           lr=0.01,
                           weight_decay=args.weight_decay)
dis_optimizer = optim.SGD(Dnet.parameters(),
                          lr=args.lr,