Exemple #1
0
def evaluate(data_loader, context=ctx[0]):
    """Evaluate given the data loader

    Parameters
    ----------
    data_loader : DataLoader

    Returns
    -------
    avg_loss : float
        Average loss
    real_translation_out : list of list of str
        The translation output
    """
    translation_out = []
    all_inst_ids = []
    avg_loss_denom = 0
    avg_loss = 0.0
    for _, (src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids) \
            in enumerate(data_loader):
        src_seq = src_seq.as_in_context(context)
        tgt_seq = tgt_seq.as_in_context(context)
        src_valid_length = src_valid_length.as_in_context(context)
        tgt_valid_length = tgt_valid_length.as_in_context(context)
        # Calculating Loss
        out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1)
        tgt_pred_logits = mx.nd.log_softmax(out, axis=-1)
        nll_loss = - mx.nd.pick(tgt_pred_logits, tgt_seq[:, 1:], axis=-1)
        masked_nll_loss = mx.nd.SequenceMask(nll_loss,
                                             sequence_length=tgt_valid_length - 1,
                                             use_sequence_length=True,
                                             axis=1)
        loss = masked_nll_loss.sum().asscalar()
        # loss = test_loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean().asscalar()
        all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist())
        # avg_loss += loss * (tgt_seq.shape[1] - 1)
        # avg_loss_denom += (tgt_seq.shape[1] - 1)
        avg_loss += loss
        avg_loss_denom += (tgt_valid_length - 1).sum().asscalar()
        # Translate
        samples, _, sample_valid_length = \
            translator.translate(src_seq=src_seq, src_valid_length=src_valid_length)
        max_score_sample = samples[:, 0, :].asnumpy()
        sample_valid_length = sample_valid_length[:, 0].asnumpy()
        for i in range(max_score_sample.shape[0]):
            translation_out.append(
                [tgt_vocab.idx_to_token[ele] for ele in
                 max_score_sample[i][1:(sample_valid_length[i] - 1)]])
    avg_loss = avg_loss / avg_loss_denom
    real_translation_out = [None for _ in range(len(all_inst_ids))]
    for ind, sentence in zip(all_inst_ids, translation_out):
        if args.bleu == 'tweaked':
            real_translation_out[ind] = sentence
        elif args.bleu == '13a' or args.bleu == 'intl':
            real_translation_out[ind] = detokenizer(_bpe_to_words(sentence))
        else:
            raise NotImplementedError
    return avg_loss, real_translation_out
def evaluate(data_loader, context=ctx[0]):
    """Evaluate given the data loader

    Parameters
    ----------
    data_loader : DataLoader

    Returns
    -------
    avg_loss : float
        Average loss
    real_translation_out : list of list of str
        The translation output
    """
    translation_out = []
    all_inst_ids = []
    avg_loss_denom = 0
    avg_loss = 0.0
    for _, (src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids) \
            in enumerate(data_loader):
        src_seq = src_seq.as_in_context(context)
        tgt_seq = tgt_seq.as_in_context(context)
        src_valid_length = src_valid_length.as_in_context(context)
        tgt_valid_length = tgt_valid_length.as_in_context(context)
        # Calculating Loss
        out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1)
        loss = test_loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean().asscalar()
        all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist())
        avg_loss += loss * (tgt_seq.shape[1] - 1)
        avg_loss_denom += (tgt_seq.shape[1] - 1)
        # Translate
        samples, _, sample_valid_length = \
            translator.translate(src_seq=src_seq, src_valid_length=src_valid_length)
        max_score_sample = samples[:, 0, :].asnumpy()
        sample_valid_length = sample_valid_length[:, 0].asnumpy()
        for i in range(max_score_sample.shape[0]):
            translation_out.append(
                [tgt_vocab.idx_to_token[ele] for ele in
                 max_score_sample[i][1:(sample_valid_length[i] - 1)]])
    avg_loss = avg_loss / avg_loss_denom
    real_translation_out = [None for _ in range(len(all_inst_ids))]
    for ind, sentence in zip(all_inst_ids, translation_out):
        if args.bleu == 'tweaked':
            real_translation_out[ind] = sentence
        elif args.bleu == '13a' or args.bleu == 'intl':
            real_translation_out[ind] = detokenizer(_bpe_to_words(sentence),
                                                    return_str=True)
        else:
            raise NotImplementedError
    return avg_loss, real_translation_out
def inference():
    """inference function."""
    logging.info('Inference on test_dataset!')

    # data prepare
    test_data_loader = dataprocessor.get_dataloader(data_test, args,
                                                    dataset_type='test',
                                                    use_average_length=True)

    if args.bleu == 'tweaked':
        bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY')
        split_compound_word = bpe
        tokenized = True
    elif args.bleu == '13a' or args.bleu == 'intl':
        bpe = False
        split_compound_word = False
        tokenized = False
    else:
        raise NotImplementedError

    translation_out = []
    all_inst_ids = []
    total_wc = 0
    total_time = 0
    batch_total_blue = 0

    for batch_id, (src_seq, tgt_seq, src_test_length, tgt_test_length, inst_ids) \
            in enumerate(test_data_loader):

        total_wc += src_test_length.sum().asscalar() + tgt_test_length.sum().asscalar()

        src_seq = src_seq.as_in_context(ctx[0])
        tgt_seq = tgt_seq.as_in_context(ctx[0])
        src_test_length = src_test_length.as_in_context(ctx[0])
        tgt_test_length = tgt_test_length.as_in_context(ctx[0])
        all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist())

        start = time.time()
        # Translate to get a bleu score
        samples, _, sample_test_length = \
            translator.translate(src_seq=src_seq, src_valid_length=src_test_length)
        total_time += (time.time() - start)

        # generator the translator result for each batch
        max_score_sample = samples[:, 0, :].asnumpy()
        sample_test_length = sample_test_length[:, 0].asnumpy()
        translation_tmp = []
        translation_tmp_sentences = []
        for i in range(max_score_sample.shape[0]):
            translation_tmp.append([tgt_vocab.idx_to_token[ele] for ele in \
                                    max_score_sample[i][1:(sample_test_length[i] - 1)]])

        # detokenizer each translator result
        for _, sentence in enumerate(translation_tmp):
            if args.bleu == 'tweaked':
                translation_tmp_sentences.append(sentence)
                translation_out.append(sentence)
            elif args.bleu == '13a' or args.bleu == 'intl':
                translation_tmp_sentences.append(detokenizer(_bpe_to_words(sentence)))
                translation_out.append(detokenizer(_bpe_to_words(sentence)))
            else:
                raise NotImplementedError

        # generate tgt_sentence for bleu calculation of each batch
        tgt_sen_tmp = [test_tgt_sentences[index] for \
                         _, index in enumerate(inst_ids.asnumpy().astype(np.int32).tolist())]
        batch_test_bleu_score, _, _, _, _ = compute_bleu([tgt_sen_tmp], translation_tmp_sentences,
                                                         tokenized=tokenized, tokenizer=args.bleu,
                                                         split_compound_word=split_compound_word,
                                                         bpe=bpe)
        batch_total_blue += batch_test_bleu_score

        # log for every ten batchs
        if batch_id % 10 == 0 and batch_id != 0:
            batch_ave_bleu = batch_total_blue / 10
            batch_total_blue = 0
            logging.info('batch id={:d}, batch_bleu={:.4f}'
                         .format(batch_id, batch_ave_bleu * 100))

    # reorg translation sentences by inst_ids
    real_translation_out = [None for _ in range(len(all_inst_ids))]
    for ind, sentence in zip(all_inst_ids, translation_out):
        real_translation_out[ind] = sentence

    # get bleu score, n-gram precisions, brevity penalty,  reference length, and translation length
    test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], real_translation_out,
                                               tokenized=tokenized, tokenizer=args.bleu,
                                               split_compound_word=split_compound_word,
                                               bpe=bpe)

    logging.info('Inference at test dataset. \
                 inference bleu={:.4f}, throughput={:.4f}K wps'
                 .format(test_bleu_score * 100, total_wc / total_time / 1000))