class LXMERTDataset: def __init__(self, splits: str, qa_sets=None): """ :param splits: The data sources to be loaded :param qa_sets: if None, no action o.w., only takes the answers appearing in these dsets and remove all unlabeled data (MSCOCO captions) """ self.name = splits self.sources = splits.split(',') # Loading datasets to data self.data = [] for source in self.sources: self.data.extend( json.load( open(BASEDIR + "lxmert/caption_%s.json" % source.split('_')[1]))) print("Load %d data from %s" % (len(self.data), self.name)) # Create answer table according to the qa_sets self.answer_table = AnswerTable(qa_sets) print("Load an answer table of size %d." % (len(self.answer_table.ans2id_map()))) # Modify the answers for datum in self.data: labelf = datum['labelf'] for cat, labels in labelf.items(): for label in labels: for ans in list(label.keys()): new_ans = self.answer_table.convert_ans(ans) if self.answer_table.used(new_ans): if ans != new_ans: label[new_ans] = label.pop(ans) else: label.pop(ans) def __len__(self): return len(self.data)
class PretrainingDataset(Dataset): def __init__(self, split='mscoco_mininval', topk=-1, data_out=['img'], verbose=True, args=None): self.data_out = data_out self.topk = topk self.verbose = verbose self.args = args self.datasets_dir = Path(self.args.datasets_dir) # Loading datasets to data self.sources = split.split(',') if self.verbose: print('Data sources: ', self.sources) self.answer_table = AnswerTable() # if self.verbose: print("Load an answer table of size %d." % (len(self.answer_table.ans2id_map()))) self.img_ids_to_source = {} data = [] for img_source in self.sources: with open(self.datasets_dir.joinpath(f'data/lxmert/{img_source}.json')) as f: _data = json.load(f) if self.verbose: print(f"Loaded {len(_data)} data from", img_source) # source_img_ids.append([d['img_id'] for d in _data]) for datum in _data: self.img_ids_to_source[datum['img_id']] = img_source datum['img_source'] = img_source datum['caption_only'] = args.caption_only datum['clustering'] = args.clustering datum['max_text_length'] = args.max_text_length datum['qa'] = args.task_qa data.extend(_data) # Modify the answers if args.task_qa: for datum in data: labelf = datum['labelf'] for _qa_source, labels in labelf.items(): for label in labels: for ans in list(label.keys()): new_ans = self.answer_table.convert_ans(ans) if self.answer_table.used(new_ans): if ans != new_ans: label[new_ans] = label.pop(ans) else: label.pop(ans) if self.topk > 0: data = data[:self.topk] if self.verbose: print(f"Use only {self.topk} data") if args.task_qa: self.evaluator = QAEvaluator(data) if args.clustering: clustering_dir = self.datasets_dir.joinpath('clustering') with open(clustering_dir.joinpath(f'{args.encoder}_{args.cluster_src}_mscoco_train_img_id_to_cluster_id_{args.n_centroids}_iter{args.n_iter}_d{args.feat_dim}_grid{args.grid_size}.pkl'), 'rb') as f: mscoco_train_img_id_to_cluster_id = pickle.load(f) with open(clustering_dir.joinpath(f'{args.encoder}_{args.cluster_src}_mscoco_valid_img_id_to_cluster_id_{args.n_centroids}_iter{args.n_iter}_d{args.feat_dim}_grid{args.grid_size}.pkl'), 'rb') as f: mscoco_valid_img_id_to_cluster_id = pickle.load(f) with open(clustering_dir.joinpath(f'{args.encoder}_{args.cluster_src}_vg_img_id_to_cluster_id_{args.n_centroids}_iter{args.n_iter}_d{args.feat_dim}_grid{args.grid_size}.pkl'), 'rb') as f: vg_img_id_to_cluster_id = pickle.load(f) self.data_source_to_cluster_data = { 'mscoco_train': mscoco_train_img_id_to_cluster_id, 'mscoco_minival': mscoco_valid_img_id_to_cluster_id, 'mscoco_nominival': mscoco_valid_img_id_to_cluster_id, 'vgnococo': vg_img_id_to_cluster_id } with Pool(8) as pool: if self.verbose: data = [datum for _data in tqdm(pool.imap(get_datum, data), total=len(data), ncols=100) for datum in _data] else: data = [datum for _data in pool.imap(get_datum, data) for datum in _data] if self.args.target_exact_feat or self.args.feed_exact_feat or self.args.target_obj_id: if args.grid_model: self.data_source_to_h5_path = { 'mscoco_train': self.datasets_dir.joinpath(f'COCO/features/{args.encoder}_train_grid{args.grid_size}.h5'), 'mscoco_minival': self.datasets_dir.joinpath(f'COCO/features/{args.encoder}_valid_grid{args.grid_size}.h5'), 'mscoco_nominival': self.datasets_dir.joinpath(f'COCO/features/{args.encoder}_valid_grid{args.grid_size}.h5'), 'vgnococo': self.datasets_dir.joinpath(f'VG/features/{args.encoder}_grid{args.grid_size}.h5'), } else: self.data_source_to_h5_path = { 'mscoco_train': self.datasets_dir.joinpath(f'COCO/features/maskrcnn_train_boxes36.h5'), 'mscoco_minival': self.datasets_dir.joinpath(f'COCO/features/maskrcnn_valid_boxes36.h5'), 'mscoco_nominival': self.datasets_dir.joinpath(f'COCO/features/maskrcnn_valid_boxes36.h5'), 'vgnococo': self.datasets_dir.joinpath(f'VG/features/maskrcnn_boxes36.h5'), } for source, path in self.data_source_to_h5_path.items(): assert path.is_file(), (source, path) self.source_to_h5 = None self.data = data if args.vis_mask_COCO_only: COCO_data = [] for datum in self.data: if datum['text_source'] == 'mscoco' and 'mscoco' in datum['img_source']: COCO_data.append(datum) self.COCO_data = COCO_data if self.verbose: print('# COCO captions:', len(self.COCO_data)) if self.verbose: if 'sent' not in self.data_out: print("# all images:", len(self.data)) else: print("# all sentences:", len(self.data)) self.grid_size = args.grid_size self.n_grids = args.n_grids if self.args.grid_model: self.boxes = box_position(args.grid_size) else: self.n_boxes = args.n_boxes self.boxes = None self.tokenizer = LxmertTokenizer.from_pretrained( "bert-base-uncased", do_lower_case=True ) self.max_text_length = args.max_text_length ###### Pretrainining Objective ###### tasks = [] if self.args.task_mask_lm: tasks.append('word_mask') if self.args.task_obj_predict: tasks.append('vis_mask') if self.args.task_matched: tasks.append('matched') if self.args.task_qa: tasks.append('qa') self.tasks = tasks if self.verbose: print('data_out:', self.data_out) def __len__(self): return len(self.data) def __getitem__(self, idx): out_dict = {} datum = self.data[idx] uid = datum['uid'] out_dict['uid'] = uid out_dict['args'] = self.args ###### Image ###### img_id = datum['img_id'] if 'cluster_id' in self.data_out: # cluster_id = datum['cluster_id'] img_id_to_cluster_id = self.data_source_to_cluster_data[datum['img_source']] cluster_id = img_id_to_cluster_id[img_id] assert cluster_id is not None, datum cluster_id = torch.from_numpy(cluster_id) out_dict['cluster_id'] = cluster_id if self.source_to_h5 is None: self.source_to_h5 = {} for source, path in self.data_source_to_h5_path.items(): self.source_to_h5[source] = None source = self.img_ids_to_source[img_id] f = self.source_to_h5[source] if f is None: path = self.data_source_to_h5_path[source] f = h5py.File(path, 'r') self.source_to_h5[source] = f if 'feat' in self.data_out: if self.args.grid_model: feats = np.zeros( shape=(self.grid_size, self.grid_size, self.args.feat_dim), dtype=np.float32) f[f'{img_id}/features'].read_direct(feats) feats = np.reshape(feats, (self.n_grids, self.args.feat_dim)) feats = torch.from_numpy(feats) else: feats = np.zeros(shape=(self.n_boxes, self.args.feat_dim), dtype=np.float32) f[f'{img_id}/features'].read_direct(feats) feats = torch.from_numpy(feats) out_dict['vis_feats'] = feats if 'obj_id' in self.data_out: obj_id = np.zeros(shape=(self.n_boxes), dtype=int) f[f'{img_id}/obj_id'].read_direct(obj_id) obj_id = torch.from_numpy(obj_id) out_dict['obj_id'] = obj_id if self.args.grid_model: boxes = self.boxes boxes = torch.from_numpy(boxes) else: # Normalize the boxes (to 0 ~ 1) img_h = f[f'{img_id}/img_h'][()] img_w = f[f'{img_id}/img_w'][()] boxes = f[f'{img_id}/boxes'][()] boxes[:, (0, 2)] /= img_w boxes[:, (1, 3)] /= img_h # np.testing.assert_array_less(boxes, 1+1e-5) # np.testing.assert_array_less(boxes, 1+5e-2) np.testing.assert_array_less(-boxes, 0+1e-5) boxes = torch.from_numpy(boxes) boxes.clamp_(min=0.0, max=1.0) out_dict['boxes'] = boxes # if self.args.vis_sampling: # sampled_idx = np.random.choice(self.n_grids, self.args.n_vis_sampling, replace=False) # out_dict['boxes'] = boxes[sampled_idx] # if 'cluster_id' in self.data_out: # out_dict['cluster_id'] = cluster_id[sampled_idx] # if 'feat' in self.data_out: # out_dict['vis_feats'] = feats[sampled_idx] ###### Text ##### sent = datum['sent'] # input_ids, n_tokens = text_process(sent) input_ids, n_tokens = datum['input_ids'], datum['n_tokens'] input_ids = torch.LongTensor(input_ids) out_dict['sent'] = sent out_dict['input_ids'] = input_ids out_dict['n_tokens'] = n_tokens # Flip -> Img-Text not matched if 'matched' in self.data_out and random.random() < 0.5: other_datum = self.data[random.randint(0, len(self.data) - 1)] while img_id == other_datum['img_id']: other_datum = self.data[random.randint(0, len(self.data) - 1)] other_sent = other_datum['sent'] # other_input_ids, other_n_tokens = text_process(other_sent) other_input_ids, other_n_tokens = other_datum['input_ids'], other_datum['n_tokens'] other_input_ids = torch.LongTensor(other_input_ids) out_dict['matched_label'] = 0 out_dict['other_sent'] = other_sent out_dict['other_input_ids'] = other_input_ids out_dict['other_n_tokens'] = other_n_tokens else: out_dict['matched_label'] = 1 # out_dict['other_sent'] = sent # out_dict['other_input_ids'] = input_ids out_dict['other_n_tokens'] = n_tokens if self.args.task_qa: # Label, convert answer to id if 'label' in datum: label = datum['label'].copy() if len(label) > 0: for ans in list(label.keys()): label[self.answer_table.ans2id(ans)] = label.pop(ans) keys, values = zip(*label.items()) # single answer if len(keys) == 1: ans = keys[0] # multiple answers -> sample one answer else: value_sum = sum(values) prob = [value / value_sum for value in values] choice = np.random.multinomial(1, prob).argmax() ans = keys[choice] else: ans = -1 else: ans = -1 out_dict['ans'] = ans if self.args.vis_mask_predict: if self.args.square_mask: if self.args.vis_sampling: grid_size = int(math.sqrt(self.args.n_vis_sampling)) else: grid_size = self.args.grid_size mask_size = random.randint(1, grid_size) vis_mask = torch.zeros(grid_size, grid_size) mask_position_h = random.randint(0, grid_size - mask_size) mask_position_w = random.randint(0, grid_size - mask_size) vis_mask[mask_position_h:mask_position_h + mask_size, mask_position_w:mask_position_w + mask_size] = 1 out_dict['vis_mask'] = vis_mask.flatten() else: if self.args.vis_sampling: total_idx = list(range(self.args.n_vis_sampling)) n_max_mask = self.args.n_vis_sampling else: if self.args.grid_model: total_idx = list(range(self.n_grids)) n_max_mask = self.n_grids else: total_idx = list(range(self.args.n_boxes)) n_max_mask = self.n_boxes n_masks = random.randint(1, n_max_mask) vis_mask = torch.zeros(n_max_mask) vis_mask_idx = np.random.choice(total_idx, n_masks, replace=False) vis_mask_idx = torch.from_numpy(vis_mask_idx) vis_mask[vis_mask_idx] = 1 out_dict['vis_mask'] = vis_mask # if self.args.VMP_smart: # if self.args.square_mask: # if self.args.vis_sampling: # grid_size = int(math.sqrt(self.args.n_vis_sampling)) # else: # grid_size = self.args.grid_size # mask_size = random.randint(1, grid_size) # vis_mask = torch.zeros(grid_size, grid_size) # mask_position_h = random.randint(0, grid_size - mask_size) # mask_position_w = random.randint(0, grid_size - mask_size) # vis_mask[mask_position_h:mask_position_h + mask_size, mask_position_w:mask_position_w + mask_size] = 1 # out_dict['vis_mask_2'] = vis_mask.flatten() # else: # if self.args.vis_sampling: # total_idx = list(range(self.args.n_vis_sampling)) # n_max_mask = self.args.n_vis_sampling # else: # if self.args.grid_model: # total_idx = list(range(self.n_grids)) # n_max_mask = self.n_grids # else: # total_idx = list(range(self.args.n_boxes)) # n_max_mask = self.n_boxes # n_masks = random.randint(1, n_max_mask) # vis_mask = torch.zeros(n_max_mask) # vis_mask_idx = np.random.choice(total_idx, n_masks, replace=False) # vis_mask_idx = torch.from_numpy(vis_mask_idx) # vis_mask[vis_mask_idx] = 1 # out_dict['vis_mask_2'] = vis_mask else: if self.args.grid_model: if self.args.vis_sampling: vis_mask = torch.bernoulli( torch.full((self.args.n_vis_sampling,), self.args.obj_mask_rate)).bool() else: vis_mask = torch.bernoulli( torch.full((self.n_grids,), self.args.obj_mask_rate)).bool() out_dict['vis_mask'] = vis_mask else: vis_mask = torch.bernoulli( torch.full((self.n_boxes,), self.args.obj_mask_rate)).bool() out_dict['vis_mask'] = vis_mask if self.args.vis_mask_COCO_only: quotient = idx // len(self.COCO_data) if len(self.data) - quotient * len(self.COCO_data) < len(self.COCO_data): coco_idx = random.randint(0, len(self.COCO_data) - 1) else: coco_idx = idx % len(self.COCO_data) coco_datum = self.COCO_data[coco_idx] if self.args.vis_mask_COCO_only: assert coco_datum['text_source'] == 'mscoco' assert 'mscoco' in coco_datum['img_source'] coco_input_ids, coco_n_tokens = coco_datum['input_ids'], coco_datum['n_tokens'] coco_input_ids = torch.LongTensor(coco_input_ids) out_dict['COCO_input_ids'] = coco_input_ids out_dict['COCO_n_tokens'] = coco_n_tokens if 'cluster_id' in self.data_out: img_id = coco_datum['img_id'] # cluster_id = datum['cluster_id'] img_id_to_cluster_id = self.data_source_to_cluster_data[coco_datum['img_source']] cluster_id = img_id_to_cluster_id[img_id] assert cluster_id is not None, coco_datum cluster_id = torch.from_numpy(cluster_id) out_dict['COCO_cluster_id'] = cluster_id return out_dict