def train_model_distributed(config): assert (config.use_cuda_if_available and torch.cuda.is_available()) or config.distributed_world_size == 1, ( "distributed training is only available for GPU training") assert ( config.distributed_world_size == 1 or not config.task.__class__.__name__ == "DisjointMultitask.Config" ), "Distributed training currently not supported for DisjointMultitask" assert (config.distributed_world_size == 1 or config.distributed_world_size <= torch.cuda.device_count()), ( f"Only {torch.cuda.device_count()} GPUs are available, " "{config.distributed_world_size} GPUs were requested") print( f"\n=== Starting training, World size is {config.distributed_world_size}" ) if not config.use_cuda_if_available or not torch.cuda.is_available(): run_single(0, config_to_json(PyTextConfig, config), 1, None) else: with tempfile.NamedTemporaryFile(delete=False, suffix=".dist_sync") as sync_file: dist_init_method = "file://" + sync_file.name spawn( run_single, ( config_to_json(PyTextConfig, config), config.distributed_world_size, dist_init_method, ), config.distributed_world_size, )
def test_config_to_json_for_dict(self): """For a config that contains a dict inside it, verify that config can be correctly created from a json/dict, and config can be correctly serialized and de-serialized """ class TestConfigContainer(ConfigBase): class TestConfig(ConfigBase): a_dict: Dict[str, Any] test_config: TestConfig a_dict = { "param2": { "nested_param1": 10, "nested_param2": 2 }, "param1": "val1" } a_config_containing_dict = {"test_config": {"a_dict": a_dict}} pytext_config = serialize.config_from_json(TestConfigContainer, a_config_containing_dict) # verify that a_dict was read correctly self.assertEqual(pytext_config.test_config.a_dict, a_dict) # serialize config to json, and deserialize back to config # verify that nothing changed jsonified_config = serialize.config_to_json(TestConfigContainer, pytext_config) pytext_config_deserialized = serialize.config_from_json( TestConfigContainer, jsonified_config) self.assertEqual(pytext_config, pytext_config_deserialized)
def update_config(context): """ Load a config file, update to latest version and prints the result. """ config = context.obj.load_config() config_json = config_to_json(PyTextConfig, config) print(json.dumps(config_json, sort_keys=True, indent=2))
def test_config_to_json_for_union(self): """For a config that contains a dict inside it, verify that config can be correctly created from a json/dict, and config can be correctly serialized and de-serialized """ class Boo(ConfigBase): aye: str class TestConfigContainer(ConfigBase): class TestConfig(ConfigBase): a_union: Union[Boo, List[Boo]] test_config: TestConfig a_str = "abc" a_config_containing_union = { "test_config": { "a_union": { "boo": { "aye": a_str } } } } pytext_config = serialize.config_from_json(TestConfigContainer, a_config_containing_union) # verify that a_dict was read correctly self.assertEqual(pytext_config.test_config.a_union.aye, a_str) # serialize config to json, and deserialize back to config # verify that nothing changed jsonified_config = serialize.config_to_json(TestConfigContainer, pytext_config) pytext_config_deserialized = serialize.config_from_json( TestConfigContainer, jsonified_config) self.assertEqual(pytext_config, pytext_config_deserialized)
def test_config_to_json_for_nested_objectsboo2(self): """For a config that contains a dict inside it, verify that config can be correctly created from a json/dict, and config can be correctly serialized and de-serialized """ class Boo3(ConfigBase): ayeaye: str class Boo2(ConfigBase): aye: Boo3 class TestConfigContainer(ConfigBase): class TestConfig(ConfigBase): a_list: List[Boo2] test_config: TestConfig a_str = "abc" # a_boo = {"boo1": {"aye": a_str}} a_boo = {"ayeaye": a_str} # b_boo = {"aye": {"boo3": a_boo}} b_boo = {"aye": a_boo} # a_boo_list = {"boolist1": {"boolistitems": [a_boo]}} a_config_containing_union = {"test_config": {"a_list": [b_boo]}} pytext_config = serialize.config_from_json(TestConfigContainer, a_config_containing_union) # verify that a_dict was read correctly self.assertEqual(pytext_config.test_config.a_list[0].aye.ayeaye, a_str) # serialize config to json, and deserialize back to config # verify that nothing changed jsonified_config = serialize.config_to_json(TestConfigContainer, pytext_config) pytext_config_deserialized = serialize.config_from_json( TestConfigContainer, jsonified_config) self.assertEqual(pytext_config, pytext_config_deserialized)
def test_config_to_json_for_union_with_baa(self): """For a config that contains a dict inside it, verify that config can be correctly created from a json/dict, and config can be correctly serialized and de-serialized """ class Boo(ConfigBase): aye: str class Baa(ConfigBase): oye: int class TestConfigContainer(ConfigBase): class TestConfig(ConfigBase): a_union: Union[Boo, Baa] test_config: TestConfig an_int = 1 a_baa = {"baa": {"oye": an_int}} a_config_containing_union = {"test_config": {"a_union": a_baa}} pytext_config = serialize.config_from_json(TestConfigContainer, a_config_containing_union) # verify that a_dict was read correctly self.assertEqual(pytext_config.test_config.a_union.oye, an_int) # serialize config to json, and deserialize back to config # verify that nothing changed jsonified_config = serialize.config_to_json(TestConfigContainer, pytext_config) pytext_config_deserialized = serialize.config_from_json( TestConfigContainer, jsonified_config) self.assertEqual(pytext_config, pytext_config_deserialized)
def gen_default_config(context, task_name, options): """ Generate a config for `task_name` with default values. Optionally, override the defaults by passing your desired components as `options`. """ try: cfg = gen_config_impl(task_name, *options) except TypeError as ex: eprint( "ERROR - Cannot create this config", "because some fields don't have a default value:", ex, ) sys.exit(-1) # add the --include to the config generated if context.obj.include: if cfg.include_dirs is None: cfg.include_dirs = [] for path in context.obj.include: cfg.include_dirs.append(path.rstrip("/")) cfg_json = config_to_json(PyTextConfig, cfg) print(json.dumps(cfg_json, sort_keys=True, indent=2))
def gen_config_impl(task_name, options): task_class_set = find_config_class(task_name) if not task_class_set: raise Exception(f"Unknown task class: {task_name}") elif len(task_class_set) > 1: raise Exception(f"Multiple tasks named {task_name}: {task_class_set}") task_class = next(iter(task_class_set)) root = PyTextConfig(task=task_class.Config()) # Use components listed in options instead of defaults for opt in options: replace_class_set = find_config_class(opt) if not replace_class_set: raise Exception(f"Not a component class: {opt}") elif len(replace_class_set) > 1: raise Exception( f"Multiple component named {opt}: {replace_class_set}") replace_class = next(iter(replace_class_set)) found = replace_components(root, opt, set(replace_class.__bases__)) if found: eprint("INFO - Applying option:", "->".join(reversed(found)), "=", opt) obj = root for k in reversed(found[1:]): obj = getattr(obj, k) if hasattr(replace_class, "Config"): setattr(obj, found[0], replace_class.Config()) else: setattr(obj, found[0], replace_class()) else: raise Exception(f"Unknown option: {opt}") return config_to_json(PyTextConfig, root)
def format_config_rst(config): attrs = config_attrs(config) def join(elements): return "\n\n\n".join(e for e in elements if e) bases = [Config.from_config(base) for base in config.config.__bases__] config_doc = "\n".join(( f".. py:currentmodule:: {config.config.__module__}", f".. py:class:: {config.config.__name__}", I(1) + ":noindex:", "", I(1) + "**Bases:** " + ", ".join(f":class:`{base.name} <{base.path}>`\\ " for base in bases), "", *(I(1) + line for line in GoogleDocstring(config.config.__doc__ or "").lines()), "", *itertools.chain.from_iterable(( I(1) + f"**{attr.name}**: {marked_up_type_name(attr.type)}" + (f" = {marked_up_default_value(attr.default)}" if attr. default is not NO_DEFAULT else ""), *(I(2) + line for line in attr.docstring or ("\\ ", )), "", ) for attr in attrs), )) try: config_json = json.dumps(config_to_json(config.config, config.config()), indent=4) except Exception as e: print(e) config_json = "" subclasses = sorted(subclass_configs(config.config), key=lambda c: c.path) return join(( rst_big_header(config.config.__name__), *((f"**Component:** :class:`{config.config.__COMPONENT__.__name__} " + f" <{canonical_path(config.config.__COMPONENT__)}>`\ ", ) if hasattr(config.config, "__COMPONENT__") else ()), config_doc, "\n".join(( "**Subclasses**", *(I(1) + f"- :class:`{child.name} <{child.path}>`\\ " for child in subclasses), )) if hasattr(config.config, "__EXPANSIBLE__") and subclasses else "", *(( "**Default JSON**", ".. code-block:: json", "\n".join(I(1) + line for line in config_json.split("\n")), ) if config_json else ("\n".join(( ".. warning::", I(1) + "This config has parameters with no default values.", I(1) + "We aren't yet able to generate functional JSON for it.", )), )), ))
def train_model_distributed(config, metric_channels: Optional[List[Channel]]): assert ( config.use_cuda_if_available and torch.cuda.is_available() ) or config.distributed_world_size == 1, ( "distributed training is only available for GPU training" ) assert ( config.distributed_world_size == 1 or config.distributed_world_size <= torch.cuda.device_count() ), ( f"Only {torch.cuda.device_count()} GPUs are available, " "{config.distributed_world_size} GPUs were requested" ) print(f"\n=== Starting training, World size is {config.distributed_world_size}") if not config.use_cuda_if_available or not torch.cuda.is_available(): run_single( rank=0, config_json=config_to_json(PyTextConfig, config), world_size=1, dist_init_method=None, metadata=None, metric_channels=metric_channels, ) else: with tempfile.NamedTemporaryFile( delete=False, suffix=".dist_sync" ) as sync_file: dist_init_method = "file://" + sync_file.name metadata = prepare_task_metadata(config) spawn( run_single, ( config_to_json(PyTextConfig, config), config.distributed_world_size, dist_init_method, metadata, [], ), config.distributed_world_size, )
def test_component_subconfig_serialize(self): config_json = json.loads("""{ "foo": 5, "models": [{ "bar": 12, "m2s1s1": "thing" }, { "m2s1s1": "thing2" }] }""") config = config_from_json(JointModel.Config, config_json) serialized = config_to_json(JointModel.Config, config) again = config_from_json(JointModel.Config, serialized) self.assertEqual(again.foo, 5) self.assertEqual(again.models[0].m2s1s1, "thing") self.assertEqual(again.models[1].bar, 3)
def gen_default_config(context, task_name, options): """ Generate a config for `task_name` with default values. Optionally, override the defaults by passing your desired components as `options`. """ try: cfg = gen_config_impl(task_name, *options) except TypeError as ex: eprint( "ERROR - Cannot create this config", "because some fields don't have a default value:", ex, ) sys.exit(-1) cfg_json = config_to_json(PyTextConfig, cfg) print(json.dumps(cfg_json, sort_keys=True, indent=2))
def test_serialize_union_with_expansible_component(self): config = TestConfig( model=ModelFoo.Config(), model_bar_x=ModelBar.Config(), model_union=ModelBar1.Config(), datahandler=SubDataHandler.Config(), ) json = config_to_json(TestConfig, config) print(json) self.assertEqual(json["model"]["ModelFoo"]["foo"], 2) self.assertEqual(json["model_bar_x"]["ModelBar"]["bar"], "bar") self.assertEqual(json["model_union"]["ModelBar1"]["bar"], "bar1") self.assertEqual(json["datahandler"]["SubDataHandler"]["foo"], 3) config = config_from_json(TestConfig, json) self.assertEqual(config.model.foo, 2) self.assertEqual(config.model_bar_x.bar, "bar") self.assertEqual(config.model_union.bar, "bar1") self.assertEqual(config.datahandler.foo, 3)
def gen_config_impl(task_name, options): # import the classes required by parameters requested_classes = [locate(opt) for opt in options] + [locate(task_name)] register_tasks(requested_classes) task_class_set = find_config_class(task_name) if not task_class_set: raise Exception(f"Unknown task class: {task_name} " "(try fully qualified class name?)") elif len(task_class_set) > 1: raise Exception(f"Multiple tasks named {task_name}: {task_class_set}") task_class = next(iter(task_class_set)) task_config = getattr(task_class, "example_config", task_class.Config) root = PyTextConfig(task=task_config(), version=LATEST_VERSION) # Use components listed in options instead of defaults for opt in options: replace_class_set = find_config_class(opt) if not replace_class_set: raise Exception(f"Not a component class: {opt}") elif len(replace_class_set) > 1: raise Exception( f"Multiple component named {opt}: {replace_class_set}") replace_class = next(iter(replace_class_set)) found = replace_components(root, opt, set(replace_class.__bases__)) if found: eprint("INFO - Applying option:", "->".join(reversed(found)), "=", opt) obj = root for k in reversed(found[1:]): obj = getattr(obj, k) if hasattr(replace_class, "Config"): setattr(obj, found[0], replace_class.Config()) else: setattr(obj, found[0], replace_class()) else: raise Exception(f"Unknown option: {opt}") return config_to_json(PyTextConfig, root)