Exemplo n.º 1
0
    def test_token_acc_level_char(self):
        # if len(hyp) > len(ref)
        hyp = [list("tests")]
        ref = [list("tezt")]
        #level = "char"
        score = token_accuracy(hyp, ref)
        self.assertEqual(score, 60.0)

        # if len(hyp) < len(ref)
        hyp = [list("test")]
        ref = [list("tezts")]
        #level = "char"
        score = token_accuracy(hyp, ref)
        self.assertEqual(score, 75.0)
Exemplo n.º 2
0
    def Reward_bleu_fin(self, trg, hyp, show=False):
        """
        To use as self.Reward funtion.
        Return an array of rewards, based on the differences
        of current Blue Score. As proposed on paper.

        :param trg: target.
        :param hyp: the predicted sequence.
        :param show: Boolean, display the computation of the rewards
        :return: current Bleu score
        """
        rew = np.zeros(len(hyp[0]))

        decoded_valid_tar = self.model.trg_vocab.arrays_to_sentences(
            arrays=trg, cut_at_eos=True)
        decoded_valid_hyp = self.model.trg_vocab.arrays_to_sentences(
            arrays=hyp, cut_at_eos=True)

        # evaluate with metric on each src, tar, and hypotesis
        join_char = " " if self.level in ["word", "bpe"] else ""
        valid_references = [join_char.join(t) for t in decoded_valid_tar]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid_hyp]

        # post-process
        if self.level == "bpe":
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]
        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            current_valid_score = 0
            if self.eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                #print(' aaa ')
                current_valid_score = bleu(valid_hypotheses, valid_references)
            elif self.eval_metric.lower() == 'chrf':
                current_valid_score = chrf(valid_hypotheses, valid_references)
            elif self.eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(valid_hypotheses,
                                                     valid_references,
                                                     level=self.level)
            elif self.eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)
        else:
            current_valid_score = -1

        rew[-1] = current_valid_score

        final_rew = rew[1:]
        if show:
            print(
                "\n Sample-------------Target vs Eval_net prediction:--Raw---and---Decoded-----"
            )
            print("Target: ", trg, decoded_valid_tar)
            print("Eval  : ", hyp, decoded_valid_hyp)
            print("Reward: ", final_rew, "\n")

        return final_rew
Exemplo n.º 3
0
 def test_token_acc_level_char(self):
     hyp = ["test"]
     ref = ["tezt"]
     level = "char"
     acc = token_accuracy(hyp, ref, level)
     self.assertEqual(acc, 75)
Exemplo n.º 4
0
def validate_on_data(model: Model, data: Dataset,
                     batch_size: int,
                     use_cuda: bool, max_output_length: int,
                     level: str, eval_metric: Optional[str],
                     n_gpu: int,
                     batch_class: Batch = Batch,
                     compute_loss: bool = False,
                     beam_size: int = 1, beam_alpha: int = -1,
                     batch_type: str = "sentence",
                     postprocess: bool = True,
                     bpe_type: str = "subword-nmt",
                     sacrebleu: dict = None,
                     n_best: int = 1) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `compute_loss` is True and references are given,
    also compute the loss.

    :param model: model module
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param batch_class: class type of batch
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param level: segmentation level, one of "char", "bpe", "word"
    :param eval_metric: evaluation metric, e.g. "bleu"
    :param n_gpu: number of GPUs
    :param compute_loss: whether to computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If <2 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to -1 (default).
    :param batch_type: validation batch type (sentence or token)
    :param postprocess: if True, remove BPE segmentation from translations
    :param bpe_type: bpe type, one of {"subword-nmt", "sentencepiece"}
    :param sacrebleu: sacrebleu options
    :param n_best: Amount of candidates to return

    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
    """
    assert batch_size >= n_gpu, "batch_size must be bigger than n_gpu."
    if sacrebleu is None:  # assign default value
        sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}
    if batch_size > 1000 and batch_type == "sentence":
        logger.warning(
            "WARNING: Are you sure you meant to work on huge batches like "
            "this? 'batch_size' is > 1000 for sentence-batching. "
            "Consider decreasing it or switching to"
            " 'eval_batch_type: token'.")
    valid_iter = make_data_iter(dataset=data,
                                batch_size=batch_size,
                                batch_type=batch_type,
                                shuffle=False,
                                train=False)
    valid_sources_raw = data.src
    pad_index = model.src_vocab.stoi[PAD_TOKEN]
    # disable dropout
    model.eval()
    # don't track gradients during validation
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = []
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = batch_class(valid_batch, pad_index, use_cuda=use_cuda)
            # sort batch now by src length and keep track of order
            reverse_index = batch.sort_by_src_length()
            sort_reverse_index = expand_reverse_index(reverse_index, n_best)

            # run as during training with teacher forcing
            if compute_loss and batch.trg is not None:
                batch_loss, _, _, _ = model(return_type="loss", **vars(batch))
                if n_gpu > 1:
                    batch_loss = batch_loss.mean()  # average on multi-gpu
                total_loss += batch_loss
                total_ntokens += batch.ntokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores = run_batch(
                model=model,
                batch=batch,
                beam_size=beam_size,
                beam_alpha=beam_alpha,
                max_output_length=max_output_length,
                n_best=n_best)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
            valid_attention_scores.extend(
                attention_scores[sort_reverse_index]
                if attention_scores is not None else [])

        assert len(all_outputs) == len(data) * n_best

        if compute_loss and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log prob
            valid_ppl = torch.exp(total_loss / total_ntokens)
        else:
            valid_loss = -1
            valid_ppl = -1

        # decode back to symbols
        decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs,
                                                            cut_at_eos=True)

        # evaluate with metric on full dataset
        join_char = " " if level in ["word", "bpe"] else ""
        valid_sources = [join_char.join(s) for s in data.src]
        valid_references = [join_char.join(t) for t in data.trg]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid]

        # post-process
        if level == "bpe" and postprocess:
            valid_sources = [
                bpe_postprocess(s, bpe_type=bpe_type) for s in valid_sources
            ]
            valid_references = [
                bpe_postprocess(v, bpe_type=bpe_type) for v in valid_references
            ]
            valid_hypotheses = [
                bpe_postprocess(v, bpe_type=bpe_type) for v in valid_hypotheses
            ]

        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses,
                                           valid_references,
                                           tokenize=sacrebleu["tokenize"])
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(
                    valid_hypotheses,
                    valid_references,
                    remove_whitespace=sacrebleu["remove_whitespace"])
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(  # supply List[List[str]]
                    list(decoded_valid), list(data.trg))
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)
        else:
            current_valid_score = -1

    return current_valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores
Exemplo n.º 5
0
def validate_on_data(model: Model, data: Dataset,
                     logger: Logger,
                     batch_size: int,
                     use_cuda: bool, max_output_length: int,
                     level: str, eval_metric: Optional[str],
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 1, beam_alpha: int = -1,
                     batch_type: str = "sentence",
                     postprocess: bool = True
                     ) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    also compute the loss.

    :param model: model module
    :param logger: logger
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param level: segmentation level, one of "char", "bpe", "word"
    :param eval_metric: evaluation metric, e.g. "bleu"
    :param loss_function: loss function that computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If <2 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to -1 (default).
    :param batch_type: validation batch type (sentence or token)
    :param postprocess: if True, remove BPE segmentation from translations

    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
    """
    if batch_size > 1000 and batch_type == "sentence":
        logger.warning(
            "WARNING: Are you sure you meant to work on huge batches like "
            "this? 'batch_size' is > 1000 for sentence-batching. "
            "Consider decreasing it or switching to"
            " 'eval_batch_type: token'.")
    valid_iter = make_data_iter(dataset=data,
                                batch_size=batch_size,
                                batch_type=batch_type,
                                shuffle=False,
                                train=False)
    valid_sources_raw = data.src
    pad_index = model.src_vocab.stoi[PAD_TOKEN]
    # disable dropout
    model.eval()
    # don't track gradients during validation
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = []
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda)
            # sort batch now by src length and keep track of order
            sort_reverse_index = batch.sort_by_src_lengths()

            # run as during training with teacher forcing
            if loss_function is not None and batch.trg is not None:
                batch_loss = model.get_loss_for_batch(
                    batch, loss_function=loss_function)
                total_loss += batch_loss
                total_ntokens += batch.ntokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores = model.run_batch(
                batch=batch,
                beam_size=beam_size,
                beam_alpha=beam_alpha,
                max_output_length=max_output_length)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
            valid_attention_scores.extend(
                attention_scores[sort_reverse_index]
                if attention_scores is not None else [])

        assert len(all_outputs) == len(data)

        if loss_function is not None and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log prob
            valid_ppl = torch.exp(total_loss / total_ntokens)
        else:
            valid_loss = -1
            valid_ppl = -1

        # decode back to symbols
        decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs,
                                                            cut_at_eos=True)

        # evaluate with metric on full dataset
        join_char = " " if level in ["word", "bpe"] else ""
        valid_sources = [join_char.join(s) for s in data.src]
        valid_references = [join_char.join(t) for t in data.trg]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid]

        # post-process
        if level == "bpe" and postprocess:
            valid_sources = [bpe_postprocess(s) for s in valid_sources]
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]

        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(valid_hypotheses,
                                                     valid_references,
                                                     level=level)
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)
        else:
            current_valid_score = -1

    return current_valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores
Exemplo n.º 6
0
def validate_on_data(model,
                     data,
                     batch_size,
                     use_cuda,
                     max_output_length,
                     level,
                     eval_metric,
                     criterion,
                     beam_size=0,
                     beam_alpha=-1):
    """
    Generate translations for the given data.
    If `criterion` is not None and references are given, also compute the loss.

    :param model:
    :param data:
    :param batch_size:
    :param use_cuda:
    :param max_output_length:
    :param level:
    :param eval_metric:
    :param criterion:
    :param beam_size:
    :param beam_alpha:
    :return:
    """
    valid_iter = make_data_iter(dataset=data,
                                batch_size=batch_size,
                                shuffle=False,
                                train=False)
    valid_sources_raw = [s for s in data.src]
    pad_index = model.src_vocab.stoi[PAD_TOKEN]
    # disable dropout
    model.eval()
    # don't track gradients during validation
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = []
        total_loss = 0
        total_ntokens = 0
        for valid_i, valid_batch in enumerate(iter(valid_iter), 1):
            # run as during training to get validation loss (e.g. xent)

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda)
            # sort batch now by src length and keep track of order
            sort_reverse_index = batch.sort_by_src_lengths()

            # TODO save computation: forward pass is computed twice
            # run as during training with teacher forcing
            if criterion is not None and batch.trg is not None:
                batch_loss = model.get_loss_for_batch(batch,
                                                      criterion=criterion)
                total_loss += batch_loss
                total_ntokens += batch.ntokens

            # run as during inference to produce translations
            output, attention_scores = model.run_batch(
                batch=batch,
                beam_size=beam_size,
                beam_alpha=beam_alpha,
                max_output_length=max_output_length)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
            valid_attention_scores.extend(
                attention_scores[sort_reverse_index]
                if attention_scores is not None else [])

        assert len(all_outputs) == len(data)

        if criterion is not None and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log prob
            valid_ppl = torch.exp(total_loss / total_ntokens)
        else:
            valid_loss = -1
            valid_ppl = -1

        # decode back to symbols
        decoded_valid = arrays_to_sentences(arrays=all_outputs,
                                            vocabulary=model.trg_vocab,
                                            cut_at_eos=True)

        # evaluate with metric on full dataset
        join_char = " " if level in ["word", "bpe"] else ""
        valid_sources = [join_char.join(s) for s in data.src]
        valid_references = [join_char.join(t) for t in data.trg]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid]

        # post-process
        if level == "bpe":
            valid_sources = [bpe_postprocess(s) for s in valid_sources]
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]

        # if references are given, evaluate against them
        if len(valid_references) > 0:
            assert len(valid_hypotheses) == len(valid_references)

            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(valid_hypotheses,
                                                     valid_references,
                                                     level=level)
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)
        else:
            current_valid_score = -1

    return current_valid_score, valid_loss, valid_ppl, valid_sources, \
           valid_sources_raw, valid_references, valid_hypotheses, \
           decoded_valid, \
           valid_attention_scores
Exemplo n.º 7
0
    def dev_network(self):
        """
        Show how is the current performace over the dev dataset, by mean of the
        total reward and the belu score.
        
        :return: current Bleu score
        """
        freeze_model(self.eval_net)
        for data_set_name, data_set in self.data_to_dev.items():
            #print(data_set_name)
            valid_iter = make_data_iter(dataset=data_set,
                                        batch_size=1,
                                        batch_type=self.batch_type,
                                        shuffle=False,
                                        train=False)
            valid_sources_raw = data_set.src

            # don't track gradients during validation
            r_total = 0
            roptimal_total = 0
            all_outputs = []
            i_sample = 0

            for valid_batch in iter(valid_iter):
                # run as during training to get validation loss (e.g. xent)

                batch = Batch(valid_batch,
                              self.pad_index,
                              use_cuda=self.use_cuda)

                encoder_output, encoder_hidden = self.model.encode(
                    batch.src, batch.src_lengths, batch.src_mask)

                # if maximum output length is
                # not globally specified, adapt to src len
                if self.max_output_length is None:
                    self.max_output_length = int(
                        max(batch.src_lengths.cpu().numpy()) * 1.5)

                batch_size = batch.src_mask.size(0)
                prev_y = batch.src_mask.new_full(size=[batch_size, 1],
                                                 fill_value=self.bos_index,
                                                 dtype=torch.long)
                output = []
                hidden = self.model.decoder._init_hidden(encoder_hidden)
                prev_att_vector = None
                finished = batch.src_mask.new_zeros((batch_size, 1)).byte()

                # pylint: disable=unused-variable
                for t in range(self.max_output_length):

                    # if i_sample == 0 or i_sample == 3 or i_sample == 6:
                    #     print("state on t = ", t, " : " , state)

                    # decode one single step
                    logits, hidden, att_probs, prev_att_vector = self.model.decoder(
                        encoder_output=encoder_output,
                        encoder_hidden=encoder_hidden,
                        src_mask=batch.src_mask,
                        trg_embed=self.model.trg_embed(prev_y),
                        hidden=hidden,
                        prev_att_vector=prev_att_vector,
                        unroll_steps=1)
                    # greedy decoding: choose arg max over vocabulary in each step with egreedy porbability

                    if self.state_type == 'hidden':
                        state = torch.cat(hidden,
                                          dim=2).squeeze(1).detach().cpu()[0]
                    else:
                        state = torch.FloatTensor(
                            prev_att_vector.squeeze(1).detach().cpu().numpy()
                            [0])

                    logits = self.eval_net(state)
                    logits = logits.reshape([1, 1, -1])
                    #print(type(logits), logits.shape, logits)
                    next_word = torch.argmax(logits, dim=-1)
                    a = next_word.squeeze(1).detach().cpu().numpy()[0]
                    prev_y = next_word

                    output.append(next_word.squeeze(1).detach().cpu().numpy())
                    prev_y = next_word

                    # check if previous symbol was <eos>
                    is_eos = torch.eq(next_word, self.eos_index)
                    finished += is_eos
                    # stop predicting if <eos> reached for all elements in batch
                    if (finished >= 1).sum() == batch_size:
                        break
                stacked_output = np.stack(output, axis=1)  # batch, time

                #decode back to symbols
                decoded_valid_in = self.model.trg_vocab.arrays_to_sentences(
                    arrays=batch.src, cut_at_eos=True)
                decoded_valid_out_trg = self.model.trg_vocab.arrays_to_sentences(
                    arrays=batch.trg, cut_at_eos=True)
                decoded_valid_out = self.model.trg_vocab.arrays_to_sentences(
                    arrays=stacked_output, cut_at_eos=True)

                hyp = stacked_output

                r = self.Reward(batch.trg, hyp, show=False)

                if i_sample == 0 or i_sample == 3 or i_sample == 6:
                    print(
                        "\n Sample ", i_sample,
                        "-------------Target vs Eval_net prediction:--Raw---and---Decoded-----"
                    )
                    print("Target: ", batch.trg, decoded_valid_out_trg)
                    print("Eval  : ", stacked_output, decoded_valid_out, "\n")
                    print("Reward: ", r)

                #r = self.Reward1(batch.trg, hyp , show = False)
                r_total += sum(r[np.where(r > 0)])
                if i_sample == 0:
                    roptimal = self.Reward(batch.trg, batch.trg, show=False)
                    roptimal_total += sum(roptimal[np.where(roptimal > 0)])

                all_outputs.extend(stacked_output)
                i_sample += 1

            assert len(all_outputs) == len(data_set)

            # decode back to symbols
            decoded_valid = self.model.trg_vocab.arrays_to_sentences(
                arrays=all_outputs, cut_at_eos=True)

            # evaluate with metric on full dataset
            join_char = " " if self.level in ["word", "bpe"] else ""
            valid_sources = [join_char.join(s) for s in data_set.src]
            valid_references = [join_char.join(t) for t in data_set.trg]
            valid_hypotheses = [join_char.join(t) for t in decoded_valid]

            # post-process
            if self.level == "bpe":
                valid_sources = [bpe_postprocess(s) for s in valid_sources]
                valid_references = [
                    bpe_postprocess(v) for v in valid_references
                ]
                valid_hypotheses = [
                    bpe_postprocess(v) for v in valid_hypotheses
                ]

            # if references are given, evaluate against them
            if valid_references:
                assert len(valid_hypotheses) == len(valid_references)

                current_valid_score = 0
                if self.eval_metric.lower() == 'bleu':
                    # this version does not use any tokenization
                    current_valid_score = bleu(valid_hypotheses,
                                               valid_references)
                elif self.eval_metric.lower() == 'chrf':
                    current_valid_score = chrf(valid_hypotheses,
                                               valid_references)
                elif self.eval_metric.lower() == 'token_accuracy':
                    current_valid_score = token_accuracy(valid_hypotheses,
                                                         valid_references,
                                                         level=self.level)
                elif self.eval_metric.lower() == 'sequence_accuracy':
                    current_valid_score = sequence_accuracy(
                        valid_hypotheses, valid_references)
            else:
                current_valid_score = -1

            self.dev_network_count += 1
            self.tb_writer.add_scalar("dev/dev_reward", r_total,
                                      self.dev_network_count)
            self.tb_writer.add_scalar("dev/dev_bleu", current_valid_score,
                                      self.dev_network_count)

            print(self.dev_network_count, ' r_total and score: ', r_total,
                  current_valid_score)

            unfreeze_model(self.eval_net)
        return current_valid_score
Exemplo n.º 8
0
    def Reward_lin(self, trg, hyp, show=False):
        """
        To use as self.Reward funtion. 
        Return an array of rewards, based on the current Score.
        From a T predicted sequence. Gives a reward per each T steps.
        Just when the predicted word is on the right place.

        :param trg: target.
        :param hyp: the predicted sequence.
        :param show: Boolean, display the computation of the rewards
        :return: current Bleu score
        """

        tar_len = trg.shape[1]
        hyp_len = hyp.shape[1]

        final_rew = -1 * np.ones(hyp_len - 1)

        len_temp = 0
        if tar_len > hyp_len:
            len_temp = hyp_len
        else:
            len_temp = tar_len
        hyp2com = np.zeros([1, tar_len])
        hyp2com[0, :len_temp] = hyp[0, :len_temp]

        equal = (trg.numpy() == hyp2com)

        #equal = np.invert(equal)*np.ones(equal.size)*0.2
        # ind1, ind2 = np.where(equal == False)

        # if len(ind1) != 0:
        #     equal[ind1[0]:, ind2[0]:] = False

        decoded_valid_tar = self.model.trg_vocab.arrays_to_sentences(
            arrays=trg, cut_at_eos=True)
        decoded_valid_hyp = self.model.trg_vocab.arrays_to_sentences(
            arrays=hyp, cut_at_eos=True)

        if show:
            print('la lista trg-out decodificada: ', decoded_valid_tar)
            print('la lista hypotesis decodificada: ', decoded_valid_hyp)

        # evaluate with metric on each src, tar, and hypotesis
        join_char = " " if self.level in ["word", "bpe"] else ""
        valid_references = [join_char.join(t) for t in decoded_valid_tar]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid_hyp]

        # post-process
        if self.level == "bpe":
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]
        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            current_valid_score = 0
            if self.eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses, valid_references)
            elif self.eval_metric.lower() == 'chrf':
                current_valid_score = chrf(valid_hypotheses, valid_references)
            elif self.eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(valid_hypotheses,
                                                     valid_references,
                                                     level=self.level)
            elif self.eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)
        else:
            current_valid_score = -1

        k = sum(np.arange(tar_len))
        a_i = np.arange(1, tar_len) / k
        VSa_i = [sum(a_i[:i]) for i in np.arange(1, tar_len, dtype='int')]
        VSa_i = np.multiply(
            np.asanyarray(VSa_i).reshape([1, tar_len - 1]),
            equal).reshape([tar_len - 1])

        final_rew[:len_temp - 1] = np.multiply(VSa_i,
                                               current_valid_score)[:len_temp]

        if show:
            print('Reward is: ', final_rew)
            print('sum: ', sum(final_rew))
        return final_rew
Exemplo n.º 9
0
def validate_on_data(model: Model,
                     data: Dataset,
                     batch_size: int,
                     use_cuda: bool,
                     max_output_length: int,
                     level: str,
                     eval_metric: Optional[str],
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 0,
                     beam_alpha: int = -1,
                     batch_type: str = "sentence",
                     kb_task = None,
                     valid_kb: Dataset = None,
                     valid_kb_lkp: list = [],
                     valid_kb_lens:list=[],
                     valid_kb_truvals: Dataset = None,
                     valid_data_canon: Dataset = None,
                     report_on_canonicals: bool = False,
                     ) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    also compute the loss.

    :param model: model module
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param level: segmentation level, one of "char", "bpe", "word"
    :param eval_metric: evaluation metric, e.g. "bleu"
    :param loss_function: loss function that computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If 0 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to -1 (default).
    :param batch_type: validation batch type (sentence or token)
    :param kb_task: is not None if kb_task should be executed
    :param valid_kb: MonoDataset holding the loaded valid kb data
    :param valid_kb_lkp: List with valid example index to corresponding kb indices
    :param valid_kb_len: List with amount of triples per kb 
    :param valid_data_canon: TranslationDataset of valid data but with canonized target data (for loss reporting)


    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
        - valid_ent_f1: TODO FIXME
    """

    print(f"\n{'-'*10} ENTER VALIDATION {'-'*10}\n")

    print(f"\n{'-'*10}  VALIDATION DEBUG {'-'*10}\n")

    print("---data---")
    print(dir(data[0]))
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr
    ] for example in data[:3]])
    print(batch_size)
    print(use_cuda)
    print(max_output_length)
    print(level)
    print(eval_metric)
    print(loss_function)
    print(beam_size)
    print(beam_alpha)
    print(batch_type)
    print(kb_task)
    print("---valid_kb---")
    print(dir(valid_kb[0]))
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr
    ] for example in valid_kb[:3]])
    print(len(valid_kb_lkp), valid_kb_lkp[-5:])
    print(len(valid_kb_lens), valid_kb_lens[-5:])
    print("---valid_kb_truvals---")
    print(len(valid_kb_truvals), valid_kb_lens[-5:])
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr or "trv" in attr
    ] for example in valid_kb_truvals[:3]])
    print("---valid_data_canon---")
    print(len(valid_data_canon), valid_data_canon[-5:])
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr or "trv" or "can" in attr
    ] for example in valid_data_canon[:3]])
    print(report_on_canonicals)

    print(f"\n{'-'*10} END VALIDATION DEBUG {'-'*10}\n")

    if not kb_task:
        valid_iter = make_data_iter(dataset=data,
                                    batch_size=batch_size,
                                    batch_type=batch_type,
                                    shuffle=False,
                                    train=False)
    else:
        # knowledgebase version of make data iter and also provide canonized target data
        # data: for bleu/ent f1
        # canon_data: for loss
        valid_iter = make_data_iter_kb(data,
                                       valid_kb,
                                       valid_kb_lkp,
                                       valid_kb_lens,
                                       valid_kb_truvals,
                                       batch_size=batch_size,
                                       batch_type=batch_type,
                                       shuffle=False,
                                       train=False,
                                       canonize=model.canonize,
                                       canon_data=valid_data_canon)

    valid_sources_raw = data.src
    pad_index = model.src_vocab.stoi[PAD_TOKEN]

    # disable dropout
    model.eval()
    # don't track gradients during validation
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = []
        valid_kb_att_scores = []
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda) \
                                if not kb_task else \
                Batch_with_KB(valid_batch, pad_index, use_cuda=use_cuda)

            assert hasattr(batch, "kbsrc") == bool(kb_task)

            # sort batch now by src length and keep track of order
            if not kb_task:
                sort_reverse_index = batch.sort_by_src_lengths()
            else:
                sort_reverse_index = list(range(batch.src.shape[0]))

            # run as during training with teacher forcing
            if loss_function is not None and batch.trg is not None:

                ntokens = batch.ntokens
                if hasattr(batch, "trgcanon") and batch.trgcanon is not None:
                    ntokens = batch.ntokenscanon  # normalize loss with num canonical tokens for perplexity
                # do a loss calculation without grad updates just to report valid loss
                # we can only do this when batch.trg exists, so not during actual translation/deployment
                batch_loss = model.get_loss_for_batch(
                    batch, loss_function=loss_function)
                # keep track of metrics for reporting
                total_loss += batch_loss
                total_ntokens += ntokens  # gold target tokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores, kb_att_scores = model.run_batch(
                batch=batch,
                beam_size=beam_size,
                beam_alpha=beam_alpha,
                max_output_length=max_output_length)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
            valid_attention_scores.extend(
                attention_scores[sort_reverse_index]
                if attention_scores is not None else [])
            valid_kb_att_scores.extend(kb_att_scores[sort_reverse_index]
                                       if kb_att_scores is not None else [])

        assert len(all_outputs) == len(data)

        if loss_function is not None and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log likelihood
            # can be seen as 2^(cross_entropy of model on valid set); normalized by num tokens;
            # see https://en.wikipedia.org/wiki/Perplexity#Perplexity_per_word
            valid_ppl = torch.exp(valid_loss / total_ntokens)
        else:
            valid_loss = -1
            valid_ppl = -1

        # decode back to symbols
        decoding_vocab = model.trg_vocab if not kb_task else model.trv_vocab

        decoded_valid = decoding_vocab.arrays_to_sentences(arrays=all_outputs,
                                                           cut_at_eos=True)

        print(f"decoding_vocab.itos: {decoding_vocab.itos}")
        print(decoded_valid)

        # evaluate with metric on full dataset
        join_char = " " if level in ["word", "bpe"] else ""
        valid_sources = [join_char.join(s) for s in data.src]
        # TODO replace valid_references with uncanonicalized dev.car data ... requires writing new Dataset in data.py
        valid_references = [join_char.join(t) for t in data.trg]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid]

        # post-process
        if level == "bpe":
            valid_sources = [bpe_postprocess(s) for s in valid_sources]
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]

        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            print(list(zip(valid_sources, valid_references, valid_hypotheses)))

            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(valid_hypotheses,
                                                     valid_references,
                                                     level=level)
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)

            if kb_task:
                valid_ent_f1, valid_ent_mcc = calc_ent_f1_and_ent_mcc(
                    valid_hypotheses,
                    valid_references,
                    vocab=model.trv_vocab,
                    c_fun=model.canonize,
                    report_on_canonicals=report_on_canonicals)

            else:
                valid_ent_f1, valid_ent_mcc = -1, -1
        else:
            current_valid_score = -1

    print(f"\n{'-'*10} EXIT VALIDATION {'-'*10}\n")
    return current_valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores, valid_kb_att_scores, \
        valid_ent_f1, valid_ent_mcc