def test_batched_beam_from_checkpoints_vr(self):
        check_dir = (
            '/mnt/gfsdataswarm-global/namespaces/search/language-technology-mt/'
            'nnmt_tmp/tl_XX-en_XX-pytorch-256-dim-vocab-reduction'
        )
        checkpoints = [
            'averaged_checkpoint_best_3.pt',
            'averaged_checkpoint_best_4.pt',
            'averaged_checkpoint_best_5.pt',
        ]
        checkpoint_filenames = [os.path.join(check_dir, f) for f in checkpoints]

        encoder_ensemble = EncoderEnsemble.build_from_checkpoints(
            checkpoint_filenames,
            os.path.join(check_dir, 'dictionary-tl.txt'),
            os.path.join(check_dir, 'dictionary-en.txt'),
        )

        decoder_step_ensemble = DecoderBatchedStepEnsemble.build_from_checkpoints(
            checkpoint_filenames,
            os.path.join(check_dir, 'dictionary-tl.txt'),
            os.path.join(check_dir, 'dictionary-en.txt'),
            beam_size=5,
        )

        self._test_full_ensemble(
            encoder_ensemble,
            decoder_step_ensemble,
            batched_beam=True,
        )
Exemple #2
0
def export(args):
    assert_required_args_are_set(args)
    checkpoint_filenames = args.path.split(":")

    if args.char_source:
        encoder_class = CharSourceEncoderEnsemble
    else:
        encoder_class = EncoderEnsemble

    encoder_ensemble = encoder_class.build_from_checkpoints(
        checkpoint_filenames=checkpoint_filenames,
        src_dict_filename=args.source_vocab_file,
        dst_dict_filename=args.target_vocab_file,
    )
    if args.encoder_output_file != "":
        encoder_ensemble.save_to_db(args.encoder_output_file)

    if args.decoder_output_file != "":
        decoder_step_ensemble = DecoderBatchedStepEnsemble.build_from_checkpoints(
            checkpoint_filenames=checkpoint_filenames,
            src_dict_filename=args.source_vocab_file,
            dst_dict_filename=args.target_vocab_file,
            beam_size=args.beam_size,
            word_reward=args.word_reward,
            unk_reward=args.unk_reward,
        )

        # need example encoder outputs to pass through network
        # (source length 5 is arbitrary)
        src_dict = encoder_ensemble.src_dict
        token_list = [src_dict.unk()] * 4 + [src_dict.eos()]
        src_tokens = torch.LongTensor(
            np.array(token_list, dtype="int64").reshape(-1, 1)
        )
        src_lengths = torch.IntTensor(np.array([len(token_list)], dtype="int32"))
        if args.char_source:
            char_inds = torch.LongTensor(np.ones((1, 5, 3), dtype="int64"))
            word_lengths = torch.LongTensor(np.array([3] * 5, dtype="int64")).reshape(
                1, 5
            )
            pytorch_encoder_outputs = encoder_ensemble(
                src_tokens, src_lengths, char_inds, word_lengths
            )
        else:
            pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths)

        decoder_step_ensemble.save_to_db(
            args.decoder_output_file, pytorch_encoder_outputs
        )