示例#1
0
def get_embeddings(dataset,
                   save_file,
                   pretrained_model=None,
                   random_patches=False):
    torch.cuda.empty_cache()
    torch.manual_seed(0)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
    )
    print('dataloader len: ', len(dataloader))

    if pretrained_model != None:
        embedder = codedbert_embedder.from_pretrained(
            pretrained_model, output_hidden_states=True, return_dict=True)
    else:
        embedder = codedbert_embedder.from_pretrained(
            'bert-base-uncased', output_hidden_states=True, return_dict=True)

    embedder.to(device)
    embedder.eval()
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    codedbert_embeds = dict()
    with torch.no_grad():
        for i, (patches, _, input_ids, attention_mask, _, _,
                img_name) in enumerate(tqdm(dataloader)):
            input_ids = input_ids.to(device)
            patches = patches.to(device)

            inputs = input_ids.squeeze(0).detach().tolist()
            seq = tokenizer.convert_ids_to_tokens(inputs)
            seq = tokenizer.convert_tokens_to_string(seq)
            embeds = construct_bert_input(patches, input_ids, embedder, device,
                                          random_patches)
            attention_mask = F.pad(attention_mask,
                                   (0, embeds.shape[1] - input_ids.shape[1]),
                                   value=1)
            text_emb, img_emb = embedder.embed(embeds, attention_mask)

            codedbert_embeds[img_name[0]] = {
                'text': seq,
                'text_emb': text_emb.tolist(),
                'img_emb': img_emb.tolist()
            }

    save_json(save_file, codedbert_embeds)
示例#2
0
def train(coded_bert, dataset, params, device, random_patches=False):
    print('Random patches? ', random_patches)
    torch.manual_seed(0)
    train_size = int(len(dataset) * .8)
    test_size = len(dataset) - train_size
    train_set, test_set = torch.utils.data.random_split(
        dataset, [train_size, test_size])
    dataloader = torch.utils.data.DataLoader(
        train_set,
        batch_size=params.batch_size,
        shuffle=True,
    )

    coded_bert.to(device)
    coded_bert.train()
    opt = transformers.Adafactor(coded_bert.parameters(),
                                 lr=params.lr,
                                 beta1=params.beta1,
                                 weight_decay=params.weight_decay,
                                 clip_threshold=params.clip,
                                 relative_step=False,
                                 scale_parameter=True,
                                 warmup_init=False)

    scheduler = get_linear_schedule_with_warmup(
        opt, params.num_warmup_steps, params.num_epochs * len(dataloader))

    for ep in range(params.num_epochs):
        avg_losses = {
            "masked_lm_loss": [],
            "masked_patch_loss": [],
            "alignment_loss_text": [],
            "alignment_loss_sent": [],
            "total": []
        }
        for patches, input_ids, is_paired, attention_mask, sents_embeds in dataloader:
            # sents_embeds = [batch, 16, 768]
            opt.zero_grad()

            # mask image patches with prob 10%
            im_seq_len = patches.shape[1]
            masked_patches = patches.detach().clone()
            masked_patches = masked_patches.view(-1, patches.shape[2])
            im_mask = torch.rand((masked_patches.shape[0], 1)) >= 0.1
            masked_patches *= im_mask

            try:
                masked_patches = masked_patches.view(params.batch_size,
                                                     im_seq_len,
                                                     patches.shape[2])
            except Exception as e:
                print(e)
                print(f"masked_patches: {masked_patches.shape}")
                print(f"im_mask: {im_mask.shape}")
                print(f"patches: {patches.shape}")
                continue

            # mask tokens with prob 15%, note id 103 is the [MASK] token
            token_mask = torch.rand(input_ids.shape)
            masked_input_ids = input_ids.detach().clone()
            masked_input_ids[token_mask < 0.15] = 103

            input_ids[token_mask >= 0.15] = -100
            input_ids[attention_mask == 0] = -100

            masked_patches = masked_patches.to(device)
            embeds = construct_bert_input(masked_patches,
                                          masked_input_ids,
                                          coded_bert,
                                          sentences=sents_embeds,
                                          device=device,
                                          random_patches=random_patches)
            # pad attention mask with 1s so model pays attention to the image parts
            attention_mask = F.pad(
                attention_mask, (0, embeds.shape[1] - input_ids.shape[1]),
                value=1
            )  # still works with sentences since we are using 1s for sent and imgs

            outputs = coded_bert(embeds=embeds.to(device),
                                 attention_mask=attention_mask.to(device),
                                 labels=input_ids.to(device),
                                 unmasked_patch_features=patches.to(device),
                                 is_paired=is_paired.to(device))

            loss = adaptive_loss_4losses(outputs)

            loss.backward()
            opt.step()
            scheduler.step()

            for k, v in outputs.items():
                if k in avg_losses:
                    avg_losses[k].append(v.cpu().item())
            avg_losses["total"].append(loss.cpu().item())

        print("***************************")
        print(f"At epoch {ep+1}, losses: ")
        for k, v in avg_losses.items():
            print(f"{k}: {sum(v) / len(v)}")
        print("***************************")
示例#3
0
def train(fashion_bert, dataset, params, device):
    torch.manual_seed(0)
    train_size = int(len(dataset) * .8)
    test_size = len(dataset) - train_size
    train_set, test_set = torch.utils.data.random_split(
        dataset, [train_size, test_size])
    dataloader = torch.utils.data.DataLoader(
        train_set,
        batch_size=params.batch_size,
        shuffle=True,
    )

    fashion_bert.to(device)
    fashion_bert.train()
    opt = transformers.Adafactor(fashion_bert.parameters(),
                                 lr=params.lr,
                                 beta1=params.beta1,
                                 weight_decay=params.weight_decay,
                                 clip_threshold=params.clip,
                                 relative_step=False,
                                 scale_parameter=True,
                                 warmup_init=False)

    POS_MAP = Pos.POS_MAP
    ADJ_IDX = 1

    IDX2POS = dict()
    POS2IDX = dict()
    for i, pos in enumerate(POS_MAP):
        IDX2POS[i] = pos
        POS2IDX[pos] = i

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    scheduler = get_linear_schedule_with_warmup(
        opt, params.num_warmup_steps, params.num_epochs * len(dataloader))

    for ep in range(params.num_epochs):

        avg_losses = {
            "masked_lm_loss": [],
            "masked_patch_loss": [],
            "alignment_loss": [],
            "total": []
        }
        for b, (patches, input_ids, is_paired,
                attention_mask) in enumerate(dataloader):
            opt.zero_grad()

            # mask image patches with prob 10%
            im_seq_len = patches.shape[1]
            masked_patches = patches.detach().clone()
            masked_patches = masked_patches.view(-1, patches.shape[2])
            im_mask = torch.rand((masked_patches.shape[0], 1)) >= 0.1
            masked_patches *= im_mask
            try:
                masked_patches = masked_patches.view(params.batch_size,
                                                     im_seq_len,
                                                     patches.shape[2])
            except Exception as e:
                print(e)
                print(f"masked_patches: {masked_patches.shape}")
                print(f"im_mask: {im_mask.shape}")
                print(f"patches: {patches.shape}")
                continue

            # mask tokens with prob 15%, note id 103 is the [MASK] token


#             token_mask = torch.rand(input_ids.shape)
#             masked_input_ids = input_ids.detach().clone()
#             masked_input_ids[token_mask < 0.15] = 103

#             input_ids[token_mask >= 0.15] = -100
#             input_ids[attention_mask == 0] = -100

# Get POS tensor
            pos_tensor = get_pos_tensor(input_ids, tokenizer, POS2IDX)
            masked_input_ids, labels = mask_input_ids(
                input_ids,
                pos_tensor,
                adj_mask_prob=.80,
                other_mask_prob=.12,
                adj_value=ADJ_IDX,
                attention_mask=attention_mask,
                mask_value=103)

            embeds = construct_bert_input(masked_patches,
                                          masked_input_ids,
                                          fashion_bert,
                                          device=device)
            # pad attention mask with 1s so model pays attention to the image parts
            attention_mask = F.pad(attention_mask,
                                   (0, embeds.shape[1] - input_ids.shape[1]),
                                   value=1)

            if b % 100 == 0:
                attention_mask[:, :448] = 0

            outputs = fashion_bert(embeds=embeds.to(device),
                                   attention_mask=attention_mask.to(device),
                                   labels=input_ids.to(device),
                                   unmasked_patch_features=patches.to(device),
                                   is_paired=is_paired.to(device))

            loss = adaptive_loss(outputs)

            loss.backward()
            opt.step()
            scheduler.step()

            for k, v in outputs.items():
                if k in avg_losses:
                    avg_losses[k].append(v.cpu().item())
            avg_losses["total"].append(loss.cpu().item())

        print("************TRAINING*************")
        print(f"At epoch {ep+1}, losses: ")
        for k, v in avg_losses.items():
            print(f"{k}: {sum(v) / len(v)}")
        print("***************************")
def text2image(patches, neg_patches, input_ids, is_paired, attention_mask,
               neg_input_ids, neg_attention_mask, sents_embeds, evaluator,
               random_patches):
    """
    text2image retrieval:
        Query = Text
        Paired with: 1 positive image, 100 negative images
    """
    im_seq_len = patches.shape[1]
    bs = input_ids.shape[0]
    len_neg_inputs = neg_input_ids.shape[1]

    # before constructing bert, att mask is 448 long
    # POSITIVE IMAGE
    embeds = construct_bert_input(patches,
                                  input_ids,
                                  evaluator,
                                  sentences=sents_embeds,
                                  device=device,
                                  random_patches=random_patches)
    attention_mask_mm = F.pad(attention_mask,
                              (0, embeds.shape[1] - input_ids.shape[1]),
                              value=1)  # [1, 512]

    # NEGATIVE SAMPLES
    all_embeds_neg = []
    all_att_mask = []

    for p in range(len_neg_inputs):
        neg_patches_sample = neg_patches[:, p, :, :]
        embeds_neg = construct_bert_input(neg_patches_sample,
                                          input_ids,
                                          evaluator,
                                          sentences=sents_embeds,
                                          device=device,
                                          random_patches=random_patches)
        attention_mask_neg = F.pad(
            attention_mask, (0, embeds_neg.shape[1] - input_ids.shape[1]),
            value=1)

        all_embeds_neg.append(embeds_neg)
        all_att_mask.append(attention_mask_neg)

    # Now I have all joint embeddings for 1 positive sample and 100 neg samples
    all_scores_query = evaluator.text2img_scores(
        input_ids=input_ids,
        embeds=embeds,
        att_mask=attention_mask_mm,
        embeds_n=all_embeds_neg,  # list
        att_mask_n=all_att_mask)  # list

    # Accuracy: only in positive example
    txt_acc, alig_acc, sent_alig_acc = evaluator.get_scores_and_metrics(
        embeds,  # text + image embedded
        attention_mask_mm,  # [batch,
        labels=input_ids,  # [batch, 448]
        is_paired=is_paired,  # [batch]
        only_alignment=False,
    )

    return all_scores_query, txt_acc, alig_acc, sent_alig_acc
def image2text(patches, neg_patches, input_ids, is_paired, attention_mask,
               neg_input_ids, neg_attention_mask, sents_embeds, evaluator,
               random_patches):
    """
    image2text retrieval:
        Query = Image
        Paired with: 1 positive text, 100 negative texts
    """
    im_seq_len = patches.shape[1]
    bs = input_ids.shape[0]
    len_neg_inputs = neg_input_ids.shape[1]

    embeds = construct_bert_input(patches,
                                  input_ids,
                                  evaluator,
                                  sentences=sents_embeds,
                                  device=device,
                                  random_patches=random_patches)
    attention_mask_mm = F.pad(attention_mask,
                              (0, embeds.shape[1] - input_ids.shape[1]),
                              value=1)

    # NEGATIVE SAMPLE # [batch, 100, 448]
    all_embeds_neg = []
    all_att_mask = []
    all_neg_inputs = []

    for j in range(len_neg_inputs):
        neg_input_id_sample = neg_input_ids[:, j, :]  # [1, 448]
        neg_attention_mask_sample = neg_attention_mask[:, j, :]

        embeds_neg = construct_bert_input(patches,
                                          neg_input_id_sample,
                                          evaluator,
                                          sentences=sents_embeds,
                                          device=device,
                                          random_patches=random_patches)
        attention_mask_neg = F.pad(
            neg_attention_mask_sample,
            (0, embeds_neg.shape[1] - neg_input_id_sample.shape[1]),
            value=1)

        all_embeds_neg.append(embeds_neg)
        all_att_mask.append(attention_mask_neg)
        all_neg_inputs.append(neg_input_id_sample.detach())

    # Now I have all joint embeddings for 1 positive sample and 100 neg samples
    all_scores_query = evaluator.img2text_scores(input_ids_p=input_ids,
                                                 embeds_p=embeds,
                                                 att_mask_p=attention_mask_mm,
                                                 input_ids_n=all_neg_inputs,
                                                 embeds_n=all_embeds_neg,
                                                 att_mask_n=all_att_mask)

    # Accuracy: only in positive example
    txt_acc, alig_acc, sent_alig_acc = evaluator.get_scores_and_metrics(
        embeds,  # text + image embedded
        attention_mask_mm,
        labels=input_ids,  # [batch, 448]
        is_paired=is_paired,  # [batch]
        only_alignment=False,
    )

    return all_scores_query, txt_acc, alig_acc, sent_alig_acc
示例#6
0
def test(path_to_dataset,
         sample,
         device,
         slider_att,
         slider_head,
         pretrained_model=None):
    im_path_fur = '/Users/manuelladron/iCloud_archive/Documents/_CMU/PHD-CD/PHD-CD_Research/ADARI/images/ADARI_v2/furniture/full'
    viz, tokenizer, patches, input_ids, attention_mask, img_name, patch_positions = set_model(
        pretrained_model, sample, device, path_to_dataset)

    # Text
    inputs = input_ids.squeeze(0).detach().tolist()
    seq = tokenizer.convert_ids_to_tokens(inputs)
    seq_st = tokenizer.convert_tokens_to_string(seq)

    # Image
    image_name = im_path_fur + '/' + img_name

    embeds = construct_bert_input(patches, input_ids, viz, device=device)
    attention_mask = F.pad(attention_mask,
                           (0, embeds.shape[1] - input_ids.shape[1]),
                           value=1)  # [1, 512]

    # print('SEQUENCE')
    # print(seq_st)
    #
    # print('QUERY WORD')
    button = st.sidebar.radio('Limit words to adjectives?', ('Yes', 'No'))
    if button == 'Yes':
        adjs = get_adjs(seq)  # list with [ADJ, INDEX IN STRING, OCCURRENCE]
    else:
        adjs = get_all_words(seq)
    adj_idx_pair = st.sidebar.selectbox('Select adjective to see attention',
                                        adjs)
    adj = adj_idx_pair[0]
    idx_adj = adj_idx_pair[1]
    adj_occ = adj_idx_pair[2]

    # Text without pad
    new_seq = []
    for token in seq:
        if token != '[PAD]':
            new_seq.append(token)
    nopad_string = tokenizer.convert_tokens_to_string(new_seq)

    #### Gets attention viz
    att_max_t, att_min_t, att_max_i, att_min_i, all_att_text = viz.get_att(
        embeds,  # text + image embedded
        attention_mask,
        seq_number=idx_adj,
        attention_layer=slider_att,
        attention_head=slider_head)

    att_max_t_id = att_max_t[1].tolist()
    att_max_t_v = att_max_t[0].tolist()
    att_min_t = att_min_t[1].tolist()
    att_max_i_id = att_max_i[1].tolist()
    att_min_i = att_min_i[1].tolist()

    N = 20
    # Get N attentions around word
    if idx_adj >= N:
        neighbors = all_att_text[idx_adj - N:idx_adj + N]
        if (idx_adj) + N >= 448:
            neighbors_ids = range(idx_adj - N, 448)
        else:
            neighbors_ids = range(idx_adj - N, idx_adj + N)
    else:
        neighbors = all_att_text[:idx_adj + N]
        neighbors_ids = range(0, idx_adj + N)

    neighbors = neighbors.tolist()
    labels = [seq[id] for id in neighbors_ids]
    # This line prevents from going over the limit of the text sequence
    labels = labels[:len(neighbors)]

    col1, col2, col3 = st.beta_columns(3)
    # Visualize patches that receive most attention
    with col1:
        viz.highlight_random_patches(image_name, att_max_i_id, patch_positions)
    with col2:
        create_barchart_h(
            neighbors, labels, adj,
            'Attention weights in the 20 words around the query')
    with col3:
        labels2 = [seq[id] for id in att_max_t_id]
        create_barchart_h(att_max_t_v,
                          labels2,
                          adj,
                          'Attention weights of the top 10 words',
                          topk=True)

    st.subheader('Top-10 Attention Patches')
    show_patches(image_name, att_max_i, patch_positions, None)
    t = highlight_and_bold(nopad_string, adj, adj_occ)
    st.subheader('Design Description')
    st.markdown(t, unsafe_allow_html=True)
    chart_data = pd.DataFrame(all_att_text.detach().numpy())
    st.subheader('Attention Weights of the entire sequence')
    st.area_chart(chart_data, use_container_width=True)