コード例 #1
0
    def _test_forced_decoder_export(self, test_args):
        _, src_dict, tgt_dict = test_utils.prepare_inputs(test_args)

        num_models = 3
        model_list = []
        for _ in range(num_models):
            model_list.append(models.build_model(test_args, src_dict,
                                                 tgt_dict))

        forced_decoder_ensemble = ForcedDecoder(model_list,
                                                word_reward=0.25,
                                                unk_reward=-0.5)

        tmp_dir = tempfile.mkdtemp()
        forced_decoder_pb_path = os.path.join(tmp_dir, "forced_decoder.pb")
        forced_decoder_ensemble.onnx_export(forced_decoder_pb_path)
コード例 #2
0
def main():
    parser = argparse.ArgumentParser(
        description="Export pytorch_translate models to caffe2 forced decoder"
    )
    parser.add_argument(
        "--checkpoint",
        action="append",
        nargs="+",
        help="PyTorch checkpoint file (at least one required)",
    )
    parser.add_argument(
        "--output_file",
        default="",
        help="File name to which to save forced decoder network",
    )
    parser.add_argument(
        "--src_dict",
        required=True,
        help="File encoding PyTorch dictionary for source language",
    )
    parser.add_argument(
        "--dst_dict",
        required=True,
        help="File encoding PyTorch dictionary for source language",
    )
    parser.add_argument(
        "--word_reward",
        type=float,
        default=0.0,
        help="Value to add for each word (besides EOS)",
    )
    parser.add_argument(
        "--unk_reward",
        type=float,
        default=0.0,
        help="Value to add for each word UNK token",
    )

    args = parser.parse_args()

    if args.output_file == "":
        print("No action taken. Need output_file to be specified.")
        parser.print_help()
        return

    checkpoint_filenames = [arg[0] for arg in args.checkpoint]

    forced_decoder = ForcedDecoder.build_from_checkpoints(
        checkpoint_filenames=checkpoint_filenames,
        src_dict_filename=args.src_dict,
        dst_dict_filename=args.dst_dict,
        word_reward=args.word_reward,
        unk_reward=args.unk_reward,
    )
    forced_decoder.save_to_db(args.output_file)