def test_assignment_of_subclass(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.create({"plugin": module.Plugin}) cfg.plugin = OmegaConf.structured(module.ConcretePlugin) assert OmegaConf.get_type(cfg.plugin) == module.ConcretePlugin assert (OmegaConf.get_type( cfg.plugin.params) == module.ConcretePlugin.FoobarParams)
def my_app(cfg: Config) -> None: if OmegaConf.get_type(cfg.model) is MlpConfig: mlp_cfg = cast(MlpConfig, cfg.model) print("using MLP") print(f"{mlp_cfg.layers=}") print(f"{mlp_cfg.hidden_units=}") elif OmegaConf.get_type(cfg.model) is SVMConfig: svm_cfg = cast(SVMConfig, cfg.model) print("using SVM") print(f"{svm_cfg.kernel=}") print(f"{svm_cfg.C=}") print() data_dir: Path = instantiate(cfg.dataset.dir) print(data_dir) if OmegaConf.get_type(cfg.dataset) is AdultConfig: adult_cfg = cast(AdultConfig, cfg.dataset) print("using Adult dataset") print(f"{adult_cfg.drop_native=}") elif OmegaConf.get_type(cfg.dataset) is CmnistConfig: cmnist_cfg = cast(CmnistConfig, cfg.dataset) print("using CMNIST dataset") print(f"{cmnist_cfg.padding=}") print() print(f"{cfg.seed=}") print(f"{cfg.use_wandb=}") print(f"{cfg.data_pcnt=}") print() print("Config as flat dictionary:") print(flatten(OmegaConf.to_container(cfg, enum_to_str=True)))
def test_overlapping_schemas(hydra_restore_singletons: Any) -> None: cs = ConfigStore.instance() cs.store(name="config", node=Config) cs.store(group="plugin", name="concrete", node=ConcretePlugin) config_loader = ConfigLoaderImpl( config_search_path=create_config_search_path(None)) cfg = config_loader.load_configuration(config_name="config", overrides=[], run_mode=RunMode.RUN) with open_dict(cfg): del cfg["hydra"] assert cfg == {"plugin": {"name": "???", "params": "???"}} assert OmegaConf.get_type(cfg.plugin) == Plugin cfg = config_loader.load_configuration(config_name="config", overrides=["+plugin=concrete"], run_mode=RunMode.RUN) with open_dict(cfg): del cfg["hydra"] assert cfg == {"plugin": {"name": "foobar_plugin", "params": {"foo": 10}}} assert OmegaConf.get_type(cfg.plugin) == ConcretePlugin assert OmegaConf.get_type(cfg.plugin.params) == ConcretePlugin.FoobarParams with pytest.raises(ValidationError): cfg.plugin = 10
def configure_optimizers(self): if OmegaConf.get_type(self.optim_cfg) is SGDConfig: optimizer = torch.optim.SGD( params=self.parameters(), lr=self.optim_cfg.learning_rate, momentum=self.optim_cfg.momentum, nesterov=True, weight_decay=self.optim_cfg.weight_decay ) elif OmegaConf.get_type(self.optim_cfg) is AdamConfig: optimizer = torch.optim.AdamW( params=self.parameters(), lr=self.optim_cfg.learning_rate, betas=self.optim_cfg.betas, eps=self.optim_cfg.eps, weight_decay=self.optim_cfg.weight_decay ) else: raise ValueError("Optimizer has not been specified correctly.") scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=self.optim_cfg.learning_anneal ) return [optimizer], [scheduler]
def _validate_merge(self, value: Any) -> None: from omegaconf import OmegaConf dest = self src = value self._validate_non_optional(None, src) dest_obj_type = OmegaConf.get_type(dest) src_obj_type = OmegaConf.get_type(src) if dest._is_missing() and src._metadata.object_type is not None: self._validate_set(key=None, value=_get_value(src)) if src._is_missing(): return validation_error = (dest_obj_type is not None and src_obj_type is not None and is_structured_config(dest_obj_type) and not src._is_none() and not is_dict(src_obj_type) and not issubclass(src_obj_type, dest_obj_type)) if validation_error: msg = (f"Merge error: {type_str(src_obj_type)} is not a " f"subclass of {type_str(dest_obj_type)}. value: {src}") raise ValidationError(msg)
def test_merge(self, module: Any) -> None: cfg1 = OmegaConf.create({"plugin": module.Plugin}) cfg2 = OmegaConf.create({"plugin": module.ConcretePlugin}) assert cfg2.plugin == module.ConcretePlugin res: Any = OmegaConf.merge(cfg1, cfg2) assert OmegaConf.get_type(res.plugin) == module.ConcretePlugin assert (OmegaConf.get_type( res.plugin.params) == module.ConcretePlugin.FoobarParams)
def test_promote_to_object(self, module: Any) -> None: conf = OmegaConf.create(module.AnyTypeConfig) assert OmegaConf.get_type(conf) == module.AnyTypeConfig conf._promote(module.BoolConfig(with_default=False)) assert OmegaConf.get_type(conf) == module.BoolConfig assert conf.with_default is False
def test_merge_with_subclass_into_missing(self, module: Any) -> None: base = OmegaConf.structured(module.PluginHolder) assert _utils.get_ref_type(base, "missing") == module.Plugin assert OmegaConf.get_type(base, "missing") is None res = OmegaConf.merge(base, {"missing": module.Plugin}) assert OmegaConf.get_type(res) == module.PluginHolder assert _utils.get_ref_type(base, "missing") == module.Plugin assert OmegaConf.get_type(res, "missing") == module.Plugin
def __init__(self, labels: List, model_cfg: Union[UniDirectionalConfig, BiDirectionalConfig, ConvolutionConfig], precision: int, optim_cfg: Union[AdamConfig, SGDConfig], spect_cfg: SpectConfig): super().__init__() self.save_hyperparameters() self.model_cfg = model_cfg self.precision = precision self.optim_cfg = optim_cfg self.spect_cfg = spect_cfg self.convolutional = True if OmegaConf.get_type( model_cfg) is ConvolutionConfig else False self.bidirectional = True if OmegaConf.get_type( model_cfg) is BiDirectionalConfig else False self.labels = labels self.conv = MaskConv( nn.Sequential( nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)), nn.BatchNorm2d(32), nn.Hardtanh(0, 20, inplace=True), nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)), nn.BatchNorm2d(32), nn.Hardtanh(0, 20, inplace=True))) # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1 rnn_input_size = int( math.floor((self.spect_cfg.sample_rate * self.spect_cfg.window_size) / 2) + 1) rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1) rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1) rnn_input_size *= 32 if self.convolutional is False: self.rnns, self.lookahead, self.fc = self._rnn_construct( rnn_input_size) else: self.deep_conv, self.fc = self._conv_construct(rnn_input_size) self.inference_softmax = InferenceBatchSoftmax() self.criterion = CTCLoss(blank=self.labels.index('_'), reduction='sum', zero_infinity=True) self.evaluation_decoder = GreedyDecoder( self.labels) # Decoder used for validation self.wer = WordErrorRate(decoder=self.evaluation_decoder, target_decoder=self.evaluation_decoder) self.cer = CharErrorRate(decoder=self.evaluation_decoder, target_decoder=self.evaluation_decoder)
def my_app(cfg: Config) -> None: # Remember that the actual type of Config and db inside it is DictConfig. # If you need to get the underlying type of a config object use OmegaConf.get_type: if OmegaConf.get_type(cfg.db) is MySQLConfig: connect_mysql(cast(MySQLConfig, cfg.db)) elif OmegaConf.get_type(cfg.db) is PostGreSQLConfig: connect_postgresql(cast(PostGreSQLConfig, cfg.db)) else: raise ValueError()
def test_get_type() -> None: cfg = OmegaConf.structured(User) assert OmegaConf.get_type(cfg) == User cfg = OmegaConf.structured(User(name="bond")) assert OmegaConf.get_type(cfg) == User cfg = OmegaConf.create({"user": User}) assert OmegaConf.get_type(cfg.user) == User
def test_promote_to_class(self, module: Any) -> None: conf = OmegaConf.create(module.AnyTypeConfig) assert OmegaConf.get_type(conf) == module.AnyTypeConfig conf._promote(module.BoolConfig) assert OmegaConf.get_type(conf) == module.BoolConfig assert conf.with_default is True assert conf.null_default is None assert OmegaConf.is_missing(conf, "mandatory_missing")
def hydra_main(cfg: ExampleEEGConfig): console = logging.StreamHandler() console.setFormatter(logging.Formatter("[%(name)s] [%(levelname)s] %(message)s")) console.setLevel(logging.INFO) logging.getLogger("ml").addHandler(console) if OmegaConf.get_type(cfg.train.model) == CNNConfig: hyperparameters = { 'transformer.transform': ['none'], 'train.model.channel_list': [[4, 8, 16, 32]], 'train.model.kernel_sizes': [[[4]] * 4], 'train.model.stride_sizes': [[[2]] * 4], 'train.model.padding_sizes': [[[1]] * 4], 'train.model.optim.lr': [1e-4], } elif OmegaConf.get_type(cfg.train.model) == CNNRNNConfig: hyperparameters = { 'train.model.optim.lr': [1e-3, 1e-4, 1e-5], 'transformer.transform': ['none'], 'train.model.channel_list': [[4, 8, 16, 32]], 'train.model.kernel_sizes': [[[4]] * 4], 'train.model.stride_sizes': [[[2]] * 4], 'train.model.padding_sizes': [[[1]] * 4], 'train.model.rnn_type': [cfg.train.model.rnn_type], 'train.model.bidirectional': [True], 'train.model.rnn_n_layers': [1, 2], 'train.model.rnn_hidden_size': [10, 50], } elif OmegaConf.get_type(cfg.train.model) == RNNConfig: hyperparameters = { 'train.model.bidirectional': [True, False], 'train.model.rnn_type': ['lstm', 'gru'], 'train.model.rnn_n_layers': [1, 2], 'train.model.rnn_hidden_size': [10, 50], 'transformer.transform': ['none'], 'train.model.optim.lr': [1e-4], } else: hyperparameters = { 'train.model.optim.lr': [1e-4], 'batch_size': [16], 'transformer.transform': ['logmel'], 'loss_func': ['ce'], 'epoch_rate': [1.0], 'sample_balance': ['same'], } cfg.expt_id = f'{cfg.train.model_type.value}' expt_dir = Path(utils.to_absolute_path('output')) / 'example_face' / f'{cfg.expt_id}' expt_dir.mkdir(exist_ok=True, parents=True) main(cfg, expt_dir, hyperparameters) if not cfg.mlflow: shutil.rmtree('mlruns')
def _set_optimizer(self): if OmegaConf.get_type(self.cfg.optim) == AdamConfig: return torch.optim.Adam(self.model.parameters(), lr=self.cfg.optim.lr, weight_decay=self.cfg.optim.weight_decay) elif OmegaConf.get_type(self.cfg.optim) == SGDConfig: return torch.optim.SGD(self.model.parameters(), lr=self.cfg.optim.lr, momentum=self.cfg.optim.momentum, weight_decay=self.cfg.optim.weight_decay, nesterov=True)
def test_get_ref_type_with_conflict() -> None: cfg = OmegaConf.create( {"user": User, "inter": DictConfig(ref_type=Plugin, content="${user}")} ) assert OmegaConf.get_type(cfg.user) == User assert _utils.get_ref_type(cfg.user) == Optional[User] # Interpolation inherits both type and ref type from the target assert OmegaConf.get_type(cfg.inter) == User assert _utils.get_ref_type(cfg.inter) == Optional[User]
def test_get_type(self, class_type: str) -> None: module: Any = import_module(class_type) linked_list = module.LinkedList cfg1 = OmegaConf.create(linked_list) assert OmegaConf.get_type(cfg1) == linked_list assert cfg1.next is None assert OmegaConf.is_missing(cfg1, "value") cfg2 = OmegaConf.create(module.MissingTest.Missing1) assert OmegaConf.is_missing(cfg2, "head") assert OmegaConf.get_type(cfg2, "head") == module.LinkedList
def test_plugin_merge(self, class_type: str) -> None: module: Any = import_module(class_type) plugin = OmegaConf.structured(module.Plugin) concrete = OmegaConf.structured(module.ConcretePlugin) ret = OmegaConf.merge(plugin, concrete) assert ret == concrete assert OmegaConf.get_type(ret) == module.ConcretePlugin more_fields = OmegaConf.structured( module.PluginWithAdditionalField) ret = OmegaConf.merge(plugin, more_fields) assert ret == more_fields assert OmegaConf.get_type(ret) == module.PluginWithAdditionalField
def test_get_type(self, module: Any) -> None: cfg1 = OmegaConf.create(module.LinkedList) assert OmegaConf.get_type(cfg1) == module.LinkedList assert _utils.get_ref_type(cfg1, "next") == Optional[module.LinkedList] assert OmegaConf.get_type(cfg1, "next") is None assert cfg1.next is None assert OmegaConf.is_missing(cfg1, "value") cfg2 = OmegaConf.create(module.MissingTest.Missing1) assert OmegaConf.is_missing(cfg2, "head") assert _utils.get_ref_type(cfg2, "head") == module.LinkedList assert OmegaConf.get_type(cfg2, "head") is None
def test_instantiated_regular_class_container_types_partial__recursive( instantiate_func: Any, config: Any) -> None: ret = instantiate_func(config, _convert_=ConvertMode.PARTIAL) assert isinstance(ret.a, SimpleClass) assert isinstance(ret.a.a, dict) assert isinstance(ret.a.b, DictConfig) assert OmegaConf.get_type(ret.a.b) is User
def _validate_set(self, key: Any, value: Any) -> None: from omegaconf import OmegaConf vk = get_value_kind(value) if vk in (ValueKind.INTERPOLATION, ValueKind.STR_INTERPOLATION): return self._validate_non_optional(key, value) if value == "???" or value is None: return target = self._get_node(key) if key is not None else self target_has_ref_type = isinstance( target, DictConfig) and target._metadata.ref_type not in (Any, dict) is_valid_target = target is None or not target_has_ref_type if is_valid_target: return target_type = target._metadata.ref_type # type: ignore value_type = OmegaConf.get_type(value) if is_dict(value_type) and is_dict(target_type): return if is_container_annotation( target_type) and not is_container_annotation(value_type): raise ValidationError( f"Cannot assign {type_str(value_type)} to {type_str(target_type)}" ) validation_error = (target_type is not None and value_type is not None and not issubclass(value_type, target_type)) if validation_error: self._raise_invalid_value(value, value_type, target_type)
def _validate_merge(self, key: Any, value: Any) -> None: from omegaconf import OmegaConf self._validate_non_optional(key, value) target = self._get_node(key) if key is not None else self target_has_ref_type = isinstance( target, DictConfig ) and target._metadata.ref_type not in (Any, dict) is_valid_value = target is None or not target_has_ref_type if is_valid_value: return target_type = target._metadata.ref_type # type: ignore value_type = OmegaConf.get_type(value) if is_generic_container(target_type): return # Merging of a dictionary is allowed even if assignment is illegal (merge would do deeper checks) validation_error = ( target_type is not None and value_type is not None and not issubclass(value_type, target_type) and not is_dict(value_type) ) if validation_error: self._raise_invalid_value(value, value_type, target_type)
def test_instantiated_regular_class_container_types_partial( instantiate_func: Any, ) -> None: cfg = {"_target_": "tests.instantiate.SimpleClass", "a": {}, "b": User()} ret = instantiate_func(cfg, _convert_=ConvertMode.PARTIAL) assert isinstance(ret.a, dict) assert isinstance(ret.b, DictConfig) assert OmegaConf.get_type(ret.b) is User
def test_plugin_merge_2(self, module: Any) -> None: plugin = OmegaConf.structured(module.Plugin) more_fields = OmegaConf.structured( module.PluginWithAdditionalField) ret = OmegaConf.merge(plugin, more_fields) assert ret == more_fields assert OmegaConf.get_type(ret) == module.PluginWithAdditionalField
def test_merged_type1(self, class_type: str) -> None: # Test that the merged type is that of the last merged config module: Any = import_module(class_type) input_ = module.WithDictField conf = OmegaConf.structured(input_) res = OmegaConf.merge(OmegaConf.create(), conf) assert OmegaConf.get_type(res) == input_
def test_merged_with_subclass(self, class_type: str) -> None: # Test that the merged type is that of the last merged config module: Any = import_module(class_type) c1 = OmegaConf.structured(module.Plugin) c2 = OmegaConf.structured(module.ConcretePlugin) res = OmegaConf.merge(c1, c2) assert OmegaConf.get_type(res) == module.ConcretePlugin
def test_merged_type2(self, class_type: str) -> None: # Test that the merged type is that of the last merged config module: Any = import_module(class_type) input_ = module.WithDictField conf = OmegaConf.structured(input_) res = OmegaConf.merge(conf, {"dict": {"foo": 99}}) assert OmegaConf.get_type(res) == input_
def train(cfg: DeepSpeechConfig): seed_everything(cfg.seed) with open(to_absolute_path(cfg.data.labels_path)) as label_file: labels = json.load(label_file) if cfg.trainer.checkpoint_callback: if OmegaConf.get_type(cfg.checkpoint) is GCSCheckpointConfig: checkpoint_callback = GCSCheckpointHandler(cfg=cfg.checkpoint) else: checkpoint_callback = FileCheckpointHandler(cfg=cfg.checkpoint) if cfg.load_auto_checkpoint: resume_from_checkpoint = checkpoint_callback.find_latest_checkpoint( ) if resume_from_checkpoint: cfg.trainer.resume_from_checkpoint = resume_from_checkpoint print(cfg.trainer.gpus) data_loader = DeepSpeechDataModule(labels=labels, data_cfg=cfg.data, normalize=True, is_distributed=cfg.trainer.gpus > 1) model = DeepSpeech(labels=labels, model_cfg=cfg.model, optim_cfg=cfg.optim, precision=cfg.trainer.precision, spect_cfg=cfg.data.spect) trainer = hydra.utils.instantiate( config=cfg.trainer, replace_sampler_ddp=False, callbacks=[checkpoint_callback] if cfg.trainer.checkpoint_callback else None, ) trainer.fit(model, data_loader)
def _validate_set(self, key: Any, value: Any) -> None: from omegaconf import OmegaConf self._validate_get(key, value) if self._get_flag("readonly"): raise ReadonlyConfigError("ListConfig is read-only") if 0 <= key < self.__len__(): target = self._get_node(key) if target is not None: assert isinstance(target, Node) if value is None and not target._is_optional(): raise ValidationError( "$FULL_KEY is not optional and cannot be assigned None" ) target_type = self._metadata.element_type value_type = OmegaConf.get_type(value) if is_structured_config(target_type): if ( target_type is not None and value_type is not None and not issubclass(value_type, target_type) ): msg = ( f"Invalid type assigned : {type_str(value_type)} is not a " f"subclass of {type_str(target_type)}. value: {value}" ) raise ValidationError(msg)
def hydra_main(cfg: ExampleFaceConfig): console = logging.StreamHandler() console.setFormatter( logging.Formatter("[%(name)s] [%(levelname)s] %(message)s")) console.setLevel(logging.INFO) logging.getLogger("ml").addHandler(console) if OmegaConf.get_type(cfg.train.model) == CNNConfig: hyperparameters = { 'train.model.optim.lr': [1e-4], } elif OmegaConf.get_type(cfg.train.model) == CNNRNNConfig: hyperparameters = { 'train.model.optim.lr': [1e-3, 1e-4, 1e-5], 'window_size': [0.5], 'window_stride': [0.1], 'transform': ['logmel'], 'rnn_type': [cfg.rnn_type], 'bidirectional': [True], 'rnn_n_layers': [1], 'rnn_hidden_size': [10], } elif OmegaConf.get_type(cfg.train.model) == RNNConfig: hyperparameters = { 'bidirectional': [True, False], 'rnn_type': ['lstm', 'gru'], 'rnn_n_layers': [1, 2], 'rnn_hidden_size': [10, 50], 'transform': [None], 'train.model.optim.lr': [1e-3, 1e-4, 1e-5], } else: hyperparameters = { 'train.model.optim.lr': [1e-4, 1e-5], 'data.batch_size': [64], 'data.epoch_rate': [1.0], 'data.sample_balance': ['same'], } cfg.expt_id = f'{cfg.train.model_type.value}_pretrain-{cfg.train.model.pretrained}' expt_dir = Path( utils.to_absolute_path('output')) / 'example_face' / f'{cfg.expt_id}' expt_dir.mkdir(exist_ok=True, parents=True) main(cfg, expt_dir, hyperparameters) if not cfg.mlflow: shutil.rmtree('mlruns')
def test_missing1(self, module: Any) -> None: cfg = OmegaConf.create(module.MissingTest.Missing1) assert OmegaConf.is_missing(cfg, "head") assert OmegaConf.get_type(cfg, "head") is None with raises(ValidationError): cfg.head = 10