コード例 #1
0
ファイル: compose.py プロジェクト: ryantm/hydra
def compose(
    config_file: Optional[str] = None,
    overrides: List[str] = [],
    strict: Optional[bool] = None,
) -> DictConfig:
    """
    :param config_file: optional config file to load
    :param overrides: list of overrides for config file
    :param strict: optionally override the default strict mode
    :return: the composed config
    """
    assert (
        GlobalHydra().is_initialized()
    ), "GlobalHydra is not initialized, use @hydra.main() or call hydra.experimental.initialize() first"

    gh = GlobalHydra.instance()
    assert gh.hydra is not None
    cfg = gh.hydra.compose_config(config_file=config_file,
                                  overrides=overrides,
                                  strict=strict)
    assert isinstance(cfg, DictConfig)

    if "hydra" in cfg:
        del cfg["hydra"]
    return cfg
コード例 #2
0
def test_initialize() -> None:
    try:
        assert not GlobalHydra().is_initialized()
        hydra.experimental.initialize(config_dir=None, strict=True)
        assert GlobalHydra().is_initialized()
    finally:
        GlobalHydra().clear()
コード例 #3
0
ファイル: test_compose.py プロジェクト: aixioma/hydra
def test_initialize_hydra():
    try:
        assert not GlobalHydra().is_initialized()
        hydra.experimental.initialize_hydra(task_name="task",
                                            search_path_dir=None,
                                            strict=True)
        assert GlobalHydra().is_initialized()
    finally:
        GlobalHydra().clear()
コード例 #4
0
def test_config_installed(hydra_global_context):  # noqa: F811
    """
    Tests that color options are available for both hydra/hydra_logging and hydra/job_logging
    """

    with hydra_global_context(config_dir="../hydra_plugins/hydra_colorlog/conf"):
        config_loader = GlobalHydra().hydra.config_loader
        assert "colorlog" in config_loader.get_group_options("hydra/job_logging")
        assert "colorlog" in config_loader.get_group_options("hydra/hydra_logging")
コード例 #5
0
def get_hydra():
    global_hydra = GlobalHydra()
    if not global_hydra.is_initialized():
        return Hydra.create_main_hydra_file_or_module(
            calling_file=__file__,
            calling_module=None,
            config_dir="configs",
            strict=False,
        )
    else:
        return global_hydra.hydra
コード例 #6
0
def test_initialize_with_config_dir() -> None:
    try:
        assert not GlobalHydra().is_initialized()
        hydra.experimental.initialize(config_dir="../hydra/test_utils/configs",
                                      strict=True)
        assert GlobalHydra().is_initialized()

        config_search_path = (
            GlobalHydra.instance().hydra.config_loader.config_search_path)
        idx = config_search_path.find_first_match(
            SearchPath(provider="main", search_path=None))
        assert idx != -1
    finally:
        GlobalHydra().clear()
コード例 #7
0
def compose(config_file=None, overrides=[], strict=None):
    """
    :param config_file: optional config file to load
    :param overrides: list of overrides for config file
    :param strict: optionally override the default strict mode
    :return: the composed config
    """
    assert (
        GlobalHydra().is_initialized()
    ), "GlobalHydra is not initialized, use @hydra.main() or call hydra.experimental.initialize() first"

    cfg = GlobalHydra().hydra.compose_config(config_file=config_file,
                                             overrides=overrides,
                                             strict=strict)
    if "hydra" in cfg:
        del cfg["hydra"]
    return cfg
コード例 #8
0
def test_config_installed(
        hydra_global_context: TGlobalHydraContext,  # noqa: F811
) -> None:
    """
    Tests that color options are available for both hydra/hydra_logging and hydra/job_logging
    """

    with hydra_global_context(
            config_dir="../hydra_plugins/hydra_colorlog/conf"):
        gh = GlobalHydra.instance()
        assert gh.hydra is not None
        config_loader = gh.hydra.config_loader
        assert "colorlog" in config_loader.get_group_options(
            "hydra/job_logging")
        assert "colorlog" in config_loader.get_group_options(
            "hydra/hydra_logging")
コード例 #9
0
        def __enter__(self):
            try:
                config_dir, config_file = split_config_path(self.config_path)
                hydra = Hydra.create_main_hydra_file_or_module(
                    calling_file=self.calling_file,
                    calling_module=self.calling_module,
                    config_dir=config_dir,
                    strict=self.strict,
                )

                self.hydra = hydra
                self.temp_dir = tempfile.mkdtemp()
                overrides = copy.deepcopy(self.overrides)
                overrides.append("hydra.run.dir={}".format(self.temp_dir))
                self.job_ret = self.hydra.run(
                    config_file=config_file, task_function=self, overrides=overrides,
                )
                strip_node(self.job_ret.cfg, "hydra.run.dir")
                return self
            finally:
                GlobalHydra().clear()
コード例 #10
0
    def __enter__(self) -> "SweepTaskFunction":
        self.temp_dir = tempfile.mkdtemp()
        overrides = copy.deepcopy(self.overrides)
        assert overrides is not None
        overrides.append("hydra.sweep.dir={}".format(self.temp_dir))
        try:
            config_dir, config_file = split_config_path(self.config_path)
            hydra = Hydra.create_main_hydra_file_or_module(
                calling_file=self.calling_file,
                calling_module=self.calling_module,
                config_dir=config_dir,
                strict=self.strict,
            )

            self.returns = hydra.multirun(
                config_file=config_file, task_function=self, overrides=overrides
            )
        finally:
            GlobalHydra().clear()

        return self
コード例 #11
0
        def __enter__(self):
            self.temp_dir = tempfile.mkdtemp()
            overrides = copy.deepcopy(self.overrides)
            overrides.append("hydra.sweep.dir={}".format(self.temp_dir))
            try:
                config_dir, config_file = split_config_path(self.config_path)
                hydra = Hydra.create_main_hydra_file_or_module(
                    calling_file=self.calling_file,
                    calling_module=self.calling_module,
                    config_dir=config_dir,
                    strict=self.strict,
                )

                self.returns = hydra.multirun(
                    config_file=config_file, task_function=self, overrides=overrides
                )
                flat = [item for sublist in self.returns for item in sublist]
                for ret in flat:
                    strip_node(ret.cfg, "hydra.sweep.dir")
            finally:
                GlobalHydra().clear()

            return self
コード例 #12
0
from hydra._internal.hydra import GlobalHydra
from hydra.experimental import compose as hydra_compose
from hydra.experimental import initialize as hydra_init
from PIL import Image

from meta_blocks.experiment.eval import evaluate
from meta_blocks.experiment.train import train

logger = logging.getLogger(__name__)

AVAILABLE_METHODS = ("maml", "fomaml", "reptile", "proto")
AVAILABLE_SETTINGS = ("classic_supervised", "self_supervised")

# Initialize hydra.
# TODO: is there a way to check if hydra is initialized using public API?
if not GlobalHydra().is_initialized():
    hydra_init(config_dir="conf", strict=False)


@pytest.mark.parametrize("adaptation_method", AVAILABLE_METHODS)
@pytest.mark.parametrize("experiment_setting", AVAILABLE_SETTINGS)
def test_omniglot_integration(adaptation_method, experiment_setting):
    def generate_dummy_miniimagenet_data(dir_path):
        """Generates dummy data that imitates mini-ImageNet structure.

        Mini-ImageNet is too heavy for integration testing, so we generate
        synthetic data (dummy images) that satisfy the mini-ImageNet spec.
        """
        num_dummy_categories = 20
        num_dummy_img_per_category = 10
        img_height, img_width = 84, 84
コード例 #13
0
 def __exit__(self, exc_type, exc_val, exc_tb):
     GlobalHydra().clear()
コード例 #14
0
 def __exit__(self, exc_type, exc_val, exc_tb) -> None:  # type: ignore
     GlobalHydra().clear()