Ejemplo n.º 1
0
def validate_on_data(model: Model, data: Dataset,
                     batch_size: int,
                     use_cuda: bool, max_output_length: int,
                     src_level: str,
                     trg_level: str,
                     eval_metrics: Optional[Sequence[str]],
                     attn_metrics: Optional[Sequence[str]],
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 0, beam_alpha: int = 0,
                     batch_type: str = "sentence",
                     save_attention: bool = False,
                     log_sparsity: bool = False,
                     apply_mask: bool = True  # hmm
                     ) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    also compute the loss.

    :param model: model module
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param src_level: source segmentation level, one of "char", "bpe", "word"
    :param trg_level: target segmentation level, one of "char", "bpe", "word"
    :param eval_metrics: evaluation metric, e.g. "bleu"
    :param loss_function: loss function that computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If 0 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to 0 (default).
    :param batch_type: validation batch type (sentence or token)

    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
    """
    eval_funcs = {
        "bleu": bleu,
        "chrf": chrf,
        "token_accuracy": partial(token_accuracy, level=trg_level),
        "sequence_accuracy": sequence_accuracy,
        "wer": wer,
        "cer": partial(character_error_rate, level=trg_level)
    }
    selected_eval_metrics = {name: eval_funcs[name] for name in eval_metrics}

    valid_iter = make_data_iter(
        dataset=data, batch_size=batch_size, batch_type=batch_type,
        shuffle=False, train=False)
    valid_sources_raw = [s for s in data.src]
    pad_index = model.src_vocab.stoi[PAD_TOKEN]
    # disable dropout
    model.eval()
    # don't track gradients during validation
    scorer = partial(len_penalty, alpha=beam_alpha) if beam_alpha > 0 else None
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = defaultdict(list)
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        total_attended = defaultdict(int)
        greedy_steps = 0
        greedy_supported = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda)
            # sort batch now by src length and keep track of order
            sort_reverse_index = batch.sort_by_src_lengths()

            # run as during training with teacher forcing
            if loss_function is not None and batch.trg is not None:
                batch_loss = model.get_loss_for_batch(
                    batch, loss_function=loss_function)
                total_loss += batch_loss
                total_ntokens += batch.ntokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores, probs = model.run_batch(
                batch=batch, beam_size=beam_size, scorer=scorer,
                max_output_length=max_output_length, log_sparsity=log_sparsity,
                apply_mask=apply_mask)
            if log_sparsity:
                lengths = torch.LongTensor((output == model.trg_vocab.stoi[EOS_TOKEN]).argmax(axis=1)).unsqueeze(1)
                batch_greedy_steps = lengths.sum().item()
                greedy_steps += lengths.sum().item()

                ix = torch.arange(output.shape[1]).unsqueeze(0).expand(output.shape[0], -1)
                mask = ix <= lengths
                supp = probs.exp().gt(0).sum(dim=-1).cpu()  # batch x len
                supp = torch.where(mask, supp, torch.tensor(0)).sum()
                greedy_supported += supp.float().item()

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])

            if attention_scores is not None:
                # is attention_scores ever None?
                if save_attention:
                    # beam search currently does not support attention logging
                    for k, v in attention_scores.items():
                        valid_attention_scores[k].extend(v[sort_reverse_index])
                if attn_metrics:
                    # add to total_attended
                    for k, v in attention_scores.items():
                        total_attended[k] += (v > 0).sum()

        assert len(all_outputs) == len(data)

        if log_sparsity:
            print(greedy_supported / greedy_steps)

        valid_scores = dict()
        if loss_function is not None and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            valid_scores["loss"] = total_loss
            valid_scores["ppl"] = torch.exp(total_loss / total_ntokens)

        # decode back to symbols
        decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs,
                                                            cut_at_eos=True)

        # evaluate with metric on full dataset
        src_join_char = " " if src_level in ["word", "bpe"] else ""
        trg_join_char = " " if trg_level in ["word", "bpe"] else ""
        valid_sources = [src_join_char.join(s) for s in data.src]
        valid_references = [trg_join_char.join(t) for t in data.trg]
        valid_hypotheses = [trg_join_char.join(t) for t in decoded_valid]

        if attn_metrics:
            decoded_ntokens = sum(len(t) for t in decoded_valid)
            for attn_metric in attn_metrics:
                assert attn_metric == "support"
                for attn_name, tot_attended in total_attended.items():
                    score_name = attn_name + "_" + attn_metric
                    # this is not the right denominator
                    valid_scores[score_name] = tot_attended / decoded_ntokens

        # post-process
        if src_level == "bpe":
            valid_sources = [bpe_postprocess(s) for s in valid_sources]
        if trg_level == "bpe":
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]

        languages = [language for language in data.language]
        by_language = defaultdict(list)
        seqs = zip(valid_references, valid_hypotheses) if valid_references else valid_hypotheses
        if languages:
            examples = zip(languages, seqs)
            for lang, seq in examples:
                by_language[lang].append(seq)
        else:
            by_language[None].extend(seqs)

        # if references are given, evaluate against them
        # incorrect if-condition?
        # scores_by_lang = {name: dict() for name in selected_eval_metrics}
        scores_by_lang = dict()
        if valid_references and eval_metrics is not None:
            assert len(valid_hypotheses) == len(valid_references)

            for eval_metric, eval_func in selected_eval_metrics.items():
                score_by_lang = dict()
                for lang, pairs in by_language.items():
                    lang_hyps, lang_refs = zip(*pairs)
                    lang_score = eval_func(lang_hyps, lang_refs)
                    score_by_lang[lang] = lang_score

                score = sum(score_by_lang.values()) / len(score_by_lang)
                valid_scores[eval_metric] = score
                scores_by_lang[eval_metric] = score_by_lang

    if not languages:
        scores_by_lang = None
    return valid_scores, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores, scores_by_lang, by_language
Ejemplo n.º 2
0
def validate_on_data(model: Model, data: Dataset,
                     logger: Logger,
                     batch_size: int,
                     use_cuda: bool, max_output_length: int,
                     level: str, eval_metric: Optional[str],
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 1, beam_alpha: int = -1,
                     batch_type: str = "sentence",
                     postprocess: bool = True
                     ) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    also compute the loss.

    :param model: model module
    :param logger: logger
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param level: segmentation level, one of "char", "bpe", "word"
    :param eval_metric: evaluation metric, e.g. "bleu"
    :param loss_function: loss function that computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If <2 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to -1 (default).
    :param batch_type: validation batch type (sentence or token)
    :param postprocess: if True, remove BPE segmentation from translations

    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
    """
    if batch_size > 1000 and batch_type == "sentence":
        logger.warning(
            "WARNING: Are you sure you meant to work on huge batches like "
            "this? 'batch_size' is > 1000 for sentence-batching. "
            "Consider decreasing it or switching to"
            " 'eval_batch_type: token'.")
    valid_iter = make_data_iter(dataset=data,
                                batch_size=batch_size,
                                batch_type=batch_type,
                                shuffle=False,
                                train=False)
    valid_sources_raw = data.src
    pad_index = model.src_vocab.stoi[PAD_TOKEN]
    # disable dropout
    model.eval()
    # don't track gradients during validation
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = []
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda)
            # sort batch now by src length and keep track of order
            sort_reverse_index = batch.sort_by_src_lengths()

            # run as during training with teacher forcing
            if loss_function is not None and batch.trg is not None:
                batch_loss = model.get_loss_for_batch(
                    batch, loss_function=loss_function)
                total_loss += batch_loss
                total_ntokens += batch.ntokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores = model.run_batch(
                batch=batch,
                beam_size=beam_size,
                beam_alpha=beam_alpha,
                max_output_length=max_output_length)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
            valid_attention_scores.extend(
                attention_scores[sort_reverse_index]
                if attention_scores is not None else [])

        assert len(all_outputs) == len(data)

        if loss_function is not None and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log prob
            valid_ppl = torch.exp(total_loss / total_ntokens)
        else:
            valid_loss = -1
            valid_ppl = -1

        # decode back to symbols
        decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs,
                                                            cut_at_eos=True)

        # evaluate with metric on full dataset
        join_char = " " if level in ["word", "bpe"] else ""
        valid_sources = [join_char.join(s) for s in data.src]
        valid_references = [join_char.join(t) for t in data.trg]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid]

        # post-process
        if level == "bpe" and postprocess:
            valid_sources = [bpe_postprocess(s) for s in valid_sources]
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]

        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(valid_hypotheses,
                                                     valid_references,
                                                     level=level)
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)
        else:
            current_valid_score = -1

    return current_valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores
Ejemplo n.º 3
0
def validate_on_data(model: Model,
                     data: Dataset,
                     batch_size: int,
                     use_cuda: bool,
                     max_output_length: int,
                     level: str,
                     eval_metric: Optional[str],
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 0,
                     beam_alpha: int = -1,
                     batch_type: str = "sentence",
                     kb_task = None,
                     valid_kb: Dataset = None,
                     valid_kb_lkp: list = [],
                     valid_kb_lens:list=[],
                     valid_kb_truvals: Dataset = None,
                     valid_data_canon: Dataset = None,
                     report_on_canonicals: bool = False,
                     ) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    also compute the loss.

    :param model: model module
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param level: segmentation level, one of "char", "bpe", "word"
    :param eval_metric: evaluation metric, e.g. "bleu"
    :param loss_function: loss function that computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If 0 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to -1 (default).
    :param batch_type: validation batch type (sentence or token)
    :param kb_task: is not None if kb_task should be executed
    :param valid_kb: MonoDataset holding the loaded valid kb data
    :param valid_kb_lkp: List with valid example index to corresponding kb indices
    :param valid_kb_len: List with amount of triples per kb 
    :param valid_data_canon: TranslationDataset of valid data but with canonized target data (for loss reporting)


    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
        - valid_ent_f1: TODO FIXME
    """

    print(f"\n{'-'*10} ENTER VALIDATION {'-'*10}\n")

    print(f"\n{'-'*10}  VALIDATION DEBUG {'-'*10}\n")

    print("---data---")
    print(dir(data[0]))
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr
    ] for example in data[:3]])
    print(batch_size)
    print(use_cuda)
    print(max_output_length)
    print(level)
    print(eval_metric)
    print(loss_function)
    print(beam_size)
    print(beam_alpha)
    print(batch_type)
    print(kb_task)
    print("---valid_kb---")
    print(dir(valid_kb[0]))
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr
    ] for example in valid_kb[:3]])
    print(len(valid_kb_lkp), valid_kb_lkp[-5:])
    print(len(valid_kb_lens), valid_kb_lens[-5:])
    print("---valid_kb_truvals---")
    print(len(valid_kb_truvals), valid_kb_lens[-5:])
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr or "trv" in attr
    ] for example in valid_kb_truvals[:3]])
    print("---valid_data_canon---")
    print(len(valid_data_canon), valid_data_canon[-5:])
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr or "trv" or "can" in attr
    ] for example in valid_data_canon[:3]])
    print(report_on_canonicals)

    print(f"\n{'-'*10} END VALIDATION DEBUG {'-'*10}\n")

    if not kb_task:
        valid_iter = make_data_iter(dataset=data,
                                    batch_size=batch_size,
                                    batch_type=batch_type,
                                    shuffle=False,
                                    train=False)
    else:
        # knowledgebase version of make data iter and also provide canonized target data
        # data: for bleu/ent f1
        # canon_data: for loss
        valid_iter = make_data_iter_kb(data,
                                       valid_kb,
                                       valid_kb_lkp,
                                       valid_kb_lens,
                                       valid_kb_truvals,
                                       batch_size=batch_size,
                                       batch_type=batch_type,
                                       shuffle=False,
                                       train=False,
                                       canonize=model.canonize,
                                       canon_data=valid_data_canon)

    valid_sources_raw = data.src
    pad_index = model.src_vocab.stoi[PAD_TOKEN]

    # disable dropout
    model.eval()
    # don't track gradients during validation
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = []
        valid_kb_att_scores = []
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda) \
                                if not kb_task else \
                Batch_with_KB(valid_batch, pad_index, use_cuda=use_cuda)

            assert hasattr(batch, "kbsrc") == bool(kb_task)

            # sort batch now by src length and keep track of order
            if not kb_task:
                sort_reverse_index = batch.sort_by_src_lengths()
            else:
                sort_reverse_index = list(range(batch.src.shape[0]))

            # run as during training with teacher forcing
            if loss_function is not None and batch.trg is not None:

                ntokens = batch.ntokens
                if hasattr(batch, "trgcanon") and batch.trgcanon is not None:
                    ntokens = batch.ntokenscanon  # normalize loss with num canonical tokens for perplexity
                # do a loss calculation without grad updates just to report valid loss
                # we can only do this when batch.trg exists, so not during actual translation/deployment
                batch_loss = model.get_loss_for_batch(
                    batch, loss_function=loss_function)
                # keep track of metrics for reporting
                total_loss += batch_loss
                total_ntokens += ntokens  # gold target tokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores, kb_att_scores = model.run_batch(
                batch=batch,
                beam_size=beam_size,
                beam_alpha=beam_alpha,
                max_output_length=max_output_length)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
            valid_attention_scores.extend(
                attention_scores[sort_reverse_index]
                if attention_scores is not None else [])
            valid_kb_att_scores.extend(kb_att_scores[sort_reverse_index]
                                       if kb_att_scores is not None else [])

        assert len(all_outputs) == len(data)

        if loss_function is not None and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log likelihood
            # can be seen as 2^(cross_entropy of model on valid set); normalized by num tokens;
            # see https://en.wikipedia.org/wiki/Perplexity#Perplexity_per_word
            valid_ppl = torch.exp(valid_loss / total_ntokens)
        else:
            valid_loss = -1
            valid_ppl = -1

        # decode back to symbols
        decoding_vocab = model.trg_vocab if not kb_task else model.trv_vocab

        decoded_valid = decoding_vocab.arrays_to_sentences(arrays=all_outputs,
                                                           cut_at_eos=True)

        print(f"decoding_vocab.itos: {decoding_vocab.itos}")
        print(decoded_valid)

        # evaluate with metric on full dataset
        join_char = " " if level in ["word", "bpe"] else ""
        valid_sources = [join_char.join(s) for s in data.src]
        # TODO replace valid_references with uncanonicalized dev.car data ... requires writing new Dataset in data.py
        valid_references = [join_char.join(t) for t in data.trg]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid]

        # post-process
        if level == "bpe":
            valid_sources = [bpe_postprocess(s) for s in valid_sources]
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]

        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            print(list(zip(valid_sources, valid_references, valid_hypotheses)))

            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(valid_hypotheses,
                                                     valid_references,
                                                     level=level)
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)

            if kb_task:
                valid_ent_f1, valid_ent_mcc = calc_ent_f1_and_ent_mcc(
                    valid_hypotheses,
                    valid_references,
                    vocab=model.trv_vocab,
                    c_fun=model.canonize,
                    report_on_canonicals=report_on_canonicals)

            else:
                valid_ent_f1, valid_ent_mcc = -1, -1
        else:
            current_valid_score = -1

    print(f"\n{'-'*10} EXIT VALIDATION {'-'*10}\n")
    return current_valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores, valid_kb_att_scores, \
        valid_ent_f1, valid_ent_mcc