Ejemplo n.º 1
0
def test_TransformerDecoder_batch_beam_search(input_layer, normalize_before,
                                              use_output_layer, dtype,
                                              decoder_class):
    token_list = ["<blank>", "a", "b", "c", "unk", "<eos>"]
    vocab_size = len(token_list)
    encoder_output_size = 4

    decoder = decoder_class(
        vocab_size=vocab_size,
        encoder_output_size=encoder_output_size,
        input_layer=input_layer,
        normalize_before=normalize_before,
        use_output_layer=use_output_layer,
        linear_units=10,
    )
    beam = BatchBeamSearch(
        beam_size=3,
        vocab_size=vocab_size,
        weights={"test": 1.0},
        scorers={"test": decoder},
        token_list=token_list,
        sos=vocab_size - 1,
        eos=vocab_size - 1,
        pre_beam_score_key=None,
    )
    beam.to(dtype=dtype)

    enc = torch.randn(10, encoder_output_size).type(dtype)
    with torch.no_grad():
        beam(
            x=enc,
            maxlenratio=0.0,
            minlenratio=0.0,
        )
Ejemplo n.º 2
0
def test_SequentialRNNLM_batch_beam_search(rnn_type, tie_weights, dtype):
    token_list = ["<blank>", "a", "b", "c", "unk", "<eos>"]
    vocab_size = len(token_list)

    model = SequentialRNNLM(
        vocab_size, nlayers=2, rnn_type=rnn_type, tie_weights=tie_weights
    )
    beam = BatchBeamSearch(
        beam_size=3,
        vocab_size=vocab_size,
        weights={"test": 1.0},
        scorers={"test": model},
        token_list=token_list,
        sos=vocab_size - 1,
        eos=vocab_size - 1,
        pre_beam_score_key=None,
    )
    beam.to(dtype=dtype)

    enc = torch.randn(10, 20).type(dtype)
    with torch.no_grad():
        beam(
            x=enc,
            maxlenratio=0.0,
            minlenratio=0.0,
        )
Ejemplo n.º 3
0
def test_batchfy_hyp():
    vocab_size = 5
    eos = -1
    # simplest beam search
    beam = BatchBeamSearch(
        beam_size=3,
        vocab_size=vocab_size,
        weights={
            "a": 0.5,
            "b": 0.5
        },
        scorers={
            "a": LengthBonus(vocab_size),
            "b": LengthBonus(vocab_size)
        },
        pre_beam_score_key="a",
        sos=eos,
        eos=eos,
    )
    hs = [
        Hypothesis(
            yseq=torch.tensor([0, 1, 2]),
            score=torch.tensor(0.15),
            scores={
                "a": torch.tensor(0.1),
                "b": torch.tensor(0.2)
            },
            states={
                "a": 1,
                "b": 2
            },
        ),
        Hypothesis(
            yseq=torch.tensor([0, 1]),
            score=torch.tensor(0.1),
            scores={
                "a": torch.tensor(0.0),
                "b": torch.tensor(0.2)
            },
            states={
                "a": 3,
                "b": 4
            },
        ),
    ]
    bs = beam.batchfy(hs)
    assert torch.all(bs.yseq == torch.tensor([[0, 1, 2], [0, 1, eos]]))
    assert torch.all(bs.score == torch.tensor([0.15, 0.1]))
    assert torch.all(bs.scores["a"] == torch.tensor([0.1, 0.0]))
    assert torch.all(bs.scores["b"] == torch.tensor([0.2, 0.2]))
    assert bs.states["a"] == [1, 3]
    assert bs.states["b"] == [2, 4]

    us = beam.unbatchfy(bs)
    for i in range(len(hs)):
        assert us[i].yseq.tolist() == hs[i].yseq.tolist()
        assert us[i].score == hs[i].score
        assert us[i].scores == hs[i].scores
        assert us[i].states == hs[i].states
Ejemplo n.º 4
0
def test_TransformerLM_batch_beam_search(pos_enc, dtype):
    token_list = ["<blank>", "a", "b", "c", "unk", "<eos>"]
    vocab_size = len(token_list)

    model = TransformerLM(vocab_size, pos_enc=pos_enc, unit=10)
    beam = BatchBeamSearch(
        beam_size=3,
        vocab_size=vocab_size,
        weights={"test": 1.0},
        scorers={"test": model},
        token_list=token_list,
        sos=vocab_size - 1,
        eos=vocab_size - 1,
        pre_beam_score_key=None,
    )
    beam.to(dtype=dtype)

    enc = torch.randn(10, 20).type(dtype)
    with torch.no_grad():
        beam(
            x=enc,
            maxlenratio=0.0,
            minlenratio=0.0,
        )
Ejemplo n.º 5
0
def test_batch_beam_search_equal(
    model_class,
    args,
    ctc_weight,
    lm_nn,
    lm_args,
    lm_weight,
    ngram_weight,
    bonus,
    device,
    dtype,
):
    if device == "cuda" and not torch.cuda.is_available():
        pytest.skip("no cuda device is available")
    if device == "cpu" and dtype == "float16":
        pytest.skip(
            "cpu float16 implementation is not available in pytorch yet")

    # seed setting
    torch.manual_seed(123)
    torch.backends.cudnn.deterministic = True
    # https://github.com/pytorch/pytorch/issues/6351
    torch.backends.cudnn.benchmark = False

    dtype = getattr(torch, dtype)
    model, x, ilens, y, data, train_args = prepare(model_class,
                                                   args,
                                                   mtlalpha=ctc_weight)
    model.eval()
    char_list = train_args.char_list
    lm = dynamic_import_lm(lm_nn, backend="pytorch")(len(char_list), lm_args)
    lm.eval()
    root = os.path.dirname(os.path.abspath(__file__))
    ngram = NgramFullScorer(os.path.join(root, "beam_search_test.arpa"),
                            args.char_list)

    # test previous beam search
    args = Namespace(
        beam_size=3,
        penalty=bonus,
        ctc_weight=ctc_weight,
        maxlenratio=0,
        lm_weight=lm_weight,
        ngram_weight=ngram_weight,
        minlenratio=0,
        nbest=5,
    )

    # new beam search
    scorers = model.scorers()
    if lm_weight != 0:
        scorers["lm"] = lm
    if ngram_weight != 0:
        scorers["ngram"] = ngram
    scorers["length_bonus"] = LengthBonus(len(char_list))
    weights = dict(
        decoder=1.0 - ctc_weight,
        ctc=ctc_weight,
        lm=args.lm_weight,
        ngram=args.ngram_weight,
        length_bonus=args.penalty,
    )
    model.to(device, dtype=dtype)
    model.eval()
    with torch.no_grad():
        enc = model.encode(x[0, :ilens[0]].to(device, dtype=dtype))

    legacy_beam = BeamSearch(
        beam_size=args.beam_size,
        vocab_size=len(char_list),
        weights=weights,
        scorers=scorers,
        token_list=train_args.char_list,
        sos=model.sos,
        eos=model.eos,
        pre_beam_score_key=None if ctc_weight == 1.0 else "decoder",
    )
    legacy_beam.to(device, dtype=dtype)
    legacy_beam.eval()

    beam = BatchBeamSearch(
        beam_size=args.beam_size,
        vocab_size=len(char_list),
        weights=weights,
        scorers=scorers,
        token_list=train_args.char_list,
        sos=model.sos,
        eos=model.eos,
    )
    beam.to(device, dtype=dtype)
    beam.eval()
    with torch.no_grad():
        legacy_nbest_bs = legacy_beam(x=enc,
                                      maxlenratio=args.maxlenratio,
                                      minlenratio=args.minlenratio)
        nbest_bs = beam(x=enc,
                        maxlenratio=args.maxlenratio,
                        minlenratio=args.minlenratio)

    for i, (expected, actual) in enumerate(zip(legacy_nbest_bs, nbest_bs)):
        assert expected.yseq.tolist() == actual.yseq.tolist()
        numpy.testing.assert_allclose(expected.score.cpu(),
                                      actual.score.cpu(),
                                      rtol=1e-6)