def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: self.mp_queue = mp_queue if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: trainer.progress_bar_callback.disable() shared_params = find_shared_parameters(self.model) self.model_to_device() if is_overridden("on_post_move_to_device", self.lightning_module): self.model.module.on_post_move_to_device() else: set_shared_parameters(self.model.module, shared_params) trainer.accelerator.setup_optimizers(trainer) trainer.precision_plugin.connect(self._model, None, None) self.barrier("pre-run-stage") results = trainer.run_stage() self.__transfer_distrib_spawn_state_on_fit_end(trainer, results) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 self.barrier("end-process") # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 if self.local_rank == 0: time.sleep(2) # ensure that spawned processes go through teardown before joining trainer._call_teardown_hook()
def setup(self) -> None: shared_params = find_shared_parameters(self.model) self.model_to_device() if is_overridden("on_post_move_to_device", self.lightning_module): self.model.on_post_move_to_device() else: set_shared_parameters(self.model, shared_params)
def setup(self, trainer: "pl.Trainer") -> None: assert self.model, "self.model must be set before find_shared_parameters(self.model)" shared_params = find_shared_parameters(self.model) self.model_to_device() set_shared_parameters(self.model, shared_params) super().setup(trainer) if self.debug: os.environ["PT_XLA_DEBUG"] = str(1)
def test_auto_parameters_tying_tpus(tmpdir): model = WeightSharingModule() shared_params = find_shared_parameters(model) assert shared_params[0] == ["layer_1.weight", "layer_3.weight"] trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=5, tpu_cores=8, max_epochs=1) trainer.fit(model) assert torch.all(torch.eq(model.layer_1.weight, model.layer_3.weight))
def setup(self, trainer: "pl.Trainer") -> None: self.accelerator.setup(trainer) if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) shared_params = find_shared_parameters(self.model) self.model_to_device() set_shared_parameters(self.model.module, shared_params) self.setup_precision_plugin() if trainer.state.fn == TrainerFn.FITTING: self.setup_optimizers(trainer) optimizers_to_device(self.optimizers, self.root_device)
def setup(self, trainer: "pl.Trainer") -> None: shared_params = find_shared_parameters(self.model) self.model_to_device() if is_overridden("on_post_move_to_device", self.lightning_module): self.model.on_post_move_to_device() else: set_shared_parameters(self.model, shared_params) super().setup(trainer) if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal()
def setup(self, trainer: "pl.Trainer") -> None: self.start_method = "fork" self.accelerator.setup(trainer) self.setup_optimizers(trainer) self.setup_precision_plugin() optimizers_to_device(self.optimizers, self.root_device) if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) shared_params = find_shared_parameters(self.model) self.model_to_device() if is_overridden("on_post_move_to_device", self.lightning_module): self.model.module.on_post_move_to_device() else: set_shared_parameters(self.model.module, shared_params) self.setup_optimizers(trainer) self.precision_plugin.connect(self.model, None, None)
def test_find_shared_parameters(model, expected_shared_params): assert expected_shared_params == find_shared_parameters(model())