Exemplo n.º 1
0
class LocalWorkerGroup(WorkerGroupInterface):
    """A group of TorchRunner workers.
    1 worker runs locally, and all the other workers are remote actors.

    Args:
        Same as RemoteWorkerGroup.
    """
    def __init__(
        self,
        max_workers,
        params,
        dist_params,
        initialization_hook,
        timeout_s,
        num_cpus_per_worker,
        use_gpu,
    ):

        # Invariant: These variables should never change state!
        self._max_workers = max_workers
        self._params = params
        self._dist_params = dist_params
        self._initialization_hook = initialization_hook
        self._timeout_s = timeout_s
        self._num_cpus_per_worker = num_cpus_per_worker
        self._use_gpu = use_gpu

        self.local_worker = None
        self.remote_worker_group = RemoteWorkerGroup(
            max_workers=max_workers - 1,
            params=params,
            dist_params=dist_params,
            initialization_hook=initialization_hook,
            timeout_s=timeout_s,
            num_cpus_per_worker=num_cpus_per_worker,
            use_gpu=use_gpu,
        )

    def start_workers(self, num_workers):
        logger.debug(f"start_workers: Setting {num_workers} workers.")

        if num_workers == 1:
            self.local_worker = TorchRunner(**self._params)
            if self._initialization_hook:
                self.apply_all_workers(self._initialization_hook)
            self.local_worker.setup_operator()
            return True
        else:
            try:
                # Start local worker
                self.local_worker = LocalDistributedRunner(
                    num_cpus=self._num_cpus_per_worker,
                    num_gpus=int(self._use_gpu),
                    **{
                        **self._params,
                        **self._dist_params
                    },
                )
                self.remote_worker_group._init_dist_workers(num_workers - 1)
                if self._initialization_hook:
                    self.apply_all_workers(self._initialization_hook)

                # Compute URL for initializing distributed PyTorch.
                address = setup_address()

                remote_pgs = self.remote_worker_group._setup_process_group(
                    address=address, world_size=num_workers, starting_rank=1)
                # Use the local worker as rank 0. Helps with debugging.
                self.local_worker.setup_process_group(
                    url=address,
                    world_rank=0,
                    world_size=num_workers,
                    timeout=timedelta(seconds=self._timeout_s),
                )
                ray.get(remote_pgs)

                local_node_ip = ray.util.get_node_ip_address()
                rank_dict = defaultdict(int)
                self.local_worker.set_local_rank(local_rank=0)
                rank_dict[local_node_ip] += 1
                self.remote_worker_group._setup_local_rank(rank_dict)

                remote_operators = self.remote_worker_group._setup_operator()
                self.local_worker.setup_operator()
                ray.get(remote_operators)
                return True
            except RayActorError:
                return False

    def apply_all_operators(self, fn):
        remote_calls = self.remote_worker_group._apply_all_operators(fn)
        local_call = self.local_worker.apply_operator(fn)
        return [local_call] + ray.get(remote_calls)

    def apply_all_workers(self, fn):
        remote_calls = self.remote_worker_group._apply_all_workers(fn)
        local_call = self.local_worker.apply(fn)
        return [local_call] + ray.get(remote_calls)

    def get_local_operator(self):
        return self.local_worker.training_operator

    def get_model(self, to_cpu=False):
        models = self.local_worker.models
        if to_cpu:
            models = [m.cpu() for m in models]
        return models

    def load_state_dict(self, state_dict, blocking=False):
        # This is not the most efficient because you have to wait for
        # the local worker to save then dump to buffer.
        self.local_worker.load_state_dict(state_dict)
        state_id = ray.put(self.local_worker.state_stream())

        remote_calls = self.remote_worker_group._load_state_id(state_id)
        if blocking:
            ray.get(remote_calls)

    def state_dict(self):
        return self.local_worker.state_dict()

    def should_scale_up(self):
        return self.remote_worker_group.should_scale_up()

    def reset(self):
        """Terminates models without giving up local resource reservation."""
        if not isinstance(self.local_worker, LocalDistributedRunner):
            self.local_worker.shutdown()
        else:
            self.local_worker.shutdown(cleanup=False)
        self.remote_worker_group.reset()

        self.local_worker = None
        self.remote_worker_group = RemoteWorkerGroup(
            max_workers=self._max_workers - 1,
            params=self._params,
            dist_params=self._dist_params,
            initialization_hook=self._initialization_hook,
            num_cpus_per_worker=self._num_cpus_per_worker,
            use_gpu=self._use_gpu,
            timeout_s=self._timeout_s,
        )

    def new_workers_size(self):
        return self.remote_worker_group.new_workers_size() + 1

    def train(self, num_steps=None, profile=False, info=None, dataset=None):
        params = dict(num_steps=num_steps, profile=profile, info=info)
        if dataset:
            dataset.set_num_shards(self.num_workers)

        remote_worker_stats = self.remote_worker_group._train(
            num_steps, profile, info, dataset)
        try:
            if dataset:
                params["iterator"] = dataset.get_shard(self.num_workers - 1)
            local_worker_stats = self.local_worker.train_epoch(**params)
        except RuntimeError as err:
            if "gloo" in err.args[0] and "Timed out" in err.args[0]:
                logger.warning(err)
                return False, None
            if "NCCL" in err.args[0]:  # there is no specific error message
                logger.warning(err)
                return False, None
            if "Connection closed by peer" in err.args[0]:
                logger.warning(err)
                return False, None

            raise err

        success = check_for_failure(remote_worker_stats)
        if success:
            return success, [local_worker_stats] + ray.get(remote_worker_stats)

        return success, None

    def validate(self, num_steps=None, profile=False, info=None):
        params = dict(num_steps=num_steps, profile=profile, info=info)

        remote_worker_stats = self.remote_worker_group._validate(params)
        local_worker_stats = self.local_worker.validate(**params)
        worker_stats = [local_worker_stats] + ray.get(remote_worker_stats)
        return worker_stats

    def shutdown(self, force=False):
        self.local_worker.shutdown()
        self.remote_worker_group.shutdown(force=force)

        self.local_worker = None
        self.remote_worker_group = DeactivatedWorkerGroup()

    @property
    def num_workers(self):
        return self.remote_worker_group.num_workers + 1

    @property
    def remote_workers(self):
        return self.remote_worker_group.remote_workers
Exemplo n.º 2
0
class TorchTrainer:
    """Train a PyTorch model using distributed PyTorch.

    Launches a set of actors which connect via distributed PyTorch and
    coordinate gradient updates to train the provided model.

    .. code-block:: python

        ray.init()

        def model_creator(config):
            return nn.Linear(1, 1)


        def optimizer_creator(model, config):
            return torch.optim.SGD(
                model.parameters(), lr=config.get("lr", 1e-4))


        def data_creator(config):
            batch_size = config["batch_size"]
            train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5)
            train_loader = DataLoader(train_data, batch_size=batch_size)
            val_loader = DataLoader(val_data, batch_size=batch_size)
            return train_loader, val_loader


        trainer = TorchTrainer(
            model_creator=model_creator,
            data_creator=data_creator,
            optimizer_creator=optimizer_creator,
            loss_creator=nn.MSELoss,
            config={"batch_size": 32},
            use_gpu=True
        )
        for i in range(4):
            trainer.train()


    Args:
        model_creator (dict -> Model(s)): Constructor function that takes in
            config and returns the model(s) to be optimized. These must be
            ``torch.nn.Module`` objects. If multiple models are returned,
            a ``training_operator_cls`` must be specified. You do not need to
            handle GPU/devices in this function; RaySGD will do that under
            the hood.
        data_creator (dict -> Iterable(s)): Constructor function
            that takes in the passed config and returns one or
            two Iterable objects. Note that even though two Iterable objects
            can be returned, only one will be used for training, and the
            other will be used for validation. If not provided, you must
            provide a custom TrainingOperator.
        optimizer_creator ((models, dict) -> optimizers): Constructor
            function that takes in the return values from
            ``model_creator`` and the passed config and returns One or
            more Torch optimizer objects. You do not need to handle
            GPU/devices in this function; ``RaySGD`` will do that for you.
        loss_creator (torch.nn.*Loss class | dict -> loss): A constructor
            function for the training loss. This can be either a function that
            takes in the provided config for customization or a subclass
            of ``torch.nn.modules.loss._Loss``, which is most Pytorch
            loss classes. For example, ``loss_creator=torch.nn.BCELoss``.
            If not provided, you must provide a custom TrainingOperator.
        scheduler_creator ((optimizers, dict) -> scheduler):
            A constructor function for the torch scheduler. This is
            a function that takes in the generated optimizers (from
            ``optimizer_creator``) provided config for customization.
            Be sure to set ``scheduler_step_freq`` to increment the
            scheduler correctly.
        training_operator_cls (type): Custom training operator class
            that subclasses the TrainingOperator class. This class
            will be copied onto all remote workers and used to specify
            custom training and validation operations. Defaults to
            TrainingOperator.
        config (dict): Custom configuration value to be passed to
            all creator and operator constructors.
        num_workers (int): the number of workers used in distributed
            training. If 1, the worker will not be wrapped with
            DistributedDataParallel.
        use_gpu (bool): Sets resource allocation for workers to 1 GPU
            if true, and automatically moves both the model and optimizer
            to the available CUDA device.
        backend (string): backend used by distributed PyTorch. Currently
            support "nccl", "gloo", and "auto". If "auto", RaySGD will
            automatically use "nccl" if `use_gpu` is True, and "gloo"
            otherwise.
        use_fp16 (bool): Enables mixed precision training via apex if apex
            is installed. This is automatically done after the model and
            optimizers are constructed and will work for multi-model training.
            Please see https://github.com/NVIDIA/apex for more details.
        apex_args (dict|None): Dict containing keyword args for amp.initialize.
            See https://nvidia.github.io/apex/amp.html#module-apex.amp. By
            default, the models and optimizers are passed in. Consider using
            "num_losses" if operating over multiple models and optimizers.
        scheduler_step_freq: "batch", "epoch", or None. This will
            determine when ``scheduler.step`` is called. If "batch",
            ``step`` will be called after every optimizer step. If "epoch",
            ``step`` will be called after one pass of the DataLoader.

    """

    def __init__(
            self,
            *,
            model_creator,
            data_creator,
            optimizer_creator,
            loss_creator=None,
            scheduler_creator=None,
            training_operator_cls=None,
            initialization_hook=None,
            config=None,
            num_workers=1,
            use_gpu=False,
            backend="auto",
            use_fp16=False,
            use_tqdm=False,
            apex_args=None,
            scheduler_step_freq="batch",
            num_replicas=None,
            batch_size=None,
            data_loader_args=None,
    ):
        if num_workers > 1 and not dist.is_available():
            raise ValueError(
                ("Distributed PyTorch is not supported on macOS. "
                 "To run without distributed PyTorch, set 'num_workers=1'. "
                 "For more information, see "
                 "https://github.com/pytorch/examples/issues/467."))

        if not (callable(model_creator) and callable(optimizer_creator)
                and callable(data_creator)):
            raise ValueError(
                "Must provide a callable model_creator, optimizer_creator, "
                "and data_creator.")

        if num_replicas is not None:
            raise DeprecationWarning(
                "num_replicas is deprecated. Use num_workers instead.")

        if batch_size is not None:
            raise DeprecationWarning(
                "batch_size is deprecated. Use config={'batch_size': N} "
                "specify a batch size for each worker or "
                "config={ray.util.sgd.utils.BATCH_SIZE: N} to specify a "
                "batch size to be used across all workers.")

        if data_loader_args:
            raise ValueError(
                "data_loader_args is deprecated. You can return a "
                "torch.utils.data.DataLoader in data_creator. Ray will "
                "automatically set a DistributedSampler if a DataLoader is "
                "returned and num_workers > 1.")

        self.model_creator = model_creator
        self.optimizer_creator = optimizer_creator
        self.loss_creator = loss_creator
        self.data_creator = data_creator
        self.scheduler_creator = scheduler_creator
        self.training_operator_cls = training_operator_cls

        if not training_operator_cls and not loss_creator:
            raise ValueError("If a loss_creator is not provided, you must "
                             "provide a custom training operator.")

        self.initialization_hook = initialization_hook
        self.config = {} if config is None else config

        if backend == "auto":
            backend = "nccl" if use_gpu else "gloo"

        logger.debug("Using {} as backend.".format(backend))
        self.backend = backend

        # TODO: Have an auto "use_gpu" option to detect and use GPUs.
        self.use_gpu = use_gpu
        self.max_replicas = num_workers

        self.use_fp16 = use_fp16
        self.use_tqdm = use_tqdm

        if apex_args and not isinstance(apex_args, dict):
            raise ValueError("apex_args needs to be a dict object.")

        self.apex_args = apex_args
        self.temp_dir = tempfile.mkdtemp(prefix="raysgd")
        self._num_failures = 0
        self._last_resize = float("-inf")

        _validate_scheduler_step_freq(scheduler_step_freq)
        self.scheduler_step_freq = scheduler_step_freq

        self._start_workers(self.max_replicas)

    def _configure_and_split_batch(self, num_workers):
        """If sgd.utils.BATCH_SIZE is provided, split among workers."""
        if BATCH_SIZE not in self.config:
            return
        # Compute batch size per worker
        logger.debug("BATCH_SIZE parameter detected. Splitting among workers.")
        batch_size = self.config[BATCH_SIZE]
        batch_size_per_worker = batch_size // num_workers
        if batch_size % num_workers > 0:
            new_batch_size = batch_size_per_worker * num_workers
            logger.warning(
                ("Changing batch size from {old_batch_size} to "
                 "{new_batch_size} to evenly distribute batches across "
                 "{num_workers} workers.").format(
                     old_batch_size=batch_size,
                     new_batch_size=new_batch_size,
                     num_workers=num_workers))
            self.config[BATCH_SIZE] = new_batch_size
        return batch_size_per_worker

    def _start_workers(self, num_workers):
        logger.debug(f"start_workers: Setting %d workers." % num_workers)
        worker_config = self.config.copy()
        batch_size_per_worker = self._configure_and_split_batch(num_workers)
        if batch_size_per_worker:
            worker_config[BATCH_SIZE] = batch_size_per_worker

        self.local_worker = None
        self.remote_workers = []

        if num_workers == 1:
            # Start local worker
            self.local_worker = TorchRunner(
                model_creator=self.model_creator,
                data_creator=self.data_creator,
                optimizer_creator=self.optimizer_creator,
                loss_creator=self.loss_creator,
                scheduler_creator=self.scheduler_creator,
                training_operator_cls=self.training_operator_cls,
                config=worker_config,
                use_fp16=self.use_fp16,
                use_tqdm=self.use_tqdm,
                apex_args=self.apex_args,
                scheduler_step_freq=self.scheduler_step_freq)

            if self.initialization_hook:
                self.apply_all_workers(self.initialization_hook)

            self.local_worker.setup()
        else:
            params = dict(
                model_creator=self.model_creator,
                data_creator=self.data_creator,
                optimizer_creator=self.optimizer_creator,
                loss_creator=self.loss_creator,
                scheduler_creator=self.scheduler_creator,
                backend=self.backend,
                training_operator_cls=self.training_operator_cls,
                config=worker_config,
                use_fp16=self.use_fp16,
                use_tqdm=self.use_tqdm,
                apex_args=self.apex_args,
                scheduler_step_freq=self.scheduler_step_freq)

            # Start local worker
            self.local_worker = LocalDistributedRunner(
                num_cpus=1, num_gpus=int(self.use_gpu), **params)

            # Generate actor class
            RemoteRunner = ray.remote(
                num_cpus=1, num_gpus=int(self.use_gpu))(DistributedTorchRunner)
            # Start workers
            self.remote_workers = [
                RemoteRunner.remote(**params) for i in range(num_workers - 1)
            ]
            if self.initialization_hook:
                self.apply_all_workers(self.initialization_hook)

            # Compute URL for initializing distributed PyTorch
            ip = ray.services.get_node_ip_address()
            port = self.local_worker.find_free_port()

            address = "tcp://{ip}:{port}".format(ip=ip, port=port)

            remote_setups = [
                worker.setup.remote(address, i + 1, num_workers)
                for i, worker in enumerate(self.remote_workers)
            ]
            self.local_worker.setup(address, 0, num_workers)
            # Get setup tasks in order to throw errors on failure
            ray.get(remote_setups)

    def train(self,
              num_steps=None,
              profile=False,
              reduce_results=True,
              max_retries=0,
              checkpoint="auto",
              info=None):
        """Runs a training epoch.

        Calls `operator.train_epoch()` on N parallel workers simultaneously
        underneath the hood.

        Set `max_retries` to enable fault handling in case of
        instance preemption.

        Args:
            num_steps (int): Number of batches to compute update steps on.
                This corresponds also to the number of times
                ``TrainingOperator.train_batch`` is called.
            profile (bool): Returns time stats for the training procedure.
            reduce_results (bool): Whether to average all metrics across
                all workers into one dict. If a metric is a non-numerical
                value (or nested dictionaries), one value will be randomly
                selected among the workers. If False, returns a list of dicts.
            max_retries (int): Must be non-negative. If set to N, will
                kill all current workers, query the Ray global state for
                total available resources, and re-launch up to the
                available resources. Behavior is not well-defined
                in case of shared cluster usage.
            checkpoint (str): Path to checkpoint to restore from if retrying.
                If max_retries is set and ``checkpoint == "auto"``,
                TorchTrainer will save a checkpoint before starting to train.
            info (dict): Optional dictionary passed to the training
                operator for ``train_epoch`` and ``train_batch``.

        Returns:
            (dict | list) A dictionary of metrics for training.
                You can provide custom metrics by passing in a custom
                ``training_operator_cls``. If ``reduce_results=False``,
                this will return a list of metric dictionaries whose
                length will be equal to ``num_workers``.
        """
        assert max_retries >= 0, "`max_retries` must be non-negative."
        if max_retries:
            if checkpoint == "auto":
                logger.debug("Retrying detected. Automatically checkpointing.")
                checkpoint = self.save(
                    os.path.join(self.temp_dir, "tmp_checkpoint"))
            elif not checkpoint:
                raise ValueError("Cannot retry from empty checkpoint.")

        if checkpoint and self._should_resize():
            logger.info("Resize opportunity detected. Attempting to scale up.")
            self._resize_workers(checkpoint=checkpoint)

        success, worker_stats = self._train_epoch(
            num_steps=num_steps, profile=profile, info=info)
        # Fault handling
        for i in range(max_retries):
            if success:
                break
            else:
                self._num_failures += 1
            self._resize_workers(checkpoint=checkpoint)
            logger.info("Retrying training step with %d workers." %
                        (len(self.remote_workers) + 1))
            success, worker_stats = self._train_epoch(
                num_steps=num_steps, profile=profile, info=info)
        if not success:
            raise RuntimeError("Training run failed.")

        if reduce_results:
            return self._process_stats(worker_stats)
        else:
            return worker_stats

    def _process_stats(self, worker_stats):
        stats = {
            NUM_SAMPLES: sum(
                stats.pop(NUM_SAMPLES, np.nan) for stats in worker_stats)
        }

        for stat_key in worker_stats[0]:
            if isinstance(worker_stats[0], numbers.Number):
                stats[stat_key] = np.nanmean(
                    [s.get(stat_key, np.nan) for s in worker_stats])
            else:
                stats[stat_key] = worker_stats[0][stat_key]
        return stats

    def _train_epoch(self, num_steps=None, profile=False, info=None):
        params = dict(num_steps=num_steps, profile=profile, info=info)

        remote_worker_stats = [
            w.train_epoch.remote(**params) for w in self.remote_workers
        ]

        try:
            local_worker_stats = self.local_worker.train_epoch(**params)
        except RuntimeError as err:
            if "gloo" in err.args[0] and "Timed out" in err.args[0]:
                logger.warning(err)
                return False, None
            if "NCCL" in err.args[0]:  # there is no specific error message
                logger.warning(err)
                return False, None

            raise err

        success = check_for_failure(remote_worker_stats)
        if success:
            return success, [local_worker_stats] + ray.get(remote_worker_stats)

        return success, None

    def apply_all_workers(self, fn):
        """Run a function on all operators on the workers.

        Args:
            fn (Callable): A function that takes in no arguments.

        Returns:
            A list of objects returned by ``fn`` on each worker.

        """
        remote_calls = [w.apply.remote(fn) for w in self.remote_workers]
        local_call = self.local_worker.apply(fn)
        return [local_call] + ray.get(remote_calls)

    def apply_all_operators(self, fn):
        """Run a function on all operators on the workers.

        Args:
            fn (Callable[TrainingOperator]): A function that takes in a
                TrainingOperator.

        Returns:
            A list of objects returned by ``fn`` on each operator.

        """
        remote_calls = [
            w.apply_operator.remote(fn) for w in self.remote_workers
        ]
        local_call = self.local_worker.apply_operator(fn)
        return [local_call] + ray.get(remote_calls)

    def validate(self, num_steps=None, profile=False, info=None):
        """Evaluates the model on the validation data set.

        Args:
            num_steps (int): Number of batches to compute update steps on.
                This corresponds also to the number of times
                ``TrainingOperator.validate_batch`` is called.
            profile (bool): Returns time stats for the evaluation procedure.
            info (dict): Optional dictionary passed to the training
                operator for `validate` and `validate_batch`.

        Returns:
            A dictionary of metrics for validation.
                You can provide custom metrics by passing in a custom
                ``training_operator_cls``.
        """
        params = dict(num_steps=num_steps, profile=profile, info=info)

        remote_worker_stats = [
            w.validate.remote(**params) for w in self.remote_workers
        ]
        local_worker_stats = self.local_worker.validate(**params)

        return self._process_stats([local_worker_stats] +
                                   ray.get(remote_worker_stats))

    def update_scheduler(self, metric):
        """Calls ``scheduler.step(metric)`` on all schedulers.

        This is useful for lr_schedulers such as ``ReduceLROnPlateau``.
        """
        self.apply_all_operators(
            lambda op: [sched.step(metric) for sched in op.schedulers])

    def get_model(self):
        """Returns the learned model(s)."""
        models = self.model_creator(self.config)
        state = self.local_worker.get_state()
        if len(state["models"]) == 1:
            models.load_state_dict(state["models"][0])
        else:
            for model, state_dict in zip(models, state["models"]):
                model.load_state_dict(state_dict)
        return models

    def state_dict(self):
        return self.local_worker.get_state()

    def load_state_dict(self, state):
        state_id = ray.put(state)

        remote_calls = [
            worker.set_state.remote(state_id) for worker in self.remote_workers
        ]
        self.local_worker.set_state(state)
        ray.get(remote_calls)

    def save(self, checkpoint):
        """Saves the model(s) to the provided checkpoint.

        Args:
            checkpoint (str): Path to target checkpoint file.

        Returns:
            checkpoint (str): Path to target checkpoint file.
        """
        torch.save(self.state_dict(), checkpoint)
        return checkpoint

    def restore(self, checkpoint):
        """Restores the Trainer and all workers from the provided checkpoint.

        Args:
            checkpoint (str): Path to target checkpoint file.
        """
        state = torch.load(checkpoint)
        self.load_state_dict(state)

    def shutdown(self, force=False):
        """Shuts down workers and releases resources."""
        if not force:
            cleanup = [
                worker.shutdown.remote() for worker in self.remote_workers
            ]
            self.local_worker.shutdown()
            try:
                ray.get(cleanup)
                [
                    worker.__ray_terminate__.remote()
                    for worker in self.remote_workers
                ]
            except RayActorError:
                logger.warning(
                    "Failed to shutdown gracefully, forcing a shutdown.")

                for worker in self.remote_workers:
                    logger.warning("Killing worker {}.".format(worker))
                    ray.kill(worker)
        else:
            self.local_worker.shutdown()
            for worker in self.remote_workers:
                logger.warning("Killing worker {}.".format(worker))
                ray.kill(worker)

        self.local_worker = None
        self.remote_workers = []

    def _reset(self):
        """Terminates models without giving up local resource reservation."""
        self.local_worker.shutdown(cleanup=False)
        for worker in self.remote_workers:
            logger.warning("Killing worker {}.".format(worker))
            ray.kill(worker)
        self.local_worker = None
        self.remote_workers = []

    def _check_potential_remote_workers_size(self):
        # ASSUME 1 GPU + 1 CPU is already reserved for the local worker
        remote_resources = ray.available_resources()
        max_remote_workers = self.max_replicas - 1
        new_remote_workers = min(
            remote_resources.get("CPU", 0), max_remote_workers)
        if self.use_gpu:
            new_remote_workers = min(
                remote_resources.get("GPU", 0), new_remote_workers)
        return new_remote_workers

    def _resize_workers(self, checkpoint, max_retries=10):
        self._reset()
        assert checkpoint, "Cannot restore without checkpoint."

        time.sleep(1)
        for i in range(max_retries):
            new_remote_workers = self._check_potential_remote_workers_size()
            if new_remote_workers:
                self._last_resize = time.time()
                self._start_workers(int(new_remote_workers) + 1)
                self.restore(checkpoint)
                return
            else:
                delay = 2**i
                logger.warning(
                    "No new workers found. Retrying in %d sec." % delay)
                time.sleep(delay)
        raise RuntimeError("Exceeded max_retries for relaunching workers.")

    def _should_resize(self):
        """Returns True if past cooldown and exists resources to scale up."""
        worker_gap = self.max_replicas - 1 - len(self.remote_workers)
        past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S
        if past_cooldown and worker_gap:
            # Assume 1 resource is already reserved for local worker.
            potential_remote_size = self._check_potential_remote_workers_size()
            return potential_remote_size > 0
        return False