def test_expand_reverse_index(self): reverse_index = [1, 0, 2] n_best = 1 reverse_index_1best = expand_reverse_index(reverse_index, n_best) self.assertEqual(reverse_index_1best, [1, 0, 2]) n_best = 2 reverse_index_2best = expand_reverse_index(reverse_index, n_best) self.assertEqual(reverse_index_2best, [2, 3, 0, 1, 4, 5]) n_best = 3 reverse_index_3best = expand_reverse_index(reverse_index, n_best) self.assertEqual(reverse_index_3best, [3, 4, 5, 0, 1, 2, 6, 7, 8])
def validate_on_data(model: Model, data: Dataset, batch_size: int, use_cuda: bool, max_output_length: int, level: str, eval_metric: Optional[str], n_gpu: int, batch_class: Batch = Batch, compute_loss: bool = False, beam_size: int = 1, beam_alpha: int = -1, batch_type: str = "sentence", postprocess: bool = True, bpe_type: str = "subword-nmt", sacrebleu: dict = None, n_best: int = 1) \ -> (float, float, float, List[str], List[List[str]], List[str], List[str], List[List[str]], List[np.array]): """ Generate translations for the given data. If `compute_loss` is True and references are given, also compute the loss. :param model: model module :param data: dataset for validation :param batch_size: validation batch size :param batch_class: class type of batch :param use_cuda: if True, use CUDA :param max_output_length: maximum length for generated hypotheses :param level: segmentation level, one of "char", "bpe", "word" :param eval_metric: evaluation metric, e.g. "bleu" :param n_gpu: number of GPUs :param compute_loss: whether to computes a scalar loss for given inputs and targets :param beam_size: beam size for validation. If <2 then greedy decoding (default). :param beam_alpha: beam search alpha for length penalty, disabled if set to -1 (default). :param batch_type: validation batch type (sentence or token) :param postprocess: if True, remove BPE segmentation from translations :param bpe_type: bpe type, one of {"subword-nmt", "sentencepiece"} :param sacrebleu: sacrebleu options :param n_best: Amount of candidates to return :return: - current_valid_score: current validation score [eval_metric], - valid_loss: validation loss, - valid_ppl:, validation perplexity, - valid_sources: validation sources, - valid_sources_raw: raw validation sources (before post-processing), - valid_references: validation references, - valid_hypotheses: validation_hypotheses, - decoded_valid: raw validation hypotheses (before post-processing), - valid_attention_scores: attention scores for validation hypotheses """ assert batch_size >= n_gpu, "batch_size must be bigger than n_gpu." if sacrebleu is None: # assign default value sacrebleu = {"remove_whitespace": True, "tokenize": "13a"} if batch_size > 1000 and batch_type == "sentence": logger.warning( "WARNING: Are you sure you meant to work on huge batches like " "this? 'batch_size' is > 1000 for sentence-batching. " "Consider decreasing it or switching to" " 'eval_batch_type: token'.") valid_iter = make_data_iter(dataset=data, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False) valid_sources_raw = data.src pad_index = model.src_vocab.stoi[PAD_TOKEN] # disable dropout model.eval() # don't track gradients during validation with torch.no_grad(): all_outputs = [] valid_attention_scores = [] total_loss = 0 total_ntokens = 0 total_nseqs = 0 for valid_batch in iter(valid_iter): # run as during training to get validation loss (e.g. xent) batch = batch_class(valid_batch, pad_index, use_cuda=use_cuda) # sort batch now by src length and keep track of order reverse_index = batch.sort_by_src_length() sort_reverse_index = expand_reverse_index(reverse_index, n_best) # run as during training with teacher forcing if compute_loss and batch.trg is not None: batch_loss, _, _, _ = model(return_type="loss", **vars(batch)) if n_gpu > 1: batch_loss = batch_loss.mean() # average on multi-gpu total_loss += batch_loss total_ntokens += batch.ntokens total_nseqs += batch.nseqs # run as during inference to produce translations output, attention_scores = run_batch( model=model, batch=batch, beam_size=beam_size, beam_alpha=beam_alpha, max_output_length=max_output_length, n_best=n_best) # sort outputs back to original order all_outputs.extend(output[sort_reverse_index]) valid_attention_scores.extend( attention_scores[sort_reverse_index] if attention_scores is not None else []) assert len(all_outputs) == len(data) * n_best if compute_loss and total_ntokens > 0: # total validation loss valid_loss = total_loss # exponent of token-level negative log prob valid_ppl = torch.exp(total_loss / total_ntokens) else: valid_loss = -1 valid_ppl = -1 # decode back to symbols decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs, cut_at_eos=True) # evaluate with metric on full dataset join_char = " " if level in ["word", "bpe"] else "" valid_sources = [join_char.join(s) for s in data.src] valid_references = [join_char.join(t) for t in data.trg] valid_hypotheses = [join_char.join(t) for t in decoded_valid] # post-process if level == "bpe" and postprocess: valid_sources = [ bpe_postprocess(s, bpe_type=bpe_type) for s in valid_sources ] valid_references = [ bpe_postprocess(v, bpe_type=bpe_type) for v in valid_references ] valid_hypotheses = [ bpe_postprocess(v, bpe_type=bpe_type) for v in valid_hypotheses ] # if references are given, evaluate against them if valid_references: assert len(valid_hypotheses) == len(valid_references) current_valid_score = 0 if eval_metric.lower() == 'bleu': # this version does not use any tokenization current_valid_score = bleu(valid_hypotheses, valid_references, tokenize=sacrebleu["tokenize"]) elif eval_metric.lower() == 'chrf': current_valid_score = chrf( valid_hypotheses, valid_references, remove_whitespace=sacrebleu["remove_whitespace"]) elif eval_metric.lower() == 'token_accuracy': current_valid_score = token_accuracy( # supply List[List[str]] list(decoded_valid), list(data.trg)) elif eval_metric.lower() == 'sequence_accuracy': current_valid_score = sequence_accuracy( valid_hypotheses, valid_references) else: current_valid_score = -1 return current_valid_score, valid_loss, valid_ppl, valid_sources, \ valid_sources_raw, valid_references, valid_hypotheses, \ decoded_valid, valid_attention_scores