Ejemplo n.º 1
0
    def test_state_dict(self):
        """Check that the ZeroRedundancyOptimizer exposes the expected state dict interface,
        irrespective of the sharding.
        """
        self.dist_init(self.rank)
        x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
        o = ZeroRedundancyOptimizer([x],
                                    optimizer_class=SGD,
                                    lr=0.1,
                                    momentum=0.9)
        x.backward()
        o.step()
        self.assertEqual(x, torch.tensor([0.9], device=DEVICE))
        self.assertEqual(o.optim.state[x]["momentum_buffer"],
                         torch.tensor([1.0], device=DEVICE))

        o.zero_grad()
        o.consolidate_state_dict(
        )  # Sync state dict in between replicas - even if there are none
        state_dict = o.state_dict()

        # Check that the state dict is pytorch-compliant key wise
        self.assertIn("param_groups", state_dict.keys())
        self.assertIn("state", state_dict.keys())

        # Check that the pulled state is what we expect, and that we have all the expected keys
        self.assertEqual(state_dict["param_groups"][0]["lr"], 0.1)
        self.assertEqual(state_dict["param_groups"][0]["momentum"], 0.9)
        self.assertFalse(state_dict["param_groups"][0]["nesterov"])
        self.assertEqual(state_dict["param_groups"][0]["weight_decay"], 0.0)
        self.assertEqual(state_dict["param_groups"][0]["dampening"], 0.0)

        # Check that the pulled state and the .param_groups attribute are in sync
        for k in state_dict["param_groups"][0].keys():
            if k != "params":
                self.assertEqual(state_dict["param_groups"][0][k],
                                 o.param_groups[0][k])

        # Check that it's correctly loaded
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01)
        o.load_state_dict(state_dict)

        # Check that state is correct and on proper device
        self.assertEqual(o.optim.state[x]["momentum_buffer"],
                         torch.tensor([1.0], device=DEVICE))

        # We should now be using a lr of 0.1, both within the optimizer
        # and as exposed by the .param_groups attribute
        assert o.param_groups[0]["lr"] == 0.1
        x.backward()
        o.step()
        self.assertEqual(x, torch.tensor([0.71], device=DEVICE))
        self.assertEqual(o.optim.state[x]["momentum_buffer"],
                         torch.tensor([1.9], device=DEVICE))

        # Check that the exposed param_groups are on the proper device
        self.assertEqual(o.param_groups[0]["params"][0].device, x.device)
Ejemplo n.º 2
0
    def test_collect_shards(self):
        """ Check the state consolidation mechanism, and the state dict exposed by ZeroRedundancyOptimizer"""
        self.dist_init(self.rank)
        RECIPIENT_RANK = 0

        # Run a dummy step so that the optimizer state dict exists
        batch, input_width, hidden, target_width = 3, 20, 10, 5
        target = torch.rand((batch, target_width), device=self.device)
        inputs = torch.rand((batch, input_width), device=self.device)

        model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden),
                                    torch.nn.Linear(hidden, target_width))
        model.to(self.device)

        loss_fn = torch.nn.L1Loss()
        loss_fn.to(self.device)

        # With SGD, Momentum is required to get a state to shard
        optimizer = ZeroRedundancyOptimizer(model.parameters(),
                                            optimizer_class=SGD,
                                            lr=0.1,
                                            momentum=0.99)

        def closure():
            optimizer.zero_grad()
            output = model(inputs)
            loss = loss_fn(output, target)
            loss.backward()
            return loss

        _ = optimizer.step(closure=closure)

        # Update the optimizer state on the reference rank
        optimizer.consolidate_state_dict(to=RECIPIENT_RANK)

        # Fetch the state on the reference rank
        # - check that it has the correct size
        # - load it again
        if self.rank == RECIPIENT_RANK:
            optimizer_state_dict = optimizer.state_dict()
            self.assertEqual(len(optimizer_state_dict["state"]),
                             len(list(model.parameters())))
        else:
            optimizer_state_dict = {}

        optimizer_state_dict = _broadcast_object(
            optimizer_state_dict,
            src_rank=RECIPIENT_RANK,
            group=dist.group.WORLD,
            device=self.device,
        )

        # Load the optimizer state dict, check that no exception is raised
        optimizer.load_state_dict(optimizer_state_dict)
Ejemplo n.º 3
0
            def check_optimizer_equivalence(
                    optimizer: Type[torch.optim.Optimizer]):
                # Any model works. Add one different buffer per rank
                model = torch.nn.Sequential(
                    torch.nn.Linear(2, 3),
                    torch.nn.Linear(3, 3),
                    torch.nn.Linear(3, 3),
                )
                model.register_buffer("test_buffer",
                                      torch.ones((1)) * self.rank)
                model.to(self.device)

                sharded_optimizer = ZeroRedundancyOptimizer(
                    params=model.parameters(),
                    optimizer_class=optimizer,
                    lr=1e-3)
                sharded_ddp_model = DDP(module=model,
                                        device_ids=[self.rank],
                                        broadcast_buffers=True,
                                        find_unused_parameters=True)

                ddp_model_single = copy.deepcopy(model)
                ddp_model_single.to(self.device)

                ddp_optimizer = optimizer(ddp_model_single.parameters(),
                                          lr=1e-3)
                ddp_model = DDP(ddp_model_single,
                                device_ids=[self.rank],
                                broadcast_buffers=True,
                                find_unused_parameters=True)

                # The model should be synchronized in between the ranks at construction time, check that
                check_same_model_params(sharded_ddp_model, ddp_model,
                                        "Models differ from the start")

                def check_step():
                    input_tensor = torch.rand((64, 2))

                    def closure_ddp(input_tensor=input_tensor):
                        ddp_optimizer.zero_grad()
                        ddp_loss = ddp_model(input_tensor).abs().sum()
                        ddp_loss.backward()
                        return ddp_loss

                    def closure_sharded(input_tensor=input_tensor):
                        sharded_optimizer.zero_grad()
                        sharded_loss = sharded_ddp_model(
                            input_tensor).abs().sum()
                        sharded_loss.backward()
                        return sharded_loss

                    loss_ddp = cast(torch.Tensor,
                                    ddp_optimizer.step(closure=closure_ddp))
                    loss_sharded_optim = cast(
                        torch.Tensor,
                        sharded_optimizer.step(closure=closure_sharded))

                    assert torch.allclose(
                        loss_ddp, loss_sharded_optim
                    ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer"

                    check_same_model_params(sharded_ddp_model, ddp_model,
                                            "Models differ after a step")

                # The models should stay the same in between the ranks
                for i in range(BATCHS):
                    check_step()

                    # Change the models trainability, check that parity is maintained
                    # only check after a couple of constant batchs to go through both regimes
                    if i > BATCHS // 2:
                        next(ddp_model.parameters()).requires_grad = bool(i %
                                                                          2)
                        next(sharded_ddp_model.parameters()
                             ).requires_grad = bool(i % 2)

                # Check that the checkpoints are compatible
                reference_rank = 0
                # - get states
                ddp_state_dict = ddp_optimizer.state_dict()
                sharded_optimizer.consolidate_state_dict(to=reference_rank)
                sharded_optim_state_dict = [
                    sharded_optimizer.state_dict()
                    if self.rank == reference_rank else {}
                ]
                dist.broadcast_object_list(sharded_optim_state_dict,
                                           src=reference_rank,
                                           group=dist.group.WORLD)
                sharded_optim_state_dict = sharded_optim_state_dict[0]

                # - cross load the states
                # run one step and check that the models are still the same
                ddp_state_dict_ref = copy.deepcopy(
                    ddp_state_dict)  # OSS will remove some states
                ddp_optimizer.load_state_dict(
                    sharded_optim_state_dict)  # mixup on purpose !
                sharded_optimizer.load_state_dict(ddp_state_dict)
                check_step()

                #  - self load, rewind, check no problem
                # run one step and check that the models are still the same
                ddp_optimizer.load_state_dict(ddp_state_dict_ref)
                sharded_optimizer.load_state_dict(sharded_optim_state_dict)
                check_step()
            def check_optimizer_equivalence(
                    optimizer: Type[torch.optim.Optimizer]):
                # Any model works. Add one different buffer per rank
                model = torch.nn.Sequential(
                    torch.nn.Linear(2, 3),
                    torch.nn.Linear(3, 3),
                    torch.nn.Linear(3, 3),
                )
                model.register_buffer("test_buffer",
                                      torch.ones((1)) * self.rank)
                model.to(self.device)

                sharded_optimizer = ZeroRedundancyOptimizer(
                    params=model.parameters(), optim=optimizer, lr=1e-3)
                sharded_ddp_model = DDP(module=model,
                                        device_ids=[self.rank],
                                        broadcast_buffers=True)

                ddp_model_single = copy.deepcopy(model)
                ddp_model_single.to(self.device)

                ddp_optimizer = optimizer(ddp_model_single.parameters(),
                                          lr=1e-3)
                ddp_model = DDP(ddp_model_single,
                                device_ids=[self.rank],
                                broadcast_buffers=True)

                def check_same_model_params():
                    for pg, ddp_pg in zip(sharded_optimizer.param_groups,
                                          ddp_optimizer.param_groups):
                        for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
                            assert torch.allclose(
                                p, ddp_p, atol=1e-3
                            ), f"Model parameters differ in between Pytorch optim and ZeroRedundancyOptimizer \n{p} {ddp_p}"

                    for b, ddp_b in zip(sharded_ddp_model.buffers(),
                                        ddp_model.buffers()):
                        assert torch.allclose(
                            b, ddp_b
                        ), "Model buffers differ in between Pytorch optim and ZeroRedundancyOptimizer"

                # The model should be synchronized in between the ranks at construction time, check that
                check_same_model_params()

                # The models should stay the same across multiple steps, losses should stay the same
                def check_step():
                    input_tensor = torch.rand((64, 2))

                    def closure_ddp(input_tensor=input_tensor):
                        ddp_optimizer.zero_grad()
                        ddp_loss = ddp_model(input_tensor).abs().sum()
                        ddp_loss.backward()
                        return ddp_loss

                    def closure_sharded(input_tensor=input_tensor):
                        sharded_optimizer.zero_grad()
                        sharded_loss = sharded_ddp_model(
                            input_tensor).abs().sum()
                        sharded_loss.backward()
                        return sharded_loss

                    loss_ddp = cast(torch.Tensor,
                                    ddp_optimizer.step(closure=closure_ddp))
                    loss_sharded_optim = cast(
                        torch.Tensor,
                        sharded_optimizer.step(closure=closure_sharded))

                    assert torch.allclose(
                        loss_ddp, loss_sharded_optim
                    ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer"

                    check_same_model_params()

                for i in range(20):
                    check_step()

                # Test state dict save/load/equivalence with pytorch
                # - save state for both
                sharded_optimizer.consolidate_state_dict()
                sharded_optimizer_state_dict = (sharded_optimizer.state_dict()
                                                if self.rank == RECIPIENT_RANK
                                                else torch.zeros(1))
                ddp_state_dict = ddp_optimizer.state_dict()

                #  - sync the saved state with all the ranks
                exchange_list = [sharded_optimizer_state_dict]
                dist.broadcast_object_list(
                    exchange_list,
                    src=RECIPIENT_RANK,
                    group=dist.group.WORLD,
                )
                sharded_optimizer_state_dict = exchange_list[0]

                # - cross load the states
                ddp_optimizer.load_state_dict(sharded_optimizer_state_dict)
                sharded_optimizer.load_state_dict(ddp_state_dict)

                # - run one step, and check that the models are still the same
                check_step()
            def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer]):
                # Any model works. Add one different buffer per rank
                model = torch.nn.Sequential(
                    torch.nn.Linear(2, 3),
                    torch.nn.Linear(3, 3),
                    torch.nn.Linear(3, 3),
                )
                model.register_buffer("test_buffer", torch.ones((1)) * self.rank)
                model.to(self.device)

                sharded_optimizer = ZeroRedundancyOptimizer(params=model.parameters(), optim=optimizer, lr=1e-3)
                sharded_ddp_model = DDP(module=model, device_ids=[self.rank], broadcast_buffers=True)

                ddp_model_single = copy.deepcopy(model)
                ddp_model_single.to(self.device)

                ddp_optimizer = optimizer(ddp_model_single.parameters(), lr=1e-3)
                ddp_model = DDP(ddp_model_single, device_ids=[self.rank], broadcast_buffers=True)

                def check_same_model_params():
                    for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
                        for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
                            assert torch.allclose(
                                p, ddp_p, atol=1e-3
                            ), f"Model parameters differ in between Pytorch optim and ZeroRedundancyOptimizer \n{p} {ddp_p}"

                    for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
                        assert torch.allclose(
                            b, ddp_b
                        ), "Model buffers differ in between Pytorch optim and ZeroRedundancyOptimizer"

                # The model should be synchronized in between the ranks at construction time, check that
                check_same_model_params()

                def check_step():
                    input_tensor = torch.rand((64, 2))

                    def closure_ddp(input_tensor=input_tensor):
                        ddp_optimizer.zero_grad()
                        ddp_loss = ddp_model(input_tensor).abs().sum()
                        ddp_loss.backward()
                        return ddp_loss

                    def closure_sharded(input_tensor=input_tensor):
                        sharded_optimizer.zero_grad()
                        sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
                        sharded_loss.backward()
                        return sharded_loss

                    loss_ddp = cast(torch.Tensor, ddp_optimizer.step(closure=closure_ddp))
                    loss_sharded_optim = cast(torch.Tensor, sharded_optimizer.step(closure=closure_sharded))

                    assert torch.allclose(
                        loss_ddp, loss_sharded_optim
                    ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer"

                    check_same_model_params()

                # The models should stay the same in between the ranks
                for i in range(20):
                    check_step()

                # Check that the checkpoints are compatible
                reference_rank = 0
                # - get states
                ddp_state_dict = ddp_optimizer.state_dict()
                sharded_optimizer.consolidate_state_dict(recipient_rank=reference_rank)
                sharded_optim_state_dict = [sharded_optimizer.state_dict() if self.rank == reference_rank else {}]
                dist.broadcast_object_list(sharded_optim_state_dict, src=reference_rank, group=dist.group.WORLD)
                sharded_optim_state_dict = sharded_optim_state_dict[0]

                # - cross load the states
                # run one step and check that the models are still the same
                ddp_state_dict_ref = copy.deepcopy(ddp_state_dict)  # OSS will remove some states
                ddp_optimizer.load_state_dict(sharded_optim_state_dict)  # mixup on purpose !
                sharded_optimizer.load_state_dict(ddp_state_dict)
                check_step()

                #  - self load, rewind, check no problem
                # run one step and check that the models are still the same
                ddp_optimizer.load_state_dict(ddp_state_dict_ref)
                sharded_optimizer.load_state_dict(sharded_optim_state_dict)
                check_step()