def main(args):
    # Image preprocessing
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))])

    # Load vocabulary
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build models
    encoder = Encoder(args.embed_size).eval()
    decoder = Decoder(args.embed_size, args.hidden_size, len(vocab), args.num_layers)
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    # Load the trained model parameters
    encoder.load_state_dict(torch.load(args.encoder_path))
    decoder.load_state_dict(torch.load(args.decoder_path))

    # Prepare an image
    image = load_image(args.image, transform)
    image_tensor = image.to(device)

    # Generate an caption from the image
    feature = encoder(image_tensor)
    sampled_ids = decoder.sample(feature)
    sampled_ids = sampled_ids[0].cpu().numpy()  # (1, max_seq_length) -> (max_seq_length)

    # Convert word_ids to words
    sampled_caption = []
    for word_id in sampled_ids:
        word = vocab.idx2word[word_id]
        sampled_caption.append(word)
        if word == '<end>':
            break
    sentence = ' '.join(sampled_caption)

    # Print out the image and the generated caption
    print(sentence)
    image = Image.open(args.image)
    plt.imshow(np.asarray(image))
예제 #2
0
def infer(args):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    #加载词汇表
    with open(args['vocab_path'], 'rb') as f:
        vocab = pickle.load(f)
    with open(args['data_path'], 'rb') as f:
        Data = pickle.load(f)

    #在测试阶段使用model.eval(),将BN和Dropout固定,使用训练好的值
    encoder = Encoder(args['embed_size'], args['pooling_kernel']).eval().cuda()
    decoder = Decoder(args['embed_size'], args['hidden_size'], len(vocab),
                      args['num_layers']).cuda()

    #加载训练时的参数
    encoder.load_state_dict(torch.load(args['encoder_path']))
    decoder.load_state_dict(torch.load(args['decoder_path']))

    #加载图片
    image = load_image(args['val_img_path'], transform,
                       (args['resize'], args['resize']))
    image_tensor = image.cuda()

    #送入模型并输出caption
    feature = encoder(image_tensor)
    index = decoder.sample(feature)
    index = index[0].cpu().numpy()

    #将index转化成word
    words = []
    for ind in index:
        word = vocab.idx2word[word_id]
        words.append(word)
        if word == '<end>':
            break

    sentence = ' '.join(words[1:-1])  #去掉开头和结尾的特殊字符<start>,<end>
    print(sentence)
    image = Image.open(args.image)
    plt.imshow(np.asarray(image))
예제 #3
0

NetG = Decoder(nc, ngf, nz).to(device)
NetD = Discriminator(imageSize, nc, ndf, nz).to(device)
NetE = Encoder(imageSize, nc, ngf, nz).to(device)
Sampler = Sampler().to(device)

NetE.apply(weights_init)
NetG.apply(weights_init)
NetD.apply(weights_init)

# load weights
if opt.netE != '':
    NetE.load_state_dict(torch.load(opt.netE))
if opt.netG != '':
    NetG.load_state_dict(torch.load(opt.netG))
if opt.netD != '':
    NetD.load_state_dict(torch.load(opt.netD))

optimizer_encorder = optim.RMSprop(params=NetE.parameters(),
                                   lr=lr,
                                   alpha=0.9,
                                   eps=1e-8,
                                   weight_decay=0,
                                   momentum=0,
                                   centered=False)
optimizer_decoder = optim.RMSprop(params=NetG.parameters(),
                                  lr=lr,
                                  alpha=0.9,
                                  eps=1e-8,
                                  weight_decay=0,
예제 #4
0
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


NetE = Encoder(imageSize, nc, ngf, nz).to(device)
NetG = Decoder(nc, ngf, nz).to(device)

Sampler = Sampler().to(device)

NetE.apply(weights_init)
NetG.apply(weights_init)

# load weights
NetE.load_state_dict(torch.load(opt.netE, map_location=opt.cuda))
NetG.load_state_dict(torch.load(opt.netG, map_location=opt.cuda))

NetE.eval()
NetG.eval()

# 21 attributes
attributes = [
    '5_o_Clock_Shadow', 'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair',
    'Brown_Hair', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'Male',
    'Mustache', 'No_Beard', 'Receding_Hairline', 'Sideburns', 'Smiling',
    'Straight_Hair', 'Wavy_Hair', 'Wearing_Hat', 'Wearing_Lipstick', 'Young'
]

torch.set_grad_enabled(False)

attributes_z = [torch.zeros([100, 1, 1])] * len(attributes)
예제 #5
0
def main(args):
    # Image preprocessing
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))])

    # Load vocabulary
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build models
    encoder = Encoder(args.embed_size).eval()
    decoder = Decoder(args.embed_size, args.hidden_size, len(vocab), args.num_layers)
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    # Load the trained model parameters
    encoder.load_state_dict(torch.load(args.encoder_path))
    decoder.load_state_dict(torch.load(args.decoder_path))

    # load validation image set
    lis = os.listdir(args.image_dir)
    num = len(lis)
    captions = []
    for i in range(num):

        im_pth = os.path.join(args.image_dir, lis[i])

        image = load_image(im_pth, transform)
        image_tensor = image.to(device)

        # Generate an caption from the image
        feature = encoder(image_tensor)
        sampled_ids = decoder.sample(feature)
        sampled_ids = sampled_ids[0].cpu().numpy()  # (1, max_seq_length) -> (max_seq_length)

        # Convert word_ids to words
        sampled_caption = []
        for word_id in sampled_ids:
            word = vocab.idx2word[word_id]
            if word == '<start>':
                continue
            if word == '<end>':
                break

            sampled_caption.append(word)

        sentence = ' '.join(sampled_caption)
        cap= {}
        id = int(lis[i][14:-4]) #extract image id
        cap['image_id'] = id
        cap['caption'] =  sentence
        captions.append(cap)
    # save results
    with open('captions_res.json', 'w') as f:
        json.dump(captions, f)

    # evaluation with coco-caption evaluation tools
    coco = COCO(args.caption_path)
    cocoRes = coco.loadRes('captions_res.json')
    cocoEval = COCOEvalCap(coco, cocoRes)
    cocoEval.params['image_id'] = cocoRes.getImgIds()
    cocoEval.evaluate()