(SubscriptedList, "list", int, Any, False, List[int]), ( DictConfig( content={"a": "foo"}, ref_type=Dict[str, str], element_type=str, key_type=str, ), None, str, str, True, Optional[Dict[str, str]], ), ( ListConfig(content=[1, 2], ref_type=List[int], element_type=int), None, int, Any, True, Optional[List[int]], ), ], ) def test_pickle_untyped( input_: Any, node: str, optional: bool, element_type: Any, key_type: Any, ref_type: Any,
# testing invalid conversions @mark.parametrize( "type_,input_", [ (IntegerNode, "abc"), (IntegerNode, 10.1), (IntegerNode, "-1132c"), (IntegerNode, Color.RED), (FloatNode, "abc"), (FloatNode, Color.RED), (IntegerNode, "-abc"), (BooleanNode, "Nope"), (BooleanNode, "Yup"), (BooleanNode, Color.RED), (StringNode, [1, 2]), (StringNode, ListConfig([1, 2])), (StringNode, { "foo": "var" }), (FloatNode, DictConfig({"foo": "var"})), (IntegerNode, [1, 2]), (IntegerNode, ListConfig([1, 2])), (IntegerNode, { "foo": "var" }), (IntegerNode, DictConfig({"foo": "var"})), (BooleanNode, [1, 2]), (BooleanNode, ListConfig([1, 2])), (BooleanNode, { "foo": "var" }),
def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str): """ Changes vocabulary of the tokenizer used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need model to learn capitalization, punctuation and/or special characters. Args: new_tokenizer_dir: Path to the new tokenizer directory. new_tokenizer_type: Either `bpe` or `wpe`. `bpe` is used for SentencePiece tokenizers, whereas `wpe` is used for `BertTokenizer`. Returns: None """ if not os.path.isdir(new_tokenizer_dir): raise NotADirectoryError( f'New tokenizer dir must be non-empty path to a directory. But I got: {new_tokenizer_dir}' ) if new_tokenizer_type.lower() not in ('bpe', 'wpe'): raise ValueError( f'New tokenizer type must be either `bpe` or `wpe`') tokenizer_cfg = OmegaConf.create({ 'dir': new_tokenizer_dir, 'type': new_tokenizer_type }) # Setup the tokenizer self._setup_tokenizer(tokenizer_cfg) # Initialize a dummy vocabulary vocabulary = self.tokenizer.tokenizer.get_vocab() # Set the new vocabulary decoder_config = copy.deepcopy(self.decoder.to_config_dict()) decoder_config.vocabulary = ListConfig(list(vocabulary.keys())) decoder_num_classes = decoder_config['num_classes'] # Override number of classes if placeholder provided logging.info( "\nReplacing old number of classes ({}) with new number of classes - {}" .format(decoder_num_classes, len(vocabulary))) decoder_config['num_classes'] = len(vocabulary) del self.decoder self.decoder = EncDecCTCModelBPE.from_config_dict(decoder_config) del self.loss self.loss = CTCLoss( num_classes=self.decoder.num_classes_with_blank - 1, zero_infinity=True, reduction=self._cfg.get("ctc_reduction", "mean_batch"), ) self._wer = WERBPE( tokenizer=self.tokenizer, batch_dim_index=0, use_cer=self._cfg.get('use_cer', False), ctc_decode=True, log_prediction=self._cfg.get("log_prediction", False), ) # Update config OmegaConf.set_struct(self._cfg.decoder, False) self._cfg.decoder = decoder_config OmegaConf.set_struct(self._cfg.decoder, True) logging.info( f"Changed tokenizer to {self.decoder.vocabulary} vocabulary.")
class TestConfigs: def test_nested_config_is_none(self, module: Any) -> None: cfg = OmegaConf.structured(module.NestedWithNone) assert cfg == {"plugin": None} assert OmegaConf.get_type(cfg, "plugin") is None assert _utils.get_ref_type(cfg, "plugin") == Optional[module.Plugin] def test_nested_config(self, module: Any) -> None: def validate(cfg: DictConfig) -> None: assert cfg == { "default_value": { "with_default": 10, "null_default": None, "mandatory_missing": "???", "interpolation": "${value_at_root}", }, "user_provided_default": { "with_default": 42, "null_default": None, "mandatory_missing": "???", "interpolation": "${value_at_root}", }, "value_at_root": 1000, } with raises(ValidationError): cfg.user_provided_default = 10 with raises(ValidationError): cfg.default_value = 10 # assign subclass cfg.default_value = module.NestedSubclass() assert cfg.default_value == { "additional": 20, "with_default": 10, "null_default": None, "mandatory_missing": "???", "interpolation": "${value_at_root}", } # assign original ref type back cfg.default_value = module.Nested() assert cfg.default_value == module.Nested() conf1 = OmegaConf.structured(module.NestedConfig(default_value=module.Nested())) validate(conf1) def test_nested_config2(self, module: Any) -> None: def validate(cfg: DictConfig) -> None: assert cfg == { "default_value": "???", "user_provided_default": { "with_default": 42, "null_default": None, "mandatory_missing": "???", "interpolation": "${value_at_root}", }, "value_at_root": 1000, } with raises(ValidationError): cfg.user_provided_default = 10 with raises(ValidationError): cfg.default_value = 10 # assign subclass cfg.default_value = module.NestedSubclass() assert cfg.default_value == { "additional": 20, "with_default": 10, "null_default": None, "mandatory_missing": "???", "interpolation": "${value_at_root}", } # assign original ref type back cfg.default_value = module.Nested() assert cfg.default_value == module.Nested() conf1 = OmegaConf.structured(module.NestedConfig) validate(conf1) def test_value_without_a_default(self, module: Any) -> None: cfg = OmegaConf.structured(module.NoDefaultValue) assert OmegaConf.is_missing(cfg, "no_default") OmegaConf.structured(module.NoDefaultValue(no_default=10)) == {"no_default": 10} def test_union_errors(self, module: Any) -> None: with raises(ValueError): OmegaConf.structured(module.UnionError) def test_config_with_list(self, module: Any) -> None: def validate(cfg: DictConfig) -> None: assert cfg == {"list1": [1, 2, 3], "list2": [1, 2, 3], "missing": MISSING} with raises(ValidationError): cfg.list1[1] = "foo" assert OmegaConf.is_missing(cfg, "missing") conf1 = OmegaConf.structured(module.ConfigWithList) validate(conf1) conf1 = OmegaConf.structured(module.ConfigWithList()) validate(conf1) def test_assignment_to_nested_structured_config(self, module: Any) -> None: conf = OmegaConf.structured(module.NestedConfig) with raises(ValidationError): conf.default_value = 10 conf.default_value = module.Nested() def test_assignment_to_structured_inside_dict_config(self, module: Any) -> None: conf = OmegaConf.create( {"val": DictConfig(module.Nested, ref_type=module.Nested)} ) with raises(ValidationError): conf.val = 10 def test_config_with_dict(self, module: Any) -> None: def validate(cfg: DictConfig) -> None: assert cfg == {"dict1": {"foo": "bar"}, "missing": MISSING} assert OmegaConf.is_missing(cfg, "missing") conf1 = OmegaConf.structured(module.ConfigWithDict) validate(conf1) conf1 = OmegaConf.structured(module.ConfigWithDict()) validate(conf1) def test_structured_config_struct_behavior(self, module: Any) -> None: def validate(cfg: DictConfig) -> None: assert not OmegaConf.is_struct(cfg) with raises(AttributeError): # noinspection PyStatementEffect cfg.foo cfg.dict1.foo = 10 assert cfg.dict1.foo == 10 # setting struct False on a specific typed node opens it up even though it's # still typed OmegaConf.set_struct(cfg, False) cfg.foo = 20 assert cfg.foo == 20 conf = OmegaConf.structured(module.ConfigWithDict) validate(conf) conf = OmegaConf.structured(module.ConfigWithDict()) validate(conf) @mark.parametrize( "tested_type,assignment_data, init_dict", [ # Use class to build config ("BoolConfig", BoolConfigAssignments, {}), ("IntegersConfig", IntegersConfigAssignments, {}), ("FloatConfig", FloatConfigAssignments, {}), ("StringConfig", StringConfigAssignments, {}), ("EnumConfig", EnumConfigAssignments, {}), # Use instance to build config ("BoolConfig", BoolConfigAssignments, {"with_default": False}), ("IntegersConfig", IntegersConfigAssignments, {"with_default": 42}), ("FloatConfig", FloatConfigAssignments, {"with_default": 42.0}), ("StringConfig", StringConfigAssignments, {"with_default": "fooooooo"}), ("EnumConfig", EnumConfigAssignments, {"with_default": Color.BLUE}), ("AnyTypeConfig", AnyTypeConfigAssignments, {}), ], ) def test_field_with_default_value( self, module: Any, tested_type: str, init_dict: Dict[str, Any], assignment_data: Any, ) -> None: input_class = getattr(module, tested_type) def validate(input_: Any, expected: Any) -> None: conf = OmegaConf.structured(input_) # Test access assert conf.with_default == expected.with_default assert conf.null_default is None # Test that accessing a variable without a default value # results in a MissingMandatoryValue exception with raises(MissingMandatoryValue): # noinspection PyStatementEffect conf.mandatory_missing # Test interpolation preserves type and value assert type(conf.with_default) == type(conf.interpolation) # noqa E721 assert conf.with_default == conf.interpolation # Test that assignment of illegal values for illegal_value in assignment_data.illegal: with raises(ValidationError): conf.with_default = illegal_value with raises(ValidationError): conf.null_default = illegal_value with raises(ValidationError): conf.mandatory_missing = illegal_value # Test assignment of legal values for legal_value in assignment_data.legal: expected_data = legal_value if isinstance(legal_value, tuple): expected_data = legal_value[1] legal_value = legal_value[0] conf.with_default = legal_value conf.null_default = legal_value conf.mandatory_missing = legal_value msg = "Error: {} : {}".format(input_class.__name__, legal_value) assert conf.with_default == expected_data, msg assert conf.null_default == expected_data, msg assert conf.mandatory_missing == expected_data, msg validate(input_class, input_class()) validate(input_class(**init_dict), input_class(**init_dict)) @mark.parametrize( "input_init, expected_init", [ # attr class as class (None, {}), # attr class object with custom values ({"int_default": 30}, {"int_default": 30}), # dataclass as class (None, {}), # dataclass as object with custom values ({"int_default": 30}, {"int_default": 30}), ], ) def test_untyped(self, module: Any, input_init: Any, expected_init: Any) -> None: input_ = module.AnyTypeConfig expected = input_(**expected_init) if input_init is not None: input_ = input_(**input_init) conf = OmegaConf.structured(input_) assert conf.null_default == expected.null_default assert conf.int_default == expected.int_default assert conf.float_default == expected.float_default assert conf.str_default == expected.str_default assert conf.bool_default == expected.bool_default # yes, this is weird. assert "mandatory_missing" in conf.keys() and "mandatory_missing" not in conf with raises(MissingMandatoryValue): # noinspection PyStatementEffect conf.mandatory_missing assert type(conf._get_node("null_default")) == AnyNode assert type(conf._get_node("int_default")) == AnyNode assert type(conf._get_node("float_default")) == AnyNode assert type(conf._get_node("str_default")) == AnyNode assert type(conf._get_node("bool_default")) == AnyNode assert type(conf._get_node("mandatory_missing")) == AnyNode assert conf.int_default == expected.int_default with raises(ValidationError): conf.typed_int_default = "foo" values = [10, True, False, None, 1.0, -1.0, "10", float("inf")] for val in values: conf.null_default = val conf.int_default = val conf.float_default = val conf.str_default = val conf.bool_default = val assert conf.null_default == val assert conf.int_default == val assert conf.float_default == val assert conf.str_default == val assert conf.bool_default == val def test_interpolation(self, module: Any) -> Any: input_ = module.Interpolation() conf = OmegaConf.structured(input_) assert conf.x == input_.x assert conf.z1 == conf.x assert conf.z2 == f"{conf.x}_{conf.y}" assert type(conf.x) == int assert type(conf.y) == int assert type(conf.z1) == int assert type(conf.z2) == str @mark.parametrize( "tested_type", [ "BoolOptional", "IntegerOptional", "FloatOptional", "StringOptional", "ListOptional", "TupleOptional", "EnumOptional", "StructuredOptional", "DictOptional", ], ) def test_optional(self, module: Any, tested_type: str) -> None: input_ = getattr(module, tested_type) obj = input_() conf = OmegaConf.structured(input_) # verify non-optional fields are rejecting None with raises(ValidationError): conf.not_optional = None assert conf.as_none is None assert conf.with_default == obj.with_default # assign None to an optional field conf.with_default = None assert conf.with_default is None def test_list_field(self, module: Any) -> None: input_ = module.WithListField conf = OmegaConf.structured(input_) with raises(ValidationError): conf.list[0] = "fail" with raises(ValidationError): conf.list.append("fail") with raises(ValidationError): cfg2 = OmegaConf.create({"list": ["fail"]}) OmegaConf.merge(conf, cfg2) def test_dict_field(self, module: Any) -> None: input_ = module.WithDictField conf = OmegaConf.structured(input_) with raises(ValidationError): conf.dict["foo"] = "fail" with raises(ValidationError): OmegaConf.merge(conf, OmegaConf.create({"dict": {"foo": "fail"}})) @mark.skipif(sys.version_info < (3, 8), reason="requires Python 3.8 or newer") def test_typed_dict_field(self, module: Any) -> None: input_ = module.WithTypedDictField conf = OmegaConf.structured(input_(dict={"foo": 10})) assert conf.dict["foo"] == 10 # typed dicts does not currently runtime type safety. conf = OmegaConf.merge(conf, {"dict": {"foo": "not_failed"}}) assert conf.dict["foo"] == "not_failed" def test_merged_type1(self, module: Any) -> None: # Test that the merged type is that of the last merged config input_ = module.WithDictField conf = OmegaConf.structured(input_) res = OmegaConf.merge(OmegaConf.create(), conf) assert OmegaConf.get_type(res) == input_ def test_merged_type2(self, module: Any) -> None: # Test that the merged type is that of the last merged config input_ = module.WithDictField conf = OmegaConf.structured(input_) res = OmegaConf.merge(conf, {"dict": {"foo": 99}}) assert OmegaConf.get_type(res) == input_ def test_merged_with_subclass(self, module: Any) -> None: # Test that the merged type is that of the last merged config c1 = OmegaConf.structured(module.Plugin) c2 = OmegaConf.structured(module.ConcretePlugin) res = OmegaConf.merge(c1, c2) assert OmegaConf.get_type(res) == module.ConcretePlugin def test_merge_missing_structured_on_self(self, module: Any) -> None: c1 = OmegaConf.structured(module.MissingStructuredConfigField) assert OmegaConf.is_missing(c1, "plugin") c2 = OmegaConf.merge(c1, module.MissingStructuredConfigField) assert OmegaConf.is_missing(c2, "plugin") def test_merge_missing_structured_config_is_missing(self, module: Any) -> None: c1 = OmegaConf.structured(module.MissingStructuredConfigField) assert OmegaConf.is_missing(c1, "plugin") def test_merge_missing_structured(self, module: Any) -> None: # Test that the merged type is that of the last merged config c1 = OmegaConf.create({"plugin": "???"}) c2 = OmegaConf.merge(c1, module.MissingStructuredConfigField) assert OmegaConf.is_missing(c2, "plugin") def test_merge_none_is_none(self, module: Any) -> None: # Test that the merged type is that of the last merged config c1 = OmegaConf.structured(module.StructuredOptional) assert c1.with_default == module.Nested() c2 = OmegaConf.merge(c1, {"with_default": None}) assert c2.with_default is None def test_merge_with_subclass_into_missing(self, module: Any) -> None: base = OmegaConf.structured(module.PluginHolder) assert _utils.get_ref_type(base, "missing") == module.Plugin assert OmegaConf.get_type(base, "missing") is None res = OmegaConf.merge(base, {"missing": module.Plugin}) assert OmegaConf.get_type(res) == module.PluginHolder assert _utils.get_ref_type(base, "missing") == module.Plugin assert OmegaConf.get_type(res, "missing") == module.Plugin def test_merged_with_nons_subclass(self, module: Any) -> None: c1 = OmegaConf.structured(module.Plugin) c2 = OmegaConf.structured(module.FaultyPlugin) with raises(ValidationError): OmegaConf.merge(c1, c2) def test_merge_into_Dict(self, module: Any) -> None: cfg = OmegaConf.structured(module.DictExamples) res = OmegaConf.merge(cfg, {"strings": {"x": "abc"}}) assert res.strings == {"a": "foo", "b": "bar", "x": "abc"} def test_merge_user_list_with_wrong_key(self, module: Any) -> None: cfg = OmegaConf.structured(module.UserList) with raises(ConfigKeyError): OmegaConf.merge(cfg, {"list": [{"foo": "var"}]}) def test_merge_list_with_correct_type(self, module: Any) -> None: cfg = OmegaConf.structured(module.UserList) user = module.User(name="John", age=21) res = OmegaConf.merge(cfg, {"list": [user]}) assert res.list == [user] def test_merge_dict_with_wrong_type(self, module: Any) -> None: cfg = OmegaConf.structured(module.UserDict) with raises(ValidationError): OmegaConf.merge(cfg, {"dict": {"foo": "var"}}) def test_merge_dict_with_correct_type(self, module: Any) -> None: cfg = OmegaConf.structured(module.UserDict) user = module.User(name="John", age=21) res = OmegaConf.merge(cfg, {"dict": {"foo": user}}) assert res.dict == {"foo": user} def test_dict_field_key_type_error(self, module: Any) -> None: input_ = module.ErrorDictObjectKey with raises(KeyValidationError): OmegaConf.structured(input_) def test_dict_field_value_type_error(self, module: Any) -> None: input_ = module.ErrorDictUnsupportedValue with raises(ValidationError): OmegaConf.structured(input_) def test_list_field_value_type_error(self, module: Any) -> None: input_ = module.ErrorListUnsupportedValue with raises(ValidationError): OmegaConf.structured(input_) @mark.parametrize("example", ["ListExamples", "TupleExamples"]) def test_list_examples(self, module: Any, example: str) -> None: input_ = getattr(module, example) conf = OmegaConf.structured(input_) def test_any(name: str) -> None: conf[name].append(True) conf[name].extend([Color.RED, 3.1415]) conf[name][2] = False assert conf[name] == [1, "foo", False, Color.RED, 3.1415] # any and untyped test_any("any") # test ints with raises(ValidationError): conf.ints[0] = "foo" conf.ints.append(10) assert conf.ints == [1, 2, 10] # test strings conf.strings.append(Color.BLUE) assert conf.strings == ["foo", "bar", "Color.BLUE"] # test booleans with raises(ValidationError): conf.booleans[0] = "foo" conf.booleans.append(True) conf.booleans.append("off") conf.booleans.append(1) assert conf.booleans == [True, False, True, False, True] # test colors with raises(ValidationError): conf.colors[0] = "foo" conf.colors.append(Color.BLUE) conf.colors.append("RED") conf.colors.append("Color.GREEN") conf.colors.append(3) assert conf.colors == [ Color.RED, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.BLUE, ] def test_dict_examples_any(self, module: Any) -> None: conf = OmegaConf.structured(module.DictExamples) dct = conf.any dct.c = True dct.d = Color.RED dct.e = 3.1415 assert dct == {"a": 1, "b": "foo", "c": True, "d": Color.RED, "e": 3.1415} def test_dict_examples_int(self, module: Any) -> None: conf = OmegaConf.structured(module.DictExamples) dct = conf.ints # test ints with raises(ValidationError): dct.a = "foo" dct.c = 10 assert dct == {"a": 10, "b": 20, "c": 10} def test_dict_examples_strings(self, module: Any) -> None: conf = OmegaConf.structured(module.DictExamples) # test strings conf.strings.c = Color.BLUE assert conf.strings == {"a": "foo", "b": "bar", "c": "Color.BLUE"} def test_dict_examples_bool(self, module: Any) -> None: conf = OmegaConf.structured(module.DictExamples) dct = conf.booleans # test bool with raises(ValidationError): dct.a = "foo" dct.c = True dct.d = "off" dct.e = 1 assert dct == { "a": True, "b": False, "c": True, "d": False, "e": True, } class TestDictExamples: @fixture def conf(self, module: Any) -> DictConfig: conf: DictConfig = OmegaConf.structured(module.DictExamples) return conf def test_dict_examples_colors(self, conf: DictConfig) -> None: dct = conf.colors # test colors with raises(ValidationError): dct.foo = "foo" dct.c = Color.BLUE dct.d = "RED" dct.e = "Color.GREEN" dct.f = 3 assert dct == { "red": Color.RED, "green": Color.GREEN, "blue": Color.BLUE, "c": Color.BLUE, "d": Color.RED, "e": Color.GREEN, "f": Color.BLUE, } def test_dict_examples_str_keys(self, conf: DictConfig) -> None: dct = conf.any with raises(KeyValidationError): dct[123] = "bad key type" dct["c"] = "three" assert dct == { "a": 1, "b": "foo", "c": "three", } def test_dict_examples_int_keys(self, conf: DictConfig) -> None: dct = conf.int_keys # test int keys with raises(KeyValidationError): dct.foo_key = "foo_value" dct[3] = "three" assert dct == { 1: "one", 2: "two", 3: "three", } def test_dict_examples_float_keys(self, conf: DictConfig) -> None: dct = conf.float_keys # test float keys with raises(KeyValidationError): dct.foo_key = "foo_value" dct[3.3] = "three" assert dct == { 1.1: "one", 2.2: "two", 3.3: "three", } def test_dict_examples_bool_keys(self, conf: DictConfig) -> None: dct = conf.bool_keys # test bool_keys with raises(KeyValidationError): dct.foo_key = "foo_value" dct[True] = "new value" assert dct == { True: "new value", False: "F", } def test_dict_examples_enum_key(self, conf: DictConfig) -> None: dct = conf.enum_key # When an Enum is a dictionary key the name of the Enum is actually used # as the key assert dct.RED == "red" assert dct["GREEN"] == "green" assert dct[Color.GREEN] == "green" dct["BLUE"] = "Blue too" assert dct[Color.BLUE] == "Blue too" with raises(KeyValidationError): dct["error"] = "error" def test_dict_of_objects(self, module: Any) -> None: conf = OmegaConf.structured(module.DictOfObjects) dct = conf.users assert dct.joe.age == 18 assert dct.joe.name == "Joe" dct.bond = module.User(name="James Bond", age=7) assert dct.bond.name == "James Bond" assert dct.bond.age == 7 with raises(ValidationError): dct.fail = "fail" def test_list_of_objects(self, module: Any) -> None: conf = OmegaConf.structured(module.ListOfObjects) assert conf.users[0].age == 18 assert conf.users[0].name == "Joe" conf.users.append(module.User(name="James Bond", age=7)) assert conf.users[1].name == "James Bond" assert conf.users[1].age == 7 with raises(ValidationError): conf.users.append("fail") def test_promote_api(self, module: Any) -> None: conf = OmegaConf.create(module.AnyTypeConfig) conf._promote(None) assert conf == OmegaConf.create(module.AnyTypeConfig) with raises(ValueError): conf._promote(42) assert conf == OmegaConf.create(module.AnyTypeConfig) def test_promote_to_class(self, module: Any) -> None: conf = OmegaConf.create(module.AnyTypeConfig) assert OmegaConf.get_type(conf) == module.AnyTypeConfig conf._promote(module.BoolConfig) assert OmegaConf.get_type(conf) == module.BoolConfig assert conf.with_default is True assert conf.null_default is None assert OmegaConf.is_missing(conf, "mandatory_missing") def test_promote_to_object(self, module: Any) -> None: conf = OmegaConf.create(module.AnyTypeConfig) assert OmegaConf.get_type(conf) == module.AnyTypeConfig conf._promote(module.BoolConfig(with_default=False)) assert OmegaConf.get_type(conf) == module.BoolConfig assert conf.with_default is False def test_set_key_with_with_dataclass(self, module: Any) -> None: cfg = OmegaConf.create({"foo": [1, 2]}) cfg.foo = module.ListClass() def test_set_list_correct_type(self, module: Any) -> None: cfg = OmegaConf.structured(module.ListClass) value = [1, 2, 3] cfg.list = value cfg.tuple = value assert cfg.list == value assert cfg.tuple == value @mark.parametrize("value", [1, True, "str", 3.1415, ["foo", True, 1.2], User()]) def test_assign_wrong_type_to_list(self, module: Any, value: Any) -> None: cfg = OmegaConf.structured(module.ListClass) with raises(ValidationError): cfg.list = value with raises(ValidationError): cfg.tuple = value assert cfg == OmegaConf.structured(module.ListClass) @mark.parametrize( "value", [ 1, True, "str", 3.1415, ["foo", True, 1.2], {"foo": True}, User(age=1, name="foo"), {"user": User(age=1, name="foo")}, ListConfig(content=[1, 2], ref_type=List[int], element_type=int), ], ) def test_assign_wrong_type_to_dict(self, module: Any, value: Any) -> None: cfg = OmegaConf.structured(module.ConfigWithDict2) with raises(ValidationError): cfg.dict1 = value assert cfg == OmegaConf.structured(module.ConfigWithDict2) def test_recursive_dict(self, module: Any) -> None: rd = module.RecursiveDict o = rd(d={"a": rd(), "b": rd()}) cfg = OmegaConf.structured(o) assert cfg == { "d": { "a": {"d": "???"}, "b": {"d": "???"}, } } def test_recursive_list(self, module: Any) -> None: rl = module.RecursiveList o = rl(d=[rl(), rl()]) cfg = OmegaConf.structured(o) assert cfg == {"d": [{"d": "???"}, {"d": "???"}]} def test_create_untyped_dict(self, module: Any) -> None: cfg = OmegaConf.structured(module.UntypedDict) assert _utils.get_ref_type(cfg, "dict") == Dict[Any, Any] assert _utils.get_ref_type(cfg, "opt_dict") == Optional[Dict[Any, Any]] assert cfg.dict == {"foo": "var"} assert cfg.opt_dict is None def test_create_untyped_list(self, module: Any) -> None: cfg = OmegaConf.structured(module.UntypedList) assert _utils.get_ref_type(cfg, "list") == List[Any] assert _utils.get_ref_type(cfg, "opt_list") == Optional[List[Any]] assert cfg.list == [1, 2] assert cfg.opt_list is None
id="list:get_nox_ex:invalid_index_type", ), pytest.param( Expected( create=lambda: OmegaConf.create([1, 2, 3]), op=lambda cfg: cfg._get_node(20), exception_type=IndexError, msg="list index out of range", key=20, full_key="[20]", ), id="list:get_node_ex:index_out_of_range", ), pytest.param( Expected( create=lambda: ListConfig(content=None), op=lambda cfg: cfg._get_node(20), exception_type=TypeError, msg="Cannot get_node from a ListConfig object representing None", key=20, full_key="[20]", ref_type=Optional[List[Any]], ), id="list:get_node_none", ), pytest.param( Expected( create=lambda: ListConfig(content="???"), op=lambda cfg: cfg._get_node(20), exception_type=MissingMandatoryValue, msg="Cannot get_node from a missing ListConfig",
def test_append_convert(lc: ListConfig, element: Any, expected: Any) -> None: lc.append(element) value = lc[-1] assert value == expected assert type(value) == type(expected)
lambda value, is_optional, key=None: EnumNode( enum_type=Color, value=value, is_optional=is_optional, key=key ), [Color.RED], ), # DictConfig ( lambda value, is_optional, key=None: DictConfig( is_optional=is_optional, content=value, key=key), [{}, { "foo": "bar" }], ), # ListConfig ( lambda value, is_optional, key=None: ListConfig( is_optional=is_optional, content=value, key=key), [[], [1, 2, 3]], ), # dataclass ( lambda value, is_optional, key=None: DictConfig(ref_type=Group, is_optional= is_optional, content=value, key=key), [Group, Group()], ), ], ids=( "BooleanNode", "FloatNode",
def test_append_invalid_element_type(lc: ListConfig, element: Any, expected: Any) -> None: with expected: lc.append(element)
def test_listconfig_creation_with_parent_flag(flag: str) -> None: parent = OmegaConf.create([]) parent._set_flag(flag, True) d = [1, 2, 3] cfg = ListConfig(d, parent=parent) assert cfg == d
def test_shallow_copy_none() -> None: cfg = ListConfig(content=None) c = cfg.copy() c._set_value([1]) assert c[0] == 1 assert cfg._is_none()
def test_shallow_copy_missing() -> None: cfg = ListConfig(content=MISSING) c = cfg.copy() c._set_value([1]) assert c[0] == 1 assert cfg._is_missing()
def test_list_of_dicts() -> None: v = [dict(key1="value1"), dict(key2="value2")] c = OmegaConf.create(v) assert c[0].key1 == "value1" assert c[1].key2 == "value2" @mark.parametrize("default", [None, 0, "default"]) @mark.parametrize( ("cfg", "key"), [ (["???"], 0), ([DictConfig(content="???")], 0), ([ListConfig(content="???")], 0), ], ) def test_list_get_return_default(cfg: List[Any], key: int, default: Any) -> None: c = OmegaConf.create(cfg) val = c.get(key, default_value=default) assert val is default @mark.parametrize("default", [None, 0, "default"]) @mark.parametrize( ("cfg", "key", "expected"), [ (["found"], 0, "found"), ([None], 0, None),