示例#1
0
def evaluate(loader, seq2seq, criterion, max_len):
    losses = utils.AverageMeter()
    ppls = utils.AverageMeter()
    seq2seq.eval()
    bleu = BLEU()

    tot_st = time.time()
    bleu_time = 0.

    with torch.no_grad():
        for i, example in enumerate(loader):
            src, src_lens, tgt, tgt_lens = parse(example)
            B = src.size(0)

            dec_outs, attn_ws = seq2seq(src,
                                        src_lens,
                                        tgt,
                                        tgt_lens,
                                        teacher_forcing=0.)
            loss, ppl = criterion(dec_outs, tgt[:, 1:])
            losses.update(loss, B)
            ppls.update(ppl, B)

            # BLEU
            bleu_st = time.time()
            # convert logits to preds
            preds = dec_outs.max(-1)[1]
            # get pred lens by finding EOS token
            pred_lens = get_lens(preds, max_len)

            for pred, target, pred_len, target_len in zip(
                    preds, tgt, pred_lens, tgt_lens):
                # target_len include SOS & EOS token => 1:target_len-1.
                bleu.add_sentence(pred[:pred_len].cpu().numpy(),
                                  target[1:target_len - 1].cpu().numpy())

            bleu_time += time.time() - bleu_st
    total_time = time.time() - tot_st

    logger.debug("TIME: tot = {:.3f}\t bleu = {:.3f}".format(
        total_time, bleu_time))

    return losses.avg, ppls.avg, bleu.score()
示例#2
0
        org_stats = bleu_stats(hyp, refs)
        assert (mine_stats.flatten().astype(np.int) == org_stats).all()

        # bleu
        mine_bleu = BLEU.compute_bleu(mine_stats)
        org_bleu = bleu(org_stats)

        #print(mine_bleu, org_bleu)
        assert mine_bleu == org_bleu

    # total bleu score
    org = get_bleu(hyps, refses)
    bleu = BLEU()
    bleu.add_corpus(hyps, refses)
    print("org:", org)
    print("mine:", bleu.score())
    assert org == bleu.score()

    m_hyps = [" ".join(hyp) for hyp in hyps]
    m_refs = [" ".join(ref) for ref in refses]
    bleu_score = get_moses_multi_bleu(m_hyps, m_refs)
    print("moses:", bleu_score)
    assert round(float(bleu_score), 2) == round(bleu.score(), 2)

    print("All test passed !")

    print("Multi BLEU:")
    print("BLEU-1:", bleu.score(1))
    print("BLEU-2:", bleu.score(2))
    print("BLEU-3:", bleu.score(3))
    print("BLEU-4:", bleu.score(4))
示例#3
0
def evaluate(loader, seq2seq, criterion, max_len):
    import time
    losses = utils.AverageMeter()
    ppls = utils.AverageMeter()
    seq2seq.eval()
    bleu = BLEU()

    tot_st = time.time()
    bleu_time = 0.

    # BLEU time: 13k 개에 대해서 약 4s. multi-cpu parallelization 은 가능함.

    def get_lens(tensor, max_len=max_len):
        """ get first position (index) of EOS_idx in tensor
            = length of each sentence
        tensor: [B, T]
        """
        # assume that former idx coming earlier in nonzero().
        # tensor 가 [B, T] 이므로 nonzero 함수도 [i, j] 형태의 tuple 을 결과로 내놓는데,
        # 이 결과가 i => j 순으로 sorting 되어 있다고 가정.
        # e.g) nonzero() => [[1,1], [1,2], [2,1], [2,3], [2,5], ...]
        nz = (tensor == EOS_idx).nonzero()
        is_first = nz[:-1, 0] != nz[1:, 0]
        is_first = torch.cat([torch.cuda.ByteTensor([1]),
                              is_first])  # first mask

        # convert is_first from mask to indice by nonzero()
        first_nz = nz[is_first.nonzero().flatten()]
        lens = torch.full([tensor.size(0)], max_len, dtype=torch.long).cuda()
        lens[first_nz[:, 0]] = first_nz[:, 1]
        return lens

    with torch.no_grad():
        for i, (src, src_lens, tgt, tgt_lens) in enumerate(loader):
            B = src.size(0)
            src = src.cuda()
            tgt = tgt.cuda()

            dec_outs, attn_ws = seq2seq(src,
                                        src_lens,
                                        tgt,
                                        tgt_lens,
                                        teacher_forcing=0.)
            loss, ppl = criterion(dec_outs, tgt)
            losses.update(loss, B)
            ppls.update(ppl, B)

            # BLEU
            bleu_st = time.time()
            # convert logits to preds
            preds = dec_outs.max(-1)[1]
            # get pred lens by finding EOS token
            pred_lens = get_lens(preds)

            for pred, target, pred_len, target_len in zip(
                    preds, tgt, pred_lens, tgt_lens):
                # target_len include EOS token => -1.
                bleu.add_sentence(pred[:pred_len].cpu().numpy(),
                                  target[:target_len - 1].cpu().numpy())

            bleu_time += time.time() - bleu_st
    total_time = time.time() - tot_st

    logger.debug("TIME: tot = {:.3f}\t bleu = {:.3f}".format(
        total_time, bleu_time))

    return losses.avg, ppls.avg, bleu.score()