def main():
    if not os.path.exists(os.path.join(config.tensorboard_dir, config.name)):
        os.makedirs(os.path.join(config.tensorboard_dir, config.name))
    if not os.path.exists(os.path.join(config.checkpoint_dir, config.name)):
        os.makedirs(os.path.join(config.checkpoint_dir, config.name))

    device = torch.device('cuda:0' if config.use_cuda else 'cpu')
    models = GAN(config).to(device)
    if config.load_epoch != 0:
        load_checkpoints(models, config.checkpoint_dir, config.name,
                         config.load_epoch)

    if config.is_train:
        models.train()
        writer = SummaryWriter(
            log_dir=os.path.join(config.tensorboard_dir, config.name))
        train(models, writer, device)
    else:
        models.eval()
        test(models, device)
示例#2
0
if __name__ == '__main__':
    imageDir = '...celeb_images/Part 1'

    image_path_list = []
    for file in os.listdir(imageDir):
        image_path_list.append(os.path.join(imageDir, file))

    image = np.empty([len(image_path_list), 64, 64, 3])
    for indx, imagePath in enumerate(image_path_list):
        if(indx>5000):
            break
        im = Image.open(imagePath).convert('RGB')
        im = im.resize((64, 64))
        im=np.array(im,dtype=np.float32)
        im=im/255
        im=Normalize(im, [0.5,0.5,0.5], [0.5,0.5,0.5])
        image[indx,:,:,:] = im

    image = image[:5000,:,:,:]


    gan = GAN()
    gen, dis = gan.train(image,200)

    noise = Variable(torch.randn(1, 100, 1, 1)).float().cuda()
    im = gen.feed_forward(noise)
    im = im.permute(0, 2, 3, 1)
    im = im.squeeze(0).cpu().detach().numpy()*255
    im = DeNormalize(im,[0.5,0.5,0.5],[0.5,0.5,0.5])

    plt.imshow(im)