Example #1
0
    def test_pickle(self):
        def func(a: int = 1, b: str = "3"):
            pass

        enable_get_default_args(func)

        args = get_default_args(func)
        args2 = pickle.loads(pickle.dumps(args))
        self.assertEqual(args2.a, 1)
        self.assertEqual(args2.b, "3")

        args_regenerated = get_default_args(func)
        pickle.dumps(args_regenerated)
        pickle.dumps(args)
Example #2
0
    def test_irrelevant_bases(self):
        class NotADataclass:
            # Like torch.nn.Module, this class contains annotations
            # but is not designed to be dataclass'd.
            # This test ensures that such classes, when inherited fron,
            # are not accidentally affected by expand_args_fields.
            a: int = 9
            b: int

        class LeftConfigured(Configurable, NotADataclass):
            left: int = 1

        class RightConfigured(NotADataclass, Configurable):
            right: int = 2

        class Outer(Configurable):
            left: LeftConfigured
            right: RightConfigured

            def __post_init__(self):
                run_auto_creation(self)

        outer = Outer(**get_default_args(Outer))
        self.assertEqual(outer.left.left, 1)
        self.assertEqual(outer.right.right, 2)
        with self.assertRaisesRegex(TypeError, "non-default argument"):
            dataclass(NotADataclass)
Example #3
0
    def test_inheritance2(self):
        # This is a case where a class could contain an instance
        # of a subclass, which is ignored.
        class Parent(ReplaceableBase):
            pass

        class Main(Configurable):
            parent: Parent
            # Note - no __post__init__

        @registry.register
        class Derived(Parent, Main):
            pass

        args = get_default_args(Main)
        # Derived has been ignored in processing Main.
        self.assertCountEqual(args.keys(), ["parent_class_type"])

        main = Main(**args)

        with self.assertRaisesRegex(ValueError,
                                    "UNDEFAULTED has not been registered."):
            run_auto_creation(main)

        main.parent_class_type = "Derived"
        # Illustrates that a dict works fine instead of a DictConfig.
        main.parent_Derived_args = {}
        with self.assertRaises(AttributeError):
            main.parent
        run_auto_creation(main)
        self.assertIsInstance(main.parent, Derived)
Example #4
0
    def test_no_replacement(self):
        # Test of Configurables without ReplaceableBase
        class A(Configurable):
            n: int = 9

        class B(Configurable):
            a: A

            def __post_init__(self):
                run_auto_creation(self)

        class C(Configurable):
            b1: B
            b2: Optional[B]
            b3: Optional[B]
            b2_enabled: bool = True

            def __post_init__(self):
                run_auto_creation(self)

        c_args = get_default_args(C)
        c = C(**c_args)
        self.assertIsInstance(c.b1.a, A)
        self.assertEqual(c.b1.a.n, 9)
        self.assertFalse(hasattr(c, "b1_enabled"))
        self.assertIsInstance(c.b2.a, A)
        self.assertEqual(c.b2.a.n, 9)
        self.assertTrue(c.b2_enabled)
        self.assertIsNone(c.b3)
        self.assertFalse(c.b3_enabled)
Example #5
0
    def test_recursion(self):
        class Shape(ReplaceableBase):
            pass

        @registry.register
        class Triangle(Shape):
            a: float = 5.0

        @registry.register
        class Square(Shape):
            a: float = 3.0

        @registry.register
        class LargeShape(Shape):
            inner: Shape

            def __post_init__(self):
                run_auto_creation(self)

        class ShapeContainer(Configurable):
            shape: Shape

        container = ShapeContainer(**get_default_args(ShapeContainer))
        # This is because ShapeContainer is missing __post_init__
        with self.assertRaises(AttributeError):
            container.shape

        class ShapeContainer2(Configurable):
            x: Shape
            x_class_type: str = "LargeShape"

            def __post_init__(self):
                self.x_LargeShape_args.inner_class_type = "Triangle"
                run_auto_creation(self)

        container2_args = get_default_args(ShapeContainer2)
        container2_args.x_LargeShape_args.inner_Triangle_args.a += 10
        self.assertIn("inner_Square_args", container2_args.x_LargeShape_args)
        # We do not perform expansion that would result in an infinite recursion,
        # so this member is not present.
        self.assertNotIn("inner_LargeShape_args",
                         container2_args.x_LargeShape_args)
        container2_args.x_LargeShape_args.inner_Square_args.a += 100
        container2 = ShapeContainer2(**container2_args)
        self.assertIsInstance(container2.x, LargeShape)
        self.assertIsInstance(container2.x.inner, Triangle)
        self.assertEqual(container2.x.inner.a, 15.0)
Example #6
0
    def test_remove_unused_components_optional(self):
        class MainTestWrapper(Configurable):
            mt: Optional[MainTest]

        args = get_default_args(MainTestWrapper)
        self.assertEqual(list(args.keys()), ["mt_args", "mt_enabled"])
        remove_unused_components(args)
        self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
Example #7
0
    def test_raw_types(self):
        @dataclass
        class MyDataclass:
            int_field: int = 0
            none_field: Optional[int] = None
            float_field: float = 9.3
            bool_field: bool = True
            tuple_field: tuple = (3, True, "j")

        class SimpleClass:
            def __init__(
                    self,
                    tuple_member_: Tuple[int, int] = (3, 4),
                    set_member_: Set[int] = {2},  # noqa
            ):
                self.tuple_member = tuple_member_
                self.set_member = set_member_

            def get_tuple(self):
                return self.tuple_member

        enable_get_default_args(SimpleClass)

        def f(*, a: int = 3, b: str = "kj"):
            self.assertEqual(a, 3)
            self.assertEqual(b, "kj")

        enable_get_default_args(f)

        class C(Configurable):
            simple: DictConfig = get_default_args_field(SimpleClass)
            # simple2: SimpleClass2 = SimpleClass2()
            mydata: DictConfig = get_default_args_field(MyDataclass)
            a_tuple: Tuple[float] = (4.0, 3.0)
            f_args: DictConfig = get_default_args_field(f)

        args = get_default_args(C)
        c = C(**args)
        self.assertCountEqual(args.keys(),
                              ["simple", "mydata", "a_tuple", "f_args"])

        mydata = MyDataclass(**c.mydata)
        simple = SimpleClass(**c.simple)

        # OmegaConf converts tuples to ListConfigs (which act like lists).
        self.assertEqual(simple.get_tuple(), [3, 4])
        self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
        # get_default_args converts sets to ListConfigs (which act like lists).
        self.assertEqual(simple.set_member, [2])
        self.assertTrue(isinstance(simple.set_member, ListConfig))
        self.assertEqual(c.a_tuple, [4.0, 3.0])
        self.assertTrue(isinstance(c.a_tuple, ListConfig))
        self.assertEqual(mydata.tuple_field, (3, True, "j"))
        self.assertTrue(isinstance(mydata.tuple_field, ListConfig))
        f(**c.f_args)
Example #8
0
 def test_get_default_args(self):
     for cls in [MockDataclass, MockClassWithInit]:
         dataclass_defaults = get_default_args(cls)
         # DictConfig fields with missing values are `not in`
         self.assertNotIn("field_no_default", dataclass_defaults)
         self.assertNotIn("field_no_nothing", dataclass_defaults)
         self.assertNotIn("field_reference_type", dataclass_defaults)
         if cls == MockDataclass:  # we don't remove undefaulted from dataclasses
             dataclass_defaults.field_no_default = 0
         for name, val in dataclass_defaults.items():
             self.assertTrue(hasattr(self._instances[cls], name))
             self.assertEqual(val, getattr(self._instances[cls], name))
Example #9
0
    def test_simple_replacement(self):
        struct = get_default_args(MainTest)
        struct.n_ids = 9780
        struct.the_fruit_Pear_args.n_pips = 3
        struct.the_fruit_class_type = "Pear"
        struct.the_second_fruit_class_type = "Pear"

        main = MainTest(**struct)
        self.assertIsInstance(main.the_fruit, Pear)
        self.assertEqual(main.n_reps, 8)
        self.assertEqual(main.n_ids, 9780)
        self.assertEqual(main.the_fruit.n_pips, 3)
        self.assertIsInstance(main.the_second_fruit, Pineapple)

        struct2 = get_default_args(MainTest)
        self.assertEqual(struct2.the_fruit_Pear_args.n_pips, 13)

        self.assertEqual(
            MainTest._creation_functions,
            ("create_the_fruit", "create_the_second_fruit"),
        )
Example #10
0
 def test_create_gm(self):
     args = get_default_args(GenericModel)
     gm = GenericModel(**args)
     self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer)
     self.assertIsInstance(
         gm.feature_aggregator, AngleWeightedReductionFeatureAggregator
     )
     self.assertIsInstance(
         gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
     )
     self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
     self.assertFalse(hasattr(gm, "implicit_function"))
     self.assertFalse(hasattr(gm, "image_feature_extractor"))
Example #11
0
    def test_simpleclass_member(self):
        # Members which are not dataclasses are
        # tolerated. But it would be nice to be able to
        # configure them.
        class Foo:
            def __init__(self, a: Any = 1, b: Any = 2):
                self.a, self.b = a, b

        enable_get_default_args(Foo)

        @dataclass()
        class Bar:
            aa: int = 9
            bb: int = 9

        class Container(Configurable):
            bar: Bar = Bar()
            # TODO make this work?
            # foo: Foo = Foo()
            fruit: Fruit
            fruit_class_type: str = "Orange"

            def __post_init__(self):
                run_auto_creation(self)

        self.assertEqual(get_default_args(Foo), {"a": 1, "b": 2})
        container_args = get_default_args(Container)
        container = Container(**container_args)
        self.assertIsInstance(container.fruit, Orange)
        self.assertEqual(Container._processed_members, {"fruit": Fruit})
        self.assertEqual(container._processed_members, {"fruit": Fruit})

        container_defaulted = Container()
        container_defaulted.fruit_Pear_args.n_pips += 4

        container_args2 = get_default_args(Container)
        container = Container(**container_args2)
        self.assertEqual(container.fruit_Pear_args.n_pips, 13)
Example #12
0
    def test_create_gm_overrides(self):
        args = get_default_args(GenericModel)
        args.feature_aggregator_class_type = "AngleWeightedIdentityFeatureAggregator"
        args.implicit_function_class_type = "IdrFeatureField"
        args.renderer_class_type = "LSTMRenderer"
        gm = GenericModel(**args)
        self.assertIsInstance(gm.renderer, LSTMRenderer)
        self.assertIsInstance(
            gm.feature_aggregator, AngleWeightedIdentityFeatureAggregator
        )
        self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
        self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
        self.assertFalse(hasattr(gm, "implicit_function"))

        instance_args = OmegaConf.structured(gm)
        remove_unused_components(instance_args)
        yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
        if DEBUG:
            (DATA_DIR / "overrides.yaml_").write_text(yaml)
        self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
Example #13
0
    def test_remove_unused_components(self):
        struct = get_default_args(MainTest)
        struct.n_ids = 32
        struct.the_fruit_class_type = "Pear"
        struct.the_second_fruit_class_type = "Banana"
        remove_unused_components(struct)
        expected_keys = [
            "n_ids",
            "n_reps",
            "the_fruit_Pear_args",
            "the_fruit_class_type",
            "the_second_fruit_Banana_args",
            "the_second_fruit_class_type",
        ]
        expected_yaml = textwrap.dedent("""\
            n_ids: 32
            n_reps: 8
            the_fruit_class_type: Pear
            the_fruit_Pear_args:
              n_pips: 13
            the_second_fruit_class_type: Banana
            the_second_fruit_Banana_args:
              pips: ???
              spots: ???
              bananame: ???
            """)
        self.assertEqual(sorted(struct.keys()), expected_keys)

        # Check that struct is what we expect
        expected = OmegaConf.create(expected_yaml)
        self.assertEqual(struct, expected)

        # Check that we get what we expect when writing to yaml.
        self.assertEqual(OmegaConf.to_yaml(struct, sort_keys=False),
                         expected_yaml)

        main = MainTest(**struct)
        instance_data = OmegaConf.structured(main)
        remove_unused_components(instance_data)
        self.assertEqual(sorted(instance_data.keys()), expected_keys)
        self.assertEqual(instance_data, expected)
Example #14
0
    def test_inheritance(self):
        # Also exercises optional replaceables
        class FruitBowl(ReplaceableBase):
            main_fruit: Fruit
            main_fruit_class_type: str = "Orange"

            def __post_init__(self):
                raise ValueError("This doesn't get called")

        class LargeFruitBowl(FruitBowl):
            extra_fruit: Optional[Fruit]
            extra_fruit_class_type: str = "Kiwi"
            no_fruit: Optional[Fruit]
            no_fruit_class_type: Optional[str] = None

            def __post_init__(self):
                run_auto_creation(self)

        large_args = get_default_args(LargeFruitBowl)
        self.assertNotIn("extra_fruit", large_args)
        self.assertNotIn("main_fruit", large_args)
        large = LargeFruitBowl(**large_args)
        self.assertIsInstance(large.main_fruit, Orange)
        self.assertIsInstance(large.extra_fruit, Kiwi)
        self.assertIsNone(large.no_fruit)
        self.assertIn("no_fruit_Kiwi_args", large_args)

        remove_unused_components(large_args)
        large2 = LargeFruitBowl(**large_args)
        self.assertIsInstance(large2.main_fruit, Orange)
        self.assertIsInstance(large2.extra_fruit, Kiwi)
        self.assertIsNone(large2.no_fruit)
        needed_args = [
            "extra_fruit_Kiwi_args",
            "extra_fruit_class_type",
            "main_fruit_Orange_args",
            "main_fruit_class_type",
            "no_fruit_class_type",
        ]
        self.assertEqual(sorted(large_args.keys()), needed_args)
Example #15
0
    def test_enum(self):
        # Test that enum values are kept, i.e. that OmegaConf's runtime checks
        # are in use.

        class A(Enum):
            B1 = "b1"
            B2 = "b2"

        # Test for a Configurable class, a function, and a regular class.
        class C(Configurable):
            a: A = A.B1

        # Also test for a calllable with enum arguments.
        def C_fn(a: A = A.B1):
            pass

        enable_get_default_args(C_fn)

        class C_cl:
            def __init__(self, a: A = A.B1) -> None:
                pass

        enable_get_default_args(C_cl)

        for C_ in [C, C_fn, C_cl]:
            base = get_default_args(C_)
            self.assertEqual(base.a, A.B1)
            replaced = OmegaConf.merge(base, {"a": "B2"})
            self.assertEqual(replaced.a, A.B2)
            with self.assertRaises(ValidationError):
                # You can't use a value which is not one of the
                # choices, even if it is the str representation
                # of one of the choices.
                OmegaConf.merge(base, {"a": "b2"})

            remerged = OmegaConf.merge(
                base, OmegaConf.create(OmegaConf.to_yaml(base)))
            self.assertEqual(remerged.a, A.B1)
Example #16
0
    def test_doc(self):
        # The case in the docstring.
        class A(ReplaceableBase):
            k: int = 1

        @registry.register
        class A1(A):
            m: int = 3

        @registry.register
        class A2(A):
            n: str = "2"

        class B(Configurable):
            a: A
            a_class_type: str = "A2"

            def __post_init__(self):
                run_auto_creation(self)

        b_args = get_default_args(B)
        self.assertNotIn("a", b_args)
        b = B(**b_args)
        self.assertEqual(b.a.n, "2")
Example #17
0
 def test_get_default_args_readonly(self):
     for cls in [MockDataclass, MockClassWithInit]:
         dataclass_defaults = get_default_args(cls)
         dataclass_defaults["field_list_type"].append(13)
         self.assertEqual(self._instances[cls].field_list_type, [])
Example #18
0
 def setUp(self) -> None:
     torch.manual_seed(42)
     get_default_args(SRNHyperNetImplicitFunction)
     get_default_args(SRNImplicitFunction)
Example #19
0
    def test_redefine(self):
        class FruitBowl(ReplaceableBase):
            main_fruit: Fruit
            main_fruit_class_type: str = "Grape"

            def __post_init__(self):
                run_auto_creation(self)

        @registry.register
        @dataclass
        class Grape(Fruit):
            large: bool = False

            def get_color(self):
                return "red"

            def __post_init__(self):
                raise ValueError("This doesn't get called")

        bowl_args = get_default_args(FruitBowl)

        @registry.register
        @dataclass
        class Grape(Fruit):  # noqa: F811
            large: bool = True

            def get_color(self):
                return "green"

        with self.assertWarnsRegex(
                UserWarning, "New implementation of Grape is being chosen."):
            bowl = FruitBowl(**bowl_args)
        self.assertIsInstance(bowl.main_fruit, Grape)

        # Redefining the same class won't help with defaults because encoded in args
        self.assertEqual(bowl.main_fruit.large, False)

        # But the override worked.
        self.assertEqual(bowl.main_fruit.get_color(), "green")

        # 2. Try redefining without the dataclass modifier
        # This relies on the fact that default creation processes the class.
        # (otherwise incomprehensible messages)
        @registry.register
        class Grape(Fruit):  # noqa: F811
            large: bool = True

        with self.assertWarnsRegex(
                UserWarning, "New implementation of Grape is being chosen."):
            bowl = FruitBowl(**bowl_args)

        # 3. Adding a new class doesn't get picked up, because the first
        # get_default_args call has frozen FruitBowl. This is intrinsic to
        # the way dataclass and expand_args_fields work in-place but
        # expand_args_fields is not pure - it depends on the registry.
        @registry.register
        class Fig(Fruit):
            pass

        bowl_args2 = get_default_args(FruitBowl)
        self.assertIn("main_fruit_Grape_args", bowl_args2)
        self.assertNotIn("main_fruit_Fig_args", bowl_args2)