Exemple #1
0
    def train_remote(self, model: ObjectRef, queue: Queue = None, **kwargs):
        """Training function to be executed on each remote worker."""
        self._model = ray.get(model)
        self.lightning_module.trainer.accelerator_connector\
            ._training_type_plugin = self
        self.lightning_module.trainer.accelerator.training_type_plugin = self

        hvd.init()
        self.global_rank = hvd.rank()
        self.local_rank = hvd.local_rank()
        self.world_size = hvd.size()
        rank_zero_only.rank = self.global_rank

        if queue is not None:
            # Initialize session.
            init_session(rank=self.global_rank, queue=queue)

        # Move the model to the appropriate device.
        super(HorovodRayPlugin, self).model_to_device()

        # TODO: Make changes in PTL to clean this up.
        super(HorovodRayPlugin, self).pre_dispatch()
        results = super(HorovodRayPlugin,
                        self).start_training(self.lightning_module.trainer)
        if self.global_rank != 0:
            # Only want results from the first worker.
            return None

        best_model_path = None
        if self.lightning_module.trainer.checkpoint_callback is not None:
            best_model_path = \
                self.lightning_module.trainer.checkpoint_callback\
                    .best_model_path

        return results, self.lightning_module.state_dict(), best_model_path
Exemple #2
0
    def train_remote(self,
                     model: LightningModule,
                     global_rank: int,
                     queue: Queue = None):
        """Training function to be executed on each remote worker."""
        assert isinstance(self, RayPlugin)
        # This method should be executed remotely in each worker.
        self._model = model
        self.lightning_module.trainer.accelerator_connector\
            ._training_type_plugin = self
        self.lightning_module.trainer.accelerator.training_type_plugin = self
        self.global_rank = global_rank

        if queue is not None:
            # Initialize session.
            init_session(rank=global_rank, queue=queue)

        # Calling new_process will call
        # transfer_distrib_spawn_state_on_fit_end.
        # We override that method and have it just set attributes.
        # Then we can just return those attributes here.
        super(RayPlugin,
              self).new_process(process_idx=global_rank,
                                trainer=self.lightning_module.trainer,
                                mp_queue=None)
        # Only need results from worker 0.
        if self.global_rank == 0:
            return self.results, self.best_model_path, self.model_state_dict
        else:
            return None
    def train_remote(self,
                     trainer: Trainer,
                     global_rank: int,
                     queue: Queue = None):
        """Training function to be executed on each remote worker."""
        assert isinstance(self, RayAccelerator)
        # This method should be executed remotely in each worker.
        self.trainer = trainer
        self.trainer.accelerator_backend = self
        self.global_rank = global_rank
        model = self.trainer.model

        if queue is not None:
            # Initialize session.
            init_session(rank=global_rank, queue=queue)

        # Calling ddp_train will call transfer_distrib_spawn_state_on_fit_end.
        # We override that method and have it just set attributes.
        # Then we can just return those attributes here.
        super(RayAccelerator, self).ddp_train(
            process_idx=global_rank, mp_queue=None, model=model)
        return self.results, self.best_model_path, self.model_state_dict
Exemple #4
0
    def train_remote(self, trainer_ref: ObjectRef, queue: Queue = None):
        """Training function to be executed on each remote worker."""
        self.trainer = ray.get(trainer_ref)
        hvd.init()
        if queue is not None:
            # Initialize session.
            init_session(rank=hvd.rank(), queue=queue)
        if self.trainer.on_gpu:
            # Horovod assigns one local GPU per process.
            self.trainer.root_gpu = hvd.local_rank()

        # TODO: Make changes in PTL to clean this up.
        super(HorovodRayAccelerator, self).setup(self.trainer.model)
        results = super(HorovodRayAccelerator, self).train()
        if hvd.rank() != 0:
            # Only want results from the first worker.
            return None

        best_model_path = None
        if self.trainer.checkpoint_callback is not None:
            best_model_path = self.trainer.checkpoint_callback.best_model_path

        model = self.trainer.model
        return results, model.state_dict(), best_model_path