Exemple #1
0
def validate_on_data(model: Model, data: Dataset,
                     batch_size: int,
                     use_cuda: bool, max_output_length: int,
                     level: str, eval_metric: Optional[str],
                     n_gpu: int,
                     compute_loss: bool = False,
                     beam_size: int = 1, beam_alpha: int = -1,
                     batch_type: str = "sentence",
                     postprocess: bool = True,
                     bpe_type: str = "subword-nmt",
                     sacrebleu: dict = None) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `compute_loss` is True and references are given,
    also compute the loss.

    :param model: model module
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param level: segmentation level, one of "char", "bpe", "word"
    :param eval_metric: evaluation metric, e.g. "bleu"
    :param n_gpu: number of GPUs
    :param compute_loss: whether to computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If <2 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to -1 (default).
    :param batch_type: validation batch type (sentence or token)
    :param postprocess: if True, remove BPE segmentation from translations
    :param bpe_type: bpe type, one of {"subword-nmt", "sentencepiece"}
    :param sacrebleu: sacrebleu options

    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
    """
    assert batch_size >= n_gpu, "batch_size must be bigger than n_gpu."
    if sacrebleu is None:   # assign default value
        sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}
    if batch_size > 1000 and batch_type == "sentence":
        logger.warning(
            "WARNING: Are you sure you meant to work on huge batches like "
            "this? 'batch_size' is > 1000 for sentence-batching. "
            "Consider decreasing it or switching to"
            " 'eval_batch_type: token'.")
    valid_iter = make_data_iter(
        dataset=data, batch_size=batch_size, batch_type=batch_type,
        shuffle=False, train=False)
    valid_sources_raw = data.src
    pad_index = model.src_vocab.stoi[PAD_TOKEN]
    # disable dropout
    model.eval()
    # don't track gradients during validation
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = []
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda)
            # sort batch now by src length and keep track of order
            sort_reverse_index = batch.sort_by_src_length()

            # run as during training with teacher forcing
            if compute_loss and batch.trg is not None:
                batch_loss, _, _, _ = model(
                    return_type="loss", src=batch.src, trg=batch.trg,
                    trg_input=batch.trg_input, trg_mask=batch.trg_mask,
                    src_mask=batch.src_mask, src_length=batch.src_length)
                if n_gpu > 1:
                    batch_loss = batch_loss.mean() # average on multi-gpu
                total_loss += batch_loss
                total_ntokens += batch.ntokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores = run_batch(
                model=model, batch=batch, beam_size=beam_size,
                beam_alpha=beam_alpha, max_output_length=max_output_length)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
            valid_attention_scores.extend(
                attention_scores[sort_reverse_index]
                if attention_scores is not None else [])

        assert len(all_outputs) == len(data)

        if compute_loss and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log prob
            valid_ppl = torch.exp(total_loss / total_ntokens)
        else:
            valid_loss = -1
            valid_ppl = -1

        # decode back to symbols
        decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs,
                                                            cut_at_eos=True)

        # evaluate with metric on full dataset
        join_char = " " if level in ["word", "bpe"] else ""
        valid_sources = [join_char.join(s) for s in data.src]
        valid_references = [join_char.join(t) for t in data.trg]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid]

        # post-process
        if level == "bpe" and postprocess:
            valid_sources = [bpe_postprocess(s, bpe_type=bpe_type)
                             for s in valid_sources]
            valid_references = [bpe_postprocess(v, bpe_type=bpe_type)
                                for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v, bpe_type=bpe_type)
                                for v in valid_hypotheses]

        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(
                    valid_hypotheses, valid_references,
                    tokenize=sacrebleu["tokenize"])
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(valid_hypotheses, valid_references,
                    remove_whitespace=sacrebleu["remove_whitespace"])
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(   # supply List[List[str]]
                    list(decoded_valid), list(data.trg))
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)
        else:
            current_valid_score = -1

    return current_valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores
Exemple #2
0
    def testBatchDevIterator(self):

        batch_size = 3
        self.assertEqual(len(self.dev_data), 20)

        # make data iterator
        dev_iter = make_data_iter(self.dev_data,
                                  train=False,
                                  shuffle=False,
                                  batch_size=batch_size)
        self.assertEqual(dev_iter.batch_size, batch_size)
        self.assertFalse(dev_iter.shuffle)
        self.assertFalse(dev_iter.train)
        self.assertEqual(dev_iter.epoch, 0)
        self.assertEqual(dev_iter.iterations, 0)

        expected_src0 = torch.Tensor([[
            29, 8, 5, 22, 5, 8, 16, 7, 19, 5, 22, 5, 24, 8, 7, 5, 7, 19, 16,
            16, 5, 31, 10, 19, 11, 8, 17, 15, 10, 6, 18, 5, 7, 4, 10, 6, 5, 25,
            3
        ],
                                      [
                                          10, 17, 11, 5, 28, 12, 4, 23, 4, 5,
                                          0, 10, 17, 11, 5, 22, 5, 14, 8, 7, 7,
                                          5, 10, 17, 11, 5, 14, 8, 5, 31, 10,
                                          6, 5, 9, 3, 1, 1, 1, 1
                                      ],
                                      [
                                          29, 8, 5, 22, 5, 18, 23, 13, 4, 6, 5,
                                          13, 8, 18, 5, 9, 3, 1, 1, 1, 1, 1, 1,
                                          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                                          1, 1, 1, 1
                                      ]]).long()
        expected_src0_len = torch.Tensor([39, 35, 17]).long()
        expected_trg0 = torch.Tensor([[
            13, 11, 12, 4, 22, 4, 12, 5, 4, 22, 4, 25, 7, 6, 8, 4, 14, 12, 4,
            24, 14, 5, 7, 6, 26, 17, 14, 10, 20, 4, 23, 3
        ],
                                      [
                                          14, 0, 28, 4, 7, 6, 18, 18, 13, 4, 8,
                                          5, 4, 24, 11, 4, 7, 11, 16, 11, 4, 9,
                                          3, 1, 1, 1, 1, 1, 1, 1, 1, 1
                                      ],
                                      [
                                          13, 11, 12, 4, 22, 4, 7, 11, 27, 27,
                                          5, 4, 9, 3, 1, 1, 1, 1, 1, 1, 1, 1,
                                          1, 1, 1, 1, 1, 1, 1, 1, 1, 1
                                      ]]).long()
        expected_trg0_len = torch.Tensor([33, 24, 15]).long()

        total_samples = 0
        for b in iter(dev_iter):
            self.assertEqual(type(b), TorchTBatch)
            b = Batch(b, pad_index=self.pad_index)

            # test the sorting by src length
            self.assertEqual(type(b), Batch)
            before_sort = b.src_length
            b.sort_by_src_length()
            after_sort = b.src_length
            self.assertTensorEqual(
                torch.sort(before_sort, descending=True)[0], after_sort)
            self.assertEqual(type(b), Batch)

            if total_samples == 0:
                self.assertTensorEqual(b.src, expected_src0)
                self.assertTensorEqual(b.src_length, expected_src0_len)
                self.assertTensorEqual(b.trg, expected_trg0)
                self.assertTensorEqual(b.trg_length, expected_trg0_len)
            total_samples += b.nseqs
            self.assertLessEqual(b.nseqs, batch_size)
        self.assertEqual(total_samples, len(self.dev_data))