def test_batched_beam_from_checkpoints_vr(self): check_dir = ( '/mnt/gfsdataswarm-global/namespaces/search/language-technology-mt/' 'nnmt_tmp/tl_XX-en_XX-pytorch-256-dim-vocab-reduction' ) checkpoints = [ 'averaged_checkpoint_best_3.pt', 'averaged_checkpoint_best_4.pt', 'averaged_checkpoint_best_5.pt', ] checkpoint_filenames = [os.path.join(check_dir, f) for f in checkpoints] encoder_ensemble = EncoderEnsemble.build_from_checkpoints( checkpoint_filenames, os.path.join(check_dir, 'dictionary-tl.txt'), os.path.join(check_dir, 'dictionary-en.txt'), ) decoder_step_ensemble = DecoderBatchedStepEnsemble.build_from_checkpoints( checkpoint_filenames, os.path.join(check_dir, 'dictionary-tl.txt'), os.path.join(check_dir, 'dictionary-en.txt'), beam_size=5, ) self._test_full_ensemble( encoder_ensemble, decoder_step_ensemble, batched_beam=True, )
def export(args): assert_required_args_are_set(args) checkpoint_filenames = args.path.split(":") if args.char_source: encoder_class = CharSourceEncoderEnsemble else: encoder_class = EncoderEnsemble encoder_ensemble = encoder_class.build_from_checkpoints( checkpoint_filenames=checkpoint_filenames, src_dict_filename=args.source_vocab_file, dst_dict_filename=args.target_vocab_file, ) if args.encoder_output_file != "": encoder_ensemble.save_to_db(args.encoder_output_file) if args.decoder_output_file != "": decoder_step_ensemble = DecoderBatchedStepEnsemble.build_from_checkpoints( checkpoint_filenames=checkpoint_filenames, src_dict_filename=args.source_vocab_file, dst_dict_filename=args.target_vocab_file, beam_size=args.beam_size, word_reward=args.word_reward, unk_reward=args.unk_reward, ) # need example encoder outputs to pass through network # (source length 5 is arbitrary) src_dict = encoder_ensemble.src_dict token_list = [src_dict.unk()] * 4 + [src_dict.eos()] src_tokens = torch.LongTensor( np.array(token_list, dtype="int64").reshape(-1, 1) ) src_lengths = torch.IntTensor(np.array([len(token_list)], dtype="int32")) if args.char_source: char_inds = torch.LongTensor(np.ones((1, 5, 3), dtype="int64")) word_lengths = torch.LongTensor(np.array([3] * 5, dtype="int64")).reshape( 1, 5 ) pytorch_encoder_outputs = encoder_ensemble( src_tokens, src_lengths, char_inds, word_lengths ) else: pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths) decoder_step_ensemble.save_to_db( args.decoder_output_file, pytorch_encoder_outputs )