def training_loop(self) -> None:
        scaling_config_dataclass = self._validate_and_get_scaling_config_data_class(
            self.scaling_config
        )

        train_loop_per_worker = construct_train_func(
            self._train_loop_per_worker,
            self._train_loop_config,
            fn_arg_name="train_loop_per_worker",
        )

        additional_resources_per_worker = (
            scaling_config_dataclass.additional_resources_per_worker
        )

        trial_info = TrialInfo(
            name=session.get_trial_name(),
            id=session.get_trial_id(),
            resources=session.get_trial_resources(),
            logdir=os.getcwd(),
        )

        backend_executor = BackendExecutor(
            backend_config=self._backend_config,
            trial_info=trial_info,
            num_workers=scaling_config_dataclass.num_workers,
            num_cpus_per_worker=scaling_config_dataclass.num_cpus_per_worker,
            num_gpus_per_worker=scaling_config_dataclass.num_gpus_per_worker,
            additional_resources_per_worker=additional_resources_per_worker,
            max_retries=0,
        )

        checkpoint_manager = self._checkpoint_manager_cls(
            preprocessor=self.preprocessor
        )

        # Start the remote actors.
        backend_executor.start(initialization_hook=None)

        training_iterator = TrainingIterator(
            backend_executor=backend_executor,
            backend_config=self._backend_config,
            train_func=train_loop_per_worker,
            dataset_spec=self._ingest_spec,
            checkpoint_manager=checkpoint_manager,
            checkpoint=self.resume_from_checkpoint,
            checkpoint_strategy=None,
        )

        for results in training_iterator:
            # TODO(ml-team): add ability to report results from multiple workers.
            first_worker_results = results[0]

            tune.report(**first_worker_results)

        # Shutdown workers.
        backend_executor.shutdown()
Beispiel #2
0
    def run_iterator(
        self,
        train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
        config: Optional[Dict[str, Any]] = None,
        dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]] = None,
        checkpoint: Optional[Union[Dict, str, Path]] = None,
        checkpoint_strategy: Optional[CheckpointConfig] = None,
    ) -> "TrainingIterator":
        """Same as ``run`` except returns an iterator over the results.

        This is useful if you want to have more customization of what to do
        with the intermediate results or how to use the ``Trainer`` with Ray
        Tune.

        .. code-block:: python

            def train_func(config):
                ...
                for _ in config["epochs"]:
                    metrics = train()
                    metrics = validate(...)
                    ray.train.report(**metrics)
                return model

            iterator = trainer.run_iterator(train_func, config=config)

            for result in iterator:
                do_stuff(result)
                latest_ckpt = trainer.get_latest_checkpoint()

            assert iterator.is_finished()
            model = iterator.get_fin()[0]

        Args:
            train_func: The training function to execute.
                This can either take in no arguments or a ``config`` dict.
            config (Optional[Dict]): Configurations to pass into
                ``train_func``. If None then an empty Dict will be created.
            checkpoint (Optional[Dict|Path|str]): The checkpoint data that
                should be loaded onto each worker and accessed by the
                training function via ``train.load_checkpoint()``. If this is a
                ``str`` or ``Path`` then the value is expected to be a path
                to a file that contains a serialized checkpoint dict. If this
                is ``None`` then no checkpoint will be loaded.
            checkpoint_strategy (Optional[CheckpointConfig]): The
                configurations for saving checkpoints.

        Returns:
            An Iterator over the intermediate results from ``train.report()``.
        """
        # Create new log directory for this run.
        self._run_id += 1
        self.create_run_dir()

        train_func = construct_train_func(train_func, config)

        dataset_spec = RayDatasetSpec(dataset_or_dict=dataset)

        return TrainingIterator(
            backend_executor=self._backend_executor,
            backend_config=self._backend_config,
            train_func=train_func,
            run_dir=self.latest_run_dir,
            dataset_spec=dataset_spec,
            checkpoint_manager=self.checkpoint_manager,
            checkpoint=checkpoint,
            checkpoint_strategy=checkpoint_strategy,
        )
Beispiel #3
0
    def run(
        self,
        train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
        config: Optional[Dict[str, Any]] = None,
        callbacks: Optional[List[TrainingCallback]] = None,
        dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]] = None,
        checkpoint: Optional[Union[Dict, str, Path]] = None,
        checkpoint_strategy: Optional[CheckpointConfig] = None,
    ) -> List[T]:
        """Runs a training function in a distributed manner.

        Args:
            train_func: The training function to execute.
                This can either take in no arguments or a ``config`` dict.
            config (Optional[Dict]): Configurations to pass into
                ``train_func``. If None then an empty Dict will be created.
            callbacks (Optional[List[TrainingCallback]]): A list of Callbacks
                which will be executed during training. If this is not set,
                currently there are NO default Callbacks.
            dataset (Optional[Union[RayDataset, Dict[str, RayDataset]]]):
                Distributed Ray :ref:`Dataset <dataset-api>` or
                :ref:`DatasetPipeline <dataset-pipeline-api>` to pass into the
                workers, which can be accessed from the training function via
                ``train.get_dataset_shard()``. Sharding will automatically be
                handled by the Trainer. Multiple Datasets can be passed in as
                a ``Dict`` that maps each name key to a Dataset value,
                and each Dataset can be accessed from the training function
                by passing in a `dataset_name` argument to
                ``train.get_dataset_shard()``.
            checkpoint (Optional[Dict|str|Path]): The checkpoint data that
                should be loaded onto each worker and accessed by the training
                function via ``train.load_checkpoint()``. If this is a ``str``
                or ``Path`` then the value is expected to be a path to a file
                that contains a serialized checkpoint dict. If this is
                ``None`` then no checkpoint will be loaded.
            checkpoint_strategy (Optional[CheckpointConfig]): The
                configurations for saving checkpoints.

        Returns:
            A list of results from the training function. Each value in the
            list corresponds to the output of the training function from
            each worker.
        """
        # Create new log directory for this run.
        self._run_id += 1
        self.create_run_dir()

        # TODO(matt): Set default callbacks.
        callbacks = [] if callbacks is None else callbacks
        finished_with_errors = False

        for callback in callbacks:
            callback.start_training(
                logdir=str(self.latest_run_dir), config=config or {}
            )

        train_func = construct_train_func(train_func, config)

        dataset_spec = RayDatasetSpec(dataset_or_dict=dataset)

        try:
            iterator = TrainingIterator(
                backend_executor=self._backend_executor,
                backend_config=self._backend_config,
                train_func=train_func,
                dataset_spec=dataset_spec,
                checkpoint_manager=self.checkpoint_manager,
                checkpoint=checkpoint,
                checkpoint_strategy=checkpoint_strategy,
                run_dir=self.latest_run_dir,
            )
            for intermediate_result in iterator:
                for callback in callbacks:
                    callback.process_results(intermediate_result)

            assert iterator.is_finished()
            return iterator.get_final_results()
        finally:
            for callback in callbacks:
                callback.finish_training(error=finished_with_errors)