def evaluate(loader, seq2seq, criterion, max_len): losses = utils.AverageMeter() ppls = utils.AverageMeter() seq2seq.eval() bleu = BLEU() tot_st = time.time() bleu_time = 0. with torch.no_grad(): for i, example in enumerate(loader): src, src_lens, tgt, tgt_lens = parse(example) B = src.size(0) dec_outs, attn_ws = seq2seq(src, src_lens, tgt, tgt_lens, teacher_forcing=0.) loss, ppl = criterion(dec_outs, tgt[:, 1:]) losses.update(loss, B) ppls.update(ppl, B) # BLEU bleu_st = time.time() # convert logits to preds preds = dec_outs.max(-1)[1] # get pred lens by finding EOS token pred_lens = get_lens(preds, max_len) for pred, target, pred_len, target_len in zip( preds, tgt, pred_lens, tgt_lens): # target_len include SOS & EOS token => 1:target_len-1. bleu.add_sentence(pred[:pred_len].cpu().numpy(), target[1:target_len - 1].cpu().numpy()) bleu_time += time.time() - bleu_st total_time = time.time() - tot_st logger.debug("TIME: tot = {:.3f}\t bleu = {:.3f}".format( total_time, bleu_time)) return losses.avg, ppls.avg, bleu.score()
def evaluate(loader, seq2seq, criterion, max_len): import time losses = utils.AverageMeter() ppls = utils.AverageMeter() seq2seq.eval() bleu = BLEU() tot_st = time.time() bleu_time = 0. # BLEU time: 13k 개에 대해서 약 4s. multi-cpu parallelization 은 가능함. def get_lens(tensor, max_len=max_len): """ get first position (index) of EOS_idx in tensor = length of each sentence tensor: [B, T] """ # assume that former idx coming earlier in nonzero(). # tensor 가 [B, T] 이므로 nonzero 함수도 [i, j] 형태의 tuple 을 결과로 내놓는데, # 이 결과가 i => j 순으로 sorting 되어 있다고 가정. # e.g) nonzero() => [[1,1], [1,2], [2,1], [2,3], [2,5], ...] nz = (tensor == EOS_idx).nonzero() is_first = nz[:-1, 0] != nz[1:, 0] is_first = torch.cat([torch.cuda.ByteTensor([1]), is_first]) # first mask # convert is_first from mask to indice by nonzero() first_nz = nz[is_first.nonzero().flatten()] lens = torch.full([tensor.size(0)], max_len, dtype=torch.long).cuda() lens[first_nz[:, 0]] = first_nz[:, 1] return lens with torch.no_grad(): for i, (src, src_lens, tgt, tgt_lens) in enumerate(loader): B = src.size(0) src = src.cuda() tgt = tgt.cuda() dec_outs, attn_ws = seq2seq(src, src_lens, tgt, tgt_lens, teacher_forcing=0.) loss, ppl = criterion(dec_outs, tgt) losses.update(loss, B) ppls.update(ppl, B) # BLEU bleu_st = time.time() # convert logits to preds preds = dec_outs.max(-1)[1] # get pred lens by finding EOS token pred_lens = get_lens(preds) for pred, target, pred_len, target_len in zip( preds, tgt, pred_lens, tgt_lens): # target_len include EOS token => -1. bleu.add_sentence(pred[:pred_len].cpu().numpy(), target[:target_len - 1].cpu().numpy()) bleu_time += time.time() - bleu_st total_time = time.time() - tot_st logger.debug("TIME: tot = {:.3f}\t bleu = {:.3f}".format( total_time, bleu_time)) return losses.avg, ppls.avg, bleu.score()