Example #1
0
def test_resolve_str_interpolation(query: str, result: Any) -> None:
    cfg = OmegaConf.create({"foo": 10, "bar": "${foo}"})
    assert (cfg._resolve_interpolation(
        key=None,
        value=StringNode(value=query),
        throw_on_missing=False,
        throw_on_resolution_failure=True,
    ) == result)
Example #2
0
    read_write,
)
from omegaconf.errors import ConfigAttributeError, ConfigKeyError

from . import Color, StructuredWithMissing, User, does_not_raise


@pytest.mark.parametrize(
    "input_, key, value, expected",
    [
        # dict
        (dict(), "foo", 10, dict(foo=10)),
        (dict(), "foo", IntegerNode(10), dict(foo=10)),
        (dict(foo=5), "foo", IntegerNode(10), dict(foo=10)),
        # changing type of a node
        (dict(foo=StringNode("str")), "foo", IntegerNode(10), dict(foo=10)),
        # list
        ([0], 0, 10, [10]),
        (["a", "b", "c"], 1, 10, ["a", 10, "c"]),
        ([1, 2], 1, IntegerNode(10), [1, 10]),
        ([1, IntegerNode(2)], 1, IntegerNode(10), [1, 10]),
        # changing type of a node
        ([1, StringNode("str")], 1, IntegerNode(10), [1, 10]),
    ],
)
def test_set_value(
    input_: Any, key: Union[str, int], value: Any, expected: Any
) -> None:
    c = OmegaConf.create(input_)
    c[key] = value
    assert c == expected
Example #3
0
    def test_interpolation(self, node_type: Any, values: Any,
                           restore_resolvers: Any) -> None:
        resolver_output = 9999
        OmegaConf.register_resolver("func", lambda: resolver_output)
        values = copy.deepcopy(values)
        for value in values:
            node = {
                "reg": node_type(value=value, is_optional=False),
                "opt": node_type(value=value, is_optional=True),
            }
            cfg = OmegaConf.create({
                "const":
                10,
                "primitive_missing":
                "???",
                "resolver":
                StringNode(value="${func:}", is_optional=False),
                "opt_resolver":
                StringNode(value="${func:}", is_optional=True),
                "node":
                DictConfig(content=node, is_optional=False),
                "opt_node":
                DictConfig(content=node, is_optional=True),
                "reg":
                node_type(value=value, is_optional=False),
                "opt":
                node_type(value=value, is_optional=True),
                "opt_none":
                node_type(value=None, is_optional=True),
                "missing":
                node_type(value="???", is_optional=False),
                "opt_missing":
                node_type(value="???", is_optional=True),
                # Interpolations
                "int_reg":
                "${reg}",
                "int_opt":
                "${opt}",
                "int_opt_none":
                "${opt_none}",
                "int_missing":
                "${missing}",
                "int_opt_missing":
                "${opt_missing}",
                "str_int_const":
                StringNode(value="foo_${const}", is_optional=False),
                "opt_str_int_const":
                StringNode(value="foo_${const}", is_optional=True),
                "str_int_with_primitive_missing":
                StringNode(value="foo_${primitive_missing}",
                           is_optional=False),
                "opt_str_int_with_primitive_missing":
                StringNode(value="foo_${primitive_missing}", is_optional=True),
                "int_node":
                "${node}",
                "int_opt_node":
                "${opt_node}",
                "int_resolver":
                "${resolver}",
                "int_opt_resolver":
                "${opt_resolver}",
            })

            verify(cfg,
                   "const",
                   none=False,
                   opt=True,
                   missing=False,
                   inter=False,
                   exp=10)

            verify(
                cfg,
                "resolver",
                none=False,
                # Note, resolvers are always optional because the underlying function may return None
                opt=True,
                missing=False,
                inter=True,
                exp=resolver_output,
            )

            verify(
                cfg,
                "opt_resolver",
                none=False,
                opt=True,
                missing=False,
                inter=True,
                exp=resolver_output,
            )

            verify(
                cfg,
                "reg",
                none=False,
                opt=False,
                missing=False,
                inter=False,
                exp=value,
            )

            verify(cfg,
                   "opt",
                   none=False,
                   opt=True,
                   missing=False,
                   inter=False,
                   exp=value)
            verify(
                cfg,
                "opt_none",
                none=True,
                opt=True,
                missing=False,
                inter=False,
                exp=None,
            )
            verify(cfg,
                   "missing",
                   none=False,
                   opt=False,
                   missing=True,
                   inter=False)
            verify(cfg,
                   "opt_missing",
                   none=False,
                   opt=True,
                   missing=True,
                   inter=False)

            verify(
                cfg,
                "int_reg",
                none=False,
                opt=False,
                missing=False,
                inter=True,
                exp=value,
            )
            verify(
                cfg,
                "int_opt",
                none=False,
                opt=True,
                missing=False,
                inter=True,
                exp=value,
            )
            verify(
                cfg,
                "int_opt_none",
                none=True,
                opt=True,
                missing=False,
                inter=True,
                exp=None,
            )
            verify(cfg,
                   "int_missing",
                   none=False,
                   opt=False,
                   missing=True,
                   inter=True)
            verify(cfg,
                   "int_opt_missing",
                   none=False,
                   opt=True,
                   missing=True,
                   inter=True)

            verify(
                cfg,
                "str_int_const",
                none=False,
                opt=False,
                missing=False,
                inter=True,
                exp="foo_10",
            )
            verify(
                cfg,
                "opt_str_int_const",
                none=False,
                opt=True,
                missing=False,
                inter=True,
                exp="foo_10",
            )
            verify(
                cfg,
                "int_node",
                none=False,
                opt=False,
                missing=False,
                inter=True,
                exp=node,
            )

            verify(
                cfg,
                "int_opt_node",
                none=False,
                opt=True,
                missing=False,
                inter=True,
                exp=node,
            )

            verify(
                cfg,
                "int_resolver",
                none=False,
                opt=True,
                missing=False,
                inter=True,
                exp=resolver_output,
            )

            verify(
                cfg,
                "int_opt_resolver",
                none=False,
                opt=True,
                missing=False,
                inter=True,
                exp=resolver_output,
            )

            verify(
                cfg,
                "str_int_with_primitive_missing",
                none=False,
                opt=False,
                missing=False,
                inter=True,
            )

            verify(
                cfg,
                "opt_str_int_with_primitive_missing",
                none=False,
                opt=True,
                missing=False,
                inter=True,
            )
Example #4
0
    c = OmegaConf.create([])
    assert isinstance(c, ListConfig)
    c.append(IntegerNode(10))
    assert c.get(0) == 10
    with pytest.raises(ValidationError):
        c[0] = "string"
    assert c[0] == 10
    assert type(c._get_node(0)) == IntegerNode


# Test merge raises validation error
@pytest.mark.parametrize(
    "c1, c2",
    [
        (dict(a=IntegerNode(10)), dict(a="str")),
        (dict(a=IntegerNode(10)), dict(a=StringNode("str"))),
        (dict(a=10, b=IntegerNode(10)), dict(a=20, b="str")),
        (dict(foo=dict(bar=IntegerNode(10))), dict(foo=dict(bar="str"))),
    ],
)
def test_merge_validation_error(c1: Dict[str, Any], c2: Dict[str, Any]) -> None:
    conf1 = OmegaConf.create(c1)
    conf2 = OmegaConf.create(c2)
    with pytest.raises(ValidationError):
        OmegaConf.merge(conf1, conf2)
    # make sure that conf1 and conf2 were not modified
    assert conf1 == OmegaConf.create(c1)
    assert conf2 == OmegaConf.create(c2)


@pytest.mark.parametrize(
Example #5
0
    OmegaConfUserResolver,
)
from tests import Color


@fixture
def resolver() -> Any:
    yield OmegaConfUserResolver()


@mark.parametrize(
    ("obj", "expected"),
    [
        # nodes
        param(AnyNode(10), {}, id="any:10"),
        param(StringNode("foo"), {}, id="str:foo"),
        param(IntegerNode(10), {}, id="int:10"),
        param(FloatNode(3.14), {}, id="float:3.14"),
        param(BooleanNode(True), {}, id="bool:True"),
        param(BytesNode(b"binary"), {}, id="bytes:binary"),
        param(EnumNode(enum_type=Color, value=Color.RED), {},
              id="Color:Color.RED"),
        # nodes are never returning a dictionary
        param(AnyNode("${foo}", parent=DictConfig({"foo": 10})), {},
              id="any:inter_10"),
        # DictConfig
        param(DictConfig({"a": 10}), {"a": AnyNode(10)}, id="dict"),
        param(
            DictConfig({
                "a": 10,
                "b": "${a}"
Example #6
0
         "bar": MISSING
     },
     "foo",
     True,
     raises(MissingMandatoryValue),
 ),
 (
     {
         "foo": "${unknown_resolver:foo}"
     },
     "foo",
     False,
     raises(UnsupportedInterpolationType),
 ),
 ({
     "foo": StringNode(value="???")
 }, "foo", True, raises(MissingMandatoryValue)),
 (
     {
         "foo": StringNode(value="???"),
         "inter": "${foo}"
     },
     "inter",
     True,
     raises(MissingMandatoryValue),
 ),
 (StructuredWithMissing, "num", True, raises(MissingMandatoryValue)),
 (StructuredWithMissing, "opt_num", True,
  raises(MissingMandatoryValue)),
 (StructuredWithMissing, "dict", True, raises(MissingMandatoryValue)),
 (StructuredWithMissing, "opt_dict", True,
Example #7
0
     id="missing_dict",
 ),
 ({"foo": "${bar}", "bar": MISSING}, "foo", True, raises(MissingMandatoryValue)),
 (
     {"foo": "foo_${bar}", "bar": MISSING},
     "foo",
     False,
     raises(MissingMandatoryValue),
 ),
 (
     {"foo": "${unknown_resolver:foo}"},
     "foo",
     False,
     raises(UnsupportedInterpolationType),
 ),
 ({"foo": StringNode(value="???")}, "foo", True, raises(MissingMandatoryValue)),
 (
     {"foo": StringNode(value="???"), "inter": "${foo}"},
     "inter",
     True,
     raises(MissingMandatoryValue),
 ),
 (StructuredWithMissing, "num", True, raises(MissingMandatoryValue)),
 (StructuredWithMissing, "opt_num", True, raises(MissingMandatoryValue)),
 (StructuredWithMissing, "dict", True, raises(MissingMandatoryValue)),
 (StructuredWithMissing, "opt_dict", True, raises(MissingMandatoryValue)),
 (StructuredWithMissing, "list", True, raises(MissingMandatoryValue)),
 (StructuredWithMissing, "opt_list", True, raises(MissingMandatoryValue)),
 (StructuredWithMissing, "user", True, raises(MissingMandatoryValue)),
 (StructuredWithMissing, "opt_user", True, raises(MissingMandatoryValue)),
 (StructuredWithMissing, "inter_user", True, raises(MissingMandatoryValue)),
Example #8
0
    c = OmegaConf.create([])
    assert isinstance(c, ListConfig)
    c.append(IntegerNode(10))
    assert c.get(0) == 10
    with raises(ValidationError):
        c[0] = "string"
    assert c[0] == 10
    assert type(c._get_node(0)) == IntegerNode


# Test merge raises validation error
@mark.parametrize(
    "c1, c2",
    [
        (dict(a=IntegerNode(10)), dict(a="str")),
        (dict(a=IntegerNode(10)), dict(a=StringNode("str"))),
        (dict(a=10, b=IntegerNode(10)), dict(a=20, b="str")),
        (dict(foo=dict(bar=IntegerNode(10))), dict(foo=dict(bar="str"))),
    ],
)
def test_merge_validation_error(c1: Dict[str, Any], c2: Dict[str,
                                                             Any]) -> None:
    conf1 = OmegaConf.create(c1)
    conf2 = OmegaConf.create(c2)
    with raises(ValidationError):
        OmegaConf.merge(conf1, conf2)
    # make sure that conf1 and conf2 were not modified
    assert conf1 == OmegaConf.create(c1)
    assert conf2 == OmegaConf.create(c2)

Example #9
0

def test_list_integer_rejects_string():
    c = OmegaConf.create([])
    c.append(IntegerNode(10))
    assert c.get(0) == 10
    with pytest.raises(ValidationError):
        c[0] = "string"
    assert c[0] == 10
    assert type(c.get_node(0)) == IntegerNode


# Test merge raises validation error
@pytest.mark.parametrize(
    "c1, c2",
    [
        (dict(a=IntegerNode(10)), dict(a="str")),
        (dict(a=IntegerNode(10)), dict(a=StringNode("str"))),
        (dict(a=10, b=IntegerNode(10)), dict(a=20, b="str")),
        (dict(foo=dict(bar=IntegerNode(10))), dict(foo=dict(bar="str"))),
    ],
)
def test_merge_validation_error(c1, c2):
    conf1 = OmegaConf.create(c1)
    conf2 = OmegaConf.create(c2)
    with pytest.raises(ValidationError):
        OmegaConf.merge(conf1, conf2)
    # make sure that conf1 and conf2 were not modified
    assert conf1 == OmegaConf.create(c1)
    assert conf2 == OmegaConf.create(c2)