def test_placement_group(ray_start_2_cpus): """Tests that workers can be removed and added to a placement group.""" num_workers = 2 bundle = {"CPU": 1} bundles = [bundle.copy() for _ in range(num_workers)] placement_group = ray.util.placement_group(bundles) wg = WorkerGroup(num_workers=num_workers, placement_group=placement_group) wg.remove_workers([0]) wg.add_workers(1)
def handle_failure(self, worker_group: WorkerGroup, failed_worker_indexes: List[int], backend_config: BackendConfig): """Failure handling for Tensorflow. Instead of restarting all workers, the failed workers are removed from the ``WorkerGroup``. The backend and session are shutdown on the remaining workers. Then new workers are added back in. """ worker_group.remove_workers(failed_worker_indexes) if len(worker_group) > 0: self.on_shutdown(worker_group, backend_config) worker_group.execute(shutdown_session) worker_group.add_workers(len(failed_worker_indexes)) self.on_start(worker_group, backend_config)