예제 #1
0
def MarginAbs(em, ofp, params, args, stats):
    D, I = params.idx.search(em, args.kmax)
    thresh = args.threshold_faiss
    if args.embed:
        D, I = IndexDistL2(em, params.E, D, I, args.threshold_faiss)
        thresh = args.threshold_L2

    for n in range(D.shape[0]):

        prev = {}  # for deduplication
        for i in range(args.kmax):
            txt = IndexTextQuery(params.T, params.R, I[n, i])
            if (args.dedup and txt not in prev) and D[n, i] <= thresh:
                prev[txt] = 1
                ofp.write('{:d}\t{:7.5f}\t{}\n'.format(stats.nbs, D[n, i],
                                                       txt))
                stats.nbp += 1

        # display source sentece if requested
        if (args.include_source == 'matches' and len(prev) > 0):
            ofp.write('{:d}\t{:6.1f}\t{}\n'.format(
                stats.nbs, 0.0, sentences[n].replace('@@ ', '')))

        if args.include_source == 'always':
            ofp.write('{:d}\t{:6.1f}\t{}\n'.format(
                stats.nbs, 0.0, sentences[n].replace('@@ ', '')))
        stats.nbs += 1
예제 #2
0
def MarginRatio(em, ofp, params, args, stats):
    if args.include_source == 'always':
        ofp.write('{:d}\t{:6.1f}\t{}\n'
                  .format(stats.nbs, 0.0, sentences[n].replace('@@ ', '')))
    D, I = params.idx.search(em, args.margin_k)
    Mean = D.mean(axis=1)
    for n in range(D.shape[0]):
        if D[n, 0] / Mean[n] <= args.threshold:
            if args.include_source == 'matches':
                ofp.write('{:d}\t{:6.1f}\t{}\n'
                          .format(stats.nbs, 0.0, sentences[n].replace('@@ ', '')))
            txt = IndexTextQuery(params.T, params.R, I[n, 0])
            ofp.write('{:d}\t{:7.5f}\t{}\n'.format(stats.nbs, D[n, 0], txt))
            stats.nbp += 1

        stats.nbs += 1
예제 #3
0
def MarginAbs(em, ofp, params, args, stats):
    D, I = params.idx.search(em, args.kmax)
    for n in range(D.shape[0]):

        if args.include_source == 'always':
            ofp.write('{:d}\t{:6.1f}\t{}\n'.format(
                stats.nbs, 0.0, sentences[n].replace('@@ ', '')))
        prev = {}  # for depuplication
        for i in range(args.kmax):
            txt = IndexTextQuery(params.T, params.R, I[n, i])
            if (args.dedup and txt not in prev) and D[n, i] <= args.threshold:
                prev[txt] = 1
                ofp.write('{:d}\t{:7.5f}\t{}\n'.format(stats.nbs, D[n, i],
                                                       txt))
                stats.nbp += 1
        if (args.include_source == 'matches' and len(prev) > 1):
            ofp.write('{:d}\t{:6.1f}\t{}\n'.format(
                stats.nbs, 0.0, sentences[n].replace('@@ ', '')))
        stats.nbs += 1
예제 #4
0
                    required=True,
                    help="compressed text file")
parser.add_argument("--emb",
                    type=str,
                    required=True,
                    help="pytorch embeddings of text bank")
parser.add_argument("--K",
                    type=int,
                    default=100,
                    help="number of nearest neighbors per sentence")

args = parser.parse_args()

# load query embedding and bank embedding
query_emb = torch.load(args.input)
bank_emb = torch.load(args.emb)

# normalize embeddings
query_emb.div_(query_emb.norm(2, 1, keepdim=True).expand_as(query_emb))
bank_emb.div_(bank_emb.norm(2, 1, keepdim=True).expand_as(bank_emb))

# score and rank
scores = bank_emb.mm(query_emb.t())  # B x Q
_, indices = torch.topk(scores, params.k, dim=0)  # K x Q

# fetch and print retrieved text
txt_mmap, ref_mmap = IndexTextOpen(args.bank)
for qeury_idx in range(indices.size(1)):
    for k in range(K):
        print(IndexTextQuery(txt_mmap, ref_mmap, indices[k][qeury_idx]))