示例#1
0
def main_worker(args, unknown_args):
    """Runs main worker thread from model training."""
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    config.setdefault("distributed_params", {})["apex"] = args.apex
    config.setdefault("distributed_params", {})["amp"] = args.amp
    expdir = Path(args.expdir)

    # optuna objective
    def objective(trial: optuna.trial):
        trial, trial_config = _process_trial_config(trial, config.copy())
        experiment, runner, trial_config = prepare_config_api_components(
            expdir=expdir, config=trial_config)
        # @TODO: here we need better solution.
        experiment._trial = trial  # noqa: WPS437

        if experiment.logdir is not None and get_rank() <= 0:
            dump_environment(trial_config, experiment.logdir, args.configs)
            dump_code(args.expdir, experiment.logdir)

        runner.run_experiment(experiment)

        return runner.best_valid_metrics[runner.main_metric]

    # optuna direction
    direction = ("minimize" if config.get("stages", {}).get(
        "stage_params", {}).get("minimize_metric", True) else "maximize")

    # optuna study
    study_params = config.pop("study_params", {})

    # optuna sampler
    sampler_params = study_params.pop("sampler_params", {})
    optuna_sampler_type = sampler_params.pop("sampler", None)
    optuna_sampler = (optuna.samplers.__dict__[optuna_sampler_type](
        **sampler_params) if optuna_sampler_type is not None else None)

    # optuna pruner
    pruner_params = study_params.pop("pruner_params", {})
    optuna_pruner_type = pruner_params.pop("pruner", None)
    optuna_pruner = (optuna.pruners.__dict__[optuna_pruner_type](
        **pruner_params) if optuna_pruner_type is not None else None)

    study = optuna.create_study(
        direction=direction,
        storage=args.storage or study_params.pop("storage", None),
        study_name=args.study_name or study_params.pop("study_name", None),
        sampler=optuna_sampler,
        pruner=optuna_pruner,
    )
    study.optimize(
        objective,
        n_trials=args.n_trials,
        timeout=args.timeout,
        n_jobs=args.n_jobs or 1,
        gc_after_trial=args.gc_after_trial,
        show_progress_bar=args.show_progress_bar,
    )
示例#2
0
def main_worker(args, unknown_args):
    """Runs main worker thread from model training."""
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    config.setdefault("distributed_params", {})["apex"] = args.apex
    config.setdefault("distributed_params", {})["amp"] = args.amp

    experiment, runner, config = prepare_config_api_components(expdir=Path(
        args.expdir),
                                                               config=config)

    if experiment.logdir is not None and get_rank() <= 0:
        dump_environment(config, experiment.logdir, args.configs)
        dump_code(args.expdir, experiment.logdir)

    runner.run_experiment(experiment)