コード例 #1
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    if args.logdir is not None:
        os.makedirs(args.logdir, exist_ok=True)
        dump_environment(config, args.logdir, args.configs)

    if args.expdir is not None:
        module = import_module(expdir=args.expdir)  # noqa: F841
        if args.logdir is not None:
            dump_code(args.expdir, args.logdir)

    env = ENVIRONMENTS.get_from_params(**config["environment"])

    algorithm_name = config["algorithm"].pop("algorithm")
    if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = OFFPOLICY_ALGORITHMS
        trainer_fn = OffpolicyTrainer
        sync_epoch = False
    elif algorithm_name in ONPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = ONPOLICY_ALGORITHMS
        trainer_fn = OnpolicyTrainer
        sync_epoch = True
    else:
        # @TODO: add registry for algorithms, trainers, samplers
        raise NotImplementedError()

    db_server = DATABASES.get_from_params(
        **config.get("db", {}), sync_epoch=sync_epoch
    )

    algorithm_fn = ALGORITHMS.get(algorithm_name)
    algorithm = algorithm_fn.prepare_for_trainer(env_spec=env, config=config)

    if args.resume is not None:
        checkpoint = utils.load_checkpoint(filepath=args.resume)
        checkpoint = utils.any2device(checkpoint, utils.get_device())
        algorithm.unpack_checkpoint(
            checkpoint=checkpoint,
            with_optimizer=False
        )

    monitoring_params = config.get("monitoring_params", None)

    trainer = trainer_fn(
        algorithm=algorithm,
        env_spec=env,
        db_server=db_server,
        logdir=args.logdir,
        monitoring_params=monitoring_params,
        **config["trainer"],
    )

    trainer.run()
コード例 #2
0
    def objective(trial: optuna.trial):
        trial, trial_config = _process_trial_config(trial, config.copy())
        experiment, runner, trial_config = prepare_config_api_components(
            expdir=expdir, config=trial_config)
        # @TODO: here we need better solution.
        experiment._trial = trial  # noqa: WPS437

        if experiment.logdir is not None and get_rank() <= 0:
            dump_environment(trial_config, experiment.logdir, args.configs)
            dump_code(args.expdir, experiment.logdir)

        runner.run_experiment(experiment)

        return runner.best_valid_metrics[runner.main_metric]
コード例 #3
0
ファイル: run.py プロジェクト: ram-iyer/catalyst
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)

    Experiment, Runner = import_experiment_and_runner(Path(args.expdir))

    experiment = Experiment(config)
    runner = Runner()

    if experiment.logdir is not None:
        dump_config(config, experiment.logdir, args.configs)
        dump_code(args.expdir, experiment.logdir)

    runner.run_experiment(experiment, check=args.check)
コード例 #4
0
ファイル: run.py プロジェクト: yuv4r4j/catalyst
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    Experiment, Runner = import_experiment_and_runner(Path(args.expdir))

    experiment = Experiment(config)
    runner = Runner()

    if experiment.logdir is not None:
        dump_environment(config, experiment.logdir, args.configs)
        dump_code(args.expdir, experiment.logdir)

    runner.run_experiment(experiment, check=args.check)
コード例 #5
0
def main_worker(cfg: DictConfig):
    set_global_seed(cfg.args.seed)
    prepare_cudnn(cfg.args.deterministic, cfg.args.benchmark)

    import_module(hydra.utils.to_absolute_path(cfg.args.expdir))

    experiment = hydra.utils.instantiate(cfg.experiment, cfg=cfg)
    runner = hydra.utils.instantiate(cfg.runner)

    if experiment.logdir is not None and get_rank() <= 0:
        dump_environment(cfg, experiment.logdir)
        dump_code(
            hydra.utils.to_absolute_path(cfg.args.expdir), experiment.logdir
        )

    runner.run_experiment(experiment)
コード例 #6
0
def main_worker(args, unknown_args):
    """Runs main worker thread from model training."""
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    config.setdefault("distributed_params", {})["apex"] = args.apex
    config.setdefault("distributed_params", {})["amp"] = args.amp

    experiment, runner, config = prepare_config_api_components(expdir=Path(
        args.expdir),
                                                               config=config)

    if experiment.logdir is not None and get_rank() <= 0:
        dump_environment(config, experiment.logdir, args.configs)
        dump_code(args.expdir, experiment.logdir)

    runner.run_experiment(experiment)
コード例 #7
0
def main(args, unknown_args):
    """Run the ``catalyst-dl run`` script"""
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    Experiment, Runner = import_experiment_and_runner(Path(args.expdir))

    runner_params = config.pop("runner_params", {}) or {}
    experiment = Experiment(config)
    runner = Runner(**runner_params)

    if experiment.logdir is not None:
        dump_environment(config, experiment.logdir, args.configs)
        dump_code(args.expdir, experiment.logdir)

    check_run = safitty.get(config, "args", "check", default=False)
    runner.run_experiment(experiment, check=check_run)