def shard_xattn_i2t(images, captions, caplens, opt, tag_masks, shard_size=128):
    """
    Computer pairwise i2t image-caption distance with locality sharding
    """
    n_im_shard = (len(images) - 1) // shard_size + 1
    n_cap_shard = (len(captions) - 1) // shard_size + 1

    d = np.zeros((len(images), len(captions)))
    for i in range(n_im_shard):
        im_start, im_end = shard_size * i, min(shard_size * (i + 1),
                                               len(images))
        for j in range(n_cap_shard):
            sys.stdout.write('\r>> shard_xattn_i2t batch (%d,%d)' % (i, j))
            cap_start, cap_end = shard_size * j, min(shard_size * (j + 1),
                                                     len(captions))
            # im = Variable(torch.from_numpy(images[im_start:im_end]), volatile=True).cuda()
            # s = Variable(torch.from_numpy(captions[cap_start:cap_end]), volatile=True).cuda()
            im = torch.from_numpy(images[im_start:im_end]).cuda()
            s = torch.from_numpy(captions[cap_start:cap_end]).cuda()
            l = caplens[cap_start:cap_end]
            #======================================================
            batch_tag_masks = tag_masks[cap_start:cap_end]
            #======================================================
            if opt.pos:
                sim = xattn_score_i2t(im, s, l, opt, batch_tag_masks)
            else:
                sim = xattn_score_i2t(im, s, l, opt)
            d[im_start:im_end, cap_start:cap_end] = sim.data.cpu().numpy()
    sys.stdout.write('\n')
    return d
Esempio n. 2
0
def shard_xattn_i2t(images, captions, caplens, freqs, opt, shard_size=128):
    """
    Computer pairwise i2t image-caption distance with locality sharding
    """
    n_im_shard = (len(images) - 1) / shard_size + 1
    n_cap_shard = (len(captions) - 1) / shard_size + 1

    attention = []
    d = np.zeros((len(images), len(captions)))
    for i in range(int(n_im_shard)):
        im_start, im_end = shard_size * i, min(shard_size * (i + 1),
                                               len(images))
        for j in range(int(n_cap_shard)):
            sys.stdout.write('\r>> shard_xattn_i2t batch (%d,%d)' % (i, j))
            cap_start, cap_end = shard_size * j, min(shard_size * (j + 1),
                                                     len(captions))
            if torch.cuda.is_available():
                im = Variable(torch.from_numpy(images[im_start:im_end]),
                              volatile=True).cuda()
                s = Variable(torch.from_numpy(captions[cap_start:cap_end]),
                             volatile=True).cuda()
            else:
                im = Variable(torch.from_numpy(images[im_start:im_end]),
                              volatile=True)
                s = Variable(torch.from_numpy(captions[cap_start:cap_end]),
                             volatile=True)
            l = caplens[cap_start:cap_end]
            sim, attn = xattn_score_i2t(im, s, l, freqs, opt)
            d[im_start:im_end, cap_start:cap_end] = sim.data.cpu().numpy()
            attention = attention + attn
    sys.stdout.write('\n')
    return d, attention
Esempio n. 3
0
def create_attn(data_loader, positions, opt, model):
    # collect all attention scores
    img_features = []

    with torch.no_grad():

        for i, (images, captions, lengths, ids, freq_score,
                freqs) in enumerate(data_loader):
            # compute the embeddings
            img_emb, cap_emb, cap_len = model.forward_emb(images,
                                                          captions,
                                                          lengths,
                                                          volatile=True)

            if opt.cross_attn == "i2t":
                sim, attn = xattn_score_i2t(img_emb, cap_emb, cap_len, freqs,
                                            opt)
            else:
                row_sim = xattn_score_t2i_cosine(img_emb, cap_emb, cap_len,
                                                 freqs, opt).squeeze()

            img_features.append(row_sim[positions[i]])
    features = torch.stack(img_features, dim=0)

    return features
Esempio n. 4
0
def shard_xattn_i2t(images, captions, caplens, opt, shard_size=128):
    """
    Computer pairwise i2t image-caption distance with locality sharding
    """
    n_im_shard = (len(images) - 1) / shard_size + 1
    n_cap_shard = (len(captions) - 1) / shard_size + 1
    n_im_shard = int(n_im_shard)
    n_cap_shard = int(n_cap_shard)
    d = np.zeros((len(images), len(captions)))
    for i in range(n_im_shard):
        im_start, im_end = shard_size * i, min(shard_size * (i + 1),
                                               len(images))
        for j in range(n_cap_shard):
            sys.stdout.write('\r>> shard_xattn_i2t batch (%d,%d)' % (i, j))
            cap_start, cap_end = shard_size * j, min(shard_size * (j + 1),
                                                     len(captions))
            im = torch.from_numpy(images[im_start:im_end]).cuda()
            s = torch.from_numpy(captions[cap_start:cap_end]).cuda()
            l = caplens[cap_start:cap_end]
            sim = xattn_score_i2t(im, s, l, opt)
            d[im_start:im_end, cap_start:cap_end] = sim.data.cpu().numpy()
    sys.stdout.write('\n')
    return d
Esempio n. 5
0
def create_attn(data_loader, positions, opt, model):
    # collect all attention scores
    total_attn = []

    with torch.no_grad():

        for i, (images, captions, lengths, ids, freq_score,
                freqs) in enumerate(data_loader):
            # compute the embeddings
            img_emb, cap_emb, cap_len = model.forward_emb(images,
                                                          captions,
                                                          lengths,
                                                          volatile=True)

            if opt.cross_attn == "i2t":
                sim, attn = xattn_score_i2t(img_emb, cap_emb, cap_len, freqs,
                                            opt)
            else:
                sim, attn = xattn_score_t2i(img_emb, cap_emb, cap_len, freqs,
                                            opt)

            total_attn.append(attn[0][0, :, positions[i]])
    return total_attn
Esempio n. 6
0
def xattn_sim(images, captions, caplens, opt, shard_size=64):
    """
    Computer pairwise t2i image-caption distance with locality sharding
    """
    n_im_shard = (len(images) - 1) // shard_size + 1
    n_cap_shard = (len(captions) - 1) // shard_size + 1

    d = np.zeros((len(images), len(captions)))
    for i in range(n_im_shard):
        im_start, im_end = shard_size * i, min(shard_size * (i + 1), len(images))
        for j in range(n_cap_shard):
            sys.stdout.write('\r>> xattn_sim batch (%d,%d)' % (i, j))
            cap_start, cap_end = shard_size * j, min(shard_size * (j + 1), len(captions))
            im = torch.from_numpy(images[im_start:im_end]).cuda()
            s = torch.from_numpy(captions[cap_start:cap_end]).cuda()
            l = caplens[cap_start:cap_end]
            t2i_scores = xattn_score_t2i(im, s, l, opt.lambda_softmax,
                                         opt.norm_func, opt.agg_func, opt.lambda_lse)
            i2t_scores = xattn_score_i2t(im, s, l, opt.lambda_softmax,
                                         opt.norm_func, opt.agg_func, opt.lambda_lse)
            sim = opt.alpha * t2i_scores + (1 - opt.alpha) * i2t_scores
            d[im_start:im_end, cap_start:cap_end] = sim.data.cpu().numpy()
    sys.stdout.write('\n')
    return d