def __init__(self,
                 yaml_file,
                 tokenizer=None,
                 add_od_labels=True,
                 max_img_seq_length=50,
                 max_seq_length=70,
                 max_seq_a_length=40,
                 is_train=True,
                 mask_prob=0.15,
                 max_masked_tokens=3,
                 add_conf=False,
                 **kwargs):
        """Constructor.
        Args:
            yaml file with all required data (image feature, caption, labels, etc)
            tokenizer: tokenizer for text processing.
            add_od_labels: whether to add labels from yaml file to BERT. 
            max_img_seq_length: max image sequence length.
            max_seq_length: max text sequence length.
            max_seq_a_length: max caption sequence length.
            is_train: train or test mode.
            mask_prob: probability to mask a input token.
            max_masked_tokens: maximum number of tokens to be masked in one sentence.
            kwargs: other arguments.
        """
        self.yaml_file = yaml_file
        self.cfg = load_from_yaml_file(yaml_file)
        self.root = op.dirname(yaml_file)
        self.label_file = find_file_path_in_yaml(self.cfg['label'], self.root)
        self.feat_file = find_file_path_in_yaml(self.cfg['feature'], self.root)
        self.caption_file = find_file_path_in_yaml(self.cfg.get('caption'),
                                                   self.root)

        assert op.isfile(self.feat_file)
        if add_od_labels: assert op.isfile(self.label_file)
        if is_train:
            assert op.isfile(self.caption_file) and tokenizer is not None

        self.label_tsv = None if not self.label_file else TSVFile(
            self.label_file)
        self.feat_tsv = TSVFile(self.feat_file)
        if self.caption_file and op.isfile(self.caption_file):
            with open(self.caption_file, 'r') as f:
                self.captions = json.load(f)

        self.tokenizer = tokenizer
        self.tensorizer = CaptionTensorizer(self.tokenizer,
                                            max_img_seq_length,
                                            max_seq_length,
                                            max_seq_a_length,
                                            mask_prob,
                                            max_masked_tokens,
                                            is_train=is_train)
        self.add_od_labels = add_od_labels
        self.is_train = is_train
        self.kwargs = kwargs
        self.image_keys = self.prepare_image_keys()
        self.key2index = self.prepare_image_key_to_index()
        self.key2captions = self.prepare_image_key_to_captions()
        self.add_conf = add_conf
Esempio n. 2
0
 def check_img_label_file(self):
     if self.img_label_file is None:
         self.img_label_file = {}
         self.img_qa_file = {}
         for dataset_name in self.datasets_names:
             img_label_file_path = os.path.join(
                 self.image_label_path[dataset_name], 'predictions_gt.tsv')
             img_qa_file_path = os.path.join(
                 self.image_label_path[dataset_name], 'QA_fileB.tsv')
             t_s = time.time()
             self.img_label_file[dataset_name] = TSVFile(
                 img_label_file_path)
             if os.path.exists(img_qa_file_path):
                 self.img_qa_file[dataset_name] = TSVFile(img_qa_file_path)
             t_e = time.time()
             logging.info("Open image label file {}, time: {}".format(
                 img_label_file_path, (t_e - t_s)))
Esempio n. 3
0
    def check_img_feature_file(self):
        if self.img_feature_file is None:
            # self.img_feature_file = [] # original
            self.img_feature_file = {}
            self.img_feat_offset_map = {}
            for dataset_name in self.datasets_names:
                logging.info("* Loading dataset {}".format(dataset_name))
                if dataset_name in self.datasets_with_splits:
                    self.img_feature_file[dataset_name] = {}
                    self.img_feat_offset_map[dataset_name] = {}
                    chunk_list = []
                    if self.chunk_list is not None:
                        chunk_list = self.chunk_list
                        chunk_file_list = []
                        for chunk_fp_id in chunk_list:
                            chunk_file_list.append(
                                os.path.join(
                                    self.image_feature_path[dataset_name],
                                    chunk_fp_id, self.image_file_name))
                        if dataset_name == 'googlecc':
                            for i, (chunk_fp_id, chunk_fp) in enumerate(
                                    zip(chunk_list, chunk_file_list)):
                                assert os.path.exists(
                                    chunk_file_list[i]
                                ), "Chunk file {} does not exists!".format(
                                    chunk_fp)
                    else:
                        chunk_file_list = glob.glob(
                            self.image_feature_path[dataset_name] +
                            "/*/{}".format(self.image_file_name))
                        for chunk_fp in chunk_file_list:
                            chunk_fp_id = chunk_fp.split('/')[-2]
                            chunk_list.append(chunk_fp_id)
                    logging.info("* Load Image Chunks {}".format(
                        len(chunk_list)))

                    t_s_total = time.time()
                    for chunk_fp in chunk_file_list:
                        chunk_fp_id = chunk_fp.split('/')[-2]
                        t_s = time.time()
                        self.img_feature_file[dataset_name][
                            chunk_fp_id] = TSVFile(chunk_fp)
                        chunk_offsetmap = os.path.join(
                            os.path.dirname(chunk_fp), 'imageid2idx.json')
                        assert os.path.isfile(
                            chunk_offsetmap
                        ), "Imageid2idx file {} does not exists!".format(
                            chunk_offsetmap)
                        self.img_feat_offset_map[dataset_name][
                            chunk_fp_id] = json.load(open(
                                chunk_offsetmap, 'r'))
                        t_e = time.time()
                        logging.info("Open image chunk {}, time: {}".format(
                            chunk_fp_id, (t_e - t_s)))
                    t_e_total = time.time()
                    logging.info("Open total {} image chunks, time: {}".format(
                        len(chunk_list), (t_e_total - t_s_total)))
                    logging.info("Image chunk info: {}".format(
                        '\n'.join(chunk_file_list)))
                elif dataset_name in self.datasets_with_onesplit:
                    t_s = time.time()
                    chunk_fp = os.path.join(
                        self.image_feature_path[dataset_name],
                        self.image_file_name)
                    self.img_feature_file[dataset_name] = TSVFile(chunk_fp)
                    chunk_offsetmap = os.path.join(os.path.dirname(chunk_fp),
                                                   'imageid2idx.json')
                    assert os.path.isfile(
                        chunk_offsetmap
                    ), "Imageid2idx file {} does not exists!".format(
                        chunk_offsetmap)
                    self.img_feat_offset_map[dataset_name] = json.load(
                        open(chunk_offsetmap, 'r'))
                    t_e = time.time()
                    logging.info("Open dataset {}, time: {}".format(
                        chunk_fp, (t_e - t_s)))
                else:
                    raise ValueError(
                        "Not supported dataset: {}".format(dataset_name))
Esempio n. 4
0
    def __init__(self,
                 yaml_file,
                 args=None,
                 tokenizer=None,
                 seq_len=35,
                 encoding="utf-8",
                 corpus_lines=None,
                 on_memory=True,
                 **kwargs):
        self.cfg = load_from_yaml_file(yaml_file)
        self.root = os.path.dirname(yaml_file)
        self.vocab = tokenizer.vocab
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.on_memory = on_memory
        self.corpus_lines = corpus_lines  # number of non-empty lines in input corpus
        self.corpus_tsvfile = TSVFile(
            os.path.join(self.root, self.cfg['corpus_file']))
        if 'textb_sample_mode' in kwargs:
            self.textb_sample_mode = kwargs['textb_sample_mode']
        else:
            self.textb_sample_mode = args.textb_sample_mode

        self.datasets_names = self.cfg['corpus'].split('_')
        self.datasets_with_splits = [
            'googlecc', 'sbu', 'oi', 'objects365', 'tagoi'
        ]
        self.datasets_with_onesplit = ['coco', 'flickr30k', 'gqa']
        logging.info('Datasets: {}'.format(','.join(self.datasets_names)))
        self.image_label_path = self.cfg['image_label_path']
        for key, val in self.image_label_path.items():
            # get the absolute path
            if key in self.datasets_names:
                self.image_label_path[key] = os.path.join(self.root, val)
        self.image_feature_path = self.cfg['image_feature_path']
        self.image_file_name = 'features.tsv'
        if args.data_dir is not None:
            for key, val in self.image_feature_path.items():
                # get the absolute path
                if key in self.datasets_names:
                    self.image_feature_path[key] = os.path.join(
                        args.data_dir, val)
                else:
                    logging.info("Data {} with path {} is not used in the "
                                 "training.".format(key, val))
        self.encoding = encoding
        self.current_doc = 0  # to avoid random sentence from same doc
        self.current_img = ''  # to avoid random sentence from same image

        self.args = args

        # for loading samples directly from file
        self.sample_counter = 0  # used to keep track of full epochs on file
        self.line_buffer = None  # keep second sentence of a pair in memory and use as first sentence in next pair

        # for loading samples in memory
        self.current_random_doc = 0
        self.num_docs = 0
        self.sample_to_doc = []  # map sample index to doc and line

        self.chunk_list = None
        if 0 <= args.chunk_start_id <= args.chunk_end_id and args.chunk_end_id >= 0:
            self.chunk_list = [
                str(c_i)
                for c_i in range(args.chunk_start_id, args.chunk_end_id)
            ]
            logging.info('Chunk list: {}'.format(','.join(self.chunk_list)))

        # load image tags and features
        t_start = time.time()
        self.img_label_file = None
        self.img_qa_file = None
        self.img_label_offset_map = None
        self.img_qa_offset_map = None
        self.img_feature_file = None
        self.img_feat_offset_map = None
        self.load_img_labels()
        self.load_img_tsv_features()
        t_end = time.time()
        logging.info(
            'Info: loading img features using {} secs'.format(t_end - t_start))

        # load samples into memory
        if on_memory:
            self.all_docs = []
            self.all_qa_docs = []
            self.imgid2labels = {}
            self.corpus_lines = 0
            max_tokens = 0
            for line_no in tqdm(range(len(self.corpus_tsvfile))):
                doc = []
                row = self.corpus_tsvfile.seek(line_no)
                img_info = row[0].split('_')
                label_info = row[1].split('_')
                assert img_info[0] == label_info[
                    0], "Dataset names for image and label do not match!"
                dataset_name = label_info[0]
                if dataset_name == 'cc':
                    dataset_name = 'googlecc'

                if dataset_name not in self.datasets_names:
                    continue

                if dataset_name in self.datasets_with_splits:
                    chunk_id = img_info[-2]
                    if self.chunk_list is not None and chunk_id not in self.chunk_list:
                        continue
                    else:
                        img_feat_offset_map = self.img_feat_offset_map[
                            dataset_name][chunk_id]
                else:
                    img_feat_offset_map = self.img_feat_offset_map[
                        dataset_name]
                assert img_info[
                    -1] in img_feat_offset_map, "{}: Image id {} cannot be found in image feature imageid_to_index file!".format(
                        row[0], img_info[-1])

                # append id info
                doc.append('%s|%s' % (row[0], row[1]))
                # append text_a info
                self.corpus_lines = self.corpus_lines + 1
                sample = {"doc_id": len(self.all_docs), "line": len(doc)}
                self.sample_to_doc.append(sample)
                assert len(row[2]) != 0, "Text_a is empty in {} : {}"\
                    .format(dataset_name, row[0])
                doc.append(row[2])
                # append text_b info
                self.corpus_lines = self.corpus_lines + 1
                label_id = label_info[-1]
                if 'qa' in label_info:
                    assert img_info[-1] == label_info[
                        -2], "Image ids for image and qa do not match!"
                    label_line_no = self.img_qa_offset_map[dataset_name][
                        label_id]
                    rowb = self.img_qa_file[dataset_name].seek(label_line_no)
                else:
                    assert img_info[-1] == label_info[
                        -1], "Image ids for image and label do not match!"
                    label_line_no = self.img_label_offset_map[dataset_name][
                        label_id]
                    rowb = self.img_label_file[dataset_name].seek(
                        label_line_no)
                assert label_id == rowb[0]
                results = json.loads(rowb[1])
                if 'qa' not in label_info:  # more intuitively, should be if 'qa' not in label_info:
                    objects = results['objects']
                    if row[0] not in self.imgid2labels:
                        self.imgid2labels[row[0]] = {
                            "image_h": results["image_h"],
                            "image_w": results["image_w"],
                            "boxes": None
                        }
                    else:
                        assert results["image_h"] == self.imgid2labels[row[0]][
                            "image_h"], "Image_h does not match in image {}!".format(
                                row[0])
                        assert results["image_w"] == self.imgid2labels[row[0]][
                            "image_w"], "Image_w does not match in image {}!".format(
                                row[0])
                    if args.use_gtlabels and 'gt_objects' in results:
                        # use ground-truth tags for text_b
                        textb = ' '.join([
                            cur_d['class'] for cur_d in results["gt_objects"]
                        ])
                    else:
                        textb = ' '.join([cur_d['class'] for cur_d in objects])
                else:
                    tag_label_line_no = self.img_label_offset_map[
                        dataset_name][img_info[-1]]
                    tag_rowb = self.img_label_file[dataset_name].seek(
                        tag_label_line_no)
                    tag_results = json.loads(tag_rowb[1])
                    if row[0] not in self.imgid2labels:
                        self.imgid2labels[row[0]] = {
                            "image_h": tag_results["image_h"],
                            "image_w": tag_results["image_w"],
                            "boxes": None
                        }
                    else:
                        assert tag_results["image_h"] == self.imgid2labels[row[0]][
                            "image_h"], "Image_h does not match in image {}!".format(
                                row[0])
                        assert tag_results["image_w"] == self.imgid2labels[row[0]][
                            "image_w"], "Image_w does not match in image {}!".format(
                                row[0])
                    textb = ' '.join(results['labels'])
                assert len(textb) != 0, "Text_b is empty in {} : {}".format(
                    dataset_name, row[1])
                doc.append(textb)

                # add to all_docs
                max_tokens = max(
                    max_tokens,
                    len(doc[1].split(' ')) + len(doc[2].split(' ')))
                if 'qa' in label_info:
                    self.all_qa_docs.append({
                        "doc": doc,
                        "doc_id": len(self.all_docs)
                    })
                self.all_docs.append(doc)

            self.num_docs = len(self.all_docs)
            logging.info("Max_tokens: {}".format(max_tokens))
        # load samples later lazily from disk
        else:
            raise ValueError("on_memory = False Not supported yet!")

        logging.info("Total docs - Corpus_lines: {}-{}".format(
            self.num_docs, self.corpus_lines))
        logging.info("Total QA docs - Corpus_lines: {}".format(
            len(self.all_qa_docs)))
Esempio n. 5
0
class OscarTSVDataset(Dataset):
    def __init__(self,
                 yaml_file,
                 args=None,
                 tokenizer=None,
                 seq_len=35,
                 encoding="utf-8",
                 corpus_lines=None,
                 on_memory=True,
                 **kwargs):
        self.cfg = load_from_yaml_file(yaml_file)
        self.root = os.path.dirname(yaml_file)
        self.vocab = tokenizer.vocab
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.on_memory = on_memory
        self.corpus_lines = corpus_lines  # number of non-empty lines in input corpus
        self.corpus_tsvfile = TSVFile(
            os.path.join(self.root, self.cfg['corpus_file']))
        if 'textb_sample_mode' in kwargs:
            self.textb_sample_mode = kwargs['textb_sample_mode']
        else:
            self.textb_sample_mode = args.textb_sample_mode

        self.datasets_names = self.cfg['corpus'].split('_')
        self.datasets_with_splits = [
            'googlecc', 'sbu', 'oi', 'objects365', 'tagoi'
        ]
        self.datasets_with_onesplit = ['coco', 'flickr30k', 'gqa']
        logging.info('Datasets: {}'.format(','.join(self.datasets_names)))
        self.image_label_path = self.cfg['image_label_path']
        for key, val in self.image_label_path.items():
            # get the absolute path
            if key in self.datasets_names:
                self.image_label_path[key] = os.path.join(self.root, val)
        self.image_feature_path = self.cfg['image_feature_path']
        self.image_file_name = 'features.tsv'
        if args.data_dir is not None:
            for key, val in self.image_feature_path.items():
                # get the absolute path
                if key in self.datasets_names:
                    self.image_feature_path[key] = os.path.join(
                        args.data_dir, val)
                else:
                    logging.info("Data {} with path {} is not used in the "
                                 "training.".format(key, val))
        self.encoding = encoding
        self.current_doc = 0  # to avoid random sentence from same doc
        self.current_img = ''  # to avoid random sentence from same image

        self.args = args

        # for loading samples directly from file
        self.sample_counter = 0  # used to keep track of full epochs on file
        self.line_buffer = None  # keep second sentence of a pair in memory and use as first sentence in next pair

        # for loading samples in memory
        self.current_random_doc = 0
        self.num_docs = 0
        self.sample_to_doc = []  # map sample index to doc and line

        self.chunk_list = None
        if 0 <= args.chunk_start_id <= args.chunk_end_id and args.chunk_end_id >= 0:
            self.chunk_list = [
                str(c_i)
                for c_i in range(args.chunk_start_id, args.chunk_end_id)
            ]
            logging.info('Chunk list: {}'.format(','.join(self.chunk_list)))

        # load image tags and features
        t_start = time.time()
        self.img_label_file = None
        self.img_qa_file = None
        self.img_label_offset_map = None
        self.img_qa_offset_map = None
        self.img_feature_file = None
        self.img_feat_offset_map = None
        self.load_img_labels()
        self.load_img_tsv_features()
        t_end = time.time()
        logging.info(
            'Info: loading img features using {} secs'.format(t_end - t_start))

        # load samples into memory
        if on_memory:
            self.all_docs = []
            self.all_qa_docs = []
            self.imgid2labels = {}
            self.corpus_lines = 0
            max_tokens = 0
            for line_no in tqdm(range(len(self.corpus_tsvfile))):
                doc = []
                row = self.corpus_tsvfile.seek(line_no)
                img_info = row[0].split('_')
                label_info = row[1].split('_')
                assert img_info[0] == label_info[
                    0], "Dataset names for image and label do not match!"
                dataset_name = label_info[0]
                if dataset_name == 'cc':
                    dataset_name = 'googlecc'

                if dataset_name not in self.datasets_names:
                    continue

                if dataset_name in self.datasets_with_splits:
                    chunk_id = img_info[-2]
                    if self.chunk_list is not None and chunk_id not in self.chunk_list:
                        continue
                    else:
                        img_feat_offset_map = self.img_feat_offset_map[
                            dataset_name][chunk_id]
                else:
                    img_feat_offset_map = self.img_feat_offset_map[
                        dataset_name]
                assert img_info[
                    -1] in img_feat_offset_map, "{}: Image id {} cannot be found in image feature imageid_to_index file!".format(
                        row[0], img_info[-1])

                # append id info
                doc.append('%s|%s' % (row[0], row[1]))
                # append text_a info
                self.corpus_lines = self.corpus_lines + 1
                sample = {"doc_id": len(self.all_docs), "line": len(doc)}
                self.sample_to_doc.append(sample)
                assert len(row[2]) != 0, "Text_a is empty in {} : {}"\
                    .format(dataset_name, row[0])
                doc.append(row[2])
                # append text_b info
                self.corpus_lines = self.corpus_lines + 1
                label_id = label_info[-1]
                if 'qa' in label_info:
                    assert img_info[-1] == label_info[
                        -2], "Image ids for image and qa do not match!"
                    label_line_no = self.img_qa_offset_map[dataset_name][
                        label_id]
                    rowb = self.img_qa_file[dataset_name].seek(label_line_no)
                else:
                    assert img_info[-1] == label_info[
                        -1], "Image ids for image and label do not match!"
                    label_line_no = self.img_label_offset_map[dataset_name][
                        label_id]
                    rowb = self.img_label_file[dataset_name].seek(
                        label_line_no)
                assert label_id == rowb[0]
                results = json.loads(rowb[1])
                if 'qa' not in label_info:  # more intuitively, should be if 'qa' not in label_info:
                    objects = results['objects']
                    if row[0] not in self.imgid2labels:
                        self.imgid2labels[row[0]] = {
                            "image_h": results["image_h"],
                            "image_w": results["image_w"],
                            "boxes": None
                        }
                    else:
                        assert results["image_h"] == self.imgid2labels[row[0]][
                            "image_h"], "Image_h does not match in image {}!".format(
                                row[0])
                        assert results["image_w"] == self.imgid2labels[row[0]][
                            "image_w"], "Image_w does not match in image {}!".format(
                                row[0])
                    if args.use_gtlabels and 'gt_objects' in results:
                        # use ground-truth tags for text_b
                        textb = ' '.join([
                            cur_d['class'] for cur_d in results["gt_objects"]
                        ])
                    else:
                        textb = ' '.join([cur_d['class'] for cur_d in objects])
                else:
                    tag_label_line_no = self.img_label_offset_map[
                        dataset_name][img_info[-1]]
                    tag_rowb = self.img_label_file[dataset_name].seek(
                        tag_label_line_no)
                    tag_results = json.loads(tag_rowb[1])
                    if row[0] not in self.imgid2labels:
                        self.imgid2labels[row[0]] = {
                            "image_h": tag_results["image_h"],
                            "image_w": tag_results["image_w"],
                            "boxes": None
                        }
                    else:
                        assert tag_results["image_h"] == self.imgid2labels[row[0]][
                            "image_h"], "Image_h does not match in image {}!".format(
                                row[0])
                        assert tag_results["image_w"] == self.imgid2labels[row[0]][
                            "image_w"], "Image_w does not match in image {}!".format(
                                row[0])
                    textb = ' '.join(results['labels'])
                assert len(textb) != 0, "Text_b is empty in {} : {}".format(
                    dataset_name, row[1])
                doc.append(textb)

                # add to all_docs
                max_tokens = max(
                    max_tokens,
                    len(doc[1].split(' ')) + len(doc[2].split(' ')))
                if 'qa' in label_info:
                    self.all_qa_docs.append({
                        "doc": doc,
                        "doc_id": len(self.all_docs)
                    })
                self.all_docs.append(doc)

            self.num_docs = len(self.all_docs)
            logging.info("Max_tokens: {}".format(max_tokens))
        # load samples later lazily from disk
        else:
            raise ValueError("on_memory = False Not supported yet!")

        logging.info("Total docs - Corpus_lines: {}-{}".format(
            self.num_docs, self.corpus_lines))
        logging.info("Total QA docs - Corpus_lines: {}".format(
            len(self.all_qa_docs)))

    def __len__(self):
        # last line of doc won't be used, because there's no "nextSentence".
        return self.corpus_lines - self.num_docs

    def get_img_info(self, idx):
        sample = self.sample_to_doc[idx]
        # img_id = self.all_docs[sample["doc_id"]][0].strip() # original
        img_id = self.all_docs[sample["doc_id"]][0].strip().split('|')[0]
        imgid2labels = self.imgid2labels[img_id]
        return {
            "height": imgid2labels["image_h"],
            "width": imgid2labels["image_w"]
        }

    def __getitem__(self, item):
        cur_id = self.sample_counter
        self.sample_counter += 1
        if not self.on_memory:
            # after one epoch we start again from beginning of file
            if cur_id != 0 and (cur_id % len(self) == 0):
                raise ValueError("on_memory = False Not supported yet!")

        img_id, t1, t2, is_next_label, is_img_match = self.random_sent(item)

        # tokenize
        tokens_a = self.tokenizer.tokenize(t1)
        if self.args.use_b:
            tokens_b = self.tokenizer.tokenize(t2)
        else:
            tokens_b = None

        # combine to one sample
        cur_example = InputExample(guid=cur_id,
                                   tokens_a=tokens_a,
                                   tokens_b=tokens_b,
                                   is_next=is_next_label,
                                   img_id=img_id,
                                   is_img_match=is_img_match)

        # get image feature
        img_feat = self.get_img_feature(img_id)
        if img_feat.shape[0] >= self.args.max_img_seq_length:
            img_feat = img_feat[0:self.args.max_img_seq_length, ]
            img_feat_len = img_feat.shape[0]
        else:
            img_feat_len = img_feat.shape[0]
            padding_matrix = torch.zeros(
                (self.args.max_img_seq_length - img_feat.shape[0],
                 img_feat.shape[1]))
            img_feat = torch.cat((img_feat, padding_matrix), 0)

        # transform sample to features
        cur_features = convert_example_to_features(self.args, cur_example,
                                                   self.seq_len,
                                                   self.tokenizer,
                                                   img_feat_len)

        return img_feat, (
            torch.tensor(cur_features.input_ids, dtype=torch.long),
            torch.tensor(cur_features.input_mask, dtype=torch.long),
            torch.tensor(cur_features.segment_ids, dtype=torch.long),
            torch.tensor(cur_features.lm_label_ids, dtype=torch.long),
            torch.tensor(cur_features.is_next),
            torch.tensor(cur_features.is_img_match),
        ), item
        # return cur_tensors

    def random_sent(self, index):
        """
        Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
        from one doc. With 50% the second sentence will be a random one from another doc.
        :param index: int, index of sample.
        :return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
        """
        img_id, t1, t2 = self.get_corpus_line(index)
        rand_dice = random.random()
        if rand_dice > 0.5:
            label = 0
            random_img_id = img_id
        elif rand_dice > self.args.texta_false_prob and t2 != "":
            # wrong qa triplets
            random_img_id, t2 = self.get_random_line()
            label = 1
        else:
            # wrong retrieval triplets
            random_img_id, t1 = self.get_random_texta()
            # args.num_contrast_classes = 3 if args.texta_false_prob<0.5 and (args.texta_false_prob>0 or not args.use_b) else 2
            label = self.args.num_contrast_classes - 1

        img_match_label = 0
        if img_id != random_img_id: img_match_label = 1

        assert len(t1) > 0
        assert len(t2) > 0 or not self.args.use_b
        return img_id, t1, t2, label, img_match_label

    def get_corpus_line(self, item):
        """
        Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
        :param item: int, index of sample.
        :return: (str, str), two subsequent sentences from corpus
        """
        assert item < self.corpus_lines
        if self.on_memory:
            sample = self.sample_to_doc[item]
            # img_id = self.all_docs[sample["doc_id"]][0].strip() # original
            img_id = self.all_docs[sample["doc_id"]][0].strip().split('|')[0]
            t1 = self.all_docs[sample["doc_id"]][sample["line"]]
            t2 = self.all_docs[sample["doc_id"]][sample["line"] + 1]
            # used later to avoid random nextSentence from same doc
            self.current_doc = sample["doc_id"]
            self.current_img = img_id

            assert t1 != ""
            if self.args.use_b or 'qa' in self.all_docs[
                    sample["doc_id"]][0].split('_'):
                assert t2 != ""
            else:
                t2 = ""
            return img_id, t1, t2
        else:
            raise ValueError("on_memory = False Not supported yet!")

    def get_random_line(self):
        """
        Get random line from another document for nextSentence task.
        :return: str, content of one line
        """
        # Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
        # corpora. However, just to be careful, we try to make sure that
        # the random document is not the same as the document we're processing.
        if self.on_memory:
            if self.textb_sample_mode in [0, 1]:
                # sample from all docs
                for _ in range(10):
                    rand_doc_idx = random.randrange(0, len(self.all_docs))
                    img_id = self.all_docs[rand_doc_idx][0].split('|')[0]
                    # check if our picked random line is really from another image like we want it to be
                    if img_id != self.current_img:
                        break
                rand_doc = self.all_docs[rand_doc_idx]
            else:
                # sample from all qa docs
                for _ in range(10):
                    rand_doc_idx = random.randrange(0, len(self.all_qa_docs))
                    # check if our picked random line is really from another doc like we want it to be % no need to be different image here
                    if self.all_qa_docs[rand_doc_idx][
                            "doc_id"] != self.current_doc:
                        break
                rand_doc = self.all_qa_docs[rand_doc_idx]["doc"]
            # img_id = rand_doc[0] # original
            img_id = rand_doc[0].split('|')[0]
            if self.textb_sample_mode == 0:
                # default oscar sample mode
                line = rand_doc[random.randrange(1, len(rand_doc))]
            else:
                # only sample text_b
                line = rand_doc[2]
            return img_id, line
        else:
            raise ValueError("on_memory = False Not supported yet!")

    def get_random_texta(self):
        """
        Get random text_a from another document for nextSentence task.
        :return: str, content of one line
        """
        # Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
        # corpora. However, just to be careful, we try to make sure that
        # the random document is not the same as the document we're processing.
        if self.on_memory:
            for _ in range(10):
                rand_doc_idx = random.randrange(0, len(self.all_docs))
                img_id = self.all_docs[rand_doc_idx][0].split('|')[0]
                # check if our picked random line is really from another image like we want it to be
                if img_id != self.current_img:
                    break
            rand_doc = self.all_docs[rand_doc_idx]
            # img_id = rand_doc[0] # original
            img_id = rand_doc[0].split('|')[0]
            line = rand_doc[1]  # we want the text_a
            return img_id, line
        else:
            raise ValueError("on_memory = False Not supported yet!")

    # tsv image labels
    def load_img_labels(self):
        self.check_img_label_file()
        self.check_img_label_offset_map()

    def check_img_label_file(self):
        if self.img_label_file is None:
            self.img_label_file = {}
            self.img_qa_file = {}
            for dataset_name in self.datasets_names:
                img_label_file_path = os.path.join(
                    self.image_label_path[dataset_name], 'predictions_gt.tsv')
                img_qa_file_path = os.path.join(
                    self.image_label_path[dataset_name], 'QA_fileB.tsv')
                t_s = time.time()
                self.img_label_file[dataset_name] = TSVFile(
                    img_label_file_path)
                if os.path.exists(img_qa_file_path):
                    self.img_qa_file[dataset_name] = TSVFile(img_qa_file_path)
                t_e = time.time()
                logging.info("Open image label file {}, time: {}".format(
                    img_label_file_path, (t_e - t_s)))

    def check_img_label_offset_map(self):
        if self.img_label_offset_map is None:
            self.img_label_offset_map = {}
            self.img_qa_offset_map = {}
            for dataset_name in self.datasets_names:
                img_label_offset_map_path = os.path.join(
                    self.image_label_path[dataset_name], 'imageid2idx.json')
                img_qa_offset_map_path = os.path.join(
                    self.image_label_path[dataset_name], 'QA_qaid2idx.json')
                t_s = time.time()
                self.img_label_offset_map[dataset_name] = json.load(
                    open(img_label_offset_map_path))
                if os.path.exists(img_qa_offset_map_path):
                    self.img_qa_offset_map[dataset_name] = json.load(
                        open(img_qa_offset_map_path))
                t_e = time.time()
                logging.info("Load img label offset map: {}, time: {}".format(
                    img_label_offset_map_path, (t_e - t_s)))

    def get_img_labels(self, image_id):
        """ decode the image labels: read the image label from the img_label.tsv """
        self.check_img_label_file()
        self.check_img_label_offset_map()

        if image_id in self.img_label_offset_map:
            img_offset = self.img_label_offset_map[image_id]

            self.img_label_file.seek(img_offset, 0)
            arr = [
                s.strip() for s in self.img_label_file.readline().split('\t')
            ]
            eles = json.loads(arr[1])
            labels = eles['labels']
            return labels

        return None

    # tsv feature loading
    def load_img_tsv_features(self):
        self.check_img_feature_file()
        self.check_img_feature_offset_map()

    def check_img_feature_file(self):
        if self.img_feature_file is None:
            # self.img_feature_file = [] # original
            self.img_feature_file = {}
            self.img_feat_offset_map = {}
            for dataset_name in self.datasets_names:
                logging.info("* Loading dataset {}".format(dataset_name))
                if dataset_name in self.datasets_with_splits:
                    self.img_feature_file[dataset_name] = {}
                    self.img_feat_offset_map[dataset_name] = {}
                    chunk_list = []
                    if self.chunk_list is not None:
                        chunk_list = self.chunk_list
                        chunk_file_list = []
                        for chunk_fp_id in chunk_list:
                            chunk_file_list.append(
                                os.path.join(
                                    self.image_feature_path[dataset_name],
                                    chunk_fp_id, self.image_file_name))
                        if dataset_name == 'googlecc':
                            for i, (chunk_fp_id, chunk_fp) in enumerate(
                                    zip(chunk_list, chunk_file_list)):
                                assert os.path.exists(
                                    chunk_file_list[i]
                                ), "Chunk file {} does not exists!".format(
                                    chunk_fp)
                    else:
                        chunk_file_list = glob.glob(
                            self.image_feature_path[dataset_name] +
                            "/*/{}".format(self.image_file_name))
                        for chunk_fp in chunk_file_list:
                            chunk_fp_id = chunk_fp.split('/')[-2]
                            chunk_list.append(chunk_fp_id)
                    logging.info("* Load Image Chunks {}".format(
                        len(chunk_list)))

                    t_s_total = time.time()
                    for chunk_fp in chunk_file_list:
                        chunk_fp_id = chunk_fp.split('/')[-2]
                        t_s = time.time()
                        self.img_feature_file[dataset_name][
                            chunk_fp_id] = TSVFile(chunk_fp)
                        chunk_offsetmap = os.path.join(
                            os.path.dirname(chunk_fp), 'imageid2idx.json')
                        assert os.path.isfile(
                            chunk_offsetmap
                        ), "Imageid2idx file {} does not exists!".format(
                            chunk_offsetmap)
                        self.img_feat_offset_map[dataset_name][
                            chunk_fp_id] = json.load(open(
                                chunk_offsetmap, 'r'))
                        t_e = time.time()
                        logging.info("Open image chunk {}, time: {}".format(
                            chunk_fp_id, (t_e - t_s)))
                    t_e_total = time.time()
                    logging.info("Open total {} image chunks, time: {}".format(
                        len(chunk_list), (t_e_total - t_s_total)))
                    logging.info("Image chunk info: {}".format(
                        '\n'.join(chunk_file_list)))
                elif dataset_name in self.datasets_with_onesplit:
                    t_s = time.time()
                    chunk_fp = os.path.join(
                        self.image_feature_path[dataset_name],
                        self.image_file_name)
                    self.img_feature_file[dataset_name] = TSVFile(chunk_fp)
                    chunk_offsetmap = os.path.join(os.path.dirname(chunk_fp),
                                                   'imageid2idx.json')
                    assert os.path.isfile(
                        chunk_offsetmap
                    ), "Imageid2idx file {} does not exists!".format(
                        chunk_offsetmap)
                    self.img_feat_offset_map[dataset_name] = json.load(
                        open(chunk_offsetmap, 'r'))
                    t_e = time.time()
                    logging.info("Open dataset {}, time: {}".format(
                        chunk_fp, (t_e - t_s)))
                else:
                    raise ValueError(
                        "Not supported dataset: {}".format(dataset_name))

    def check_img_feature_offset_map(self):
        """ load the image feature offset map """
        if self.img_feat_offset_map is None:
            self.img_feat_offset_map = {}
            for dataset_name in self.datasets_names:
                logging.info(
                    "* Loading imageid2idx_map {}".format(dataset_name))
                if dataset_name in self.datasets_with_splits:
                    chunk_list = []
                    chunk_file_list = glob.glob(
                        self.image_feature_path[dataset_name] +
                        "/*/imageid2idx.json")
                    for chunk_fp in chunk_file_list:
                        chunk_fp_id = chunk_fp.split('/')[-2]
                        chunk_list.append(chunk_fp_id)
                    logging.info("* Load Image Chunks {}".format(
                        len(chunk_list)))

                    t_s_total = time.time()
                    for chunk_fp in chunk_file_list:
                        chunk_fp_id = chunk_fp.split('/')[-2]
                        t_s = time.time()
                        self.img_feat_offset_map[dataset_name][
                            chunk_fp_id] = json.load(open(chunk_fp))
                        t_e = time.time()
                        logging.info("Open image chunk {}, time: {}".format(
                            chunk_fp_id, (t_e - t_s)))
                    t_e_total = time.time()
                    logging.info("Open total {} image chunks, time: {}".format(
                        len(chunk_list), (t_e_total - t_s_total)))
                elif dataset_name in self.datasets_with_onesplit:
                    t_s = time.time()
                    chunk_fp = self.image_feature_path[
                        dataset_name] + "/imageid2idx.json"
                    self.img_feat_offset_map[dataset_name] = json.load(
                        open(chunk_fp))
                    t_e = time.time()
                    logging.info("Open dataset {}, time: {}".format(
                        chunk_fp, (t_e - t_s)))
                else:
                    raise ValueError(
                        "Not supported dataset: {}".format(dataset_name))

    def get_img_feature(self, image_id):
        """ decode the image feature: read the image feature from the right chunk id """
        self.check_img_feature_file()
        self.check_img_feature_offset_map()
        img_infos = image_id.split('_')
        dataset_name = img_infos[0]
        if dataset_name == 'cc':
            dataset_name = 'googlecc'
        img_id = img_infos[-1]
        if dataset_name in self.datasets_with_splits:
            chunk_id = img_infos[-2]
            img_feat_offset_map = self.img_feat_offset_map[dataset_name][
                chunk_id]
            img_feature_file = self.img_feature_file[dataset_name][chunk_id]
        else:
            img_feat_offset_map = self.img_feat_offset_map[dataset_name]
            img_feature_file = self.img_feature_file[dataset_name]
        if img_id in img_feat_offset_map:
            img_offset = img_feat_offset_map[img_id]

            arr = img_feature_file.seek(img_offset)
            num_boxes = int(arr[1])
            feat = np.frombuffer(base64.b64decode(arr[-1]),
                                 dtype=np.float32).reshape(
                                     (num_boxes, self.args.img_feature_dim))
            feat = torch.from_numpy(feat)
            return feat

        return None
class CaptionTSVDataset(Dataset):
    def __init__(self,
                 yaml_file,
                 tokenizer=None,
                 add_od_labels=True,
                 disable_img_features=False,
                 keep_top_percentage_tag_conf_threshold=0.3,
                 keep_top_percentage_tag=1,
                 max_img_seq_length=50,
                 max_seq_length=70,
                 max_seq_a_length=40,
                 is_train=True,
                 mask_prob=0.15,
                 max_masked_tokens=3,
                 **kwargs):
        """Constructor.
        Args:
            yaml file with all required data (image feature, caption, labels, etc)
            tokenizer: tokenizer for text processing.
            add_od_labels: whether to add labels from yaml file to BERT. 
            max_img_seq_length: max image sequence length.
            max_seq_length: max text sequence length.
            max_seq_a_length: max caption sequence length.
            is_train: train or test mode.
            mask_prob: probability to mask a input token.
            max_masked_tokens: maximum number of tokens to be masked in one sentence.
            kwargs: other arguments.
        """
        self.yaml_file = yaml_file
        self.cfg = load_from_yaml_file(yaml_file)
        self.root = op.dirname(yaml_file)
        self.label_file = find_file_path_in_yaml(self.cfg['label'], self.root)
        self.feat_file = find_file_path_in_yaml(self.cfg['feature'], self.root)
        self.caption_file = find_file_path_in_yaml(self.cfg.get('caption'),
                                                   self.root)

        assert op.isfile(self.feat_file)
        if add_od_labels: assert op.isfile(self.label_file)
        if is_train:
            assert op.isfile(self.caption_file) and tokenizer is not None

        self.label_tsv = None if not self.label_file else TSVFile(
            self.label_file)
        self.feat_tsv = TSVFile(self.feat_file)
        if self.caption_file and op.isfile(self.caption_file):
            with open(self.caption_file, 'r') as f:
                self.captions = json.load(f)

        self.tokenizer = tokenizer
        self.tensorizer = CaptionTensorizer(self.tokenizer,
                                            max_img_seq_length,
                                            max_seq_length,
                                            max_seq_a_length,
                                            mask_prob,
                                            max_masked_tokens,
                                            is_train=is_train)
        self.add_od_labels = add_od_labels
        self.disable_img_features = disable_img_features
        self.keep_top_percentage_tag_conf_threshold = keep_top_percentage_tag_conf_threshold
        self.keep_top_percentage_tag = keep_top_percentage_tag
        self.is_train = is_train
        self.kwargs = kwargs
        self.image_keys = self.prepare_image_keys()
        self.key2index = self.prepare_image_key_to_index()
        self.key2captions = self.prepare_image_key_to_captions()

    def get_valid_tsv(self):
        # based on the order of file size
        if self.label_tsv:
            return self.label_tsv
        if self.feat_tsv:
            return self.feat_tsv

    def prepare_image_keys(self):
        tsv = self.get_valid_tsv()
        return [tsv.seek(i)[0] for i in range(tsv.num_rows())]

    def prepare_image_key_to_index(self):
        tsv = self.get_valid_tsv()
        return {tsv.seek(i)[0]: i for i in range(tsv.num_rows())}

    def prepare_image_key_to_captions(self):
        if self.is_train:
            key2captions = {key: [] for key in self.image_keys}
            for cap in self.captions:
                key2captions[cap['image_id']].append(cap['caption'])
            return key2captions

    def get_image_index(self, idx):
        if self.is_train:
            img_cap_pair = self.captions[idx]
            img_key = img_cap_pair['image_id']
            return self.key2index[img_key]
        return idx

    def get_image_key(self, idx):
        img_idx = self.get_image_index(idx)
        return self.image_keys[img_idx]

    def get_image_features(self, img_idx):
        feat_info = json.loads(self.feat_tsv.seek(img_idx)[1])
        num_boxes = feat_info['num_boxes']
        tmp = np.frombuffer(base64.b64decode(feat_info['features']),
                            np.float32).reshape((num_boxes, -1))
        # remove image features from fine-tuning stage
        if self.disable_img_features:
            features = tmp.copy()
            features.fill(0)
        else:
            features = tmp
        return torch.Tensor(features)

    def get_caption(self, idx):
        if self.is_train:
            img_cap_pair = self.captions[idx]
            return img_cap_pair['caption']
        return ""

    def get_od_labels(self, img_idx):
        od_labels = None
        if self.add_od_labels:
            label_info = json.loads(self.label_tsv.seek(img_idx)[1])
            if self.keep_top_percentage_tag == 1:
                od_labels = " ".join([l['class'] for l in label_info])
            else:
                # only keep label that have >= keep_top_percentage_tag_conf_threshold confidence
                label_over_threshold = [
                    l for l in label_info
                    if l['conf'] >= self.keep_top_percentage_tag_conf_threshold
                ]
                sorted_by_conf = sorted(label_over_threshold,
                                        key=lambda i: i['conf'],
                                        reverse=True)
                # top keep_top_percentage_tag% of object tag that is above the confidence threshold
                label_info_sort_by_conf = sorted_by_conf[:int(
                    len(sorted_by_conf) * self.keep_top_percentage_tag)]
                od_labels = " ".join(
                    [l['class'] for l in label_info_sort_by_conf])
        return od_labels

    def get_caption_file_in_coco_format(self):
        cap_file = op.splitext(self.caption_file)[0] + '_coco_format.json'
        return cap_file

    def get_captions_by_key(self, key):
        assert self.is_train, "cannot get captions for inference"
        return self.key2captions[key]

    def __getitem__(self, idx):
        img_idx = self.get_image_index(idx)
        img_key = self.image_keys[img_idx]
        features = self.get_image_features(img_idx)
        caption = self.get_caption(idx)
        od_labels = self.get_od_labels(img_idx)
        example = self.tensorizer.tensorize_example(caption,
                                                    features,
                                                    text_b=od_labels)
        return img_key, example

    def __len__(self):
        if self.is_train:
            return len(self.captions)
        return self.get_valid_tsv().num_rows()
Esempio n. 7
0
    def __init__(self, tokenizer, args, split='train', is_train=True):
        """
        tokenizer: tokenizer to process caption text.
        args: configureation parameters including max_seq_length, etc.
        split: used to infer the data used for training or testing. 
             All files are in .pt format of a dictionary with image keys and 
             image features (pytorch tensors), captions (list of str, support multiple
             captions per image), labels (list of dictionary or str of all labels),

        """
        super(RetrievalDataset, self).__init__()
        self.img_file = args.img_feat_file
        caption_file = op.join(args.data_dir, '{}_captions.pt'.format(split))
        self.img_tsv = TSVFile(self.img_file)
        self.captions = torch.load(caption_file)
        self.img_keys = list(self.captions.keys())  # img_id as int
        if not type(self.captions[self.img_keys[0]]) == list:
            self.captions = {
                k: json.loads(self.captions[k])
                for k in self.img_keys
            }

        # get the image image_id to index map
        imgid2idx_file = op.join(op.dirname(self.img_file), 'imageid2idx.json')
        self.image_id2idx = json.load(open(imgid2idx_file))  # img_id as string

        if args.add_od_labels:
            label_data_dir = op.dirname(self.img_file)
            label_file = os.path.join(label_data_dir, "predictions.tsv")
            self.label_tsv = TSVFile(label_file)
            self.labels = {}
            for line_no in range(self.label_tsv.num_rows()):
                row = self.label_tsv.seek(line_no)
                image_id = row[0]
                if int(image_id) in self.img_keys:
                    results = json.loads(row[1])
                    objects = results['objects'] if type(
                        results) == dict else results
                    self.labels[int(image_id)] = {
                        "image_h":
                        results["image_h"] if type(results) == dict else 600,
                        "image_w":
                        results["image_w"] if type(results) == dict else 800,
                        "class": [cur_d['class'] for cur_d in objects],
                        "boxes":
                        np.array([cur_d['rect'] for cur_d in objects],
                                 dtype=np.float32)
                    }
            self.label_tsv._fp.close()
            self.label_tsv._fp = None

        if is_train:
            self.num_captions_per_img = args.num_captions_per_img_train
        else:
            self.num_captions_per_img = args.num_captions_per_img_val
            if args.eval_img_keys_file:
                # select a subset of image keys for evaluation. eg. COCO 1k and 5k
                # eval_img_keys_file is a list of image keys saved in tsv file
                with open(op.join(args.data_dir, args.eval_img_keys_file),
                          'r') as f:
                    img_keys = f.readlines()
                self.img_keys = [int(k.strip()) for k in img_keys]
                self.captions = {k: self.captions[k] for k in self.img_keys}
                if args.add_od_labels:
                    self.labels = {k: self.labels[k] for k in self.img_keys}

            if args.eval_caption_index_file:
                # hard negative image/caption indexs for retrieval re-rank setting.
                # useful for mini val set to monitor the performance during training.
                # However, it cannot be used together with cross image evaluation.
                self.has_caption_indexs = True
                assert not args.cross_image_eval
                caption_index_file = op.join(args.data_dir,
                                             args.eval_caption_index_file)
                self.caption_indexs = torch.load(caption_index_file)
                if not type(self.caption_indexs[self.img_keys[0]]) == list:
                    self.caption_indexs = {
                        k: json.loads(self.caption_indexs[k])
                        for k in self.img_keys
                    }
            else:
                self.has_caption_indexs = False
        self.is_train = is_train
        self.output_mode = args.output_mode
        self.tokenizer = tokenizer
        self.max_seq_len = args.max_seq_length
        self.max_img_seq_len = args.max_img_seq_length
        self.args = args
Esempio n. 8
0
class RetrievalDataset(Dataset):
    """ Image/Text Retrieval Dataset"""
    def __init__(self, tokenizer, args, split='train', is_train=True):
        """
        tokenizer: tokenizer to process caption text.
        args: configureation parameters including max_seq_length, etc.
        split: used to infer the data used for training or testing. 
             All files are in .pt format of a dictionary with image keys and 
             image features (pytorch tensors), captions (list of str, support multiple
             captions per image), labels (list of dictionary or str of all labels),

        """
        super(RetrievalDataset, self).__init__()
        self.img_file = args.img_feat_file
        caption_file = op.join(args.data_dir, '{}_captions.pt'.format(split))
        self.img_tsv = TSVFile(self.img_file)
        self.captions = torch.load(caption_file)
        self.img_keys = list(self.captions.keys())  # img_id as int
        if not type(self.captions[self.img_keys[0]]) == list:
            self.captions = {
                k: json.loads(self.captions[k])
                for k in self.img_keys
            }

        # get the image image_id to index map
        imgid2idx_file = op.join(op.dirname(self.img_file), 'imageid2idx.json')
        self.image_id2idx = json.load(open(imgid2idx_file))  # img_id as string

        if args.add_od_labels:
            label_data_dir = op.dirname(self.img_file)
            label_file = os.path.join(label_data_dir, "predictions.tsv")
            self.label_tsv = TSVFile(label_file)
            self.labels = {}
            for line_no in range(self.label_tsv.num_rows()):
                row = self.label_tsv.seek(line_no)
                image_id = row[0]
                if int(image_id) in self.img_keys:
                    results = json.loads(row[1])
                    objects = results['objects'] if type(
                        results) == dict else results
                    self.labels[int(image_id)] = {
                        "image_h":
                        results["image_h"] if type(results) == dict else 600,
                        "image_w":
                        results["image_w"] if type(results) == dict else 800,
                        "class": [cur_d['class'] for cur_d in objects],
                        "boxes":
                        np.array([cur_d['rect'] for cur_d in objects],
                                 dtype=np.float32)
                    }
            self.label_tsv._fp.close()
            self.label_tsv._fp = None

        if is_train:
            self.num_captions_per_img = args.num_captions_per_img_train
        else:
            self.num_captions_per_img = args.num_captions_per_img_val
            if args.eval_img_keys_file:
                # select a subset of image keys for evaluation. eg. COCO 1k and 5k
                # eval_img_keys_file is a list of image keys saved in tsv file
                with open(op.join(args.data_dir, args.eval_img_keys_file),
                          'r') as f:
                    img_keys = f.readlines()
                self.img_keys = [int(k.strip()) for k in img_keys]
                self.captions = {k: self.captions[k] for k in self.img_keys}
                if args.add_od_labels:
                    self.labels = {k: self.labels[k] for k in self.img_keys}

            if args.eval_caption_index_file:
                # hard negative image/caption indexs for retrieval re-rank setting.
                # useful for mini val set to monitor the performance during training.
                # However, it cannot be used together with cross image evaluation.
                self.has_caption_indexs = True
                assert not args.cross_image_eval
                caption_index_file = op.join(args.data_dir,
                                             args.eval_caption_index_file)
                self.caption_indexs = torch.load(caption_index_file)
                if not type(self.caption_indexs[self.img_keys[0]]) == list:
                    self.caption_indexs = {
                        k: json.loads(self.caption_indexs[k])
                        for k in self.img_keys
                    }
            else:
                self.has_caption_indexs = False
        self.is_train = is_train
        self.output_mode = args.output_mode
        self.tokenizer = tokenizer
        self.max_seq_len = args.max_seq_length
        self.max_img_seq_len = args.max_img_seq_length
        self.args = args

    def get_image_caption_index(self, index):
        # return img_idx to access features and [img_key, cap_idx] to access caption
        if not self.is_train and self.args.cross_image_eval:
            img_idx = index // (self.num_captions_per_img * len(self.img_keys))
            cap_idx = index % (self.num_captions_per_img * len(self.img_keys))
            img_idx1 = cap_idx // self.num_captions_per_img
            cap_idx1 = cap_idx % self.num_captions_per_img
            return img_idx, [self.img_keys[img_idx1], cap_idx1]
        if not self.is_train and self.has_caption_indexs:
            img_idx = index // self.num_captions_per_img
            cap_idx = index % self.num_captions_per_img
            img_key1, cap_idx1 = self.caption_indexs[
                self.img_keys[img_idx]][cap_idx]
            return img_idx, [img_key1, cap_idx1]
        img_idx = index // self.num_captions_per_img
        cap_idx = index % self.num_captions_per_img
        return img_idx, [self.img_keys[img_idx], cap_idx]

    def get_label(self, index):
        img_idx, cap_idx = self.get_image_caption_index(index)
        return 1 if self.img_keys[img_idx] == cap_idx[0] else 0

    def get_od_labels(self, img_key):
        if self.args.add_od_labels:
            if type(self.labels[img_key]) == str:
                od_labels = self.labels[img_key]
            else:
                od_labels = ' '.join(self.labels[img_key]['class'])
            return od_labels

    def tensorize_example(self,
                          text_a,
                          img_feat,
                          text_b=None,
                          cls_token_segment_id=0,
                          pad_token_segment_id=0,
                          sequence_a_segment_id=0,
                          sequence_b_segment_id=1):
        tokens_a = self.tokenizer.tokenize(text_a)
        if len(tokens_a) > self.args.max_seq_length - 2:
            tokens_a = tokens_a[:(self.args.max_seq_length - 2)]

        tokens = [self.tokenizer.cls_token
                  ] + tokens_a + [self.tokenizer.sep_token]
        segment_ids = [cls_token_segment_id
                       ] + [sequence_a_segment_id] * (len(tokens_a) + 1)
        seq_a_len = len(tokens)
        if text_b:
            tokens_b = self.tokenizer.tokenize(text_b)
            if len(tokens_b) > self.max_seq_len - len(tokens) - 1:
                tokens_b = tokens_b[:(self.max_seq_len - len(tokens) - 1)]
            tokens += tokens_b + [self.tokenizer.sep_token]
            segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)

        seq_len = len(tokens)
        seq_padding_len = self.max_seq_len - seq_len
        tokens += [self.tokenizer.pad_token] * seq_padding_len
        segment_ids += [pad_token_segment_id] * seq_padding_len
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)

        # image features
        img_len = img_feat.shape[0]
        if img_len > self.max_img_seq_len:
            img_feat = img_feat[0:self.max_img_seq_len, :]
            img_len = img_feat.shape[0]
            img_padding_len = 0
        else:
            img_padding_len = self.max_img_seq_len - img_len
            padding_matrix = torch.zeros((img_padding_len, img_feat.shape[1]))
            img_feat = torch.cat((img_feat, padding_matrix), 0)

        # generate attention_mask
        att_mask_type = self.args.att_mask_type
        if att_mask_type == "CLR":
            attention_mask = [1] * seq_len + [0] * seq_padding_len + \
                             [1] * img_len + [0] * img_padding_len
        else:
            # use 2D mask to represent the attention
            max_len = self.max_seq_len + self.max_img_seq_len
            attention_mask = torch.zeros((max_len, max_len), dtype=torch.long)
            # full attention of C-C, L-L, R-R
            c_start, c_end = 0, seq_a_len
            l_start, l_end = seq_a_len, seq_len
            r_start, r_end = self.max_seq_len, self.max_seq_len + img_len
            attention_mask[c_start:c_end, c_start:c_end] = 1
            attention_mask[l_start:l_end, l_start:l_end] = 1
            attention_mask[r_start:r_end, r_start:r_end] = 1
            if att_mask_type == 'CL':
                attention_mask[c_start:c_end, l_start:l_end] = 1
                attention_mask[l_start:l_end, c_start:c_end] = 1
            elif att_mask_type == 'CR':
                attention_mask[c_start:c_end, r_start:r_end] = 1
                attention_mask[r_start:r_end, c_start:c_end] = 1
            elif att_mask_type == 'LR':
                attention_mask[l_start:l_end, r_start:r_end] = 1
                attention_mask[r_start:r_end, l_start:l_end] = 1
            else:
                raise ValueError(
                    "Unsupported attention mask type {}".format(att_mask_type))

        input_ids = torch.tensor(input_ids, dtype=torch.long)
        attention_mask = torch.tensor(attention_mask, dtype=torch.long)
        segment_ids = torch.tensor(segment_ids, dtype=torch.long)
        return (input_ids, attention_mask, segment_ids, img_feat)

    def __getitem__(self, index):
        if self.is_train:
            img_idx, cap_idxs = self.get_image_caption_index(index)
            img_key = self.img_keys[img_idx]
            feature = self.get_image(img_key)
            caption = self.captions[cap_idxs[0]][cap_idxs[1]]
            od_labels = self.get_od_labels(img_key)
            example = self.tensorize_example(caption,
                                             feature,
                                             text_b=od_labels)

            # select a negative pair
            neg_img_indexs = list(range(0, img_idx)) + list(
                range(img_idx + 1, len(self.img_keys)))
            img_idx_neg = random.choice(neg_img_indexs)
            if random.random() <= 0.5:
                # randomly select a negative caption from a different image.
                cap_idx_neg = random.randint(0, self.num_captions_per_img - 1)
                caption_neg = self.captions[
                    self.img_keys[img_idx_neg]][cap_idx_neg]
                example_neg = self.tensorize_example(caption_neg,
                                                     feature,
                                                     text_b=od_labels)
            else:
                # randomly select a negative image
                feature_neg = self.get_image(self.img_keys[img_idx_neg])
                od_labels_neg = self.get_od_labels(self.img_keys[img_idx_neg])
                example_neg = self.tensorize_example(caption,
                                                     feature_neg,
                                                     text_b=od_labels_neg)

            example_pair = tuple(list(example) + [1] + list(example_neg) + [0])
            return index, example_pair
        else:
            img_idx, cap_idxs = self.get_image_caption_index(index)
            img_key = self.img_keys[img_idx]
            feature = self.get_image(img_key)
            caption = self.captions[cap_idxs[0]][cap_idxs[1]]
            od_labels = self.get_od_labels(img_key)
            example = self.tensorize_example(caption,
                                             feature,
                                             text_b=od_labels)
            label = 1 if img_key == cap_idxs[0] else 0
            return index, tuple(list(example) + [label])

    def get_image(self, image_id):
        image_idx = self.image_id2idx[str(image_id)]
        row = self.img_tsv.seek(image_idx)
        num_boxes = int(row[1])
        features = np.frombuffer(base64.b64decode(row[-1]),
                                 dtype=np.float32).reshape((num_boxes, -1))
        t_features = torch.from_numpy(features)
        return t_features

    def __len__(self):
        if not self.is_train and self.args.cross_image_eval:
            return len(self.img_keys)**2 * self.num_captions_per_img
        return len(self.img_keys) * self.num_captions_per_img
class CaptionTSVDataset(Dataset):
    def __init__(self,
                 yaml_file,
                 tokenizer=None,
                 add_od_labels=True,
                 max_img_seq_length=50,
                 max_seq_length=70,
                 max_seq_a_length=40,
                 is_train=True,
                 mask_prob=0.15,
                 max_masked_tokens=3,
                 add_conf=False,
                 **kwargs):
        """Constructor.
        Args:
            yaml file with all required data (image feature, caption, labels, etc)
            tokenizer: tokenizer for text processing.
            add_od_labels: whether to add labels from yaml file to BERT. 
            max_img_seq_length: max image sequence length.
            max_seq_length: max text sequence length.
            max_seq_a_length: max caption sequence length.
            is_train: train or test mode.
            mask_prob: probability to mask a input token.
            max_masked_tokens: maximum number of tokens to be masked in one sentence.
            kwargs: other arguments.
        """
        self.yaml_file = yaml_file
        self.cfg = load_from_yaml_file(yaml_file)
        self.root = op.dirname(yaml_file)
        self.label_file = find_file_path_in_yaml(self.cfg['label'], self.root)
        self.feat_file = find_file_path_in_yaml(self.cfg['feature'], self.root)
        self.caption_file = find_file_path_in_yaml(self.cfg.get('caption'),
                                                   self.root)

        assert op.isfile(self.feat_file)
        if add_od_labels: assert op.isfile(self.label_file)
        if is_train:
            assert op.isfile(self.caption_file) and tokenizer is not None

        self.label_tsv = None if not self.label_file else TSVFile(
            self.label_file)
        self.feat_tsv = TSVFile(self.feat_file)
        if self.caption_file and op.isfile(self.caption_file):
            with open(self.caption_file, 'r') as f:
                self.captions = json.load(f)

        self.tokenizer = tokenizer
        self.tensorizer = CaptionTensorizer(self.tokenizer,
                                            max_img_seq_length,
                                            max_seq_length,
                                            max_seq_a_length,
                                            mask_prob,
                                            max_masked_tokens,
                                            is_train=is_train)
        self.add_od_labels = add_od_labels
        self.is_train = is_train
        self.kwargs = kwargs
        self.image_keys = self.prepare_image_keys()
        self.key2index = self.prepare_image_key_to_index()
        self.key2captions = self.prepare_image_key_to_captions()
        self.add_conf = add_conf

    def get_valid_tsv(self):
        # based on the order of file size
        if self.label_tsv:
            return self.label_tsv
        if self.feat_tsv:
            return self.feat_tsv

    def prepare_image_keys(self):
        tsv = self.get_valid_tsv()
        return [tsv.seek(i)[0] for i in range(tsv.num_rows())]

    def prepare_image_key_to_index(self):
        tsv = self.get_valid_tsv()
        return {tsv.seek(i)[0]: i for i in range(tsv.num_rows())}

    def prepare_image_key_to_captions(self):
        if self.is_train:
            key2captions = {key: [] for key in self.image_keys}
            for cap in self.captions:
                key2captions[cap['image_id']].append(cap['caption'])
            return key2captions

    def get_image_index(self, idx):
        if self.is_train:
            img_cap_pair = self.captions[idx]
            img_key = img_cap_pair['image_id']
            return self.key2index[img_key]
        return idx

    def get_image_key(self, idx):
        img_idx = self.get_image_index(idx)
        return self.image_keys[img_idx]

    def get_image_features(self, img_idx):
        feat_info = json.loads(self.feat_tsv.seek(img_idx)[1])
        num_boxes = feat_info['num_boxes']
        features = np.frombuffer(base64.b64decode(feat_info['features']),
                                 np.float32).reshape((num_boxes, -1))
        return torch.Tensor(features)

    def get_caption(self, idx):
        if self.is_train:
            img_cap_pair = self.captions[idx]
            return img_cap_pair['caption']
        return ""

    def get_od_labels(self, img_idx):
        od_labels = None
        if self.add_od_labels:
            label_info = json.loads(self.label_tsv.seek(img_idx)[1])
            od_labels = " ".join([l['class'] for l in label_info])
        return od_labels

    def get_od_confidence(self, img_idx):
        od_confs = None
        od_labels = None
        if self.add_conf:
            label_info = json.loads(self.label_tsv.seek(img_idx)[1])
            od_confs = []
            # we need to repeat conf because some labels have spaces
            for idx, info in enumerate(label_info):
                repeats = len(info["class"].strip().split(" "))
                od_confs.extend([info["conf"]] * repeats)
        return od_confs

    def get_caption_file_in_coco_format(self):
        cap_file = op.splitext(self.caption_file)[0] + '_coco_format.json'
        return cap_file

    def get_captions_by_key(self, key):
        assert self.is_train, "cannot get captions for inference"
        return self.key2captions[key]

    def __getitem__(self, idx):
        img_idx = self.get_image_index(idx)
        img_key = self.image_keys[img_idx]
        features = self.get_image_features(img_idx)
        caption = self.get_caption(idx)
        od_labels = self.get_od_labels(img_idx)
        od_confs = self.get_od_confidence(img_idx)
        example = self.tensorizer.tensorize_example(caption,
                                                    features,
                                                    text_b=od_labels,
                                                    confs=od_confs)
        return img_key, example

    def __len__(self):
        if self.is_train:
            return len(self.captions)
        return self.get_valid_tsv().num_rows()