def _test_forced_decoder_export(self, test_args): _, src_dict, tgt_dict = test_utils.prepare_inputs(test_args) num_models = 3 model_list = [] for _ in range(num_models): model_list.append(models.build_model(test_args, src_dict, tgt_dict)) forced_decoder_ensemble = ForcedDecoder(model_list, word_reward=0.25, unk_reward=-0.5) tmp_dir = tempfile.mkdtemp() forced_decoder_pb_path = os.path.join(tmp_dir, "forced_decoder.pb") forced_decoder_ensemble.onnx_export(forced_decoder_pb_path)
def main(): parser = argparse.ArgumentParser( description="Export pytorch_translate models to caffe2 forced decoder" ) parser.add_argument( "--checkpoint", action="append", nargs="+", help="PyTorch checkpoint file (at least one required)", ) parser.add_argument( "--output_file", default="", help="File name to which to save forced decoder network", ) parser.add_argument( "--src_dict", required=True, help="File encoding PyTorch dictionary for source language", ) parser.add_argument( "--dst_dict", required=True, help="File encoding PyTorch dictionary for source language", ) parser.add_argument( "--word_reward", type=float, default=0.0, help="Value to add for each word (besides EOS)", ) parser.add_argument( "--unk_reward", type=float, default=0.0, help="Value to add for each word UNK token", ) args = parser.parse_args() if args.output_file == "": print("No action taken. Need output_file to be specified.") parser.print_help() return checkpoint_filenames = [arg[0] for arg in args.checkpoint] forced_decoder = ForcedDecoder.build_from_checkpoints( checkpoint_filenames=checkpoint_filenames, src_dict_filename=args.src_dict, dst_dict_filename=args.dst_dict, word_reward=args.word_reward, unk_reward=args.unk_reward, ) forced_decoder.save_to_db(args.output_file)