def run( self, config_name: Optional[str], task_function: TaskFunction, overrides: List[str], with_log_configuration: bool = True, ) -> JobReturn: cfg = self.compose_config( config_name=config_name, overrides=overrides, with_log_configuration=with_log_configuration, run_mode=RunMode.RUN, ) callbacks = Callbacks(cfg) callbacks.on_run_start(config=cfg, config_name=config_name) ret = run_job( hydra_context=HydraContext(config_loader=self.config_loader, callbacks=callbacks), task_function=task_function, config=cfg, job_dir_key="hydra.run.dir", job_subdir_key=None, configure_logging=with_log_configuration, ) callbacks.on_run_end(config=cfg, config_name=config_name, job_return=ret) # access the result to trigger an exception in case the job failed. _ = ret.return_value return ret
def multirun( self, config_name: Optional[str], task_function: TaskFunction, overrides: List[str], with_log_configuration: bool = True, ) -> Any: cfg = self.compose_config( config_name=config_name, overrides=overrides, with_log_configuration=with_log_configuration, run_mode=RunMode.MULTIRUN, ) callbacks = Callbacks(cfg) callbacks.on_multirun_start(config=cfg, config_name=config_name) sweeper = Plugins.instance().instantiate_sweeper( config=cfg, hydra_context=HydraContext(config_loader=self.config_loader, callbacks=callbacks), task_function=task_function, ) task_overrides = OmegaConf.to_container(cfg.hydra.overrides.task, resolve=False) assert isinstance(task_overrides, list) ret = sweeper.sweep(arguments=task_overrides) callbacks.on_multirun_end(config=cfg, config_name=config_name) return ret
def test_setup_plugins( monkeypatch: Any, plugin: Union[Launcher, Sweeper], config: DictConfig ) -> None: task_function = Mock(spec=TaskFunction) config_loader = ConfigLoaderImpl(config_search_path=create_config_search_path(None)) hydra_context = HydraContext(config_loader=config_loader, callbacks=Callbacks()) plugin_instance = Plugins.instance() monkeypatch.setattr(Plugins, "check_usage", lambda _: None) monkeypatch.setattr(plugin_instance, "_instantiate", lambda _: plugin) msg = dedent( """ Plugin's setup() signature has changed in Hydra 1.1. Support for the old style will be removed in Hydra 1.2. For more info, check https://github.com/facebookresearch/hydra/pull/1581.""" ) with warns(expected_warning=UserWarning, match=re.escape(msg)): if isinstance(plugin, Launcher): Plugins.instance().instantiate_launcher( task_function=task_function, config=config, config_loader=config_loader, hydra_context=hydra_context, ) else: Plugins.instance().instantiate_sweeper( hydra_context=hydra_context, task_function=task_function, config=config, )
def _setup_plugin( plugin: Any, task_function: TaskFunction, config: DictConfig, config_loader: Optional[ConfigLoader] = None, hydra_context: Optional[HydraContext] = None, ) -> Any: """ With HydraContext introduced in #1581, we need to set up the plugins in a way that's compatible with both Hydra 1.0 and Hydra 1.1 syntax. This method should be deleted in the next major release. """ assert isinstance(plugin, Sweeper) or isinstance(plugin, Launcher) assert (config_loader is not None or hydra_context is not None ), "config_loader and hydra_context cannot both be None" param_keys = signature(plugin.setup).parameters.keys() if "hydra_context" not in param_keys: # DEPRECATED: remove in 1.2 # hydra_context will be required in 1.2 warnings.warn( message=dedent(""" Plugin's setup() signature has changed in Hydra 1.1. Support for the old style will be removed in Hydra 1.2. For more info, check https://github.com/facebookresearch/hydra/pull/1581.""" ), category=UserWarning, ) config_loader = ( config_loader if config_loader is not None else hydra_context.config_loader # type: ignore ) plugin.setup( # type: ignore config=config, config_loader=config_loader, task_function=task_function, ) else: if hydra_context is None: # hydra_context could be None when an incompatible Sweeper instantiates a compatible Launcher assert config_loader is not None hydra_context = HydraContext(config_loader=config_loader, callbacks=Callbacks()) plugin.setup(config=config, hydra_context=hydra_context, task_function=task_function) return plugin
def test_setup_plugins( monkeypatch: Any, plugin: Union[Launcher, Sweeper], config: DictConfig ) -> None: task_function = Mock(spec=TaskFunction) config_loader = ConfigLoaderImpl(config_search_path=create_config_search_path(None)) hydra_context = HydraContext(config_loader=config_loader, callbacks=Callbacks()) plugin_instance = Plugins.instance() monkeypatch.setattr(Plugins, "check_usage", lambda _: None) monkeypatch.setattr(plugin_instance, "_instantiate", lambda _: plugin) msg = "setup() got an unexpected keyword argument 'hydra_context'" with raises(TypeError, match=re.escape(msg)): if isinstance(plugin, Launcher): Plugins.instance().instantiate_launcher( hydra_context=hydra_context, task_function=task_function, config=config, ) else: Plugins.instance().instantiate_sweeper( hydra_context=hydra_context, task_function=task_function, config=config, )