def new_process(self, trainer: "pl.Trainer",
                    mp_queue: SimpleQueue) -> None:
        self.mp_queue = mp_queue

        if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
            trainer.progress_bar_callback.disable()

        shared_params = find_shared_parameters(self.model)
        self.model_to_device()
        if is_overridden("on_post_move_to_device", self.lightning_module):
            self.model.module.on_post_move_to_device()
        else:
            set_shared_parameters(self.model.module, shared_params)

        trainer.accelerator.setup_optimizers(trainer)
        trainer.precision_plugin.connect(self._model, None, None)

        self.barrier("pre-run-stage")

        results = trainer.run_stage()

        self.__transfer_distrib_spawn_state_on_fit_end(trainer, results)

        # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
        self.barrier("end-process")

        # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
        if self.local_rank == 0:
            time.sleep(2)

        # ensure that spawned processes go through teardown before joining
        trainer._call_teardown_hook()
Beispiel #2
0
 def setup(self) -> None:
     shared_params = find_shared_parameters(self.model)
     self.model_to_device()
     if is_overridden("on_post_move_to_device", self.lightning_module):
         self.model.on_post_move_to_device()
     else:
         set_shared_parameters(self.model, shared_params)
    def setup(self, trainer: "pl.Trainer") -> None:
        assert self.model, "self.model must be set before find_shared_parameters(self.model)"
        shared_params = find_shared_parameters(self.model)
        self.model_to_device()
        set_shared_parameters(self.model, shared_params)
        super().setup(trainer)

        if self.debug:
            os.environ["PT_XLA_DEBUG"] = str(1)
Beispiel #4
0
def test_auto_parameters_tying_tpus(tmpdir):

    model = WeightSharingModule()
    shared_params = find_shared_parameters(model)

    assert shared_params[0] == ["layer_1.weight", "layer_3.weight"]

    trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=5, tpu_cores=8, max_epochs=1)
    trainer.fit(model)

    assert torch.all(torch.eq(model.layer_1.weight, model.layer_3.weight))
    def setup(self, trainer: "pl.Trainer") -> None:
        self.accelerator.setup(trainer)

        if self.debug:
            os.environ["PT_XLA_DEBUG"] = str(1)

        shared_params = find_shared_parameters(self.model)
        self.model_to_device()
        set_shared_parameters(self.model.module, shared_params)
        self.setup_precision_plugin()

        if trainer.state.fn == TrainerFn.FITTING:
            self.setup_optimizers(trainer)
            optimizers_to_device(self.optimizers, self.root_device)
    def setup(self, trainer: "pl.Trainer") -> None:
        shared_params = find_shared_parameters(self.model)
        self.model_to_device()
        if is_overridden("on_post_move_to_device", self.lightning_module):
            self.model.on_post_move_to_device()
        else:
            set_shared_parameters(self.model, shared_params)

        super().setup(trainer)

        if self.debug:
            os.environ["PT_XLA_DEBUG"] = str(1)

        self.tpu_local_core_rank = xm.get_local_ordinal()
        self.tpu_global_core_rank = xm.get_ordinal()
Beispiel #7
0
    def setup(self, trainer: "pl.Trainer") -> None:
        self.start_method = "fork"
        self.accelerator.setup(trainer)
        self.setup_optimizers(trainer)
        self.setup_precision_plugin()
        optimizers_to_device(self.optimizers, self.root_device)

        if self.debug:
            os.environ["PT_XLA_DEBUG"] = str(1)

        shared_params = find_shared_parameters(self.model)
        self.model_to_device()
        if is_overridden("on_post_move_to_device", self.lightning_module):
            self.model.module.on_post_move_to_device()
        else:
            set_shared_parameters(self.model.module, shared_params)

        self.setup_optimizers(trainer)
        self.precision_plugin.connect(self.model, None, None)
Beispiel #8
0
def test_find_shared_parameters(model, expected_shared_params):

    assert expected_shared_params == find_shared_parameters(model())