def new_process(self, trainer: "pl.Trainer",
                    mp_queue: SimpleQueue) -> None:
        self.mp_queue = mp_queue

        if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
            trainer.progress_bar_callback.disable()

        shared_params = find_shared_parameters(self.model)
        self.model_to_device()
        if is_overridden("on_post_move_to_device", self.lightning_module):
            self.model.module.on_post_move_to_device()
        else:
            set_shared_parameters(self.model.module, shared_params)

        trainer.accelerator.setup_optimizers(trainer)
        trainer.precision_plugin.connect(self._model, None, None)

        self.barrier("pre-run-stage")

        results = trainer.run_stage()

        self.__transfer_distrib_spawn_state_on_fit_end(trainer, results)

        # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
        self.barrier("end-process")

        # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
        if self.local_rank == 0:
            time.sleep(2)

        # ensure that spawned processes go through teardown before joining
        trainer._call_teardown_hook()
Beispiel #2
0
 def setup(self) -> None:
     shared_params = find_shared_parameters(self.model)
     self.model_to_device()
     if is_overridden("on_post_move_to_device", self.lightning_module):
         self.model.on_post_move_to_device()
     else:
         set_shared_parameters(self.model, shared_params)
Beispiel #3
0
def test_set_shared_parameters():
    model = ParameterSharingModule()
    set_shared_parameters(model, [["layer_1.weight", "layer_3.weight"]])

    assert torch.all(torch.eq(model.layer_1.weight, model.layer_3.weight))

    class SubModule(nn.Module):
        def __init__(self, layer):
            super().__init__()
            self.layer = layer

        def forward(self, x):
            return self.layer(x)

    class NestedModule(BoringModel):
        def __init__(self):
            super().__init__()
            self.layer = nn.Linear(32, 10, bias=False)
            self.net_a = SubModule(self.layer)
            self.layer_2 = nn.Linear(10, 32, bias=False)
            self.net_b = SubModule(self.layer)

        def forward(self, x):
            x = self.net_a(x)
            x = self.layer_2(x)
            x = self.net_b(x)
            return x

    model = NestedModule()
    set_shared_parameters(
        model, [["layer.weight", "net_a.layer.weight", "net_b.layer.weight"]])

    assert torch.all(
        torch.eq(model.net_a.layer.weight, model.net_b.layer.weight))
    def setup(self, trainer: "pl.Trainer") -> None:
        assert self.model, "self.model must be set before find_shared_parameters(self.model)"
        shared_params = find_shared_parameters(self.model)
        self.model_to_device()
        set_shared_parameters(self.model, shared_params)
        super().setup(trainer)

        if self.debug:
            os.environ["PT_XLA_DEBUG"] = str(1)
    def setup(self, trainer: "pl.Trainer") -> None:
        self.accelerator.setup(trainer)

        if self.debug:
            os.environ["PT_XLA_DEBUG"] = str(1)

        shared_params = find_shared_parameters(self.model)
        self.model_to_device()
        set_shared_parameters(self.model.module, shared_params)
        self.setup_precision_plugin()

        if trainer.state.fn == TrainerFn.FITTING:
            self.setup_optimizers(trainer)
            optimizers_to_device(self.optimizers, self.root_device)
    def setup(self, trainer: "pl.Trainer") -> None:
        shared_params = find_shared_parameters(self.model)
        self.model_to_device()
        if is_overridden("on_post_move_to_device", self.lightning_module):
            self.model.on_post_move_to_device()
        else:
            set_shared_parameters(self.model, shared_params)

        super().setup(trainer)

        if self.debug:
            os.environ["PT_XLA_DEBUG"] = str(1)

        self.tpu_local_core_rank = xm.get_local_ordinal()
        self.tpu_global_core_rank = xm.get_ordinal()
Beispiel #7
0
    def setup(self, trainer: "pl.Trainer") -> None:
        self.start_method = "fork"
        self.accelerator.setup(trainer)
        self.setup_optimizers(trainer)
        self.setup_precision_plugin()
        optimizers_to_device(self.optimizers, self.root_device)

        if self.debug:
            os.environ["PT_XLA_DEBUG"] = str(1)

        shared_params = find_shared_parameters(self.model)
        self.model_to_device()
        if is_overridden("on_post_move_to_device", self.lightning_module):
            self.model.module.on_post_move_to_device()
        else:
            set_shared_parameters(self.model.module, shared_params)

        self.setup_optimizers(trainer)
        self.precision_plugin.connect(self.model, None, None)