コード例 #1
0
det_field = COCOControlSequenceField(detections_path=os.path.join(coco_root, 'coco_detections.hdf5'),
                                     classes_path=os.path.join(coco_root, 'object_class_list.txt'),
                                     pad_init=False, padding_idx=-1, all_boxes=False, fix_length=20)

text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, remove_punctuation=True, fix_length=20)

dataset = COCOEntities(image_field, det_field, text_field,
                       img_root='',
                       ann_root=os.path.join(coco_root, 'annotations'),
                       entities_file=os.path.join(coco_root, 'coco_entities.json'),
                       id_root=os.path.join(coco_root, 'annotations'))

train_dataset, val_dataset, _ = dataset.splits
text_field.build_vocab(train_dataset, val_dataset, min_freq=5)

test_dataset = COCOEntities(image_field, det_field, RawField(),
                            img_root='',
                            ann_root=os.path.join(coco_root, 'annotations'),
                            entities_file=os.path.join(coco_root, 'coco_entities.json'),
                            id_root=os.path.join(coco_root, 'annotations'),
                            filtering=True)

_, val_dataset, _ = test_dataset.splits

if opt.sample_rl or opt.sample_rl_nw:
    train_dataset.fields['text'] = RawField()
    train_dataset_raw = PairedDataset(train_dataset.examples, {'image': image_field, 'detection': det_field, 'text': RawField()})
    ref_caps_train = list(train_dataset_raw.text)
    cider_train = evaluation.Cider(evaluation.PTBTokenizer.tokenize(ref_caps_train))

dataloader_train = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.nb_workers)
コード例 #2
0
    idx_vs_path=os.path.join(flickr_root, 'idx_2_vs_flickr.json'),
    cap_verb_path=os.path.join(flickr_root, 'cap_2_verb_flickr.json'),
    cap_classes_path=os.path.join(flickr_root, 'cap_2_classes_flickr.json'),
    idx_v_og_path=os.path.join(flickr_root, 'idx_2_v_og_flickr.json'),
    vocab_list_path=os.path.join(flickr_root, 'vocab_tv_flickr.json'),
    fix_length=10,
    visual=False)

text_field = TextField(init_token='<bos>',
                       eos_token='<eos>',
                       lower=True,
                       remove_punctuation=True,
                       fix_length=20)

test_dataset = FlickrEntities(image_field,
                              RawField(),
                              det_field,
                              img_root='',
                              ann_file=os.path.join(
                                  flickr_root, 'flickr30k_annotations.json'),
                              entities_root=flickr_entities_root,
                              verb_filter=True)

train_dataset, val_dataset, test_dataset = test_dataset.splits
test_dataset = DictionaryDataset(test_dataset.examples, test_dataset.fields,
                                 'image')
dataloader_test = DataLoader(test_dataset,
                             batch_size=opt.batch_size,
                             num_workers=opt.nb_workers)
train_dataset = DictionaryDataset(train_dataset.examples, train_dataset.fields,
                                  'image')
コード例 #3
0
                                    precomp_glove_path=os.path.join(coco_root, 'object_class_glove.pkl'),
                                    vocab_path=os.path.join(coco_root, 'vocab_tv.json'), 
                                    vlem_2_v_og_path=os.path.join(coco_root, 'vlem_2_vog_coco.json'), 
                                    cls_seq_path=os.path.join('saved_data/coco', 'img_cap_v_2_class_self.json'),
                                    fix_length=10, max_detections=20, gt_verb=opt.gt)

text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, remove_punctuation=True, fix_length=20)

# define the datasets
dataset = COCOEntities(image_field, det_field, text_field,
                        img_root='',
                        ann_root=os.path.join(coco_root, 'annotations'),
                        entities_file=os.path.join(coco_root, 'coco_entities.json'),
                        id_root=os.path.join(coco_root, 'annotations'))

test_dataset = COCOEntities(image_field, det_field, RawField(),
                            img_root='',
                            ann_root=os.path.join(coco_root, 'annotations'),
                            entities_file=os.path.join(coco_root, 'coco_entities.json'),
                            id_root=os.path.join(coco_root, 'annotations'),
                            filtering=True,
                            det_filtering=opt.det)

train_dataset, val_dataset, _ = dataset.splits
text_field.build_vocab(train_dataset, val_dataset, min_freq=5)

# define the dataloader
_, _, test_dataset = test_dataset.splits
test_dataset = DictionaryDataset(test_dataset.examples, test_dataset.fields, 'image')
dataloader_test = DataLoader(test_dataset, batch_size=opt.batch_size, num_workers=opt.nb_workers)