def validate_on_data(model: Model, data: Dataset, batch_size: int, use_cuda: bool, max_output_length: int, level: str, eval_metric: Optional[str], n_gpu: int, compute_loss: bool = False, beam_size: int = 1, beam_alpha: int = -1, batch_type: str = "sentence", postprocess: bool = True, bpe_type: str = "subword-nmt", sacrebleu: dict = None) \ -> (float, float, float, List[str], List[List[str]], List[str], List[str], List[List[str]], List[np.array]): """ Generate translations for the given data. If `compute_loss` is True and references are given, also compute the loss. :param model: model module :param data: dataset for validation :param batch_size: validation batch size :param use_cuda: if True, use CUDA :param max_output_length: maximum length for generated hypotheses :param level: segmentation level, one of "char", "bpe", "word" :param eval_metric: evaluation metric, e.g. "bleu" :param n_gpu: number of GPUs :param compute_loss: whether to computes a scalar loss for given inputs and targets :param beam_size: beam size for validation. If <2 then greedy decoding (default). :param beam_alpha: beam search alpha for length penalty, disabled if set to -1 (default). :param batch_type: validation batch type (sentence or token) :param postprocess: if True, remove BPE segmentation from translations :param bpe_type: bpe type, one of {"subword-nmt", "sentencepiece"} :param sacrebleu: sacrebleu options :return: - current_valid_score: current validation score [eval_metric], - valid_loss: validation loss, - valid_ppl:, validation perplexity, - valid_sources: validation sources, - valid_sources_raw: raw validation sources (before post-processing), - valid_references: validation references, - valid_hypotheses: validation_hypotheses, - decoded_valid: raw validation hypotheses (before post-processing), - valid_attention_scores: attention scores for validation hypotheses """ assert batch_size >= n_gpu, "batch_size must be bigger than n_gpu." if sacrebleu is None: # assign default value sacrebleu = {"remove_whitespace": True, "tokenize": "13a"} if batch_size > 1000 and batch_type == "sentence": logger.warning( "WARNING: Are you sure you meant to work on huge batches like " "this? 'batch_size' is > 1000 for sentence-batching. " "Consider decreasing it or switching to" " 'eval_batch_type: token'.") valid_iter = make_data_iter( dataset=data, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False) valid_sources_raw = data.src pad_index = model.src_vocab.stoi[PAD_TOKEN] # disable dropout model.eval() # don't track gradients during validation with torch.no_grad(): all_outputs = [] valid_attention_scores = [] total_loss = 0 total_ntokens = 0 total_nseqs = 0 for valid_batch in iter(valid_iter): # run as during training to get validation loss (e.g. xent) batch = Batch(valid_batch, pad_index, use_cuda=use_cuda) # sort batch now by src length and keep track of order sort_reverse_index = batch.sort_by_src_length() # run as during training with teacher forcing if compute_loss and batch.trg is not None: batch_loss, _, _, _ = model( return_type="loss", src=batch.src, trg=batch.trg, trg_input=batch.trg_input, trg_mask=batch.trg_mask, src_mask=batch.src_mask, src_length=batch.src_length) if n_gpu > 1: batch_loss = batch_loss.mean() # average on multi-gpu total_loss += batch_loss total_ntokens += batch.ntokens total_nseqs += batch.nseqs # run as during inference to produce translations output, attention_scores = run_batch( model=model, batch=batch, beam_size=beam_size, beam_alpha=beam_alpha, max_output_length=max_output_length) # sort outputs back to original order all_outputs.extend(output[sort_reverse_index]) valid_attention_scores.extend( attention_scores[sort_reverse_index] if attention_scores is not None else []) assert len(all_outputs) == len(data) if compute_loss and total_ntokens > 0: # total validation loss valid_loss = total_loss # exponent of token-level negative log prob valid_ppl = torch.exp(total_loss / total_ntokens) else: valid_loss = -1 valid_ppl = -1 # decode back to symbols decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs, cut_at_eos=True) # evaluate with metric on full dataset join_char = " " if level in ["word", "bpe"] else "" valid_sources = [join_char.join(s) for s in data.src] valid_references = [join_char.join(t) for t in data.trg] valid_hypotheses = [join_char.join(t) for t in decoded_valid] # post-process if level == "bpe" and postprocess: valid_sources = [bpe_postprocess(s, bpe_type=bpe_type) for s in valid_sources] valid_references = [bpe_postprocess(v, bpe_type=bpe_type) for v in valid_references] valid_hypotheses = [bpe_postprocess(v, bpe_type=bpe_type) for v in valid_hypotheses] # if references are given, evaluate against them if valid_references: assert len(valid_hypotheses) == len(valid_references) current_valid_score = 0 if eval_metric.lower() == 'bleu': # this version does not use any tokenization current_valid_score = bleu( valid_hypotheses, valid_references, tokenize=sacrebleu["tokenize"]) elif eval_metric.lower() == 'chrf': current_valid_score = chrf(valid_hypotheses, valid_references, remove_whitespace=sacrebleu["remove_whitespace"]) elif eval_metric.lower() == 'token_accuracy': current_valid_score = token_accuracy( # supply List[List[str]] list(decoded_valid), list(data.trg)) elif eval_metric.lower() == 'sequence_accuracy': current_valid_score = sequence_accuracy( valid_hypotheses, valid_references) else: current_valid_score = -1 return current_valid_score, valid_loss, valid_ppl, valid_sources, \ valid_sources_raw, valid_references, valid_hypotheses, \ decoded_valid, valid_attention_scores
def testBatchDevIterator(self): batch_size = 3 self.assertEqual(len(self.dev_data), 20) # make data iterator dev_iter = make_data_iter(self.dev_data, train=False, shuffle=False, batch_size=batch_size) self.assertEqual(dev_iter.batch_size, batch_size) self.assertFalse(dev_iter.shuffle) self.assertFalse(dev_iter.train) self.assertEqual(dev_iter.epoch, 0) self.assertEqual(dev_iter.iterations, 0) expected_src0 = torch.Tensor([[ 29, 8, 5, 22, 5, 8, 16, 7, 19, 5, 22, 5, 24, 8, 7, 5, 7, 19, 16, 16, 5, 31, 10, 19, 11, 8, 17, 15, 10, 6, 18, 5, 7, 4, 10, 6, 5, 25, 3 ], [ 10, 17, 11, 5, 28, 12, 4, 23, 4, 5, 0, 10, 17, 11, 5, 22, 5, 14, 8, 7, 7, 5, 10, 17, 11, 5, 14, 8, 5, 31, 10, 6, 5, 9, 3, 1, 1, 1, 1 ], [ 29, 8, 5, 22, 5, 18, 23, 13, 4, 6, 5, 13, 8, 18, 5, 9, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ]]).long() expected_src0_len = torch.Tensor([39, 35, 17]).long() expected_trg0 = torch.Tensor([[ 13, 11, 12, 4, 22, 4, 12, 5, 4, 22, 4, 25, 7, 6, 8, 4, 14, 12, 4, 24, 14, 5, 7, 6, 26, 17, 14, 10, 20, 4, 23, 3 ], [ 14, 0, 28, 4, 7, 6, 18, 18, 13, 4, 8, 5, 4, 24, 11, 4, 7, 11, 16, 11, 4, 9, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1 ], [ 13, 11, 12, 4, 22, 4, 7, 11, 27, 27, 5, 4, 9, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ]]).long() expected_trg0_len = torch.Tensor([33, 24, 15]).long() total_samples = 0 for b in iter(dev_iter): self.assertEqual(type(b), TorchTBatch) b = Batch(b, pad_index=self.pad_index) # test the sorting by src length self.assertEqual(type(b), Batch) before_sort = b.src_length b.sort_by_src_length() after_sort = b.src_length self.assertTensorEqual( torch.sort(before_sort, descending=True)[0], after_sort) self.assertEqual(type(b), Batch) if total_samples == 0: self.assertTensorEqual(b.src, expected_src0) self.assertTensorEqual(b.src_length, expected_src0_len) self.assertTensorEqual(b.trg, expected_trg0) self.assertTensorEqual(b.trg_length, expected_trg0_len) total_samples += b.nseqs self.assertLessEqual(b.nseqs, batch_size) self.assertEqual(total_samples, len(self.dev_data))