def test_summon_full_param_writeback(
        self, writeback, cpu_offload, modify_outer
    ):
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False)
            )
        ).cuda(self.rank)

        # set the value
        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        p = outer_param if modify_outer else inner_param

        with torch.no_grad():
            # This sets the local shard value
            p[0] = self.rank + 2

        with model._summon_full_params(writeback=writeback):
            with torch.no_grad():
                p.copy_(torch.zeros_like(p))

        if writeback:
            self.assertEqual(p.cpu()[0], 0)
        else:
            self.assertEqual(p.cpu()[0], self.rank + 2)
    def test_summon_full_params_respects_reshard_after_forward(
            self, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False),
                     mixed_precision=mixed_precision),
                nn.Linear(5, 3, bias=False),
            ),
            mixed_precision=mixed_precision,
        ).cuda(self.rank)

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        outer_full_param_size = outer_param.numel() * self.world_size

        # trigger lazy init
        model(torch.zeros(5).cuda(self.rank))
        # the root FSDP module keeps all params around
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        # similarly summon_full_params should have the same behavior
        with model.summon_full_params(model):
            pass
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
    def test_summon_full_param_recursive(self, recurse, summon_outer):
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False)
            )
        ).cuda(self.rank)

        global_inner_numel = self.get_model_param_count(nn.Linear(5, 5, bias=False))
        global_outer_numel = self.get_model_param_count(nn.Linear(5, 3, bias=False))

        shard_inner_numel = int(math.ceil(global_inner_numel / self.world_size))
        shard_outer_numel = int(math.ceil(global_outer_numel / self.world_size))

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        self.assertEqual(shard_outer_numel, outer_param.numel())
        self.assertEqual(shard_inner_numel, inner_param.numel())

        model_to_summon = model if summon_outer else model[0]
        # outer is summoned if _summon_full_param is called on the outer FSDP module
        expected_outer_numel = global_outer_numel if summon_outer else shard_outer_numel

        # inner is summoned if _summon_full_param is called with recursion or on the inner FSDP module
        expected_inner_numel = (
            global_inner_numel if recurse or not summon_outer else shard_inner_numel
        )

        with model_to_summon._summon_full_params(recurse=recurse):
            self.assertEqual(expected_outer_numel, outer_param.numel())
            self.assertEqual(expected_inner_numel, inner_param.numel())
def _run_test_summon_full_param_writeback(cls, writeback, cpu_offload,
                                          modify_outer):
    model = FSDP(
        nn.Sequential(FSDP(nn.Linear(5, 5, bias=False)),
                      nn.Linear(5, 3, bias=False))).cuda(cls.rank)

    # set the value
    outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
    inner_param = model.get_parameter(
        "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param")
    p = outer_param if modify_outer else inner_param

    with torch.no_grad():
        # This sets the local shard value
        p[0] = cls.rank + 2

    with model.summon_full_params(writeback=writeback):
        with torch.no_grad():
            p.copy_(torch.zeros_like(p))

    if writeback or cls.world_size == 1:
        # When world_size = 1, FSDP does not shard and parameter is not set to
        # a local shard, so write is always reflected.
        cls.assertEqual(p.cpu()[0], 0)
    else:
        cls.assertEqual(p.cpu()[0], cls.rank + 2)
    def test_reshard_outside_forward_backward_iteration(
            self, rank0_only, offload_to_cpu, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False),
                     mixed_precision=mixed_precision),
                nn.Linear(5, 1, bias=False),
            ),
            mixed_precision=mixed_precision,
        ).cuda(self.rank)

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        outer_full_param_size = outer_param.numel() * self.world_size

        # First lets validate our assumption about resharding

        output = model(torch.zeros(5).cuda(self.rank))
        # the root FSDP module keeps all params around
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        # we reshard everything after backward() finishes
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        # now lets repeat it with summon done in between

        output = model(torch.zeros(5).cuda(self.rank))
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
        with model.summon_full_params(
                model,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            pass
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        with model.summon_full_params(
                model,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            pass
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
    def test_params_are_unflatenned(self):
        model = FSDP(nn.Linear(self.world_size, 1, bias=False)).cuda(self.rank)

        flattened_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        self.assertEqual(1, flattened_param.numel())

        with model._summon_full_params():
            a = model.weight.flatten().detach()
            b = flattened_param.detach()
            self.assertTrue(torch.equal(a, b))
    def test_reshard_outside_forward_backward_iteration(self):
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 1, bias=False)
            )
        ).cuda(self.rank)

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        outer_full_param_size = outer_param.numel() * self.world_size

        # First lets validate our assumption about resharding

        output = model(torch.zeros(5).cuda(self.rank))
        # the root FSDP module keeps all params around
        self.assertEqual(
            outer_full_param_size, outer_param._full_param_padded.storage().size()
        )
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        # we reshard everything after backward() finishes
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        # now lets repeat it with summon done in between

        output = model(torch.zeros(5).cuda(self.rank))
        with model._summon_full_params():
            pass
        self.assertEqual(
            outer_full_param_size, outer_param._full_param_padded.storage().size()
        )
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        with model._summon_full_params():
            pass
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
    def test_params_are_unflattenned(self):
        layer_shape = (10, 12)
        model = nn.Linear(*layer_shape, bias=False).cuda(self.rank)
        fsdp_model = FSDP(deepcopy(model)).cuda(self.rank)

        flattened_param = fsdp_model.get_parameter(
            "_fsdp_wrapped_module.flat_param")
        self.assertEqual(layer_shape[0] * layer_shape[1] / 2,
                         flattened_param.numel())

        with fsdp_model.summon_full_params():
            self.assertEqual(fsdp_model.weight.shape, model.weight.shape)
    def test_summon_single_param(self):
        model = FSDP(nn.Linear(1, 1, bias=False)).cuda(self.rank)

        p = model.get_parameter("_fsdp_wrapped_module.flat_param")
        self.assertEqual(1, p.numel())

        with torch.no_grad():
            # This sets the local shard value
            p[0] = self.rank + 2

        with model.summon_full_params(model, writeback=True):
            self.assertEqual(1, p.numel())
            with torch.no_grad():
                p.copy_(torch.zeros_like(p))

        # most ranks hold no data and wrote to padding so only rank zero will observe the above write
        if self.rank == 0:
            self.assertEqual(0, p[0])
        else:
            self.assertEqual(self.rank + 2, p[0])