Exemple #1
0
def test_transducer_beam_search(rnn_type, search_params):
    token_list = ["<blank>", "a", "b", "c", "<sos>"]
    vocab_size = len(token_list)
    beam_size = 1 if search_params["search_type"] == "greedy" else 2

    encoder_output_size = 4
    decoder_output_size = 4

    decoder = TransducerDecoder(vocab_size,
                                hidden_size=decoder_output_size,
                                rnn_type=rnn_type)
    joint_net = JointNetwork(vocab_size,
                             encoder_output_size,
                             decoder_output_size,
                             joint_space_size=2)

    lm = search_params.pop("lm", SequentialRNNLM(vocab_size, rnn_type="lstm"))
    if isinstance(lm, str) and lm == "TransformerLM":
        lm = TransformerLM(vocab_size, pos_enc=None, unit=10, layer=2)

    beam = BeamSearchTransducer(
        decoder,
        joint_net,
        beam_size=beam_size,
        lm=lm,
        token_list=token_list,
        **search_params,
    )

    enc_out = torch.randn(30, encoder_output_size)

    with torch.no_grad():
        _ = beam(enc_out)
Exemple #2
0
def test_TransformerLM_backward(pos_enc):
    model = TransformerLM(10, pos_enc=pos_enc, unit=10)
    input = torch.randint(0, 9, [2, 5])

    out, h = model(input, None)
    out, h = model(input, h)
    out.sum().backward()
Exemple #3
0
def test_TransformerLM_batch_beam_search(pos_enc, dtype):
    token_list = ["<blank>", "a", "b", "c", "unk", "<eos>"]
    vocab_size = len(token_list)

    model = TransformerLM(vocab_size, pos_enc=pos_enc, unit=10)
    beam = BatchBeamSearch(
        beam_size=3,
        vocab_size=vocab_size,
        weights={"test": 1.0},
        scorers={"test": model},
        token_list=token_list,
        sos=vocab_size - 1,
        eos=vocab_size - 1,
        pre_beam_score_key=None,
    )
    beam.to(dtype=dtype)

    enc = torch.randn(10, 20).type(dtype)
    with torch.no_grad():
        beam(
            x=enc,
            maxlenratio=0.0,
            minlenratio=0.0,
        )
Exemple #4
0
def test_TransformerLM_invalid_type():
    with pytest.raises(ValueError):
        TransformerLM(10, pos_enc="fooo")
Exemple #5
0
def test_TransformerLM_score(pos_enc):
    model = TransformerLM(10, pos_enc=pos_enc, unit=10)
    input = torch.randint(0, 9, (12, ))
    state = model.init_state(None)
    model.score(input, state, None)