示例#1
0
    def train(self):
        """Main training loop.

        Trigger remote training via ``train_remote`` on each
        worker. If using with Ray Tune, create a communication queue to
        revieve intermediate results, and process those results. Finally
        retrieve the training results from the rank 0 worker and return."""
        trainer = self.trainer
        trainer_ref = ray.put(self.trainer)
        self.trainer = None

        queue = None
        if TUNE_INSTALLED and is_session_enabled():
            # Create communication queue and send to all the workers.
            queue = Queue(actor_options={"num_cpus": 0})

        result_futures = self.executor.run_async(self.train_remote,
                                                 args=[trainer_ref, queue])

        results = process_results(result_futures, queue)

        results, state_dict, best_path = results[0]

        self.trainer = trainer
        self.trainer.model.load_state_dict(state_dict)
        if self.trainer.checkpoint_callback:
            self.trainer.checkpoint_callback.best_model_path = best_path

        return results
示例#2
0
    def start_training(self, trainer):
        """Main training loop.

        Sets up the torch.distributed process group for each training
        worker. Then trigger remote training via ``train_remote`` on each
        worker. If using with Ray Tune, create a communication queue to
        revieve intermediate results, and process those results. Finally
        retrieve the training results from the rank 0 worker and return."""

        # Get rank 0 worker address and port for DDP connection.
        os.environ["MASTER_ADDR"] = ray.get(
            self.workers[0].get_node_ip.remote())
        os.environ["MASTER_PORT"] = str(
            ray.get(self.workers[0].execute.remote(find_free_port)))

        # Set environment variables for remote workers.
        keys = [
            "PL_GLOBAL_SEED", "PL_TORCH_DISTRIBUTED_BACKEND", "MASTER_ADDR",
            "MASTER_PORT"
        ]
        values = [os.getenv(k) for k in keys]
        ray.get([w.set_env_vars.remote(keys, values) for w in self.workers])

        self.global_to_local = self.get_local_ranks()

        model = self._model
        model_ref = ray.put(model)
        # Don't pickle the model when training remotely.
        self._model = None

        queue = None
        if TUNE_INSTALLED and is_session_enabled():
            # Create communication queue and send to all the workers.
            queue = Queue(actor_options={"num_cpus": 0})

        futures = [
            self.workers[i].execute.remote(self.train_remote, model_ref, i,
                                           queue)
            for i in range(self.num_workers)
        ]

        results = process_results(futures, queue)
        # Get the results, checkpoint path, and model weights from worker 0.
        results, best_path, state_dict = results[0]
        # Set the state for PTL using the output from remote training.
        self._results = results
        self._model = model
        self._model.load_state_dict(state_dict)
        if self.lightning_module.trainer.checkpoint_callback:
            self.lightning_module.trainer.checkpoint_callback\
                .best_model_path = best_path

        if queue:
            # Shutdown the queue.
            queue.shutdown()

        return results
    def train(self):
        """Main training loop.

        Sets up the torch.distributed process group for each training
        worker. Then trigger remote training via ``train_remote`` on each
        worker. If using with Ray Tune, create a communication queue to
        revieve intermediate results, and process those results. Finally
        retrieve the training results from the rank 0 worker and return."""

        if "PL_GLOBAL_SEED" in os.environ:
            seed = os.environ["PL_GLOBAL_SEED"]
            ray.get([
                w.set_env_var.remote("PL_GLOBAL_SEED", seed)
                for w in self.workers
            ])

        # Get the rank 0 address for DDP connection.
        self.ddp_address = ray.get(
            self.workers[0].execute.remote(setup_address))

        self.global_to_local = self.get_local_ranks()

        trainer = self.trainer
        assert trainer is not None
        trainer_ref = ray.put(trainer)
        # Don't pickle self.trainer when training remotely.
        self.trainer = None

        queue = None
        if TUNE_INSTALLED and is_session_enabled():
            # Create communication queue and send to all the workers.
            queue = Queue(actor_options={"num_cpus": 0})

        futures = [
            self.workers[i].execute.remote(self.train_remote, trainer_ref, i,
                                           queue)
            for i in range(self.num_workers)
        ]

        results = process_results(futures, queue)
        results, best_path, state_dict = results[0]
        self.trainer = trainer
        self.trainer.model.load_state_dict(state_dict)
        if self.trainer.checkpoint_callback:
            self.trainer.checkpoint_callback.best_model_path = best_path

        if queue:
            # Shutdown the queue.
            queue.shutdown()

        return results
示例#4
0
    def execution_loop(self, trainer, tune_enabled: bool = True):
        """Main execution loop for training, testing, & prediction.

        Sets up the torch.distributed process group for each
        worker. Then trigger remote training/testing/eval via
        ``train_remote`` on each worker. If using with Ray Tune, create a
        communication queue to retrieve intermediate results, and process
        those results. Finally retrieve the training results from the rank 0
        worker and return."""

        # Sets environment variables for all workers.
        self._setup_env_vars()

        self.global_to_local = self.get_local_ranks()

        model = self._model
        model_ref = ray.put(model)
        # Don't pickle the model when training remotely.
        self._model = None

        queue = None
        if tune_enabled and TUNE_INSTALLED and is_session_enabled():
            # Create communication queue and send to all the workers.
            queue = Queue(actor_options={"num_cpus": 0})

        futures = [
            self.workers[i].execute.remote(self.execute_remote, model_ref, i,
                                           queue)
            for i in range(self.num_workers)
        ]

        results = process_results(futures, queue)
        # Get the results, checkpoint path, and model weights from worker 0.
        results, best_path, state_stream = results[0]
        state_dict = load_state_stream(state_stream, to_gpu=self.use_gpu)
        # Set the state for PTL using the output from remote training.
        self._results = results
        self._model = model
        self._model.load_state_dict(state_dict)
        if self.lightning_module.trainer.checkpoint_callback:
            self.lightning_module.trainer.checkpoint_callback \
                .best_model_path = best_path

        if queue:
            # Shutdown the queue.
            queue.shutdown()

        return results
示例#5
0
    def start_training(self, trainer):
        """Main training loop.

        Trigger remote training via ``train_remote`` on each
        worker. If using with Ray Tune, create a communication queue to
        retrieve intermediate results, and process those results. Finally
        retrieve the training results from the rank 0 worker and return."""
        model = self._model
        model_ref = ray.put(model)
        # Don't pickle the model when training remotely.
        self._model = None

        queue = None
        if TUNE_INSTALLED and is_session_enabled():
            # Create communication queue and send to all the workers.
            queue = Queue(actor_options={"num_cpus": 0})

        result_futures = self.executor.run_remote(
            self.train_remote, args=[model_ref, queue])

        results = process_results(result_futures, queue)

        results, state_dict, best_path = results[0]
        self._results = results
        self._model = model
        self._model.load_state_dict(state_dict)
        self._model.trainer.accelerator.training_type_plugin = self
        if self.lightning_module.trainer.checkpoint_callback:
            self.lightning_module.trainer.checkpoint_callback \
                .best_model_path = best_path

        if queue:
            # Shutdown the queue.
            queue.shutdown()

        return results