def get_embeddings(dataset, save_file, pretrained_model=None, random_patches=False): torch.cuda.empty_cache() torch.manual_seed(0) dataloader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, ) print('dataloader len: ', len(dataloader)) if pretrained_model != None: embedder = codedbert_embedder.from_pretrained( pretrained_model, output_hidden_states=True, return_dict=True) else: embedder = codedbert_embedder.from_pretrained( 'bert-base-uncased', output_hidden_states=True, return_dict=True) embedder.to(device) embedder.eval() tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') codedbert_embeds = dict() with torch.no_grad(): for i, (patches, _, input_ids, attention_mask, _, _, img_name) in enumerate(tqdm(dataloader)): input_ids = input_ids.to(device) patches = patches.to(device) inputs = input_ids.squeeze(0).detach().tolist() seq = tokenizer.convert_ids_to_tokens(inputs) seq = tokenizer.convert_tokens_to_string(seq) embeds = construct_bert_input(patches, input_ids, embedder, device, random_patches) attention_mask = F.pad(attention_mask, (0, embeds.shape[1] - input_ids.shape[1]), value=1) text_emb, img_emb = embedder.embed(embeds, attention_mask) codedbert_embeds[img_name[0]] = { 'text': seq, 'text_emb': text_emb.tolist(), 'img_emb': img_emb.tolist() } save_json(save_file, codedbert_embeds)
def train(coded_bert, dataset, params, device, random_patches=False): print('Random patches? ', random_patches) torch.manual_seed(0) train_size = int(len(dataset) * .8) test_size = len(dataset) - train_size train_set, test_set = torch.utils.data.random_split( dataset, [train_size, test_size]) dataloader = torch.utils.data.DataLoader( train_set, batch_size=params.batch_size, shuffle=True, ) coded_bert.to(device) coded_bert.train() opt = transformers.Adafactor(coded_bert.parameters(), lr=params.lr, beta1=params.beta1, weight_decay=params.weight_decay, clip_threshold=params.clip, relative_step=False, scale_parameter=True, warmup_init=False) scheduler = get_linear_schedule_with_warmup( opt, params.num_warmup_steps, params.num_epochs * len(dataloader)) for ep in range(params.num_epochs): avg_losses = { "masked_lm_loss": [], "masked_patch_loss": [], "alignment_loss_text": [], "alignment_loss_sent": [], "total": [] } for patches, input_ids, is_paired, attention_mask, sents_embeds in dataloader: # sents_embeds = [batch, 16, 768] opt.zero_grad() # mask image patches with prob 10% im_seq_len = patches.shape[1] masked_patches = patches.detach().clone() masked_patches = masked_patches.view(-1, patches.shape[2]) im_mask = torch.rand((masked_patches.shape[0], 1)) >= 0.1 masked_patches *= im_mask try: masked_patches = masked_patches.view(params.batch_size, im_seq_len, patches.shape[2]) except Exception as e: print(e) print(f"masked_patches: {masked_patches.shape}") print(f"im_mask: {im_mask.shape}") print(f"patches: {patches.shape}") continue # mask tokens with prob 15%, note id 103 is the [MASK] token token_mask = torch.rand(input_ids.shape) masked_input_ids = input_ids.detach().clone() masked_input_ids[token_mask < 0.15] = 103 input_ids[token_mask >= 0.15] = -100 input_ids[attention_mask == 0] = -100 masked_patches = masked_patches.to(device) embeds = construct_bert_input(masked_patches, masked_input_ids, coded_bert, sentences=sents_embeds, device=device, random_patches=random_patches) # pad attention mask with 1s so model pays attention to the image parts attention_mask = F.pad( attention_mask, (0, embeds.shape[1] - input_ids.shape[1]), value=1 ) # still works with sentences since we are using 1s for sent and imgs outputs = coded_bert(embeds=embeds.to(device), attention_mask=attention_mask.to(device), labels=input_ids.to(device), unmasked_patch_features=patches.to(device), is_paired=is_paired.to(device)) loss = adaptive_loss_4losses(outputs) loss.backward() opt.step() scheduler.step() for k, v in outputs.items(): if k in avg_losses: avg_losses[k].append(v.cpu().item()) avg_losses["total"].append(loss.cpu().item()) print("***************************") print(f"At epoch {ep+1}, losses: ") for k, v in avg_losses.items(): print(f"{k}: {sum(v) / len(v)}") print("***************************")
def train(fashion_bert, dataset, params, device): torch.manual_seed(0) train_size = int(len(dataset) * .8) test_size = len(dataset) - train_size train_set, test_set = torch.utils.data.random_split( dataset, [train_size, test_size]) dataloader = torch.utils.data.DataLoader( train_set, batch_size=params.batch_size, shuffle=True, ) fashion_bert.to(device) fashion_bert.train() opt = transformers.Adafactor(fashion_bert.parameters(), lr=params.lr, beta1=params.beta1, weight_decay=params.weight_decay, clip_threshold=params.clip, relative_step=False, scale_parameter=True, warmup_init=False) POS_MAP = Pos.POS_MAP ADJ_IDX = 1 IDX2POS = dict() POS2IDX = dict() for i, pos in enumerate(POS_MAP): IDX2POS[i] = pos POS2IDX[pos] = i tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') scheduler = get_linear_schedule_with_warmup( opt, params.num_warmup_steps, params.num_epochs * len(dataloader)) for ep in range(params.num_epochs): avg_losses = { "masked_lm_loss": [], "masked_patch_loss": [], "alignment_loss": [], "total": [] } for b, (patches, input_ids, is_paired, attention_mask) in enumerate(dataloader): opt.zero_grad() # mask image patches with prob 10% im_seq_len = patches.shape[1] masked_patches = patches.detach().clone() masked_patches = masked_patches.view(-1, patches.shape[2]) im_mask = torch.rand((masked_patches.shape[0], 1)) >= 0.1 masked_patches *= im_mask try: masked_patches = masked_patches.view(params.batch_size, im_seq_len, patches.shape[2]) except Exception as e: print(e) print(f"masked_patches: {masked_patches.shape}") print(f"im_mask: {im_mask.shape}") print(f"patches: {patches.shape}") continue # mask tokens with prob 15%, note id 103 is the [MASK] token # token_mask = torch.rand(input_ids.shape) # masked_input_ids = input_ids.detach().clone() # masked_input_ids[token_mask < 0.15] = 103 # input_ids[token_mask >= 0.15] = -100 # input_ids[attention_mask == 0] = -100 # Get POS tensor pos_tensor = get_pos_tensor(input_ids, tokenizer, POS2IDX) masked_input_ids, labels = mask_input_ids( input_ids, pos_tensor, adj_mask_prob=.80, other_mask_prob=.12, adj_value=ADJ_IDX, attention_mask=attention_mask, mask_value=103) embeds = construct_bert_input(masked_patches, masked_input_ids, fashion_bert, device=device) # pad attention mask with 1s so model pays attention to the image parts attention_mask = F.pad(attention_mask, (0, embeds.shape[1] - input_ids.shape[1]), value=1) if b % 100 == 0: attention_mask[:, :448] = 0 outputs = fashion_bert(embeds=embeds.to(device), attention_mask=attention_mask.to(device), labels=input_ids.to(device), unmasked_patch_features=patches.to(device), is_paired=is_paired.to(device)) loss = adaptive_loss(outputs) loss.backward() opt.step() scheduler.step() for k, v in outputs.items(): if k in avg_losses: avg_losses[k].append(v.cpu().item()) avg_losses["total"].append(loss.cpu().item()) print("************TRAINING*************") print(f"At epoch {ep+1}, losses: ") for k, v in avg_losses.items(): print(f"{k}: {sum(v) / len(v)}") print("***************************")
def text2image(patches, neg_patches, input_ids, is_paired, attention_mask, neg_input_ids, neg_attention_mask, sents_embeds, evaluator, random_patches): """ text2image retrieval: Query = Text Paired with: 1 positive image, 100 negative images """ im_seq_len = patches.shape[1] bs = input_ids.shape[0] len_neg_inputs = neg_input_ids.shape[1] # before constructing bert, att mask is 448 long # POSITIVE IMAGE embeds = construct_bert_input(patches, input_ids, evaluator, sentences=sents_embeds, device=device, random_patches=random_patches) attention_mask_mm = F.pad(attention_mask, (0, embeds.shape[1] - input_ids.shape[1]), value=1) # [1, 512] # NEGATIVE SAMPLES all_embeds_neg = [] all_att_mask = [] for p in range(len_neg_inputs): neg_patches_sample = neg_patches[:, p, :, :] embeds_neg = construct_bert_input(neg_patches_sample, input_ids, evaluator, sentences=sents_embeds, device=device, random_patches=random_patches) attention_mask_neg = F.pad( attention_mask, (0, embeds_neg.shape[1] - input_ids.shape[1]), value=1) all_embeds_neg.append(embeds_neg) all_att_mask.append(attention_mask_neg) # Now I have all joint embeddings for 1 positive sample and 100 neg samples all_scores_query = evaluator.text2img_scores( input_ids=input_ids, embeds=embeds, att_mask=attention_mask_mm, embeds_n=all_embeds_neg, # list att_mask_n=all_att_mask) # list # Accuracy: only in positive example txt_acc, alig_acc, sent_alig_acc = evaluator.get_scores_and_metrics( embeds, # text + image embedded attention_mask_mm, # [batch, labels=input_ids, # [batch, 448] is_paired=is_paired, # [batch] only_alignment=False, ) return all_scores_query, txt_acc, alig_acc, sent_alig_acc
def image2text(patches, neg_patches, input_ids, is_paired, attention_mask, neg_input_ids, neg_attention_mask, sents_embeds, evaluator, random_patches): """ image2text retrieval: Query = Image Paired with: 1 positive text, 100 negative texts """ im_seq_len = patches.shape[1] bs = input_ids.shape[0] len_neg_inputs = neg_input_ids.shape[1] embeds = construct_bert_input(patches, input_ids, evaluator, sentences=sents_embeds, device=device, random_patches=random_patches) attention_mask_mm = F.pad(attention_mask, (0, embeds.shape[1] - input_ids.shape[1]), value=1) # NEGATIVE SAMPLE # [batch, 100, 448] all_embeds_neg = [] all_att_mask = [] all_neg_inputs = [] for j in range(len_neg_inputs): neg_input_id_sample = neg_input_ids[:, j, :] # [1, 448] neg_attention_mask_sample = neg_attention_mask[:, j, :] embeds_neg = construct_bert_input(patches, neg_input_id_sample, evaluator, sentences=sents_embeds, device=device, random_patches=random_patches) attention_mask_neg = F.pad( neg_attention_mask_sample, (0, embeds_neg.shape[1] - neg_input_id_sample.shape[1]), value=1) all_embeds_neg.append(embeds_neg) all_att_mask.append(attention_mask_neg) all_neg_inputs.append(neg_input_id_sample.detach()) # Now I have all joint embeddings for 1 positive sample and 100 neg samples all_scores_query = evaluator.img2text_scores(input_ids_p=input_ids, embeds_p=embeds, att_mask_p=attention_mask_mm, input_ids_n=all_neg_inputs, embeds_n=all_embeds_neg, att_mask_n=all_att_mask) # Accuracy: only in positive example txt_acc, alig_acc, sent_alig_acc = evaluator.get_scores_and_metrics( embeds, # text + image embedded attention_mask_mm, labels=input_ids, # [batch, 448] is_paired=is_paired, # [batch] only_alignment=False, ) return all_scores_query, txt_acc, alig_acc, sent_alig_acc
def test(path_to_dataset, sample, device, slider_att, slider_head, pretrained_model=None): im_path_fur = '/Users/manuelladron/iCloud_archive/Documents/_CMU/PHD-CD/PHD-CD_Research/ADARI/images/ADARI_v2/furniture/full' viz, tokenizer, patches, input_ids, attention_mask, img_name, patch_positions = set_model( pretrained_model, sample, device, path_to_dataset) # Text inputs = input_ids.squeeze(0).detach().tolist() seq = tokenizer.convert_ids_to_tokens(inputs) seq_st = tokenizer.convert_tokens_to_string(seq) # Image image_name = im_path_fur + '/' + img_name embeds = construct_bert_input(patches, input_ids, viz, device=device) attention_mask = F.pad(attention_mask, (0, embeds.shape[1] - input_ids.shape[1]), value=1) # [1, 512] # print('SEQUENCE') # print(seq_st) # # print('QUERY WORD') button = st.sidebar.radio('Limit words to adjectives?', ('Yes', 'No')) if button == 'Yes': adjs = get_adjs(seq) # list with [ADJ, INDEX IN STRING, OCCURRENCE] else: adjs = get_all_words(seq) adj_idx_pair = st.sidebar.selectbox('Select adjective to see attention', adjs) adj = adj_idx_pair[0] idx_adj = adj_idx_pair[1] adj_occ = adj_idx_pair[2] # Text without pad new_seq = [] for token in seq: if token != '[PAD]': new_seq.append(token) nopad_string = tokenizer.convert_tokens_to_string(new_seq) #### Gets attention viz att_max_t, att_min_t, att_max_i, att_min_i, all_att_text = viz.get_att( embeds, # text + image embedded attention_mask, seq_number=idx_adj, attention_layer=slider_att, attention_head=slider_head) att_max_t_id = att_max_t[1].tolist() att_max_t_v = att_max_t[0].tolist() att_min_t = att_min_t[1].tolist() att_max_i_id = att_max_i[1].tolist() att_min_i = att_min_i[1].tolist() N = 20 # Get N attentions around word if idx_adj >= N: neighbors = all_att_text[idx_adj - N:idx_adj + N] if (idx_adj) + N >= 448: neighbors_ids = range(idx_adj - N, 448) else: neighbors_ids = range(idx_adj - N, idx_adj + N) else: neighbors = all_att_text[:idx_adj + N] neighbors_ids = range(0, idx_adj + N) neighbors = neighbors.tolist() labels = [seq[id] for id in neighbors_ids] # This line prevents from going over the limit of the text sequence labels = labels[:len(neighbors)] col1, col2, col3 = st.beta_columns(3) # Visualize patches that receive most attention with col1: viz.highlight_random_patches(image_name, att_max_i_id, patch_positions) with col2: create_barchart_h( neighbors, labels, adj, 'Attention weights in the 20 words around the query') with col3: labels2 = [seq[id] for id in att_max_t_id] create_barchart_h(att_max_t_v, labels2, adj, 'Attention weights of the top 10 words', topk=True) st.subheader('Top-10 Attention Patches') show_patches(image_name, att_max_i, patch_positions, None) t = highlight_and_bold(nopad_string, adj, adj_occ) st.subheader('Design Description') st.markdown(t, unsafe_allow_html=True) chart_data = pd.DataFrame(all_att_text.detach().numpy()) st.subheader('Attention Weights of the entire sequence') st.area_chart(chart_data, use_container_width=True)