def get_with_failure_handling(self, remote_values): """Gets the remote values while handling for worker failures. This method should be called instead of ``ray.get()`` directly in order to handle worker failures. If a worker failure is identified, backend specific failure handling is executed and a ``TrainingWorkerError`` is raised. Args: remote_values (list): List of object refs representing functions that may fail in the middle of execution. For example, running a Train training loop in multiple parallel actor calls. Returns: The resolved objects represented by the passed in ObjectRefs. """ success, failed_worker_indexes = check_for_failure(remote_values) if success: return ray.get(remote_values) else: self._increment_failures() try: self._backend.handle_failure(self.worker_group, failed_worker_indexes, self._backend_config) except RayActorError as exc: logger.exception(str(exc)) self._restart() raise TrainingWorkerError
def get_with_failure_handling(self, remote_values): """Gets the remote values while handling for worker failures. This method should be called instead of ``ray.get()`` directly in order to handle worker failures. If a worker failure is identified, backend specific failure handling is executed and a ``TrainingWorkerError`` is raised. Args: remote_values (list): List of object refs representing functions that may fail in the middle of execution. For example, running a Train training loop in multiple parallel actor calls. Returns: The resolved objects represented by the passed in ObjectRefs. """ success = check_for_failure(remote_values) if success: return ray.get(remote_values) else: self._increment_failures() logger.warning( "Failure identified during training. Restarting all workers and " "continuing training from latest checkpoint." ) self._restart() raise TrainingWorkerError