예제 #1
0
def test_default(config, expected):
    """test parsing of model"""
    if isinstance(expected, dict):
        formatted_config = parse_default(config["model"],
                                         DEFAULT_CONFIG["model"])
        try:
            assert expected == formatted_config
        except AssertionError:
            for k, d in formatted_config["layers"].items():
                for opt in ["user_vals"]:
                    try:
                        assert (d["options"][opt] is
                                expected["layers"][k]["options"][opt]
                                ), f"layer {k} does not have matching {opt}"
                    except AssertionError:
                        for i, a in enumerate(d["options"][opt]):
                            b = expected["layers"][k]["options"][opt][i]
                            try:
                                assert (
                                    a is b
                                ), f"layer {k} does not have matching {opt} for {a} != {b}"
                            except AssertionError:
                                if issubclass(
                                        type(b),
                                        tf.keras.regularizers.Regularizer):
                                    # TODO: implement more in depth check
                                    assert issubclass(
                                        type(a),
                                        tf.keras.regularizers.Regularizer)
                                elif issubclass(
                                        type(b),
                                        tf.keras.initializers.Initializer):
                                    # TODO: implement more in depth check
                                    assert issubclass(
                                        type(a),
                                        tf.keras.initializers.Initializer)
                                else:
                                    assert (
                                        a == b
                                    ), f"{opt} in layer {k} does not match: {a} != {b}"
                for opt in ["func", "func_args", "func_defaults"]:
                    assert (d["layer_base"][opt] == expected["layers"][k]
                            ["layer_base"][opt]
                            ), f"layer {k} does not have matching {opt}"
                for opt in ["layer_in_name"]:
                    # print(d[opt])
                    assert (d[opt] == expected["layers"][k][opt]
                            ), f"layer {k} does not have matching {opt}"

    elif isinstance(expected, ValueError):
        with pytest.raises(ValueError):
            formatted_config = parse_default(config["model"],
                                             DEFAULT_CONFIG["model"])
    elif isinstance(expected, TypeError):
        with pytest.raises(TypeError):
            formatted_config = parse_default(config["model"],
                                             DEFAULT_CONFIG["model"])
예제 #2
0
def test_default(config, expected):
    """test parsing of data"""
    if isinstance(expected, dict):
        formatted_config = parse_default(config["data"], DEFAULT_CONFIG["data"])
        assert expected == formatted_config
    elif isinstance(expected, ValueError):
        with pytest.raises(ValueError):
            formatted_config = parse_default(config["data"], DEFAULT_CONFIG["data"])
    elif isinstance(expected, TypeError):
        with pytest.raises(TypeError):
            formatted_config = parse_default(config["data"], DEFAULT_CONFIG["data"])
예제 #3
0
def _primary_config(main_path: str) -> dict:
    main_config_raw = get_raw_dict_from_string(main_path)
    cur_keys = main_config_raw.keys()
    invalid_keys = []
    for key in CONFIG_KEYS:
        if key not in cur_keys:
            invalid_keys.append(key)
            # not all of these *need* to be present, but for now that will be
            # enforced
    if invalid_keys:
        raise ValueError(
            f"The main config does not contain the key(s) {invalid_keys}:"
            f" current keys: {cur_keys}")

    # build dict containing configs
    config_dict = {}
    for config_type in CONFIG_KEYS:
        # try block?
        raw_config = main_config_raw[config_type]
        raw_config = _maybe_extract_from_path(raw_config)

        formatted_config = parse_default(raw_config,
                                         DEFAULT_CONFIG[f"{config_type}"])
        if config_type == "model":
            model_hash = make_hash(formatted_config, IGNORE_HASH_KEYS)
            formatted_config["model_hash"] = model_hash

        config_dict[config_type] = formatted_config

    full_exp_path = (Path(config_dict["meta"]["yeahml_dir"]).joinpath(
        config_dict["meta"]["data_name"]).joinpath(
            config_dict["meta"]["experiment_name"]).joinpath(
                config_dict["model"]["name"]))
    logger = config_logger(full_exp_path, config_dict["logging"], "config")

    unused_keys = check_for_unused_keys(config_dict, main_config_raw, [], [])
    if unused_keys:
        _maybe_message(unused_keys, main_config_raw, logger)

    # TODO: this should probably be made once and stored? in the :meta?
    exp_root_dir = (Path(config_dict["meta"]["yeahml_dir"]).joinpath(
        config_dict["meta"]["data_name"]).joinpath(
            config_dict["meta"]["experiment_name"]))

    try:
        override_yml_dir = config_dict["meta"]["start_fresh"]
    except KeyError:
        # leave existing model information
        override_yml_dir = False

    if os.path.exists(exp_root_dir):
        if override_yml_dir:
            shutil.rmtree(exp_root_dir)
    if not os.path.exists(exp_root_dir):
        Path(exp_root_dir).mkdir(parents=True, exist_ok=True)

    model_root_dir = exp_root_dir.joinpath(config_dict["model"]["name"])
    try:
        override_model_dir = config_dict["model"]["start_fresh"]
    except KeyError:
        # leave existing model information
        override_model_dir = False

    _create_exp_dir(model_root_dir, wipe_dirs=override_model_dir)

    return config_dict