def test_batched_beam_from_checkpoints_vr(self): check_dir = ( '/mnt/gfsdataswarm-global/namespaces/search/language-technology-mt/' 'nnmt_tmp/tl_XX-en_XX-pytorch-256-dim-vocab-reduction' ) checkpoints = [ 'averaged_checkpoint_best_3.pt', 'averaged_checkpoint_best_4.pt', 'averaged_checkpoint_best_5.pt', ] checkpoint_filenames = [os.path.join(check_dir, f) for f in checkpoints] encoder_ensemble = EncoderEnsemble.build_from_checkpoints( checkpoint_filenames, os.path.join(check_dir, 'dictionary-tl.txt'), os.path.join(check_dir, 'dictionary-en.txt'), ) decoder_step_ensemble = DecoderBatchedStepEnsemble.build_from_checkpoints( checkpoint_filenames, os.path.join(check_dir, 'dictionary-tl.txt'), os.path.join(check_dir, 'dictionary-en.txt'), beam_size=5, ) self._test_full_ensemble( encoder_ensemble, decoder_step_ensemble, batched_beam=True, )
def _test_full_ensemble_export(self, test_args): samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args) num_models = 3 model_list = [] for _ in range(num_models): model_list.append(models.build_model(test_args, src_dict, tgt_dict)) encoder_ensemble = EncoderEnsemble(model_list) # test equivalence # The discrepancy in types here is a temporary expedient. # PyTorch indexing requires int64 while support for tracing # pack_padded_sequence() requires int32. sample = next(samples) src_tokens = sample["net_input"]["src_tokens"][0:1].t() src_lengths = sample["net_input"]["src_lengths"][0:1].int() pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths) decoder_step_ensemble = DecoderStepEnsemble(model_list, beam_size=5) tmp_dir = tempfile.mkdtemp() decoder_step_pb_path = os.path.join(tmp_dir, "decoder_step.pb") decoder_step_ensemble.onnx_export(decoder_step_pb_path, pytorch_encoder_outputs) # single EOS input_token = torch.LongTensor( np.array([[model_list[0].dst_dict.eos()]])) timestep = torch.LongTensor(np.array([[0]])) pytorch_decoder_outputs = decoder_step_ensemble( input_token, timestep, *pytorch_encoder_outputs) with open(decoder_step_pb_path, "r+b") as f: onnx_model = onnx.load(f) onnx_decoder = caffe2_backend.prepare(onnx_model) decoder_inputs_numpy = [input_token.numpy(), timestep.numpy()] for tensor in pytorch_encoder_outputs: decoder_inputs_numpy.append(tensor.detach().numpy()) caffe2_decoder_outputs = onnx_decoder.run(tuple(decoder_inputs_numpy)) for i in range(len(pytorch_decoder_outputs)): caffe2_out_value = caffe2_decoder_outputs[i] pytorch_out_value = pytorch_decoder_outputs[i].detach().numpy() np.testing.assert_allclose(caffe2_out_value, pytorch_out_value, rtol=1e-4, atol=1e-6) decoder_step_ensemble.save_to_db( os.path.join(tmp_dir, "decoder_step.predictor_export"), pytorch_encoder_outputs, )
def test_export_encoder_from_checkpoints_no_vr(self): check_dir = ( '/mnt/gfsdataswarm-global/namespaces/search/language-technology-mt/' 'nnmt_tmp/tl_XX-en_XX-pytorch-testing-no-vocab-reduction-2' ) checkpoints = [ 'averaged_checkpoint_best_3.pt', 'averaged_checkpoint_best_4.pt', ] checkpoint_filenames = [os.path.join(check_dir, f) for f in checkpoints] encoder_ensemble = EncoderEnsemble.build_from_checkpoints( checkpoint_filenames, os.path.join(check_dir, 'dictionary-tl.txt'), os.path.join(check_dir, 'dictionary-en.txt'), ) self._test_ensemble_encoder_object_export(encoder_ensemble)
def export(args): assert_required_args_are_set(args) checkpoint_filenames = [arg[0] for arg in args.path] encoder_ensemble = EncoderEnsemble.build_from_checkpoints( checkpoint_filenames=checkpoint_filenames, src_dict_filename=args.source_vocab_file, dst_dict_filename=args.target_vocab_file, ) if args.encoder_output_file != '': encoder_ensemble.save_to_db(args.encoder_output_file) if args.decoder_output_file != '': if args.batched_beam: decoder_step_class = DecoderBatchedStepEnsemble else: decoder_step_class = DecoderStepEnsemble decoder_step_ensemble = decoder_step_class.build_from_checkpoints( checkpoint_filenames=checkpoint_filenames, src_dict_filename=args.source_vocab_file, dst_dict_filename=args.target_vocab_file, beam_size=args.beam_size, word_penalty=args.word_penalty, unk_penalty=args.unk_penalty, ) # need example encoder outputs to pass through network # (source length 5 is arbitrary) src_dict = encoder_ensemble.models[0].src_dict token_list = [src_dict.unk()] * 4 + [src_dict.eos()] src_tokens = torch.LongTensor( np.array(token_list, dtype='int64').reshape(-1, 1), ) src_lengths = torch.IntTensor( np.array([len(token_list)], dtype='int32'), ) pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths) decoder_step_ensemble.save_to_db( args.decoder_output_file, pytorch_encoder_outputs, )
def _test_ensemble_encoder_export(self, test_args): samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args) num_models = 3 model_list = [] for _ in range(num_models): model_list.append(models.build_model(test_args, src_dict, tgt_dict)) encoder_ensemble = EncoderEnsemble(model_list) tmp_dir = tempfile.mkdtemp() encoder_pb_path = os.path.join(tmp_dir, 'encoder.pb') encoder_ensemble.onnx_export(encoder_pb_path) # test equivalence # The discrepancy in types here is a temporary expedient. # PyTorch indexing requires int64 while support for tracing # pack_padded_sequence() requires int32. sample = next(samples) src_tokens = sample['net_input']['src_tokens'][0:1].t() src_lengths = sample['net_input']['src_lengths'][0:1].int() pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths) with open(encoder_pb_path, 'r+b') as f: onnx_model = onnx.load(f) onnx_encoder = caffe2_backend.prepare(onnx_model) caffe2_encoder_outputs = onnx_encoder.run( ( src_tokens.numpy(), src_lengths.numpy(), ), ) for i in range(len(pytorch_encoder_outputs)): caffe2_out_value = caffe2_encoder_outputs[i] pytorch_out_value = pytorch_encoder_outputs[i].data.numpy() np.testing.assert_allclose( caffe2_out_value, pytorch_out_value, rtol=1e-4, atol=1e-6, ) encoder_ensemble.save_to_db( os.path.join(tmp_dir, 'encoder.predictor_export'), )
def _test_ensemble_encoder_export(self, test_args): samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args) task = tasks.DictionaryHolderTask(src_dict, tgt_dict) num_models = 3 model_list = [] for _ in range(num_models): model_list.append(task.build_model(test_args)) encoder_ensemble = EncoderEnsemble(model_list) tmp_dir = tempfile.mkdtemp() encoder_pb_path = os.path.join(tmp_dir, "encoder.pb") encoder_ensemble.onnx_export(encoder_pb_path) # test equivalence # The discrepancy in types here is a temporary expedient. # PyTorch indexing requires int64 while support for tracing # pack_padded_sequence() requires int32. sample = next(samples) src_tokens = sample["net_input"]["src_tokens"][0:1].t() src_lengths = sample["net_input"]["src_lengths"][0:1].int() pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths) onnx_encoder = caffe2_backend.prepare_zip_archive(encoder_pb_path) caffe2_encoder_outputs = onnx_encoder.run( (src_tokens.numpy(), src_lengths.numpy())) for i in range(len(pytorch_encoder_outputs)): caffe2_out_value = caffe2_encoder_outputs[i] pytorch_out_value = pytorch_encoder_outputs[i].detach().numpy() np.testing.assert_allclose(caffe2_out_value, pytorch_out_value, rtol=1e-4, atol=1e-6) encoder_ensemble.save_to_db( os.path.join(tmp_dir, "encoder.predictor_export"))
def _test_batched_beam_decoder_step(self, test_args): beam_size = 5 samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args) num_models = 3 model_list = [] for _ in range(num_models): model_list.append(models.build_model(test_args, src_dict, tgt_dict)) encoder_ensemble = EncoderEnsemble(model_list) # test equivalence # The discrepancy in types here is a temporary expedient. # PyTorch indexing requires int64 while support for tracing # pack_padded_sequence() requires int32. sample = next(samples) src_tokens = sample['net_input']['src_tokens'][0:1].t() src_lengths = sample['net_input']['src_lengths'][0:1].int() pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths) decoder_step_ensemble = DecoderBatchedStepEnsemble( model_list, beam_size=beam_size, ) tmp_dir = tempfile.mkdtemp() decoder_step_pb_path = os.path.join(tmp_dir, 'decoder_step.pb') decoder_step_ensemble.onnx_export( decoder_step_pb_path, pytorch_encoder_outputs, ) # single EOS in flat array input_tokens = torch.LongTensor( np.array([model_list[0].dst_dict.eos()]), ) prev_scores = torch.FloatTensor(np.array([0.0])) timestep = torch.LongTensor(np.array([0])) pytorch_first_step_outputs = decoder_step_ensemble( input_tokens, prev_scores, timestep, *pytorch_encoder_outputs) # next step inputs (input_tokesn shape: [beam_size]) next_input_tokens = torch.LongTensor( np.array([i for i in range(4, 9)]), ) next_prev_scores = pytorch_first_step_outputs[1] next_timestep = timestep + 1 next_states = pytorch_first_step_outputs[4:] step_inputs = [] # encoder outputs need to be replicated for each input hypothesis for encoder_rep in pytorch_encoder_outputs[:len(model_list)]: step_inputs.append(encoder_rep.repeat(1, beam_size, 1)) if model_list[0].decoder.vocab_reduction_module is not None: step_inputs.append(pytorch_encoder_outputs[len(model_list)]) step_inputs.extend(list(next_states)) pytorch_next_step_outputs = decoder_step_ensemble( next_input_tokens, next_prev_scores, next_timestep, *step_inputs) with open(decoder_step_pb_path, 'r+b') as f: onnx_model = onnx.load(f) onnx_decoder = caffe2_backend.prepare(onnx_model) decoder_inputs_numpy = [ next_input_tokens.numpy(), next_prev_scores.detach().numpy(), next_timestep.detach().numpy(), ] for tensor in step_inputs: decoder_inputs_numpy.append(tensor.detach().numpy()) caffe2_next_step_outputs = onnx_decoder.run( tuple(decoder_inputs_numpy), ) for i in range(len(pytorch_next_step_outputs)): caffe2_out_value = caffe2_next_step_outputs[i] pytorch_out_value = pytorch_next_step_outputs[i].data.numpy() np.testing.assert_allclose( caffe2_out_value, pytorch_out_value, rtol=1e-4, atol=1e-6, )
def _test_batched_beam_decoder_step(self, test_args, return_caffe2_rep=False): beam_size = 5 samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args) task = tasks.DictionaryHolderTask(src_dict, tgt_dict) num_models = 3 model_list = [] for _ in range(num_models): model_list.append(task.build_model(test_args)) encoder_ensemble = EncoderEnsemble(model_list) # test equivalence # The discrepancy in types here is a temporary expedient. # PyTorch indexing requires int64 while support for tracing # pack_padded_sequence() requires int32. sample = next(samples) src_tokens = sample["net_input"]["src_tokens"][0:1].t() src_lengths = sample["net_input"]["src_lengths"][0:1].int() pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths) decoder_step_ensemble = DecoderBatchedStepEnsemble(model_list, tgt_dict, beam_size=beam_size) tmp_dir = tempfile.mkdtemp() decoder_step_pb_path = os.path.join(tmp_dir, "decoder_step.pb") decoder_step_ensemble.onnx_export(decoder_step_pb_path, pytorch_encoder_outputs) # single EOS in flat array input_tokens = torch.LongTensor(np.array([tgt_dict.eos()])) prev_scores = torch.FloatTensor(np.array([0.0])) timestep = torch.LongTensor(np.array([0])) pytorch_first_step_outputs = decoder_step_ensemble( input_tokens, prev_scores, timestep, *pytorch_encoder_outputs) # next step inputs (input_tokesn shape: [beam_size]) next_input_tokens = torch.LongTensor(np.array([i for i in range(4, 9)])) next_prev_scores = pytorch_first_step_outputs[1] next_timestep = timestep + 1 next_states = list(pytorch_first_step_outputs[4:]) # Tile these for the next timestep for i in range(len(model_list)): next_states[i] = next_states[i].repeat(1, beam_size, 1) pytorch_next_step_outputs = decoder_step_ensemble( next_input_tokens, next_prev_scores, next_timestep, *next_states) onnx_decoder = caffe2_backend.prepare_zip_archive(decoder_step_pb_path) if return_caffe2_rep: return onnx_decoder decoder_inputs_numpy = [ next_input_tokens.numpy(), next_prev_scores.detach().numpy(), next_timestep.detach().numpy(), ] for tensor in next_states: decoder_inputs_numpy.append(tensor.detach().numpy()) caffe2_next_step_outputs = onnx_decoder.run( tuple(decoder_inputs_numpy)) for i in range(len(pytorch_next_step_outputs)): caffe2_out_value = caffe2_next_step_outputs[i] pytorch_out_value = pytorch_next_step_outputs[i].detach().numpy() np.testing.assert_allclose(caffe2_out_value, pytorch_out_value, rtol=1e-4, atol=1e-6) decoder_step_ensemble.save_to_db( output_path=os.path.join(tmp_dir, "decoder.predictor_export"), encoder_ensemble_outputs=pytorch_encoder_outputs, )
def main(): parser = argparse.ArgumentParser(description=( 'Export PyTorch-trained FBTranslate models to Caffe2 components'), ) parser.add_argument( '--checkpoint', action='append', nargs='+', help='PyTorch checkpoint file (at least one required)', ) parser.add_argument( '--encoder_output_file', default='', help='File name to which to save encoder ensemble network', ) parser.add_argument( '--decoder_output_file', default='', help='File name to which to save decoder step ensemble network', ) parser.add_argument( '--src_dict', required=True, help='File encoding PyTorch dictionary for source language', ) parser.add_argument( '--dst_dict', required=True, help='File encoding PyTorch dictionary for source language', ) parser.add_argument( '--beam_size', type=int, default=6, help='Number of top candidates returned by each decoder step', ) parser.add_argument( '--word_penalty', type=float, default=0.0, help='Value to add for each word (besides EOS)', ) parser.add_argument( '--unk_penalty', type=float, default=0.0, help='Value to add for each word UNK token', ) parser.add_argument( '--batched_beam', action='store_true', help='Decoder step has entire beam as input/output', ) args = parser.parse_args() if args.encoder_output_file == args.decoder_output_file == '': print('No action taken. Need at least one of --encoder_output_file ' 'and --decoder_output_file.') parser.print_help() return checkpoint_filenames = [arg[0] for arg in args.checkpoint] encoder_ensemble = EncoderEnsemble.build_from_checkpoints( checkpoint_filenames=checkpoint_filenames, src_dict_filename=args.src_dict, dst_dict_filename=args.dst_dict, ) if args.encoder_output_file != '': encoder_ensemble.save_to_db(args.encoder_output_file) if args.decoder_output_file != '': if args.batched_beam: decoder_step_class = DecoderBatchedStepEnsemble else: decoder_step_class = DecoderStepEnsemble decoder_step_ensemble = decoder_step_class.build_from_checkpoints( checkpoint_filenames=checkpoint_filenames, src_dict_filename=args.src_dict, dst_dict_filename=args.dst_dict, beam_size=args.beam_size, word_penalty=args.word_penalty, unk_penalty=args.unk_penalty, ) # need example encoder outputs to pass through network # (source length 5 is arbitrary) src_dict = encoder_ensemble.models[0].src_dict token_list = [src_dict.unk()] * 4 + [src_dict.eos()] src_tokens = torch.LongTensor( np.array(token_list, dtype='int64').reshape(-1, 1), ) src_lengths = torch.IntTensor( np.array([len(token_list)], dtype='int32'), ) pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths) decoder_step_ensemble.save_to_db( args.decoder_output_file, pytorch_encoder_outputs, )
def _test_beam_component_equivalence(self, test_args): beam_size = 5 samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args) task = tasks.DictionaryHolderTask(src_dict, tgt_dict) num_models = 3 model_list = [] for _ in range(num_models): model_list.append(task.build_model(test_args)) # to initialize BeamSearch object sample = next(samples) # [seq len, batch size=1] src_tokens = sample["net_input"]["src_tokens"][0:1].t() # [seq len] src_lengths = sample["net_input"]["src_lengths"][0:1].long() beam_size = 5 full_beam_search = BeamSearch(model_list, tgt_dict, src_tokens, src_lengths, beam_size=beam_size) encoder_ensemble = EncoderEnsemble(model_list) # to initialize decoder_step_ensemble with torch.no_grad(): pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths) decoder_step_ensemble = DecoderBatchedStepEnsemble(model_list, tgt_dict, beam_size=beam_size) prev_token = torch.LongTensor([tgt_dict.eos()]) prev_scores = torch.FloatTensor([0.0]) attn_weights = torch.zeros(src_tokens.shape[0]) prev_hypos_indices = torch.zeros(beam_size, dtype=torch.int64) num_steps = torch.LongTensor([2]) with torch.no_grad(): ( bs_out_tokens, bs_out_scores, bs_out_weights, bs_out_prev_indices, ) = full_beam_search( src_tokens, src_lengths, prev_token, prev_scores, attn_weights, prev_hypos_indices, num_steps, ) comp_out_tokens = (np.ones([num_steps + 1, beam_size], dtype="int64") * tgt_dict.eos()) comp_out_scores = np.zeros([num_steps + 1, beam_size]) comp_out_weights = np.zeros( [num_steps + 1, beam_size, src_lengths.numpy()[0]]) comp_out_prev_indices = np.zeros([num_steps + 1, beam_size], dtype="int64") # single EOS in flat array input_tokens = torch.LongTensor(np.array([tgt_dict.eos()])) prev_scores = torch.FloatTensor(np.array([0.0])) timestep = torch.LongTensor(np.array([0])) with torch.no_grad(): pytorch_first_step_outputs = decoder_step_ensemble( input_tokens, prev_scores, timestep, *pytorch_encoder_outputs) comp_out_tokens[1, :] = pytorch_first_step_outputs[0] comp_out_scores[1, :] = pytorch_first_step_outputs[1] comp_out_prev_indices[1, :] = pytorch_first_step_outputs[2] comp_out_weights[1, :, :] = pytorch_first_step_outputs[3] next_input_tokens = pytorch_first_step_outputs[0] next_prev_scores = pytorch_first_step_outputs[1] timestep += 1 # Tile states after first timestep next_states = list(pytorch_first_step_outputs[4:]) for i in range(len(model_list)): next_states[i] = next_states[i].repeat(1, beam_size, 1) with torch.no_grad(): pytorch_next_step_outputs = decoder_step_ensemble( next_input_tokens, next_prev_scores, timestep, *next_states) comp_out_tokens[2, :] = pytorch_next_step_outputs[0] comp_out_scores[2, :] = pytorch_next_step_outputs[1] comp_out_prev_indices[2, :] = pytorch_next_step_outputs[2] comp_out_weights[2, :, :] = pytorch_next_step_outputs[3] np.testing.assert_array_equal(comp_out_tokens, bs_out_tokens.numpy()) np.testing.assert_allclose(comp_out_scores, bs_out_scores.numpy(), rtol=1e-4, atol=1e-6) np.testing.assert_array_equal(comp_out_prev_indices, bs_out_prev_indices.numpy()) np.testing.assert_allclose(comp_out_weights, bs_out_weights.numpy(), rtol=1e-4, atol=1e-6)
def __init__( self, models, tgt_dict, src_tokens, src_lengths, eos_token_id, length_penalty, nbest, beam_size, stop_at_eos, word_reward=0, unk_reward=0, quantize=False, ): super().__init__() self.models = models self.tgt_dict = tgt_dict self.beam_size = torch.jit.Attribute(beam_size, int) self.word_reward = torch.jit.Attribute(word_reward, float) self.unk_reward = torch.jit.Attribute(unk_reward, float) encoder_ens = EncoderEnsemble(self.models) encoder_ens.enable_precompute_reduced_weights = True if quantize: encoder_ens = torch.jit.quantized.quantize_linear_modules( encoder_ens) encoder_ens = torch.jit.quantized.quantize_rnn_cell_modules( encoder_ens) # not support char source model self.is_char_source = False enc_inputs = (src_tokens, src_lengths) example_encoder_outs = encoder_ens(*enc_inputs) self.encoder_ens = torch.jit.trace(encoder_ens, enc_inputs, _force_outplace=True) self.encoder_ens_char_source = FakeCharSourceEncoderEnsemble() decoder_ens = DecoderBatchedStepEnsemble2BeamWithEOS( self.models, tgt_dict, beam_size, word_reward, unk_reward, tile_internal=False, ) decoder_ens.enable_precompute_reduced_weights = True if quantize: decoder_ens = torch.jit.quantized.quantize_linear_modules( decoder_ens) decoder_ens = torch.jit.quantized.quantize_rnn_cell_modules( decoder_ens) decoder_ens = torch.jit.quantized.quantize_rnn_modules(decoder_ens) decoder_ens_tile = DecoderBatchedStepEnsemble2BeamWithEOS( self.models, tgt_dict, beam_size, word_reward, unk_reward, tile_internal=True, ) decoder_ens_tile.enable_precompute_reduced_weights = True if quantize: decoder_ens_tile = torch.jit.quantized.quantize_linear_modules( decoder_ens_tile) decoder_ens_tile = torch.jit.quantized.quantize_rnn_cell_modules( decoder_ens_tile) decoder_ens_tile = torch.jit.quantized.quantize_rnn_modules( decoder_ens_tile) prev_token = torch.LongTensor([0]) prev_scores = torch.FloatTensor([0.0]) ts = torch.LongTensor([0]) final_step = torch.tensor([False], dtype=torch.bool) active_hypos = torch.LongTensor([0]) _, _, _, _, _, *tiled_states = decoder_ens_tile( prev_token, prev_scores, active_hypos, ts, final_step, *example_encoder_outs) self.decoder_ens_tile = torch.jit.trace( decoder_ens_tile, ( prev_token, prev_scores, active_hypos, ts, final_step, *example_encoder_outs, ), _force_outplace=True, ) self.decoder_ens = torch.jit.trace( decoder_ens, ( prev_token.repeat(self.beam_size), prev_scores.repeat(self.beam_size), active_hypos.repeat(self.beam_size), ts, final_step, *tiled_states, ), _force_outplace=True, ) self.beam_decode = BeamDecodeWithEOS(eos_token_id, length_penalty, nbest, beam_size, stop_at_eos) self.input_names = [ "src_tokens", "src_lengths", "prev_token", "prev_scores", "attn_weights", "prev_hypos_indices", "num_steps", ] self.output_names = [ "beam_output", "hypothesis_score", "token_level_scores", "back_alignment_weights", "best_indices", ]
def test_decoder_ensemble_with_eos(self): """ This is to test the functionality of DecoderBatchedStepEnsembleWithEOS class. We expect it generates same outputs with DecoderBatchedStepEnsemble before final step. At final step, it generates EOS tokens. """ test_args = test_utils.ModelParamsDict(arch="rnn") samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args) task = tasks.DictionaryHolderTask(src_dict, tgt_dict) model = task.build_model(test_args) eos_token = tgt_dict.eos() encoder_ensemble = EncoderEnsemble([model]) src_tokens = torch.LongTensor([4, 5, 6, 7, 8]).unsqueeze(1) src_lengths = torch.LongTensor([5]) enc_inputs = (src_tokens, src_lengths) encoder_outputs = encoder_ensemble(*enc_inputs) beam_size = 8 word_reward = 1 unk_reward = -1 decoder_ensemble = DecoderBatchedStepEnsemble( models=[model], tgt_dict=tgt_dict, beam_size=beam_size, word_reward=word_reward, unk_reward=unk_reward, ) decoder_ensemble_with_eos = DecoderBatchedStepEnsembleWithEOS( models=[model], tgt_dict=tgt_dict, beam_size=beam_size, word_reward=word_reward, unk_reward=unk_reward, ) prev_tokens = torch.LongTensor([eos_token]) prev_scores = torch.FloatTensor([0.0]) timestep = torch.LongTensor([0]) final_step = torch.tensor([False], dtype=torch.bool) maxLen = 5 num_steps = torch.LongTensor([maxLen]) decoder_first_step_outputs = decoder_ensemble(prev_tokens, prev_scores, timestep, *encoder_outputs) decoder_with_eos_first_step_outputs = decoder_ensemble_with_eos( prev_tokens, prev_scores, timestep, final_step, *encoder_outputs) # Test results at first step self._test_base(decoder_first_step_outputs, decoder_with_eos_first_step_outputs) ( prev_tokens, prev_scores, prev_hypos_indices, attn_weights, *states, ) = decoder_first_step_outputs # Tile is needed after first step for i in range(len([model])): states[i] = states[i].repeat(1, beam_size, 1) ( prev_tokens_with_eos, prev_scores_with_eos, prev_hypos_indices_with_eos, attn_weights_with_eos, *states_with_eos, ) = decoder_with_eos_first_step_outputs for i in range(len([model])): states_with_eos[i] = states_with_eos[i].repeat(1, beam_size, 1) for i in range(num_steps - 1): decoder_step_outputs = decoder_ensemble(prev_tokens, prev_scores, torch.tensor([i + 1]), *states) ( prev_tokens, prev_scores, prev_hypos_indices, attn_weights, *states, ) = decoder_step_outputs decoder_step_with_eos_outputs = decoder_ensemble_with_eos( prev_tokens_with_eos, prev_scores_with_eos, torch.tensor([i + 1]), final_step, *states_with_eos, ) ( prev_tokens_with_eos, prev_scores_with_eos, prev_hypos_indices_with_eos, attn_weights_with_eos, *states_with_eos, ) = decoder_step_with_eos_outputs # Test results at each step self._test_base(decoder_step_outputs, decoder_step_with_eos_outputs) # Test the outputs of final tesp decoder_final_with_eos_outputs = decoder_ensemble_with_eos( prev_tokens_with_eos, prev_scores_with_eos, torch.tensor([num_steps]), torch.tensor([True]), *states_with_eos, ) np.testing.assert_array_equal( decoder_final_with_eos_outputs[0], torch.LongTensor([eos_token]).repeat(beam_size), ) np.testing.assert_array_equal( decoder_final_with_eos_outputs[2], torch.LongTensor(np.array([i for i in range(beam_size)])), )