示例#1
0
文件: api_test.py 项目: d4l3k/elastic
 def _stop_workers(self, worker_group: WorkerGroup) -> None:
     # workers are fake, nothing to stop; just clear the rdzv info
     worker_group.group_rank = None
     worker_group.group_world_size = None
     for w in worker_group.workers:
         w.id = None
         w.global_rank = None
         w.world_size = None
     self.stop_workers_call_count += 1
示例#2
0
文件: api_test.py 项目: d4l3k/elastic
    def test_worker_group_constructor(self):
        spec = WorkerSpec(
            role="test_trainer",
            local_world_size=4,
            fn=do_nothing(),
            args=(),
            rdzv_handler=None,
            max_restarts=50,
            monitor_interval=1,
        )
        worker_group = WorkerGroup(spec)

        self.assertEqual(WorkerState.INIT, worker_group.state)

        workers = worker_group.workers
        self.assertEqual(4, len(workers))

        # validate full, consecutive local ranks
        self.assertSetEqual(set(range(4)), {w.local_rank for w in workers})

        # global_rank, world_size are assigned after rdzv
        # id is assigned after starting worker (by the agent)
        # validate there are None
        for w in workers:
            self.assertIsNone(w.global_rank)
            self.assertIsNone(w.world_size)
            self.assertIsNone(w.id)

        # rank and store are assigned after rdzv; validate that they are None
        self.assertIsNone(worker_group.group_rank)
        self.assertIsNone(worker_group.store)