def test_beam_search_and_decode_generate(self): """ A basic test that the output given by BeamSearchAndDecode class is the same as SequenceGenerator """ test_args = test_utils.ModelParamsDict(arch="rnn") test_args.sequence_lstm = True BEAM_SIZE = 1 WORD_REWARD = 1 UNK_REWARD = -1 LENGTH_PENALTY = 0 PLACEHOLDER_SEQ_LENGTH = 5 NBEST = 2 MAX_SEQ_LEN = 7 src_tokens = torch.LongTensor([[0, 0, 0]]) src_lengths = torch.LongTensor([3]) # Build model list samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args) task = tasks.DictionaryHolderTask(src_dict, tgt_dict) models = task.build_model(test_args) # Placeholder inputs for BeamSearchAndDecode placeholder_src_tokens = torch.LongTensor( np.ones((PLACEHOLDER_SEQ_LENGTH, 1), dtype="int64")) placeholder_src_lengths = torch.IntTensor( np.array([PLACEHOLDER_SEQ_LENGTH], dtype="int32")) prev_token = torch.LongTensor([tgt_dict.eos()]) prev_scores = torch.FloatTensor([0.0]) attn_weights = torch.zeros(src_lengths[0].item()) prev_hypos_indices = torch.zeros(BEAM_SIZE, dtype=torch.int64) num_steps = torch.LongTensor([MAX_SEQ_LEN]) # Generate output using SequenceGenerator translator = SequenceGenerator( [models], task.target_dictionary, beam_size=BEAM_SIZE, word_reward=WORD_REWARD, unk_reward=UNK_REWARD, ) encoder_input = {"src_tokens": src_tokens, "src_lengths": src_lengths} top_seq_gen_hypothesis = translator.generate(encoder_input, beam_size=BEAM_SIZE, maxlen=MAX_SEQ_LEN)[0] # Generate output using BeamSearch/BeamDecode placeholder_src_tokens = torch.LongTensor( np.ones((PLACEHOLDER_SEQ_LENGTH, 1), dtype="int64")) placeholder_src_lengths = torch.IntTensor( np.array([PLACEHOLDER_SEQ_LENGTH], dtype="int32")) # Generate output using BeamSearchAndDecode class beam_search_and_decode = BeamSearchAndDecode( [models], tgt_dict=tgt_dict, src_tokens=placeholder_src_tokens, src_lengths=placeholder_src_lengths, eos_token_id=tgt_dict.eos(), length_penalty=LENGTH_PENALTY, nbest=NBEST, beam_size=BEAM_SIZE, stop_at_eos=True, word_reward=WORD_REWARD, unk_reward=UNK_REWARD, quantize=True, ) beam_search_and_decode_output = beam_search_and_decode( src_tokens.transpose(0, 1), src_lengths, prev_token, prev_scores, attn_weights, prev_hypos_indices, num_steps[0], ) for hyp_index in range( min(len(beam_search_and_decode_output), len(top_seq_gen_hypothesis))): beam_search_and_decode_hypothesis = beam_search_and_decode_output[ hyp_index] # Compare two outputs # We always look only from 0 to MAX_SEQ_LEN, because sequence generator # adds an EOS at the end after MAX_SEQ_LEN # Compare two hypotheses np.testing.assert_array_equal( top_seq_gen_hypothesis[hyp_index]["tokens"].tolist() [0:MAX_SEQ_LEN], beam_search_and_decode_hypothesis[0].tolist()[0:MAX_SEQ_LEN], ) # Compare token level scores np.testing.assert_array_almost_equal( top_seq_gen_hypothesis[hyp_index] ["positional_scores"].tolist()[0:MAX_SEQ_LEN], beam_search_and_decode_hypothesis[2][0:MAX_SEQ_LEN], decimal=1, ) # Compare attention weights np.testing.assert_array_almost_equal( top_seq_gen_hypothesis[hyp_index]["attention"].numpy() [:, 0:MAX_SEQ_LEN], beam_search_and_decode_hypothesis[3].numpy()[:, 0:MAX_SEQ_LEN], decimal=1, )
def _test_full_beam_search_decoder(self, test_args, quantize=False): samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args) task = tasks.DictionaryHolderTask(src_dict, tgt_dict) sample = next(samples) # [seq len, batch size=1] src_tokens = sample["net_input"]["src_tokens"][0:1].t() # [seq len] src_lengths = sample["net_input"]["src_lengths"][0:1].long() num_models = 3 model_list = [] for _ in range(num_models): model_list.append(task.build_model(test_args)) eos_token_id = 8 length_penalty = 0.25 nbest = 3 stop_at_eos = True num_steps = torch.LongTensor([20]) beam_size = 6 bsd = BeamSearchAndDecode( model_list, tgt_dict, src_tokens, src_lengths, eos_token_id=eos_token_id, length_penalty=length_penalty, nbest=nbest, beam_size=beam_size, stop_at_eos=stop_at_eos, quantize=quantize, ) f = io.BytesIO() bsd.save_to_pytorch(f) # Test generalization with a different sequence length src_tokens = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 9, 9, 10, 11]).unsqueeze(1) src_lengths = torch.LongTensor([11]) prev_token = torch.LongTensor([0]) prev_scores = torch.FloatTensor([0.0]) attn_weights = torch.zeros(src_tokens.shape[0]) prev_hypos_indices = torch.zeros(beam_size, dtype=torch.int64) outs = bsd( src_tokens, src_lengths, prev_token, prev_scores, attn_weights, prev_hypos_indices, num_steps[0], ) f.seek(0) deserialized_bsd = torch.jit.load(f) deserialized_bsd.apply(lambda s: s._unpack() if hasattr(s, "_unpack") else None) outs_deserialized = deserialized_bsd( src_tokens, src_lengths, prev_token, prev_scores, attn_weights, prev_hypos_indices, num_steps[0], ) for hypo, hypo_deserialized in zip(outs, outs_deserialized): np.testing.assert_array_equal( hypo[0].tolist(), hypo_deserialized[0].tolist() ) np.testing.assert_array_almost_equal( hypo[2], hypo_deserialized[2], decimal=1 ) np.testing.assert_array_almost_equal( hypo[3].numpy(), hypo_deserialized[3].numpy(), decimal=1 )