def test_batched_beam_from_checkpoints_vr(self):
        check_dir = (
            '/mnt/gfsdataswarm-global/namespaces/search/language-technology-mt/'
            'nnmt_tmp/tl_XX-en_XX-pytorch-256-dim-vocab-reduction'
        )
        checkpoints = [
            'averaged_checkpoint_best_3.pt',
            'averaged_checkpoint_best_4.pt',
            'averaged_checkpoint_best_5.pt',
        ]
        checkpoint_filenames = [os.path.join(check_dir, f) for f in checkpoints]

        encoder_ensemble = EncoderEnsemble.build_from_checkpoints(
            checkpoint_filenames,
            os.path.join(check_dir, 'dictionary-tl.txt'),
            os.path.join(check_dir, 'dictionary-en.txt'),
        )

        decoder_step_ensemble = DecoderBatchedStepEnsemble.build_from_checkpoints(
            checkpoint_filenames,
            os.path.join(check_dir, 'dictionary-tl.txt'),
            os.path.join(check_dir, 'dictionary-en.txt'),
            beam_size=5,
        )

        self._test_full_ensemble(
            encoder_ensemble,
            decoder_step_ensemble,
            batched_beam=True,
        )
    def test_export_encoder_from_checkpoints_no_vr(self):
        check_dir = (
            '/mnt/gfsdataswarm-global/namespaces/search/language-technology-mt/'
            'nnmt_tmp/tl_XX-en_XX-pytorch-testing-no-vocab-reduction-2'
        )
        checkpoints = [
            'averaged_checkpoint_best_3.pt',
            'averaged_checkpoint_best_4.pt',
        ]
        checkpoint_filenames = [os.path.join(check_dir, f) for f in checkpoints]

        encoder_ensemble = EncoderEnsemble.build_from_checkpoints(
            checkpoint_filenames,
            os.path.join(check_dir, 'dictionary-tl.txt'),
            os.path.join(check_dir, 'dictionary-en.txt'),
        )

        self._test_ensemble_encoder_object_export(encoder_ensemble)
def export(args):
    assert_required_args_are_set(args)
    checkpoint_filenames = [arg[0] for arg in args.path]

    encoder_ensemble = EncoderEnsemble.build_from_checkpoints(
        checkpoint_filenames=checkpoint_filenames,
        src_dict_filename=args.source_vocab_file,
        dst_dict_filename=args.target_vocab_file,
    )
    if args.encoder_output_file != '':
        encoder_ensemble.save_to_db(args.encoder_output_file)

    if args.decoder_output_file != '':
        if args.batched_beam:
            decoder_step_class = DecoderBatchedStepEnsemble
        else:
            decoder_step_class = DecoderStepEnsemble
        decoder_step_ensemble = decoder_step_class.build_from_checkpoints(
            checkpoint_filenames=checkpoint_filenames,
            src_dict_filename=args.source_vocab_file,
            dst_dict_filename=args.target_vocab_file,
            beam_size=args.beam_size,
            word_penalty=args.word_penalty,
            unk_penalty=args.unk_penalty,
        )

        # need example encoder outputs to pass through network
        # (source length 5 is arbitrary)
        src_dict = encoder_ensemble.models[0].src_dict
        token_list = [src_dict.unk()] * 4 + [src_dict.eos()]
        src_tokens = torch.LongTensor(
            np.array(token_list, dtype='int64').reshape(-1, 1), )
        src_lengths = torch.IntTensor(
            np.array([len(token_list)], dtype='int32'), )
        pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths)

        decoder_step_ensemble.save_to_db(
            args.decoder_output_file,
            pytorch_encoder_outputs,
        )
def main():
    parser = argparse.ArgumentParser(description=(
        'Export PyTorch-trained FBTranslate models to Caffe2 components'), )
    parser.add_argument(
        '--checkpoint',
        action='append',
        nargs='+',
        help='PyTorch checkpoint file (at least one required)',
    )
    parser.add_argument(
        '--encoder_output_file',
        default='',
        help='File name to which to save encoder ensemble network',
    )
    parser.add_argument(
        '--decoder_output_file',
        default='',
        help='File name to which to save decoder step ensemble network',
    )
    parser.add_argument(
        '--src_dict',
        required=True,
        help='File encoding PyTorch dictionary for source language',
    )
    parser.add_argument(
        '--dst_dict',
        required=True,
        help='File encoding PyTorch dictionary for source language',
    )
    parser.add_argument(
        '--beam_size',
        type=int,
        default=6,
        help='Number of top candidates returned by each decoder step',
    )
    parser.add_argument(
        '--word_penalty',
        type=float,
        default=0.0,
        help='Value to add for each word (besides EOS)',
    )
    parser.add_argument(
        '--unk_penalty',
        type=float,
        default=0.0,
        help='Value to add for each word UNK token',
    )
    parser.add_argument(
        '--batched_beam',
        action='store_true',
        help='Decoder step has entire beam as input/output',
    )

    args = parser.parse_args()

    if args.encoder_output_file == args.decoder_output_file == '':
        print('No action taken. Need at least one of --encoder_output_file '
              'and --decoder_output_file.')
        parser.print_help()
        return

    checkpoint_filenames = [arg[0] for arg in args.checkpoint]

    encoder_ensemble = EncoderEnsemble.build_from_checkpoints(
        checkpoint_filenames=checkpoint_filenames,
        src_dict_filename=args.src_dict,
        dst_dict_filename=args.dst_dict,
    )
    if args.encoder_output_file != '':
        encoder_ensemble.save_to_db(args.encoder_output_file)

    if args.decoder_output_file != '':
        if args.batched_beam:
            decoder_step_class = DecoderBatchedStepEnsemble
        else:
            decoder_step_class = DecoderStepEnsemble
        decoder_step_ensemble = decoder_step_class.build_from_checkpoints(
            checkpoint_filenames=checkpoint_filenames,
            src_dict_filename=args.src_dict,
            dst_dict_filename=args.dst_dict,
            beam_size=args.beam_size,
            word_penalty=args.word_penalty,
            unk_penalty=args.unk_penalty,
        )

        # need example encoder outputs to pass through network
        # (source length 5 is arbitrary)
        src_dict = encoder_ensemble.models[0].src_dict
        token_list = [src_dict.unk()] * 4 + [src_dict.eos()]
        src_tokens = torch.LongTensor(
            np.array(token_list, dtype='int64').reshape(-1, 1), )
        src_lengths = torch.IntTensor(
            np.array([len(token_list)], dtype='int32'), )
        pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths)

        decoder_step_ensemble.save_to_db(
            args.decoder_output_file,
            pytorch_encoder_outputs,
        )