def main_worker(args, unknown_args): """Runs main worker thread from model training.""" args, config = parse_args_uargs(args, unknown_args) set_global_seed(args.seed) prepare_cudnn(args.deterministic, args.benchmark) config.setdefault("distributed_params", {})["apex"] = args.apex config.setdefault("distributed_params", {})["amp"] = args.amp expdir = Path(args.expdir) # optuna objective def objective(trial: optuna.trial): trial, trial_config = _process_trial_config(trial, config.copy()) experiment, runner, trial_config = prepare_config_api_components( expdir=expdir, config=trial_config) # @TODO: here we need better solution. experiment._trial = trial # noqa: WPS437 if experiment.logdir is not None and get_rank() <= 0: dump_environment(trial_config, experiment.logdir, args.configs) dump_code(args.expdir, experiment.logdir) runner.run_experiment(experiment) return runner.best_valid_metrics[runner.main_metric] # optuna direction direction = ("minimize" if config.get("stages", {}).get( "stage_params", {}).get("minimize_metric", True) else "maximize") # optuna study study_params = config.pop("study_params", {}) # optuna sampler sampler_params = study_params.pop("sampler_params", {}) optuna_sampler_type = sampler_params.pop("sampler", None) optuna_sampler = (optuna.samplers.__dict__[optuna_sampler_type]( **sampler_params) if optuna_sampler_type is not None else None) # optuna pruner pruner_params = study_params.pop("pruner_params", {}) optuna_pruner_type = pruner_params.pop("pruner", None) optuna_pruner = (optuna.pruners.__dict__[optuna_pruner_type]( **pruner_params) if optuna_pruner_type is not None else None) study = optuna.create_study( direction=direction, storage=args.storage or study_params.pop("storage", None), study_name=args.study_name or study_params.pop("study_name", None), sampler=optuna_sampler, pruner=optuna_pruner, ) study.optimize( objective, n_trials=args.n_trials, timeout=args.timeout, n_jobs=args.n_jobs or 1, gc_after_trial=args.gc_after_trial, show_progress_bar=args.show_progress_bar, )
def main_worker(args, unknown_args): """Runs main worker thread from model training.""" args, config = parse_args_uargs(args, unknown_args) set_global_seed(args.seed) prepare_cudnn(args.deterministic, args.benchmark) config.setdefault("distributed_params", {})["apex"] = args.apex config.setdefault("distributed_params", {})["amp"] = args.amp experiment, runner, config = prepare_config_api_components(expdir=Path( args.expdir), config=config) if experiment.logdir is not None and get_rank() <= 0: dump_environment(config, experiment.logdir, args.configs) dump_code(args.expdir, experiment.logdir) runner.run_experiment(experiment)