Ejemplo n.º 1
0
def _get_parser():
    parser = ArgumentParser(description='train.py')

    opts.config_opts(parser)
    opts.model_opts(parser)
    opts.train_opts(parser)
    return parser
Ejemplo n.º 2
0
def _get_parser():
    parser = ArgumentParser(description='build_copy_transformer.py')

    opts.config_opts(parser)
    opts.model_opts(parser)
    opts.train_opts(parser)
    return parser
Ejemplo n.º 3
0
def _get_parser():
    parser = ArgumentParser(description='train.py')

    opts.config_opts(parser)
    opts.model_opts(parser)
    opts.train_opts(parser)
    parser.add('--data', '-data', required=False,
               default='F:/Project/Python/selfProject/translate_NMT/data/demo',
               help='Path prefix to the ".train.pt" and '
                    '".valid.pt" file path from preprocess.py')

    parser.add('--save_model', '-save_model', required=False,
               default='F:/Project/Python/selfProject/translate_NMT/data',
               help="Model filename (the model will be saved as "
                    "<save_model>_N.pt where N is the number "
                    "of steps")

    parser.add('--save_checkpoint_steps', '-save_checkpoint_steps',
               type=int, default=500,
               help="""Save a checkpoint every X steps""")

    parser.add('--train_from', '-train_from',
               # default='F:/Project/Python/selfProject/translate_NMT/data/demo-model_step_150.pt',
               default='',
               type=str,
               help="If training from a checkpoint then this is the "
                    "path to the pretrained model's state_dict.")

    # default = 100000,
    parser.add('--train_steps', '-train_steps', type=int, default=100000,
               help='训练多少步')
    return parser
Ejemplo n.º 4
0
def _get_parser():
    parser = ArgumentParser(description='train.py')
    parser.add_argument('--teacher_model_path',
                        action='store',
                        dest='teacher_model_path',
                        help='the path direct to the teacher model path')
    parser.add_argument("--word_sampling",
                        action="store",
                        default=False,
                        help="optional arg")

    opts.config_opts(parser)
    opts.model_opts(parser)
    opts.train_opts(parser)
    return parser
Ejemplo n.º 5
0
def parse_args():
    parser = configargparse.ArgumentParser(
        description='train.py',
        config_file_parser_class=configargparse.YAMLConfigFileParser,
        formatter_class=configargparse.ArgumentDefaultsHelpFormatter)

    opts.general_opts(parser)
    opts.config_opts(parser)
    opts.add_md_help_argument(parser)
    opts.model_opts(parser)
    opts.train_opts(parser)

    opt = parser.parse_args()

    return opt
Ejemplo n.º 6
0
def _get_parser():
    parser = ArgumentParser(description='train.py')

    opts.config_opts(parser)
    opts.model_opts(parser)
    opts.train_opts(parser)
    '''extended opts for pretrained language models'''
    group = parser.add_argument_group("extended opts")
    group.add('--pretrained_encoder',
              '-pretrained_encoder',
              default="bert",
              type=str,
              choices=["bert", "roberta", "xlnet"],
              help="choose a pretrained language model as encoder")

    return parser
            'opt': self.model_opt,
            'optim': self.optim,
        }

        logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
        checkpoint_path = '%s_step_%d.pt' % (self.base_path, step)
        torch.save(checkpoint, checkpoint_path)
        return checkpoint, checkpoint_path

    def _rm_checkpoint(self, name):
        """
        Remove a checkpoint

        Args:
            name(str): name that indentifies the checkpoint
                (it may be a filepath)
        """
        os.remove(name)


if __name__ == "__main__":
    parser = configargparse.ArgumentParser(
        description='train.py',
        formatter_class=configargparse.ArgumentDefaultsHelpFormatter)

    opts.model_opts(parser)
    opts.train_opts(parser)

    opt = parser.parse_args()
    main(opt)
Ejemplo n.º 8
0
import onmt.opts as opts
from train_multi import main as multi_main
from train_single import main as single_main


def main(opt):

    if opt.rnn_type == "SRU" and not opt.gpuid:
        raise AssertionError("Using SRU requires -gpuid set.")

    if torch.cuda.is_available() and not opt.gpuid:
        print("WARNING: You have a CUDA device, should run with -gpuid 0")

    if len(opt.gpuid) > 1:
        multi_main(opt)
    else:
        single_main(opt)


if __name__ == "__main__":
    PARSER = argparse.ArgumentParser(
        description='train.py',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    opts.add_md_help_argument(PARSER)
    opts.model_opts(PARSER)
    opts.train_opts(PARSER)

    OPT = PARSER.parse_args()
    main(OPT)
Ejemplo n.º 9
0
def _get_parser():
    parser = ArgumentParser(description='train.py')
    train_opts(parser)
    return parser
Ejemplo n.º 10
0
        os.path.join(temp, "data", "train_source.txt"), "-train_tgt",
        os.path.join(temp, "data", "train_target.txt"), "-valid_src",
        os.path.join(temp, "data", "dev_target.txt"), "-valid_tgt",
        os.path.join(temp, "data", "dev_target.txt"), "-save_data",
        os.path.join(temp, "data", "out")
    ])
    preproc_args.shuffle = 0
    preproc_args.src_seq_length = source_max
    preproc_args.tgt_seq_length = target_max

    train_parser = argparse.ArgumentParser(
        description='vivisect example',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    opts.add_md_help_argument(train_parser)
    opts.model_opts(train_parser)
    opts.train_opts(train_parser)
    train_args = train_parser.parse_args([
        "-data",
        os.path.join(temp, "data/out"),
        "-train_steps",
        str(args.epochs - 1),
        "-save_model",
        os.path.join(temp, "model"),
        "-enc_layers",
        "3",
        "-dec_layers",
        "3",
        "-rnn_size",
        "50",
        "-src_word_vec_size",
        "25",