示例#1
0
    add_general_options(parser)
    task_class = add_task_option(parser)
    task_class.add_options(parser)
    trainer_class = add_trainer_option(parser)
    trainer_class.add_inference_options(parser)
    add_log_options(parser)

    parser.add_argument('-load_from',
                        type=str,
                        required=True,
                        help='Path to one or more pretrained models.')
    parser.add_argument('-output', help="Path to output the predictions")

    args = parser.parse_args()

    logger = setup_logging_from_args(args, 'evaluate')

    logger.debug('Torch version: {}'.format(torch.__version__))
    logger.debug(args)

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    task = task_class.setup_task(args)  # type: Task

    logger.info('Loading checkpoint {}'.format(args.load_from))
    checkpoint = load_checkpoint(args.load_from)

    trainer = trainer_class(args, for_training=False, checkpoint=checkpoint)

    results = trainer.solve(task)
示例#2
0
    parser = argparse.ArgumentParser(description="train.py")
    add_general_options(parser)
    task_class = add_task_option(parser)
    task_class.add_options(parser)
    trainer_class = add_trainer_option(parser)
    trainer_class.add_eval_options(parser)
    add_log_options(parser)

    parser.add_argument('-load_from', type=str, nargs='+', required=True,
                        help='Path to one or more pretrained models.')
    parser.add_argument('-output',
                        help="Path to output the predictions")

    args = parser.parse_args()

    logger = custom_logging.setup_logging_from_args(args, 'evaluate')

    logger.debug('Torch version: {}'.format(torch.__version__))
    logger.debug(args)

    torch.manual_seed(args.seed)

    task = task_class.setup_task(args)  # type: Task

    trainer = trainer_class(args)  # type: Trainer

    models = []
    for filename in args.load_from:
        logger.info('Loading checkpoint {}'.format(filename))
        models.append(trainer.load_checkpoint(convert.load_checkpoint(filename)))
示例#3
0
文件: train.py 项目: zqs01/NMTGMinor
    add_general_options(parser)
    task_class = add_task_option(parser)
    task_class.add_options(parser)
    trainer_class = add_trainer_option(parser)
    trainer_class.add_training_options(parser)
    add_log_options(parser)

    parser.add_argument('-load_from',
                        type=str,
                        help='If training from a checkpoint then this is the'
                        'path to the pretrained model.')
    parser.add_argument('-reset_optim', action='store_true')

    args = parser.parse_args()

    logger = setup_logging_from_args(args, 'train')

    logger.debug('Torch version: {}'.format(torch.__version__))
    logger.debug(args)

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    task = task_class.setup_task(args)  # type: Task

    if args.load_from is not None:
        logger.info("Loading checkpoint {}".format(args.load_from))
        checkpoint = load_checkpoint(args.load_from)
    else:
        checkpoint = None
示例#4
0
    parser = argparse.ArgumentParser(description="train.py")
    add_general_options(parser)
    task_class = add_task_option(parser)
    task_class.add_options(parser)
    trainer_class = add_trainer_option(parser)
    trainer_class.add_eval_options(parser)
    add_log_options(parser)

    parser.add_argument('-load_from',
                        type=str,
                        required=True,
                        help='Path to a pretrained model.')
    parser.add_argument('-output', help="Path to output the predictions")

    args = parser.parse_args()

    logger = custom_logging.setup_logging_from_args(args, 'validate')

    logger.debug('Torch version: {}'.format(torch.__version__))
    logger.debug(args)

    torch.manual_seed(args.seed)

    task = task_class.setup_task(args)  # type: Task

    trainer = trainer_class(args)  # type: Trainer

    model = trainer.load_checkpoint(convert.load_checkpoint(args.load_from))

    val_loss = trainer.evaluate(model, task)
示例#5
0
import argparse
import torch

from nmtg import custom_logging
from nmtg.average_checkpoints import average_checkpoints

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('checkpoints',
                        nargs='+',
                        help='Which checkpoints to average')
    parser.add_argument('-output',
                        type=str,
                        required=True,
                        help='Output filename')
    parser.add_argument('-method',
                        choices=['mean', 'gmean'],
                        default='mean',
                        help='Method of averaging')
    custom_logging.add_log_options(parser)
    args = parser.parse_args()
    logger = custom_logging.setup_logging_from_args(args,
                                                    'average_checkpoints.py')

    checkpoint = average_checkpoints(args.checkpoints, args.method)

    logger.info('Saving checkpoint to {}'.format(args.output))
    torch.save(checkpoint, args.output)