def fetch_multimnist_image(label):
    """Return a random image from the MultiMNIST dataset with label.

    @param label: string
                  a string of up to 4 digits
    @return: torch.autograd.Variable
             MultiMNIST image
    """
    dataset = MultiMNIST('./data',
                         train=False,
                         download=True,
                         transform=transforms.ToTensor(),
                         target_transform=charlist_tensor)
    images = dataset.test_data
    labels = dataset.test_labels
    n_rows = len(images)

    images = []
    for i in xrange(n_rows):
        image = images[i]
        text = labels[i]
        if tensor_to_string(text.squeeze(0)) == label:
            images.append(image)

    if len(images) == 0:
        sys.exit('No images with label (%s) found.' % label)

    images = torch.cat(images).cpu().numpy()
    ix = np.random.choice(np.arange(images.shape[0]))
    image = images[ix]
    image = torch.from_numpy(image).float()
    image = image.unsqueeze(0)
    return Variable(image, volatile=True)
Esempio n. 2
0
def fetch_multimnist_image(_label):
    if _label == EMPTY:
        _label = ''

    loader = torch.utils.data.DataLoader(datasets.MultiMNIST(
        './data',
        train=False,
        download=True,
        transform=transforms.ToTensor(),
        target_transform=charlist_tensor),
                                         batch_size=1,
                                         shuffle=True)

    images = []
    for image, label in loader:
        if tensor_to_string(label.squeeze(0)) == _label:
            images.append(image)

    if len(images) == 0:
        sys.exit('No images with label (%s) found.' % _label)

    images = torch.cat(images).cpu().numpy()
    ix = np.random.choice(np.arange(images.shape[0]))
    image = images[ix]

    image = torch.from_numpy(image).float()
    image = image.unsqueeze(0)
    return Variable(image, volatile=True)
    # mode 4: generate conditioned on image and text
    elif args.condition_on_text and args.condition_on_image:
        image = fetch_multimnist_image(args.condition_on_image)
        text = fetch_multimnist_text(args.condition_on_text)
        if args.cuda:
            image = image.cuda()
            text = text.cuda()
        mu, logvar = model.infer(1, image=image, text=text)
        std = logvar.mul(0.5).exp_()

    # sample from uniform gaussian
    sample = Variable(torch.randn(args.n_samples, model.n_latents))
    if args.cuda:
        sample = sample.cuda()
    # sample from particular gaussian by multiplying + adding
    mu = mu.expand_as(sample)
    std = std.expand_as(sample)
    sample = sample.mul(std).add_(mu)
    # generate image and text
    img_recon = F.sigmoid(model.image_decoder(sample)).cpu().data
    txt_recon = F.log_softmax(model.text_decoder(sample), dim=1).cpu().data
    txt_recon = torch.max(txt_recon, dim=2)[1]

    # save image samples to filesystem
    save_image(img_recon.view(args.n_samples, 1, 50, 50), './sample_image.png')
    # save text samples to filesystem
    with open('./sample_text.txt', 'w') as fp:
        for i in xrange(text_recon.size(0)):
            text_recon_str = tensor_to_string(text_recon[i])
            fp.write('Text (%d): %s\n' % (i, text_recon_str))
Esempio n. 4
0
            try:
                kl_lambda = next(schedule)
            except:
                pass

        train(epoch)
        loss = test()

        is_best = loss < best_loss
        best_loss = min(loss, best_loss)

        save_checkpoint({
            'state_dict': vae.state_dict(),
            'best_loss': best_loss,
            'n_latents': args.n_latents,
            'optimizer' : optimizer.state_dict(),
        }, is_best, folder='./trained_models/text_only')

        sample = Variable(torch.randn(64, args.n_latents))
        if args.cuda:
           sample = sample.cuda()

        sample = vae.decoder.generate(sample).cpu().data.long()            
        sample_texts = []
        for i in xrange(sample.size(0)):
            text = tensor_to_string(sample[i])
            sample_texts.append(text)
        
        with open('./results/text_only/sample_text_epoch%d.txt' % epoch, 'w') as fp:
            fp.writelines(sample_texts)