Esempio n. 1
0
                                  entities_root=flickr_entities_root,
                                  det_filter=True)

train_dataset, val_dataset, _ = dataset.splits
text_field.build_vocab(train_dataset, val_dataset, min_freq=5)

# define the dataloader
_, _, test_dataset = test_dataset.splits
test_dataset = DictionaryDataset(test_dataset.examples, test_dataset.fields,
                                 'image')
dataloader_test = DataLoader(test_dataset,
                             batch_size=opt.batch_size,
                             num_workers=opt.nb_workers)

# captioning model
model = ControllableCaptioningModel(20, len(text_field.vocab), text_field.vocab.stoi['<bos>'], \
        h2_first_lstm=opt_cap.h2_first_lstm, img_second_lstm=opt_cap.img_second_lstm, dataset='flickr').to(device)
model.eval()
model.load_state_dict(saved_data['state_dict'])

# region sort model
re_sort_net = S_SSP(dataset='flickr').cuda()
re_sort_net.load_state_dict(
    torch.load(os.path.join('saved_model/flickr_npos_fc_v', 'model-tr.pth')))
re_sort_net.eval()

sinkhorn_len = 10
sinkhorn_net = SinkhornNet(sinkhorn_len, 20, 0.1).cuda()
sinkhorn_net.load_state_dict(
    torch.load(os.path.join('saved_model/flickr_sinkhorn', 'model-sh.pth')))
sinkhorn_net.eval()
                                  entities_root=flickr_entities_root)

    nw_aligner = NWNounAligner(pre_comp_file=os.path.join(
        flickr_root, '%s_noun_glove.pkl' % opt_test.dataset),
                               normalized=True)

else:
    raise NotImplementedError

train_dataset, val_dataset, _ = dataset.splits
text_field.build_vocab(train_dataset, val_dataset, min_freq=5)

if opt_test.exp_name == 'ours':
    model = ControllableCaptioningModel(
        20,
        len(text_field.vocab),
        text_field.vocab.stoi['<bos>'],
        h2_first_lstm=opt.h2_first_lstm,
        img_second_lstm=opt.img_second_lstm).to(device)
elif opt_test.exp_name == 'ours_without_visual_sentinel':
    model = ControllableCaptioningModel_NoVisualSentinel(
        20,
        len(text_field.vocab),
        text_field.vocab.stoi['<bos>'],
        h2_first_lstm=opt.h2_first_lstm,
        img_second_lstm=opt.img_second_lstm).to(device)
elif opt_test.exp_name == 'ours_with_single_sentinel':
    model = ControllableCaptioningModel_SingleSentinel(
        20,
        len(text_field.vocab),
        text_field.vocab.stoi['<bos>'],
        h2_first_lstm=opt.h2_first_lstm,
Esempio n. 3
0
                            filtering=True)

_, val_dataset, _ = test_dataset.splits

if opt.sample_rl or opt.sample_rl_nw:
    train_dataset.fields['text'] = RawField()
    train_dataset_raw = PairedDataset(train_dataset.examples, {'image': image_field, 'detection': det_field, 'text': RawField()})
    ref_caps_train = list(train_dataset_raw.text)
    cider_train = evaluation.Cider(evaluation.PTBTokenizer.tokenize(ref_caps_train))

dataloader_train = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.nb_workers)

val_dataset.fields['text'] = RawField()
dataloader_val = DataLoader(val_dataset, batch_size=16, num_workers=opt.nb_workers)

model = ControllableCaptioningModel(20, len(text_field.vocab), text_field.vocab.stoi['<bos>'],
                                    h2_first_lstm=opt.h2_first_lstm, img_second_lstm=opt.img_second_lstm).to(device)

optim = Adam(model.parameters(), lr=opt.lr)
scheduler = StepLR(optim, step_size=opt.step_size, gamma=opt.gamma)
loss_fn = NLLLoss()
loss_fn_gate = NLLLoss(ignore_index=-1)

start_epoch = 0
best_cider = .0
patience = 0
if opt.sample_rl or opt.sample_rl_nw:
    saved_data = torch.load('saved_models/%s_best.pth' % opt.exp_name)
    print("Loading from epoch %d, with validation CIDER %.02f" % (saved_data['epoch'], saved_data['val_cider']))
    start_epoch = saved_data['epoch'] + 1
    model.load_state_dict(saved_data['state_dict'])
    best_cider = saved_data['best_cider']