コード例 #1
0
def get_tuple(splits: str,
              bs: int,
              shuffle=False,
              drop_last=False,
              topk=-1) -> DataTuple:
    # Decide which QA datasets would be used in pre-training.
    # Options: vqa, gqa, visual7w
    # Note: visual7w is a part of vgqa, we take the name here.
    qa_sets = args.qa_sets
    if qa_sets is not None:
        qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(","))

    # Build dataset, data loader, and evaluator.
    dset = LXMERTDataset(splits, qa_sets=qa_sets)
    tset = LXMERTTorchDataset(dset, topk)
    data_loader = DataLoader(tset,
                             batch_size=bs,
                             shuffle=shuffle,
                             num_workers=args.num_workers,
                             collate_fn=lambda x: x,
                             drop_last=drop_last,
                             pin_memory=True)
    evaluator = LXMERTEvaluator(dset)
    print()

    return DataTuple(dataset=dset,
                     torchdset=tset,
                     loader=data_loader,
                     evaluator=evaluator)
コード例 #2
0
ファイル: lxmert_pretrain.py プロジェクト: uclanlp/visualbert
def get_tuple(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1, num_workers = 0, limit_source = [], restrict_source = None) -> DataTuple:
    # Decide which QA datasets would be used in pre-training.
    # Options: vqa, gqa, visual7w
    # Note: visual7w is a part of vgqa, we take the name here.
    qa_sets = args.qa_sets
    if qa_sets is not None:
        qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(","))

    # Build dataset, data loader, and evaluator.
    dset = LXMERTDataset(splits, qa_sets=qa_sets)
    tset = LXMERTTorchDataset(
        dset, 
        topk, 
        limit_source = limit_source, 
        use_visual_tag_flag = args.get("allow_tag_for_eval", False) # As this function is called for evaulation in our context
        )

    data_loader = DataLoader(
        tset, batch_size=bs,
        shuffle=shuffle, num_workers=num_workers,
        collate_fn= tset.custom_collact_fn if args.get('custom_collact_fn', False) else lambda x: x,
        drop_last=drop_last, pin_memory=args.get("pin_memory", True)
    )
    evaluator = LXMERTEvaluator(dset)
    print()

    return DataTuple(dataset=dset, torchdset=tset, loader=data_loader, evaluator=evaluator, vl_torchdset=tset)
コード例 #3
0
#
# def get_subsets_dl(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1) -> DataTuple:
#
#     # Build dataset and data loader
#     tset = IXTorchDataset(splits)
#     data_loader = DataLoader(
#         tset, batch_size=bs,
#         shuffle=shuffle, num_workers=args.num_workers,
#         collate_fn=lambda x: x,
#         drop_last=drop_last, pin_memory=True
#     )
#
#     return data_loader

# train_tuple = get_tuple(args.train, args.batch_size, shuffle=True, drop_last=True)
train_dset = LXMERTDataset(args.train, qa_sets=args.qa_sets)
# train_subsets_dl = get_subsets_dl(args.train, args.batch_size*220, shuffle=True, drop_last=True)
valid_batch_size = 2048 if args.multiGPU else 512
#valid_tuple = get_tuple(args.valid, valid_batch_size, shuffle=False, drop_last=False) #, topk=5000)


class InputFeatures(object):
    """A single set of features of data."""
    def __init__(self, input_ids, input_mask, segment_ids, lm_label_ids,
                 visual_feats, obj_labels, is_matched, ans):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.lm_label_ids = lm_label_ids

        self.visual_feats = visual_feats
コード例 #4
0
ファイル: lxmert_pretrain.py プロジェクト: uclanlp/visualbert
def get_tuple_hybrid(splits: str, bs: int, shuffle=False, drop_last=False, num_workers=0, topk=-1, image_only_splits=None, text_only_splits = None, limit_source = [], restrict_source = None) -> DataTuple:
    # Decide which QA datasets would be used in pre-training.
    # Options: vqa, gqa, visual7w
    # Note: visual7w is a part of vgqa, we take the name here.
    qa_sets = args.qa_sets
    if qa_sets is not None:
        qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(","))

    # Three type of datasets: v&l, language, vision
    datasets_list_torch = []
    datasets_list = []

    if splits is not None:
        vl_dataset = LXMERTDataset(splits, qa_sets=qa_sets)
        vl_dataset_torch = LXMERTTorchDataset(vl_dataset, topk, limit_source = limit_source, randomized_pairing = args.get("randomized_pairing", False),  use_visual_tag_flag = args.get("use_visual_tag_flag", False))
        datasets_list.append(vl_dataset)
        datasets_list_torch.append(vl_dataset_torch)

    if text_only_splits is not None:
        text_only_datasets = []
        for split in text_only_splits.split("+"):
            if not("book_corpus" in split or "sbu" in split):
                text_only_dataset = LXMERTDataset(split, qa_sets=qa_sets)
                text_only_dataset_torch = LXMERTTorchDataset(text_only_dataset, topk, text_only=True, limit_source=limit_source)
                
                datasets_list.append(text_only_dataset)
                datasets_list_torch.append(text_only_dataset_torch)
                text_only_datasets.append(text_only_dataset_torch)
            else:
                text_only_dataset = None
                if "book_corpus" in split and args.get("text_shared_memory", False):
                    text_class = GeneralCorpusNP
                else:
                    #text_class = GeneralCorpus
                    pass
                text_only_dataset_torch = text_class(ann_file=args.book_corpus_path if "book_corpus" in split else args.sbu_path, pretrained_model_name="bert-base-uncased", tokenizer=None, seq_len=args.get("text_only_max_seq_len", 64), min_seq_len=args.get("text_only_min_seq_len", 64), encoding="utf-8", on_memory=True)
                datasets_list.append(text_only_dataset)
                datasets_list_torch.append(text_only_dataset_torch)
                text_only_datasets.append(text_only_dataset_torch)

    if image_only_splits is not None:
        if image_only_splits != "":
            image_only_dataset = LXMERTDataset(image_only_splits, qa_sets=qa_sets)
            image_only_dataset_torch = LXMERTTorchDataset(image_only_dataset, topk, image_only=True, use_visual_tag_flag = args.get("use_visual_tag_flag", False))
            datasets_list.append(image_only_dataset)
            datasets_list_torch.append(image_only_dataset_torch)

        if args.get("add_adhoc_google_cc_image_only", False):
            google_cc_dataset = LXMERTDataset("google_cc_train", qa_sets=qa_sets)
            google_cc_dataset_torch = LXMERTTorchDataset(google_cc_dataset, topk, image_only=True, use_visual_tag_flag=args.get("use_visual_tag_flag", False), available_split_for_cc = args.get("available_split_for_cc", [0]))
            datasets_list.append(google_cc_dataset)
            datasets_list_torch.append(google_cc_dataset_torch)
        
        if args.get("add_adhoc_open_image_image_only", False):
            open_image_dataset = LXMERTDataset("open_images_train", qa_sets=qa_sets)
            open_image_torch = LXMERTTorchDataset(open_image_dataset, topk, image_only=True, use_visual_tag_flag=args.get("use_visual_tag_flag", False))
            datasets_list.append(open_image_dataset)
            datasets_list_torch.append(open_image_torch)

    # Merge different datasets
    merged_dataset = ConcateDataset(datasets_list_torch)

    if args.task_qa:
        merged_dataset.answer_table = datasets_list[0].answer_table if datasets_list[0] is not None else None
    
    batch_sampler = CustomBatchSampler(merged_dataset.datasets, bs, upsample_ratios=args.get("upsample_ratios", [1,1,1]))
    try:
        custom_collact_fn = datasets_list_torch[0].custom_collact_fn if args.get('custom_collact_fn', False) else lambda x: x
    except:
        custom_collact_fn = datasets_list_torch[-1].custom_collact_fn if args.get('custom_collact_fn', False) else lambda x: x
    data_loader = DataLoader(
        merged_dataset, num_workers=num_workers,
        batch_sampler=batch_sampler,
        collate_fn=custom_collact_fn,
        pin_memory=args.get("pin_memory", True)
    )
    if args.task_qa:
        evaluator = LXMERTEvaluator(datasets_list[0]) if datasets_list[0] is not None else None  # The evaluator is for task_qa so no need to have it
    else:
        evaluator = None
    print()

    if splits is not None:
        vl_torchdset = vl_dataset_torch
    else:
        vl_torchdset = datasets_list_torch[-1] # the last dataset

    return DataTuple(dataset=merged_dataset, torchdset=merged_dataset, loader=data_loader, evaluator=evaluator, vl_torchdset=vl_torchdset)