예제 #1
0
파일: trainer.py 프로젝트: hngenc/ray
    def __init__(self,
                 backend: Union[str, BackendConfig],
                 num_workers: int = 1,
                 use_gpu: bool = False,
                 resources_per_worker: Optional[Dict[str, float]] = None):
        """A class for distributed training.

        Args:
            backend (Union[str, BackendConfig]): The backend used for
                distributed communication. If configurations are needed,
                a subclass of ``BackendConfig`` can be passed in.
                Supported ``str`` values: {"torch"}.
            num_workers (int): The number of workers (Ray actors) to launch.
                Defaults to 1. Each worker will reserve 1 CPU by default.
            use_gpu (bool): If True, training will be done on GPUs (1 per
                worker). Defaults to False.
            resources_per_worker (Optional[Dict]): If specified, the resources
                defined in this Dict will be reserved for each worker.
        """
        # Setup executor.
        backend_config = self._get_backend_config(backend)

        if resources_per_worker:
            raise NotImplementedError("`resources_per_worker` argument is not "
                                      "supported yet.")

        self._executor = BackendExecutor(backend_config, num_workers, 1,
                                         int(use_gpu))
예제 #2
0
파일: trainer.py 프로젝트: haochihlin/ray
    def __init__(
        self,
        backend: Union[str, BackendConfig],
        num_workers: int = 1,
        use_gpu: bool = False,
        resources_per_worker: Optional[Dict[str, float]] = None,
        logdir: Optional[str] = None,
        max_retries: int = 3,
    ):

        self._backend = backend
        self._num_workers = num_workers
        self._use_gpu = use_gpu
        self._resources_per_worker = resources_per_worker

        # Incremental unique run ID.
        self._run_id = 0

        self.logdir = self.create_logdir(logdir)

        # Setup executor.
        backend_config = self._get_backend_config(backend)

        num_cpus = 1
        num_gpus = int(use_gpu)

        if resources_per_worker:
            # Override CPU and GPU resources and remove from dict.
            num_cpus = resources_per_worker.pop("CPU", num_cpus)
            num_gpus = resources_per_worker.pop("GPU", num_gpus)
            if not use_gpu and num_gpus > 0:
                raise ValueError(
                    "`use_gpu` is False but `GPU` was found in "
                    "`resources_per_worker`. Either set `use_gpu` to True or "
                    "remove `GPU` from `resources_per_worker.")
            if use_gpu and num_gpus == 0:
                raise ValueError(
                    "`use_gpu` is True but `GPU` is set to 0 in "
                    "`resources_per_worker`. Either set `use_gpu` to False or "
                    "request a positive number of `GPU` in "
                    "`resources_per_worker.")

        self._executor = BackendExecutor(
            backend_config=backend_config,
            num_workers=num_workers,
            num_cpus_per_worker=num_cpus,
            num_gpus_per_worker=num_gpus,
            additional_resources_per_worker=resources_per_worker,
            max_retries=max_retries)
예제 #3
0
def test_shutdown(ray_start_2_cpus, tmp_path):
    config = TestConfig()
    e = BackendExecutor(config, num_workers=2)
    e.start()
    assert len(e.worker_group) == 2
    e.shutdown()
    with pytest.raises(InactiveWorkerGroupError):
        e.start_training(lambda: 1, run_dir=tmp_path)
예제 #4
0
def test_initialization_hook(ray_start_2_cpus):
    config = TestConfig()
    e = BackendExecutor(config, num_workers=2)

    def init_hook():
        import os
        os.environ["TEST"] = "1"

    e.start(initialization_hook=init_hook)

    def check():
        import os
        return os.getenv("TEST", "0")

    assert e.run(check) == ["1", "1"]
예제 #5
0
def test_torch_start_shutdown(ray_start_2_cpus, init_method, tmp_path):
    torch_config = TorchConfig(backend="gloo", init_method=init_method)
    e = BackendExecutor(torch_config, num_workers=2)
    e.start()

    def check_process_group():
        import torch
        return torch.distributed.is_initialized(
        ) and torch.distributed.get_world_size() == 2

    e.start_training(check_process_group, run_dir=tmp_path)
    assert all(e.finish_training())

    e._backend.on_shutdown(e.worker_group, e._backend_config)

    e.start_training(check_process_group, run_dir=tmp_path)
    assert not any(e.finish_training())
예제 #6
0
def test_train(ray_start_2_cpus, tmp_path):
    config = TestConfig()
    e = BackendExecutor(config, num_workers=2)
    e.start()

    e.start_training(lambda: 1, run_dir=tmp_path)
    assert e.finish_training() == [1, 1]
예제 #7
0
파일: test_backend.py 프로젝트: hngenc/ray
def test_checkpoint(ray_start_2_cpus):
    def train():
        for i in range(2):
            sgd.save_checkpoint(epoch=i)

    config = TestConfig()
    e = BackendExecutor(config, num_workers=1)
    e.start()

    e.start_training(train)
    e.finish_training()

    latest_checkpoint = e.get_latest_checkpoint()
    assert latest_checkpoint is not None
    assert latest_checkpoint["epoch"] == 1
예제 #8
0
def test_mismatch_checkpoint_report(ray_start_2_cpus, tmp_path):
    def train():
        if (sgd.world_rank()) == 0:
            sgd.save_checkpoint(epoch=0)
        else:
            sgd.report(iter=0)

    config = TestConfig()
    e = BackendExecutor(config, num_workers=2)
    e.start()
    e.start_training(train, run_dir=tmp_path)
    with pytest.raises(RuntimeError):
        e.finish_training()
예제 #9
0
def test_worker_failure(ray_start_2_cpus, tmp_path):
    config = TestConfig()
    e = BackendExecutor(config, num_workers=2)
    e.start()

    def train_fail():
        ray.actor.exit_actor()

    new_execute_func = gen_execute_special(train_fail)
    with patch.object(WorkerGroup, "execute_async", new_execute_func):
        with pytest.raises(TrainingWorkerError):
            e.start_training(lambda: 1, run_dir=tmp_path)
            e.finish_training()
예제 #10
0
def test_local_ranks(ray_start_2_cpus, tmp_path):
    config = TestConfig()
    e = BackendExecutor(config, num_workers=2)
    e.start()

    def train():
        return sgd.local_rank()

    e.start_training(train, run_dir=tmp_path)
    assert set(e.finish_training()) == {0, 1}
예제 #11
0
def test_checkpoint(ray_start_2_cpus, tmp_path):
    def train():
        for i in range(2):
            sgd.save_checkpoint(epoch=i)

    config = TestConfig()
    e = BackendExecutor(config, num_workers=1)
    e.start()

    e.start_training(train, run_dir=tmp_path)
    e.finish_training()

    assert e.latest_checkpoint is not None
    assert e.latest_checkpoint["epoch"] == 1
예제 #12
0
파일: test_backend.py 프로젝트: hngenc/ray
def test_start(ray_start_2_cpus):
    config = TestConfig()
    e = BackendExecutor(config, num_workers=2)
    with pytest.raises(InactiveWorkerGroupError):
        e.start_training(lambda: 1)
    e.start()
    assert len(e.worker_group) == 2
예제 #13
0
def test_persisted_checkpoint_id(ray_start_2_cpus, tmp_path):
    def train():
        for i in range(2):
            sgd.save_checkpoint(epoch=i)

    config = TestConfig()
    e = BackendExecutor(config)
    e.start()
    e.start_training(train, run_dir=tmp_path, latest_checkpoint_id=100)
    e.finish_training()

    assert e.latest_checkpoint_id == 102
    assert e.latest_checkpoint is not None
    assert e.latest_checkpoint["epoch"] == 1
    assert e.latest_checkpoint_path is not None

    assert os.path.exists(e.latest_checkpoint_path)
예제 #14
0
def test_no_exhaust(ray_start_2_cpus, tmp_path):
    """Tests if training can finish even if queue is not exhausted."""
    def train():
        for _ in range(2):
            sgd.report(loss=1)
        return 2

    config = TestConfig()
    e = BackendExecutor(config, num_workers=2)
    e.start()

    e.start_training(train, run_dir=tmp_path)
    output = e.finish_training()

    assert output == [2, 2]
예제 #15
0
def test_execute_worker_failure(ray_start_2_cpus):
    config = TestConfig()
    e = BackendExecutor(config, num_workers=2)
    e.start()

    def train_fail():
        ray.actor.exit_actor()

    new_execute_func = gen_execute_special(train_fail)
    with patch.object(WorkerGroup, "execute_async", new_execute_func):
        with pytest.raises(RuntimeError):
            e.run(lambda: 1)
예제 #16
0
def test_initialization_hook(ray_start_2_cpus, tmp_path):
    config = TestConfig()
    e = BackendExecutor(config, num_workers=2)

    def init_hook():
        import os
        os.environ["TEST"] = "1"

    e.start(initialization_hook=init_hook)

    def check():
        import os
        return os.getenv("TEST", "0")

    e.start_training(check, run_dir=tmp_path)
    assert e.finish_training() == ["1", "1"]
예제 #17
0
def test_cuda_visible_devices(ray_2_node_2_gpu, worker_results, tmp_path):
    config = TestConfig()

    def get_resources():
        return os.environ["CUDA_VISIBLE_DEVICES"]

    num_workers, expected_results = worker_results

    e = BackendExecutor(config,
                        num_workers=num_workers,
                        num_cpus_per_worker=0,
                        num_gpus_per_worker=1)
    e.start()
    e.start_training(get_resources, tmp_path)
    results = e.finish_training()
    results.sort()
    assert results == expected_results
예제 #18
0
def test_tensorflow_start(ray_start_2_cpus, tmp_path):
    num_workers = 2
    tensorflow_config = TensorflowConfig()
    e = BackendExecutor(tensorflow_config, num_workers=num_workers)
    e.start()

    def get_tf_config():
        import json
        import os
        return json.loads(os.environ["TF_CONFIG"])

    e.start_training(get_tf_config, run_dir=tmp_path)
    results = e.finish_training()
    assert len(results) == num_workers

    workers = [result["cluster"]["worker"] for result in results]
    assert all(worker == workers[0] for worker in workers)

    indexes = [result["task"]["index"] for result in results]
    assert len(set(indexes)) == num_workers
예제 #19
0
def test_train_failure(ray_start_2_cpus, tmp_path):
    config = TestConfig()
    e = BackendExecutor(config, num_workers=2)
    e.start()

    with pytest.raises(SGDBackendError):
        e.fetch_next_result()

    with pytest.raises(SGDBackendError):
        e.finish_training()

    e.start_training(lambda: 1, run_dir=tmp_path)

    with pytest.raises(SGDBackendError):
        e.start_training(lambda: 2, run_dir=tmp_path)

    assert e.finish_training() == [1, 1]
예제 #20
0
def test_execute(ray_start_2_cpus):
    config = TestConfig()
    e = BackendExecutor(config, num_workers=2)
    e.start()

    assert e.run(lambda: 1) == [1, 1]
예제 #21
0
파일: trainer.py 프로젝트: haochihlin/ray
class Trainer:
    """A class for enabling seamless distributed deep learning.

    Directory structure:
    - A logdir is created during instantiation. This will hold all the
    results/checkpoints for the lifetime of the Trainer. By default, it will be
    of the form ``~/ray_results/sgd_<datestring>``.
    - A run_dir is created for each ``run`` call. This will
    hold the checkpoints and results for a single ``trainer.run()`` or
    ``trainer.run_iterator()`` call. It will be of the form ``run_<run_id>``.

    Args:
        backend (Union[str, BackendConfig]): The backend used for
            distributed communication. If configurations are needed,
            a subclass of ``BackendConfig`` can be passed in.
            Supported ``str`` values: {"torch"}.
        num_workers (int): The number of workers (Ray actors) to launch.
            Defaults to 1. Each worker will reserve 1 CPU by default. The
            number of CPUs reserved by each worker can be overridden with the
            ``resources_per_worker`` argument.
        use_gpu (bool): If True, training will be done on GPUs (1 per
            worker). Defaults to False. The number of GPUs reserved by each
            worker can be overridden with the ``resources_per_worker``
            argument.
        resources_per_worker (Optional[Dict]): If specified, the resources
            defined in this Dict will be reserved for each worker. The
            ``CPU`` and ``GPU`` keys (case-sensitive) can be defined to
            override the number of CPU/GPUs used by each worker.
        logdir (Optional[str]): Path to the file directory where logs
            should be persisted. If this is not specified, one will be
            generated.
         max_retries (int): Number of retries when Ray actors fail.
            Defaults to 3. Set to -1 for unlimited retries.
    """
    def __init__(
        self,
        backend: Union[str, BackendConfig],
        num_workers: int = 1,
        use_gpu: bool = False,
        resources_per_worker: Optional[Dict[str, float]] = None,
        logdir: Optional[str] = None,
        max_retries: int = 3,
    ):

        self._backend = backend
        self._num_workers = num_workers
        self._use_gpu = use_gpu
        self._resources_per_worker = resources_per_worker

        # Incremental unique run ID.
        self._run_id = 0

        self.logdir = self.create_logdir(logdir)

        # Setup executor.
        backend_config = self._get_backend_config(backend)

        num_cpus = 1
        num_gpus = int(use_gpu)

        if resources_per_worker:
            # Override CPU and GPU resources and remove from dict.
            num_cpus = resources_per_worker.pop("CPU", num_cpus)
            num_gpus = resources_per_worker.pop("GPU", num_gpus)
            if not use_gpu and num_gpus > 0:
                raise ValueError(
                    "`use_gpu` is False but `GPU` was found in "
                    "`resources_per_worker`. Either set `use_gpu` to True or "
                    "remove `GPU` from `resources_per_worker.")
            if use_gpu and num_gpus == 0:
                raise ValueError(
                    "`use_gpu` is True but `GPU` is set to 0 in "
                    "`resources_per_worker`. Either set `use_gpu` to False or "
                    "request a positive number of `GPU` in "
                    "`resources_per_worker.")

        self._executor = BackendExecutor(
            backend_config=backend_config,
            num_workers=num_workers,
            num_cpus_per_worker=num_cpus,
            num_gpus_per_worker=num_gpus,
            additional_resources_per_worker=resources_per_worker,
            max_retries=max_retries)

    def create_logdir(self, log_dir: Optional[Union[str, Path]]) -> Path:
        """Create logdir for the Trainer."""
        # Create directory for logs.
        log_dir = Path(log_dir) if log_dir else None
        if not log_dir:
            # Initialize timestamp for identifying this SGD training execution.
            timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
            log_dir = Path(f"sgd_{timestr}")
        log_dir = construct_path(log_dir, DEFAULT_RESULTS_DIR)
        log_dir.mkdir(parents=True, exist_ok=True)
        logger.info(f"Trainer logs will be logged in: {log_dir}")
        return log_dir

    def create_run_dir(self):
        """Create rundir for the particular training run."""
        self.latest_run_dir.mkdir(parents=True, exist_ok=True)
        logger.info(f"Run results will be logged in: {self.latest_run_dir}")

    def _get_backend_config(
            self, backend: Union[str, BackendConfig]) -> BackendConfig:
        """Gets the ``BackendConfig`` to use for training.

        Args:
            backend (Union[str, BackendConfig]): If a ``BackendConfig`` is
                passed in, then it will also be returned. If a ``str`` is
                passed in, then the default config for that backend will be
                returned.

        Returns:
            The ``BackendConfig`` that will be used to set up the
            ``BackendExecutor``.
        """

        if isinstance(backend, BackendConfig):
            return backend
        elif isinstance(backend, str):
            try:
                return BACKEND_NAME_TO_CONFIG_CLS[backend]()
            except KeyError:
                raise ValueError(f"Invalid backend: {backend}. "
                                 f"Supported string values are: "
                                 f"{BACKEND_NAME_TO_CONFIG_CLS.keys()}")
        else:
            raise TypeError(f"Invalid type for backend: {type(backend)}.")

    def start(self,
              initialization_hook: Optional[Callable[[], None]] = None,
              train_cls: Optional[S] = None,
              *args,
              **kwargs):
        """Starts the training execution service.

        Args:
            initialization_hook (Optional[Callable]): The function to call on
                each worker when it is instantiated.
            train_cls (Optional[cls]): The training class that each worker
                should be instantiated as.
            args, kwargs: The arguments to pass into ``train_cls.__init__``.
        """
        self._executor.start(initialization_hook)

    def run(self,
            train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
            config: Optional[Dict[str, Any]] = None,
            callbacks: Optional[List[SGDCallback]] = None,
            checkpoint: Optional[Union[Dict, str, Path]] = None,
            checkpoint_strategy: Optional[CheckpointStrategy] = None
            ) -> List[T]:
        """Runs a training function in a distributed manner.

        Args:
            train_func (Callable): The training function to execute.
                This can either take in no arguments or a ``config`` dict.
            config (Optional[Dict]): Configurations to pass into
                ``train_func``. If None then an empty Dict will be created.
            callbacks (Optional[List[SGDCallback]]): A list of Callbacks which
                will be executed during training. If this is not set,
                currently there are NO default Callbacks.
            checkpoint (Optional[Dict|str|Path]): The checkpoint data that
                should be loaded onto each worker and accessed by the training
                function via ``sgd.load_checkpoint()``. If this is a ``str`` or
                ``Path`` then the value is expected to be a path to a file
                that contains a serialized checkpoint dict. If this is
                ``None`` then no checkpoint will be loaded.
            checkpoint_strategy (Optional[CheckpointStrategy]): The
                configurations for saving checkpoints.

        Returns:
            A list of results from the training function. Each value in the
            list corresponds to the output of the training function from
            each worker.
        """
        # Create new log directory for this run.
        self._run_id += 1
        self.create_run_dir()

        # TODO(matt): Set default callbacks.
        callbacks = [] if callbacks is None else callbacks
        finished_with_errors = False

        for callback in callbacks:
            callback.start_training(logdir=self.latest_run_dir)

        train_func = self._get_train_func(train_func, config)

        try:
            iterator = SGDIterator(
                backend_executor=self._executor,
                train_func=train_func,
                checkpoint=checkpoint,
                checkpoint_strategy=checkpoint_strategy,
                run_dir=self.latest_run_dir,
            )
            for intermediate_result in iterator:
                for callback in callbacks:
                    callback.handle_result(intermediate_result)

            assert iterator.is_finished()
            return iterator.get_final_results()
        finally:
            for callback in callbacks:
                callback.finish_training(error=finished_with_errors)

    def run_iterator(
        self,
        train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
        config: Optional[Dict[str, Any]] = None,
        checkpoint: Optional[Union[Dict, str, Path]] = None,
        checkpoint_strategy: Optional[CheckpointStrategy] = None
    ) -> Iterator[List[Dict]]:
        """Same as ``run`` except returns an iterator over the results.

        This is useful if you want to have more customization of what to do
        with the intermediate results or how to use the ``Trainer`` with Ray
        Tune.

        .. code-block:: python

            def train_func(config):
                ...
                for _ in config["epochs"]:
                    metrics = train()
                    metrics = validate(...)
                    ray.sgd.report(**metrics)
                return model

            iterator = trainer.run_iterator(train_func, config=config)

            for result in iterator:
                do_stuff(result)
                latest_ckpt = trainer.get_latest_checkpoint()

            assert iterator.is_finished()
            model = iterator.get_fin()[0]

        Args:
            train_func (Callable): The training function to execute.
                This can either take in no arguments or a ``config`` dict.
            config (Optional[Dict]): Configurations to pass into
                ``train_func``. If None then an empty Dict will be created.
            checkpoint (Optional[Dict|Path|str]): The checkpoint data that
                should be loaded onto each worker and accessed by the
                training function via ``sgd.load_checkpoint()``. If this is a
                ``str`` or ``Path`` then the value is expected to be a path
                to a file that contains a serialized checkpoint dict. If this
                is ``None`` then no checkpoint will be loaded.
            checkpoint_strategy (Optional[CheckpointStrategy]): The
                configurations for saving checkpoints.

        Returns:
            An Iterator over the intermediate results from ``sgd.report()``.
        """
        # Create new log directory for this run.
        self._run_id += 1
        self.create_run_dir()

        train_func = self._get_train_func(train_func, config)

        return SGDIterator(
            backend_executor=self._executor,
            train_func=train_func,
            checkpoint=checkpoint,
            checkpoint_strategy=checkpoint_strategy,
            run_dir=self.latest_run_dir,
        )

    def _get_train_func(self, train_func: Union[Callable[[], T],
                                                Callable[[Dict[str, Any]], T]],
                        config: Optional[Dict[str, Any]]) -> Callable[[], T]:
        """Validates and constructs the training function to execute.

        Args:
            train_func (Callable): The training function to execute.
                This can either take in no arguments or a ``config`` dict.
            config (Optional[Dict]): Configurations to pass into
                ``train_func``. If None then an empty Dict will be created.

        Returns:
            A valid training function.

        Raises:
            ValueError: if the input ``train_func`` is invalid.
        """
        signature = inspect.signature(train_func)
        num_params = len(signature.parameters)
        if num_params > 1:
            raise ValueError("train_func should take in a 0 or 1 arguments.")
        elif num_params == 1:
            config = {} if config is None else config
            return lambda: train_func(config)
        else:  # num_params == 0
            return train_func

    def execute(self, func: Callable[..., T], *args, **kwargs) -> List[T]:
        """Executes a function for all instances of ``self.train_cls``.

        Args:
            func (Callable): The function that should be executed.
                The first argument should be an instance of
                ``self.train_cls``.
            args, kwargs: The arguments to pass into ``func``.

        Returns:
            A list of results from ``func``. Each value in the
            list corresponds to the output of ``func`` from
            each worker.
        """
        raise NotImplementedError

    def execute_single(self, func: Callable[..., T], *args, **kwargs) -> T:
        """Executes a function on a single instance of ``self.train_cls``.

        Args:
            func (Callable): The function that should be executed.
                The first argument should be an instance of
                ``self.train_cls``.
            args, kwargs: The arguments to pass into ``func``.

        Returns:
            The output of ``func`` from a single worker.
        """
        raise NotImplementedError

    @property
    def latest_run_dir(self) -> Optional[Path]:
        """Path to the log directory for the latest call to ``run()``.

        Returns ``None`` if ``run()`` has not been called.
        """
        if self._run_id > 0:
            run_dir = Path(f"run_{self._run_id:03d}")
            return construct_path(run_dir, self.logdir)
        else:
            return None

    @property
    def latest_checkpoint_dir(self) -> Optional[Path]:
        """Path to the checkpoint directory.

        Returns ``None`` if ``run()`` has not been called or if
        ``sgd.checkpoint()`` has not been called from ``train_func``within
        the most recent call to ``run``.
        """
        return self._executor.latest_checkpoint_dir

    @property
    def latest_checkpoint_path(self) -> Optional[Path]:
        """Path to the latest persisted checkpoint from the latest run.

        Returns ``None`` if ``run()`` has not been called or if
        ``sgd.checkpoint()`` has not been called from ``train_func`` within
        the most recent call to ``run``.
        """
        return self._executor.latest_checkpoint_path

    @property
    def latest_checkpoint(self) -> Optional[Dict]:
        """The latest saved checkpoint.

        This checkpoint may not be saved to disk.

        Returns ``None`` if ``run()`` has not been called or if
        ``sgd.checkpoint()`` has not been called from ``train_func``.
        """
        return self._executor.latest_checkpoint

    def shutdown(self):
        """Shuts down the training execution service."""
        self._executor.shutdown()

    def to_tune_trainable(
            self, train_func: Callable[[Dict[str, Any]],
                                       T]) -> Type[Trainable]:
        """Creates a Tune ``Trainable`` from the input training function.

        Args:
            func (Callable): The function that should be executed on each
                training worker.

        Returns:
            A Trainable that can directly be passed into ``tune.run()``.
        """
        if not TUNE_INSTALLED:
            raise ValueError("Tune is not installed. Please install ray["
                             "tune] to use the Tune integration.")

        if self._executor.is_started:
            raise RuntimeError("The Trainer must not be active to use "
                               "`to_tune_trainable`. Either shutdown the "
                               "Trainer or don't start it in the first place.")

        return _create_tune_trainable(train_func, self._backend,
                                      self._num_workers, self._use_gpu,
                                      self._resources_per_worker)
예제 #22
0
파일: trainer.py 프로젝트: hngenc/ray
class Trainer:
    """A class for enabling seamless distributed deep learning.

    Args:
        backend (Union[str, BackendConfig]): The backend used for
            distributed communication. If configurations are needed,
            a subclass of ``BackendConfig`` can be passed in.
            Supported ``str`` values: {"torch"}.
        num_workers (int): The number of workers (Ray actors) to launch.
            Defaults to 1. Each worker will reserve 1 CPU by default.
        use_gpu (bool): If True, training will be done on GPUs (1 per
            worker). Defaults to False.
        resources_per_worker (Optional[Dict]): If specified, the resources
            defined in this Dict will be reserved for each worker.
    """

    def __init__(self,
                 backend: Union[str, BackendConfig],
                 num_workers: int = 1,
                 use_gpu: bool = False,
                 resources_per_worker: Optional[Dict[str, float]] = None):
        """A class for distributed training.

        Args:
            backend (Union[str, BackendConfig]): The backend used for
                distributed communication. If configurations are needed,
                a subclass of ``BackendConfig`` can be passed in.
                Supported ``str`` values: {"torch"}.
            num_workers (int): The number of workers (Ray actors) to launch.
                Defaults to 1. Each worker will reserve 1 CPU by default.
            use_gpu (bool): If True, training will be done on GPUs (1 per
                worker). Defaults to False.
            resources_per_worker (Optional[Dict]): If specified, the resources
                defined in this Dict will be reserved for each worker.
        """
        # Setup executor.
        backend_config = self._get_backend_config(backend)

        if resources_per_worker:
            raise NotImplementedError("`resources_per_worker` argument is not "
                                      "supported yet.")

        self._executor = BackendExecutor(backend_config, num_workers, 1,
                                         int(use_gpu))

    def _get_backend_config(
            self, backend: Union[str, BackendConfig]) -> BackendConfig:
        """Gets the ``BackendConfig`` to use for training.

        Args:
            backend (Union[str, BackendConfig]): If a ``BackendConfig`` is
                passed in, then it will also be returned. If a ``str`` is
                passed in, then the default config for that backend will be
                returned.

        Returns:
            The ``BackendConfig`` that will be used to set up the
            ``BackendExecutor``.
        """

        if isinstance(backend, BackendConfig):
            return backend
        elif isinstance(backend, str):
            try:
                return BACKEND_NAME_TO_CONFIG_CLS[backend]()
            except KeyError:
                raise ValueError(f"Invalid backend: {backend}. "
                                 f"Supported string values are: "
                                 f"{BACKEND_NAME_TO_CONFIG_CLS.keys()}")
        else:
            raise TypeError(f"Invalid type for backend: {type(backend)}.")

    def start(self,
              initialization_hook: Optional[Callable[[], None]] = None,
              train_cls: Optional[S] = None,
              *args,
              **kwargs):
        """Starts the training execution service.

        Args:
            initialization_hook (Optional[Callable]): The function to call on
                each worker when it is instantiated.
            train_cls (Optional[cls]): The training class that each worker
                should be instantiated as.
            args, kwargs: The arguments to pass into ``train_cls.__init__``.
        """
        self._executor.start(initialization_hook)

    def run(self,
            train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
            config: Optional[Dict[str, Any]] = None,
            callbacks: Optional[List[SGDCallback]] = None,
            checkpoint: Optional[Dict] = None) -> List[T]:
        """Runs a training function in a distributed manner.

        Args:
            train_func (Callable): The training function to execute.
                This can either take in no arguments or a ``config`` dict.
            config (Optional[Dict]): Configurations to pass into
                ``train_func``. If None then an empty Dict will be created.
            callbacks (Optional[List[SGDCallback]]): A list of Callbacks which
                will be executed during training. If this is not set,
                currently there are NO default Callbacks.
            checkpoint (Optional[Dict]): The checkpoint data that should be
                loaded onto each worker and accessed by the training function
                via ``sgd.load_checkpoint()``.

        Returns:
            A list of results from the training function. Each value in the
            list corresponds to the output of the training function from
            each worker.
        """
        train_func = self._get_train_func(train_func, config)
        # TODO(matt): Set default callbacks.
        callbacks = [] if callbacks is None else callbacks
        finished_with_errors = False

        try:
            for callback in callbacks:
                callback.start_training()
            self._executor.start_training(train_func, checkpoint)

            while True:
                intermediate_results = self._executor.fetch_next_result()
                if intermediate_results is None:
                    break
                else:
                    for callback in callbacks:
                        callback.handle_result(intermediate_results)

            return self._executor.finish_training()
        except InactiveWorkerGroupError:
            finished_with_errors = True
            raise RuntimeError(
                "This Trainer is not active. It is either shutdown already or "
                "never started in the first place. Either create a new "
                "Trainer or start this one.") from None
        except SGDBackendError:
            finished_with_errors = True
            raise RuntimeError("Training failed. You should not be seeing "
                               "this error and this is a bug. Please create "
                               "a new issue at "
                               "https://github.com/ray-project/ray.") from None
        finally:
            for callback in callbacks:
                callback.finish_training(error=finished_with_errors)

    def _get_train_func(
            self,
            train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
            config: Optional[Dict[str, Any]]) -> Callable[[], T]:
        """Validates and constructs the training function to execute.

        Args:
            train_func (Callable): The training function to execute.
                This can either take in no arguments or a ``config`` dict.
            config (Optional[Dict]): Configurations to pass into
                ``train_func``. If None then an empty Dict will be created.

        Returns:
            A valid training function.

        Raises:
            ValueError: if the input ``train_func`` is invalid.
        """
        signature = inspect.signature(train_func)
        num_params = len(signature.parameters)
        if num_params > 1:
            raise ValueError("train_func should take in a 0 or 1 arguments.")
        elif num_params == 1:
            config = {} if config is None else config
            return lambda: train_func(config)
        else:  # num_params == 0
            return train_func

    def execute(self, func: Callable[..., T], *args, **kwargs) -> List[T]:
        """Executes a function for all instances of ``self.train_cls``.

        Args:
            func (Callable): The function that should be executed.
                The first argument should be an instance of
                ``self.train_cls``.
            args, kwargs: The arguments to pass into ``func``.

        Returns:
            A list of results from ``func``. Each value in the
            list corresponds to the output of ``func`` from
            each worker.
        """
        raise NotImplementedError

    def execute_single(self, func: Callable[..., T], *args, **kwargs) -> T:
        """Executes a function on a single instance of ``self.train_cls``.

        Args:
            func (Callable): The function that should be executed.
                The first argument should be an instance of
                ``self.train_cls``.
            args, kwargs: The arguments to pass into ``func``.

        Returns:
            The output of ``func`` from a single worker.
        """
        raise NotImplementedError

    def get_latest_checkpoint(self) -> Optional[Dict]:
        """Gets the latest checkpoint for this Trainer."""
        return self._executor.get_latest_checkpoint()

    def shutdown(self):
        """Shuts down the training execution service."""
        self._executor.shutdown()

    def to_tune_trainable(
            self, train_func: Callable[[Dict[str, Any]], T]) -> Trainable:
        """Creates a Tune ``Trainable`` from the input training function.

        Args:
            func (Callable): The function that should be executed on each
                training worker.

        Returns:
            A Trainable that can directly be passed into ``tune.run()``.
        """

        def trainable_func(config: Dict[str, Any]) -> T:
            pass

        raise NotImplementedError