Exemplo n.º 1
0
 def forward(self, sentences, device='cuda'):
     """
     sentences: list[str], len of list: B
     output sent_embs: Tensor B x OUT
     """
     sentences = [WordEncoder.tokenize(s) for s in sentences]
     # sentences = [['First', 'sentence', '.'], ['Another', '.']]
     # use batch_to_ids to convert sentences to character ids
     character_ids = batch_to_ids(sentences).to(device)
     embeddings = self.elmo(character_ids)
     # embeddings['elmo_representations'] is length two list of tensors.
     # Each element contains one layer of ELMo representations with shape
     # (2, 3, 1024).
     #   2    - the batch size
     #   3    - the sequence length of the batch
     #   1024 - the length of each ELMo vector
     sent_embeds = embeddings['elmo_representations'][1]  # B x max_l x 1024
     sent_emb_list = list()
     for si in range(len(sentences)):
         sent_len = len(sentences[si])
         sent_embed = torch.mean(sent_embeds[si, :sent_len, :],
                                 dim=0)  # 1024
         sent_emb_list.append(sent_embed)
     sent_embs = torch.stack(sent_emb_list, dim=0)  # B x 1024
     return sent_embs
Exemplo n.º 2
0
def add_space_to_cap_dict(cap_dict):
    new_dict = dict()
    for img_name, caps in cap_dict.items():
        new_dict[img_name] = list()
        for cap in caps:
            tokens = WordEncoder.tokenize(cap)
            if len(tokens) > 0:
                new_cap = ' '.join(tokens)
            else:
                new_cap = cap
            new_dict[img_name].append(new_cap)
    return new_dict
Exemplo n.º 3
0
def compare_visualize(split='test',
                      html_path='visualizations/caption.html',
                      visualize_count=100):
    dataset = TextureDescriptionData()
    word_encoder = WordEncoder()
    # cls_predictions = top_k_caption(top_k=5, model_type='cls', dataset=dataset, split=split)
    # with open('output/naive_classify/v1_35_ft2,4_fc512_tuneTrue/caption_top5_%s.json' % split, 'w') as f:
    #     json.dump(cls_predictions, f)
    # tri_predictions = top_k_caption(top_k=5, model_type='tri', dataset=dataset, split=split)
    # with open('output/triplet_match/c34_bert_l2_s_lr0.00001/caption_top5_%s.json' % split, 'w') as f:
    #     json.dump(tri_predictions, f)
    cls_predictions = json.load(
        open(
            'output/naive_classify/v1_35_ft2,4_fc512_tuneTrue/caption_top5_%s.json'
            % split))
    tri_predictions = json.load(
        open(
            'output/triplet_match/c34_bert_l2_s_lr0.00001/caption_top5_%s.json'
            % split))
    sat_predictions = json.load(
        open('output/show_attend_tell/results/pred_v2_last_beam5_%s.json' %
             split))
    pred_dicts = [cls_predictions, tri_predictions, sat_predictions]
    img_pref = 'https://www.robots.ox.ac.uk/~vgg/data/dtd/thumbs/'

    html_str = '''<!DOCTYPE html>
<html lang="en">
<head>
    <title>Caption visualize</title>
    <link rel="stylesheet" href="caption_style.css">
</head>
<body>
<table>
    <col class="column-one">
    <col class="column-two">
    <col class="column-three">
    <tr>
        <th style="text-align: center">Image</th>
        <th>Predicted captions</th>
        <th>Ground-truth descriptions</th>
    </tr>
'''

    for img_i, img_name in enumerate(dataset.img_splits[split]):
        gt_descs = dataset.img_data_dict[img_name]['descriptions']
        gt_desc_str = '|'.join(gt_descs)
        gt_html_str = ''
        for ci, cap in enumerate(gt_descs):
            gt_html_str += '[%d] %s<br>\n' % (ci + 1, cap)

        pred_caps = [pred_dict[img_name][0] for pred_dict in pred_dicts]
        for ci, cap in enumerate(pred_caps):
            tokens = WordEncoder.tokenize(cap)
            for ti, t in enumerate(tokens):
                if t in gt_desc_str and len(t) > 1:
                    tokens[ti] = '<span class="correct">%s</span>' % t
            pred_caps[ci] = word_encoder.detokenize(tokens)
        html_str += '''
<tr>
    <td>
        <img src={img_pref}{img_name} alt="{img_name}">
    </td>
    <td>
        <span class="pred_name">Classifier top 5:</span><br>
        {pred0}<br>
        <span class="pred_name">Triplet top 5:</span><br>
        {pred1}<br>
        <span class="pred_name">Show-attend-tell:</span><br>
        {pred2}<br>
    </td>
    <td>
        {gt}
    </td>
</tr>
'''.format(img_pref=img_pref,
           img_name=img_name,
           pred0=pred_caps[0],
           pred1=pred_caps[1],
           pred2=pred_caps[2],
           gt=gt_html_str)

        if img_i >= visualize_count:
            break

    html_str += '</table>\n</body\n></html>'
    with open(html_path, 'w') as f:
        f.write(html_str)

    return
Exemplo n.º 4
0
def caption_visualize(pred_dicts=None, filter=False):
    if pred_dicts is None:
        pred_dicts = np.load('applications/synthetic_imgs/visualizations/results/caption.npy')
    syn_dataset = SyntheticData()
    word_encoder = WordEncoder()
    img_pref = '../modified_imgs/'
    html_str = '''<!DOCTYPE html>
    <html lang="en">
    <head>
        <title>Caption visualize</title>
        <style>
        .correct {
            font-weight: bold;
        }
        .pred_name {
            color: ROYALBLUE;
            font-weight: bold;
        }
        img {
           width: 3cm
        }
        table {
            border-collapse: collapse;
        }
        tr {
            border-bottom: 1px solid lightgray;
        }

        </style>
    </head>
    <body>
    <table>
        <col class="column-one">
        <col class="column-two">
        <tr>
            <th style="text-align: center">Image</th>
            <th>Predicted captions</th>
        </tr>
    '''
    for idx in range(len(syn_dataset)):
        pred_caps = [pred_dict[idx][0] for pred_dict in pred_dicts]
        img_i, c1_i, c2_i = syn_dataset.unravel_index(idx)

        if filter:
            good_cap = False
            c1 = syn_dataset.color_names[c1_i]
            c2 = syn_dataset.color_names[c2_i]
            for ci, cap in enumerate(pred_caps):
                if c1 in cap and c2 in cap:
                    good_cap = True
                    break
            if not good_cap:
                continue

        img_name = '%s_%s_%s.jpg' % (syn_dataset.img_names[img_i].split('.')[0],
                                     syn_dataset.color_names[c1_i],
                                     syn_dataset.color_names[c2_i])
        gt_desc = syn_dataset.get_desc(img_i, c1_i, c2_i)

        for ci, cap in enumerate(pred_caps):
            tokens = WordEncoder.tokenize(cap)
            for ti, t in enumerate(tokens):
                if t in gt_desc and len(t) > 1:
                    tokens[ti] = '<span class="correct">%s</span>' % t
            pred_caps[ci] = word_encoder.detokenize(tokens)
        html_str += '''
    <tr>
        <td>
            <img src={img_pref}{img_name} alt="{img_name}">
        </td>
        <td>
            <span class="pred_name">Synthetic Ground-truth Description:</span><br>
            {gt}<br>
            <span class="pred_name">Classifier top 5:</span><br>
            {pred0}<br>
            <span class="pred_name">Triplet top 5:</span><br>
            {pred1}<br>
            <span class="pred_name">Show-attend-tell:</span><br>
            {pred2}<br>
        </td>
    </tr>
    '''.format(img_pref=img_pref, img_name=img_name, pred0=pred_caps[0], pred1=pred_caps[1], pred2=pred_caps[2],
               gt=gt_desc)
    html_str += '</table>\n</body\n></html>'

    html_name = 'caption.html'
    if filter:
        html_name = 'caption_filtered.html'
    with open('applications/synthetic_imgs/visualizations/results/' + html_name, 'w') as f:
        f.write(html_str)
    return