Beispiel #1
0
def test_divide_nested_grid_search_options_nested_options(make_classes):
    A, B = make_classes
    txt = """
!A
akw1: 8
akw2: !g
  - !B
    bkw1: !g [4, 5]
    bkw2: 'first'
  - !B
    bkw1: 3
    bkw2: !g ['second', 'third']
"""
    txt_1 = """
!A
akw1: 8
akw2: !B
  bkw1: !g [4, 5]
  bkw2: 'first'
"""
    txt_2 = """
!A
akw1: 8
akw2: !B
  bkw1: 3
  bkw2: !g ['second', 'third']
"""
    config = yaml.load(txt)
    config1 = yaml.load(txt_1)
    config2 = yaml.load(txt_2)
    divided_configs = list(divide_nested_grid_search_options(config))
    assert repr(divided_configs) == repr([config1, config2])
Beispiel #2
0
def test_registrable_factory_roundtrip_alias(make_aliased_classes):
    A, B = make_aliased_classes

    txt = """a: !a_class.some_factory
  akw1: 8
  akw2: !b_
    bkw1: 2
    bkw2: hello world
"""
    txt_default_alias = """a: !a_class.some_factory
  akw1: 8
  akw2: !b_
    bkw1: 2
    bkw2: hello world
"""
    config = yaml.load(txt)
    a = config['a']
    assert a.akw1 == 8
    assert a.akw2 is not None
    assert hasattr(a.akw2, "bkw1")
    assert a.akw2.bkw1 == 2
    assert isinstance(a, A)
    with StringIO() as s:
        yaml.dump(config, s)
        assert s.getvalue() == txt_default_alias
def complex_builder(from_config):
    if from_config:
        config = """
!ComposableTorchStateful
a: !ComposableTorchStateful
  a: !ComposableTorchStateful
    a: !BasicStateful {}
    b: 2021
    c: !torch.Linear
      in_features: 2
      out_features: 2
  b: 2022
  c: !torch.Linear
    in_features: 2
    out_features: 2
b: 2023
c: !torch.Linear
  in_features: 2
  out_features: 2
"""
        obj = yaml.load(config)()
        return obj
    else:
        a1 = BasicStateful()
        b1 = 2021
        c1 = torch.nn.Linear(2, 2)
        a2 = ComposableTorchStateful(a1, b1, c1)
        b2 = 2022
        c2 = torch.nn.Linear(2, 2)
        a3 = ComposableTorchStateful(a2, b2, c2)
        b3 = 2023
        c3 = torch.nn.Linear(2, 2)
        obj = ComposableTorchStateful(a3, b3, c3)
        return obj
Beispiel #4
0
    def test_add_extensions_metadata_2(self):
        """Test that add_extensions_metadata doesn't add extensions that are not used.

        In this case we will use a config containing torch, but we will make_component
        on torch so that it can be compiled. After that, we add_extensions_metadata with
        torch, which is a valid extensions for the config (redundant, but valid).
        
        """
        TORCH_TAG_PREFIX = "torch"
        make_component(torch.nn.Module,
                       TORCH_TAG_PREFIX,
                       only_module='torch.nn')

        config = """
        !torch.Linear
          in_features: 2
          out_features: 2
        """

        schema = yaml.load(config)
        schema.add_extensions_metadata({"torch": "torch"})
        assert schema._extensions == {"torch": "torch"}

        mixed_ext = TestSerializationExtensions.EXTENSIONS.copy()
        mixed_ext.update({"torch": "torch"})
        schema.add_extensions_metadata(mixed_ext)
        assert schema._extensions == {"torch": "torch"}
Beispiel #5
0
 def _factory(from_config, x=-1):
     if from_config:
         config = f"!BasicStateful\nx: {x}\n"
         obj = yaml.load(config)()
         return obj
     else:
         obj = BasicStateful(x=x)
     return obj
Beispiel #6
0
 def _factory(from_config, x=-1):
     if from_config:
         config = f"!RootTorch\nx: {x}\n"
         obj = yaml.load(config)()
         return obj
     else:
         obj = RootTorch(x=x)
     return obj
Beispiel #7
0
 def _factory(from_config):
     if from_config:
         config = f"!{class_.__name__} {{}}\n"
         obj = yaml.load(config)()
         return obj
     else:
         obj = class_()
     return obj
Beispiel #8
0
def test_divide_nested_grid_search_options_non_nested_options(make_classes):
    A, B = make_classes
    txt = """
!A
akw1: 8
akw2: !B
  bkw1: !g [1, 3, 5]
  bkw2: 'hello'
"""
    config = yaml.load(txt)
    divided_configs = list(divide_nested_grid_search_options(config))
    assert repr(divided_configs) == repr([config])
Beispiel #9
0
def test_component_override(make_classes):
    A, B = make_classes

    txt = """
!A
akw1: 8
akw2: !B
  bkw1: 1
  bkw2: 'test'
"""
    a_schema = yaml.load(txt)
    a = a_schema(akw1=9)
    assert a.akw1 == 9
Beispiel #10
0
def test_registrable_roundtrip(make_classes):
    A, B = make_classes

    txt = """a: !A
  akw1: 8
  akw2: !B
    bkw1: 2
    bkw2: hello world
"""
    config = yaml.load(txt)
    with StringIO() as s:
        yaml.dump(config, s)
        assert s.getvalue() == txt
Beispiel #11
0
def test_registrable_factory(make_classes):
    A, B = make_classes

    txt = """a: !A.some_factory
  akw1: 8
  akw2: !B
    bkw1: 2
    bkw2: hello world
"""
    config = yaml.load(txt)
    a = config['a']
    assert a.akw1 == 8
    assert a.akw2 is not None
    assert hasattr(a.akw2, "bkw1")
    assert a.akw2.bkw1 == 2
Beispiel #12
0
def test_registrable_load_new_class(make_new_classes):
    A, B = make_new_classes

    txt = """a: !a_class
  akw1: 8
  akw2: !b_class
    bkw1: 2
    bkw2: hello world
"""
    config = yaml.load(txt)
    a = config['a']
    assert a.akw1 == 8
    assert a.akw2 is not None
    assert hasattr(a.akw2, "bkw1")
    assert a.akw2.bkw1 == 2
Beispiel #13
0
def test_component_basic_top_level(make_classes):
    A, B = make_classes

    txt = """
!A
akw1: 8
akw2: !B
  bkw1: 1
  bkw2: 'test'
"""
    a_schema = yaml.load(txt)
    a = a_schema()
    assert a.akw1 == 8
    assert a.akw2 is not None
    assert a.akw2.bkw1 == 1
Beispiel #14
0
def test_registrable_load_context(make_namespace_classes):
    A, B = make_namespace_classes

    txt = """a: !ns.A
  akw1: 8
  akw2: !ns.B
    bkw1: 2
    bkw2: hello world
"""
    config = yaml.load(txt)
    a = config['a']
    assert a.akw1 == 8
    assert a.akw2 is not None
    assert hasattr(a.akw2, "bkw1")
    assert a.akw2.bkw1 == 2
Beispiel #15
0
def test_component_basic(make_classes):
    A, B = make_classes

    txt = """
top: !A
  akw1: 8
  akw2: !B
    bkw1: 1
    bkw2: 'test'
"""

    config = yaml.load(txt)
    a = config['top']()
    assert a.akw1 == 8
    assert a.akw2 is not None
    assert a.akw2.bkw1 == 1
Beispiel #16
0
    def test_contains(self, make_classes):
        A, B = make_classes

        txt = """
top: !A
  akw1: 8
  akw2: !B
    bkw1: 1
    bkw2: 'test'
"""
        schema_a = yaml.load(txt)['top']
        assert schema_a.contains(schema_a, original_link=None)[0]
        present, updated_path = schema_a.contains(schema_a.akw2, original_link=None)
        print(f"present: {present}, updated_path: {updated_path}")
        assert present
        assert updated_path == ['akw2']
        assert not schema_a.akw2.contains(schema_a, original_link=None)[0]
Beispiel #17
0
def test_component_schema_dict_access(make_classes):
    A, B = make_classes

    txt = """
!A
akw1: 8
akw2: !B
  bkw1: 1
  bkw2: 'test'
"""
    a_schema = yaml.load(txt)
    assert a_schema['akw1'] == 8
    assert a_schema['akw2']['bkw2'] == 'test'
    a_schema['akw2']['bkw1'] = 13
    assert a_schema['akw2']['bkw1'] == 13
    a_schema.keywords['akw2'].keywords['bkw1'] = 14
    a = a_schema()
    assert a.akw2.bkw1 == 14
Beispiel #18
0
def test_component_anchors_compile_to_same_instance(make_classes_2):

    txt = """
one: !A
  akw2: &theb !B
    bkw2: test
    bkw1: 1
  akw1: 8
two: !A
  akw1: 8
  # Comment Here
  akw2: *theb
"""
    config = yaml.load(txt)
    a1 = config["one"]()
    a1.akw2.bkw1 = 6
    a2 = config["two"]()
    assert a1.akw2 is a2.akw2
    assert a1.akw2.bkw1 == a2.akw2.bkw1
Beispiel #19
0
def test_registrable_roundtrip_new_default(make_new_classes):
    A, B = make_new_classes

    txt = """a: !a_class
  akw1: 8
  akw2: !b_
    bkw1: 2
    bkw2: hello world
"""
    txt_default_alias = """a: !a_class
  akw1: 8
  akw2: !b_
    bkw1: 2
    bkw2: hello world
"""
    config = yaml.load(txt)
    with StringIO() as s:
        yaml.dump(config, s)
        assert s.getvalue() == txt_default_alias
Beispiel #20
0
def complex_builder_nontorch_root(from_config, schema=False, x=-1):
    if from_config:
        config = """
!ComposableContainer
item:
  !ComposableTorchStatefulPrime
  a: !ComposableTorchStateful
    a: !ComposableTorchStateful
      a: !BasicStateful
        x: {}
      b: 2021
      c: !torch.Linear
        in_features: 2
        out_features: 2
    b: 2022
    c: !torch.Linear
      in_features: 2
      out_features: 2
  b: 2023
  c: !torch.Linear
    in_features: 2
    out_features: 2
"""
        config.format(x)
        obj = yaml.load(config)
        if not schema:
            obj = obj()
        return obj
    else:
        a1 = BasicStateful(x=x)
        b1 = 2021
        c1 = torch.nn.Linear(2, 2)
        a2 = ComposableTorchStateful(a1, b1, c1)
        b2 = 2022
        c2 = torch.nn.Linear(2, 2)
        a3 = ComposableTorchStateful(a2, b2, c2)
        b3 = 2023
        c3 = torch.nn.Linear(2, 2)
        item = ComposableTorchStateful(a3, b3, c3)
        obj = ComposableContainer(item)
        return obj
Beispiel #21
0
    def compile_runnable(self, content: str) -> Runnable:
        """Compiles and returns the Runnable.

        IMPORTANT: This method should run after all
        extensions were registered.

        Parameters
        ----------
        content: str
            The runnable, as a YAML string

        Returns
        -------
        Runnable
            The compiled experiment.

        """
        ret: Any = yaml.load(content)
        if not isinstance(ret, Runnable):
            raise ValueError("Tried to run a non-Runnable")
        cast(Runnable, ret)
        return ret
Beispiel #22
0
def test_component_dumping_with_defaults_and_comments(make_classes_2):
    A, B = make_classes_2

    txt = """!A
akw1: 8
# Comment Here
akw2: !B
  bkw1: 1
  bkw2: !!str test
"""
    txt_expected = """!A
akw1: 8
akw2: !B
  bkw1: 1
  bkw2: test
  bkw3: 99
"""
    a_schema = yaml.load(txt)
    a = a_schema()
    with StringIO() as stream:
        yaml.dump(a, stream)
        assert txt_expected == stream.getvalue()
    def test_module_save_and_load_single_instance_appears_twice(self, make_classes_2):
        txt = """
!C
one: !A
  akw2: &theb !B
    bkw2: test
    bkw1: 1
  akw1: 8
two: !A
  akw1: 8
  # Comment Here
  akw2: *theb
"""
        c = yaml.load(txt)()
        c.one.akw2.bkw1 = 6
        assert c.one.akw2 is c.two.akw2
        assert c.one.akw2.bkw1 == c.two.akw2.bkw1
        with tempfile.TemporaryDirectory() as path:
            save(c, path)
            state = load_state_from_file(path)
            loaded_c = load(path)
        assert loaded_c.one.akw2 is loaded_c.two.akw2
        assert loaded_c.one.akw2.bkw1 == loaded_c.two.akw2.bkw1
Beispiel #24
0
def schema_builder():
    config = """
!Basic
"""
    obj = yaml.load(config)
    return obj
Beispiel #25
0
 def _factory():
     config = "!A {}\n"
     return yaml.load(config)()