Exemplo n.º 1
0
def compare_visualize(split='test',
                      html_path='visualizations/caption.html',
                      visualize_count=100):
    dataset = TextureDescriptionData()
    word_encoder = WordEncoder()
    # cls_predictions = top_k_caption(top_k=5, model_type='cls', dataset=dataset, split=split)
    # with open('output/naive_classify/v1_35_ft2,4_fc512_tuneTrue/caption_top5_%s.json' % split, 'w') as f:
    #     json.dump(cls_predictions, f)
    # tri_predictions = top_k_caption(top_k=5, model_type='tri', dataset=dataset, split=split)
    # with open('output/triplet_match/c34_bert_l2_s_lr0.00001/caption_top5_%s.json' % split, 'w') as f:
    #     json.dump(tri_predictions, f)
    cls_predictions = json.load(
        open(
            'output/naive_classify/v1_35_ft2,4_fc512_tuneTrue/caption_top5_%s.json'
            % split))
    tri_predictions = json.load(
        open(
            'output/triplet_match/c34_bert_l2_s_lr0.00001/caption_top5_%s.json'
            % split))
    sat_predictions = json.load(
        open('output/show_attend_tell/results/pred_v2_last_beam5_%s.json' %
             split))
    pred_dicts = [cls_predictions, tri_predictions, sat_predictions]
    img_pref = 'https://www.robots.ox.ac.uk/~vgg/data/dtd/thumbs/'

    html_str = '''<!DOCTYPE html>
<html lang="en">
<head>
    <title>Caption visualize</title>
    <link rel="stylesheet" href="caption_style.css">
</head>
<body>
<table>
    <col class="column-one">
    <col class="column-two">
    <col class="column-three">
    <tr>
        <th style="text-align: center">Image</th>
        <th>Predicted captions</th>
        <th>Ground-truth descriptions</th>
    </tr>
'''

    for img_i, img_name in enumerate(dataset.img_splits[split]):
        gt_descs = dataset.img_data_dict[img_name]['descriptions']
        gt_desc_str = '|'.join(gt_descs)
        gt_html_str = ''
        for ci, cap in enumerate(gt_descs):
            gt_html_str += '[%d] %s<br>\n' % (ci + 1, cap)

        pred_caps = [pred_dict[img_name][0] for pred_dict in pred_dicts]
        for ci, cap in enumerate(pred_caps):
            tokens = WordEncoder.tokenize(cap)
            for ti, t in enumerate(tokens):
                if t in gt_desc_str and len(t) > 1:
                    tokens[ti] = '<span class="correct">%s</span>' % t
            pred_caps[ci] = word_encoder.detokenize(tokens)
        html_str += '''
<tr>
    <td>
        <img src={img_pref}{img_name} alt="{img_name}">
    </td>
    <td>
        <span class="pred_name">Classifier top 5:</span><br>
        {pred0}<br>
        <span class="pred_name">Triplet top 5:</span><br>
        {pred1}<br>
        <span class="pred_name">Show-attend-tell:</span><br>
        {pred2}<br>
    </td>
    <td>
        {gt}
    </td>
</tr>
'''.format(img_pref=img_pref,
           img_name=img_name,
           pred0=pred_caps[0],
           pred1=pred_caps[1],
           pred2=pred_caps[2],
           gt=gt_html_str)

        if img_i >= visualize_count:
            break

    html_str += '</table>\n</body\n></html>'
    with open(html_path, 'w') as f:
        f.write(html_str)

    return
Exemplo n.º 2
0
def caption_visualize(pred_dicts=None, filter=False):
    if pred_dicts is None:
        pred_dicts = np.load('applications/synthetic_imgs/visualizations/results/caption.npy')
    syn_dataset = SyntheticData()
    word_encoder = WordEncoder()
    img_pref = '../modified_imgs/'
    html_str = '''<!DOCTYPE html>
    <html lang="en">
    <head>
        <title>Caption visualize</title>
        <style>
        .correct {
            font-weight: bold;
        }
        .pred_name {
            color: ROYALBLUE;
            font-weight: bold;
        }
        img {
           width: 3cm
        }
        table {
            border-collapse: collapse;
        }
        tr {
            border-bottom: 1px solid lightgray;
        }

        </style>
    </head>
    <body>
    <table>
        <col class="column-one">
        <col class="column-two">
        <tr>
            <th style="text-align: center">Image</th>
            <th>Predicted captions</th>
        </tr>
    '''
    for idx in range(len(syn_dataset)):
        pred_caps = [pred_dict[idx][0] for pred_dict in pred_dicts]
        img_i, c1_i, c2_i = syn_dataset.unravel_index(idx)

        if filter:
            good_cap = False
            c1 = syn_dataset.color_names[c1_i]
            c2 = syn_dataset.color_names[c2_i]
            for ci, cap in enumerate(pred_caps):
                if c1 in cap and c2 in cap:
                    good_cap = True
                    break
            if not good_cap:
                continue

        img_name = '%s_%s_%s.jpg' % (syn_dataset.img_names[img_i].split('.')[0],
                                     syn_dataset.color_names[c1_i],
                                     syn_dataset.color_names[c2_i])
        gt_desc = syn_dataset.get_desc(img_i, c1_i, c2_i)

        for ci, cap in enumerate(pred_caps):
            tokens = WordEncoder.tokenize(cap)
            for ti, t in enumerate(tokens):
                if t in gt_desc and len(t) > 1:
                    tokens[ti] = '<span class="correct">%s</span>' % t
            pred_caps[ci] = word_encoder.detokenize(tokens)
        html_str += '''
    <tr>
        <td>
            <img src={img_pref}{img_name} alt="{img_name}">
        </td>
        <td>
            <span class="pred_name">Synthetic Ground-truth Description:</span><br>
            {gt}<br>
            <span class="pred_name">Classifier top 5:</span><br>
            {pred0}<br>
            <span class="pred_name">Triplet top 5:</span><br>
            {pred1}<br>
            <span class="pred_name">Show-attend-tell:</span><br>
            {pred2}<br>
        </td>
    </tr>
    '''.format(img_pref=img_pref, img_name=img_name, pred0=pred_caps[0], pred1=pred_caps[1], pred2=pred_caps[2],
               gt=gt_desc)
    html_str += '</table>\n</body\n></html>'

    html_name = 'caption.html'
    if filter:
        html_name = 'caption_filtered.html'
    with open('applications/synthetic_imgs/visualizations/results/' + html_name, 'w') as f:
        f.write(html_str)
    return