def get_tuple(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1) -> DataTuple: # Decide which QA datasets would be used in pre-training. # Options: vqa, gqa, visual7w # Note: visual7w is a part of vgqa, we take the name here. qa_sets = args.qa_sets if qa_sets is not None: qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(",")) # Build dataset, data loader, and evaluator. dset = LXMERTDataset(splits, qa_sets=qa_sets) tset = LXMERTTorchDataset(dset, topk) data_loader = DataLoader(tset, batch_size=bs, shuffle=shuffle, num_workers=args.num_workers, collate_fn=lambda x: x, drop_last=drop_last, pin_memory=True) evaluator = LXMERTEvaluator(dset) print() return DataTuple(dataset=dset, torchdset=tset, loader=data_loader, evaluator=evaluator)
def get_tuple(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1, num_workers = 0, limit_source = [], restrict_source = None) -> DataTuple: # Decide which QA datasets would be used in pre-training. # Options: vqa, gqa, visual7w # Note: visual7w is a part of vgqa, we take the name here. qa_sets = args.qa_sets if qa_sets is not None: qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(",")) # Build dataset, data loader, and evaluator. dset = LXMERTDataset(splits, qa_sets=qa_sets) tset = LXMERTTorchDataset( dset, topk, limit_source = limit_source, use_visual_tag_flag = args.get("allow_tag_for_eval", False) # As this function is called for evaulation in our context ) data_loader = DataLoader( tset, batch_size=bs, shuffle=shuffle, num_workers=num_workers, collate_fn= tset.custom_collact_fn if args.get('custom_collact_fn', False) else lambda x: x, drop_last=drop_last, pin_memory=args.get("pin_memory", True) ) evaluator = LXMERTEvaluator(dset) print() return DataTuple(dataset=dset, torchdset=tset, loader=data_loader, evaluator=evaluator, vl_torchdset=tset)
# # def get_subsets_dl(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1) -> DataTuple: # # # Build dataset and data loader # tset = IXTorchDataset(splits) # data_loader = DataLoader( # tset, batch_size=bs, # shuffle=shuffle, num_workers=args.num_workers, # collate_fn=lambda x: x, # drop_last=drop_last, pin_memory=True # ) # # return data_loader # train_tuple = get_tuple(args.train, args.batch_size, shuffle=True, drop_last=True) train_dset = LXMERTDataset(args.train, qa_sets=args.qa_sets) # train_subsets_dl = get_subsets_dl(args.train, args.batch_size*220, shuffle=True, drop_last=True) valid_batch_size = 2048 if args.multiGPU else 512 #valid_tuple = get_tuple(args.valid, valid_batch_size, shuffle=False, drop_last=False) #, topk=5000) class InputFeatures(object): """A single set of features of data.""" def __init__(self, input_ids, input_mask, segment_ids, lm_label_ids, visual_feats, obj_labels, is_matched, ans): self.input_ids = input_ids self.input_mask = input_mask self.segment_ids = segment_ids self.lm_label_ids = lm_label_ids self.visual_feats = visual_feats
def get_tuple_hybrid(splits: str, bs: int, shuffle=False, drop_last=False, num_workers=0, topk=-1, image_only_splits=None, text_only_splits = None, limit_source = [], restrict_source = None) -> DataTuple: # Decide which QA datasets would be used in pre-training. # Options: vqa, gqa, visual7w # Note: visual7w is a part of vgqa, we take the name here. qa_sets = args.qa_sets if qa_sets is not None: qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(",")) # Three type of datasets: v&l, language, vision datasets_list_torch = [] datasets_list = [] if splits is not None: vl_dataset = LXMERTDataset(splits, qa_sets=qa_sets) vl_dataset_torch = LXMERTTorchDataset(vl_dataset, topk, limit_source = limit_source, randomized_pairing = args.get("randomized_pairing", False), use_visual_tag_flag = args.get("use_visual_tag_flag", False)) datasets_list.append(vl_dataset) datasets_list_torch.append(vl_dataset_torch) if text_only_splits is not None: text_only_datasets = [] for split in text_only_splits.split("+"): if not("book_corpus" in split or "sbu" in split): text_only_dataset = LXMERTDataset(split, qa_sets=qa_sets) text_only_dataset_torch = LXMERTTorchDataset(text_only_dataset, topk, text_only=True, limit_source=limit_source) datasets_list.append(text_only_dataset) datasets_list_torch.append(text_only_dataset_torch) text_only_datasets.append(text_only_dataset_torch) else: text_only_dataset = None if "book_corpus" in split and args.get("text_shared_memory", False): text_class = GeneralCorpusNP else: #text_class = GeneralCorpus pass text_only_dataset_torch = text_class(ann_file=args.book_corpus_path if "book_corpus" in split else args.sbu_path, pretrained_model_name="bert-base-uncased", tokenizer=None, seq_len=args.get("text_only_max_seq_len", 64), min_seq_len=args.get("text_only_min_seq_len", 64), encoding="utf-8", on_memory=True) datasets_list.append(text_only_dataset) datasets_list_torch.append(text_only_dataset_torch) text_only_datasets.append(text_only_dataset_torch) if image_only_splits is not None: if image_only_splits != "": image_only_dataset = LXMERTDataset(image_only_splits, qa_sets=qa_sets) image_only_dataset_torch = LXMERTTorchDataset(image_only_dataset, topk, image_only=True, use_visual_tag_flag = args.get("use_visual_tag_flag", False)) datasets_list.append(image_only_dataset) datasets_list_torch.append(image_only_dataset_torch) if args.get("add_adhoc_google_cc_image_only", False): google_cc_dataset = LXMERTDataset("google_cc_train", qa_sets=qa_sets) google_cc_dataset_torch = LXMERTTorchDataset(google_cc_dataset, topk, image_only=True, use_visual_tag_flag=args.get("use_visual_tag_flag", False), available_split_for_cc = args.get("available_split_for_cc", [0])) datasets_list.append(google_cc_dataset) datasets_list_torch.append(google_cc_dataset_torch) if args.get("add_adhoc_open_image_image_only", False): open_image_dataset = LXMERTDataset("open_images_train", qa_sets=qa_sets) open_image_torch = LXMERTTorchDataset(open_image_dataset, topk, image_only=True, use_visual_tag_flag=args.get("use_visual_tag_flag", False)) datasets_list.append(open_image_dataset) datasets_list_torch.append(open_image_torch) # Merge different datasets merged_dataset = ConcateDataset(datasets_list_torch) if args.task_qa: merged_dataset.answer_table = datasets_list[0].answer_table if datasets_list[0] is not None else None batch_sampler = CustomBatchSampler(merged_dataset.datasets, bs, upsample_ratios=args.get("upsample_ratios", [1,1,1])) try: custom_collact_fn = datasets_list_torch[0].custom_collact_fn if args.get('custom_collact_fn', False) else lambda x: x except: custom_collact_fn = datasets_list_torch[-1].custom_collact_fn if args.get('custom_collact_fn', False) else lambda x: x data_loader = DataLoader( merged_dataset, num_workers=num_workers, batch_sampler=batch_sampler, collate_fn=custom_collact_fn, pin_memory=args.get("pin_memory", True) ) if args.task_qa: evaluator = LXMERTEvaluator(datasets_list[0]) if datasets_list[0] is not None else None # The evaluator is for task_qa so no need to have it else: evaluator = None print() if splits is not None: vl_torchdset = vl_dataset_torch else: vl_torchdset = datasets_list_torch[-1] # the last dataset return DataTuple(dataset=merged_dataset, torchdset=merged_dataset, loader=data_loader, evaluator=evaluator, vl_torchdset=vl_torchdset)