def test_params_count_and_value(
        self,
        rank0_only: bool,
        offload_to_cpu: bool,
        mixed_precision: bool,
    ):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
        )
        fsdp_model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
        )
        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))
        params_to_compare = ([p.to(dev) for p in model.module.parameters()]
                             if not rank0_only or self.rank == 0 else list(
                                 p.clone() for p in fsdp_model.parameters()))
        with FSDP.summon_full_params(fsdp_model,
                                     rank0_only=rank0_only,
                                     writeback=not rank0_only):
            for p1, p2 in itertools.zip_longest(fsdp_model.parameters(),
                                                params_to_compare):
                self.assertEqual(p1, p2)

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device()))
 def test_named_parameters_buffers(self, prefix: str, recurse: bool):
     """Tests that ``named_parameters()`` and ``named_buffers()`` for a
     top-level FSDP-wrapped model matches their behavior for the equivalent
     non-wrapped model."""
     model = NestedWrappedModule.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
         deterministic=True,
     )
     model.register_buffer("buffer", torch.ones(1))
     # `named_parameters()` and `named_buffers` will contain FSDP prefixes
     # if called on a non-FSDP root module
     fsdp_model = FSDP(
         NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.NO_FSDP,
             CUDAInitMode.CUDA_BEFORE,
             deterministic=True,
         ),
         self.process_group,
     )
     fsdp_model.register_buffer("buffer", torch.ones(1))
     with FSDP.summon_full_params(fsdp_model):
         for call in ["named_parameters", "named_buffers"]:
             for (n1, p1), (n2, p2) in itertools.zip_longest(
                     getattr(fsdp_model, call)(prefix=prefix,
                                               recurse=recurse),
                     getattr(model, call)(prefix=prefix, recurse=recurse),
             ):
                 self.assertEqual(n1, n2)
                 self.assertEqual(p1, p2)
Ejemplo n.º 3
0
 def _init_model(
     self,
     nested_model: bool,
     sharding_strategy: ShardingStrategy,
     device: torch.device,
 ):
     fsdp_kwargs = {"sharding_strategy": sharding_strategy}
     if nested_model:
         model = NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.RECURSIVE,
             CUDAInitMode.CUDA_AFTER,
             fsdp_kwargs,
         )
         fsdp_model: FSDP = FSDP(
             model,
             self.process_group,
             **fsdp_kwargs,
         ).to(device)
     else:
         fsdp_model: FSDP = TransformerWithSharedParams.init(
             self.process_group,
             FSDPInitMode.RECURSIVE,
             CUDAInitMode.CUDA_BEFORE,
             fsdp_kwargs,
         )
     return fsdp_model
Ejemplo n.º 4
0
 def test_fsdp_cpu_init_stays_on_cpu(self):
     """Tests that passing a CPU module to FSDP preserves that the wrapped
     module is on CPU after FSDP initialization, albeit after loging a
     warning, and that FSDP moves CPU input to GPU before the forward."""
     torch.cuda.set_device(self.rank)
     regex = "Module is put on CPU"
     context = self.assertWarnsRegex(
         expected_warning=UserWarning, expected_regex=regex
     )
     with context:
         nested_wrapped_module = NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.RECURSIVE,
             CUDAInitMode.CUDA_NEVER,
         )
         fsdp_model = FSDP(nested_wrapped_module, self.process_group)
     devices = {p.device for p in fsdp_model.parameters()}
     self.assertEqual(1, len(devices))
     self.assertEqual(torch.device("cpu"), devices.pop())
     fsdp_model = fsdp_model.cuda()
     # Ensure fwd + backward can be performed after moving to CUDA.
     # CPU input also tests that input is correctly moved to appropriate
     # CUDA device.
     inp = fsdp_model.module.get_input(device=torch.device("cpu"))
     fsdp_model(*inp).sum().backward()
Ejemplo n.º 5
0
 def test_nested_module_apply(self):
     """Tests that ``apply()`` modifies parameter values in-place on a
     non-FSDP-root nested FSDP-wrapped model."""
     nested_wrapped_module = NestedWrappedModule.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_AFTER,
     )
     self._check_apply(nested_wrapped_module)
Ejemplo n.º 6
0
 def test_module_device_mismatches_device_id(self):
     """Tests that specifying a ``device_id`` argument to FSDP for a GPU
     module that does not match the GPU device ID raises an error."""
     context = (
         self.assertRaisesRegex(
             RuntimeError,
             f"on rank {self.rank}.*cuda:0, but is on cuda:{self.rank}"
         ) if self.rank != 0 else suppress()
     )
     with context:
         NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.RECURSIVE,
             # Move wrapped modules to CUDA before wrapping with FSDP
             cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
             # Should raise error since rank 1 is given `device_id=0` when
             # the model is on cuda:1
             fsdp_kwargs={"device_id": 0},
         )
 def test_raises_rank0_with_writeback(self):
     """Tests that ``summon_full_params()`` with both ``rank0_only=True``
     and ``writeback=True`` raises an error."""
     nested_wrapped_module = NestedWrappedModule.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE,
     )
     with self.assertRaisesRegex(ValueError, "is not supported"):
         with FSDP.summon_full_params(nested_wrapped_module,
                                      rank0_only=True,
                                      writeback=True):
             pass
Ejemplo n.º 8
0
 def test_fsdp_modules(self):
     nested_wrapped_module = NestedWrappedModule.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE,
     )
     modules = FSDP.fsdp_modules(nested_wrapped_module)
     self.assertEquals(modules, [
         nested_wrapped_module.module.get_submodule("1"),
         nested_wrapped_module.module.get_submodule("1").get_submodule("0"),
         nested_wrapped_module.module.get_submodule("2"),
     ])
     modules = FSDP.fsdp_modules(nested_wrapped_module, root_only=True)
     self.assertEqual(modules, [
         nested_wrapped_module.module.get_submodule("1"),
         nested_wrapped_module.module.get_submodule("2"),
     ])
Ejemplo n.º 9
0
    def test_cpu_init_with_sync_module_states(self):
        """Tests that passing ``sync_module_states=True`` raises an error for
        a CPU module since the synchronization requires GPU communication,
        while additionally passing ``device_id`` does not raise an error."""
        nested_wrapped_module = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_NEVER,
        )
        with self.assertRaisesRegex(
            ValueError,
            "Module has CPU parameters, but sync_module_states=True is specified."
        ):
            FSDP(nested_wrapped_module, self.process_group, sync_module_states=True)

        # Specifying device_id with sync_module_states=True works.
        FSDP(
            nested_wrapped_module,
            self.process_group,
            device_id=torch.cuda.current_device(),
            sync_module_states=True,
        )
Ejemplo n.º 10
0
    def test_fsdp_device_id(self, use_index):
        """
        Tests the FSDP ``device_id`` argument:
          - Wrapping a CPU module should move the module to the GPU matching
          ``device_id``
          - Wrapping a GPU module already on the GPU matching ``device_id``
          should not raise an error
          - Wrapping a GPU module already on GPU and passing a GPU device
          without specifying a device ID (i.e. ``torch.device("cuda")``) warns
        """
        dev_id = (
            torch.cuda.current_device() if use_index
            else torch.device("cuda", torch.cuda.current_device())
        )

        def _check_device_matches(module, device_id):
            """Checks that the ``FlatParameter``s in ``module`` have device
            matching ``device_id``."""
            devices = {
                p.device for p in module.parameters()
                if isinstance(p, FlatParameter)
            }
            assert len(devices) > 0
            self.assertEqual(1, len(devices))
            found_device = devices.pop()
            if use_index and not isinstance(device_id, torch.device):
                device = torch.device("cuda", device_id)
            else:
                device = device_id
            self.assertEqual(found_device, device)

        # Check that FSDP parameters are moved to `device_id` for a CPU module
        nested_wrapped_module = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_NEVER,
            fsdp_kwargs={"device_id": dev_id},
        )
        _check_device_matches(nested_wrapped_module, dev_id)
        # Check that specifying `device_id` for a GPU module already on that
        # device does not raise an error
        nested_wrapped_module = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
            fsdp_kwargs={"device_id": dev_id},
        )
        _check_device_matches(nested_wrapped_module, dev_id)
        # Check that passing in `torch.device("cuda")` for a GPU module warns
        regex = "does not have explicit index"
        context = self.assertWarnsRegex(
            expected_warning=UserWarning, expected_regex=regex
        )
        with context:
            nested_wrapped_module = NestedWrappedModule.init(
                self.process_group,
                FSDPInitMode.RECURSIVE,
                CUDAInitMode.CUDA_BEFORE,
                fsdp_kwargs={"device_id": torch.device("cuda")}
            )
        _check_device_matches(
            nested_wrapped_module,
            torch.device("cuda", torch.cuda.current_device())
        )