def __init__(self, max_seq_length): super().__init__() self.max_seq_length = max_seq_length self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) # Build model set_visual_config(args) self.model = LXRTPretraining.from_pretrained( "bert-base-uncased", task_mask_lm=args.task_mask_lm, task_obj_predict=args.task_obj_predict, task_matched=args.task_matched, task_qa=args.task_qa, visual_losses=args.visual_losses, num_answers=train_dset.answer_table.num_answers) # Weight initialization and loading if args.from_scratch: print("Train from Scratch: re-initialize all BERT weights.") self.model.apply(self.model.init_bert_weights) if args.load is not None: self.load(args.load) if args.load_lxmert is not None: # Load lxmert would not load the answer head. self.load_lxmert(args.load_lxmert) # GPU Options self.model = self.model.cuda() if args.multiGPU: self.model = nn.DataParallel(self.model)
def __init__(self, args, max_seq_length, mode='x'): super().__init__() self.max_seq_length = max_seq_length from lxrt.modeling import LXRTFeatureExtraction as VisualBertForLXRFeature, VISUAL_CONFIG set_visual_config(args, VISUAL_CONFIG) # Using the bert tokenizer self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) # Build LXRT Model self.model = VisualBertForLXRFeature.from_pretrained( "bert-base-uncased", mode=mode) if args.from_scratch: print("Re-initializing all the weights") self.model.apply(self.model.init_bert_weights) self.load_pretrain_head = args.get("load_pretrain_head", False) if self.load_pretrain_head: from lxmert.src.lxrt.modeling import BertPreTrainingHeads self.pretrained_head = BertPreTrainingHeads( self.model.config, self.model.bert.embeddings.word_embeddings.weight)
def __init__(self, num_answers): super().__init__() # Build LXRT encoder self.lxrt_encoder = LXRTEncoder(args, max_seq_length=MAX_VQA_LENGTH) hid_dim = self.lxrt_encoder.dim # VQA Answer heads self.logit_fc = nn.Sequential(nn.Linear(hid_dim, hid_dim * 2), GeLU(), BertLayerNorm(hid_dim * 2, eps=1e-12), nn.Linear(hid_dim * 2, num_answers)) self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights) self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
def __init__(self, args, max_seq_length, mode='x'): super().__init__() self.max_seq_length = max_seq_length set_visual_config(args) # Using the bert tokenizer self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) # Build LXRT Model self.model = VisualBertForLXRFeature.from_pretrained( "bert-base-uncased", mode=mode) if args.from_scratch: print("initializing all the weights") self.model.apply(self.model.init_bert_weights)
def __init__(self, files, file_type='train-match', batch_size=256): self.files = files self.file_type = file_type self.batch_size = batch_size # generate dictionary of labels self.dict_multimodal_labels = {} for line in open(os.path.join(KDD_DATA, "../data/multimodal_labels.txt"), encoding='utf-8'): arr = line.strip().split("\t") label = arr[1].replace(",", " ").replace(".", " ").replace( "(", " ").replace(")", " ") self.dict_multimodal_labels[arr[0]] = label.strip() self.tokenizer = BertTokenizer.from_pretrained("../user_data", do_lower_case=True)
def __init__(self, args, max_seq_length, mode='x'): super().__init__() self.max_seq_length = max_seq_length # Using the bert tokenizer self.tokenizer = BertTokenizer.from_pretrained( "bert-base-uncased", do_lower_case=True ) config = BertConfig.from_json_file(args.config_file) # Build Model self.model = InterBertForVLTasks.from_pretrained( "snap/pretrained/inter_bert/pytorch_model.bin", config ) if args.from_scratch: print("initializing all the weights") self.model.apply(self.model.init_bert_weights)
def __init__(self, max_seq_length): super().__init__() self.max_seq_length = max_seq_length self.tokenizer = BertTokenizer.from_pretrained( "bert-base-uncased", do_lower_case=True ) # Build model self.model = LXRTPretraining.from_pretrained( "bert-base-uncased", args = args, task_mask_lm=args.task_mask_lm, task_obj_predict=args.task_obj_predict, task_matched=args.task_matched, task_qa=args.task_qa, visual_losses=args.visual_losses, num_answers= args.num_answers if args.get("num_answers", None) else train_tuple.dataset.answer_table.num_answers ) # Weight initialization and loading if args.from_scratch: print("Train from Scratch: re-initialize all BERT weights.") self.model.apply(self.model.init_bert_weights) if args.get("use_tag_symbolic_embedding", False): self.model.bert.embeddings.initialize_symbolic_embeddings(symbolic_vocab.get_symbolic_list(self.tokenizer)) self.model.special_initialize_pretraining_head() if args.get("hybrid_embedding", False): self.model.bert.embeddings.initialize_visual_position_type_embeddings() if args.load_lxmert is not None: # Load lxmert would not load the answer head. self.load_lxmert(args.load_lxmert) self.model = self.model.cuda() if args.multiGPU: self.model = nn.DataParallel(self.model) self.global_step = 0
def __init__(self, args, max_seq_length, mode='x', attention=False): super().__init__() print(f"Making {__name__}") self.max_seq_length = max_seq_length set_visual_config(args) # Using the bert tokenizer self.tokenizer = BertTokenizer.from_pretrained( "bert-base-uncased", do_lower_case=True ) print("Made Tokenizer") # Build LXRT Model self.model = VisualBertForLXRFeature.from_pretrained( "bert-base-uncased", mode=mode, attention=attention ) print("Made VisualBertForLXRFeature") if args.from_scratch: print("initializing all the weights") self.model.apply(self.model.init_bert_weights) print(f"Done {__name__}")
def __init__(self, ann_file, pretrained_model_name, tokenizer=None, seq_len=64, min_seq_len=64, encoding="utf-8", on_memory=True, **kwargs): assert on_memory, "only support on_memory mode!" self.tokenizer = tokenizer if tokenizer is not None else BertTokenizer.from_pretrained( pretrained_model_name) self.vocab = self.tokenizer.vocab self.seq_len = seq_len self.min_seq_len = min_seq_len self.on_memory = on_memory self.ann_file = ann_file self.encoding = encoding self.test_mode = False self.do_no_fill = False self.use_mismatch_objective = args.get("task_matched", False) #self.load_corpus_with_passages() # load samples into memory if on_memory: if self.use_mismatch_objective: #self.corpus = self.load_corpus_with_passages_preprocess() self.load_corpus_with_passages_preprocess() else: self.corpus = self.load_corpus() if args.get("presegment_sentence", False): self.presegment_sentence() print("Using {} with {} data.\n\n".format(self.ann_file, len(self)))
def __init__(self, dataset: VQADataset, args): super().__init__() self.raw_dataset = dataset if args.tiny: topk = TINY_IMG_NUM elif args.fast: topk = FAST_IMG_NUM else: topk = None self.limit_to_symbolic_split = args.get("limit_to_symbolic_split", False) if self.limit_to_symbolic_split: dataDir = "/local/harold/ubert/bottom-up-attention/data/vg/" coco_ids = set() self.mapping_cocoid_to_imageid = {} with open(os.path.join(dataDir, 'image_data.json')) as f: metadata = json.load(f) for item in metadata: if item['coco_id']: coco_ids.add(int(item['coco_id'])) self.mapping_cocoid_to_imageid[int( item['coco_id'])] = item["image_id"] from lib.data.vg_gqa import vg_gqa self.vg_gqa = vg_gqa( None, split="val" if self.raw_dataset.name == "minival" else "train", transforms=None, num_im=-1) self.custom_coco_data = args.get("custom_coco_data", False) self.use_h5_file = args.get("use_h5_file", False) if self.use_h5_file: self.image_feature_dataset = ImageFeatureDataset.create( dataset.splits, Split2ImgFeatPath, on_memory=args.get("on_memory", False)) self.ids_to_index = self.image_feature_dataset.ids_to_index # Screen data used_data = [] for datum in self.raw_dataset.data: if datum['img_id'] in self.ids_to_index: used_data.append(datum) else: # Loading detection features to img_data img_data = [] for split in dataset.splits: # Minival is 5K images in MS COCO, which is used in evaluating VQA/LXMERT-pre-training. # It is saved as the top 5K features in val2014_***.tsv load_topk = 5000 if (split == 'minival' and topk is None) else topk img_data.extend( load_obj_tsv(os.path.join( MSCOCO_IMGFEAT_ROOT, '%s_obj36.tsv' % (SPLIT2NAME[split])), topk=load_topk)) # Convert img list to dict self.imgid2img = {} for img_datum in img_data: self.imgid2img[img_datum['img_id']] = img_datum used_data = self.raw_dataset.data used_data = used_data[::args.get("partial_dataset", 1)] self.data = used_data # Only kept the data with loaded image features print("Use %d data in torch dataset" % (len(self.data))) print() if args.get("add_tags", False): self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) from lxrt.symbolic_vocabulary import SymbolicVocab self.symbolic_vocab = SymbolicVocab(args.objects_vocab, args.attributes_vocab)
def predict(self, eval_tuple: DataTuple, dump=None): """ Predict the answers to questions in a data split. :param eval_tuple: The data tuple to be evaluated. :param dump: The path of saved file to dump results. :return: A dict of question_id to answer. """ self.model.eval() dset, loader, evaluator = eval_tuple question_id2img_id = {x["question_id"]: x["img_id"] for x in dset.data} tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) plt.rcParams['figure.figsize'] = (12, 10) num_regions = 36 count = 0 quesid2ans = {} for i, datum_tuple in enumerate(loader): ques_id, feats, boxes, sent = datum_tuple[: 4] # Avoid seeing ground truth with torch.no_grad(): feats, boxes = feats.cuda(), boxes.cuda() logit = self.model(feats, boxes, sent) for layer in [0, 4]: for head in [0, 1]: for datapoint in range(len(sent)): print(count, len(sent)) count += 1 lang2vis_attention_probs = self.model.lxrt_encoder.model.bert.encoder.x_layers[ layer].lang_att_map[datapoint][head].detach( ).cpu().numpy() vis2lang_attention_probs = self.model.lxrt_encoder.model.bert.encoder.x_layers[ layer].visn_att_map[datapoint][head].detach( ).cpu().numpy() plt.clf() plt.subplot(2, 3, 1) plt.gca().set_axis_off() plt.title("Image (regions 0-7)") im = cv2.imread( os.path.join( "/mnt/8tera/claudio.greco/mscoco_trainval_2014", question_id2img_id[ ques_id[datapoint].item()]) + ".jpg") im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) plt.imshow(im) plt.subplot(2, 3, 2) plt.gca().set_axis_off() plt.title("Image (regions 8-15)") im = cv2.imread( os.path.join( "/mnt/8tera/claudio.greco/mscoco_trainval_2014", question_id2img_id[ ques_id[datapoint].item()]) + ".jpg") im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) plt.imshow(im) plt.subplot(2, 3, 3) plt.gca().set_axis_off() plt.title("Image (regions 16-35)") im = cv2.imread( os.path.join( "/mnt/8tera/claudio.greco/mscoco_trainval_2014", question_id2img_id[ ques_id[datapoint].item()]) + ".jpg") im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) plt.imshow(im) img_info = loader.dataset.imgid2img[ question_id2img_id[ques_id[datapoint].item()]] img_h, img_w = img_info['img_h'], img_info['img_w'] unnormalized_boxes = boxes[datapoint].clone() unnormalized_boxes[:, (0, 2)] *= img_w unnormalized_boxes[:, (1, 3)] *= img_h for i, bbox in enumerate(unnormalized_boxes): if i < 8: plt.subplot(2, 3, 1) elif i < 16: plt.subplot(2, 3, 2) else: plt.subplot(2, 3, 3) bbox = [ bbox[0].item(), bbox[1].item(), bbox[2].item(), bbox[3].item() ] if bbox[0] == 0: bbox[0] = 2 if bbox[1] == 0: bbox[1] = 2 plt.gca().add_patch( plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0] - 4, bbox[3] - bbox[1] - 4, fill=False, edgecolor='red', linewidth=1)) plt.gca().text(bbox[0], bbox[1] - 2, '%s' % i, bbox=dict(facecolor='blue'), fontsize=9, color='white') ax = plt.subplot(2, 1, 2) plt.title("Cross-modal attention lang2vis") tokenized_question = tokenizer.tokenize( sent[datapoint]) tokenized_question = [ "<CLS>" ] + tokenized_question + ["<SEP>"] transposed_attention_map = lang2vis_attention_probs[:len( tokenized_question), :num_regions] im = plt.imshow(transposed_attention_map, vmin=0, vmax=1) for i in range(len(tokenized_question)): for j in range(num_regions): att_value = round( transposed_attention_map[i, j], 1) text = ax.text( j, i, att_value, ha="center", va="center", color="w" if att_value <= 0.5 else "b", fontsize=6) ax.set_xticks(np.arange(num_regions)) ax.set_xticklabels(list(range(num_regions))) ax.set_yticks(np.arange(len(tokenized_question))) ax.set_yticklabels(tokenized_question) plt.tight_layout() # plt.gca().set_axis_off() plt.savefig( "/mnt/8tera/claudio.greco/guesswhat_lxmert/guesswhat/visualization_vqa/lang2vis_question_{}_layer_{}_head_{}.png" .format(ques_id[datapoint].item(), layer, head), bbox_inches='tight', pad_inches=0.5) plt.close() ## vis2lang plt.clf() plt.subplot(2, 3, 1) plt.gca().set_axis_off() plt.title("Image (regions 0-7)") im = cv2.imread( os.path.join( "/mnt/8tera/claudio.greco/mscoco_trainval_2014", question_id2img_id[ ques_id[datapoint].item()]) + ".jpg") im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) plt.imshow(im) plt.subplot(2, 3, 2) plt.gca().set_axis_off() plt.title("Image (regions 8-15)") im = cv2.imread( os.path.join( "/mnt/8tera/claudio.greco/mscoco_trainval_2014", question_id2img_id[ ques_id[datapoint].item()]) + ".jpg") im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) plt.imshow(im) plt.subplot(2, 3, 3) plt.gca().set_axis_off() plt.title("Image (regions 16-35)") im = cv2.imread( os.path.join( "/mnt/8tera/claudio.greco/mscoco_trainval_2014", question_id2img_id[ ques_id[datapoint].item()]) + ".jpg") im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) plt.imshow(im) img_info = loader.dataset.imgid2img[ question_id2img_id[ques_id[datapoint].item()]] img_h, img_w = img_info['img_h'], img_info['img_w'] unnormalized_boxes = boxes[datapoint].clone() unnormalized_boxes[:, (0, 2)] *= img_w unnormalized_boxes[:, (1, 3)] *= img_h for i, bbox in enumerate(unnormalized_boxes): if i < 8: plt.subplot(2, 3, 1) elif i < 16: plt.subplot(2, 3, 2) else: plt.subplot(2, 3, 3) bbox = [ bbox[0].item(), bbox[1].item(), bbox[2].item(), bbox[3].item() ] if bbox[0] == 0: bbox[0] = 2 if bbox[1] == 0: bbox[1] = 2 plt.gca().add_patch( plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0] - 4, bbox[3] - bbox[1] - 4, fill=False, edgecolor='red', linewidth=1)) plt.gca().text(bbox[0], bbox[1] - 2, '%s' % i, bbox=dict(facecolor='blue'), fontsize=9, color='white') ax = plt.subplot(2, 1, 2) plt.title("Cross-modal attention vis2lang") tokenized_question = tokenizer.tokenize( sent[datapoint]) tokenized_question = [ "<CLS>" ] + tokenized_question + ["<SEP>"] transposed_attention_map = vis2lang_attention_probs.transpose( )[:len(tokenized_question), :num_regions] im = plt.imshow(transposed_attention_map, vmin=0, vmax=1) for i in range(len(tokenized_question)): for j in range(num_regions): att_value = round( transposed_attention_map[i, j], 1) text = ax.text( j, i, att_value, ha="center", va="center", color="w" if att_value <= 0.5 else "b", fontsize=6) ax.set_xticks(np.arange(num_regions)) ax.set_xticklabels(list(range(num_regions))) ax.set_yticks(np.arange(len(tokenized_question))) ax.set_yticklabels(tokenized_question) plt.tight_layout() # plt.gca().set_axis_off() plt.savefig( "/mnt/8tera/claudio.greco/guesswhat_lxmert/guesswhat/visualization_vqa/vis2lang_question_{}_layer_{}_head_{}.png" .format(ques_id[datapoint].item(), layer, head), bbox_inches='tight', pad_inches=0.5) plt.close() # print(datapoint, len(sent)) # # print(datapoint) # if datapoint > 20: # break # if datapoint > 20: # break # if datapoint > 20: # break score, label = logit.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid.item()] = ans if dump is not None: evaluator.dump_result(quesid2ans, dump) return quesid2ans
def __init__(self, dataset: LXMERTDataset, topk=-1, sgg_dataset=None, image_only=False, text_only=False, use_visual_tag_flag=False, limit_source=[], available_split_for_cc=None): super().__init__() self.raw_dataset = dataset self.name = '_'.join(self.raw_dataset.sources) if args.get('disable_mismatch_for_other_dataset', False): # Do not resample for datasets such as BookCorpus self.task_matched = args.task_matched if "book_corpus" in self.raw_dataset.sources else False else: self.task_matched = args.task_matched print(self.raw_dataset.sources) print(self.task_matched) print("\n\n\n") self.sgg_dataset = sgg_dataset self.image_only = image_only self.text_only = text_only self.use_visual_tag_flag = use_visual_tag_flag self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) self.task_nlvr2 = args.get("task_nlvr2", False) if args.tiny: topk = TINY_IMG_NUM elif args.fast: topk = FAST_IMG_NUM #self.fake_data = args.get("fake_data", False) self.custom_coco_data = args.get("custom_coco_data", False) self.use_h5_file = args.get("use_h5_file", False) if self.use_h5_file: if "google_cc_train" in dataset.sources: if args.get('change_split', False): available_split_for_cc = [39] else: available_split_for_cc = args.get("available_split_for_cc", [0]) sources = [] split_map = {} for i in available_split_for_cc: sources.append("google_cc_{}".format(i)) split_map["google_cc_{}".format( i )] = "data/google_concetual/butd_feat/train_no_features_split_{}_of_40_splits.h5".format( i) self.image_feature_dataset = ImageFeatureDataset.create( sources, split_map, load_custom_h5_version2=True, text_only=self.text_only, on_memory=False) elif "open_images_train" in dataset.sources: available_split_for_open_image = args.get( "available_split_for_open_image", [0]) sources = [] split_map = {} for split_i, split_j, total_split in available_split_for_open_image: sources.append("open_image_{}_{}".format(split_i, split_j)) split_map["open_image_{}_{}".format( split_i, split_j )] = "data/open_image/butd_feat/train_{}_no_features_split_{}_of_{}_splits.h5".format( split_i, split_j, total_split) self.image_feature_dataset = ImageFeatureDataset.create( sources, split_map, load_custom_h5_version2=True, on_memory=False) else: self.image_feature_dataset = ImageFeatureDataset.create( dataset.sources, Split2ImgFeatPath_h5, text_only=self.text_only, load_custom_h5_version2=True if "flickr_train" in dataset.sources else False, on_memory=args.get("on_memory", False)) self.ids_to_index = self.image_feature_dataset.ids_to_index # Screen data used_data = [] for datum in self.raw_dataset.data: if datum['img_id'] in self.ids_to_index: used_data.append(datum) else: # Original LXMERT. Load the dataset img_data = [] for source in self.raw_dataset.sources: img_data.extend(load_obj_tsv(Split2ImgFeatPath[source], topk)) self.imgid2img = {} for img_datum in img_data: self.imgid2img[img_datum['img_id']] = img_datum # Filter out the dataset used_data = [] for datum in self.raw_dataset.data: if datum['img_id'] in self.imgid2img: used_data.append(datum) used_data = used_data[::args.get("partial_dataset", 1)] if sgg_dataset is not None: used_data = [ datum for datum in used_data if str(datum["img_id"]) in self.sgg_dataset.imageids_to_index ] # Flatten the dataset (into one sent + one image entries) self.data = [] record_img_id = set() remaining_set = set() for datum in used_data: # datum: {'img_id': 'COCO_train2014_000000318556', 'labelf': {'vqa': [{'no': 1}, {'yes': 1}, {'no': 1}, {'blue': 1, 'blue and white': 0.3}]}, 'sentf': {'mscoco': ['A very clean and well decorated empty bathroom', 'A blue and white bathroom with butterfly themed wall tiles.', 'A bathroom with a border of butterflies and blue paint on the walls above it.', 'An angled view of a beautifully decorated bathroom.', 'A clock that blends in with the wall hangs in a bathroom. '], 'vqa': ['Is the sink full of water?', 'Are there any butterflies on the tiles?', 'Is this bathroom in a hotel?', 'What color are the walls?']}} sentf = datum['sentf'] for sents_cat, sents in sentf.items(): if sents_cat in limit_source: continue remaining_set.add(sents_cat) if sents_cat in datum['labelf']: labels = datum['labelf'][sents_cat] else: labels = None for sent_idx, sent in enumerate(sents): new_datum = { 'uid': make_uid(datum['img_id'], sents_cat, sent_idx) if args.task_qa else None, 'img_id': datum['img_id'], # if not self.text_only else "", 'sent': sent #if not self.image_only else "" } if image_only: # If we only use image, make sure one image only appears one time if datum["img_id"] in record_img_id: continue record_img_id.add(datum["img_id"]) if labels is not None and args.task_qa: new_datum['label'] = labels[sent_idx] if self.task_nlvr2: new_datum['match_label'] = datum["label"] new_datum['img_id_1'] = datum["img_id_1"] self.data.append(new_datum) if image_only: dataset_str = "image_only" elif text_only: dataset_str = "text_only" else: dataset_str = "vision and language" if self.image_only and args.get("screen_image", False): counter = 0 from tqdm import tqdm _data = [] for data_item in tqdm(self.data): img_id = data_item["img_id"] image_index = self.image_feature_dataset.ids_to_index[img_id] img_h = self.image_feature_dataset.h5_wh[image_index][1] img_w = self.image_feature_dataset.h5_wh[image_index][0] if img_h == 0 or img_w == 0: counter += 1 else: _data.append(data_item) print( "Screened {} images with zero heights and weidths, {} in total" .format(counter, len(_data))) self.data = _data print("Use {} data in {} torch dataset, {}, limit_source {}".format( len(self.data), dataset_str, remaining_set, limit_source)) if text_only: del self.image_feature_dataset if text_only or image_only: del self.raw_dataset.data del self.raw_dataset self.compress_memory = False if args.get("compress_memory", False): # Move some data to shared memory so the memory will not explode when using multi-process for data loading self.compress() print("\n\n\n")