예제 #1
0
 def test_build_elastic_role_flag_args(self):
     role = ElasticRole("test_role", no_python=False).runs("user_script.py")
     self.assertEqual(
         [
             "-m",
             "torchelastic.distributed.launch",
             "--rdzv_backend",
             "etcd",
             "--rdzv_id",
             macros.app_id,
             "user_script.py",
         ],
         role.args,
     )
예제 #2
0
 def test_build_elastic_role_img_root_already_in_entrypoint(self):
     role = ElasticRole("test_role", no_python=False).runs(
         os.path.join(macros.img_root, "user_script.py"))
     self.assertEqual(
         [
             "-m",
             "torchelastic.distributed.launch",
             "--rdzv_backend",
             "etcd",
             "--rdzv_id",
             macros.app_id,
             "--role",
             "test_role",
             os.path.join(macros.img_root, "user_script.py"),
         ],
         role.args,
     )
예제 #3
0
 def test_build_elastic_role(self):
     # runs: python -m torchelastic.distributed.launch
     #                    --nnodes 2:4
     #                    --max_restarts 3
     #                    --no_python True
     #                    --rdzv_backend etcd
     #                    --rdzv_id ${app_id}
     #                    /bin/echo hello world
     container = Container(image="test_image")
     container.ports(foo=8080)
     elastic_trainer = (ElasticRole(
         "elastic_trainer", nnodes="2:4", max_restarts=3,
         no_python=True).runs("/bin/echo",
                              "hello",
                              "world",
                              ENV_VAR_1="FOOBAR").on(container).replicas(2))
     self.assertEqual("elastic_trainer", elastic_trainer.name)
     self.assertEqual("python", elastic_trainer.entrypoint)
     self.assertEqual(
         [
             "-m",
             "torchelastic.distributed.launch",
             "--nnodes",
             "2:4",
             "--max_restarts",
             "3",
             "--no_python",
             "--rdzv_backend",
             "etcd",
             "--rdzv_id",
             macros.app_id,
             "--role",
             "elastic_trainer",
             "/bin/echo",
             "hello",
             "world",
         ],
         elastic_trainer.args,
     )
     self.assertEqual({"ENV_VAR_1": "FOOBAR"}, elastic_trainer.env)
     self.assertEqual(container, elastic_trainer.container)
     self.assertEqual(2, elastic_trainer.num_replicas)
예제 #4
0
 def test_build_elastic_role_override_rdzv_params(self):
     role = ElasticRole(
         "test_role", nnodes="2:4", rdzv_backend="zeus", rdzv_id="foobar"
     ).runs("user_script.py", "--script_arg", "foo")
     self.assertEqual(
         [
             "-m",
             "torchelastic.distributed.launch",
             "--nnodes",
             "2:4",
             "--rdzv_backend",
             "zeus",
             "--rdzv_id",
             "foobar",
             "user_script.py",
             "--script_arg",
             "foo",
         ],
         role.args,
     )
예제 #5
0
    def test_json_serialization(self):
        """
        Tests that an ElasticRole can be serialized into json (dict)
        then recreated as a Role. An ElasticRole is really just a builder
        utility to make it easy for users to create a Role with the entrypoint
        being ``torchelastic.distributed.launch``
        """
        resource = Resource(cpu=1, gpu=0, memMB=512)
        container = Container(image="user_image",
                              resources={
                                  "default": resource
                              }).ports(tensorboard=8080)
        elastic_role = (ElasticRole("test_role",
                                    nnodes="2:4",
                                    rdzv_backend="etcd",
                                    rdzv_id="foobar").runs(
                                        "user_script.py", "--script_arg",
                                        "foo").on(container).replicas(3))

        # this is effectively JSON
        elastic_json = dataclasses.asdict(elastic_role)
        container_json = elastic_json.pop("container")
        resources_json = container_json.pop("resources")
        container_json["resources"] = {}
        for sched, resource_json in resources_json.items():
            container_json["resources"][sched] = Resource(**resource_json)

        role = Role(
            **elastic_json,
            container=Container(**container_json),
        )
        self.assertEqual(container, role.container)
        self.assertEqual(elastic_role.name, role.name)
        self.assertEqual(elastic_role.entrypoint, role.entrypoint)
        self.assertEqual(
            elastic_role.args,
            role.args,
        )
        self.assertEqual(dataclasses.asdict(elastic_role),
                         dataclasses.asdict(role))