def __init__(self) -> None: self._cs = ConfigStore.instance()
class Package: path: str = MISSING version_type: VersionType = VersionType.setup_py version_file: str = MISSING @dataclass class Config: dry_run: bool = False action: Action = Action.check packages: List[Package] = MISSING build_targets: Tuple[str, ...] = ("sdist", "bdist_wheel") build_dir: str = "build" ConfigStore.instance().store(name="config", node=Config) @lru_cache() def get_metadata(package_name: str) -> DictConfig: url = f"https://pypi.org/pypi/{package_name}/json" with requests.get(url, timeout=10) as response: ret = OmegaConf.create(response.content.decode("utf-8")) response.close() assert isinstance(ret, DictConfig) return ret def get_releases(metadata: DictConfig) -> List[Version]: ret: List[Version] = [] for ver, files in metadata.releases.items():
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from omegaconf import DictConfig import hydra # Underlying object from configen.samples.user import User # Generated config dataclass from example.gen.configen.samples.user.conf.User import UserConf from hydra.core.config_store import ConfigStore ConfigStore.instance().store(name="config", node={"user": UserConf}) @hydra.main(config_name="config") def my_app(cfg: DictConfig) -> None: user: User = hydra.utils.instantiate(cfg.user) print(user) if __name__ == "__main__": my_app()
# number of batches to be pre-dispatched pre_dispatch: str = "2*n_jobs" # number of atomic tasks to dispatch at once to each worker batch_size: str = "auto" # folder used for memmapping large arrays for sharing memory with workers temp_folder: Optional[str] = None # thresholds size of arrays that triggers automated memmapping max_nbytes: Optional[str] = None # memmapping mode for numpy arrays passed to workers mmap_mode: str = "r" @dataclass class JobLibLauncherConf(ObjectConf): cls: str = "hydra_plugins.hydra_joblib_launcher.joblib_launcher.JoblibLauncher" params: JobLibConf = JobLibConf() ConfigStore.instance().store( group="hydra/launcher", name="joblib", path="hydra.launcher", node=JobLibLauncherConf, provider="joblib_launcher", )
@dataclass class AxConfig: # max_trials is application-specific. Tune it for your use case max_trials: int = 10 early_stop: EarlyStopConfig = EarlyStopConfig() experiment: ExperimentConfig = ExperimentConfig() client: ClientConfig = ClientConfig() params: Dict[str, Any] = field(default_factory=dict) @dataclass class AxSweeperParams: # Maximum number of trials to run in parallel max_batch_size: Optional[int] = None ax_config: AxConfig = AxConfig() @dataclass class AxSweeperConf(ObjectConf): target: str = "hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper" params: AxSweeperParams = AxSweeperParams() ConfigStore.instance().store( group="hydra/sweeper", name="ax", node=AxSweeperConf, provider="ax_sweeper", )
from typing import Any, Dict, Optional from hydra.core.config_store import ConfigStore from hydra_plugins.hydra_submitit_launcher.config import BaseQueueConf, SlurmQueueConf @dataclass class GridEngineQueueConf: _target_: str = ( "hydra_plugins.hydra_gridengine_launcher.gridengine_launcher.GridEngineLauncher" ) merge: bool = False submitit_folder: str = "${hydra.sweep.dir}/.submitit/%j" job_name: str = "submitit" num_gpus: int = 0 shell: str = "/bin/bash" tc: Optional[str] = None tmem: str = "4G" h_vmem: str = "4G" h_rt: str = "01:00:00" smp: int = 1 ConfigStore.instance().store( group="hydra/launcher", name="submitit_gridengine", node=GridEngineQueueConf(), provider="hydra_gridengine_launcher", )
from hydra.types import ObjectConf, TaskFunction @dataclass class BasicSweeperConf(ObjectConf): target: str = MISSING @dataclass class Params: max_batch_size: Optional[int] = None params: Params = Params() ConfigStore.instance().store( group="hydra/sweeper", name="basic", node=BasicSweeperConf, provider="hydra", ) class BasicSweeper(Sweeper): """ Basic sweeper """ def __init__(self, max_batch_size: Optional[int]) -> None: """ Instantiates """ super(BasicSweeper, self).__init__() self.overrides: Optional[Sequence[Sequence[Sequence[str]]]] = None self.batch_index = 0
def main(_app): cs = ConfigStore.instance() cs.store(name="config", node=OurConfig) _app()
# set to true for performing maximization instead of minimization maximize: bool = False # optimization seed, for reproducibility seed: Optional[int] = None @dataclass class NevergradSweeperConf: _target_: str = ( "hydra_plugins.hydra_nevergrad_sweeper.nevergrad_sweeper.NevergradSweeper" ) # configuration of the optimizer optim: OptimConf = OptimConf() # default parametrization of the search space # can be specified: # - as a string, like commandline arguments # - as a list, for categorical variables # - as a full scalar specification parametrization: Dict[str, Any] = field(default_factory=dict) ConfigStore.instance().store( group="hydra/sweeper", name="nevergrad", node=NevergradSweeperConf, provider="nevergrad", )
# https://github.com/facebookincubator/submitit/blob/master/docs/checkpointing.md max_num_timeout: int = 0 # Useful to add parameters which are not currently available in the plugin. # Eg: {"mail-user": "******", "mail-type": "BEGIN"} additional_parameters: Dict[str, Any] = field(default_factory=dict) # Maximum number of jobs running in parallel array_parallelism: int = 256 @dataclass class LocalQueueConf(BaseQueueConf): _target_: str = ( "hydra_plugins.hydra_submitit_launcher.submitit_launcher.LocalLauncher" ) # finally, register two different choices: ConfigStore.instance().store( group="hydra/launcher", name="submitit_local", node=LocalQueueConf(), provider="submitit_launcher", ) ConfigStore.instance().store( group="hydra/launcher", name="submitit_slurm", node=SlurmQueueConf(), provider="submitit_launcher", )
def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any: # Check it config was passed. if cfg_passthrough is not None: return task_function(cfg_passthrough) else: args = get_args_parser() # Parse arguments in order to retrieve overrides parsed_args = args.parse_args() # type: argparse.Namespace # Get overriding args in dot string format overrides = parsed_args.overrides # type: list # Disable the creation of .hydra subdir # https://hydra.cc/docs/tutorials/basic/running_your_app/working_directory overrides.append("hydra.output_subdir=null") # Hydra logging outputs only to stdout (no log file). # https://hydra.cc/docs/configure_hydra/logging overrides.append("hydra/job_logging=stdout") # Set run.dir ONLY for ExpManager "compatibility" - to be removed. overrides.append("hydra.run.dir=.") # Check if user set the schema. if schema is not None: # Create config store. cs = ConfigStore.instance() # Get the correct ConfigStore "path name" to "inject" the schema. if parsed_args.config_name is not None: path, name = os.path.split(parsed_args.config_name) # Make sure the path is not set - as this will disable validation scheme. if path != '': sys.stderr.write( f"ERROR Cannot set config file path using `--config-name` when " "using schema. Please set path using `--config-path` and file name using " "`--config-name` separately.\n") sys.exit(1) else: name = config_name # Register the configuration as a node under the name in the group. cs.store(name=name, node=schema) # group=group, # Wrap a callable object with name `parse_args` # This is to mimic the ArgParser.parse_args() API. class _argparse_wrapper: def __init__(self, arg_parser): self.arg_parser = arg_parser self._actions = arg_parser._actions def parse_args(self, args=None, namespace=None): return parsed_args # no return value from run_hydra() as it may sometime actually run the task_function # multiple times (--multirun) _run_hydra( args_parser=_argparse_wrapper(args), task_function=task_function, config_path=config_path, config_name=config_name, )
@dataclass class QueueParams: slurm: SlurmQueueConf = SlurmQueueConf() local: LocalQueueConf = LocalQueueConf() auto: AutoQueueConf = AutoQueueConf() @dataclass class SubmititConf: queue: QueueType = QueueType.slurm folder: str = "${hydra.sweep.dir}/.${hydra.launcher.params.queue}" queue_parameters: QueueParams = QueueParams() @dataclass class SubmititLauncherConf(ObjectConf): cls: str = "hydra_plugins.hydra_submitit_launcher.submitit_launcher.SubmititLauncher" params: SubmititConf = SubmititConf() # memory to reserve for the job on each node, in GB mem_limit: int = 2 ConfigStore.instance().store( group="hydra/launcher", name="submitit", node=SubmititLauncherConf, provider="submitit_launcher", )
run_job, setup_globals, ) from hydra.plugins.launcher import Launcher from hydra.types import HydraContext, TaskFunction log = logging.getLogger(__name__) @dataclass class BasicLauncherConf: _target_: str = "hydra._internal.core_plugins.basic_launcher.BasicLauncher" ConfigStore.instance().store(group="hydra/launcher", name="basic", node=BasicLauncherConf, provider="hydra") class BasicLauncher(Launcher): def __init__(self) -> None: super().__init__() self.config: Optional[DictConfig] = None self.task_function: Optional[TaskFunction] = None self.hydra_context: Optional[HydraContext] = None def setup( self, *, hydra_context: HydraContext, task_function: TaskFunction,
# how long successful jobs and their results are kept (e.g. "1d" for 1 day, units: d/h/m/s), default: no limit result_ttl: Optional[str] = None # specifies how long failed jobs are kept (e.g. "1d" for 1 day, units: d/h/m/s), default: no limit failure_ttl: Optional[str] = None # place job at the front of the queue, instead of the back at_front: bool = False # job id, will be overidden automatically by a uuid unless specified explicitly job_id: Optional[str] = None # description, will be overidden automatically unless specified explicitly description: Optional[str] = None @dataclass class RQLauncherConf: _target_: str = "hydra_plugins.hydra_rq_launcher.rq_launcher.RQLauncher" # enqueue configuration enqueue: EnqueueConf = EnqueueConf() # queue name queue: str = "default" # redis configuration redis: RedisConf = RedisConf() # stop after enqueueing by raising custom exception stop_after_enqueue: bool = False # wait time in seconds when polling results wait_polling: float = 1.0 ConfigStore.instance().store( group="hydra/launcher", name="rq", node=RQLauncherConf, provider="rq_launcher" )
def populate() -> None: """Register configs.""" cs = ConfigStore.instance() cs.store(name="MaskRCNN", node=MaskRCNNConfig, provider="paddle")
'env': env, 'seed': seed, 'agent.horizon': horizon, 'learn_temp.init_targ_entr': init_targ_entr, 'learn_temp.final_targ_entr': final_targ_entr, 'learn_temp.entr_decay_factor': gamma } overrides_i = [f'{k}={v}' for k, v in overrides_i.items()] overrides.append(overrides_i) random.shuffle(overrides) # self.validate_batch_is_legal(overrides) # Can take a long time returns = self.launcher.launch(overrides, initial_job_idx=self.job_idx) self.job_idx += len(returns) @dataclass class SVGSweeperConf: _target_: str = "svg.sweeper.SVGSweeper" # Hacks for a non-standard plugin Plugins.is_in_toplevel_plugins_module = lambda x, y: True Plugins.instance().class_name_to_class['svg.sweeper.SVGSweeper'] = SVGSweeper ConfigStore.instance().store( group="hydra/sweeper", name="svg", node=SVGSweeperConf, provider="svg", )
# Stop Ray AWS cluster after jobs are finished. # (if False, cluster will remain provisioned and can be started with "ray up cluster.yaml"). stop_cluster: bool = True # sync_up is executed before launching jobs on the cluster. # This can be used for syncing up source code to remote cluster for execution. # You need to sync up if your code contains multiple modules. # source is local dir, target is remote dir sync_up: RsyncConf = RsyncConf() # sync_down is executed after jobs finishes on the cluster. # This can be used to download jobs output to local machine avoid the hassle to log on remote machine. # source is remote dir, target is local dir sync_down: RsyncConf = RsyncConf() config_store = ConfigStore.instance() config_store.store( group="hydra/launcher", name="ray_local", node=RayLocalLauncherConf, provider="ray_launcher", ) config_store.store( group="hydra/launcher", name="ray_aws", node=RayAWSLauncherConf, provider="ray_launcher", )
from omegaconf import DictConfig import hydra from hydra.core.config_store import ConfigStore @dataclass class MySQLConfig: driver: str = "mysql" host: str = "localhost" port: int = 3306 user: str = "omry" password: str = "secret" ConfigStore.instance().store(node=MySQLConfig, name="config", path="db") @hydra.main(config_name="config") def my_app(cfg: DictConfig) -> None: # In order to get type safety you need to tell Python that the type of cfg.db is MySQLConfig: db: MySQLConfig = cfg.db print( f"Connecting to {db.driver} at {db.host}:{db.port}, user={db.user}, password={db.password}" ) if __name__ == "__main__": my_app()
from dataclasses import dataclass from typing import Optional from hydra.core.config_store import ConfigStore @dataclass class OptunaConfig: direction: str = "minimize" storage: Optional[str] = None study_name: Optional[str] = None n_trials: Optional[int] = 20 n_jobs: int = 1 timeout: Optional[float] = None # TODO(yanase): Configure sampler and pruner. @dataclass class OptunaSweeperConf: _target_: str = "hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper" optuna_config: OptunaConfig = OptunaConfig() ConfigStore.instance().store(group="hydra/sweeper", name="optuna", node=OptunaSweeperConf, provider="optuna_sweeper")
def register_params_dataclass(cs: ConfigStore, name: str, group: str, data_class: Type[FairseqDataclass]) -> None: """register params dataclass in config store""" node_ = data_class(_name=data_class.name()) cs.store(name=name, group=group, node=node_)
run_job, setup_globals, ) from hydra.plugins.launcher import Launcher from hydra.types import TaskFunction log = logging.getLogger(__name__) @dataclass class SlurmLauncherConf: _target_: str = "hydra._internal.core_plugins.slurm_launcher.SlurmLauncher" ConfigStore.instance().store(group="hydra/launcher", name="slurm", node=SlurmLauncherConf, provider="hydra") class SlurmLauncher(Launcher): def __init__(self) -> None: super().__init__() self.config: Optional[DictConfig] = None self.config_loader: Optional[ConfigLoader] = None self.task_function: Optional[TaskFunction] = None def setup( self, config: DictConfig, config_loader: ConfigLoader, task_function: TaskFunction,
def my_app(cfg: Config) -> None: cs = ConfigStore.instance() print(cs.repo) print(OmegaConf.to_yaml(cfg))
from hydra.core.config_store import ConfigStore @dataclass class MySQLConfig: driver: str = "mysql" host: str = "localhost" port: int = 3306 user: str = "omry" password: str = "secret" @dataclass class Config: db: MySQLConfig = MySQLConfig() # We no longer need to use the path parameter because Config has the correct structure ConfigStore.instance().store(node=Config, name="config") @hydra.main(config_name="config") def my_app(cfg: Config) -> None: # Python knows that the type of cfg.db is MySQLConfig without any additional hints print(f"Connecting to {cfg.db.driver} at {cfg.db.host}:{cfg.db.port}, " f"user={cfg.db.user}, password={cfg.db.password}") if __name__ == "__main__": my_app()
def init_hydra(config_dir: Path) -> None: cs = ConfigStore.instance() cs.store(name='config', node=PipelineConfig) current_file = Path(__file__) relative_config_dir = os.path.relpath(config_dir, current_file.parent) initialize(relative_config_dir)
@hydra.main(config_name='config_zinc', config_path='conf') def train_with_conf(config: combo_models.ZincTrainingConfiguration): trainer = utils.make_trainer(config) torch.manual_seed(config.seed) path_lengths, cycle_lengths = _expand_to_default( config.model.path_lengths, config.model.cycle_lengths) transform = Compose([Pathifier(path_lengths), Cyclifier(cycle_lengths)]) batch_split = max(config.num_gpus, 1) dataset = ZincDataModule(config.data, transform=transform, batch_size=config.batch_size // batch_split) dataset.prepare_data() dataset.setup() config.model.atom_feature_cardinality = dataset.atom_feature_cardinality fixture = combo_models.ZincPathAndCycleModel(config) trainer.fit(fixture, dataset) if __name__ == '__main__': from hydra.core.config_store import ConfigStore cs = ConfigStore() cs.store(name='base_config_zinc', node=combo_models.ZincTrainingConfiguration) train_with_conf()
def hydra_runner( config_path: Optional[str] = None, config_name: Optional[str] = None, schema: Optional[Any] = None) -> Callable[[TaskFunction], Any]: """ Decorator used for passing the Config paths to main function. Optionally registers a schema used for validation/providing default values. Args: config_path: Path to the directory where the config exists. config_name: Name of the config file. schema: Structured config type representing the schema used for validation/providing default values. """ if schema is not None: # Create config store. cs = ConfigStore.instance() # Register the configuration as a node under a given name. cs.store(name=config_name.replace(".yaml", ""), node=schema) def decorator(task_function: TaskFunction) -> Callable[[], None]: @functools.wraps(task_function) def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any: # Check it config was passed. if cfg_passthrough is not None: return task_function(cfg_passthrough) else: args = get_args_parser() # Parse arguments in order to retrieve overrides parsed_args = args.parse_args() # Get overriding args in dot string format overrides = parsed_args.overrides # type: list # Update overrides overrides.append("hydra.run.dir=.") overrides.append('hydra.job_logging.root.handlers=null') # Wrap a callable object with name `parse_args` # This is to mimic the ArgParser.parse_args() API. class _argparse_wrapper: def __init__(self, arg_parser): self.arg_parser = arg_parser self._actions = arg_parser._actions def parse_args(self, args=None, namespace=None): return parsed_args # no return value from run_hydra() as it may sometime actually run the task_function # multiple times (--multirun) _run_hydra( args_parser=_argparse_wrapper(args), task_function=task_function, config_path=config_path, config_name=config_name, strict=None, ) return wrapper return decorator
def _load_config_impl( self, input_file: str, record_load: bool = True ) -> Tuple[Optional[DictConfig], Optional[LoadTrace]]: """ :param input_file: :param record_load: :return: the loaded config or None if it was not found """ def record_loading( name: str, path: Optional[str], provider: Optional[str], schema_provider: Optional[str], ) -> Optional[LoadTrace]: trace = LoadTrace( filename=name, path=path, provider=provider, schema_provider=schema_provider, ) if record_load: self.all_config_checked.append(trace) return trace ret = self.repository.load_config(config_path=input_file) if ret is not None: if not isinstance(ret.config, DictConfig): raise ValueError( f"Config {input_file} must be a Dictionary, got {type(ret).__name__}" ) if not ret.is_schema_source: try: schema = ConfigStore.instance().load( config_path=ConfigSource._normalize_file_name( filename=input_file ) ) merged = OmegaConf.merge(schema.node, ret.config) assert isinstance(merged, DictConfig) return ( merged, record_loading( name=input_file, path=ret.path, provider=ret.provider, schema_provider=schema.provider, ), ) except ConfigLoadError: # schema not found, ignore pass return ( ret.config, record_loading( name=input_file, path=ret.path, provider=ret.provider, schema_provider=None, ), ) else: return ( None, record_loading( name=input_file, path=None, provider=None, schema_provider=None ), )
elif cfg.validation: ckpt = torch.load( r'C:\Users\yonio\PycharmProjects\Amygdala_new\outputs\2020-10-25\23-43-07\Untitled\AM-374\checkpoints\epoch=265.ckpt' ) lut = Net.EEGNetwork(**cfg.net).regulate_enc.embedding_lut model = Net.IndicesNetwrork(lut, cfg.data.criteria_len, 3) model.load_state_dict(ckpt['state_dict']) data.phase = 3 trainer = create_trainer(cfg, ['indices']) trainer.test(model, datamodule=data) else: ckpt = torch.load( r'C:\Users\yonio\PycharmProjects\Amygdala_new\outputs\2020-10-25\01-03-56\Untitled\AM-328\checkpoints\epoch=499.ckpt' ) eeg_model = Net.EEGNetwork(**cfg.net) eeg_model.load_state_dict(ckpt['state_dict']) model = Net.IndicesNetwrork(eeg_model.regulate_enc.embedding_lut, cfg.data.criteria_len, bins_num=3) train_third_phase(cfg, data, model) if __name__ == '__main__': cs = ConfigStore() cs.store(name='eeg', node=EEGLearnerConfig) main()
except: return False def download_file(url, folder_name): local_filename = url.split('/')[-1] path = os.path.join("/{}/{}".format(folder_name, local_filename)) with requests.get(url, stream=True) as r: with open(path, 'wb') as f: shutil.copyfileobj(r.raw, f) return path ALLOWED_EXTENSIONS = set(['.wav', '.mp3', '.ogg', '.webm']) cs = ConfigStore.instance() cs.store(name="config", node=ServerConfig) @app.errorhandler(404) def page_not_found(e): request.path = request.path.replace("//", "/") if ("/download/" in request.path): req = request.path.split("/download/") status = True try: if (req[0] != '' and len(req < 2)): status = False except: status = False if status == False:
def test_load_config_with_schema(self, hydra_restore_singletons: Any, path: str) -> None: ConfigStore.instance().store(name="config", node=TopLevelConfig, provider="this_test") ConfigStore.instance().store(group="db", name="mysql", node=MySQLConfig, provider="this_test") config_loader = ConfigLoaderImpl( config_search_path=create_config_search_path(path)) cfg = config_loader.load_configuration(config_name="config", overrides=["+db=mysql"], run_mode=RunMode.RUN) expected = deepcopy(hydra_load_list) expected.append( LoadTrace( config_path="config", package="", parent="<root>", is_self=False, search_path=path, provider="main", )) expected.append( LoadTrace( config_path="db/mysql", package="db", parent="<root>", is_self=False, search_path=path, provider="main", )) assert_same_composition_trace(cfg.hydra.composition_trace, expected) with open_dict(cfg): del cfg["hydra"] assert cfg == { "normal_yaml_config": True, "db": { "driver": "mysql", "host": "???", "port": "???", "user": "******", "password": "******", }, } # verify illegal modification is rejected at runtime with pytest.raises(ValidationError): cfg.db.port = "fail" # verify illegal override is rejected during load with pytest.raises(HydraException): config_loader.load_configuration(config_name="db/mysql", overrides=["db.port=fail"], run_mode=RunMode.RUN)