예제 #1
0
    def test_batched_beam_from_checkpoints_vr(self):
        check_dir = (
            '/mnt/gfsdataswarm-global/namespaces/search/language-technology-mt/'
            'nnmt_tmp/tl_XX-en_XX-pytorch-256-dim-vocab-reduction'
        )
        checkpoints = [
            'averaged_checkpoint_best_3.pt',
            'averaged_checkpoint_best_4.pt',
            'averaged_checkpoint_best_5.pt',
        ]
        checkpoint_filenames = [os.path.join(check_dir, f) for f in checkpoints]

        encoder_ensemble = EncoderEnsemble.build_from_checkpoints(
            checkpoint_filenames,
            os.path.join(check_dir, 'dictionary-tl.txt'),
            os.path.join(check_dir, 'dictionary-en.txt'),
        )

        decoder_step_ensemble = DecoderBatchedStepEnsemble.build_from_checkpoints(
            checkpoint_filenames,
            os.path.join(check_dir, 'dictionary-tl.txt'),
            os.path.join(check_dir, 'dictionary-en.txt'),
            beam_size=5,
        )

        self._test_full_ensemble(
            encoder_ensemble,
            decoder_step_ensemble,
            batched_beam=True,
        )
예제 #2
0
    def _test_full_ensemble_export(self, test_args):
        samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args)

        num_models = 3
        model_list = []
        for _ in range(num_models):
            model_list.append(models.build_model(test_args, src_dict,
                                                 tgt_dict))
        encoder_ensemble = EncoderEnsemble(model_list)

        # test equivalence
        # The discrepancy in types here is a temporary expedient.
        # PyTorch indexing requires int64 while support for tracing
        # pack_padded_sequence() requires int32.
        sample = next(samples)
        src_tokens = sample["net_input"]["src_tokens"][0:1].t()
        src_lengths = sample["net_input"]["src_lengths"][0:1].int()

        pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths)

        decoder_step_ensemble = DecoderStepEnsemble(model_list, beam_size=5)

        tmp_dir = tempfile.mkdtemp()
        decoder_step_pb_path = os.path.join(tmp_dir, "decoder_step.pb")
        decoder_step_ensemble.onnx_export(decoder_step_pb_path,
                                          pytorch_encoder_outputs)

        # single EOS
        input_token = torch.LongTensor(
            np.array([[model_list[0].dst_dict.eos()]]))
        timestep = torch.LongTensor(np.array([[0]]))

        pytorch_decoder_outputs = decoder_step_ensemble(
            input_token, timestep, *pytorch_encoder_outputs)

        with open(decoder_step_pb_path, "r+b") as f:
            onnx_model = onnx.load(f)
        onnx_decoder = caffe2_backend.prepare(onnx_model)

        decoder_inputs_numpy = [input_token.numpy(), timestep.numpy()]
        for tensor in pytorch_encoder_outputs:
            decoder_inputs_numpy.append(tensor.detach().numpy())

        caffe2_decoder_outputs = onnx_decoder.run(tuple(decoder_inputs_numpy))

        for i in range(len(pytorch_decoder_outputs)):
            caffe2_out_value = caffe2_decoder_outputs[i]
            pytorch_out_value = pytorch_decoder_outputs[i].detach().numpy()
            np.testing.assert_allclose(caffe2_out_value,
                                       pytorch_out_value,
                                       rtol=1e-4,
                                       atol=1e-6)

        decoder_step_ensemble.save_to_db(
            os.path.join(tmp_dir, "decoder_step.predictor_export"),
            pytorch_encoder_outputs,
        )
예제 #3
0
    def test_export_encoder_from_checkpoints_no_vr(self):
        check_dir = (
            '/mnt/gfsdataswarm-global/namespaces/search/language-technology-mt/'
            'nnmt_tmp/tl_XX-en_XX-pytorch-testing-no-vocab-reduction-2'
        )
        checkpoints = [
            'averaged_checkpoint_best_3.pt',
            'averaged_checkpoint_best_4.pt',
        ]
        checkpoint_filenames = [os.path.join(check_dir, f) for f in checkpoints]

        encoder_ensemble = EncoderEnsemble.build_from_checkpoints(
            checkpoint_filenames,
            os.path.join(check_dir, 'dictionary-tl.txt'),
            os.path.join(check_dir, 'dictionary-en.txt'),
        )

        self._test_ensemble_encoder_object_export(encoder_ensemble)
예제 #4
0
def export(args):
    assert_required_args_are_set(args)
    checkpoint_filenames = [arg[0] for arg in args.path]

    encoder_ensemble = EncoderEnsemble.build_from_checkpoints(
        checkpoint_filenames=checkpoint_filenames,
        src_dict_filename=args.source_vocab_file,
        dst_dict_filename=args.target_vocab_file,
    )
    if args.encoder_output_file != '':
        encoder_ensemble.save_to_db(args.encoder_output_file)

    if args.decoder_output_file != '':
        if args.batched_beam:
            decoder_step_class = DecoderBatchedStepEnsemble
        else:
            decoder_step_class = DecoderStepEnsemble
        decoder_step_ensemble = decoder_step_class.build_from_checkpoints(
            checkpoint_filenames=checkpoint_filenames,
            src_dict_filename=args.source_vocab_file,
            dst_dict_filename=args.target_vocab_file,
            beam_size=args.beam_size,
            word_penalty=args.word_penalty,
            unk_penalty=args.unk_penalty,
        )

        # need example encoder outputs to pass through network
        # (source length 5 is arbitrary)
        src_dict = encoder_ensemble.models[0].src_dict
        token_list = [src_dict.unk()] * 4 + [src_dict.eos()]
        src_tokens = torch.LongTensor(
            np.array(token_list, dtype='int64').reshape(-1, 1), )
        src_lengths = torch.IntTensor(
            np.array([len(token_list)], dtype='int32'), )
        pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths)

        decoder_step_ensemble.save_to_db(
            args.decoder_output_file,
            pytorch_encoder_outputs,
        )
예제 #5
0
    def _test_ensemble_encoder_export(self, test_args):
        samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args)

        num_models = 3
        model_list = []
        for _ in range(num_models):
            model_list.append(models.build_model(test_args, src_dict, tgt_dict))
        encoder_ensemble = EncoderEnsemble(model_list)

        tmp_dir = tempfile.mkdtemp()
        encoder_pb_path = os.path.join(tmp_dir, 'encoder.pb')
        encoder_ensemble.onnx_export(encoder_pb_path)

        # test equivalence
        # The discrepancy in types here is a temporary expedient.
        # PyTorch indexing requires int64 while support for tracing
        # pack_padded_sequence() requires int32.
        sample = next(samples)
        src_tokens = sample['net_input']['src_tokens'][0:1].t()
        src_lengths = sample['net_input']['src_lengths'][0:1].int()

        pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths)

        with open(encoder_pb_path, 'r+b') as f:
            onnx_model = onnx.load(f)
        onnx_encoder = caffe2_backend.prepare(onnx_model)

        caffe2_encoder_outputs = onnx_encoder.run(
            (
                src_tokens.numpy(),
                src_lengths.numpy(),
            ),
        )

        for i in range(len(pytorch_encoder_outputs)):
            caffe2_out_value = caffe2_encoder_outputs[i]
            pytorch_out_value = pytorch_encoder_outputs[i].data.numpy()
            np.testing.assert_allclose(
                caffe2_out_value,
                pytorch_out_value,
                rtol=1e-4,
                atol=1e-6,
            )

        encoder_ensemble.save_to_db(
            os.path.join(tmp_dir, 'encoder.predictor_export'),
        )
예제 #6
0
    def _test_ensemble_encoder_export(self, test_args):
        samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args)
        task = tasks.DictionaryHolderTask(src_dict, tgt_dict)

        num_models = 3
        model_list = []
        for _ in range(num_models):
            model_list.append(task.build_model(test_args))
        encoder_ensemble = EncoderEnsemble(model_list)

        tmp_dir = tempfile.mkdtemp()
        encoder_pb_path = os.path.join(tmp_dir, "encoder.pb")
        encoder_ensemble.onnx_export(encoder_pb_path)

        # test equivalence
        # The discrepancy in types here is a temporary expedient.
        # PyTorch indexing requires int64 while support for tracing
        # pack_padded_sequence() requires int32.
        sample = next(samples)
        src_tokens = sample["net_input"]["src_tokens"][0:1].t()
        src_lengths = sample["net_input"]["src_lengths"][0:1].int()

        pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths)

        onnx_encoder = caffe2_backend.prepare_zip_archive(encoder_pb_path)

        caffe2_encoder_outputs = onnx_encoder.run(
            (src_tokens.numpy(), src_lengths.numpy()))

        for i in range(len(pytorch_encoder_outputs)):
            caffe2_out_value = caffe2_encoder_outputs[i]
            pytorch_out_value = pytorch_encoder_outputs[i].detach().numpy()
            np.testing.assert_allclose(caffe2_out_value,
                                       pytorch_out_value,
                                       rtol=1e-4,
                                       atol=1e-6)

        encoder_ensemble.save_to_db(
            os.path.join(tmp_dir, "encoder.predictor_export"))
예제 #7
0
    def _test_batched_beam_decoder_step(self, test_args):
        beam_size = 5
        samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args)

        num_models = 3
        model_list = []
        for _ in range(num_models):
            model_list.append(models.build_model(test_args, src_dict,
                                                 tgt_dict))
        encoder_ensemble = EncoderEnsemble(model_list)

        # test equivalence
        # The discrepancy in types here is a temporary expedient.
        # PyTorch indexing requires int64 while support for tracing
        # pack_padded_sequence() requires int32.
        sample = next(samples)
        src_tokens = sample['net_input']['src_tokens'][0:1].t()
        src_lengths = sample['net_input']['src_lengths'][0:1].int()

        pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths)

        decoder_step_ensemble = DecoderBatchedStepEnsemble(
            model_list,
            beam_size=beam_size,
        )

        tmp_dir = tempfile.mkdtemp()
        decoder_step_pb_path = os.path.join(tmp_dir, 'decoder_step.pb')
        decoder_step_ensemble.onnx_export(
            decoder_step_pb_path,
            pytorch_encoder_outputs,
        )

        # single EOS in flat array
        input_tokens = torch.LongTensor(
            np.array([model_list[0].dst_dict.eos()]), )
        prev_scores = torch.FloatTensor(np.array([0.0]))
        timestep = torch.LongTensor(np.array([0]))

        pytorch_first_step_outputs = decoder_step_ensemble(
            input_tokens, prev_scores, timestep, *pytorch_encoder_outputs)

        # next step inputs (input_tokesn shape: [beam_size])
        next_input_tokens = torch.LongTensor(
            np.array([i for i in range(4, 9)]), )

        next_prev_scores = pytorch_first_step_outputs[1]
        next_timestep = timestep + 1
        next_states = pytorch_first_step_outputs[4:]

        step_inputs = []

        # encoder outputs need to be replicated for each input hypothesis
        for encoder_rep in pytorch_encoder_outputs[:len(model_list)]:
            step_inputs.append(encoder_rep.repeat(1, beam_size, 1))

        if model_list[0].decoder.vocab_reduction_module is not None:
            step_inputs.append(pytorch_encoder_outputs[len(model_list)])

        step_inputs.extend(list(next_states))

        pytorch_next_step_outputs = decoder_step_ensemble(
            next_input_tokens, next_prev_scores, next_timestep, *step_inputs)

        with open(decoder_step_pb_path, 'r+b') as f:
            onnx_model = onnx.load(f)
        onnx_decoder = caffe2_backend.prepare(onnx_model)

        decoder_inputs_numpy = [
            next_input_tokens.numpy(),
            next_prev_scores.detach().numpy(),
            next_timestep.detach().numpy(),
        ]
        for tensor in step_inputs:
            decoder_inputs_numpy.append(tensor.detach().numpy())

        caffe2_next_step_outputs = onnx_decoder.run(
            tuple(decoder_inputs_numpy), )

        for i in range(len(pytorch_next_step_outputs)):
            caffe2_out_value = caffe2_next_step_outputs[i]
            pytorch_out_value = pytorch_next_step_outputs[i].data.numpy()
            np.testing.assert_allclose(
                caffe2_out_value,
                pytorch_out_value,
                rtol=1e-4,
                atol=1e-6,
            )
예제 #8
0
    def _test_batched_beam_decoder_step(self,
                                        test_args,
                                        return_caffe2_rep=False):
        beam_size = 5
        samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args)
        task = tasks.DictionaryHolderTask(src_dict, tgt_dict)

        num_models = 3
        model_list = []
        for _ in range(num_models):
            model_list.append(task.build_model(test_args))
        encoder_ensemble = EncoderEnsemble(model_list)

        # test equivalence
        # The discrepancy in types here is a temporary expedient.
        # PyTorch indexing requires int64 while support for tracing
        # pack_padded_sequence() requires int32.
        sample = next(samples)
        src_tokens = sample["net_input"]["src_tokens"][0:1].t()
        src_lengths = sample["net_input"]["src_lengths"][0:1].int()

        pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths)

        decoder_step_ensemble = DecoderBatchedStepEnsemble(model_list,
                                                           tgt_dict,
                                                           beam_size=beam_size)

        tmp_dir = tempfile.mkdtemp()
        decoder_step_pb_path = os.path.join(tmp_dir, "decoder_step.pb")
        decoder_step_ensemble.onnx_export(decoder_step_pb_path,
                                          pytorch_encoder_outputs)

        # single EOS in flat array
        input_tokens = torch.LongTensor(np.array([tgt_dict.eos()]))
        prev_scores = torch.FloatTensor(np.array([0.0]))
        timestep = torch.LongTensor(np.array([0]))

        pytorch_first_step_outputs = decoder_step_ensemble(
            input_tokens, prev_scores, timestep, *pytorch_encoder_outputs)

        # next step inputs (input_tokesn shape: [beam_size])
        next_input_tokens = torch.LongTensor(np.array([i
                                                       for i in range(4, 9)]))

        next_prev_scores = pytorch_first_step_outputs[1]
        next_timestep = timestep + 1
        next_states = list(pytorch_first_step_outputs[4:])

        # Tile these for the next timestep
        for i in range(len(model_list)):
            next_states[i] = next_states[i].repeat(1, beam_size, 1)

        pytorch_next_step_outputs = decoder_step_ensemble(
            next_input_tokens, next_prev_scores, next_timestep, *next_states)

        onnx_decoder = caffe2_backend.prepare_zip_archive(decoder_step_pb_path)

        if return_caffe2_rep:
            return onnx_decoder

        decoder_inputs_numpy = [
            next_input_tokens.numpy(),
            next_prev_scores.detach().numpy(),
            next_timestep.detach().numpy(),
        ]
        for tensor in next_states:
            decoder_inputs_numpy.append(tensor.detach().numpy())

        caffe2_next_step_outputs = onnx_decoder.run(
            tuple(decoder_inputs_numpy))

        for i in range(len(pytorch_next_step_outputs)):
            caffe2_out_value = caffe2_next_step_outputs[i]
            pytorch_out_value = pytorch_next_step_outputs[i].detach().numpy()
            np.testing.assert_allclose(caffe2_out_value,
                                       pytorch_out_value,
                                       rtol=1e-4,
                                       atol=1e-6)
        decoder_step_ensemble.save_to_db(
            output_path=os.path.join(tmp_dir, "decoder.predictor_export"),
            encoder_ensemble_outputs=pytorch_encoder_outputs,
        )
def main():
    parser = argparse.ArgumentParser(description=(
        'Export PyTorch-trained FBTranslate models to Caffe2 components'), )
    parser.add_argument(
        '--checkpoint',
        action='append',
        nargs='+',
        help='PyTorch checkpoint file (at least one required)',
    )
    parser.add_argument(
        '--encoder_output_file',
        default='',
        help='File name to which to save encoder ensemble network',
    )
    parser.add_argument(
        '--decoder_output_file',
        default='',
        help='File name to which to save decoder step ensemble network',
    )
    parser.add_argument(
        '--src_dict',
        required=True,
        help='File encoding PyTorch dictionary for source language',
    )
    parser.add_argument(
        '--dst_dict',
        required=True,
        help='File encoding PyTorch dictionary for source language',
    )
    parser.add_argument(
        '--beam_size',
        type=int,
        default=6,
        help='Number of top candidates returned by each decoder step',
    )
    parser.add_argument(
        '--word_penalty',
        type=float,
        default=0.0,
        help='Value to add for each word (besides EOS)',
    )
    parser.add_argument(
        '--unk_penalty',
        type=float,
        default=0.0,
        help='Value to add for each word UNK token',
    )
    parser.add_argument(
        '--batched_beam',
        action='store_true',
        help='Decoder step has entire beam as input/output',
    )

    args = parser.parse_args()

    if args.encoder_output_file == args.decoder_output_file == '':
        print('No action taken. Need at least one of --encoder_output_file '
              'and --decoder_output_file.')
        parser.print_help()
        return

    checkpoint_filenames = [arg[0] for arg in args.checkpoint]

    encoder_ensemble = EncoderEnsemble.build_from_checkpoints(
        checkpoint_filenames=checkpoint_filenames,
        src_dict_filename=args.src_dict,
        dst_dict_filename=args.dst_dict,
    )
    if args.encoder_output_file != '':
        encoder_ensemble.save_to_db(args.encoder_output_file)

    if args.decoder_output_file != '':
        if args.batched_beam:
            decoder_step_class = DecoderBatchedStepEnsemble
        else:
            decoder_step_class = DecoderStepEnsemble
        decoder_step_ensemble = decoder_step_class.build_from_checkpoints(
            checkpoint_filenames=checkpoint_filenames,
            src_dict_filename=args.src_dict,
            dst_dict_filename=args.dst_dict,
            beam_size=args.beam_size,
            word_penalty=args.word_penalty,
            unk_penalty=args.unk_penalty,
        )

        # need example encoder outputs to pass through network
        # (source length 5 is arbitrary)
        src_dict = encoder_ensemble.models[0].src_dict
        token_list = [src_dict.unk()] * 4 + [src_dict.eos()]
        src_tokens = torch.LongTensor(
            np.array(token_list, dtype='int64').reshape(-1, 1), )
        src_lengths = torch.IntTensor(
            np.array([len(token_list)], dtype='int32'), )
        pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths)

        decoder_step_ensemble.save_to_db(
            args.decoder_output_file,
            pytorch_encoder_outputs,
        )
예제 #10
0
    def _test_beam_component_equivalence(self, test_args):
        beam_size = 5
        samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args)
        task = tasks.DictionaryHolderTask(src_dict, tgt_dict)

        num_models = 3
        model_list = []
        for _ in range(num_models):
            model_list.append(task.build_model(test_args))

        # to initialize BeamSearch object
        sample = next(samples)
        # [seq len, batch size=1]
        src_tokens = sample["net_input"]["src_tokens"][0:1].t()
        # [seq len]
        src_lengths = sample["net_input"]["src_lengths"][0:1].long()

        beam_size = 5
        full_beam_search = BeamSearch(model_list,
                                      tgt_dict,
                                      src_tokens,
                                      src_lengths,
                                      beam_size=beam_size)

        encoder_ensemble = EncoderEnsemble(model_list)

        # to initialize decoder_step_ensemble
        with torch.no_grad():
            pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths)

        decoder_step_ensemble = DecoderBatchedStepEnsemble(model_list,
                                                           tgt_dict,
                                                           beam_size=beam_size)

        prev_token = torch.LongTensor([tgt_dict.eos()])
        prev_scores = torch.FloatTensor([0.0])
        attn_weights = torch.zeros(src_tokens.shape[0])
        prev_hypos_indices = torch.zeros(beam_size, dtype=torch.int64)
        num_steps = torch.LongTensor([2])

        with torch.no_grad():
            (
                bs_out_tokens,
                bs_out_scores,
                bs_out_weights,
                bs_out_prev_indices,
            ) = full_beam_search(
                src_tokens,
                src_lengths,
                prev_token,
                prev_scores,
                attn_weights,
                prev_hypos_indices,
                num_steps,
            )

        comp_out_tokens = (np.ones([num_steps + 1, beam_size], dtype="int64") *
                           tgt_dict.eos())
        comp_out_scores = np.zeros([num_steps + 1, beam_size])
        comp_out_weights = np.zeros(
            [num_steps + 1, beam_size,
             src_lengths.numpy()[0]])
        comp_out_prev_indices = np.zeros([num_steps + 1, beam_size],
                                         dtype="int64")

        # single EOS in flat array
        input_tokens = torch.LongTensor(np.array([tgt_dict.eos()]))
        prev_scores = torch.FloatTensor(np.array([0.0]))
        timestep = torch.LongTensor(np.array([0]))

        with torch.no_grad():
            pytorch_first_step_outputs = decoder_step_ensemble(
                input_tokens, prev_scores, timestep, *pytorch_encoder_outputs)

        comp_out_tokens[1, :] = pytorch_first_step_outputs[0]
        comp_out_scores[1, :] = pytorch_first_step_outputs[1]
        comp_out_prev_indices[1, :] = pytorch_first_step_outputs[2]
        comp_out_weights[1, :, :] = pytorch_first_step_outputs[3]

        next_input_tokens = pytorch_first_step_outputs[0]
        next_prev_scores = pytorch_first_step_outputs[1]
        timestep += 1

        # Tile states after first timestep
        next_states = list(pytorch_first_step_outputs[4:])
        for i in range(len(model_list)):
            next_states[i] = next_states[i].repeat(1, beam_size, 1)

        with torch.no_grad():
            pytorch_next_step_outputs = decoder_step_ensemble(
                next_input_tokens, next_prev_scores, timestep, *next_states)

        comp_out_tokens[2, :] = pytorch_next_step_outputs[0]
        comp_out_scores[2, :] = pytorch_next_step_outputs[1]
        comp_out_prev_indices[2, :] = pytorch_next_step_outputs[2]
        comp_out_weights[2, :, :] = pytorch_next_step_outputs[3]

        np.testing.assert_array_equal(comp_out_tokens, bs_out_tokens.numpy())
        np.testing.assert_allclose(comp_out_scores,
                                   bs_out_scores.numpy(),
                                   rtol=1e-4,
                                   atol=1e-6)
        np.testing.assert_array_equal(comp_out_prev_indices,
                                      bs_out_prev_indices.numpy())
        np.testing.assert_allclose(comp_out_weights,
                                   bs_out_weights.numpy(),
                                   rtol=1e-4,
                                   atol=1e-6)
예제 #11
0
    def __init__(
        self,
        models,
        tgt_dict,
        src_tokens,
        src_lengths,
        eos_token_id,
        length_penalty,
        nbest,
        beam_size,
        stop_at_eos,
        word_reward=0,
        unk_reward=0,
        quantize=False,
    ):
        super().__init__()

        self.models = models
        self.tgt_dict = tgt_dict
        self.beam_size = torch.jit.Attribute(beam_size, int)
        self.word_reward = torch.jit.Attribute(word_reward, float)
        self.unk_reward = torch.jit.Attribute(unk_reward, float)

        encoder_ens = EncoderEnsemble(self.models)
        encoder_ens.enable_precompute_reduced_weights = True

        if quantize:
            encoder_ens = torch.jit.quantized.quantize_linear_modules(
                encoder_ens)
            encoder_ens = torch.jit.quantized.quantize_rnn_cell_modules(
                encoder_ens)

        # not support char source model
        self.is_char_source = False
        enc_inputs = (src_tokens, src_lengths)
        example_encoder_outs = encoder_ens(*enc_inputs)
        self.encoder_ens = torch.jit.trace(encoder_ens,
                                           enc_inputs,
                                           _force_outplace=True)
        self.encoder_ens_char_source = FakeCharSourceEncoderEnsemble()

        decoder_ens = DecoderBatchedStepEnsemble2BeamWithEOS(
            self.models,
            tgt_dict,
            beam_size,
            word_reward,
            unk_reward,
            tile_internal=False,
        )
        decoder_ens.enable_precompute_reduced_weights = True
        if quantize:
            decoder_ens = torch.jit.quantized.quantize_linear_modules(
                decoder_ens)
            decoder_ens = torch.jit.quantized.quantize_rnn_cell_modules(
                decoder_ens)
            decoder_ens = torch.jit.quantized.quantize_rnn_modules(decoder_ens)
        decoder_ens_tile = DecoderBatchedStepEnsemble2BeamWithEOS(
            self.models,
            tgt_dict,
            beam_size,
            word_reward,
            unk_reward,
            tile_internal=True,
        )
        decoder_ens_tile.enable_precompute_reduced_weights = True
        if quantize:
            decoder_ens_tile = torch.jit.quantized.quantize_linear_modules(
                decoder_ens_tile)
            decoder_ens_tile = torch.jit.quantized.quantize_rnn_cell_modules(
                decoder_ens_tile)
            decoder_ens_tile = torch.jit.quantized.quantize_rnn_modules(
                decoder_ens_tile)
        prev_token = torch.LongTensor([0])
        prev_scores = torch.FloatTensor([0.0])
        ts = torch.LongTensor([0])
        final_step = torch.tensor([False], dtype=torch.bool)
        active_hypos = torch.LongTensor([0])

        _, _, _, _, _, *tiled_states = decoder_ens_tile(
            prev_token, prev_scores, active_hypos, ts, final_step,
            *example_encoder_outs)

        self.decoder_ens_tile = torch.jit.trace(
            decoder_ens_tile,
            (
                prev_token,
                prev_scores,
                active_hypos,
                ts,
                final_step,
                *example_encoder_outs,
            ),
            _force_outplace=True,
        )
        self.decoder_ens = torch.jit.trace(
            decoder_ens,
            (
                prev_token.repeat(self.beam_size),
                prev_scores.repeat(self.beam_size),
                active_hypos.repeat(self.beam_size),
                ts,
                final_step,
                *tiled_states,
            ),
            _force_outplace=True,
        )

        self.beam_decode = BeamDecodeWithEOS(eos_token_id, length_penalty,
                                             nbest, beam_size, stop_at_eos)

        self.input_names = [
            "src_tokens",
            "src_lengths",
            "prev_token",
            "prev_scores",
            "attn_weights",
            "prev_hypos_indices",
            "num_steps",
        ]
        self.output_names = [
            "beam_output",
            "hypothesis_score",
            "token_level_scores",
            "back_alignment_weights",
            "best_indices",
        ]
    def test_decoder_ensemble_with_eos(self):
        """
        This is to test the functionality of DecoderBatchedStepEnsembleWithEOS class.
        We expect it generates same outputs with DecoderBatchedStepEnsemble before
        final step. At final step, it generates EOS tokens.
        """
        test_args = test_utils.ModelParamsDict(arch="rnn")
        samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args)
        task = tasks.DictionaryHolderTask(src_dict, tgt_dict)
        model = task.build_model(test_args)
        eos_token = tgt_dict.eos()

        encoder_ensemble = EncoderEnsemble([model])
        src_tokens = torch.LongTensor([4, 5, 6, 7, 8]).unsqueeze(1)
        src_lengths = torch.LongTensor([5])
        enc_inputs = (src_tokens, src_lengths)
        encoder_outputs = encoder_ensemble(*enc_inputs)

        beam_size = 8
        word_reward = 1
        unk_reward = -1
        decoder_ensemble = DecoderBatchedStepEnsemble(
            models=[model],
            tgt_dict=tgt_dict,
            beam_size=beam_size,
            word_reward=word_reward,
            unk_reward=unk_reward,
        )
        decoder_ensemble_with_eos = DecoderBatchedStepEnsembleWithEOS(
            models=[model],
            tgt_dict=tgt_dict,
            beam_size=beam_size,
            word_reward=word_reward,
            unk_reward=unk_reward,
        )

        prev_tokens = torch.LongTensor([eos_token])
        prev_scores = torch.FloatTensor([0.0])
        timestep = torch.LongTensor([0])
        final_step = torch.tensor([False], dtype=torch.bool)
        maxLen = 5
        num_steps = torch.LongTensor([maxLen])

        decoder_first_step_outputs = decoder_ensemble(prev_tokens, prev_scores,
                                                      timestep,
                                                      *encoder_outputs)

        decoder_with_eos_first_step_outputs = decoder_ensemble_with_eos(
            prev_tokens, prev_scores, timestep, final_step, *encoder_outputs)

        # Test results at first step
        self._test_base(decoder_first_step_outputs,
                        decoder_with_eos_first_step_outputs)

        (
            prev_tokens,
            prev_scores,
            prev_hypos_indices,
            attn_weights,
            *states,
        ) = decoder_first_step_outputs

        # Tile is needed after first step
        for i in range(len([model])):
            states[i] = states[i].repeat(1, beam_size, 1)

        (
            prev_tokens_with_eos,
            prev_scores_with_eos,
            prev_hypos_indices_with_eos,
            attn_weights_with_eos,
            *states_with_eos,
        ) = decoder_with_eos_first_step_outputs

        for i in range(len([model])):
            states_with_eos[i] = states_with_eos[i].repeat(1, beam_size, 1)

        for i in range(num_steps - 1):
            decoder_step_outputs = decoder_ensemble(prev_tokens, prev_scores,
                                                    torch.tensor([i + 1]),
                                                    *states)
            (
                prev_tokens,
                prev_scores,
                prev_hypos_indices,
                attn_weights,
                *states,
            ) = decoder_step_outputs
            decoder_step_with_eos_outputs = decoder_ensemble_with_eos(
                prev_tokens_with_eos,
                prev_scores_with_eos,
                torch.tensor([i + 1]),
                final_step,
                *states_with_eos,
            )
            (
                prev_tokens_with_eos,
                prev_scores_with_eos,
                prev_hypos_indices_with_eos,
                attn_weights_with_eos,
                *states_with_eos,
            ) = decoder_step_with_eos_outputs

            # Test results at each step
            self._test_base(decoder_step_outputs,
                            decoder_step_with_eos_outputs)

        # Test the outputs of final tesp
        decoder_final_with_eos_outputs = decoder_ensemble_with_eos(
            prev_tokens_with_eos,
            prev_scores_with_eos,
            torch.tensor([num_steps]),
            torch.tensor([True]),
            *states_with_eos,
        )

        np.testing.assert_array_equal(
            decoder_final_with_eos_outputs[0],
            torch.LongTensor([eos_token]).repeat(beam_size),
        )
        np.testing.assert_array_equal(
            decoder_final_with_eos_outputs[2],
            torch.LongTensor(np.array([i for i in range(beam_size)])),
        )