def _stop_workers(self, worker_group: WorkerGroup) -> None: # workers are fake, nothing to stop; just clear the rdzv info worker_group.group_rank = None worker_group.group_world_size = None for w in worker_group.workers: w.id = None w.global_rank = None w.world_size = None self.stop_workers_call_count += 1
def test_worker_group_constructor(self): spec = WorkerSpec( role="test_trainer", local_world_size=4, fn=do_nothing(), args=(), rdzv_handler=None, max_restarts=50, monitor_interval=1, ) worker_group = WorkerGroup(spec) self.assertEqual(WorkerState.INIT, worker_group.state) workers = worker_group.workers self.assertEqual(4, len(workers)) # validate full, consecutive local ranks self.assertSetEqual(set(range(4)), {w.local_rank for w in workers}) # global_rank, world_size are assigned after rdzv # id is assigned after starting worker (by the agent) # validate there are None for w in workers: self.assertIsNone(w.global_rank) self.assertIsNone(w.world_size) self.assertIsNone(w.id) # rank and store are assigned after rdzv; validate that they are None self.assertIsNone(worker_group.group_rank) self.assertIsNone(worker_group.store)