예제 #1
0
def generate(**kwargs):
	'''
	random create caton images and chose the highest scords top 60
	'''
	for k_,v_, in kwargs.items():
		setattr(opt,k_,v_)


	netg,netd = NetG(opt).eval(),NetD(opt).eval()
	noises = Variable(t.randn(opt.gen_search_num,opt.nz,1,1).normal_(opt.gen_mean,opt.gen_std))
	

	map_location = lambda storage,loc:storage 
	netd.load_state_dict(t.load(opt.netd_path,map_location=map_location))
	netg.load_state_dict(t.load(opt.netg_path,map_location=map_location))

	if opt.use_gpu is True:
		noises.cuda()
		netd.cuda()
		netg.cuda()
		ipdb.set_trace()

	fake_img = netg(noises)
	scores = netd(fake_img).data

	indexs = scores.topk(opt.gen_num)[1]
	result = []

	for ii in indexs:
		result.append(fake_img.data[ii])

	tv.utils.save_image(t.stack(result),opt.gen_img,normalize=True,range=(-1,1))
예제 #2
0
def generate(**kwargs):
    '''
    随机生成动漫头像,并根据netd的分数选择较好的
    '''
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1,
                     1).normal_(opt.gen_mean, opt.gen_std)
    noises = Variable(noises, volatile=True)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

    if opt.gpu:
        netd.cuda()
        netg.cuda()
        noises = noises.cuda()

    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).data

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    tv.utils.save_image(t.stack(result),
                        opt.gen_img,
                        normalize=True,
                        range=(-1, 1))
예제 #3
0
파일: main.py 프로젝트: YohLee/pytorch-book
def generate(**kwargs):
    '''
    随机生成动漫头像,并根据netd的分数选择较好的
    '''
    for k_,v_ in kwargs.items():
        setattr(opt,k_,v_)
    
    netg, netd = NetG(opt).eval(), NetD(opt).eval()  
    noises = t.randn(opt.gen_search_num,opt.nz,1,1).normal_(opt.gen_mean,opt.gen_std)
    noises = Variable(noises, volatile=True)

    map_location=lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location = map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location = map_location))
    
    if opt.gpu:
        netd.cuda()
        netg.cuda()
        noises = noises.cuda()
        
    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).data

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    tv.utils.save_image(t.stack(result),opt.gen_img,normalize=True,range=(-1,1))
예제 #4
0
def generate(**kwargs):
    """
    随机生成动漫头像,并根据netd的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    # 将网络模型置为预测模式  不保存中间结果,加速
    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    # 初始化gen_search_num张噪声,期望生成gen_search_num张预测图像
    noises = t.randn(opt.gen_search_num, opt.nz, 1,
                     1).normal_(opt.gen_mean, opt.gen_std)
    noises = Variable(noises, volatile=True)
    # 将模型参数加载到cpu中
    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    # 模型和输入噪声转到GPU中
    if opt.gpu:
        netd.cuda()
        netg.cuda()
        noises = noises.cuda()

    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).data

    # 挑选最好的某几张  从512章图片中按分数排序,取前64张  的下标
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    tv.utils.save_image(t.stack(result),
                        opt.gen_img,
                        normalize=True,
                        range=(-1, 1))
예제 #5
0
                                  betas=(0.0, 0.9))
    optimizerD_enc = torch.optim.Adam(netD.feature_encoder.parameters(),
                                      lr=0.0004,
                                      betas=(0.0, 0.9))
    optimizerD_proj = torch.optim.Adam(netD.COND_DNET.parameters(),
                                       lr=0.004,
                                       betas=(0.0, 0.9))

    if state_epoch != 0:
        netG.load_state_dict(
            torch.load('%s/models/netG_%03d.pth' % (output_dir, state_epoch),
                       map_location='cpu'))
        netD.load_state_dict(
            torch.load('%s/models/netD_%03d.pth' % (output_dir, state_epoch),
                       map_location='cpu'))
        netG = netG.cuda()
        netD = netD.cuda()
        optimizerG.load_state_dict(
            torch.load('%s/models/optimizerG.pth' % (output_dir)))
        optimizerD.load_state_dict(
            torch.load('%s/models/optimizerD.pth' % (output_dir)))

    netG.cuda()
    netD.cuda()

    if cfg.B_VALIDATION:
        sampling(text_encoder, netG, dataloader,
                 device)  # generate images for the whole valid dataset
        logger.info('state_epoch:  %d' % (state_epoch))
    else:
        train(dataloader, netG, netD, text_encoder, optimizerG, optimizerD_enc,
예제 #6
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)

    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True)

    # 定义网络
    netg, netd = NetG(opt), NetD(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr2,
                               betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss()  # 二分类交叉熵

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = Variable(t.ones(opt.batch_size))
    fake_labels = Variable(t.zeros(opt.batch_size))
    fix_noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))
    noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    if opt.gpu:
        netd.cuda()
        netg.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = Variable(img)
            if opt.gpu:
                real_img = real_img.cuda()
            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.data[0])

            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output, true_labels)
                error_g.backward()
                optimizer_g.step()
                errorg_meter.add(error_g.data[0])

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

            fix_fake_imgs = netg(fix_noises)

        if epoch % opt.decay_every == 0:
            # 保存图片
            tv.utils.save_image(fix_fake_imgs.data[:64],
                                '%s/%s.png' % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            # 保存模型
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
            optimizer_g = t.optim.Adam(netg.parameters(),
                                       opt.lr1,
                                       betas=(opt.beta1, 0.999))
            optimizer_d = t.optim.Adam(netd.parameters(),
                                       opt.lr2,
                                       betas=(opt.beta1, 0.999))
예제 #7
0
def main(args):
    # manualSeed to control the noise
    manualSeed = 100
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)

    with open(args.json_file, 'r') as f:
        dataset_json = json.load(f)

    # load rnn encoder
    text_encoder = RNN_ENCODER(dataset_json['n_words'], nhidden=dataset_json['text_embed_dim'])
    text_encoder_dir = args.rnn_encoder
    state_dict = torch.load(text_encoder_dir, map_location=lambda storage, loc: storage)
    text_encoder.load_state_dict(state_dict)

    # load netG
    state_dict = torch.load(args.model_path, map_location=torch.device('cpu'))
    # netG = NetG(int(dataset_json['n_channels']), int(dataset_json['cond_dim']))
    netG = NetG(64, int(dataset_json['cond_dim']))
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`nvidia
        new_state_dict[name] = v
    model_dict = netG.state_dict()
    pretrained_dict = {k: v for k, v in new_state_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    netG.load_state_dict(model_dict)

    # use gpu or not, change model to evaluation mode
    if args.use_gpu:
        text_encoder.cuda()
        netG.cuda()
        caption_idx.cuda()
        caption_len.cuda()
        noise.cuda()

    text_encoder.eval()
    netG.eval()

    # generate noise
    num_noise = 100
    noise = torch.FloatTensor(num_noise, 100)

    # cub bird captions
    # caption = 'this small bird has a light yellow breast and brown wings'
    # caption = 'this small bird has a short beak a light gray breast a darker gray crown and black wing tips'
    # caption = 'this small bird has wings that are gray and has a white belly'
    # caption = 'this bird has a yellow throat belly abdomen and sides with lots of brown streaks on them'
    # caption = 'this little bird has a yellow belly and breast with a gray wing with white wingbars'
    # caption = 'this bird has a white belly and breast wit ha blue crown and nape'
    # caption = 'a bird with brown and black wings red crown and throat and the bill is short and pointed'
    # caption = 'this small bird has a yellow crown and a white belly'
    # caption = 'this bird has a blue crown with white throat and brown secondaries'
    # caption = 'this bird has wings that are black and has a white belly'
    # caption = 'a yellow bird has wings with dark stripes and small eyes'
    # caption = 'a black bird has wings with dark stripes and small eyes'
    # caption = 'a red bird has wings with dark stripes and small eyes'
    # caption = 'a white bird has wings with dark stripes and small eyes'
    # caption = 'a blue bird has wings with dark stripes and small eyes'
    # caption = 'a pink bird has wings with dark stripes and small eyes'
    # caption = 'this is a white and grey bird with black wings and a black stripe by its eyes'
    # caption = 'a small bird with an orange bill and grey crown and breast'
    # caption = 'a small bird with black gray and white wingbars'
    # caption = 'this bird is white and light orange in color with a black beak'
    # caption = 'a small sized bird that has tones of brown and a short pointed bill' # beak?

    # MS coco captions
    # caption = 'two men skiing down a snow covered mountain in the evening'
    # caption = 'a man walking down a grass covered mountain'
    # caption = 'a close up of a boat on a field under a sunset'
    # caption = 'a close up of a boat on a field with a clear sky'
    # caption = 'a herd of black and white cattle standing on a field'
    # caption = 'a herd of black and white sheep standing on a field'
    # caption = 'a herd of black and white dogs standing on a field'
    # caption = 'a herd of brown cattle standing on a field'
    # caption = 'a herd of black and white cattle standing in a river'
    # caption = 'some horses in a field of green grass with a sky in the background'
    # caption = 'some horses in a field of yellow grass with a sky in the background'
    caption = 'some horses in a field of green grass with a sunset in the background'

    # convert caption to index
    caption_idx, caption_len = get_caption_idx(dataset_json, caption)
    caption_idx = torch.LongTensor(caption_idx)
    caption_len = torch.LongTensor([caption_len])
    caption_idx = caption_idx.view(1, -1)
    caption_len = caption_len.view(-1)

    # use rnn encoder to get caption embedding
    hidden = text_encoder.init_hidden(1)
    words_embs, sent_emb = text_encoder(caption_idx, caption_len, hidden)

    # generate fake image
    noise.data.normal_(0, 1)
    sent_emb = sent_emb.repeat(num_noise, 1)
    words_embs = words_embs.repeat(num_noise, 1, 1)
    with torch.no_grad():
        fake_imgs, fusion_mask = netG(noise, sent_emb)

        # create path to save image, caption and mask
        cap_number = 10000
        main_path = 'result/mani/cap_%s_0_coco_ch64' % (str(cap_number))
        img_save_path = '%s/image' % main_path
        mask_save_path = '%s/mask_' % main_path
        mkdir_p(img_save_path)
        for i in range(7):
            mkdir_p(mask_save_path + str(i))

        # save caption as image
        ixtoword = {v: k for k, v in dataset_json['word2idx'].items()}
        cap_img = cap2img(ixtoword, caption_idx, caption_len)
        im = cap_img[0].data.cpu().numpy()
        im = (im + 1.0) * 127.5
        im = im.astype(np.uint8)
        im = np.transpose(im, (1, 2, 0))
        im = Image.fromarray(im)
        full_path = '%s/caption.png' % main_path
        im.save(full_path)

        # save generated images and masks
        for i in tqdm(range(num_noise)):
            full_path = '%s/image_%d.png' % (img_save_path, i)
            im = fake_imgs[i].data.cpu().numpy()
            im = (im + 1.0) * 127.5
            im = im.astype(np.uint8)
            im = np.transpose(im, (1, 2, 0))
            im = Image.fromarray(im)
            im.save(full_path)

            for j in range(7):
                full_path = '%s%1d/mask_%d.png' % (mask_save_path, j, i)
                im = fusion_mask[j][i][0].data.cpu().numpy()
                im = im * 255
                im = im.astype(np.uint8)
                im = Image.fromarray(im)
                im.save(full_path)
예제 #8
0
파일: main.py 프로젝트: YohLee/pytorch-book
def train(**kwargs):
    for k_,v_ in kwargs.items():
        setattr(opt,k_,v_)
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)
    
    transforms = tv.transforms.Compose([
                    tv.transforms.Scale(opt.image_size),
                    tv.transforms.CenterCrop(opt.image_size),
                    tv.transforms.ToTensor(),
                    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                        ])
    
    dataset = tv.datasets.ImageFolder(opt.data_path,transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size = opt.batch_size,
                                         shuffle = True,
                                         num_workers= opt.num_workers,
                                         drop_last=True
                                         )

    # 定义网络
    netg, netd = NetG(opt), NetD(opt)  
    map_location=lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location = map_location)) 
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location = map_location))

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),opt.lr1,betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),opt.lr2,betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss()

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = Variable(t.ones(opt.batch_size))
    fake_labels = Variable(t.zeros(opt.batch_size))
    fix_noises = Variable(t.randn(opt.batch_size,opt.nz,1,1))
    noises = Variable(t.randn(opt.batch_size,opt.nz,1,1))

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    if opt.gpu:
        netd.cuda()
        netg.cuda()
        criterion.cuda()
        true_labels,fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises,noises = fix_noises.cuda(),noises.cuda()
        
    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii,(img,_) in tqdm.tqdm(enumerate(dataloader)):
            real_img = Variable(img)
            if opt.gpu: 
                real_img=real_img.cuda()
            if ii%opt.d_every==0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output,true_labels)
                error_d_real.backward()
                
                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size,opt.nz,1,1))
                fake_img = netg(noises).detach() # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output,fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.data[0])

            if ii%opt.g_every==0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size,opt.nz,1,1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output,true_labels)
                error_g.backward()
                optimizer_g.step()                
                errorg_meter.add(error_g.data[0])

            if opt.vis and ii%opt.plot_every == opt.plot_every-1:
                ## 可视化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.data.cpu().numpy()[:64]*0.5+0.5,win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64]*0.5+0.5,win='real')  
                vis.plot('errord',errord_meter.value()[0])
                vis.plot('errorg',errorg_meter.value()[0])   

        if epoch%opt.decay_every==0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64],'%s/%s.png' %(opt.save_path,epoch),normalize=True,range=(-1,1))
            t.save(netd.state_dict(),'checkpoints/netd_%s.pth' %epoch)
            t.save(netg.state_dict(),'checkpoints/netg_%s.pth' %epoch)
            errord_meter.reset()
            errorg_meter.reset()
            optimizer_g = t.optim.Adam(netg.parameters(),opt.lr1,betas=(opt.beta1, 0.999))
            optimizer_d = t.optim.Adam(netd.parameters(),opt.lr2,betas=(opt.beta1, 0.999))
예제 #9
0
def train():
    # change opt
    # for k_, v_ in kwargs.items():
    #     setattr(opt, k_, v_)

    device = torch.device('cuda') if torch.cuda.is_available else torch.device(
        'cpu')

    if opt.vis:
        from visualizer import Visualizer
        vis = Visualizer(opt.env)

    # rescale to -1~1
    transform = transforms.Compose([
        transforms.Resize(opt.image_size),
        transforms.CenterCrop(opt.image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = datasets.ImageFolder(opt.data_path, transform=transform)

    dataloader = DataLoader(dataset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            drop_last=True)

    netd = NetD(opt)
    netg = NetG(opt)
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(torch.load(opt.netd_path),
                             map_location=map_location)
    if opt.netg_path:
        netg.load_state_dict(torch.load(opt.netg_path),
                             map_location=map_location)

    if torch.cuda.is_available():
        netd.to(device)
        netg.to(device)

    # 定义优化器和损失
    optimizer_g = torch.optim.Adam(netg.parameters(),
                                   opt.lr1,
                                   betas=(opt.beta1, 0.999))
    optimizer_d = torch.optim.Adam(netd.parameters(),
                                   opt.lr2,
                                   betas=(opt.beta1, 0.999))

    criterion = torch.nn.BCELoss().to(device)

    # 真label为1, noises是输入噪声
    true_labels = Variable(torch.ones(opt.batch_size))
    fake_labels = Variable(torch.zeros(opt.batch_size))

    fix_noises = Variable(torch.randn(opt.batch_size, opt.nz, 1, 1))
    noises = Variable(torch.randn(opt.batch_size, opt.nz, 1, 1))

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    if torch.cuda.is_available():
        netd.cuda()
        netg.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()

    for epoch in range(opt.max_epoch):
        print("epoch:", epoch, end='\r')
        # sys.stdout.flush()
        for ii, (img, _) in enumerate(dataloader):
            real_img = Variable(img)
            if torch.cuda.is_available():
                real_img = real_img.cuda()

            # 训练判别器, real -> 1, fake -> 0
            if (ii + 1) % opt.d_every == 0:
                # real
                optimizer_d.zero_grad()
                output = netd(real_img)
                # print(output.shape, true_labels.shape)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()
                # fake
                noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 随机噪声生成假图
                fake_output = netd(fake_img)
                error_d_fake = criterion(fake_output, fake_labels)
                error_d_fake.backward()
                # update optimizer
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.item())

            # 训练生成器, 让生成器得到的图片能够被判别器判别为真
            if (ii + 1) % opt.g_every == 0:
                optimizer_g.zero_grad()
                noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                fake_output = netd(fake_img)
                error_g = criterion(fake_output, true_labels)
                error_g.backward()
                optimizer_g.step()

                errorg_meter.add(error_g.item())

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                # 进行可视化
                # if os.path.exists(opt.debug_file):
                #     import ipdb
                #     ipdb.set_trace()

                fix_fake_img = netg(fix_noises)
                vis.images(
                    fix_fake_img.detach().cpu().numpy()[:opt.batch_size] * 0.5
                    + 0.5,
                    win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:opt.batch_size] * 0.5 +
                           0.5,
                           win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch + 1) % opt.save_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_img.data[:opt.batch_size],
                                '%s/%s.png' % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            torch.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            torch.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()
예제 #10
0
                                   betas=(opt.beta1, 0.999))
    optimizer_d = torch.optim.Adam(netd.parameters(),
                                   opt.lr_netd,
                                   betas=(opt.beta1, 0.999))

    criterion = torch.nn.BCELoss()

    # 真图片 label 为 1,假图片为 0
    true_labels = Variable(torch.ones(opt.batch_size))
    fake_labels = Variable(torch.zeros(opt.batch_size))

    #fix_noises = Variable(torch.randn(opt.batch_size, opt.nz, 1, 1))
    noises = Variable(torch.randn(opt.batch_size, opt.nz, 1, 1))

    if opt.gpu:
        netg.cuda()
        netd.cuda()
        criterion.cuda()
        true_labels = true_labels.cuda()
        fake_labels = fake_labels.cuda()
        #fix_noises = fix_noises.cuda()
        noises = noises.cuda()

    while True:
        action = input('Train(t) or Generate(g) or Quit(q)> ').lower()

        if action == 't':
            epochs = int(input('Epoch times > '))
            train(max_epoch=epochs)

        elif action == 'g':
예제 #11
0
def train(**kwargs):
    # 读取参数赋值
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    # 可视化
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)
    # 对图片进行操作
    transforms = tv.transforms.Compose([
        tv.transforms.Scale(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        # 均值方差
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    # ImageFolder 使用pytorch原生的方法读取图片,并进行操作  封装数据集
    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    #数据加载器
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True)

    # 定义网络
    netg, netd = NetG(opt), NetD(opt)
    # 把map内容加载到CPU中
    map_location = lambda storage, loc: storage
    # 将预训练的模型都先加载到cpu上
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),
                               opt.lr1,
                               betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),
                               opt.lr2,
                               betas=(opt.beta1, 0.999))
    # BinaryCrossEntropy二分类交叉熵,常用于二分类问题,当然也可以用于多分类问题
    criterion = t.nn.BCELoss()

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = Variable(t.ones(opt.batch_size))
    fake_labels = Variable(t.zeros(opt.batch_size))
    # fix_noises是固定值,用来查看每个epoch的变化效果
    fix_noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))
    noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))
    # AverageValueMeter统计任意添加的变量的方差和均值  可视化的仪表盘
    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    if opt.gpu:
        # 网络转移到GPU
        netd.cuda()
        netg.cuda()
        # 损失函数转移到GPU
        criterion.cuda()
        # 标签转移到GPU
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        # 输入噪声转移到GPU
        fix_noises, noises = fix_noises.cuda(), noises.cuda()

    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):

        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = Variable(img)
            if opt.gpu:
                real_img = real_img.cuda()
            # 每d_every个batch训练判别器
            if ii % opt.d_every == 0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                #一个batchd的真照片判定为1 并反向传播
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                #反向传播
                error_d_real.backward()

                ## 尽可能把假图片判别为错误
                # 一个batchd的假照片判定为0 并反向传播
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output, fake_labels)
                error_d_fake.backward()
                #更新可学习参数
                optimizer_d.step()
                # 总误差=识别真实图片误差+假图片误差
                error_d = error_d_fake + error_d_real
                # 将总误差加入仪表板用于可视化显示
                errord_meter.add(error_d.data[0])
            # 每g_every个batch训练生成器
            if ii % opt.g_every == 0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                # 生成器:噪声生成假图片
                fake_img = netg(noises)
                # 判别器:假图片判别份数
                output = netd(fake_img)
                # 尽量让假图片的份数与真标签接近,让判别器分不出来
                error_g = criterion(output, true_labels)
                error_g.backward()
                # 更新参数
                optimizer_g.step()
                # 将误差加入仪表板用于可视化显示
                errorg_meter.add(error_g.data[0])

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                ## 可视化
                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                # 固定噪声生成假图片
                fix_fake_imgs = netg(fix_noises)
                # 可视化 固定噪声产生的假图片
                vis.images(fix_fake_imgs.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='fixfake')
                # 可视化一张真图片。作为对比
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5,
                           win='real')
                # 可视化仪表盘  判别器误差  生成器误差
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])
        # 每decay_every个epoch之后保存一次模型
        if epoch % opt.decay_every == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64],
                                '%s/%s.png' % (opt.save_path, epoch),
                                normalize=True,
                                range=(-1, 1))
            # 保存判别器  生成器
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            # 清空误差仪表盘
            errord_meter.reset()
            errorg_meter.reset()
            # 重置优化器参数为刚开始的参数
            optimizer_g = t.optim.Adam(netg.parameters(),
                                       opt.lr1,
                                       betas=(opt.beta1, 0.999))
            optimizer_d = t.optim.Adam(netd.parameters(),
                                       opt.lr2,
                                       betas=(opt.beta1, 0.999))
예제 #12
0
class TACGAN():
    def __init__(self, args):
        self.lr = args.lr
        self.cuda = args.use_cuda
        self.batch_size = args.batch_size
        self.image_size = args.image_size
        self.epochs = args.epochs
        self.data_root = args.data_root
        self.dataset = args.dataset
        self.num_classes = args.num_cls
        self.save_dir = args.save_dir
        self.save_prefix = args.save_prefix
        self.continue_training = args.continue_training
        self.netG_path = args.netg_path
        self.netD_path = args.netd_path
        self.save_after = args.save_after
        self.trainset_loader = None
        self.evalset_loader = None
        self.num_workers = args.num_workers
        self.n_z = args.n_z  # length of the noise vector
        self.nl_d = args.nl_d
        self.nl_g = args.nl_g
        self.nf_g = args.nf_g
        self.nf_d = args.nf_d
        self.bce_loss = nn.BCELoss()
        self.nll_loss = nn.NLLLoss()
        self.netD = NetD(n_cls=self.num_classes, n_t=self.nl_d, n_f=self.nf_d)
        self.netG = NetG(n_z=self.n_z, n_l=self.nl_g, n_c=self.nf_g)

        # convert to cuda tensors
        if self.cuda and torch.cuda.is_available():
            print('CUDA is enabled')
            self.netD = self.netD.cuda()
            self.netG = self.netG.cuda()
            self.bce_loss = self.bce_loss.cuda()
            self.nll_loss = self.nll_loss.cuda()

        # optimizers for netD and netG
        self.optimizerD = optim.Adam(params=self.netD.parameters(),
                                     lr=self.lr,
                                     betas=(0.5, 0.999))
        self.optimizerG = optim.Adam(params=self.netG.parameters(),
                                     lr=self.lr,
                                     betas=(0.5, 0.999))

        # create dir for saving checkpoints and other results if do not exist
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        if not os.path.exists(os.path.join(self.save_dir, 'netd_checkpoints')):
            os.makedirs(os.path.join(self.save_dir, 'netd_checkpoints'))
        if not os.path.exists(os.path.join(self.save_dir, 'netg_checkpoints')):
            os.makedirs(os.path.join(self.save_dir, 'netg_checkpoints'))
        if not os.path.exists(os.path.join(self.save_dir, 'generated_images')):
            os.makedirs(os.path.join(self.save_dir, 'generated_images'))

    # start training process
    def train(self):
        # write to the log file and print it
        log_msg = '********************************************\n'
        log_msg += '            Training Parameters\n'
        log_msg += 'Dataset:%s\nImage size:%dx%d\n' % (
            self.dataset, self.image_size, self.image_size)
        log_msg += 'Batch size:%d\n' % (self.batch_size)
        log_msg += 'Number of epochs:%d\nlr:%f\n' % (self.epochs, self.lr)
        log_msg += 'nz:%d\nnl-d:%d\nnl-g:%d\n' % (self.n_z, self.nl_d,
                                                  self.nl_g)
        log_msg += 'nf-g:%d\nnf-d:%d\n' % (self.nf_g, self.nf_d)
        log_msg += '********************************************\n\n'
        print(log_msg)
        with open(os.path.join(self.save_dir, 'training_log.txt'),
                  'a') as log_file:
            log_file.write(log_msg)
        # load trainset and evalset
        imtext_ds = ImTextDataset(data_dir=self.data_root,
                                  dataset=self.dataset,
                                  train=True,
                                  image_size=self.image_size)
        self.trainset_loader = DataLoader(dataset=imtext_ds,
                                          batch_size=self.batch_size,
                                          shuffle=True,
                                          num_workers=2)
        print("Dataset loaded successfuly")
        # load checkpoints for continuing training
        if args.continue_training:
            self.loadCheckpoints()

        # repeat for the number of epochs
        netd_losses = []
        netg_losses = []
        for epoch in range(self.epochs):
            netd_loss, netg_loss = self.trainEpoch(epoch)
            netd_losses.append(netd_loss)
            netg_losses.append(netg_loss)
            self.saveGraph(netd_losses, netg_losses)
            #self.evalEpoch(epoch)
            self.saveCheckpoints(epoch)

    # train epoch
    def trainEpoch(self, epoch):
        self.netD.train()  # set to train mode
        self.netG.train()  #! set to train mode???

        netd_loss_sum = 0
        netg_loss_sum = 0
        start_time = time()
        for i, (images, labels, captions,
                _) in enumerate(self.trainset_loader):
            batch_size = images.size(
                0
            )  # !batch size my be different (from self.batch_size) for the last batch
            images, labels, captions = Variable(images), Variable(
                labels), Variable(captions)  # !labels should be LongTensor
            labels = labels.type(
                torch.FloatTensor
            )  # convert to FloatTensor (from DoubleTensor)
            lbl_real = Variable(torch.ones(batch_size, 1))
            lbl_fake = Variable(torch.zeros(batch_size, 1))
            noise = Variable(torch.randn(batch_size,
                                         self.n_z))  # create random noise
            noise.data.normal_(0, 1)  # normalize the noise
            rnd_perm1 = torch.randperm(
                batch_size
            )  # random permutations for different sets of training tuples
            rnd_perm2 = torch.randperm(batch_size)
            rnd_perm3 = torch.randperm(batch_size)
            rnd_perm4 = torch.randperm(batch_size)
            if self.cuda:
                images, labels, captions = images.cuda(), labels.cuda(
                ), captions.cuda()
                lbl_real, lbl_fake = lbl_real.cuda(), lbl_fake.cuda()
                noise = noise.cuda()
                rnd_perm1, rnd_perm2, rnd_perm3, rnd_perm4 = rnd_perm1.cuda(
                ), rnd_perm2.cuda(), rnd_perm3.cuda(), rnd_perm4.cuda()

            ############### Update NetD ###############
            self.netD.zero_grad()
            # train with wrong image, wrong label, real caption
            outD_wrong, outC_wrong = self.netD(images[rnd_perm1],
                                               captions[rnd_perm2])
            lossD_wrong = self.bce_loss(outD_wrong, lbl_fake)
            lossC_wrong = self.bce_loss(outC_wrong, labels[rnd_perm1])

            # train with real image, real label, real caption
            outD_real, outC_real = self.netD(images, captions)
            lossD_real = self.bce_loss(outD_real, lbl_real)
            lossC_real = self.bce_loss(outC_real, labels)

            # train with fake image, real label, real caption
            fake = self.netG(noise, captions)
            outD_fake, outC_fake = self.netD(fake.detach(),
                                             captions[rnd_perm3])
            lossD_fake = self.bce_loss(outD_fake, lbl_fake)
            lossC_fake = self.bce_loss(outC_fake, labels[rnd_perm3])

            # backward and forwad for NetD
            netD_loss = lossC_wrong + lossC_real + lossC_fake + lossD_wrong + lossD_real + lossD_fake
            netD_loss.backward()
            self.optimizerD.step()

            ########## Update NetG ##########
            # train with fake data
            self.netG.zero_grad()
            noise.data.normal_(0, 1)  # normalize the noise vector
            fake = self.netG(noise, captions[rnd_perm4])
            d_fake, c_fake = self.netD(fake, captions[rnd_perm4])
            lossD_fake_G = self.bce_loss(d_fake, lbl_real)
            lossC_fake_G = self.bce_loss(c_fake, labels[rnd_perm4])
            netG_loss = lossD_fake_G + lossC_fake_G
            netG_loss.backward()
            self.optimizerG.step()

            netd_loss_sum += netD_loss.data[0]
            netg_loss_sum += netG_loss.data[0]
            ### print progress info ###
            print(
                'Epoch %d/%d, %.2f%% completed. Loss_NetD: %.4f, Loss_NetG: %.4f'
                % (epoch, self.epochs,
                   (float(i) / len(self.trainset_loader)) * 100,
                   netD_loss.data[0], netG_loss.data[0]))

        end_time = time()
        netd_avg_loss = netd_loss_sum / len(self.trainset_loader)
        netg_avg_loss = netg_loss_sum / len(self.trainset_loader)
        epoch_time = (end_time - start_time) / 60
        log_msg = '-------------------------------------------\n'
        log_msg += 'Epoch %d took %.2f minutes\n' % (epoch, epoch_time)
        log_msg += 'NetD average loss: %.4f, NetG average loss: %.4f\n\n' % (
            netd_avg_loss, netg_avg_loss)
        print(log_msg)
        with open(os.path.join(self.save_dir, 'training_log.txt'),
                  'a') as log_file:
            log_file.write(log_msg)
        return netd_avg_loss, netg_avg_loss

    # eval epoch
    def evalEpoch(self, epoch):
        #self.netD.eval()
        #self.netG.eval()
        return 0

    # draws and saves the loss graph upto the current epoch
    def saveGraph(self, netd_losses, netg_losses):
        plt.plot(netd_losses, color='red', label='NetD Loss')
        plt.plot(netg_losses, color='blue', label='NetG Loss')
        plt.xlabel('epoch')
        plt.ylabel('error')
        plt.legend(loc='best')
        plt.savefig(os.path.join(self.save_dir, 'loss_graph.png'))
        plt.close()

    # save after each epoch
    def saveCheckpoints(self, epoch):
        if epoch % self.save_after == 0:
            name_netD = "netd_checkpoints/netD_" + self.save_prefix + "_epoch_" + str(
                epoch) + ".pth"
            name_netG = "netg_checkpoints/netG_" + self.save_prefix + "_epoch_" + str(
                epoch) + ".pth"
            torch.save(self.netD.state_dict(),
                       os.path.join(self.save_dir, name_netD))
            torch.save(self.netG.state_dict(),
                       os.path.join(self.save_dir, name_netG))
            print("Checkpoints for epoch %d saved successfuly" % (epoch))

    # load checkpoints to continue training
    def loadCheckpoints(self):
        self.netG.load_state_dict(torch.load(self.netG_path))
        self.netD.load_state_dict(torch.load(self.netD_path))
        print("Checkpoints loaded successfuly")