Esempio n. 1
0
    def test_multiple_wrapping(self):
        """
        This test simulates wrapping the module after training to run inference.
        This is required in cases where later in a session, the model is wrapped again in FSDP but
        contains nested FSDP wrappers within the module.
        """
        inner_model = InnerModel()
        model = FSDP(inner_model).cuda()
        optim = SGD(model.parameters(), lr=0.1)

        for i in range(3):
            input = torch.rand((1, 5), dtype=torch.float).cuda()
            input.requires_grad = True
            output = model(input)
            output.sum().backward()
            optim.step()
            optim.zero_grad()
        input = torch.rand((1, 5), dtype=torch.float).cuda()
        output = model(input)

        # second time to rewrap the inner model
        rewrapped_model = FSDP(inner_model).cuda()
        rewrapped_output = rewrapped_model(input)

        self.assertEqual(output, rewrapped_output)
Esempio n. 2
0
 def wrap_alt(model, group=None) -> torch.nn.Module:
     model.block0.bias_module0 = FSDP(
         model.block0.bias_module0,
         process_group=group,
     )
     model.block0 = FSDP(model.block0, process_group=group)
     return model
Esempio n. 3
0
    def test_diff_ignored_modules_across_ranks(
            self, pass_ignored_modules_to_root: bool):
        """
        Tests ignoring different modules across ranks.

        Args:
            pass_ignored_modules_to_root (bool): If ``False``, does not pass
                any ignored modules (including those already ignored in child
                FSDP instances) to the root FSDP instance; if ``True``, passes
                all ignored modules (representing a superset of the children's
                ignored modules) to the root FSDP instance.
        """
        # To exercise different `FlatParameter` enumerations across ranks,
        # we wrap `layer3` with FSDP, where `layer3` is registered as a module
        # after `layer1`, which has the variable number of ignored modules
        model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda()
        layer1_ignored_modules = [
            m for m in model.layer1.modules() if isinstance(m, IgnoredModule)
        ]
        model.layer1 = FSDP(model.layer1,
                            ignored_modules=layer1_ignored_modules)
        model.layer3 = FSDP(model.layer3)
        model_ignored_modules = [
            m for m in model.modules() if isinstance(m, IgnoredModule)
        ] if pass_ignored_modules_to_root else []
        wrapped_model = FSDP(model, ignored_modules=model_ignored_modules)
        optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
        self._train_model(wrapped_model, optim, 3)
    def test_summon_full_params_respects_reshard_after_forward(
            self, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False),
                     mixed_precision=mixed_precision),
                nn.Linear(5, 3, bias=False),
            ),
            mixed_precision=mixed_precision,
        ).cuda(self.rank)

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        outer_full_param_size = outer_param.numel() * self.world_size

        # trigger lazy init
        model(torch.zeros(5).cuda(self.rank))
        # the root FSDP module keeps all params around
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        # similarly summon_full_params should have the same behavior
        with model.summon_full_params(model):
            pass
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
 def test_ignored_modules_nested(self):
     """Tests that passing a module with nested FSDP modules does not
     error and still ignores non-FSDP modules' parameters."""
     # Initialize an FSDP-wrapped nested model that first wraps the nested
     # sequential's middle linear layer (`layer1[1]`) and then wraps the
     # overall model while ignoring the nested sequential (`layer1`)
     model = Model().cuda()
     model.layer1[1] = FSDP(model.layer1[1])
     wrapped_model = FSDP(model, ignored_modules=[model.layer1])
     # Check that the wrapped model's flattened parameter does not include
     # the ignored nested sequential's parameters
     nonwrapped_model = Model()
     total_numel = sum(p.numel() for p in nonwrapped_model.parameters())
     ignored_numel = sum(p.numel()
                         for p in nonwrapped_model.layer1.parameters())
     nonignored_numel = total_numel - ignored_numel
     with FSDP.summon_full_params(wrapped_model):
         flat_param_numel = wrapped_model.params[0].numel()
         self.assertEqual(flat_param_numel, nonignored_numel)
     # Check that we can run a few iterations
     device = torch.device("cuda")
     optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
     for _ in range(3):
         inp = wrapped_model.get_input(device)
         output = wrapped_model(*inp)
         loss = wrapped_model.get_loss(inp, output).to(device)
         wrapped_model.run_backward(loss)
         optim.step()
Esempio n. 6
0
 def test_no_params(self):
     """
     Test that device_id and cpu init work if module has no params
     (they are effective noops, but ensure FSDP does not assume module
     has parameters during init)
     """
     # Test CPU
     no_params = nn.ReLU()
     module = FSDP(no_params)
     # Test CUDA
     no_params = nn.ReLU().cuda()
     module = FSDP(no_params)
     # Test CPU + device_id
     no_params = nn.ReLU()
     module = FSDP(no_params, device_id=torch.cuda.current_device())
     # For modules with no params, wrong device_id will raise error about
     # inconsistency between compute_device and device_id, since compute_device
     # is computed as torch.cuda.current_device when there are no params.
     no_params = nn.ReLU().cuda()
     context = (
         self.assertRaisesRegex(
             AssertionError,
             f"Inconsistent.*cuda:{self.rank} vs cuda:0"
         )
     ) if self.rank != 0 else suppress()
     with context:
         module = FSDP(no_params, device_id=0)
Esempio n. 7
0
    def test_fsdp_same_model_across_ranks(self):
        """
        FSDP broadcasts model from rank 0 to ensure it starts off with the same
        values.
        """
        class MyModel(nn.Module):
            def __init__(self, rank):
                super().__init__()
                # Seed via rank to make model different across ranks
                torch.manual_seed(rank)
                torch.cuda.manual_seed(rank)
                self.lin = nn.Linear(10, 10, bias=False)
                self.register_buffer("buffer", torch.ones(1) * rank)

        m = MyModel(self.rank).cuda()
        _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual)
        # Passing sync_module_states into FSDP makes model the same during init.
        fsdp = FSDP(m, sync_module_states=True)
        with fsdp.summon_full_params(fsdp):
            _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual)

        # sync_module_states also works with CPU module with device_id passed in
        m = MyModel(self.rank)
        _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual)
        # Passing sync_module_states into FSDP makes model the same during init.
        fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True)
        with fsdp.summon_full_params(fsdp):
            _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual)
Esempio n. 8
0
    def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None):
        # to ensure determinism
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)
        super().__init__()

        if has_wrapping:
            self.net = FSDP(nn.Sequential(
                nn.Linear(8, 16),
                nn.ReLU(),
                FSDP(
                    nn.Linear(16, 8),
                    device_id=torch.cuda.current_device(),
                    sharding_strategy=sharding_strategy,
                    mixed_precision=mixed_precision,
                )
            ),
                device_id=torch.cuda.current_device(),
                sharding_strategy=sharding_strategy,
                mixed_precision=mixed_precision,
            )
        else:
            self.net = nn.Sequential(
                nn.Linear(8, 16),
                nn.ReLU(),
                nn.Linear(16, 8)
            )

        self.out = nn.Linear(8, 4)
def _run_test_summon_full_param_writeback(cls, writeback, cpu_offload,
                                          modify_outer):
    model = FSDP(
        nn.Sequential(FSDP(nn.Linear(5, 5, bias=False)),
                      nn.Linear(5, 3, bias=False))).cuda(cls.rank)

    # set the value
    outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
    inner_param = model.get_parameter(
        "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param")
    p = outer_param if modify_outer else inner_param

    with torch.no_grad():
        # This sets the local shard value
        p[0] = cls.rank + 2

    with model.summon_full_params(writeback=writeback):
        with torch.no_grad():
            p.copy_(torch.zeros_like(p))

    if writeback or cls.world_size == 1:
        # When world_size = 1, FSDP does not shard and parameter is not set to
        # a local shard, so write is always reflected.
        cls.assertEqual(p.cpu()[0], 0)
    else:
        cls.assertEqual(p.cpu()[0], cls.rank + 2)
    def test_summon_full_param_writeback(
        self, writeback, cpu_offload, modify_outer
    ):
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False)
            )
        ).cuda(self.rank)

        # set the value
        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        p = outer_param if modify_outer else inner_param

        with torch.no_grad():
            # This sets the local shard value
            p[0] = self.rank + 2

        with model._summon_full_params(writeback=writeback):
            with torch.no_grad():
                p.copy_(torch.zeros_like(p))

        if writeback:
            self.assertEqual(p.cpu()[0], 0)
        else:
            self.assertEqual(p.cpu()[0], self.rank + 2)
    def test_summon_full_param_recursive(self, recurse, summon_outer):
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False)
            )
        ).cuda(self.rank)

        global_inner_numel = self.get_model_param_count(nn.Linear(5, 5, bias=False))
        global_outer_numel = self.get_model_param_count(nn.Linear(5, 3, bias=False))

        shard_inner_numel = int(math.ceil(global_inner_numel / self.world_size))
        shard_outer_numel = int(math.ceil(global_outer_numel / self.world_size))

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        self.assertEqual(shard_outer_numel, outer_param.numel())
        self.assertEqual(shard_inner_numel, inner_param.numel())

        model_to_summon = model if summon_outer else model[0]
        # outer is summoned if _summon_full_param is called on the outer FSDP module
        expected_outer_numel = global_outer_numel if summon_outer else shard_outer_numel

        # inner is summoned if _summon_full_param is called with recursion or on the inner FSDP module
        expected_inner_numel = (
            global_inner_numel if recurse or not summon_outer else shard_inner_numel
        )

        with model_to_summon._summon_full_params(recurse=recurse):
            self.assertEqual(expected_outer_numel, outer_param.numel())
            self.assertEqual(expected_inner_numel, inner_param.numel())
    def test_reshard_outside_forward_backward_iteration(
            self, rank0_only, offload_to_cpu, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False),
                     mixed_precision=mixed_precision),
                nn.Linear(5, 1, bias=False),
            ),
            mixed_precision=mixed_precision,
        ).cuda(self.rank)

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        outer_full_param_size = outer_param.numel() * self.world_size

        # First lets validate our assumption about resharding

        output = model(torch.zeros(5).cuda(self.rank))
        # the root FSDP module keeps all params around
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        # we reshard everything after backward() finishes
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        # now lets repeat it with summon done in between

        output = model(torch.zeros(5).cuda(self.rank))
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
        with model.summon_full_params(
                model,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            pass
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        with model.summon_full_params(
                model,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            pass
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
Esempio n. 13
0
def create_model(with_fsdp, with_checkpoint, model_hidden_dim):
    torch.manual_seed(0)
    model = Model(model_hidden_dim, with_fsdp, with_checkpoint)
    if with_fsdp:
        model.stem = FSDP(model.stem)
        model.blocks = FSDP(model.blocks)
        model.head = FSDP(model.head)

    return model
Esempio n. 14
0
def _create_model(compute_cycles, has_params: bool):
    model = FSDP(
        nn.Sequential(
            FSDP(Layer(compute_cycles, has_params)),
            FSDP(Layer(compute_cycles, has_params)),
            FSDP(Layer(compute_cycles, has_params)),
            FSDP(Layer(compute_cycles, has_params)),
        )).cuda()
    return model
Esempio n. 15
0
 def _get_simple_nested_model(self, *fsdp_args, **fsdp_kwargs):
     model = FSDP(
         nn.Sequential(
             FSDP(nn.Linear(10, 10, bias=False), *fsdp_args, **fsdp_kwargs),
             nn.Linear(10, 10, bias=False),
         ),
         *fsdp_args,
         **fsdp_kwargs,
     )
     return model
Esempio n. 16
0
 def _get_simple_nested_model(self, param_dtype, *fsdp_args, **fsdp_kwargs):
     model = FSDP(
         nn.Sequential(
             FSDP(LinearMixedPrecision(param_dtype).cuda(), *fsdp_args, **fsdp_kwargs),
             LinearMixedPrecision(param_dtype).cuda(),
         ),
         *fsdp_args,
         **fsdp_kwargs,
     )
     return model
Esempio n. 17
0
 def test_ignored_modules_invalid(self):
     """Tests that passing an FSDP module as an ignored module errors."""
     model = Model()
     model.layer1 = FSDP(model.layer1)
     # Passing an FSDP module as an ignored module should error
     with self.assertRaises(
         ValueError,
         msg="`ignored_modules` should not include FSDP modules",
     ):
         FSDP(model, ignored_modules=[model.layer1])
Esempio n. 18
0
    def test_fsdp_device_id(self, use_index):
        """
        If CPU module is passed into FSDP with device_id
        argument, it is moved to the GPU with that device_id.
        """
        dev_id = (
            torch.cuda.current_device() if use_index
            else torch.device("cuda", torch.cuda.current_device())
        )

        def _check_device_matches(fsdp, dev_id):
            devices = {p.device for p in fsdp.parameters()}
            self.assertEqual(1, len(devices))
            found_dev = devices.pop()
            if use_index and not isinstance(dev_id, torch.device):
                dev_id = torch.device("cuda", dev_id)
            self.assertEqual(found_dev, dev_id)

        mod = NestedWrappedModule(
            group=self.process_group,
            wrap_fsdp=True,
            wrap_everything=True,
            fsdp_init_mode=FSDPInitMode.CUDA_NEVER,
            device_id=dev_id
        )
        fsdp = FSDP(mod, device_id=dev_id)
        # Check FSDP parameters are moved.
        _check_device_matches(fsdp, dev_id)
        # device_id matching module device before FSDP construction
        # should not throw errors.
        mod = NestedWrappedModule(
            group=self.process_group,
            wrap_fsdp=True,
            wrap_everything=True,
            fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
            device_id=dev_id
        )
        fsdp = FSDP(mod, device_id=dev_id)
        _check_device_matches(fsdp, dev_id)
        # Passing in torch.device("cuda") should work.
        regex = "does not have explicit index"
        context = self.assertWarnsRegex(
            expected_warning=UserWarning, expected_regex=regex
        )
        with context:
            mod = NestedWrappedModule(
                group=self.process_group,
                wrap_fsdp=True,
                wrap_everything=True,
                fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
                device_id=torch.device("cuda")
            )
            fsdp = FSDP(mod, device_id=torch.device("cuda"))
        _check_device_matches(fsdp, torch.device("cuda", torch.cuda.current_device()))
Esempio n. 19
0
    def __init__(
        self,
        group: dist.ProcessGroup,
        wrap_fsdp: bool,
        cuda_init_mode: CUDAInitMode,
        delay_before_free_ms: int,
        deterministic: bool,
        **fsdp_kwargs,
    ):
        super().__init__(
            group=group,
            wrap_fsdp=wrap_fsdp,
            cuda_init_mode=cuda_init_mode,
            deterministic=deterministic,
        )
        self.group = group
        self.delay_before_free_ms = delay_before_free_ms
        self.wrap_fsdp = wrap_fsdp
        self.move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
        if deterministic:
            # Give each rank different expert parameters
            torch.manual_seed(42 + self.rank)
        d_expert = 23
        d_shared = 12
        d_input = 8
        expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda)

        self.num_expert_params = sum([p.numel() for p in expert.parameters()])
        for p in expert.parameters():
            p.expert = True  # type: ignore[attr-defined]

        if deterministic:
            # Keep all other parameters the same across ranks
            torch.manual_seed(0)

        shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda)

        if wrap_fsdp:
            # we create a process group of size 1 for the expert params
            expert_group = torch.distributed.new_group(
                [group.rank()])  # world size 1 means no shard
            expert = FSDP(expert, expert_group,
                          **fsdp_kwargs)  # type: ignore[assignment]
            shared = FSDP(shared, group,
                          **fsdp_kwargs)  # type: ignore[assignment]

        self.module = nn.Sequential(
            _maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda),
            shared, expert,
            _maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda))
Esempio n. 20
0
 def test_fsdp_cpu_init_stays_on_cpu(self):
     """
     Ensure that CPU model input stays on CPU
     after FSDP init and we log a warning.
     """
     torch.cuda.set_device(self.rank)
     regex = "Module is put on CPU"
     context = self.assertWarnsRegex(expected_warning=UserWarning,
                                     expected_regex=regex)
     with context:
         mod = NestedWrappedModule(
             group=self.process_group,
             wrap_fsdp=True,
             wrap_everything=True,
             fsdp_init_mode=FSDPInitMode.CUDA_NEVER,
         )
         fsdp = FSDP(mod)
     devices = {p.device for p in fsdp.parameters()}
     self.assertEqual(1, len(devices))
     self.assertEqual(torch.device("cpu"), devices.pop())
     fsdp = fsdp.cuda()
     # Ensure fwd + backward can be performed after moving to CUDA.
     # CPU input also tests that input is correctly moved to appropriate
     # CUDA device.
     inp = mod.get_input(device=torch.device("cpu"))
     fsdp(inp[0]).sum().backward()
Esempio n. 21
0
 def test_wrong_state_dict_config(self):
     model = FSDP(Model(wrap_fsdp=True).cuda())
     with self.assertRaisesRegex(RuntimeError,
                                 "Expected state_dict_config of type"):
         with model.state_dict_type(model, StateDictType.FULL_STATE_DICT,
                                    LocalStateDictConfig()):
             pass
Esempio n. 22
0
 def _init_model(
     self,
     nested_model: bool,
     sharding_strategy: ShardingStrategy,
     device: torch.device,
 ):
     fsdp_kwargs = {"sharding_strategy": sharding_strategy}
     if nested_model:
         model = NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.RECURSIVE,
             CUDAInitMode.CUDA_AFTER,
             fsdp_kwargs,
         )
         fsdp_model: FSDP = FSDP(
             model,
             self.process_group,
             **fsdp_kwargs,
         ).to(device)
     else:
         fsdp_model: FSDP = TransformerWithSharedParams.init(
             self.process_group,
             FSDPInitMode.RECURSIVE,
             CUDAInitMode.CUDA_BEFORE,
             fsdp_kwargs,
         )
     return fsdp_model
 def test_named_parameters_buffers(self, prefix: str, recurse: bool):
     """Tests that ``named_parameters()`` and ``named_buffers()`` for a
     top-level FSDP-wrapped model matches their behavior for the equivalent
     non-wrapped model."""
     model = NestedWrappedModule.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
         deterministic=True,
     )
     model.register_buffer("buffer", torch.ones(1))
     # `named_parameters()` and `named_buffers` will contain FSDP prefixes
     # if called on a non-FSDP root module
     fsdp_model = FSDP(
         NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.NO_FSDP,
             CUDAInitMode.CUDA_BEFORE,
             deterministic=True,
         ),
         self.process_group,
     )
     fsdp_model.register_buffer("buffer", torch.ones(1))
     with FSDP.summon_full_params(fsdp_model):
         for call in ["named_parameters", "named_buffers"]:
             for (n1, p1), (n2, p2) in itertools.zip_longest(
                     getattr(fsdp_model, call)(prefix=prefix,
                                               recurse=recurse),
                     getattr(model, call)(prefix=prefix, recurse=recurse),
             ):
                 self.assertEqual(n1, n2)
                 self.assertEqual(p1, p2)
    def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu):
        offload = CPUOffload(offload_params=True)
        model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload),
                     cpu_offload=offload)
        local_model = DeterministicModel(wrap_fsdp=False)

        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))

        params_to_compare = ([
            p.clone() for p in model.parameters()
        ] if rank0_only and self.rank != 0 else list(local_model.parameters()))

        with model.summon_full_params(
                model,
                recurse=True,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            # Below sleep causes failures without stream synchronization in
            # summon_full_params fix.
            torch.cuda._sleep(1000000)
            # FSDP param deepcopy() of params has issues
            fsdp_params = [p.clone() for p in model.parameters()]

        self.assertEqual(fsdp_params, params_to_compare)
    def test_params_count_and_value(self, rank0_only, offload_to_cpu,
                                    mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        fsdp_model = FSDP(
            NestedWrappedModule(
                group=dist.distributed_c10d._get_default_group(),
                wrap_fsdp=True,
                fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
                mixed_precision=mixed_precision,
            ),
            mixed_precision=mixed_precision,
        )
        model = NestedWrappedModule(
            group=dist.distributed_c10d._get_default_group(),
            wrap_fsdp=False,
            fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
        )

        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))

        params_to_compare = ([p.to(dev) for p in model.module.parameters()]
                             if not rank0_only or self.rank == 0 else list(
                                 p.clone() for p in fsdp_model.parameters()))
        with fsdp_model.summon_full_params(fsdp_model,
                                           rank0_only=rank0_only,
                                           writeback=not rank0_only):
            for p1, p2 in itertools.zip_longest(fsdp_model.parameters(),
                                                params_to_compare):
                self.assertEqual(p1, p2)

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device()))
    def test_summon_full_param_shard_value(self, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        raw_model = nn.Linear(10, 11)
        raw_model_size = self.get_model_param_count(raw_model)
        expected_shard_size = self.get_expected_sharded_size(raw_model_size)

        model = FSDP(raw_model.cuda(self.rank),
                     mixed_precision=mixed_precision)
        self.assertEqual(expected_shard_size,
                         self.get_model_param_count(model))

        # we're assuming a single flattened param
        self.assertEqual(1, len(list(model.parameters())))

        my_shard = torch.clone(next(model.parameters()))

        with model.summon_full_params(model):
            self.assertEqual(raw_model_size, self.get_model_param_count(model))
            parameters = list(model.parameters())
            all_shards = FlatParamHandle.flatten_params(parameters,
                                                        requires_grad=False)
            my_slice = torch.chunk(all_shards, self.world_size)[self.rank]

            # shards are padded but the full_param tensor is not
            a, b = my_shard[0:my_slice.numel()], my_slice
            self.assertTrue(
                torch.equal(my_shard[0:my_slice.numel()].cpu(),
                            my_slice.cpu()))
Esempio n. 27
0
    def test_state_dict_with_ignored_modules(self):
        # Initialize an FSDP-wrapped model with an ignored module
        model = Model(wrap_fsdp=True).cuda()
        ignored_modules = [model.outer]
        ignored_param_to_param_name = {
            model.outer.bias: "outer.bias",
            model.outer.weight: "outer.weight",
        }
        fsdp_model = FSDP(model, ignored_modules=ignored_modules)
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd = fsdp_model.state_dict()

        with FSDP.summon_full_params(fsdp_model):
            fsdp_params = deepcopy(list(fsdp_model.parameters()))
        # Check that the ignored parameters are not cloned

        for param, param_name in ignored_param_to_param_name.items():
            self.assertTrue(param_name in sd)
            self.assertEqual(param.data_ptr(), sd[param_name].data_ptr())
        # Check that the state dict can be loaded into a non-wrapped version of
        # the model
        nonwrapped_model = Model(wrap_fsdp=False).cuda()
        for param in nonwrapped_model.parameters():
            with torch.no_grad():
                param.zero_()

        nonwrapped_model.load_state_dict(sd)
        local_params = list(nonwrapped_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)
Esempio n. 28
0
 def init(
     group: dist.ProcessGroup,
     fsdp_init_mode: FSDPInitMode,
     cuda_init_mode: CUDAInitMode,
     fsdp_kwargs: Optional[Dict[str, Any]] = None,
     deterministic: bool = False,
 ):
     """
     Initializes a :class:`NestedWrappedModule` instance, but unlike
     :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this
     wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap
     policy.
     """
     super_ = super(AlwaysWrapNestedWrappedModule,
                    AlwaysWrapNestedWrappedModule)
     model = super_.init(
         group=group,
         fsdp_init_mode=FSDPInitMode.NO_FSDP,
         cuda_init_mode=cuda_init_mode,
         fsdp_kwargs=fsdp_kwargs,
         deterministic=deterministic,
     )
     if fsdp_init_mode == FSDPInitMode.NO_FSDP:
         return model
     elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
         fsdp_model = FSDP(model,
                           auto_wrap_policy=always_wrap_policy,
                           **fsdp_kwargs)
         if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
             fsdp_model = fsdp_model.cuda()
         return fsdp_model
    def test_mixed_precision_resnet(self):
        """
        End to end test to ensure mixed precision + auto_wrap works
        for ResNet model.
        """
        resnet_model = torchvision.models.resnet50().cuda()
        resnet_model = nn.SyncBatchNorm.convert_sync_batchnorm(
            resnet_model,
            process_group=dist.distributed_c10d._get_default_group())
        n_bn = sum(1 if isinstance(x, _BatchNorm) else 0
                   for x in resnet_model.modules())
        inp = torch.ones(1, 3, 1000, 1000, device='cuda')
        mp_config = MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
        )
        fsdp = FSDP(resnet_model,
                    auto_wrap_policy=size_based_auto_wrap_policy,
                    mixed_precision=mp_config)
        # Batchnorm units should be wrapped individually. Validate this by
        # ensuring there are equal no. of FSDP units that are BN as BN units
        # in original resnet model.
        fsdp_bn = 0
        for module in fsdp.fsdp_modules(fsdp):
            wrapped_module = module.module.module
            if isinstance(wrapped_module, _BatchNorm):
                fsdp_bn += 1

        self.assertEqual(fsdp_bn, n_bn)
        # Would throw type mismatch issue without mixed precision autowrapping.
        loss = fsdp(inp).sum()
        loss.backward()
Esempio n. 30
0
 def _init_model(
     self,
     nested_model: bool,
     sharding_strategy: ShardingStrategy,
     device: torch.device,
 ):
     group = dist.distributed_c10d._get_default_group()
     if nested_model:
         model = NestedWrappedModule(
             group,
             wrap_fsdp=True,
             sharding_strategy=sharding_strategy,
         )
         fsdp_model: FSDP = FSDP(
             model,
             group,
             sharding_strategy=sharding_strategy,
         ).to(device)
     else:
         fsdp_model: FSDP = self._get_wrapped_model(
             group,
             cuda_first=False,
             config={"sharding_strategy": sharding_strategy},
         )
     return fsdp_model