コード例 #1
0
def cli_main():
    parser = options.get_validation_parser()
    args = options.parse_args_and_arch(parser)
    # only override args that are explicitly given on the command line
    override_parser = options.get_validation_parser()
    override_args = options.parse_args_and_arch(override_parser,
                                                suppress_defaults=True)
    main(args, override_args)
コード例 #2
0
ファイル: validate.py プロジェクト: kts/fairseq-1
def cli_main():
    parser = options.get_validation_parser()
    args = options.parse_args_and_arch(parser)

    # only override args that are explicitly given on the command line
    override_parser = options.get_validation_parser()
    override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True)

    distributed_utils.call_main(convert_namespace_to_omegaconf(args), main, override_args=override_args)
コード例 #3
0
def cli_main():
    parser = options.get_validation_parser()
    add_distributed_training_args(parser)
    args = options.parse_args_and_arch(parser)

    # only override args that are explicitly given on the command line
    override_parser = options.get_validation_parser()
    add_distributed_training_args(override_parser)
    override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True)

    distributed_utils.call_main(args, main, override_args=override_args)
コード例 #4
0
ファイル: validate.py プロジェクト: Michiel29/graphqa
def cli_main():
    parser = options.get_validation_parser()
    parser.add_argument(
        '--config',
        type=str,
        nargs='*',
        help=
        'paths to JSON files of experiment configurations, from high to low priority'
    )
    parser.add_argument(
        '--load-checkpoint',
        type=str,
        help='path to checkpoint to load (possibly composite) model from')

    pre_parsed_args = parser.parse_args()

    config_dict = {}
    for config_path in pre_parsed_args.config:
        config_dict = update_config(config_dict, compose_configs(config_path))

    parser_modifier = modify_factory(config_dict)

    args = options.parse_args_and_arch(parser, modify_parser=parser_modifier)

    update_namespace(args, config_dict)

    main(args)
コード例 #5
0
def train_translation_model(data_dir, arch, extra_flags=None, task='translation'):
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            '--task', task,
            data_dir,
            '--save-dir', data_dir,
            '--arch', arch,
            '--lr', '0.05',
            '--max-tokens', '500',
            '--max-epoch', '1',
            '--no-progress-bar',
            '--distributed-world-size', '1',
            '--source-lang', 'in',
            '--target-lang', 'out',
        ] + (extra_flags or []),
    )
    train.main(train_args)

    # test validation
    validate_parser = options.get_validation_parser()
    validate_args = options.parse_args_and_arch(
        validate_parser,
        [
            '--task', task,
            data_dir,
            '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
            '--valid-subset', 'valid',
            '--max-tokens', '500',
            '--no-progress-bar',
        ]
    )
    validate.main(validate_args)
コード例 #6
0
ファイル: validate.py プロジェクト: zvict/NSVF
def cli_main():
    parser = options.get_validation_parser()
    args = options.parse_args_and_arch(parser)

    # only override args that are explicitly given on the command line
    override_parser = options.get_validation_parser()
    override_args = options.parse_args_and_arch(override_parser,
                                                suppress_defaults=True)

    # support multi-gpu validation, use all available gpus
    default_world_size = max(1, torch.cuda.device_count())
    if args.distributed_world_size < default_world_size:
        args.distributed_world_size = default_world_size
        override_args.distributed_world_size = default_world_size

    distributed_utils.call_main(args, main, override_args=override_args)
コード例 #7
0
def train_language_model(data_dir,
                         arch,
                         extra_flags=None,
                         run_validation=False):
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            "--task",
            "language_modeling",
            data_dir,
            "--arch",
            arch,
            "--optimizer",
            "adam",
            "--lr",
            "0.0001",
            "--criterion",
            "adaptive_loss",
            "--adaptive-softmax-cutoff",
            "5,10,15",
            "--max-tokens",
            "500",
            "--tokens-per-sample",
            "500",
            "--save-dir",
            data_dir,
            "--max-epoch",
            "1",
            "--no-progress-bar",
            "--distributed-world-size",
            "1",
            "--ddp-backend",
            "no_c10d",
        ] + (extra_flags or []),
    )
    train.main(train_args)

    if run_validation:
        # test validation
        validate_parser = options.get_validation_parser()
        validate_args = options.parse_args_and_arch(
            validate_parser,
            [
                "--task",
                "language_modeling",
                data_dir,
                "--path",
                os.path.join(data_dir, "checkpoint_last.pt"),
                "--valid-subset",
                "valid",
                "--max-tokens",
                "500",
                "--no-progress-bar",
            ],
        )
        validate.main(validate_args)
コード例 #8
0
def train_language_model(data_dir,
                         arch,
                         extra_flags=None,
                         run_validation=False):
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            '--task',
            'language_modeling',
            data_dir,
            '--arch',
            arch,
            '--optimizer',
            'adam',
            '--lr',
            '0.0001',
            '--criterion',
            'adaptive_loss',
            '--adaptive-softmax-cutoff',
            '5,10,15',
            '--max-tokens',
            '500',
            '--tokens-per-sample',
            '500',
            '--save-dir',
            data_dir,
            '--max-epoch',
            '1',
            '--no-progress-bar',
            '--distributed-world-size',
            '1',
            '--ddp-backend',
            'no_c10d',
        ] + (extra_flags or []),
    )
    train.main(train_args)

    if run_validation:
        # test validation
        validate_parser = options.get_validation_parser()
        validate_args = options.parse_args_and_arch(validate_parser, [
            '--task',
            'language_modeling',
            data_dir,
            '--path',
            os.path.join(data_dir, 'checkpoint_last.pt'),
            '--valid-subset',
            'valid',
            '--max-tokens',
            '500',
            '--no-progress-bar',
        ])
        validate.main(validate_args)
コード例 #9
0
def train_translation_model(data_dir,
                            arch,
                            extra_flags=None,
                            task="translation",
                            run_validation=False):
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            "--task",
            task,
            data_dir,
            "--save-dir",
            data_dir,
            "--arch",
            arch,
            "--lr",
            "0.05",
            "--max-tokens",
            "500",
            "--max-epoch",
            "1",
            "--no-progress-bar",
            "--distributed-world-size",
            "1",
            "--source-lang",
            "in",
            "--target-lang",
            "out",
        ] + (extra_flags or []),
    )
    train.main(train_args)

    if run_validation:
        # test validation
        validate_parser = options.get_validation_parser()
        validate_args = options.parse_args_and_arch(
            validate_parser,
            [
                "--task",
                task,
                data_dir,
                "--path",
                os.path.join(data_dir, "checkpoint_last.pt"),
                "--valid-subset",
                "valid",
                "--max-tokens",
                "500",
                "--no-progress-bar",
            ],
        )
        validate.main(validate_args)
コード例 #10
0
def cli_main():
    parser = options.get_validation_mul_parser()
    args = options.parse_args_and_arch(parser)
    override_parser = options.get_validation_parser()
    override_args = options.parse_args_and_arch(override_parser,
                                                suppress_defaults=True)

    if args.distributed_init_method is None:
        distributed_utils.infer_init_method(args)

    if args.distributed_init_method is not None:
        # distributed training
        if torch.cuda.device_count() > 1 and not args.distributed_no_spawn:
            start_rank = args.distributed_rank
            args.distributed_rank = None  # assign automatically
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(args, start_rank),
                nprocs=torch.cuda.device_count(),
            )
        else:
            distributed_main(args.device_id, args)
    elif args.distributed_world_size > 1:
        # fallback for single node with multiple GPUs
        assert args.distributed_world_size <= torch.cuda.device_count()
        port = random.randint(10000, 20000)
        args.distributed_init_method = 'tcp://localhost:{port}'.format(
            port=port)
        args.distributed_rank = None  # set based on device id
        #if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
        #    print('| NOTE: you may get better performance with: --ddp-backend=no_c10d')
        torch.multiprocessing.spawn(
            fn=distributed_main,
            args=(args, override_args),
            nprocs=args.distributed_world_size,
        )
    else:
        # single GPU training
        main(args, override_args)
コード例 #11
0
def train_language_model(
    data_dir,
    arch,
    extra_flags=None,
    run_validation=False,
    extra_valid_flags=None,
    task="language_modeling",
    world_size=1,
):
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            "--task",
            task,
            data_dir,
            "--arch",
            arch,
            "--optimizer",
            "adam",
            "--lr",
            "0.0001",
            "--max-tokens",
            "500",
            "--tokens-per-sample",
            "500",
            "--save-dir",
            data_dir,
            "--max-epoch",
            "1",
            "--no-progress-bar",
            "--distributed-world-size",
            str(world_size),
            "--ddp-backend",
            "no_c10d",
            "--num-workers",
            "0",
        ] + (extra_flags or []),
    )
    cfg = convert_namespace_to_omegaconf(train_args)
    distributed_utils.call_main(cfg, train.main)

    if run_validation:
        # test validation
        validate_parser = options.get_validation_parser()
        validate_args = options.parse_args_and_arch(
            validate_parser,
            [
                "--task",
                task,
                data_dir,
                "--path",
                os.path.join(data_dir, "checkpoint_last.pt"),
                "--valid-subset",
                "valid",
                "--max-tokens",
                "500",
                "--no-progress-bar",
                "--num-workers",
                "0",
            ] + (extra_valid_flags or []),
        )
        validate.main(validate_args)
コード例 #12
0
def train_translation_model(
    data_dir,
    arch,
    extra_flags=None,
    task="translation",
    run_validation=False,
    lang_flags=None,
    extra_valid_flags=None,
    world_size=1,
):
    if lang_flags is None:
        lang_flags = [
            "--source-lang",
            "in",
            "--target-lang",
            "out",
        ]
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            "--task",
            task,
            data_dir,
            "--save-dir",
            data_dir,
            "--arch",
            arch,
            "--optimizer",
            "nag",
            "--lr",
            "0.05",
            "--max-tokens",
            "500",
            "--max-epoch",
            "1",
            "--no-progress-bar",
            "--distributed-world-size",
            str(world_size),
            "--num-workers",
            "0",
        ] + lang_flags + (extra_flags or []),
    )

    cfg = convert_namespace_to_omegaconf(train_args)
    distributed_utils.call_main(cfg, train.main)

    if run_validation:
        # test validation
        validate_parser = options.get_validation_parser()
        validate_args = options.parse_args_and_arch(
            validate_parser,
            [
                "--task",
                task,
                data_dir,
                "--path",
                os.path.join(data_dir, "checkpoint_last.pt"),
                "--valid-subset",
                "valid",
                "--max-tokens",
                "500",
                "--no-progress-bar",
                "--num-workers",
                "0",
            ] + lang_flags + (extra_valid_flags or []),
        )
        validate.main(validate_args)