コード例 #1
0
    def test_worker_group_constructor(self):
        spec = WorkerSpec(
            role="test_trainer",
            local_world_size=4,
            fn=do_nothing,
            args=(),
            rdzv_handler=None,
            max_restarts=50,
            monitor_interval=1,
        )
        worker_group = WorkerGroup(spec)

        self.assertEqual(WorkerState.INIT, worker_group.state)

        workers = worker_group.workers
        self.assertEqual(4, len(workers))

        # validate full, consecutive local ranks
        self.assertSetEqual(set(range(4)), {w.local_rank for w in workers})

        # global_rank, world_size are assigned after rdzv
        # id is assigned after starting worker (by the agent)
        # validate there are None
        for w in workers:
            self.assertEqual(-1, w.global_rank)
            self.assertEqual(-1, w.world_size)
            self.assertEqual(None, w.id)

        # rank and store are assigned after rdzv; validate that they are None
        self.assertIsNone(worker_group.group_rank)
        self.assertIsNone(worker_group.store)
コード例 #2
0
 def _stop_workers(self, worker_group: WorkerGroup) -> None:
     # workers are fake, nothing to stop; just clear the rdzv info
     worker_group.group_rank = None
     worker_group.group_world_size = None
     self.stop_workers_call_count += 1