def test_transducer_beam_search(rnn_type, search_params): token_list = ["<blank>", "a", "b", "c", "<sos>"] vocab_size = len(token_list) beam_size = 1 if search_params["search_type"] == "greedy" else 2 encoder_output_size = 4 decoder_output_size = 4 decoder = TransducerDecoder(vocab_size, hidden_size=decoder_output_size, rnn_type=rnn_type) joint_net = JointNetwork(vocab_size, encoder_output_size, decoder_output_size, joint_space_size=2) lm = search_params.pop("lm", SequentialRNNLM(vocab_size, rnn_type="lstm")) if isinstance(lm, str) and lm == "TransformerLM": lm = TransformerLM(vocab_size, pos_enc=None, unit=10, layer=2) beam = BeamSearchTransducer( decoder, joint_net, beam_size=beam_size, lm=lm, token_list=token_list, **search_params, ) enc_out = torch.randn(30, encoder_output_size) with torch.no_grad(): _ = beam(enc_out)
def test_TransformerLM_backward(pos_enc): model = TransformerLM(10, pos_enc=pos_enc, unit=10) input = torch.randint(0, 9, [2, 5]) out, h = model(input, None) out, h = model(input, h) out.sum().backward()
def test_TransformerLM_batch_beam_search(pos_enc, dtype): token_list = ["<blank>", "a", "b", "c", "unk", "<eos>"] vocab_size = len(token_list) model = TransformerLM(vocab_size, pos_enc=pos_enc, unit=10) beam = BatchBeamSearch( beam_size=3, vocab_size=vocab_size, weights={"test": 1.0}, scorers={"test": model}, token_list=token_list, sos=vocab_size - 1, eos=vocab_size - 1, pre_beam_score_key=None, ) beam.to(dtype=dtype) enc = torch.randn(10, 20).type(dtype) with torch.no_grad(): beam( x=enc, maxlenratio=0.0, minlenratio=0.0, )
def test_TransformerLM_invalid_type(): with pytest.raises(ValueError): TransformerLM(10, pos_enc="fooo")
def test_TransformerLM_score(pos_enc): model = TransformerLM(10, pos_enc=pos_enc, unit=10) input = torch.randint(0, 9, (12, )) state = model.init_state(None) model.score(input, state, None)