コード例 #1
0
ファイル: run.py プロジェクト: nguyenvdat/CS224
def beam_search(model: NMT, test_data_src: List[List[str]], beam_size: int,
                max_decoding_time_step: int) -> List[List[Hypothesis]]:
    """ Run beam search to construct hypotheses for a list of src-language sentences.
    @param model (NMT): NMT Model
    @param test_data_src (List[List[str]]): List of sentences (words) in source language, from test set.
    @param beam_size (int): beam_size (# of hypotheses to hold for a translation at every step)
    @param max_decoding_time_step (int): maximum sentence length that Beam search can produce
    @returns hypotheses (List[List[Hypothesis]]): List of Hypothesis translations for every source sentence.
    """
    was_training = model.training
    model.eval()

    hypotheses = []
    with torch.no_grad():
        for src_sent in tqdm(test_data_src, desc='Decoding', file=sys.stdout):
            example_hyps = model.beam_search(
                src_sent,
                beam_size=beam_size,
                max_decoding_time_step=max_decoding_time_step)

            hypotheses.append(example_hyps)

    if was_training: model.train(was_training)

    return hypotheses
コード例 #2
0
def beam_search(model: NMT, test_iterator: BucketIterator, beam_size: int,
                max_decoding_time_step: int) -> List[List[Hypothesis]]:
    """ Run beam search to construct hypotheses for a list of src-language sentences.
    @param model (NMT): NMT Model
    @param test_iterator BucketIterator: BucketIterator in source language, from test set.
    @param beam_size (int): beam_size (# of hypotheses to hold for a translation at every step)
    @param max_decoding_time_step (int): maximum sentence length that Beam search can produce
    @returns hypotheses (List[List[Hypothesis]]): List of Hypothesis translations for every source sentence.
    """
    was_training = model.training
    model.eval()

    hypotheses = []
    with torch.no_grad():
        # for src_sent in tqdm(test_data_src, desc='Decoding', file=sys.stdout):
        for i, batch in enumerate(test_iterator):
            src_sents, src_sents_lens = batch.src
            src_sents = src_sents.permute(1, 0)
            for j in range(len(src_sents_lens)):
                src_sent = src_sents[j]
                example_hyps = model.beam_search(
                    src_sent,
                    src_sents_lens[j],
                    beam_size=beam_size,
                    max_decoding_time_step=max_decoding_time_step)
                hypotheses.append(example_hyps)

    if was_training: model.train(was_training)

    return hypotheses
コード例 #3
0
def beam_search2(model1: NMT, model2: DPPNMT, test_data_src: List[List[str]],
                 beam_size: int, max_decoding_time_step: int,
                 test_data_tgt) -> List[List[Hypothesis]]:
    """ Run beam search to construct hypotheses for a list of src-language sentences.
    @param model (NMT): NMT Model
    @param test_data_src (List[List[str]]): List of sentences (words) in source language, from test set.
    @param beam_size (int): beam_size (# of hypotheses to hold for a translation at every step)
    @param max_decoding_time_step (int): maximum sentence length that Beam search can produce
    @returns hypotheses (List[List[Hypothesis]]): List of Hypothesis translations for every source sentence.
    """
    model1.eval()
    model2.eval()

    i = 0
    with torch.no_grad():
        for src_sent in tqdm(test_data_src, desc='Decoding', file=sys.stdout):
            hyp1 = model1.beam_search(
                src_sent,
                beam_size=beam_size,
                max_decoding_time_step=max_decoding_time_step)
            hyp2 = model2.beam_search(
                src_sent,
                beam_size=beam_size,
                max_decoding_time_step=max_decoding_time_step)
            ref = test_data_tgt[i][1:-1]
            #print(ref, hyp1[0].value)
            bleu_topk = sentence_bleu(ref, hyp1[0].value)
            bleu_dpp = sentence_bleu(test_data_tgt[i], hyp2[0].value)
            #print(bleu_topk, bleu_dpp)
            if bleu_dpp > bleu_topk:
                print(i)
                print(" ".join(hyp1[0].value))
                print(" ".join(hyp2[0].value))
                print(" ".join(ref))
            i += 1
コード例 #4
0
def beam_search(model: NMT, test_data_src: List[List[str]], beam_size: int, max_decoding_time_step: int)\
        -> List[List[Hypothesis]]:
    """ Run beam search to construct hypotheses for a list of src-language sentences.
    :param NMT model: NMT Model
    :param List[List[str]] test_data_src: List of sentences (words) in source language, from test set
    :param int beam_size: beam_size (number of hypotheses to keep for a translation at every step)
    :param int max_decoding_time_step: maximum sentence length that beam search can produce
    :returns List[List[Hypothesis]] hypotheses: List of Hypothesis translations for every source sentence
    """
    was_training = model.training
    model.eval()

    hypotheses = []
    with torch.no_grad():
        for src_sent in tqdm(test_data_src, desc='Decoding', file=sys.stdout):
            example_hyps = model.beam_search(
                src_sent,
                beam_size=beam_size,
                max_decoding_time_step=max_decoding_time_step)
            hypotheses.append(example_hyps)

    if was_training: model.train(was_training)

    return hypotheses