Esempio n. 1
0
    def __init__(self, max_seq_length):
        super().__init__()
        self.max_seq_length = max_seq_length

        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                                       do_lower_case=True)

        # Build model
        set_visual_config(args)
        self.model = LXRTPretraining.from_pretrained(
            "bert-base-uncased",
            task_mask_lm=args.task_mask_lm,
            task_obj_predict=args.task_obj_predict,
            task_matched=args.task_matched,
            task_qa=args.task_qa,
            visual_losses=args.visual_losses,
            num_answers=train_dset.answer_table.num_answers)

        # Weight initialization and loading
        if args.from_scratch:
            print("Train from Scratch: re-initialize all BERT weights.")
            self.model.apply(self.model.init_bert_weights)
        if args.load is not None:
            self.load(args.load)
        if args.load_lxmert is not None:
            # Load lxmert would not load the answer head.
            self.load_lxmert(args.load_lxmert)

        # GPU Options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model = nn.DataParallel(self.model)
Esempio n. 2
0
    def __init__(self, args, max_seq_length, mode='x'):
        super().__init__()
        self.max_seq_length = max_seq_length

        from lxrt.modeling import LXRTFeatureExtraction as VisualBertForLXRFeature, VISUAL_CONFIG

        set_visual_config(args, VISUAL_CONFIG)

        # Using the bert tokenizer
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                                       do_lower_case=True)

        # Build LXRT Model
        self.model = VisualBertForLXRFeature.from_pretrained(
            "bert-base-uncased", mode=mode)

        if args.from_scratch:
            print("Re-initializing all the weights")
            self.model.apply(self.model.init_bert_weights)

        self.load_pretrain_head = args.get("load_pretrain_head", False)
        if self.load_pretrain_head:
            from lxmert.src.lxrt.modeling import BertPreTrainingHeads
            self.pretrained_head = BertPreTrainingHeads(
                self.model.config,
                self.model.bert.embeddings.word_embeddings.weight)
Esempio n. 3
0
    def __init__(self, num_answers):
        super().__init__()

        # Build LXRT encoder
        self.lxrt_encoder = LXRTEncoder(args, max_seq_length=MAX_VQA_LENGTH)
        hid_dim = self.lxrt_encoder.dim

        # VQA Answer heads
        self.logit_fc = nn.Sequential(nn.Linear(hid_dim, hid_dim * 2), GeLU(),
                                      BertLayerNorm(hid_dim * 2, eps=1e-12),
                                      nn.Linear(hid_dim * 2, num_answers))
        self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)

        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                                       do_lower_case=True)
Esempio n. 4
0
    def __init__(self, args, max_seq_length, mode='x'):
        super().__init__()
        self.max_seq_length = max_seq_length
        set_visual_config(args)

        # Using the bert tokenizer
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                                       do_lower_case=True)

        # Build LXRT Model
        self.model = VisualBertForLXRFeature.from_pretrained(
            "bert-base-uncased", mode=mode)

        if args.from_scratch:
            print("initializing all the weights")
            self.model.apply(self.model.init_bert_weights)
    def __init__(self, files, file_type='train-match', batch_size=256):
        self.files = files
        self.file_type = file_type
        self.batch_size = batch_size

        # generate dictionary of labels
        self.dict_multimodal_labels = {}
        for line in open(os.path.join(KDD_DATA,
                                      "../data/multimodal_labels.txt"),
                         encoding='utf-8'):
            arr = line.strip().split("\t")
            label = arr[1].replace(",", " ").replace(".", " ").replace(
                "(", " ").replace(")", " ")
            self.dict_multimodal_labels[arr[0]] = label.strip()

        self.tokenizer = BertTokenizer.from_pretrained("../user_data",
                                                       do_lower_case=True)
Esempio n. 6
0
    def __init__(self, args, max_seq_length, mode='x'):
        super().__init__()
        self.max_seq_length = max_seq_length

        # Using the bert tokenizer
        self.tokenizer = BertTokenizer.from_pretrained(
            "bert-base-uncased",
            do_lower_case=True
        )

        config = BertConfig.from_json_file(args.config_file)
        # Build Model
        self.model = InterBertForVLTasks.from_pretrained(
            "snap/pretrained/inter_bert/pytorch_model.bin", config
        )

        if args.from_scratch:
            print("initializing all the weights")
            self.model.apply(self.model.init_bert_weights)
Esempio n. 7
0
    def __init__(self, max_seq_length):
        super().__init__()
        self.max_seq_length = max_seq_length

        self.tokenizer = BertTokenizer.from_pretrained(
            "bert-base-uncased",
            do_lower_case=True
        )

        # Build model
        self.model = LXRTPretraining.from_pretrained(
            "bert-base-uncased",
            args = args,
            task_mask_lm=args.task_mask_lm,
            task_obj_predict=args.task_obj_predict,
            task_matched=args.task_matched,
            task_qa=args.task_qa,
            visual_losses=args.visual_losses,
            num_answers= args.num_answers if args.get("num_answers", None) else train_tuple.dataset.answer_table.num_answers
        )

        # Weight initialization and loading
        if args.from_scratch:
            print("Train from Scratch: re-initialize all BERT weights.")
            self.model.apply(self.model.init_bert_weights)

        if args.get("use_tag_symbolic_embedding", False):
            self.model.bert.embeddings.initialize_symbolic_embeddings(symbolic_vocab.get_symbolic_list(self.tokenizer))
            self.model.special_initialize_pretraining_head()
        
        if args.get("hybrid_embedding", False):
            self.model.bert.embeddings.initialize_visual_position_type_embeddings()
        
        if args.load_lxmert is not None:
            # Load lxmert would not load the answer head.
            self.load_lxmert(args.load_lxmert)
        
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model = nn.DataParallel(self.model)
        
        self.global_step = 0
Esempio n. 8
0
    def __init__(self, args, max_seq_length, mode='x', attention=False):
        super().__init__()
        print(f"Making {__name__}")
        self.max_seq_length = max_seq_length
        set_visual_config(args)

        # Using the bert tokenizer
        self.tokenizer = BertTokenizer.from_pretrained(
            "bert-base-uncased",
            do_lower_case=True
        )
        print("Made Tokenizer")
        # Build LXRT Model
        self.model = VisualBertForLXRFeature.from_pretrained(
            "bert-base-uncased",
            mode=mode, 
            attention=attention
        )
        print("Made VisualBertForLXRFeature")
        if args.from_scratch:
            print("initializing all the weights")
            self.model.apply(self.model.init_bert_weights)
        print(f"Done {__name__}")
Esempio n. 9
0
    def __init__(self,
                 ann_file,
                 pretrained_model_name,
                 tokenizer=None,
                 seq_len=64,
                 min_seq_len=64,
                 encoding="utf-8",
                 on_memory=True,
                 **kwargs):
        assert on_memory, "only support on_memory mode!"

        self.tokenizer = tokenizer if tokenizer is not None else BertTokenizer.from_pretrained(
            pretrained_model_name)
        self.vocab = self.tokenizer.vocab
        self.seq_len = seq_len
        self.min_seq_len = min_seq_len
        self.on_memory = on_memory
        self.ann_file = ann_file
        self.encoding = encoding
        self.test_mode = False

        self.do_no_fill = False

        self.use_mismatch_objective = args.get("task_matched", False)
        #self.load_corpus_with_passages()

        # load samples into memory
        if on_memory:
            if self.use_mismatch_objective:
                #self.corpus = self.load_corpus_with_passages_preprocess()
                self.load_corpus_with_passages_preprocess()
            else:
                self.corpus = self.load_corpus()
        if args.get("presegment_sentence", False):
            self.presegment_sentence()
        print("Using {} with {} data.\n\n".format(self.ann_file, len(self)))
Esempio n. 10
0
    def __init__(self, dataset: VQADataset, args):
        super().__init__()
        self.raw_dataset = dataset

        if args.tiny:
            topk = TINY_IMG_NUM
        elif args.fast:
            topk = FAST_IMG_NUM
        else:
            topk = None

        self.limit_to_symbolic_split = args.get("limit_to_symbolic_split",
                                                False)
        if self.limit_to_symbolic_split:
            dataDir = "/local/harold/ubert/bottom-up-attention/data/vg/"
            coco_ids = set()
            self.mapping_cocoid_to_imageid = {}
            with open(os.path.join(dataDir, 'image_data.json')) as f:
                metadata = json.load(f)
                for item in metadata:
                    if item['coco_id']:
                        coco_ids.add(int(item['coco_id']))
                        self.mapping_cocoid_to_imageid[int(
                            item['coco_id'])] = item["image_id"]

            from lib.data.vg_gqa import vg_gqa
            self.vg_gqa = vg_gqa(
                None,
                split="val" if self.raw_dataset.name == "minival" else "train",
                transforms=None,
                num_im=-1)

        self.custom_coco_data = args.get("custom_coco_data", False)
        self.use_h5_file = args.get("use_h5_file", False)
        if self.use_h5_file:
            self.image_feature_dataset = ImageFeatureDataset.create(
                dataset.splits,
                Split2ImgFeatPath,
                on_memory=args.get("on_memory", False))
            self.ids_to_index = self.image_feature_dataset.ids_to_index

            # Screen data
            used_data = []
            for datum in self.raw_dataset.data:
                if datum['img_id'] in self.ids_to_index:
                    used_data.append(datum)
        else:
            # Loading detection features to img_data
            img_data = []
            for split in dataset.splits:
                # Minival is 5K images in MS COCO, which is used in evaluating VQA/LXMERT-pre-training.
                # It is saved as the top 5K features in val2014_***.tsv
                load_topk = 5000 if (split == 'minival'
                                     and topk is None) else topk
                img_data.extend(
                    load_obj_tsv(os.path.join(
                        MSCOCO_IMGFEAT_ROOT,
                        '%s_obj36.tsv' % (SPLIT2NAME[split])),
                                 topk=load_topk))

            # Convert img list to dict
            self.imgid2img = {}
            for img_datum in img_data:
                self.imgid2img[img_datum['img_id']] = img_datum

            used_data = self.raw_dataset.data

        used_data = used_data[::args.get("partial_dataset", 1)]
        self.data = used_data

        # Only kept the data with loaded image features
        print("Use %d data in torch dataset" % (len(self.data)))
        print()

        if args.get("add_tags", False):
            self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                                           do_lower_case=True)
            from lxrt.symbolic_vocabulary import SymbolicVocab
            self.symbolic_vocab = SymbolicVocab(args.objects_vocab,
                                                args.attributes_vocab)
Esempio n. 11
0
    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple

        question_id2img_id = {x["question_id"]: x["img_id"] for x in dset.data}
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                                  do_lower_case=True)
        plt.rcParams['figure.figsize'] = (12, 10)
        num_regions = 36

        count = 0

        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)

                for layer in [0, 4]:
                    for head in [0, 1]:
                        for datapoint in range(len(sent)):
                            print(count, len(sent))
                            count += 1
                            lang2vis_attention_probs = self.model.lxrt_encoder.model.bert.encoder.x_layers[
                                layer].lang_att_map[datapoint][head].detach(
                                ).cpu().numpy()

                            vis2lang_attention_probs = self.model.lxrt_encoder.model.bert.encoder.x_layers[
                                layer].visn_att_map[datapoint][head].detach(
                                ).cpu().numpy()

                            plt.clf()

                            plt.subplot(2, 3, 1)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 0-7)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            plt.subplot(2, 3, 2)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 8-15)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            plt.subplot(2, 3, 3)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 16-35)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            img_info = loader.dataset.imgid2img[
                                question_id2img_id[ques_id[datapoint].item()]]
                            img_h, img_w = img_info['img_h'], img_info['img_w']
                            unnormalized_boxes = boxes[datapoint].clone()
                            unnormalized_boxes[:, (0, 2)] *= img_w
                            unnormalized_boxes[:, (1, 3)] *= img_h

                            for i, bbox in enumerate(unnormalized_boxes):
                                if i < 8:
                                    plt.subplot(2, 3, 1)
                                elif i < 16:
                                    plt.subplot(2, 3, 2)
                                else:
                                    plt.subplot(2, 3, 3)

                                bbox = [
                                    bbox[0].item(), bbox[1].item(),
                                    bbox[2].item(), bbox[3].item()
                                ]

                                if bbox[0] == 0:
                                    bbox[0] = 2
                                if bbox[1] == 0:
                                    bbox[1] = 2

                                plt.gca().add_patch(
                                    plt.Rectangle((bbox[0], bbox[1]),
                                                  bbox[2] - bbox[0] - 4,
                                                  bbox[3] - bbox[1] - 4,
                                                  fill=False,
                                                  edgecolor='red',
                                                  linewidth=1))

                                plt.gca().text(bbox[0],
                                               bbox[1] - 2,
                                               '%s' % i,
                                               bbox=dict(facecolor='blue'),
                                               fontsize=9,
                                               color='white')

                            ax = plt.subplot(2, 1, 2)
                            plt.title("Cross-modal attention lang2vis")

                            tokenized_question = tokenizer.tokenize(
                                sent[datapoint])
                            tokenized_question = [
                                "<CLS>"
                            ] + tokenized_question + ["<SEP>"]

                            transposed_attention_map = lang2vis_attention_probs[:len(
                                tokenized_question), :num_regions]
                            im = plt.imshow(transposed_attention_map,
                                            vmin=0,
                                            vmax=1)

                            for i in range(len(tokenized_question)):
                                for j in range(num_regions):
                                    att_value = round(
                                        transposed_attention_map[i, j], 1)
                                    text = ax.text(
                                        j,
                                        i,
                                        att_value,
                                        ha="center",
                                        va="center",
                                        color="w" if att_value <= 0.5 else "b",
                                        fontsize=6)

                            ax.set_xticks(np.arange(num_regions))
                            ax.set_xticklabels(list(range(num_regions)))

                            ax.set_yticks(np.arange(len(tokenized_question)))
                            ax.set_yticklabels(tokenized_question)

                            plt.tight_layout()
                            # plt.gca().set_axis_off()
                            plt.savefig(
                                "/mnt/8tera/claudio.greco/guesswhat_lxmert/guesswhat/visualization_vqa/lang2vis_question_{}_layer_{}_head_{}.png"
                                .format(ques_id[datapoint].item(), layer,
                                        head),
                                bbox_inches='tight',
                                pad_inches=0.5)

                            plt.close()

                            ## vis2lang

                            plt.clf()

                            plt.subplot(2, 3, 1)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 0-7)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            plt.subplot(2, 3, 2)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 8-15)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            plt.subplot(2, 3, 3)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 16-35)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            img_info = loader.dataset.imgid2img[
                                question_id2img_id[ques_id[datapoint].item()]]
                            img_h, img_w = img_info['img_h'], img_info['img_w']
                            unnormalized_boxes = boxes[datapoint].clone()
                            unnormalized_boxes[:, (0, 2)] *= img_w
                            unnormalized_boxes[:, (1, 3)] *= img_h

                            for i, bbox in enumerate(unnormalized_boxes):
                                if i < 8:
                                    plt.subplot(2, 3, 1)
                                elif i < 16:
                                    plt.subplot(2, 3, 2)
                                else:
                                    plt.subplot(2, 3, 3)

                                bbox = [
                                    bbox[0].item(), bbox[1].item(),
                                    bbox[2].item(), bbox[3].item()
                                ]

                                if bbox[0] == 0:
                                    bbox[0] = 2
                                if bbox[1] == 0:
                                    bbox[1] = 2

                                plt.gca().add_patch(
                                    plt.Rectangle((bbox[0], bbox[1]),
                                                  bbox[2] - bbox[0] - 4,
                                                  bbox[3] - bbox[1] - 4,
                                                  fill=False,
                                                  edgecolor='red',
                                                  linewidth=1))

                                plt.gca().text(bbox[0],
                                               bbox[1] - 2,
                                               '%s' % i,
                                               bbox=dict(facecolor='blue'),
                                               fontsize=9,
                                               color='white')

                            ax = plt.subplot(2, 1, 2)
                            plt.title("Cross-modal attention vis2lang")

                            tokenized_question = tokenizer.tokenize(
                                sent[datapoint])
                            tokenized_question = [
                                "<CLS>"
                            ] + tokenized_question + ["<SEP>"]

                            transposed_attention_map = vis2lang_attention_probs.transpose(
                            )[:len(tokenized_question), :num_regions]
                            im = plt.imshow(transposed_attention_map,
                                            vmin=0,
                                            vmax=1)

                            for i in range(len(tokenized_question)):
                                for j in range(num_regions):
                                    att_value = round(
                                        transposed_attention_map[i, j], 1)
                                    text = ax.text(
                                        j,
                                        i,
                                        att_value,
                                        ha="center",
                                        va="center",
                                        color="w" if att_value <= 0.5 else "b",
                                        fontsize=6)

                            ax.set_xticks(np.arange(num_regions))
                            ax.set_xticklabels(list(range(num_regions)))

                            ax.set_yticks(np.arange(len(tokenized_question)))
                            ax.set_yticklabels(tokenized_question)

                            plt.tight_layout()
                            # plt.gca().set_axis_off()
                            plt.savefig(
                                "/mnt/8tera/claudio.greco/guesswhat_lxmert/guesswhat/visualization_vqa/vis2lang_question_{}_layer_{}_head_{}.png"
                                .format(ques_id[datapoint].item(), layer,
                                        head),
                                bbox_inches='tight',
                                pad_inches=0.5)

                            plt.close()

                            # print(datapoint, len(sent))
                    #
                    #         print(datapoint)
                    #         if datapoint > 20:
                    #             break
                    #     if datapoint > 20:
                    #         break
                    # if datapoint > 20:
                    #     break

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans
Esempio n. 12
0
    def __init__(self,
                 dataset: LXMERTDataset,
                 topk=-1,
                 sgg_dataset=None,
                 image_only=False,
                 text_only=False,
                 use_visual_tag_flag=False,
                 limit_source=[],
                 available_split_for_cc=None):
        super().__init__()
        self.raw_dataset = dataset
        self.name = '_'.join(self.raw_dataset.sources)
        if args.get('disable_mismatch_for_other_dataset', False):
            # Do not resample for datasets such as BookCorpus
            self.task_matched = args.task_matched if "book_corpus" in self.raw_dataset.sources else False
        else:
            self.task_matched = args.task_matched

        print(self.raw_dataset.sources)
        print(self.task_matched)
        print("\n\n\n")
        self.sgg_dataset = sgg_dataset
        self.image_only = image_only
        self.text_only = text_only
        self.use_visual_tag_flag = use_visual_tag_flag
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                                       do_lower_case=True)
        self.task_nlvr2 = args.get("task_nlvr2", False)

        if args.tiny:
            topk = TINY_IMG_NUM
        elif args.fast:
            topk = FAST_IMG_NUM

        #self.fake_data = args.get("fake_data", False)
        self.custom_coco_data = args.get("custom_coco_data", False)
        self.use_h5_file = args.get("use_h5_file", False)
        if self.use_h5_file:
            if "google_cc_train" in dataset.sources:
                if args.get('change_split', False):
                    available_split_for_cc = [39]
                else:
                    available_split_for_cc = args.get("available_split_for_cc",
                                                      [0])
                sources = []
                split_map = {}
                for i in available_split_for_cc:
                    sources.append("google_cc_{}".format(i))
                    split_map["google_cc_{}".format(
                        i
                    )] = "data/google_concetual/butd_feat/train_no_features_split_{}_of_40_splits.h5".format(
                        i)
                self.image_feature_dataset = ImageFeatureDataset.create(
                    sources,
                    split_map,
                    load_custom_h5_version2=True,
                    text_only=self.text_only,
                    on_memory=False)
            elif "open_images_train" in dataset.sources:
                available_split_for_open_image = args.get(
                    "available_split_for_open_image", [0])
                sources = []
                split_map = {}
                for split_i, split_j, total_split in available_split_for_open_image:
                    sources.append("open_image_{}_{}".format(split_i, split_j))
                    split_map["open_image_{}_{}".format(
                        split_i, split_j
                    )] = "data/open_image/butd_feat/train_{}_no_features_split_{}_of_{}_splits.h5".format(
                        split_i, split_j, total_split)
                self.image_feature_dataset = ImageFeatureDataset.create(
                    sources,
                    split_map,
                    load_custom_h5_version2=True,
                    on_memory=False)
            else:
                self.image_feature_dataset = ImageFeatureDataset.create(
                    dataset.sources,
                    Split2ImgFeatPath_h5,
                    text_only=self.text_only,
                    load_custom_h5_version2=True
                    if "flickr_train" in dataset.sources else False,
                    on_memory=args.get("on_memory", False))
            self.ids_to_index = self.image_feature_dataset.ids_to_index

            # Screen data
            used_data = []
            for datum in self.raw_dataset.data:
                if datum['img_id'] in self.ids_to_index:
                    used_data.append(datum)
        else:
            # Original LXMERT. Load the dataset
            img_data = []
            for source in self.raw_dataset.sources:
                img_data.extend(load_obj_tsv(Split2ImgFeatPath[source], topk))

            self.imgid2img = {}
            for img_datum in img_data:
                self.imgid2img[img_datum['img_id']] = img_datum

            # Filter out the dataset
            used_data = []
            for datum in self.raw_dataset.data:
                if datum['img_id'] in self.imgid2img:
                    used_data.append(datum)

        used_data = used_data[::args.get("partial_dataset", 1)]

        if sgg_dataset is not None:
            used_data = [
                datum for datum in used_data
                if str(datum["img_id"]) in self.sgg_dataset.imageids_to_index
            ]

        # Flatten the dataset (into one sent + one image entries)
        self.data = []

        record_img_id = set()

        remaining_set = set()
        for datum in used_data:
            # datum: {'img_id': 'COCO_train2014_000000318556', 'labelf': {'vqa': [{'no': 1}, {'yes': 1}, {'no': 1}, {'blue': 1, 'blue and white': 0.3}]}, 'sentf': {'mscoco': ['A very clean and well decorated empty bathroom', 'A blue and white bathroom with butterfly themed wall tiles.', 'A bathroom with a border of butterflies and blue paint on the walls above it.', 'An angled view of a beautifully decorated bathroom.', 'A clock that blends in with the wall hangs in a bathroom. '], 'vqa': ['Is the sink full of water?', 'Are there any butterflies on the tiles?', 'Is this bathroom in a hotel?', 'What color are the walls?']}}

            sentf = datum['sentf']
            for sents_cat, sents in sentf.items():
                if sents_cat in limit_source:
                    continue

                remaining_set.add(sents_cat)

                if sents_cat in datum['labelf']:
                    labels = datum['labelf'][sents_cat]
                else:
                    labels = None
                for sent_idx, sent in enumerate(sents):
                    new_datum = {
                        'uid':
                        make_uid(datum['img_id'], sents_cat, sent_idx)
                        if args.task_qa else None,
                        'img_id':
                        datum['img_id'],  # if not self.text_only else "",
                        'sent':
                        sent  #if not self.image_only else ""
                    }
                    if image_only:  # If we only use image, make sure one image only appears one time
                        if datum["img_id"] in record_img_id:
                            continue
                        record_img_id.add(datum["img_id"])

                    if labels is not None and args.task_qa:
                        new_datum['label'] = labels[sent_idx]

                    if self.task_nlvr2:
                        new_datum['match_label'] = datum["label"]
                        new_datum['img_id_1'] = datum["img_id_1"]

                    self.data.append(new_datum)

        if image_only:
            dataset_str = "image_only"
        elif text_only:
            dataset_str = "text_only"
        else:
            dataset_str = "vision and language"

        if self.image_only and args.get("screen_image", False):
            counter = 0
            from tqdm import tqdm
            _data = []
            for data_item in tqdm(self.data):
                img_id = data_item["img_id"]
                image_index = self.image_feature_dataset.ids_to_index[img_id]
                img_h = self.image_feature_dataset.h5_wh[image_index][1]
                img_w = self.image_feature_dataset.h5_wh[image_index][0]
                if img_h == 0 or img_w == 0:
                    counter += 1
                else:
                    _data.append(data_item)

            print(
                "Screened {} images with zero heights and weidths, {} in total"
                .format(counter, len(_data)))
            self.data = _data

        print("Use {} data in {} torch dataset, {}, limit_source {}".format(
            len(self.data), dataset_str, remaining_set, limit_source))

        if text_only:
            del self.image_feature_dataset

        if text_only or image_only:
            del self.raw_dataset.data
            del self.raw_dataset

        self.compress_memory = False
        if args.get("compress_memory", False):
            # Move some data to shared memory so the memory will not explode when using multi-process for data loading
            self.compress()
        print("\n\n\n")