Exemplo n.º 1
0
 def test_assignment_of_subclass(self, class_type: str) -> None:
     module: Any = import_module(class_type)
     cfg = OmegaConf.create({"plugin": module.Plugin})
     cfg.plugin = OmegaConf.structured(module.ConcretePlugin)
     assert OmegaConf.get_type(cfg.plugin) == module.ConcretePlugin
     assert (OmegaConf.get_type(
         cfg.plugin.params) == module.ConcretePlugin.FoobarParams)
Exemplo n.º 2
0
def my_app(cfg: Config) -> None:
    if OmegaConf.get_type(cfg.model) is MlpConfig:
        mlp_cfg = cast(MlpConfig, cfg.model)
        print("using MLP")
        print(f"{mlp_cfg.layers=}")
        print(f"{mlp_cfg.hidden_units=}")
    elif OmegaConf.get_type(cfg.model) is SVMConfig:
        svm_cfg = cast(SVMConfig, cfg.model)
        print("using SVM")
        print(f"{svm_cfg.kernel=}")
        print(f"{svm_cfg.C=}")

    print()

    data_dir: Path = instantiate(cfg.dataset.dir)
    print(data_dir)
    if OmegaConf.get_type(cfg.dataset) is AdultConfig:
        adult_cfg = cast(AdultConfig, cfg.dataset)
        print("using Adult dataset")
        print(f"{adult_cfg.drop_native=}")
    elif OmegaConf.get_type(cfg.dataset) is CmnistConfig:
        cmnist_cfg = cast(CmnistConfig, cfg.dataset)
        print("using CMNIST dataset")
        print(f"{cmnist_cfg.padding=}")

    print()
    print(f"{cfg.seed=}")
    print(f"{cfg.use_wandb=}")
    print(f"{cfg.data_pcnt=}")

    print()
    print("Config as flat dictionary:")
    print(flatten(OmegaConf.to_container(cfg, enum_to_str=True)))
Exemplo n.º 3
0
def test_overlapping_schemas(hydra_restore_singletons: Any) -> None:

    cs = ConfigStore.instance()
    cs.store(name="config", node=Config)
    cs.store(group="plugin", name="concrete", node=ConcretePlugin)

    config_loader = ConfigLoaderImpl(
        config_search_path=create_config_search_path(None))
    cfg = config_loader.load_configuration(config_name="config",
                                           overrides=[],
                                           run_mode=RunMode.RUN)
    with open_dict(cfg):
        del cfg["hydra"]

    assert cfg == {"plugin": {"name": "???", "params": "???"}}
    assert OmegaConf.get_type(cfg.plugin) == Plugin

    cfg = config_loader.load_configuration(config_name="config",
                                           overrides=["+plugin=concrete"],
                                           run_mode=RunMode.RUN)
    with open_dict(cfg):
        del cfg["hydra"]

    assert cfg == {"plugin": {"name": "foobar_plugin", "params": {"foo": 10}}}
    assert OmegaConf.get_type(cfg.plugin) == ConcretePlugin
    assert OmegaConf.get_type(cfg.plugin.params) == ConcretePlugin.FoobarParams
    with pytest.raises(ValidationError):
        cfg.plugin = 10
Exemplo n.º 4
0
    def configure_optimizers(self):
        if OmegaConf.get_type(self.optim_cfg) is SGDConfig:
            optimizer = torch.optim.SGD(
                params=self.parameters(),
                lr=self.optim_cfg.learning_rate,
                momentum=self.optim_cfg.momentum,
                nesterov=True,
                weight_decay=self.optim_cfg.weight_decay
            )
        elif OmegaConf.get_type(self.optim_cfg) is AdamConfig:
            optimizer = torch.optim.AdamW(
                params=self.parameters(),
                lr=self.optim_cfg.learning_rate,
                betas=self.optim_cfg.betas,
                eps=self.optim_cfg.eps,
                weight_decay=self.optim_cfg.weight_decay
            )
        else:
            raise ValueError("Optimizer has not been specified correctly.")

        scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer=optimizer,
            gamma=self.optim_cfg.learning_anneal
        )
        return [optimizer], [scheduler]
Exemplo n.º 5
0
    def _validate_merge(self, value: Any) -> None:
        from omegaconf import OmegaConf

        dest = self
        src = value

        self._validate_non_optional(None, src)

        dest_obj_type = OmegaConf.get_type(dest)
        src_obj_type = OmegaConf.get_type(src)

        if dest._is_missing() and src._metadata.object_type is not None:
            self._validate_set(key=None, value=_get_value(src))

        if src._is_missing():
            return

        validation_error = (dest_obj_type is not None
                            and src_obj_type is not None
                            and is_structured_config(dest_obj_type)
                            and not src._is_none()
                            and not is_dict(src_obj_type)
                            and not issubclass(src_obj_type, dest_obj_type))
        if validation_error:
            msg = (f"Merge error: {type_str(src_obj_type)} is not a "
                   f"subclass of {type_str(dest_obj_type)}. value: {src}")
            raise ValidationError(msg)
Exemplo n.º 6
0
 def test_merge(self, module: Any) -> None:
     cfg1 = OmegaConf.create({"plugin": module.Plugin})
     cfg2 = OmegaConf.create({"plugin": module.ConcretePlugin})
     assert cfg2.plugin == module.ConcretePlugin
     res: Any = OmegaConf.merge(cfg1, cfg2)
     assert OmegaConf.get_type(res.plugin) == module.ConcretePlugin
     assert (OmegaConf.get_type(
         res.plugin.params) == module.ConcretePlugin.FoobarParams)
Exemplo n.º 7
0
    def test_promote_to_object(self, module: Any) -> None:

        conf = OmegaConf.create(module.AnyTypeConfig)
        assert OmegaConf.get_type(conf) == module.AnyTypeConfig

        conf._promote(module.BoolConfig(with_default=False))
        assert OmegaConf.get_type(conf) == module.BoolConfig
        assert conf.with_default is False
Exemplo n.º 8
0
 def test_merge_with_subclass_into_missing(self, module: Any) -> None:
     base = OmegaConf.structured(module.PluginHolder)
     assert _utils.get_ref_type(base, "missing") == module.Plugin
     assert OmegaConf.get_type(base, "missing") is None
     res = OmegaConf.merge(base, {"missing": module.Plugin})
     assert OmegaConf.get_type(res) == module.PluginHolder
     assert _utils.get_ref_type(base, "missing") == module.Plugin
     assert OmegaConf.get_type(res, "missing") == module.Plugin
Exemplo n.º 9
0
    def __init__(self, labels: List,
                 model_cfg: Union[UniDirectionalConfig, BiDirectionalConfig,
                                  ConvolutionConfig], precision: int,
                 optim_cfg: Union[AdamConfig,
                                  SGDConfig], spect_cfg: SpectConfig):
        super().__init__()
        self.save_hyperparameters()
        self.model_cfg = model_cfg
        self.precision = precision
        self.optim_cfg = optim_cfg
        self.spect_cfg = spect_cfg
        self.convolutional = True if OmegaConf.get_type(
            model_cfg) is ConvolutionConfig else False
        self.bidirectional = True if OmegaConf.get_type(
            model_cfg) is BiDirectionalConfig else False

        self.labels = labels

        self.conv = MaskConv(
            nn.Sequential(
                nn.Conv2d(1,
                          32,
                          kernel_size=(41, 11),
                          stride=(2, 2),
                          padding=(20, 5)), nn.BatchNorm2d(32),
                nn.Hardtanh(0, 20, inplace=True),
                nn.Conv2d(32,
                          32,
                          kernel_size=(21, 11),
                          stride=(2, 1),
                          padding=(10, 5)), nn.BatchNorm2d(32),
                nn.Hardtanh(0, 20, inplace=True)))
        # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1
        rnn_input_size = int(
            math.floor((self.spect_cfg.sample_rate *
                        self.spect_cfg.window_size) / 2) + 1)
        rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1)
        rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1)
        rnn_input_size *= 32

        if self.convolutional is False:
            self.rnns, self.lookahead, self.fc = self._rnn_construct(
                rnn_input_size)

        else:
            self.deep_conv, self.fc = self._conv_construct(rnn_input_size)

        self.inference_softmax = InferenceBatchSoftmax()
        self.criterion = CTCLoss(blank=self.labels.index('_'),
                                 reduction='sum',
                                 zero_infinity=True)
        self.evaluation_decoder = GreedyDecoder(
            self.labels)  # Decoder used for validation
        self.wer = WordErrorRate(decoder=self.evaluation_decoder,
                                 target_decoder=self.evaluation_decoder)
        self.cer = CharErrorRate(decoder=self.evaluation_decoder,
                                 target_decoder=self.evaluation_decoder)
Exemplo n.º 10
0
def my_app(cfg: Config) -> None:
    # Remember that the actual type of Config and db inside it is DictConfig.
    # If you need to get the underlying type of a config object use OmegaConf.get_type:
    if OmegaConf.get_type(cfg.db) is MySQLConfig:
        connect_mysql(cast(MySQLConfig, cfg.db))
    elif OmegaConf.get_type(cfg.db) is PostGreSQLConfig:
        connect_postgresql(cast(PostGreSQLConfig, cfg.db))
    else:
        raise ValueError()
Exemplo n.º 11
0
def test_get_type() -> None:

    cfg = OmegaConf.structured(User)
    assert OmegaConf.get_type(cfg) == User

    cfg = OmegaConf.structured(User(name="bond"))
    assert OmegaConf.get_type(cfg) == User

    cfg = OmegaConf.create({"user": User})
    assert OmegaConf.get_type(cfg.user) == User
Exemplo n.º 12
0
    def test_promote_to_class(self, module: Any) -> None:

        conf = OmegaConf.create(module.AnyTypeConfig)
        assert OmegaConf.get_type(conf) == module.AnyTypeConfig

        conf._promote(module.BoolConfig)

        assert OmegaConf.get_type(conf) == module.BoolConfig
        assert conf.with_default is True
        assert conf.null_default is None
        assert OmegaConf.is_missing(conf, "mandatory_missing")
Exemplo n.º 13
0
def hydra_main(cfg: ExampleEEGConfig):
    console = logging.StreamHandler()
    console.setFormatter(logging.Formatter("[%(name)s] [%(levelname)s] %(message)s"))
    console.setLevel(logging.INFO)
    logging.getLogger("ml").addHandler(console)

    if OmegaConf.get_type(cfg.train.model) == CNNConfig:
        hyperparameters = {
            'transformer.transform': ['none'],
            'train.model.channel_list': [[4, 8, 16, 32]],
            'train.model.kernel_sizes': [[[4]] * 4],
            'train.model.stride_sizes': [[[2]] * 4],
            'train.model.padding_sizes': [[[1]] * 4],
            'train.model.optim.lr': [1e-4],
        }
    elif OmegaConf.get_type(cfg.train.model) == CNNRNNConfig:
        hyperparameters = {
            'train.model.optim.lr': [1e-3, 1e-4, 1e-5],
            'transformer.transform': ['none'],
            'train.model.channel_list': [[4, 8, 16, 32]],
            'train.model.kernel_sizes': [[[4]] * 4],
            'train.model.stride_sizes': [[[2]] * 4],
            'train.model.padding_sizes': [[[1]] * 4],
            'train.model.rnn_type': [cfg.train.model.rnn_type],
            'train.model.bidirectional': [True],
            'train.model.rnn_n_layers': [1, 2],
            'train.model.rnn_hidden_size': [10, 50],
        }
    elif OmegaConf.get_type(cfg.train.model) == RNNConfig:
        hyperparameters = {
            'train.model.bidirectional': [True, False],
            'train.model.rnn_type': ['lstm', 'gru'],
            'train.model.rnn_n_layers': [1, 2],
            'train.model.rnn_hidden_size': [10, 50],
            'transformer.transform': ['none'],
            'train.model.optim.lr': [1e-4],
        }
    else:
        hyperparameters = {
            'train.model.optim.lr': [1e-4],
            'batch_size': [16],
            'transformer.transform': ['logmel'],
            'loss_func': ['ce'],
            'epoch_rate': [1.0],
            'sample_balance': ['same'],
        }

    cfg.expt_id = f'{cfg.train.model_type.value}'
    expt_dir = Path(utils.to_absolute_path('output')) / 'example_face' / f'{cfg.expt_id}'
    expt_dir.mkdir(exist_ok=True, parents=True)
    main(cfg, expt_dir, hyperparameters)

    if not cfg.mlflow:
        shutil.rmtree('mlruns')
Exemplo n.º 14
0
 def _set_optimizer(self):
     if OmegaConf.get_type(self.cfg.optim) == AdamConfig:
         return torch.optim.Adam(self.model.parameters(),
                                 lr=self.cfg.optim.lr,
                                 weight_decay=self.cfg.optim.weight_decay)
     elif OmegaConf.get_type(self.cfg.optim) == SGDConfig:
         return torch.optim.SGD(self.model.parameters(),
                                lr=self.cfg.optim.lr,
                                momentum=self.cfg.optim.momentum,
                                weight_decay=self.cfg.optim.weight_decay,
                                nesterov=True)
Exemplo n.º 15
0
def test_get_ref_type_with_conflict() -> None:
    cfg = OmegaConf.create(
        {"user": User, "inter": DictConfig(ref_type=Plugin, content="${user}")}
    )

    assert OmegaConf.get_type(cfg.user) == User
    assert _utils.get_ref_type(cfg.user) == Optional[User]

    # Interpolation inherits both type and ref type from the target
    assert OmegaConf.get_type(cfg.inter) == User
    assert _utils.get_ref_type(cfg.inter) == Optional[User]
Exemplo n.º 16
0
    def test_get_type(self, class_type: str) -> None:
        module: Any = import_module(class_type)
        linked_list = module.LinkedList
        cfg1 = OmegaConf.create(linked_list)
        assert OmegaConf.get_type(cfg1) == linked_list
        assert cfg1.next is None
        assert OmegaConf.is_missing(cfg1, "value")

        cfg2 = OmegaConf.create(module.MissingTest.Missing1)
        assert OmegaConf.is_missing(cfg2, "head")
        assert OmegaConf.get_type(cfg2, "head") == module.LinkedList
Exemplo n.º 17
0
        def test_plugin_merge(self, class_type: str) -> None:
            module: Any = import_module(class_type)
            plugin = OmegaConf.structured(module.Plugin)
            concrete = OmegaConf.structured(module.ConcretePlugin)
            ret = OmegaConf.merge(plugin, concrete)
            assert ret == concrete
            assert OmegaConf.get_type(ret) == module.ConcretePlugin

            more_fields = OmegaConf.structured(
                module.PluginWithAdditionalField)
            ret = OmegaConf.merge(plugin, more_fields)
            assert ret == more_fields
            assert OmegaConf.get_type(ret) == module.PluginWithAdditionalField
Exemplo n.º 18
0
    def test_get_type(self, module: Any) -> None:
        cfg1 = OmegaConf.create(module.LinkedList)
        assert OmegaConf.get_type(cfg1) == module.LinkedList
        assert _utils.get_ref_type(cfg1, "next") == Optional[module.LinkedList]
        assert OmegaConf.get_type(cfg1, "next") is None

        assert cfg1.next is None
        assert OmegaConf.is_missing(cfg1, "value")

        cfg2 = OmegaConf.create(module.MissingTest.Missing1)
        assert OmegaConf.is_missing(cfg2, "head")
        assert _utils.get_ref_type(cfg2, "head") == module.LinkedList
        assert OmegaConf.get_type(cfg2, "head") is None
Exemplo n.º 19
0
def test_instantiated_regular_class_container_types_partial__recursive(
        instantiate_func: Any, config: Any) -> None:
    ret = instantiate_func(config, _convert_=ConvertMode.PARTIAL)
    assert isinstance(ret.a, SimpleClass)
    assert isinstance(ret.a.a, dict)
    assert isinstance(ret.a.b, DictConfig)
    assert OmegaConf.get_type(ret.a.b) is User
Exemplo n.º 20
0
    def _validate_set(self, key: Any, value: Any) -> None:
        from omegaconf import OmegaConf

        vk = get_value_kind(value)
        if vk in (ValueKind.INTERPOLATION, ValueKind.STR_INTERPOLATION):
            return
        self._validate_non_optional(key, value)
        if value == "???" or value is None:
            return

        target = self._get_node(key) if key is not None else self

        target_has_ref_type = isinstance(
            target,
            DictConfig) and target._metadata.ref_type not in (Any, dict)
        is_valid_target = target is None or not target_has_ref_type

        if is_valid_target:
            return

        target_type = target._metadata.ref_type  # type: ignore
        value_type = OmegaConf.get_type(value)

        if is_dict(value_type) and is_dict(target_type):
            return
        if is_container_annotation(
                target_type) and not is_container_annotation(value_type):
            raise ValidationError(
                f"Cannot assign {type_str(value_type)} to {type_str(target_type)}"
            )
        validation_error = (target_type is not None and value_type is not None
                            and not issubclass(value_type, target_type))
        if validation_error:
            self._raise_invalid_value(value, value_type, target_type)
Exemplo n.º 21
0
    def _validate_merge(self, key: Any, value: Any) -> None:
        from omegaconf import OmegaConf

        self._validate_non_optional(key, value)

        target = self._get_node(key) if key is not None else self

        target_has_ref_type = isinstance(
            target, DictConfig
        ) and target._metadata.ref_type not in (Any, dict)
        is_valid_value = target is None or not target_has_ref_type
        if is_valid_value:
            return

        target_type = target._metadata.ref_type  # type: ignore
        value_type = OmegaConf.get_type(value)
        if is_generic_container(target_type):
            return
        # Merging of a dictionary is allowed even if assignment is illegal (merge would do deeper checks)
        validation_error = (
            target_type is not None
            and value_type is not None
            and not issubclass(value_type, target_type)
            and not is_dict(value_type)
        )

        if validation_error:
            self._raise_invalid_value(value, value_type, target_type)
Exemplo n.º 22
0
def test_instantiated_regular_class_container_types_partial(
    instantiate_func: Any, ) -> None:
    cfg = {"_target_": "tests.instantiate.SimpleClass", "a": {}, "b": User()}
    ret = instantiate_func(cfg, _convert_=ConvertMode.PARTIAL)
    assert isinstance(ret.a, dict)
    assert isinstance(ret.b, DictConfig)
    assert OmegaConf.get_type(ret.b) is User
Exemplo n.º 23
0
 def test_plugin_merge_2(self, module: Any) -> None:
     plugin = OmegaConf.structured(module.Plugin)
     more_fields = OmegaConf.structured(
         module.PluginWithAdditionalField)
     ret = OmegaConf.merge(plugin, more_fields)
     assert ret == more_fields
     assert OmegaConf.get_type(ret) == module.PluginWithAdditionalField
Exemplo n.º 24
0
 def test_merged_type1(self, class_type: str) -> None:
     # Test that the merged type is that of the last merged config
     module: Any = import_module(class_type)
     input_ = module.WithDictField
     conf = OmegaConf.structured(input_)
     res = OmegaConf.merge(OmegaConf.create(), conf)
     assert OmegaConf.get_type(res) == input_
Exemplo n.º 25
0
 def test_merged_with_subclass(self, class_type: str) -> None:
     # Test that the merged type is that of the last merged config
     module: Any = import_module(class_type)
     c1 = OmegaConf.structured(module.Plugin)
     c2 = OmegaConf.structured(module.ConcretePlugin)
     res = OmegaConf.merge(c1, c2)
     assert OmegaConf.get_type(res) == module.ConcretePlugin
Exemplo n.º 26
0
 def test_merged_type2(self, class_type: str) -> None:
     # Test that the merged type is that of the last merged config
     module: Any = import_module(class_type)
     input_ = module.WithDictField
     conf = OmegaConf.structured(input_)
     res = OmegaConf.merge(conf, {"dict": {"foo": 99}})
     assert OmegaConf.get_type(res) == input_
Exemplo n.º 27
0
def train(cfg: DeepSpeechConfig):
    seed_everything(cfg.seed)

    with open(to_absolute_path(cfg.data.labels_path)) as label_file:
        labels = json.load(label_file)

    if cfg.trainer.checkpoint_callback:
        if OmegaConf.get_type(cfg.checkpoint) is GCSCheckpointConfig:
            checkpoint_callback = GCSCheckpointHandler(cfg=cfg.checkpoint)
        else:
            checkpoint_callback = FileCheckpointHandler(cfg=cfg.checkpoint)
        if cfg.load_auto_checkpoint:
            resume_from_checkpoint = checkpoint_callback.find_latest_checkpoint(
            )
            if resume_from_checkpoint:
                cfg.trainer.resume_from_checkpoint = resume_from_checkpoint
    print(cfg.trainer.gpus)
    data_loader = DeepSpeechDataModule(labels=labels,
                                       data_cfg=cfg.data,
                                       normalize=True,
                                       is_distributed=cfg.trainer.gpus > 1)

    model = DeepSpeech(labels=labels,
                       model_cfg=cfg.model,
                       optim_cfg=cfg.optim,
                       precision=cfg.trainer.precision,
                       spect_cfg=cfg.data.spect)

    trainer = hydra.utils.instantiate(
        config=cfg.trainer,
        replace_sampler_ddp=False,
        callbacks=[checkpoint_callback]
        if cfg.trainer.checkpoint_callback else None,
    )
    trainer.fit(model, data_loader)
Exemplo n.º 28
0
    def _validate_set(self, key: Any, value: Any) -> None:
        from omegaconf import OmegaConf

        self._validate_get(key, value)

        if self._get_flag("readonly"):
            raise ReadonlyConfigError("ListConfig is read-only")

        if 0 <= key < self.__len__():
            target = self._get_node(key)
            if target is not None:
                assert isinstance(target, Node)
                if value is None and not target._is_optional():
                    raise ValidationError(
                        "$FULL_KEY is not optional and cannot be assigned None"
                    )

        target_type = self._metadata.element_type
        value_type = OmegaConf.get_type(value)
        if is_structured_config(target_type):
            if (
                target_type is not None
                and value_type is not None
                and not issubclass(value_type, target_type)
            ):
                msg = (
                    f"Invalid type assigned : {type_str(value_type)} is not a "
                    f"subclass of {type_str(target_type)}. value: {value}"
                )
                raise ValidationError(msg)
Exemplo n.º 29
0
def hydra_main(cfg: ExampleFaceConfig):
    console = logging.StreamHandler()
    console.setFormatter(
        logging.Formatter("[%(name)s] [%(levelname)s] %(message)s"))
    console.setLevel(logging.INFO)
    logging.getLogger("ml").addHandler(console)

    if OmegaConf.get_type(cfg.train.model) == CNNConfig:
        hyperparameters = {
            'train.model.optim.lr': [1e-4],
        }
    elif OmegaConf.get_type(cfg.train.model) == CNNRNNConfig:
        hyperparameters = {
            'train.model.optim.lr': [1e-3, 1e-4, 1e-5],
            'window_size': [0.5],
            'window_stride': [0.1],
            'transform': ['logmel'],
            'rnn_type': [cfg.rnn_type],
            'bidirectional': [True],
            'rnn_n_layers': [1],
            'rnn_hidden_size': [10],
        }
    elif OmegaConf.get_type(cfg.train.model) == RNNConfig:
        hyperparameters = {
            'bidirectional': [True, False],
            'rnn_type': ['lstm', 'gru'],
            'rnn_n_layers': [1, 2],
            'rnn_hidden_size': [10, 50],
            'transform': [None],
            'train.model.optim.lr': [1e-3, 1e-4, 1e-5],
        }
    else:
        hyperparameters = {
            'train.model.optim.lr': [1e-4, 1e-5],
            'data.batch_size': [64],
            'data.epoch_rate': [1.0],
            'data.sample_balance': ['same'],
        }

    cfg.expt_id = f'{cfg.train.model_type.value}_pretrain-{cfg.train.model.pretrained}'
    expt_dir = Path(
        utils.to_absolute_path('output')) / 'example_face' / f'{cfg.expt_id}'
    expt_dir.mkdir(exist_ok=True, parents=True)
    main(cfg, expt_dir, hyperparameters)

    if not cfg.mlflow:
        shutil.rmtree('mlruns')
Exemplo n.º 30
0
        def test_missing1(self, module: Any) -> None:
            cfg = OmegaConf.create(module.MissingTest.Missing1)
            assert OmegaConf.is_missing(cfg, "head")

            assert OmegaConf.get_type(cfg, "head") is None

            with raises(ValidationError):
                cfg.head = 10