def __init__(self, yaml_file, tokenizer=None, add_od_labels=True, max_img_seq_length=50, max_seq_length=70, max_seq_a_length=40, is_train=True, mask_prob=0.15, max_masked_tokens=3, add_conf=False, **kwargs): """Constructor. Args: yaml file with all required data (image feature, caption, labels, etc) tokenizer: tokenizer for text processing. add_od_labels: whether to add labels from yaml file to BERT. max_img_seq_length: max image sequence length. max_seq_length: max text sequence length. max_seq_a_length: max caption sequence length. is_train: train or test mode. mask_prob: probability to mask a input token. max_masked_tokens: maximum number of tokens to be masked in one sentence. kwargs: other arguments. """ self.yaml_file = yaml_file self.cfg = load_from_yaml_file(yaml_file) self.root = op.dirname(yaml_file) self.label_file = find_file_path_in_yaml(self.cfg['label'], self.root) self.feat_file = find_file_path_in_yaml(self.cfg['feature'], self.root) self.caption_file = find_file_path_in_yaml(self.cfg.get('caption'), self.root) assert op.isfile(self.feat_file) if add_od_labels: assert op.isfile(self.label_file) if is_train: assert op.isfile(self.caption_file) and tokenizer is not None self.label_tsv = None if not self.label_file else TSVFile( self.label_file) self.feat_tsv = TSVFile(self.feat_file) if self.caption_file and op.isfile(self.caption_file): with open(self.caption_file, 'r') as f: self.captions = json.load(f) self.tokenizer = tokenizer self.tensorizer = CaptionTensorizer(self.tokenizer, max_img_seq_length, max_seq_length, max_seq_a_length, mask_prob, max_masked_tokens, is_train=is_train) self.add_od_labels = add_od_labels self.is_train = is_train self.kwargs = kwargs self.image_keys = self.prepare_image_keys() self.key2index = self.prepare_image_key_to_index() self.key2captions = self.prepare_image_key_to_captions() self.add_conf = add_conf
def check_img_label_file(self): if self.img_label_file is None: self.img_label_file = {} self.img_qa_file = {} for dataset_name in self.datasets_names: img_label_file_path = os.path.join( self.image_label_path[dataset_name], 'predictions_gt.tsv') img_qa_file_path = os.path.join( self.image_label_path[dataset_name], 'QA_fileB.tsv') t_s = time.time() self.img_label_file[dataset_name] = TSVFile( img_label_file_path) if os.path.exists(img_qa_file_path): self.img_qa_file[dataset_name] = TSVFile(img_qa_file_path) t_e = time.time() logging.info("Open image label file {}, time: {}".format( img_label_file_path, (t_e - t_s)))
def check_img_feature_file(self): if self.img_feature_file is None: # self.img_feature_file = [] # original self.img_feature_file = {} self.img_feat_offset_map = {} for dataset_name in self.datasets_names: logging.info("* Loading dataset {}".format(dataset_name)) if dataset_name in self.datasets_with_splits: self.img_feature_file[dataset_name] = {} self.img_feat_offset_map[dataset_name] = {} chunk_list = [] if self.chunk_list is not None: chunk_list = self.chunk_list chunk_file_list = [] for chunk_fp_id in chunk_list: chunk_file_list.append( os.path.join( self.image_feature_path[dataset_name], chunk_fp_id, self.image_file_name)) if dataset_name == 'googlecc': for i, (chunk_fp_id, chunk_fp) in enumerate( zip(chunk_list, chunk_file_list)): assert os.path.exists( chunk_file_list[i] ), "Chunk file {} does not exists!".format( chunk_fp) else: chunk_file_list = glob.glob( self.image_feature_path[dataset_name] + "/*/{}".format(self.image_file_name)) for chunk_fp in chunk_file_list: chunk_fp_id = chunk_fp.split('/')[-2] chunk_list.append(chunk_fp_id) logging.info("* Load Image Chunks {}".format( len(chunk_list))) t_s_total = time.time() for chunk_fp in chunk_file_list: chunk_fp_id = chunk_fp.split('/')[-2] t_s = time.time() self.img_feature_file[dataset_name][ chunk_fp_id] = TSVFile(chunk_fp) chunk_offsetmap = os.path.join( os.path.dirname(chunk_fp), 'imageid2idx.json') assert os.path.isfile( chunk_offsetmap ), "Imageid2idx file {} does not exists!".format( chunk_offsetmap) self.img_feat_offset_map[dataset_name][ chunk_fp_id] = json.load(open( chunk_offsetmap, 'r')) t_e = time.time() logging.info("Open image chunk {}, time: {}".format( chunk_fp_id, (t_e - t_s))) t_e_total = time.time() logging.info("Open total {} image chunks, time: {}".format( len(chunk_list), (t_e_total - t_s_total))) logging.info("Image chunk info: {}".format( '\n'.join(chunk_file_list))) elif dataset_name in self.datasets_with_onesplit: t_s = time.time() chunk_fp = os.path.join( self.image_feature_path[dataset_name], self.image_file_name) self.img_feature_file[dataset_name] = TSVFile(chunk_fp) chunk_offsetmap = os.path.join(os.path.dirname(chunk_fp), 'imageid2idx.json') assert os.path.isfile( chunk_offsetmap ), "Imageid2idx file {} does not exists!".format( chunk_offsetmap) self.img_feat_offset_map[dataset_name] = json.load( open(chunk_offsetmap, 'r')) t_e = time.time() logging.info("Open dataset {}, time: {}".format( chunk_fp, (t_e - t_s))) else: raise ValueError( "Not supported dataset: {}".format(dataset_name))
def __init__(self, yaml_file, args=None, tokenizer=None, seq_len=35, encoding="utf-8", corpus_lines=None, on_memory=True, **kwargs): self.cfg = load_from_yaml_file(yaml_file) self.root = os.path.dirname(yaml_file) self.vocab = tokenizer.vocab self.tokenizer = tokenizer self.seq_len = seq_len self.on_memory = on_memory self.corpus_lines = corpus_lines # number of non-empty lines in input corpus self.corpus_tsvfile = TSVFile( os.path.join(self.root, self.cfg['corpus_file'])) if 'textb_sample_mode' in kwargs: self.textb_sample_mode = kwargs['textb_sample_mode'] else: self.textb_sample_mode = args.textb_sample_mode self.datasets_names = self.cfg['corpus'].split('_') self.datasets_with_splits = [ 'googlecc', 'sbu', 'oi', 'objects365', 'tagoi' ] self.datasets_with_onesplit = ['coco', 'flickr30k', 'gqa'] logging.info('Datasets: {}'.format(','.join(self.datasets_names))) self.image_label_path = self.cfg['image_label_path'] for key, val in self.image_label_path.items(): # get the absolute path if key in self.datasets_names: self.image_label_path[key] = os.path.join(self.root, val) self.image_feature_path = self.cfg['image_feature_path'] self.image_file_name = 'features.tsv' if args.data_dir is not None: for key, val in self.image_feature_path.items(): # get the absolute path if key in self.datasets_names: self.image_feature_path[key] = os.path.join( args.data_dir, val) else: logging.info("Data {} with path {} is not used in the " "training.".format(key, val)) self.encoding = encoding self.current_doc = 0 # to avoid random sentence from same doc self.current_img = '' # to avoid random sentence from same image self.args = args # for loading samples directly from file self.sample_counter = 0 # used to keep track of full epochs on file self.line_buffer = None # keep second sentence of a pair in memory and use as first sentence in next pair # for loading samples in memory self.current_random_doc = 0 self.num_docs = 0 self.sample_to_doc = [] # map sample index to doc and line self.chunk_list = None if 0 <= args.chunk_start_id <= args.chunk_end_id and args.chunk_end_id >= 0: self.chunk_list = [ str(c_i) for c_i in range(args.chunk_start_id, args.chunk_end_id) ] logging.info('Chunk list: {}'.format(','.join(self.chunk_list))) # load image tags and features t_start = time.time() self.img_label_file = None self.img_qa_file = None self.img_label_offset_map = None self.img_qa_offset_map = None self.img_feature_file = None self.img_feat_offset_map = None self.load_img_labels() self.load_img_tsv_features() t_end = time.time() logging.info( 'Info: loading img features using {} secs'.format(t_end - t_start)) # load samples into memory if on_memory: self.all_docs = [] self.all_qa_docs = [] self.imgid2labels = {} self.corpus_lines = 0 max_tokens = 0 for line_no in tqdm(range(len(self.corpus_tsvfile))): doc = [] row = self.corpus_tsvfile.seek(line_no) img_info = row[0].split('_') label_info = row[1].split('_') assert img_info[0] == label_info[ 0], "Dataset names for image and label do not match!" dataset_name = label_info[0] if dataset_name == 'cc': dataset_name = 'googlecc' if dataset_name not in self.datasets_names: continue if dataset_name in self.datasets_with_splits: chunk_id = img_info[-2] if self.chunk_list is not None and chunk_id not in self.chunk_list: continue else: img_feat_offset_map = self.img_feat_offset_map[ dataset_name][chunk_id] else: img_feat_offset_map = self.img_feat_offset_map[ dataset_name] assert img_info[ -1] in img_feat_offset_map, "{}: Image id {} cannot be found in image feature imageid_to_index file!".format( row[0], img_info[-1]) # append id info doc.append('%s|%s' % (row[0], row[1])) # append text_a info self.corpus_lines = self.corpus_lines + 1 sample = {"doc_id": len(self.all_docs), "line": len(doc)} self.sample_to_doc.append(sample) assert len(row[2]) != 0, "Text_a is empty in {} : {}"\ .format(dataset_name, row[0]) doc.append(row[2]) # append text_b info self.corpus_lines = self.corpus_lines + 1 label_id = label_info[-1] if 'qa' in label_info: assert img_info[-1] == label_info[ -2], "Image ids for image and qa do not match!" label_line_no = self.img_qa_offset_map[dataset_name][ label_id] rowb = self.img_qa_file[dataset_name].seek(label_line_no) else: assert img_info[-1] == label_info[ -1], "Image ids for image and label do not match!" label_line_no = self.img_label_offset_map[dataset_name][ label_id] rowb = self.img_label_file[dataset_name].seek( label_line_no) assert label_id == rowb[0] results = json.loads(rowb[1]) if 'qa' not in label_info: # more intuitively, should be if 'qa' not in label_info: objects = results['objects'] if row[0] not in self.imgid2labels: self.imgid2labels[row[0]] = { "image_h": results["image_h"], "image_w": results["image_w"], "boxes": None } else: assert results["image_h"] == self.imgid2labels[row[0]][ "image_h"], "Image_h does not match in image {}!".format( row[0]) assert results["image_w"] == self.imgid2labels[row[0]][ "image_w"], "Image_w does not match in image {}!".format( row[0]) if args.use_gtlabels and 'gt_objects' in results: # use ground-truth tags for text_b textb = ' '.join([ cur_d['class'] for cur_d in results["gt_objects"] ]) else: textb = ' '.join([cur_d['class'] for cur_d in objects]) else: tag_label_line_no = self.img_label_offset_map[ dataset_name][img_info[-1]] tag_rowb = self.img_label_file[dataset_name].seek( tag_label_line_no) tag_results = json.loads(tag_rowb[1]) if row[0] not in self.imgid2labels: self.imgid2labels[row[0]] = { "image_h": tag_results["image_h"], "image_w": tag_results["image_w"], "boxes": None } else: assert tag_results["image_h"] == self.imgid2labels[row[0]][ "image_h"], "Image_h does not match in image {}!".format( row[0]) assert tag_results["image_w"] == self.imgid2labels[row[0]][ "image_w"], "Image_w does not match in image {}!".format( row[0]) textb = ' '.join(results['labels']) assert len(textb) != 0, "Text_b is empty in {} : {}".format( dataset_name, row[1]) doc.append(textb) # add to all_docs max_tokens = max( max_tokens, len(doc[1].split(' ')) + len(doc[2].split(' '))) if 'qa' in label_info: self.all_qa_docs.append({ "doc": doc, "doc_id": len(self.all_docs) }) self.all_docs.append(doc) self.num_docs = len(self.all_docs) logging.info("Max_tokens: {}".format(max_tokens)) # load samples later lazily from disk else: raise ValueError("on_memory = False Not supported yet!") logging.info("Total docs - Corpus_lines: {}-{}".format( self.num_docs, self.corpus_lines)) logging.info("Total QA docs - Corpus_lines: {}".format( len(self.all_qa_docs)))
class OscarTSVDataset(Dataset): def __init__(self, yaml_file, args=None, tokenizer=None, seq_len=35, encoding="utf-8", corpus_lines=None, on_memory=True, **kwargs): self.cfg = load_from_yaml_file(yaml_file) self.root = os.path.dirname(yaml_file) self.vocab = tokenizer.vocab self.tokenizer = tokenizer self.seq_len = seq_len self.on_memory = on_memory self.corpus_lines = corpus_lines # number of non-empty lines in input corpus self.corpus_tsvfile = TSVFile( os.path.join(self.root, self.cfg['corpus_file'])) if 'textb_sample_mode' in kwargs: self.textb_sample_mode = kwargs['textb_sample_mode'] else: self.textb_sample_mode = args.textb_sample_mode self.datasets_names = self.cfg['corpus'].split('_') self.datasets_with_splits = [ 'googlecc', 'sbu', 'oi', 'objects365', 'tagoi' ] self.datasets_with_onesplit = ['coco', 'flickr30k', 'gqa'] logging.info('Datasets: {}'.format(','.join(self.datasets_names))) self.image_label_path = self.cfg['image_label_path'] for key, val in self.image_label_path.items(): # get the absolute path if key in self.datasets_names: self.image_label_path[key] = os.path.join(self.root, val) self.image_feature_path = self.cfg['image_feature_path'] self.image_file_name = 'features.tsv' if args.data_dir is not None: for key, val in self.image_feature_path.items(): # get the absolute path if key in self.datasets_names: self.image_feature_path[key] = os.path.join( args.data_dir, val) else: logging.info("Data {} with path {} is not used in the " "training.".format(key, val)) self.encoding = encoding self.current_doc = 0 # to avoid random sentence from same doc self.current_img = '' # to avoid random sentence from same image self.args = args # for loading samples directly from file self.sample_counter = 0 # used to keep track of full epochs on file self.line_buffer = None # keep second sentence of a pair in memory and use as first sentence in next pair # for loading samples in memory self.current_random_doc = 0 self.num_docs = 0 self.sample_to_doc = [] # map sample index to doc and line self.chunk_list = None if 0 <= args.chunk_start_id <= args.chunk_end_id and args.chunk_end_id >= 0: self.chunk_list = [ str(c_i) for c_i in range(args.chunk_start_id, args.chunk_end_id) ] logging.info('Chunk list: {}'.format(','.join(self.chunk_list))) # load image tags and features t_start = time.time() self.img_label_file = None self.img_qa_file = None self.img_label_offset_map = None self.img_qa_offset_map = None self.img_feature_file = None self.img_feat_offset_map = None self.load_img_labels() self.load_img_tsv_features() t_end = time.time() logging.info( 'Info: loading img features using {} secs'.format(t_end - t_start)) # load samples into memory if on_memory: self.all_docs = [] self.all_qa_docs = [] self.imgid2labels = {} self.corpus_lines = 0 max_tokens = 0 for line_no in tqdm(range(len(self.corpus_tsvfile))): doc = [] row = self.corpus_tsvfile.seek(line_no) img_info = row[0].split('_') label_info = row[1].split('_') assert img_info[0] == label_info[ 0], "Dataset names for image and label do not match!" dataset_name = label_info[0] if dataset_name == 'cc': dataset_name = 'googlecc' if dataset_name not in self.datasets_names: continue if dataset_name in self.datasets_with_splits: chunk_id = img_info[-2] if self.chunk_list is not None and chunk_id not in self.chunk_list: continue else: img_feat_offset_map = self.img_feat_offset_map[ dataset_name][chunk_id] else: img_feat_offset_map = self.img_feat_offset_map[ dataset_name] assert img_info[ -1] in img_feat_offset_map, "{}: Image id {} cannot be found in image feature imageid_to_index file!".format( row[0], img_info[-1]) # append id info doc.append('%s|%s' % (row[0], row[1])) # append text_a info self.corpus_lines = self.corpus_lines + 1 sample = {"doc_id": len(self.all_docs), "line": len(doc)} self.sample_to_doc.append(sample) assert len(row[2]) != 0, "Text_a is empty in {} : {}"\ .format(dataset_name, row[0]) doc.append(row[2]) # append text_b info self.corpus_lines = self.corpus_lines + 1 label_id = label_info[-1] if 'qa' in label_info: assert img_info[-1] == label_info[ -2], "Image ids for image and qa do not match!" label_line_no = self.img_qa_offset_map[dataset_name][ label_id] rowb = self.img_qa_file[dataset_name].seek(label_line_no) else: assert img_info[-1] == label_info[ -1], "Image ids for image and label do not match!" label_line_no = self.img_label_offset_map[dataset_name][ label_id] rowb = self.img_label_file[dataset_name].seek( label_line_no) assert label_id == rowb[0] results = json.loads(rowb[1]) if 'qa' not in label_info: # more intuitively, should be if 'qa' not in label_info: objects = results['objects'] if row[0] not in self.imgid2labels: self.imgid2labels[row[0]] = { "image_h": results["image_h"], "image_w": results["image_w"], "boxes": None } else: assert results["image_h"] == self.imgid2labels[row[0]][ "image_h"], "Image_h does not match in image {}!".format( row[0]) assert results["image_w"] == self.imgid2labels[row[0]][ "image_w"], "Image_w does not match in image {}!".format( row[0]) if args.use_gtlabels and 'gt_objects' in results: # use ground-truth tags for text_b textb = ' '.join([ cur_d['class'] for cur_d in results["gt_objects"] ]) else: textb = ' '.join([cur_d['class'] for cur_d in objects]) else: tag_label_line_no = self.img_label_offset_map[ dataset_name][img_info[-1]] tag_rowb = self.img_label_file[dataset_name].seek( tag_label_line_no) tag_results = json.loads(tag_rowb[1]) if row[0] not in self.imgid2labels: self.imgid2labels[row[0]] = { "image_h": tag_results["image_h"], "image_w": tag_results["image_w"], "boxes": None } else: assert tag_results["image_h"] == self.imgid2labels[row[0]][ "image_h"], "Image_h does not match in image {}!".format( row[0]) assert tag_results["image_w"] == self.imgid2labels[row[0]][ "image_w"], "Image_w does not match in image {}!".format( row[0]) textb = ' '.join(results['labels']) assert len(textb) != 0, "Text_b is empty in {} : {}".format( dataset_name, row[1]) doc.append(textb) # add to all_docs max_tokens = max( max_tokens, len(doc[1].split(' ')) + len(doc[2].split(' '))) if 'qa' in label_info: self.all_qa_docs.append({ "doc": doc, "doc_id": len(self.all_docs) }) self.all_docs.append(doc) self.num_docs = len(self.all_docs) logging.info("Max_tokens: {}".format(max_tokens)) # load samples later lazily from disk else: raise ValueError("on_memory = False Not supported yet!") logging.info("Total docs - Corpus_lines: {}-{}".format( self.num_docs, self.corpus_lines)) logging.info("Total QA docs - Corpus_lines: {}".format( len(self.all_qa_docs))) def __len__(self): # last line of doc won't be used, because there's no "nextSentence". return self.corpus_lines - self.num_docs def get_img_info(self, idx): sample = self.sample_to_doc[idx] # img_id = self.all_docs[sample["doc_id"]][0].strip() # original img_id = self.all_docs[sample["doc_id"]][0].strip().split('|')[0] imgid2labels = self.imgid2labels[img_id] return { "height": imgid2labels["image_h"], "width": imgid2labels["image_w"] } def __getitem__(self, item): cur_id = self.sample_counter self.sample_counter += 1 if not self.on_memory: # after one epoch we start again from beginning of file if cur_id != 0 and (cur_id % len(self) == 0): raise ValueError("on_memory = False Not supported yet!") img_id, t1, t2, is_next_label, is_img_match = self.random_sent(item) # tokenize tokens_a = self.tokenizer.tokenize(t1) if self.args.use_b: tokens_b = self.tokenizer.tokenize(t2) else: tokens_b = None # combine to one sample cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=tokens_b, is_next=is_next_label, img_id=img_id, is_img_match=is_img_match) # get image feature img_feat = self.get_img_feature(img_id) if img_feat.shape[0] >= self.args.max_img_seq_length: img_feat = img_feat[0:self.args.max_img_seq_length, ] img_feat_len = img_feat.shape[0] else: img_feat_len = img_feat.shape[0] padding_matrix = torch.zeros( (self.args.max_img_seq_length - img_feat.shape[0], img_feat.shape[1])) img_feat = torch.cat((img_feat, padding_matrix), 0) # transform sample to features cur_features = convert_example_to_features(self.args, cur_example, self.seq_len, self.tokenizer, img_feat_len) return img_feat, ( torch.tensor(cur_features.input_ids, dtype=torch.long), torch.tensor(cur_features.input_mask, dtype=torch.long), torch.tensor(cur_features.segment_ids, dtype=torch.long), torch.tensor(cur_features.lm_label_ids, dtype=torch.long), torch.tensor(cur_features.is_next), torch.tensor(cur_features.is_img_match), ), item # return cur_tensors def random_sent(self, index): """ Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences from one doc. With 50% the second sentence will be a random one from another doc. :param index: int, index of sample. :return: (str, str, int), sentence 1, sentence 2, isNextSentence Label """ img_id, t1, t2 = self.get_corpus_line(index) rand_dice = random.random() if rand_dice > 0.5: label = 0 random_img_id = img_id elif rand_dice > self.args.texta_false_prob and t2 != "": # wrong qa triplets random_img_id, t2 = self.get_random_line() label = 1 else: # wrong retrieval triplets random_img_id, t1 = self.get_random_texta() # args.num_contrast_classes = 3 if args.texta_false_prob<0.5 and (args.texta_false_prob>0 or not args.use_b) else 2 label = self.args.num_contrast_classes - 1 img_match_label = 0 if img_id != random_img_id: img_match_label = 1 assert len(t1) > 0 assert len(t2) > 0 or not self.args.use_b return img_id, t1, t2, label, img_match_label def get_corpus_line(self, item): """ Get one sample from corpus consisting of a pair of two subsequent lines from the same doc. :param item: int, index of sample. :return: (str, str), two subsequent sentences from corpus """ assert item < self.corpus_lines if self.on_memory: sample = self.sample_to_doc[item] # img_id = self.all_docs[sample["doc_id"]][0].strip() # original img_id = self.all_docs[sample["doc_id"]][0].strip().split('|')[0] t1 = self.all_docs[sample["doc_id"]][sample["line"]] t2 = self.all_docs[sample["doc_id"]][sample["line"] + 1] # used later to avoid random nextSentence from same doc self.current_doc = sample["doc_id"] self.current_img = img_id assert t1 != "" if self.args.use_b or 'qa' in self.all_docs[ sample["doc_id"]][0].split('_'): assert t2 != "" else: t2 = "" return img_id, t1, t2 else: raise ValueError("on_memory = False Not supported yet!") def get_random_line(self): """ Get random line from another document for nextSentence task. :return: str, content of one line """ # Similar to original tf repo: This outer loop should rarely go for more than one iteration for large # corpora. However, just to be careful, we try to make sure that # the random document is not the same as the document we're processing. if self.on_memory: if self.textb_sample_mode in [0, 1]: # sample from all docs for _ in range(10): rand_doc_idx = random.randrange(0, len(self.all_docs)) img_id = self.all_docs[rand_doc_idx][0].split('|')[0] # check if our picked random line is really from another image like we want it to be if img_id != self.current_img: break rand_doc = self.all_docs[rand_doc_idx] else: # sample from all qa docs for _ in range(10): rand_doc_idx = random.randrange(0, len(self.all_qa_docs)) # check if our picked random line is really from another doc like we want it to be % no need to be different image here if self.all_qa_docs[rand_doc_idx][ "doc_id"] != self.current_doc: break rand_doc = self.all_qa_docs[rand_doc_idx]["doc"] # img_id = rand_doc[0] # original img_id = rand_doc[0].split('|')[0] if self.textb_sample_mode == 0: # default oscar sample mode line = rand_doc[random.randrange(1, len(rand_doc))] else: # only sample text_b line = rand_doc[2] return img_id, line else: raise ValueError("on_memory = False Not supported yet!") def get_random_texta(self): """ Get random text_a from another document for nextSentence task. :return: str, content of one line """ # Similar to original tf repo: This outer loop should rarely go for more than one iteration for large # corpora. However, just to be careful, we try to make sure that # the random document is not the same as the document we're processing. if self.on_memory: for _ in range(10): rand_doc_idx = random.randrange(0, len(self.all_docs)) img_id = self.all_docs[rand_doc_idx][0].split('|')[0] # check if our picked random line is really from another image like we want it to be if img_id != self.current_img: break rand_doc = self.all_docs[rand_doc_idx] # img_id = rand_doc[0] # original img_id = rand_doc[0].split('|')[0] line = rand_doc[1] # we want the text_a return img_id, line else: raise ValueError("on_memory = False Not supported yet!") # tsv image labels def load_img_labels(self): self.check_img_label_file() self.check_img_label_offset_map() def check_img_label_file(self): if self.img_label_file is None: self.img_label_file = {} self.img_qa_file = {} for dataset_name in self.datasets_names: img_label_file_path = os.path.join( self.image_label_path[dataset_name], 'predictions_gt.tsv') img_qa_file_path = os.path.join( self.image_label_path[dataset_name], 'QA_fileB.tsv') t_s = time.time() self.img_label_file[dataset_name] = TSVFile( img_label_file_path) if os.path.exists(img_qa_file_path): self.img_qa_file[dataset_name] = TSVFile(img_qa_file_path) t_e = time.time() logging.info("Open image label file {}, time: {}".format( img_label_file_path, (t_e - t_s))) def check_img_label_offset_map(self): if self.img_label_offset_map is None: self.img_label_offset_map = {} self.img_qa_offset_map = {} for dataset_name in self.datasets_names: img_label_offset_map_path = os.path.join( self.image_label_path[dataset_name], 'imageid2idx.json') img_qa_offset_map_path = os.path.join( self.image_label_path[dataset_name], 'QA_qaid2idx.json') t_s = time.time() self.img_label_offset_map[dataset_name] = json.load( open(img_label_offset_map_path)) if os.path.exists(img_qa_offset_map_path): self.img_qa_offset_map[dataset_name] = json.load( open(img_qa_offset_map_path)) t_e = time.time() logging.info("Load img label offset map: {}, time: {}".format( img_label_offset_map_path, (t_e - t_s))) def get_img_labels(self, image_id): """ decode the image labels: read the image label from the img_label.tsv """ self.check_img_label_file() self.check_img_label_offset_map() if image_id in self.img_label_offset_map: img_offset = self.img_label_offset_map[image_id] self.img_label_file.seek(img_offset, 0) arr = [ s.strip() for s in self.img_label_file.readline().split('\t') ] eles = json.loads(arr[1]) labels = eles['labels'] return labels return None # tsv feature loading def load_img_tsv_features(self): self.check_img_feature_file() self.check_img_feature_offset_map() def check_img_feature_file(self): if self.img_feature_file is None: # self.img_feature_file = [] # original self.img_feature_file = {} self.img_feat_offset_map = {} for dataset_name in self.datasets_names: logging.info("* Loading dataset {}".format(dataset_name)) if dataset_name in self.datasets_with_splits: self.img_feature_file[dataset_name] = {} self.img_feat_offset_map[dataset_name] = {} chunk_list = [] if self.chunk_list is not None: chunk_list = self.chunk_list chunk_file_list = [] for chunk_fp_id in chunk_list: chunk_file_list.append( os.path.join( self.image_feature_path[dataset_name], chunk_fp_id, self.image_file_name)) if dataset_name == 'googlecc': for i, (chunk_fp_id, chunk_fp) in enumerate( zip(chunk_list, chunk_file_list)): assert os.path.exists( chunk_file_list[i] ), "Chunk file {} does not exists!".format( chunk_fp) else: chunk_file_list = glob.glob( self.image_feature_path[dataset_name] + "/*/{}".format(self.image_file_name)) for chunk_fp in chunk_file_list: chunk_fp_id = chunk_fp.split('/')[-2] chunk_list.append(chunk_fp_id) logging.info("* Load Image Chunks {}".format( len(chunk_list))) t_s_total = time.time() for chunk_fp in chunk_file_list: chunk_fp_id = chunk_fp.split('/')[-2] t_s = time.time() self.img_feature_file[dataset_name][ chunk_fp_id] = TSVFile(chunk_fp) chunk_offsetmap = os.path.join( os.path.dirname(chunk_fp), 'imageid2idx.json') assert os.path.isfile( chunk_offsetmap ), "Imageid2idx file {} does not exists!".format( chunk_offsetmap) self.img_feat_offset_map[dataset_name][ chunk_fp_id] = json.load(open( chunk_offsetmap, 'r')) t_e = time.time() logging.info("Open image chunk {}, time: {}".format( chunk_fp_id, (t_e - t_s))) t_e_total = time.time() logging.info("Open total {} image chunks, time: {}".format( len(chunk_list), (t_e_total - t_s_total))) logging.info("Image chunk info: {}".format( '\n'.join(chunk_file_list))) elif dataset_name in self.datasets_with_onesplit: t_s = time.time() chunk_fp = os.path.join( self.image_feature_path[dataset_name], self.image_file_name) self.img_feature_file[dataset_name] = TSVFile(chunk_fp) chunk_offsetmap = os.path.join(os.path.dirname(chunk_fp), 'imageid2idx.json') assert os.path.isfile( chunk_offsetmap ), "Imageid2idx file {} does not exists!".format( chunk_offsetmap) self.img_feat_offset_map[dataset_name] = json.load( open(chunk_offsetmap, 'r')) t_e = time.time() logging.info("Open dataset {}, time: {}".format( chunk_fp, (t_e - t_s))) else: raise ValueError( "Not supported dataset: {}".format(dataset_name)) def check_img_feature_offset_map(self): """ load the image feature offset map """ if self.img_feat_offset_map is None: self.img_feat_offset_map = {} for dataset_name in self.datasets_names: logging.info( "* Loading imageid2idx_map {}".format(dataset_name)) if dataset_name in self.datasets_with_splits: chunk_list = [] chunk_file_list = glob.glob( self.image_feature_path[dataset_name] + "/*/imageid2idx.json") for chunk_fp in chunk_file_list: chunk_fp_id = chunk_fp.split('/')[-2] chunk_list.append(chunk_fp_id) logging.info("* Load Image Chunks {}".format( len(chunk_list))) t_s_total = time.time() for chunk_fp in chunk_file_list: chunk_fp_id = chunk_fp.split('/')[-2] t_s = time.time() self.img_feat_offset_map[dataset_name][ chunk_fp_id] = json.load(open(chunk_fp)) t_e = time.time() logging.info("Open image chunk {}, time: {}".format( chunk_fp_id, (t_e - t_s))) t_e_total = time.time() logging.info("Open total {} image chunks, time: {}".format( len(chunk_list), (t_e_total - t_s_total))) elif dataset_name in self.datasets_with_onesplit: t_s = time.time() chunk_fp = self.image_feature_path[ dataset_name] + "/imageid2idx.json" self.img_feat_offset_map[dataset_name] = json.load( open(chunk_fp)) t_e = time.time() logging.info("Open dataset {}, time: {}".format( chunk_fp, (t_e - t_s))) else: raise ValueError( "Not supported dataset: {}".format(dataset_name)) def get_img_feature(self, image_id): """ decode the image feature: read the image feature from the right chunk id """ self.check_img_feature_file() self.check_img_feature_offset_map() img_infos = image_id.split('_') dataset_name = img_infos[0] if dataset_name == 'cc': dataset_name = 'googlecc' img_id = img_infos[-1] if dataset_name in self.datasets_with_splits: chunk_id = img_infos[-2] img_feat_offset_map = self.img_feat_offset_map[dataset_name][ chunk_id] img_feature_file = self.img_feature_file[dataset_name][chunk_id] else: img_feat_offset_map = self.img_feat_offset_map[dataset_name] img_feature_file = self.img_feature_file[dataset_name] if img_id in img_feat_offset_map: img_offset = img_feat_offset_map[img_id] arr = img_feature_file.seek(img_offset) num_boxes = int(arr[1]) feat = np.frombuffer(base64.b64decode(arr[-1]), dtype=np.float32).reshape( (num_boxes, self.args.img_feature_dim)) feat = torch.from_numpy(feat) return feat return None
class CaptionTSVDataset(Dataset): def __init__(self, yaml_file, tokenizer=None, add_od_labels=True, disable_img_features=False, keep_top_percentage_tag_conf_threshold=0.3, keep_top_percentage_tag=1, max_img_seq_length=50, max_seq_length=70, max_seq_a_length=40, is_train=True, mask_prob=0.15, max_masked_tokens=3, **kwargs): """Constructor. Args: yaml file with all required data (image feature, caption, labels, etc) tokenizer: tokenizer for text processing. add_od_labels: whether to add labels from yaml file to BERT. max_img_seq_length: max image sequence length. max_seq_length: max text sequence length. max_seq_a_length: max caption sequence length. is_train: train or test mode. mask_prob: probability to mask a input token. max_masked_tokens: maximum number of tokens to be masked in one sentence. kwargs: other arguments. """ self.yaml_file = yaml_file self.cfg = load_from_yaml_file(yaml_file) self.root = op.dirname(yaml_file) self.label_file = find_file_path_in_yaml(self.cfg['label'], self.root) self.feat_file = find_file_path_in_yaml(self.cfg['feature'], self.root) self.caption_file = find_file_path_in_yaml(self.cfg.get('caption'), self.root) assert op.isfile(self.feat_file) if add_od_labels: assert op.isfile(self.label_file) if is_train: assert op.isfile(self.caption_file) and tokenizer is not None self.label_tsv = None if not self.label_file else TSVFile( self.label_file) self.feat_tsv = TSVFile(self.feat_file) if self.caption_file and op.isfile(self.caption_file): with open(self.caption_file, 'r') as f: self.captions = json.load(f) self.tokenizer = tokenizer self.tensorizer = CaptionTensorizer(self.tokenizer, max_img_seq_length, max_seq_length, max_seq_a_length, mask_prob, max_masked_tokens, is_train=is_train) self.add_od_labels = add_od_labels self.disable_img_features = disable_img_features self.keep_top_percentage_tag_conf_threshold = keep_top_percentage_tag_conf_threshold self.keep_top_percentage_tag = keep_top_percentage_tag self.is_train = is_train self.kwargs = kwargs self.image_keys = self.prepare_image_keys() self.key2index = self.prepare_image_key_to_index() self.key2captions = self.prepare_image_key_to_captions() def get_valid_tsv(self): # based on the order of file size if self.label_tsv: return self.label_tsv if self.feat_tsv: return self.feat_tsv def prepare_image_keys(self): tsv = self.get_valid_tsv() return [tsv.seek(i)[0] for i in range(tsv.num_rows())] def prepare_image_key_to_index(self): tsv = self.get_valid_tsv() return {tsv.seek(i)[0]: i for i in range(tsv.num_rows())} def prepare_image_key_to_captions(self): if self.is_train: key2captions = {key: [] for key in self.image_keys} for cap in self.captions: key2captions[cap['image_id']].append(cap['caption']) return key2captions def get_image_index(self, idx): if self.is_train: img_cap_pair = self.captions[idx] img_key = img_cap_pair['image_id'] return self.key2index[img_key] return idx def get_image_key(self, idx): img_idx = self.get_image_index(idx) return self.image_keys[img_idx] def get_image_features(self, img_idx): feat_info = json.loads(self.feat_tsv.seek(img_idx)[1]) num_boxes = feat_info['num_boxes'] tmp = np.frombuffer(base64.b64decode(feat_info['features']), np.float32).reshape((num_boxes, -1)) # remove image features from fine-tuning stage if self.disable_img_features: features = tmp.copy() features.fill(0) else: features = tmp return torch.Tensor(features) def get_caption(self, idx): if self.is_train: img_cap_pair = self.captions[idx] return img_cap_pair['caption'] return "" def get_od_labels(self, img_idx): od_labels = None if self.add_od_labels: label_info = json.loads(self.label_tsv.seek(img_idx)[1]) if self.keep_top_percentage_tag == 1: od_labels = " ".join([l['class'] for l in label_info]) else: # only keep label that have >= keep_top_percentage_tag_conf_threshold confidence label_over_threshold = [ l for l in label_info if l['conf'] >= self.keep_top_percentage_tag_conf_threshold ] sorted_by_conf = sorted(label_over_threshold, key=lambda i: i['conf'], reverse=True) # top keep_top_percentage_tag% of object tag that is above the confidence threshold label_info_sort_by_conf = sorted_by_conf[:int( len(sorted_by_conf) * self.keep_top_percentage_tag)] od_labels = " ".join( [l['class'] for l in label_info_sort_by_conf]) return od_labels def get_caption_file_in_coco_format(self): cap_file = op.splitext(self.caption_file)[0] + '_coco_format.json' return cap_file def get_captions_by_key(self, key): assert self.is_train, "cannot get captions for inference" return self.key2captions[key] def __getitem__(self, idx): img_idx = self.get_image_index(idx) img_key = self.image_keys[img_idx] features = self.get_image_features(img_idx) caption = self.get_caption(idx) od_labels = self.get_od_labels(img_idx) example = self.tensorizer.tensorize_example(caption, features, text_b=od_labels) return img_key, example def __len__(self): if self.is_train: return len(self.captions) return self.get_valid_tsv().num_rows()
def __init__(self, tokenizer, args, split='train', is_train=True): """ tokenizer: tokenizer to process caption text. args: configureation parameters including max_seq_length, etc. split: used to infer the data used for training or testing. All files are in .pt format of a dictionary with image keys and image features (pytorch tensors), captions (list of str, support multiple captions per image), labels (list of dictionary or str of all labels), """ super(RetrievalDataset, self).__init__() self.img_file = args.img_feat_file caption_file = op.join(args.data_dir, '{}_captions.pt'.format(split)) self.img_tsv = TSVFile(self.img_file) self.captions = torch.load(caption_file) self.img_keys = list(self.captions.keys()) # img_id as int if not type(self.captions[self.img_keys[0]]) == list: self.captions = { k: json.loads(self.captions[k]) for k in self.img_keys } # get the image image_id to index map imgid2idx_file = op.join(op.dirname(self.img_file), 'imageid2idx.json') self.image_id2idx = json.load(open(imgid2idx_file)) # img_id as string if args.add_od_labels: label_data_dir = op.dirname(self.img_file) label_file = os.path.join(label_data_dir, "predictions.tsv") self.label_tsv = TSVFile(label_file) self.labels = {} for line_no in range(self.label_tsv.num_rows()): row = self.label_tsv.seek(line_no) image_id = row[0] if int(image_id) in self.img_keys: results = json.loads(row[1]) objects = results['objects'] if type( results) == dict else results self.labels[int(image_id)] = { "image_h": results["image_h"] if type(results) == dict else 600, "image_w": results["image_w"] if type(results) == dict else 800, "class": [cur_d['class'] for cur_d in objects], "boxes": np.array([cur_d['rect'] for cur_d in objects], dtype=np.float32) } self.label_tsv._fp.close() self.label_tsv._fp = None if is_train: self.num_captions_per_img = args.num_captions_per_img_train else: self.num_captions_per_img = args.num_captions_per_img_val if args.eval_img_keys_file: # select a subset of image keys for evaluation. eg. COCO 1k and 5k # eval_img_keys_file is a list of image keys saved in tsv file with open(op.join(args.data_dir, args.eval_img_keys_file), 'r') as f: img_keys = f.readlines() self.img_keys = [int(k.strip()) for k in img_keys] self.captions = {k: self.captions[k] for k in self.img_keys} if args.add_od_labels: self.labels = {k: self.labels[k] for k in self.img_keys} if args.eval_caption_index_file: # hard negative image/caption indexs for retrieval re-rank setting. # useful for mini val set to monitor the performance during training. # However, it cannot be used together with cross image evaluation. self.has_caption_indexs = True assert not args.cross_image_eval caption_index_file = op.join(args.data_dir, args.eval_caption_index_file) self.caption_indexs = torch.load(caption_index_file) if not type(self.caption_indexs[self.img_keys[0]]) == list: self.caption_indexs = { k: json.loads(self.caption_indexs[k]) for k in self.img_keys } else: self.has_caption_indexs = False self.is_train = is_train self.output_mode = args.output_mode self.tokenizer = tokenizer self.max_seq_len = args.max_seq_length self.max_img_seq_len = args.max_img_seq_length self.args = args
class RetrievalDataset(Dataset): """ Image/Text Retrieval Dataset""" def __init__(self, tokenizer, args, split='train', is_train=True): """ tokenizer: tokenizer to process caption text. args: configureation parameters including max_seq_length, etc. split: used to infer the data used for training or testing. All files are in .pt format of a dictionary with image keys and image features (pytorch tensors), captions (list of str, support multiple captions per image), labels (list of dictionary or str of all labels), """ super(RetrievalDataset, self).__init__() self.img_file = args.img_feat_file caption_file = op.join(args.data_dir, '{}_captions.pt'.format(split)) self.img_tsv = TSVFile(self.img_file) self.captions = torch.load(caption_file) self.img_keys = list(self.captions.keys()) # img_id as int if not type(self.captions[self.img_keys[0]]) == list: self.captions = { k: json.loads(self.captions[k]) for k in self.img_keys } # get the image image_id to index map imgid2idx_file = op.join(op.dirname(self.img_file), 'imageid2idx.json') self.image_id2idx = json.load(open(imgid2idx_file)) # img_id as string if args.add_od_labels: label_data_dir = op.dirname(self.img_file) label_file = os.path.join(label_data_dir, "predictions.tsv") self.label_tsv = TSVFile(label_file) self.labels = {} for line_no in range(self.label_tsv.num_rows()): row = self.label_tsv.seek(line_no) image_id = row[0] if int(image_id) in self.img_keys: results = json.loads(row[1]) objects = results['objects'] if type( results) == dict else results self.labels[int(image_id)] = { "image_h": results["image_h"] if type(results) == dict else 600, "image_w": results["image_w"] if type(results) == dict else 800, "class": [cur_d['class'] for cur_d in objects], "boxes": np.array([cur_d['rect'] for cur_d in objects], dtype=np.float32) } self.label_tsv._fp.close() self.label_tsv._fp = None if is_train: self.num_captions_per_img = args.num_captions_per_img_train else: self.num_captions_per_img = args.num_captions_per_img_val if args.eval_img_keys_file: # select a subset of image keys for evaluation. eg. COCO 1k and 5k # eval_img_keys_file is a list of image keys saved in tsv file with open(op.join(args.data_dir, args.eval_img_keys_file), 'r') as f: img_keys = f.readlines() self.img_keys = [int(k.strip()) for k in img_keys] self.captions = {k: self.captions[k] for k in self.img_keys} if args.add_od_labels: self.labels = {k: self.labels[k] for k in self.img_keys} if args.eval_caption_index_file: # hard negative image/caption indexs for retrieval re-rank setting. # useful for mini val set to monitor the performance during training. # However, it cannot be used together with cross image evaluation. self.has_caption_indexs = True assert not args.cross_image_eval caption_index_file = op.join(args.data_dir, args.eval_caption_index_file) self.caption_indexs = torch.load(caption_index_file) if not type(self.caption_indexs[self.img_keys[0]]) == list: self.caption_indexs = { k: json.loads(self.caption_indexs[k]) for k in self.img_keys } else: self.has_caption_indexs = False self.is_train = is_train self.output_mode = args.output_mode self.tokenizer = tokenizer self.max_seq_len = args.max_seq_length self.max_img_seq_len = args.max_img_seq_length self.args = args def get_image_caption_index(self, index): # return img_idx to access features and [img_key, cap_idx] to access caption if not self.is_train and self.args.cross_image_eval: img_idx = index // (self.num_captions_per_img * len(self.img_keys)) cap_idx = index % (self.num_captions_per_img * len(self.img_keys)) img_idx1 = cap_idx // self.num_captions_per_img cap_idx1 = cap_idx % self.num_captions_per_img return img_idx, [self.img_keys[img_idx1], cap_idx1] if not self.is_train and self.has_caption_indexs: img_idx = index // self.num_captions_per_img cap_idx = index % self.num_captions_per_img img_key1, cap_idx1 = self.caption_indexs[ self.img_keys[img_idx]][cap_idx] return img_idx, [img_key1, cap_idx1] img_idx = index // self.num_captions_per_img cap_idx = index % self.num_captions_per_img return img_idx, [self.img_keys[img_idx], cap_idx] def get_label(self, index): img_idx, cap_idx = self.get_image_caption_index(index) return 1 if self.img_keys[img_idx] == cap_idx[0] else 0 def get_od_labels(self, img_key): if self.args.add_od_labels: if type(self.labels[img_key]) == str: od_labels = self.labels[img_key] else: od_labels = ' '.join(self.labels[img_key]['class']) return od_labels def tensorize_example(self, text_a, img_feat, text_b=None, cls_token_segment_id=0, pad_token_segment_id=0, sequence_a_segment_id=0, sequence_b_segment_id=1): tokens_a = self.tokenizer.tokenize(text_a) if len(tokens_a) > self.args.max_seq_length - 2: tokens_a = tokens_a[:(self.args.max_seq_length - 2)] tokens = [self.tokenizer.cls_token ] + tokens_a + [self.tokenizer.sep_token] segment_ids = [cls_token_segment_id ] + [sequence_a_segment_id] * (len(tokens_a) + 1) seq_a_len = len(tokens) if text_b: tokens_b = self.tokenizer.tokenize(text_b) if len(tokens_b) > self.max_seq_len - len(tokens) - 1: tokens_b = tokens_b[:(self.max_seq_len - len(tokens) - 1)] tokens += tokens_b + [self.tokenizer.sep_token] segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1) seq_len = len(tokens) seq_padding_len = self.max_seq_len - seq_len tokens += [self.tokenizer.pad_token] * seq_padding_len segment_ids += [pad_token_segment_id] * seq_padding_len input_ids = self.tokenizer.convert_tokens_to_ids(tokens) # image features img_len = img_feat.shape[0] if img_len > self.max_img_seq_len: img_feat = img_feat[0:self.max_img_seq_len, :] img_len = img_feat.shape[0] img_padding_len = 0 else: img_padding_len = self.max_img_seq_len - img_len padding_matrix = torch.zeros((img_padding_len, img_feat.shape[1])) img_feat = torch.cat((img_feat, padding_matrix), 0) # generate attention_mask att_mask_type = self.args.att_mask_type if att_mask_type == "CLR": attention_mask = [1] * seq_len + [0] * seq_padding_len + \ [1] * img_len + [0] * img_padding_len else: # use 2D mask to represent the attention max_len = self.max_seq_len + self.max_img_seq_len attention_mask = torch.zeros((max_len, max_len), dtype=torch.long) # full attention of C-C, L-L, R-R c_start, c_end = 0, seq_a_len l_start, l_end = seq_a_len, seq_len r_start, r_end = self.max_seq_len, self.max_seq_len + img_len attention_mask[c_start:c_end, c_start:c_end] = 1 attention_mask[l_start:l_end, l_start:l_end] = 1 attention_mask[r_start:r_end, r_start:r_end] = 1 if att_mask_type == 'CL': attention_mask[c_start:c_end, l_start:l_end] = 1 attention_mask[l_start:l_end, c_start:c_end] = 1 elif att_mask_type == 'CR': attention_mask[c_start:c_end, r_start:r_end] = 1 attention_mask[r_start:r_end, c_start:c_end] = 1 elif att_mask_type == 'LR': attention_mask[l_start:l_end, r_start:r_end] = 1 attention_mask[r_start:r_end, l_start:l_end] = 1 else: raise ValueError( "Unsupported attention mask type {}".format(att_mask_type)) input_ids = torch.tensor(input_ids, dtype=torch.long) attention_mask = torch.tensor(attention_mask, dtype=torch.long) segment_ids = torch.tensor(segment_ids, dtype=torch.long) return (input_ids, attention_mask, segment_ids, img_feat) def __getitem__(self, index): if self.is_train: img_idx, cap_idxs = self.get_image_caption_index(index) img_key = self.img_keys[img_idx] feature = self.get_image(img_key) caption = self.captions[cap_idxs[0]][cap_idxs[1]] od_labels = self.get_od_labels(img_key) example = self.tensorize_example(caption, feature, text_b=od_labels) # select a negative pair neg_img_indexs = list(range(0, img_idx)) + list( range(img_idx + 1, len(self.img_keys))) img_idx_neg = random.choice(neg_img_indexs) if random.random() <= 0.5: # randomly select a negative caption from a different image. cap_idx_neg = random.randint(0, self.num_captions_per_img - 1) caption_neg = self.captions[ self.img_keys[img_idx_neg]][cap_idx_neg] example_neg = self.tensorize_example(caption_neg, feature, text_b=od_labels) else: # randomly select a negative image feature_neg = self.get_image(self.img_keys[img_idx_neg]) od_labels_neg = self.get_od_labels(self.img_keys[img_idx_neg]) example_neg = self.tensorize_example(caption, feature_neg, text_b=od_labels_neg) example_pair = tuple(list(example) + [1] + list(example_neg) + [0]) return index, example_pair else: img_idx, cap_idxs = self.get_image_caption_index(index) img_key = self.img_keys[img_idx] feature = self.get_image(img_key) caption = self.captions[cap_idxs[0]][cap_idxs[1]] od_labels = self.get_od_labels(img_key) example = self.tensorize_example(caption, feature, text_b=od_labels) label = 1 if img_key == cap_idxs[0] else 0 return index, tuple(list(example) + [label]) def get_image(self, image_id): image_idx = self.image_id2idx[str(image_id)] row = self.img_tsv.seek(image_idx) num_boxes = int(row[1]) features = np.frombuffer(base64.b64decode(row[-1]), dtype=np.float32).reshape((num_boxes, -1)) t_features = torch.from_numpy(features) return t_features def __len__(self): if not self.is_train and self.args.cross_image_eval: return len(self.img_keys)**2 * self.num_captions_per_img return len(self.img_keys) * self.num_captions_per_img
class CaptionTSVDataset(Dataset): def __init__(self, yaml_file, tokenizer=None, add_od_labels=True, max_img_seq_length=50, max_seq_length=70, max_seq_a_length=40, is_train=True, mask_prob=0.15, max_masked_tokens=3, add_conf=False, **kwargs): """Constructor. Args: yaml file with all required data (image feature, caption, labels, etc) tokenizer: tokenizer for text processing. add_od_labels: whether to add labels from yaml file to BERT. max_img_seq_length: max image sequence length. max_seq_length: max text sequence length. max_seq_a_length: max caption sequence length. is_train: train or test mode. mask_prob: probability to mask a input token. max_masked_tokens: maximum number of tokens to be masked in one sentence. kwargs: other arguments. """ self.yaml_file = yaml_file self.cfg = load_from_yaml_file(yaml_file) self.root = op.dirname(yaml_file) self.label_file = find_file_path_in_yaml(self.cfg['label'], self.root) self.feat_file = find_file_path_in_yaml(self.cfg['feature'], self.root) self.caption_file = find_file_path_in_yaml(self.cfg.get('caption'), self.root) assert op.isfile(self.feat_file) if add_od_labels: assert op.isfile(self.label_file) if is_train: assert op.isfile(self.caption_file) and tokenizer is not None self.label_tsv = None if not self.label_file else TSVFile( self.label_file) self.feat_tsv = TSVFile(self.feat_file) if self.caption_file and op.isfile(self.caption_file): with open(self.caption_file, 'r') as f: self.captions = json.load(f) self.tokenizer = tokenizer self.tensorizer = CaptionTensorizer(self.tokenizer, max_img_seq_length, max_seq_length, max_seq_a_length, mask_prob, max_masked_tokens, is_train=is_train) self.add_od_labels = add_od_labels self.is_train = is_train self.kwargs = kwargs self.image_keys = self.prepare_image_keys() self.key2index = self.prepare_image_key_to_index() self.key2captions = self.prepare_image_key_to_captions() self.add_conf = add_conf def get_valid_tsv(self): # based on the order of file size if self.label_tsv: return self.label_tsv if self.feat_tsv: return self.feat_tsv def prepare_image_keys(self): tsv = self.get_valid_tsv() return [tsv.seek(i)[0] for i in range(tsv.num_rows())] def prepare_image_key_to_index(self): tsv = self.get_valid_tsv() return {tsv.seek(i)[0]: i for i in range(tsv.num_rows())} def prepare_image_key_to_captions(self): if self.is_train: key2captions = {key: [] for key in self.image_keys} for cap in self.captions: key2captions[cap['image_id']].append(cap['caption']) return key2captions def get_image_index(self, idx): if self.is_train: img_cap_pair = self.captions[idx] img_key = img_cap_pair['image_id'] return self.key2index[img_key] return idx def get_image_key(self, idx): img_idx = self.get_image_index(idx) return self.image_keys[img_idx] def get_image_features(self, img_idx): feat_info = json.loads(self.feat_tsv.seek(img_idx)[1]) num_boxes = feat_info['num_boxes'] features = np.frombuffer(base64.b64decode(feat_info['features']), np.float32).reshape((num_boxes, -1)) return torch.Tensor(features) def get_caption(self, idx): if self.is_train: img_cap_pair = self.captions[idx] return img_cap_pair['caption'] return "" def get_od_labels(self, img_idx): od_labels = None if self.add_od_labels: label_info = json.loads(self.label_tsv.seek(img_idx)[1]) od_labels = " ".join([l['class'] for l in label_info]) return od_labels def get_od_confidence(self, img_idx): od_confs = None od_labels = None if self.add_conf: label_info = json.loads(self.label_tsv.seek(img_idx)[1]) od_confs = [] # we need to repeat conf because some labels have spaces for idx, info in enumerate(label_info): repeats = len(info["class"].strip().split(" ")) od_confs.extend([info["conf"]] * repeats) return od_confs def get_caption_file_in_coco_format(self): cap_file = op.splitext(self.caption_file)[0] + '_coco_format.json' return cap_file def get_captions_by_key(self, key): assert self.is_train, "cannot get captions for inference" return self.key2captions[key] def __getitem__(self, idx): img_idx = self.get_image_index(idx) img_key = self.image_keys[img_idx] features = self.get_image_features(img_idx) caption = self.get_caption(idx) od_labels = self.get_od_labels(img_idx) od_confs = self.get_od_confidence(img_idx) example = self.tensorizer.tensorize_example(caption, features, text_b=od_labels, confs=od_confs) return img_key, example def __len__(self): if self.is_train: return len(self.captions) return self.get_valid_tsv().num_rows()