def evaluate(data_loader, context=ctx[0]): """Evaluate given the data loader Parameters ---------- data_loader : DataLoader Returns ------- avg_loss : float Average loss real_translation_out : list of list of str The translation output """ translation_out = [] all_inst_ids = [] avg_loss_denom = 0 avg_loss = 0.0 for _, (src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids) \ in enumerate(data_loader): src_seq = src_seq.as_in_context(context) tgt_seq = tgt_seq.as_in_context(context) src_valid_length = src_valid_length.as_in_context(context) tgt_valid_length = tgt_valid_length.as_in_context(context) # Calculating Loss out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) tgt_pred_logits = mx.nd.log_softmax(out, axis=-1) nll_loss = - mx.nd.pick(tgt_pred_logits, tgt_seq[:, 1:], axis=-1) masked_nll_loss = mx.nd.SequenceMask(nll_loss, sequence_length=tgt_valid_length - 1, use_sequence_length=True, axis=1) loss = masked_nll_loss.sum().asscalar() # loss = test_loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean().asscalar() all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist()) # avg_loss += loss * (tgt_seq.shape[1] - 1) # avg_loss_denom += (tgt_seq.shape[1] - 1) avg_loss += loss avg_loss_denom += (tgt_valid_length - 1).sum().asscalar() # Translate samples, _, sample_valid_length = \ translator.translate(src_seq=src_seq, src_valid_length=src_valid_length) max_score_sample = samples[:, 0, :].asnumpy() sample_valid_length = sample_valid_length[:, 0].asnumpy() for i in range(max_score_sample.shape[0]): translation_out.append( [tgt_vocab.idx_to_token[ele] for ele in max_score_sample[i][1:(sample_valid_length[i] - 1)]]) avg_loss = avg_loss / avg_loss_denom real_translation_out = [None for _ in range(len(all_inst_ids))] for ind, sentence in zip(all_inst_ids, translation_out): if args.bleu == 'tweaked': real_translation_out[ind] = sentence elif args.bleu == '13a' or args.bleu == 'intl': real_translation_out[ind] = detokenizer(_bpe_to_words(sentence)) else: raise NotImplementedError return avg_loss, real_translation_out
def evaluate(data_loader, context=ctx[0]): """Evaluate given the data loader Parameters ---------- data_loader : DataLoader Returns ------- avg_loss : float Average loss real_translation_out : list of list of str The translation output """ translation_out = [] all_inst_ids = [] avg_loss_denom = 0 avg_loss = 0.0 for _, (src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids) \ in enumerate(data_loader): src_seq = src_seq.as_in_context(context) tgt_seq = tgt_seq.as_in_context(context) src_valid_length = src_valid_length.as_in_context(context) tgt_valid_length = tgt_valid_length.as_in_context(context) # Calculating Loss out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) loss = test_loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean().asscalar() all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist()) avg_loss += loss * (tgt_seq.shape[1] - 1) avg_loss_denom += (tgt_seq.shape[1] - 1) # Translate samples, _, sample_valid_length = \ translator.translate(src_seq=src_seq, src_valid_length=src_valid_length) max_score_sample = samples[:, 0, :].asnumpy() sample_valid_length = sample_valid_length[:, 0].asnumpy() for i in range(max_score_sample.shape[0]): translation_out.append( [tgt_vocab.idx_to_token[ele] for ele in max_score_sample[i][1:(sample_valid_length[i] - 1)]]) avg_loss = avg_loss / avg_loss_denom real_translation_out = [None for _ in range(len(all_inst_ids))] for ind, sentence in zip(all_inst_ids, translation_out): if args.bleu == 'tweaked': real_translation_out[ind] = sentence elif args.bleu == '13a' or args.bleu == 'intl': real_translation_out[ind] = detokenizer(_bpe_to_words(sentence), return_str=True) else: raise NotImplementedError return avg_loss, real_translation_out
def inference(): """inference function.""" logging.info('Inference on test_dataset!') # data prepare test_data_loader = dataprocessor.get_dataloader(data_test, args, dataset_type='test', use_average_length=True) if args.bleu == 'tweaked': bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY') split_compound_word = bpe tokenized = True elif args.bleu == '13a' or args.bleu == 'intl': bpe = False split_compound_word = False tokenized = False else: raise NotImplementedError translation_out = [] all_inst_ids = [] total_wc = 0 total_time = 0 batch_total_blue = 0 for batch_id, (src_seq, tgt_seq, src_test_length, tgt_test_length, inst_ids) \ in enumerate(test_data_loader): total_wc += src_test_length.sum().asscalar() + tgt_test_length.sum().asscalar() src_seq = src_seq.as_in_context(ctx[0]) tgt_seq = tgt_seq.as_in_context(ctx[0]) src_test_length = src_test_length.as_in_context(ctx[0]) tgt_test_length = tgt_test_length.as_in_context(ctx[0]) all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist()) start = time.time() # Translate to get a bleu score samples, _, sample_test_length = \ translator.translate(src_seq=src_seq, src_valid_length=src_test_length) total_time += (time.time() - start) # generator the translator result for each batch max_score_sample = samples[:, 0, :].asnumpy() sample_test_length = sample_test_length[:, 0].asnumpy() translation_tmp = [] translation_tmp_sentences = [] for i in range(max_score_sample.shape[0]): translation_tmp.append([tgt_vocab.idx_to_token[ele] for ele in \ max_score_sample[i][1:(sample_test_length[i] - 1)]]) # detokenizer each translator result for _, sentence in enumerate(translation_tmp): if args.bleu == 'tweaked': translation_tmp_sentences.append(sentence) translation_out.append(sentence) elif args.bleu == '13a' or args.bleu == 'intl': translation_tmp_sentences.append(detokenizer(_bpe_to_words(sentence))) translation_out.append(detokenizer(_bpe_to_words(sentence))) else: raise NotImplementedError # generate tgt_sentence for bleu calculation of each batch tgt_sen_tmp = [test_tgt_sentences[index] for \ _, index in enumerate(inst_ids.asnumpy().astype(np.int32).tolist())] batch_test_bleu_score, _, _, _, _ = compute_bleu([tgt_sen_tmp], translation_tmp_sentences, tokenized=tokenized, tokenizer=args.bleu, split_compound_word=split_compound_word, bpe=bpe) batch_total_blue += batch_test_bleu_score # log for every ten batchs if batch_id % 10 == 0 and batch_id != 0: batch_ave_bleu = batch_total_blue / 10 batch_total_blue = 0 logging.info('batch id={:d}, batch_bleu={:.4f}' .format(batch_id, batch_ave_bleu * 100)) # reorg translation sentences by inst_ids real_translation_out = [None for _ in range(len(all_inst_ids))] for ind, sentence in zip(all_inst_ids, translation_out): real_translation_out[ind] = sentence # get bleu score, n-gram precisions, brevity penalty, reference length, and translation length test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], real_translation_out, tokenized=tokenized, tokenizer=args.bleu, split_compound_word=split_compound_word, bpe=bpe) logging.info('Inference at test dataset. \ inference bleu={:.4f}, throughput={:.4f}K wps' .format(test_bleu_score * 100, total_wc / total_time / 1000))