Example #1
0
def train_model_distributed(config):
    assert (config.use_cuda_if_available and
            torch.cuda.is_available()) or config.distributed_world_size == 1, (
                "distributed training is only available for GPU training")
    assert (
        config.distributed_world_size == 1
        or not config.task.__class__.__name__ == "DisjointMultitask.Config"
    ), "Distributed training currently not supported for DisjointMultitask"
    assert (config.distributed_world_size == 1
            or config.distributed_world_size <= torch.cuda.device_count()), (
                f"Only {torch.cuda.device_count()} GPUs are available, "
                "{config.distributed_world_size} GPUs were requested")

    print(
        f"\n=== Starting training, World size is {config.distributed_world_size}"
    )
    if not config.use_cuda_if_available or not torch.cuda.is_available():
        run_single(0, config_to_json(PyTextConfig, config), 1, None)
    else:
        with tempfile.NamedTemporaryFile(delete=False,
                                         suffix=".dist_sync") as sync_file:
            dist_init_method = "file://" + sync_file.name
            spawn(
                run_single,
                (
                    config_to_json(PyTextConfig, config),
                    config.distributed_world_size,
                    dist_init_method,
                ),
                config.distributed_world_size,
            )
Example #2
0
    def test_config_to_json_for_dict(self):
        """For a config that contains a dict inside it, verify that config
           can be correctly created from a json/dict, and config can be correctly
           serialized and de-serialized
        """
        class TestConfigContainer(ConfigBase):
            class TestConfig(ConfigBase):
                a_dict: Dict[str, Any]

            test_config: TestConfig

        a_dict = {
            "param2": {
                "nested_param1": 10,
                "nested_param2": 2
            },
            "param1": "val1"
        }
        a_config_containing_dict = {"test_config": {"a_dict": a_dict}}
        pytext_config = serialize.config_from_json(TestConfigContainer,
                                                   a_config_containing_dict)
        # verify that a_dict was read correctly
        self.assertEqual(pytext_config.test_config.a_dict, a_dict)
        # serialize config to json, and deserialize back to config
        # verify that nothing changed
        jsonified_config = serialize.config_to_json(TestConfigContainer,
                                                    pytext_config)
        pytext_config_deserialized = serialize.config_from_json(
            TestConfigContainer, jsonified_config)
        self.assertEqual(pytext_config, pytext_config_deserialized)
Example #3
0
def update_config(context):
    """
        Load a config file, update to latest version and prints the result.
    """
    config = context.obj.load_config()
    config_json = config_to_json(PyTextConfig, config)
    print(json.dumps(config_json, sort_keys=True, indent=2))
    def test_config_to_json_for_union(self):
        """For a config that contains a dict inside it, verify that config
        can be correctly created from a json/dict, and config can be correctly
        serialized and de-serialized
        """
        class Boo(ConfigBase):
            aye: str

        class TestConfigContainer(ConfigBase):
            class TestConfig(ConfigBase):
                a_union: Union[Boo, List[Boo]]

            test_config: TestConfig

        a_str = "abc"
        a_config_containing_union = {
            "test_config": {
                "a_union": {
                    "boo": {
                        "aye": a_str
                    }
                }
            }
        }
        pytext_config = serialize.config_from_json(TestConfigContainer,
                                                   a_config_containing_union)
        # verify that a_dict was read correctly
        self.assertEqual(pytext_config.test_config.a_union.aye, a_str)
        # serialize config to json, and deserialize back to config
        # verify that nothing changed
        jsonified_config = serialize.config_to_json(TestConfigContainer,
                                                    pytext_config)
        pytext_config_deserialized = serialize.config_from_json(
            TestConfigContainer, jsonified_config)
        self.assertEqual(pytext_config, pytext_config_deserialized)
    def test_config_to_json_for_nested_objectsboo2(self):
        """For a config that contains a dict inside it, verify that config
        can be correctly created from a json/dict, and config can be correctly
        serialized and de-serialized
        """
        class Boo3(ConfigBase):
            ayeaye: str

        class Boo2(ConfigBase):
            aye: Boo3

        class TestConfigContainer(ConfigBase):
            class TestConfig(ConfigBase):
                a_list: List[Boo2]

            test_config: TestConfig

        a_str = "abc"
        # a_boo = {"boo1": {"aye": a_str}}
        a_boo = {"ayeaye": a_str}
        # b_boo = {"aye": {"boo3": a_boo}}
        b_boo = {"aye": a_boo}
        # a_boo_list = {"boolist1": {"boolistitems": [a_boo]}}
        a_config_containing_union = {"test_config": {"a_list": [b_boo]}}
        pytext_config = serialize.config_from_json(TestConfigContainer,
                                                   a_config_containing_union)
        # verify that a_dict was read correctly
        self.assertEqual(pytext_config.test_config.a_list[0].aye.ayeaye, a_str)
        # serialize config to json, and deserialize back to config
        # verify that nothing changed
        jsonified_config = serialize.config_to_json(TestConfigContainer,
                                                    pytext_config)
        pytext_config_deserialized = serialize.config_from_json(
            TestConfigContainer, jsonified_config)
        self.assertEqual(pytext_config, pytext_config_deserialized)
    def test_config_to_json_for_union_with_baa(self):
        """For a config that contains a dict inside it, verify that config
        can be correctly created from a json/dict, and config can be correctly
        serialized and de-serialized
        """
        class Boo(ConfigBase):
            aye: str

        class Baa(ConfigBase):
            oye: int

        class TestConfigContainer(ConfigBase):
            class TestConfig(ConfigBase):
                a_union: Union[Boo, Baa]

            test_config: TestConfig

        an_int = 1
        a_baa = {"baa": {"oye": an_int}}
        a_config_containing_union = {"test_config": {"a_union": a_baa}}
        pytext_config = serialize.config_from_json(TestConfigContainer,
                                                   a_config_containing_union)
        # verify that a_dict was read correctly
        self.assertEqual(pytext_config.test_config.a_union.oye, an_int)
        # serialize config to json, and deserialize back to config
        # verify that nothing changed
        jsonified_config = serialize.config_to_json(TestConfigContainer,
                                                    pytext_config)
        pytext_config_deserialized = serialize.config_from_json(
            TestConfigContainer, jsonified_config)
        self.assertEqual(pytext_config, pytext_config_deserialized)
Example #7
0
def gen_default_config(context, task_name, options):
    """
        Generate a config for `task_name` with default values.
        Optionally, override the defaults by passing your desired
        components as `options`.
    """
    try:
        cfg = gen_config_impl(task_name, *options)
    except TypeError as ex:
        eprint(
            "ERROR - Cannot create this config",
            "because some fields don't have a default value:",
            ex,
        )
        sys.exit(-1)

    # add the --include to the config generated
    if context.obj.include:
        if cfg.include_dirs is None:
            cfg.include_dirs = []
        for path in context.obj.include:
            cfg.include_dirs.append(path.rstrip("/"))

    cfg_json = config_to_json(PyTextConfig, cfg)
    print(json.dumps(cfg_json, sort_keys=True, indent=2))
Example #8
0
def gen_config_impl(task_name, options):
    task_class_set = find_config_class(task_name)
    if not task_class_set:
        raise Exception(f"Unknown task class: {task_name}")
    elif len(task_class_set) > 1:
        raise Exception(f"Multiple tasks named {task_name}: {task_class_set}")

    task_class = next(iter(task_class_set))
    root = PyTextConfig(task=task_class.Config())

    # Use components listed in options instead of defaults
    for opt in options:
        replace_class_set = find_config_class(opt)
        if not replace_class_set:
            raise Exception(f"Not a component class: {opt}")
        elif len(replace_class_set) > 1:
            raise Exception(
                f"Multiple component named {opt}: {replace_class_set}")
        replace_class = next(iter(replace_class_set))
        found = replace_components(root, opt, set(replace_class.__bases__))
        if found:
            eprint("INFO - Applying option:", "->".join(reversed(found)), "=",
                   opt)
            obj = root
            for k in reversed(found[1:]):
                obj = getattr(obj, k)
            if hasattr(replace_class, "Config"):
                setattr(obj, found[0], replace_class.Config())
            else:
                setattr(obj, found[0], replace_class())
        else:
            raise Exception(f"Unknown option: {opt}")
    return config_to_json(PyTextConfig, root)
Example #9
0
def format_config_rst(config):
    attrs = config_attrs(config)

    def join(elements):
        return "\n\n\n".join(e for e in elements if e)

    bases = [Config.from_config(base) for base in config.config.__bases__]

    config_doc = "\n".join((
        f".. py:currentmodule:: {config.config.__module__}",
        f".. py:class:: {config.config.__name__}",
        I(1) + ":noindex:",
        "",
        I(1) + "**Bases:** " +
        ", ".join(f":class:`{base.name} <{base.path}>`\\ " for base in bases),
        "",
        *(I(1) + line
          for line in GoogleDocstring(config.config.__doc__ or "").lines()),
        "",
        *itertools.chain.from_iterable((
            I(1) + f"**{attr.name}**: {marked_up_type_name(attr.type)}" +
            (f" = {marked_up_default_value(attr.default)}" if attr.
             default is not NO_DEFAULT else ""),
            *(I(2) + line for line in attr.docstring or ("\\ ", )),
            "",
        ) for attr in attrs),
    ))

    try:
        config_json = json.dumps(config_to_json(config.config,
                                                config.config()),
                                 indent=4)
    except Exception as e:
        print(e)
        config_json = ""

    subclasses = sorted(subclass_configs(config.config), key=lambda c: c.path)

    return join((
        rst_big_header(config.config.__name__),
        *((f"**Component:** :class:`{config.config.__COMPONENT__.__name__} " +
           f" <{canonical_path(config.config.__COMPONENT__)}>`\ ", )
          if hasattr(config.config, "__COMPONENT__") else ()),
        config_doc,
        "\n".join((
            "**Subclasses**",
            *(I(1) + f"- :class:`{child.name} <{child.path}>`\\ "
              for child in subclasses),
        )) if hasattr(config.config, "__EXPANSIBLE__") and subclasses else "",
        *((
            "**Default JSON**",
            ".. code-block:: json",
            "\n".join(I(1) + line for line in config_json.split("\n")),
        ) if config_json else ("\n".join((
            ".. warning::",
            I(1) + "This config has parameters with no default values.",
            I(1) + "We aren't yet able to generate functional JSON for it.",
        )), )),
    ))
Example #10
0
def train_model_distributed(config, metric_channels: Optional[List[Channel]]):
    assert (
        config.use_cuda_if_available and torch.cuda.is_available()
    ) or config.distributed_world_size == 1, (
        "distributed training is only available for GPU training"
    )
    assert (
        config.distributed_world_size == 1
        or config.distributed_world_size <= torch.cuda.device_count()
    ), (
        f"Only {torch.cuda.device_count()} GPUs are available, "
        "{config.distributed_world_size} GPUs were requested"
    )

    print(f"\n=== Starting training, World size is {config.distributed_world_size}")
    if not config.use_cuda_if_available or not torch.cuda.is_available():
        run_single(
            rank=0,
            config_json=config_to_json(PyTextConfig, config),
            world_size=1,
            dist_init_method=None,
            metadata=None,
            metric_channels=metric_channels,
        )
    else:
        with tempfile.NamedTemporaryFile(
            delete=False, suffix=".dist_sync"
        ) as sync_file:
            dist_init_method = "file://" + sync_file.name
            metadata = prepare_task_metadata(config)
            spawn(
                run_single,
                (
                    config_to_json(PyTextConfig, config),
                    config.distributed_world_size,
                    dist_init_method,
                    metadata,
                    [],
                ),
                config.distributed_world_size,
            )
Example #11
0
 def test_component_subconfig_serialize(self):
     config_json = json.loads("""{
         "foo": 5,
         "models": [{
             "bar": 12,
             "m2s1s1": "thing"
         }, {
             "m2s1s1": "thing2"
         }]
     }""")
     config = config_from_json(JointModel.Config, config_json)
     serialized = config_to_json(JointModel.Config, config)
     again = config_from_json(JointModel.Config, serialized)
     self.assertEqual(again.foo, 5)
     self.assertEqual(again.models[0].m2s1s1, "thing")
     self.assertEqual(again.models[1].bar, 3)
Example #12
0
def gen_default_config(context, task_name, options):
    """
        Generate a config for `task_name` with default values.
        Optionally, override the defaults by passing your desired
        components as `options`.
    """
    try:
        cfg = gen_config_impl(task_name, *options)
    except TypeError as ex:
        eprint(
            "ERROR - Cannot create this config",
            "because some fields don't have a default value:",
            ex,
        )
        sys.exit(-1)
    cfg_json = config_to_json(PyTextConfig, cfg)
    print(json.dumps(cfg_json, sort_keys=True, indent=2))
Example #13
0
 def test_serialize_union_with_expansible_component(self):
     config = TestConfig(
         model=ModelFoo.Config(),
         model_bar_x=ModelBar.Config(),
         model_union=ModelBar1.Config(),
         datahandler=SubDataHandler.Config(),
     )
     json = config_to_json(TestConfig, config)
     print(json)
     self.assertEqual(json["model"]["ModelFoo"]["foo"], 2)
     self.assertEqual(json["model_bar_x"]["ModelBar"]["bar"], "bar")
     self.assertEqual(json["model_union"]["ModelBar1"]["bar"], "bar1")
     self.assertEqual(json["datahandler"]["SubDataHandler"]["foo"], 3)
     config = config_from_json(TestConfig, json)
     self.assertEqual(config.model.foo, 2)
     self.assertEqual(config.model_bar_x.bar, "bar")
     self.assertEqual(config.model_union.bar, "bar1")
     self.assertEqual(config.datahandler.foo, 3)
Example #14
0
def gen_config_impl(task_name, options):
    # import the classes required by parameters
    requested_classes = [locate(opt) for opt in options] + [locate(task_name)]
    register_tasks(requested_classes)

    task_class_set = find_config_class(task_name)
    if not task_class_set:
        raise Exception(f"Unknown task class: {task_name} "
                        "(try fully qualified class name?)")
    elif len(task_class_set) > 1:
        raise Exception(f"Multiple tasks named {task_name}: {task_class_set}")

    task_class = next(iter(task_class_set))
    task_config = getattr(task_class, "example_config", task_class.Config)
    root = PyTextConfig(task=task_config(), version=LATEST_VERSION)

    # Use components listed in options instead of defaults
    for opt in options:
        replace_class_set = find_config_class(opt)
        if not replace_class_set:
            raise Exception(f"Not a component class: {opt}")
        elif len(replace_class_set) > 1:
            raise Exception(
                f"Multiple component named {opt}: {replace_class_set}")
        replace_class = next(iter(replace_class_set))
        found = replace_components(root, opt, set(replace_class.__bases__))
        if found:
            eprint("INFO - Applying option:", "->".join(reversed(found)), "=",
                   opt)
            obj = root
            for k in reversed(found[1:]):
                obj = getattr(obj, k)
            if hasattr(replace_class, "Config"):
                setattr(obj, found[0], replace_class.Config())
            else:
                setattr(obj, found[0], replace_class())
        else:
            raise Exception(f"Unknown option: {opt}")
    return config_to_json(PyTextConfig, root)