示例#1
0
    def _create_workers(self, config):
        """
        Create one ray remote process
        """
        num_gpus = config.get("num_gpus", 0)
        num_cpus = max(config.get("num_cpus", 1), config.get("workers", 0))

        experiment = ray.remote(num_cpus=num_cpus,
                                num_gpus=num_gpus)(self.experiment_class)
        for i in range(1 + self.max_retries):
            self.procs = [experiment.remote()]
            status = self.procs[0].setup_experiment.remote(config)
            self.logger.debug("_create_workers: trial=%s(%s)",
                              self._trial_info.trial_name, self.iteration)

            # Wait for remote function and check for errors
            if ray_utils.check_for_failure([status]):
                return

            # Remote function failed, kill workers and try again
            self.logger.warning(f"Failed to create workers, "
                                f"retrying {i + 1}/{self.max_retries}")
            # Restart all workers on failure
            self._kill_workers()

            # Back off a few seconds
            time.sleep(2**i)
        else:
            # Reached max failures
            err_msg = f"Failed to create workers after {self.max_retries} retries"
            self.logger.error(err_msg)
            raise RuntimeError(err_msg)
示例#2
0
 def _train_epoch(self, num_steps=None, info=None):
     worker_stats = [
         w.train_epoch.remote(num_steps=num_steps, info=info)
         for w in self.workers
     ]
     success = utils.check_for_failure(worker_stats)
     return success, worker_stats
示例#3
0
    def train(self, num_steps=None, profile=False, info=None, dataset=None):
        params = dict(num_steps=num_steps, profile=profile, info=info)
        if dataset:
            dataset.set_num_shards(self.num_workers)

        remote_worker_stats = self.remote_worker_group._train(
            num_steps, profile, info, dataset)
        try:
            if dataset:
                params["iterator"] = dataset.get_shard(self.num_workers - 1)
            local_worker_stats = self.local_worker.train_epoch(**params)
        except RuntimeError as err:
            if "gloo" in err.args[0] and "Timed out" in err.args[0]:
                logger.warning(err)
                return False, None
            if "NCCL" in err.args[0]:  # there is no specific error message
                logger.warning(err)
                return False, None
            if "Connection closed by peer" in err.args[0]:
                logger.warning(err)
                return False, None

            raise err

        success = check_for_failure(remote_worker_stats)
        if success:
            return success, [local_worker_stats] + ray.get(remote_worker_stats)

        return success, None
示例#4
0
    def _train(self):
        self.logger.debug(
            f"_train: {self._trial_info.trial_name}({self.iteration})")
        try:
            # Check if restore checkpoint file fulfills the stop criteria on first run
            pre_experiment_result = None
            if self._first_run:
                self._first_run = False
                if self._restored and self._should_stop():
                    self.logger.warning(
                        f"Restored checkpoint file '{self.restore_checkpoint_file}' "
                        f"fulfills stop criteria without additional training.")
                    return {
                        # do not train or log results, just stop
                        RESULT_DUPLICATE: True,
                        DONE: True
                    }

                # Run any pre-experiment functionality such as pre-training validation.
                # The results are aggregated here so they may be immediately logged
                # as opposed to waiting till the end of the iteration.
                if self._iteration == 0:
                    status = []
                    for w in self.procs:
                        status.append(w.pre_experiment.remote())

                    agg_pre_exp = self.experiment_class.aggregate_pre_experiment_results
                    if ray_utils.check_for_failure(status):
                        results = ray.get(status)
                        pre_experiment_result = agg_pre_exp(results)
                        self.logger.info(
                            f"Pre-Experiment Result: {pre_experiment_result}")

            results = self._run_iteration()

            # Aggregate the results from all processes
            if results is not None:

                # Aggregate results from iteration.
                ret = self.experiment_class.aggregate_results(results)

                self._process_result(ret, pre_experiment_result)
                printable_result = self.experiment_class.get_printable_result(
                    ret)
                self.logger.info(f"End Iteration Result: {printable_result}")

                # Check if we should stop the experiment
                ret[DONE] = self._should_stop()

                return ret

            err_msg = (f"{self._trial_info.trial_name}({self.iteration}): "
                       f"One of the remote workers failed during training")
            self.logger.error(err_msg)
            raise RuntimeError(err_msg)
        except Exception:
            self._kill_workers()
            raise
示例#5
0
 def step_with_fail(self, *args, **kwargs):
     worker_stats = [
         w.train_epoch.remote(*args, **kwargs) for w in self.workers
     ]
     if self._num_failures < 2:
         time.sleep(1)
         self.workers[0].__ray_kill__()
     success = check_for_failure(worker_stats)
     return success, worker_stats
示例#6
0
    def _run_iteration(self):
        """Run one iteration of the experiment"""
        status = []
        for w in self.procs:
            status.append(w.run_iteration.remote())

        # Wait for remote functions and check for errors
        if ray_utils.check_for_failure(status):
            return ray.get(status)
示例#7
0
    def _run_iteration(self):
        """Run one epoch of training on each process."""
        status = []
        for w in self.procs:
            status.append(w.run_epoch.remote())

        # Wait for remote functions and check for errors
        if ray_utils.check_for_failure(status):
            return ray.get(status)
示例#8
0
 def _should_stop(self):
     """
     Whether or not we should stop the experiment
     """
     # Check if we should stop the experiment
     stop_status = self.procs[0].should_stop.remote()
     if ray_utils.check_for_failure([stop_status]):
         return ray.get(stop_status)
     else:
         # Stop on failures
         return True
示例#9
0
    def _train_epoch(self, num_steps=None, profile=False, info=None):
        params = dict(num_steps=num_steps, profile=profile, info=info)

        remote_worker_stats = [
            w.train_epoch.remote(**params) for w in self.remote_workers
        ]

        success = check_for_failure(remote_worker_stats)
        if success:
            return success, ray.get(remote_worker_stats)

        return success, None
示例#10
0
    def train(self, num_steps=None, profile=False, info=None, dataset=None):
        """Runs 1 epoch of training on all workers.

        Has additional logic to check for worker failure.
        """
        if dataset:
            dataset.set_num_shards(self.num_workers)
        remote_worker_stats = self._train(num_steps, profile, info, dataset)
        # Check if each worker has failed before calling ray.get.
        success = check_for_failure(remote_worker_stats)
        if success:
            return success, ray.get(remote_worker_stats)
        return success, None
示例#11
0
    def step_with_fail(self, **params):
        remote_worker_stats = [
            w.train_epoch.remote(**params) for w in self.remote_workers
        ]

        if self._num_failures < 2:
            time.sleep(1)  # Make the batch will fail correctly.
            self.remote_workers[0].__ray_kill__()

        try:
            local_worker_stats = self.local_worker.train_epoch(**params)
        except RuntimeError:
            return False, None

        success = check_for_failure(remote_worker_stats)
        if success:
            return success, [local_worker_stats] + ray.get(remote_worker_stats)

        return success, None
示例#12
0
    def _train(self):
        self.logger.debug(
            f"_train: {self._trial_info.trial_name}({self.iteration})")
        try:
            # Check if restore checkpoint file fulfills the stop criteria on first run
            if self._first_run:
                self._first_run = False
                if self._restored and self._should_stop():
                    self.logger.warning(
                        f"Restored checkpoint file '{self._checkpoint_file}' fulfills "
                        f"stop criteria without additional training.")
                    return {
                        # do not train or log results, just stop
                        RESULT_DUPLICATE: True,
                        DONE: True
                    }

            status = []
            for w in self.procs:
                status.append(w.run_epoch.remote())

            # Wait for remote functions and check for errors
            # Aggregate the results from all processes
            if ray_utils.check_for_failure(status):
                results = ray.get(status)

                ret = copy.deepcopy(results[0])
                ret.update(aggregate_eval_results(results))

                self._process_result(ret)

                # Check if we should stop the experiment
                ret[DONE] = self._should_stop()

                return ret

            err_msg = (f"{self._trial_info.trial_name}({self.iteration}): "
                       f"One of the remote workers failed during training")
            self.logger.error(err_msg)
            raise RuntimeError(err_msg)
        except Exception:
            self._kill_workers()
            raise
示例#13
0
    def step_with_fail(self,
                       num_steps=None,
                       profile=False,
                       info=None,
                       dataset=None):
        params = dict(num_steps=num_steps, profile=profile, info=info)
        remote_worker_stats = [
            w.train_epoch.remote(**params) for w in self.remote_workers
        ]

        if self._num_failures < num_fails:
            time.sleep(1)  # Make the batch will fail correctly.
            ray.kill(self.remote_workers[0])

        try:
            local_worker_stats = self.local_worker.train_epoch(**params)
        except RuntimeError:
            return False, None

        success = check_for_failure(remote_worker_stats)
        if success:
            return success, [local_worker_stats] + ray.get(remote_worker_stats)

        return success, None
示例#14
0
    def _train_epoch(self,
                     num_steps=None,
                     profile=False,
                     info=None,
                     batch_logs_handler=None):
        worker_trains = [
            w.train_epoch.remote(num_steps=num_steps,
                                 profile=profile,
                                 info=info) for w in self.workers
        ]

        if not self.handlers:
            success = check_for_failure(worker_trains)
            return success, worker_trains

        unfinished = worker_trains
        try:
            while len(unfinished) > 0:
                finished, unfinished = ray.wait(unfinished,
                                                timeout=BATCH_LOGS_RATE_LIMIT)

                # throw errors on agent failure
                finished = ray.get(finished)

                futures = [h.update() for h in self.handlers]
                loop = asyncio.get_event_loop()
                if loop.is_closed():
                    loop = asyncio.new_event_loop()
                    asyncio.set_event_loop(loop)
                loop.run_until_complete(asyncio.wait(futures))
                loop.close()

            return True, worker_trains
        except RayActorError as exc:
            logger.exception(str(exc))
        return False, worker_trains
示例#15
0
    def _train_epoch(self,
                     num_steps=None,
                     profile=False,
                     info=None,
                     dataset=None):
        params = dict(num_steps=num_steps, profile=profile, info=info)
        remote_worker_stats = []
        if dataset:
            dataset.set_num_shards(self.max_replicas)
        for i, w in enumerate(self.remote_workers):
            params = dict(num_steps=num_steps, profile=profile, info=info)
            if dataset:
                params["iterator"] = dataset.get_shard(i)
            stats = w.train_epoch.remote(**params)
            remote_worker_stats.append(stats)

        try:
            if dataset:
                params["iterator"] = dataset.get_shard(
                    len(self.remote_workers))
            local_worker_stats = self.local_worker.train_epoch(**params)
        except RuntimeError as err:
            if "gloo" in err.args[0] and "Timed out" in err.args[0]:
                logger.warning(err)
                return False, None
            if "NCCL" in err.args[0]:  # there is no specific error message
                logger.warning(err)
                return False, None

            raise err

        success = check_for_failure(remote_worker_stats)
        if success:
            return success, [local_worker_stats] + ray.get(remote_worker_stats)

        return success, None
示例#16
0
    def _train_epoch(self, num_steps=None, profile=False, info=None):
        params = dict(num_steps=num_steps, profile=profile, info=info)

        remote_worker_stats = [
            w.train_epoch.remote(**params) for w in self.remote_workers
        ]

        try:
            local_worker_stats = self.local_worker.train_epoch(**params)
        except RuntimeError as err:
            if "gloo" in err.args[0] and "Timed out" in err.args[0]:
                logger.warning(err)
                return False, None
            if "NCCL" in err.args[0]:  # there is no specific error message
                logger.warning(err)
                return False, None

            raise err

        success = check_for_failure(remote_worker_stats)
        if success:
            return success, [local_worker_stats] + ray.get(remote_worker_stats)

        return success, None
示例#17
0
 def _train_step(self):
     worker_stats = [w.step.remote() for w in self.workers]
     success = utils.check_for_failure(worker_stats)
     return success, worker_stats
示例#18
0
    def _train(self):
        self.logger.debug(
            f"_train: {self._trial_info.trial_name}({self.iteration})")
        try:
            # Check if restore checkpoint file fulfills the stop criteria on first run
            initial_val_result = None
            if self._first_run:
                self._first_run = False
                if self._restored and self._should_stop():
                    self.logger.warning(
                        f"Restored checkpoint file '{self.restore_checkpoint_file}' "
                        f"fulfills stop criteria without additional training.")
                    return {
                        # do not train or log results, just stop
                        RESULT_DUPLICATE: True,
                        DONE: True
                    }

                # This initial validation would be simpler if it were in the
                # ImagenetExperiment, but doing it here makes it possible to log
                # the validation results immediately rather than waiting until
                # the end of the first epoch.
                if self._iteration == 0 and self._validate_immediately:
                    self.logger.debug("Validating before any training:")
                    status = []
                    for w in self.procs:
                        status.append(w.validate.remote())

                    if ray_utils.check_for_failure(status):
                        results = ray.get(status)
                        initial_val_result = (
                            self.experiment_class.aggregate_validation_results(
                                results))
                        self.logger.info(initial_val_result)

            status = []
            for w in self.procs:
                status.append(w.run_epoch.remote())

            # Wait for remote functions and check for errors
            # Aggregate the results from all processes
            if ray_utils.check_for_failure(status):
                results = ray.get(status)

                ret = self.experiment_class.aggregate_results(results)
                if initial_val_result is not None:
                    ret["extra_val_results"].insert(0, (0, initial_val_result))
                self.logger.info(
                    self.experiment_class.get_printable_result(ret))
                self._process_result(ret)

                # Check if we should stop the experiment
                ret[DONE] = self._should_stop()

                return ret

            err_msg = (f"{self._trial_info.trial_name}({self.iteration}): "
                       f"One of the remote workers failed during training")
            self.logger.error(err_msg)
            raise RuntimeError(err_msg)
        except Exception:
            self._kill_workers()
            raise
示例#19
0
    def _create_workers(self, config):
        """
        Create one ray remote process for each GPU/process
        """
        num_gpus = config.get("num_gpus", 0)
        num_cpus = config.get("num_cpus", 1)

        # Determine the number of distributed processes based on the number
        # GPUs and CPUs
        if num_gpus > 0:
            world_size = num_gpus
            # Assign one GPU per remote process
            num_gpus = 1
            # Assign extra CPUs for dataloaders
            num_cpus = config.get("workers", 0)
        else:
            world_size = num_cpus
            # Assign one CPU per remote process
            num_cpus = 1

        self._process_config(config)

        experiment_class = config["experiment_class"]

        for i in range(1 + self.max_retries):
            self.procs = []
            for _ in range(world_size):
                experiment = ray.remote(num_cpus=num_cpus,
                                        num_gpus=num_gpus)(experiment_class)
                self.procs.append(experiment.remote())

            # Use first process as head of the group
            ip = ray.get(self.procs[0].get_node_ip.remote())
            port = ray.get(self.procs[0].get_free_port.remote())
            port = config.get("dist_port", port)
            dist_url = "tcp://{}:{}".format(ip, port)

            # Configure each process in the group
            status = []
            for rank, w in enumerate(self.procs):
                worker_config = copy.deepcopy(config)
                worker_config["distributed"] = True
                worker_config["dist_url"] = dist_url
                worker_config["world_size"] = world_size
                worker_config["rank"] = rank
                status.append(w.setup_experiment.remote(worker_config))
                self.logger.debug(
                    f"_create_workers: rank={rank}, "
                    f"trial={self._trial_info.trial_name}({self.iteration})")

            # Wait for remote function and check for errors
            if ray_utils.check_for_failure(status):
                return

            # Remote function failed, kill workers and try again
            self.logger.warning(f"Failed to create workers, "
                                f"retrying {i + 1}/{self.max_retries}")
            # Restart all workers on failure
            self._kill_workers()

            # Back off a few seconds
            time.sleep(2**i)
        else:
            # Reached max failures
            err_msg = f"Failed to create workers after {self.max_retries} retries"
            self.logger.error(err_msg)
            raise RuntimeError(err_msg)