def test_divide_nested_grid_search_options_nested_options(make_classes): A, B = make_classes txt = """ !A akw1: 8 akw2: !g - !B bkw1: !g [4, 5] bkw2: 'first' - !B bkw1: 3 bkw2: !g ['second', 'third'] """ txt_1 = """ !A akw1: 8 akw2: !B bkw1: !g [4, 5] bkw2: 'first' """ txt_2 = """ !A akw1: 8 akw2: !B bkw1: 3 bkw2: !g ['second', 'third'] """ config = yaml.load(txt) config1 = yaml.load(txt_1) config2 = yaml.load(txt_2) divided_configs = list(divide_nested_grid_search_options(config)) assert repr(divided_configs) == repr([config1, config2])
def test_registrable_factory_roundtrip_alias(make_aliased_classes): A, B = make_aliased_classes txt = """a: !a_class.some_factory akw1: 8 akw2: !b_ bkw1: 2 bkw2: hello world """ txt_default_alias = """a: !a_class.some_factory akw1: 8 akw2: !b_ bkw1: 2 bkw2: hello world """ config = yaml.load(txt) a = config['a'] assert a.akw1 == 8 assert a.akw2 is not None assert hasattr(a.akw2, "bkw1") assert a.akw2.bkw1 == 2 assert isinstance(a, A) with StringIO() as s: yaml.dump(config, s) assert s.getvalue() == txt_default_alias
def complex_builder(from_config): if from_config: config = """ !ComposableTorchStateful a: !ComposableTorchStateful a: !ComposableTorchStateful a: !BasicStateful {} b: 2021 c: !torch.Linear in_features: 2 out_features: 2 b: 2022 c: !torch.Linear in_features: 2 out_features: 2 b: 2023 c: !torch.Linear in_features: 2 out_features: 2 """ obj = yaml.load(config)() return obj else: a1 = BasicStateful() b1 = 2021 c1 = torch.nn.Linear(2, 2) a2 = ComposableTorchStateful(a1, b1, c1) b2 = 2022 c2 = torch.nn.Linear(2, 2) a3 = ComposableTorchStateful(a2, b2, c2) b3 = 2023 c3 = torch.nn.Linear(2, 2) obj = ComposableTorchStateful(a3, b3, c3) return obj
def test_add_extensions_metadata_2(self): """Test that add_extensions_metadata doesn't add extensions that are not used. In this case we will use a config containing torch, but we will make_component on torch so that it can be compiled. After that, we add_extensions_metadata with torch, which is a valid extensions for the config (redundant, but valid). """ TORCH_TAG_PREFIX = "torch" make_component(torch.nn.Module, TORCH_TAG_PREFIX, only_module='torch.nn') config = """ !torch.Linear in_features: 2 out_features: 2 """ schema = yaml.load(config) schema.add_extensions_metadata({"torch": "torch"}) assert schema._extensions == {"torch": "torch"} mixed_ext = TestSerializationExtensions.EXTENSIONS.copy() mixed_ext.update({"torch": "torch"}) schema.add_extensions_metadata(mixed_ext) assert schema._extensions == {"torch": "torch"}
def _factory(from_config, x=-1): if from_config: config = f"!BasicStateful\nx: {x}\n" obj = yaml.load(config)() return obj else: obj = BasicStateful(x=x) return obj
def _factory(from_config, x=-1): if from_config: config = f"!RootTorch\nx: {x}\n" obj = yaml.load(config)() return obj else: obj = RootTorch(x=x) return obj
def _factory(from_config): if from_config: config = f"!{class_.__name__} {{}}\n" obj = yaml.load(config)() return obj else: obj = class_() return obj
def test_divide_nested_grid_search_options_non_nested_options(make_classes): A, B = make_classes txt = """ !A akw1: 8 akw2: !B bkw1: !g [1, 3, 5] bkw2: 'hello' """ config = yaml.load(txt) divided_configs = list(divide_nested_grid_search_options(config)) assert repr(divided_configs) == repr([config])
def test_component_override(make_classes): A, B = make_classes txt = """ !A akw1: 8 akw2: !B bkw1: 1 bkw2: 'test' """ a_schema = yaml.load(txt) a = a_schema(akw1=9) assert a.akw1 == 9
def test_registrable_roundtrip(make_classes): A, B = make_classes txt = """a: !A akw1: 8 akw2: !B bkw1: 2 bkw2: hello world """ config = yaml.load(txt) with StringIO() as s: yaml.dump(config, s) assert s.getvalue() == txt
def test_registrable_factory(make_classes): A, B = make_classes txt = """a: !A.some_factory akw1: 8 akw2: !B bkw1: 2 bkw2: hello world """ config = yaml.load(txt) a = config['a'] assert a.akw1 == 8 assert a.akw2 is not None assert hasattr(a.akw2, "bkw1") assert a.akw2.bkw1 == 2
def test_registrable_load_new_class(make_new_classes): A, B = make_new_classes txt = """a: !a_class akw1: 8 akw2: !b_class bkw1: 2 bkw2: hello world """ config = yaml.load(txt) a = config['a'] assert a.akw1 == 8 assert a.akw2 is not None assert hasattr(a.akw2, "bkw1") assert a.akw2.bkw1 == 2
def test_component_basic_top_level(make_classes): A, B = make_classes txt = """ !A akw1: 8 akw2: !B bkw1: 1 bkw2: 'test' """ a_schema = yaml.load(txt) a = a_schema() assert a.akw1 == 8 assert a.akw2 is not None assert a.akw2.bkw1 == 1
def test_registrable_load_context(make_namespace_classes): A, B = make_namespace_classes txt = """a: !ns.A akw1: 8 akw2: !ns.B bkw1: 2 bkw2: hello world """ config = yaml.load(txt) a = config['a'] assert a.akw1 == 8 assert a.akw2 is not None assert hasattr(a.akw2, "bkw1") assert a.akw2.bkw1 == 2
def test_component_basic(make_classes): A, B = make_classes txt = """ top: !A akw1: 8 akw2: !B bkw1: 1 bkw2: 'test' """ config = yaml.load(txt) a = config['top']() assert a.akw1 == 8 assert a.akw2 is not None assert a.akw2.bkw1 == 1
def test_contains(self, make_classes): A, B = make_classes txt = """ top: !A akw1: 8 akw2: !B bkw1: 1 bkw2: 'test' """ schema_a = yaml.load(txt)['top'] assert schema_a.contains(schema_a, original_link=None)[0] present, updated_path = schema_a.contains(schema_a.akw2, original_link=None) print(f"present: {present}, updated_path: {updated_path}") assert present assert updated_path == ['akw2'] assert not schema_a.akw2.contains(schema_a, original_link=None)[0]
def test_component_schema_dict_access(make_classes): A, B = make_classes txt = """ !A akw1: 8 akw2: !B bkw1: 1 bkw2: 'test' """ a_schema = yaml.load(txt) assert a_schema['akw1'] == 8 assert a_schema['akw2']['bkw2'] == 'test' a_schema['akw2']['bkw1'] = 13 assert a_schema['akw2']['bkw1'] == 13 a_schema.keywords['akw2'].keywords['bkw1'] = 14 a = a_schema() assert a.akw2.bkw1 == 14
def test_component_anchors_compile_to_same_instance(make_classes_2): txt = """ one: !A akw2: &theb !B bkw2: test bkw1: 1 akw1: 8 two: !A akw1: 8 # Comment Here akw2: *theb """ config = yaml.load(txt) a1 = config["one"]() a1.akw2.bkw1 = 6 a2 = config["two"]() assert a1.akw2 is a2.akw2 assert a1.akw2.bkw1 == a2.akw2.bkw1
def test_registrable_roundtrip_new_default(make_new_classes): A, B = make_new_classes txt = """a: !a_class akw1: 8 akw2: !b_ bkw1: 2 bkw2: hello world """ txt_default_alias = """a: !a_class akw1: 8 akw2: !b_ bkw1: 2 bkw2: hello world """ config = yaml.load(txt) with StringIO() as s: yaml.dump(config, s) assert s.getvalue() == txt_default_alias
def complex_builder_nontorch_root(from_config, schema=False, x=-1): if from_config: config = """ !ComposableContainer item: !ComposableTorchStatefulPrime a: !ComposableTorchStateful a: !ComposableTorchStateful a: !BasicStateful x: {} b: 2021 c: !torch.Linear in_features: 2 out_features: 2 b: 2022 c: !torch.Linear in_features: 2 out_features: 2 b: 2023 c: !torch.Linear in_features: 2 out_features: 2 """ config.format(x) obj = yaml.load(config) if not schema: obj = obj() return obj else: a1 = BasicStateful(x=x) b1 = 2021 c1 = torch.nn.Linear(2, 2) a2 = ComposableTorchStateful(a1, b1, c1) b2 = 2022 c2 = torch.nn.Linear(2, 2) a3 = ComposableTorchStateful(a2, b2, c2) b3 = 2023 c3 = torch.nn.Linear(2, 2) item = ComposableTorchStateful(a3, b3, c3) obj = ComposableContainer(item) return obj
def compile_runnable(self, content: str) -> Runnable: """Compiles and returns the Runnable. IMPORTANT: This method should run after all extensions were registered. Parameters ---------- content: str The runnable, as a YAML string Returns ------- Runnable The compiled experiment. """ ret: Any = yaml.load(content) if not isinstance(ret, Runnable): raise ValueError("Tried to run a non-Runnable") cast(Runnable, ret) return ret
def test_component_dumping_with_defaults_and_comments(make_classes_2): A, B = make_classes_2 txt = """!A akw1: 8 # Comment Here akw2: !B bkw1: 1 bkw2: !!str test """ txt_expected = """!A akw1: 8 akw2: !B bkw1: 1 bkw2: test bkw3: 99 """ a_schema = yaml.load(txt) a = a_schema() with StringIO() as stream: yaml.dump(a, stream) assert txt_expected == stream.getvalue()
def test_module_save_and_load_single_instance_appears_twice(self, make_classes_2): txt = """ !C one: !A akw2: &theb !B bkw2: test bkw1: 1 akw1: 8 two: !A akw1: 8 # Comment Here akw2: *theb """ c = yaml.load(txt)() c.one.akw2.bkw1 = 6 assert c.one.akw2 is c.two.akw2 assert c.one.akw2.bkw1 == c.two.akw2.bkw1 with tempfile.TemporaryDirectory() as path: save(c, path) state = load_state_from_file(path) loaded_c = load(path) assert loaded_c.one.akw2 is loaded_c.two.akw2 assert loaded_c.one.akw2.bkw1 == loaded_c.two.akw2.bkw1
def schema_builder(): config = """ !Basic """ obj = yaml.load(config) return obj
def _factory(): config = "!A {}\n" return yaml.load(config)()