def test_params_count_and_value(
        self,
        rank0_only: bool,
        offload_to_cpu: bool,
        mixed_precision: bool,
    ):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
        )
        fsdp_model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
        )
        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))
        params_to_compare = ([p.to(dev) for p in model.module.parameters()]
                             if not rank0_only or self.rank == 0 else list(
                                 p.clone() for p in fsdp_model.parameters()))
        with FSDP.summon_full_params(fsdp_model,
                                     rank0_only=rank0_only,
                                     writeback=not rank0_only):
            for p1, p2 in itertools.zip_longest(fsdp_model.parameters(),
                                                params_to_compare):
                self.assertEqual(p1, p2)

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device()))
 def test_named_parameters_buffers(self, prefix: str, recurse: bool):
     """Tests that ``named_parameters()`` and ``named_buffers()`` for a
     top-level FSDP-wrapped model matches their behavior for the equivalent
     non-wrapped model."""
     model = NestedWrappedModule.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
         deterministic=True,
     )
     model.register_buffer("buffer", torch.ones(1))
     # `named_parameters()` and `named_buffers` will contain FSDP prefixes
     # if called on a non-FSDP root module
     fsdp_model = FSDP(
         NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.NO_FSDP,
             CUDAInitMode.CUDA_BEFORE,
             deterministic=True,
         ),
         self.process_group,
     )
     fsdp_model.register_buffer("buffer", torch.ones(1))
     with FSDP.summon_full_params(fsdp_model):
         for call in ["named_parameters", "named_buffers"]:
             for (n1, p1), (n2, p2) in itertools.zip_longest(
                     getattr(fsdp_model, call)(prefix=prefix,
                                               recurse=recurse),
                     getattr(model, call)(prefix=prefix, recurse=recurse),
             ):
                 self.assertEqual(n1, n2)
                 self.assertEqual(p1, p2)
    def test_params_count_and_value(self, rank0_only, offload_to_cpu,
                                    mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        fsdp_model = FSDP(
            NestedWrappedModule(
                group=dist.distributed_c10d._get_default_group(),
                wrap_fsdp=True,
                fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
                mixed_precision=mixed_precision,
            ),
            mixed_precision=mixed_precision,
        )
        model = NestedWrappedModule(
            group=dist.distributed_c10d._get_default_group(),
            wrap_fsdp=False,
            fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
        )

        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))

        params_to_compare = ([p.to(dev) for p in model.module.parameters()]
                             if not rank0_only or self.rank == 0 else list(
                                 p.clone() for p in fsdp_model.parameters()))
        with fsdp_model.summon_full_params(fsdp_model,
                                           rank0_only=rank0_only,
                                           writeback=not rank0_only):
            for p1, p2 in itertools.zip_longest(fsdp_model.parameters(),
                                                params_to_compare):
                self.assertEqual(p1, p2)

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device()))
예제 #4
0
 def test_fsdp_cpu_init_stays_on_cpu(self):
     """
     Ensure that CPU model input stays on CPU
     after FSDP init and we log a warning.
     """
     torch.cuda.set_device(self.rank)
     regex = "Module is put on CPU"
     context = self.assertWarnsRegex(expected_warning=UserWarning,
                                     expected_regex=regex)
     with context:
         mod = NestedWrappedModule(
             group=self.process_group,
             wrap_fsdp=True,
             wrap_everything=True,
             fsdp_init_mode=FSDPInitMode.CUDA_NEVER,
         )
         fsdp = FSDP(mod)
     devices = {p.device for p in fsdp.parameters()}
     self.assertEqual(1, len(devices))
     self.assertEqual(torch.device("cpu"), devices.pop())
     fsdp = fsdp.cuda()
     # Ensure fwd + backward can be performed after moving to CUDA.
     # CPU input also tests that input is correctly moved to appropriate
     # CUDA device.
     inp = mod.get_input(device=torch.device("cpu"))
     fsdp(inp[0]).sum().backward()
예제 #5
0
    def test_fsdp_device_id(self, use_index):
        """
        If CPU module is passed into FSDP with device_id
        argument, it is moved to the GPU with that device_id.
        """
        dev_id = (
            torch.cuda.current_device() if use_index
            else torch.device("cuda", torch.cuda.current_device())
        )

        def _check_device_matches(fsdp, dev_id):
            devices = {p.device for p in fsdp.parameters()}
            self.assertEqual(1, len(devices))
            found_dev = devices.pop()
            if use_index and not isinstance(dev_id, torch.device):
                dev_id = torch.device("cuda", dev_id)
            self.assertEqual(found_dev, dev_id)

        mod = NestedWrappedModule(
            group=self.process_group,
            wrap_fsdp=True,
            wrap_everything=True,
            fsdp_init_mode=FSDPInitMode.CUDA_NEVER,
            device_id=dev_id
        )
        fsdp = FSDP(mod, device_id=dev_id)
        # Check FSDP parameters are moved.
        _check_device_matches(fsdp, dev_id)
        # device_id matching module device before FSDP construction
        # should not throw errors.
        mod = NestedWrappedModule(
            group=self.process_group,
            wrap_fsdp=True,
            wrap_everything=True,
            fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
            device_id=dev_id
        )
        fsdp = FSDP(mod, device_id=dev_id)
        _check_device_matches(fsdp, dev_id)
        # Passing in torch.device("cuda") should work.
        regex = "does not have explicit index"
        context = self.assertWarnsRegex(
            expected_warning=UserWarning, expected_regex=regex
        )
        with context:
            mod = NestedWrappedModule(
                group=self.process_group,
                wrap_fsdp=True,
                wrap_everything=True,
                fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
                device_id=torch.device("cuda")
            )
            fsdp = FSDP(mod, device_id=torch.device("cuda"))
        _check_device_matches(fsdp, torch.device("cuda", torch.cuda.current_device()))
예제 #6
0
 def _init_model(
     self,
     nested_model: bool,
     sharding_strategy: ShardingStrategy,
     device: torch.device,
 ):
     fsdp_kwargs = {"sharding_strategy": sharding_strategy}
     if nested_model:
         model = NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.RECURSIVE,
             CUDAInitMode.CUDA_AFTER,
             fsdp_kwargs,
         )
         fsdp_model: FSDP = FSDP(
             model,
             self.process_group,
             **fsdp_kwargs,
         ).to(device)
     else:
         fsdp_model: FSDP = TransformerWithSharedParams.init(
             self.process_group,
             FSDPInitMode.RECURSIVE,
             CUDAInitMode.CUDA_BEFORE,
             fsdp_kwargs,
         )
     return fsdp_model
예제 #7
0
 def test_fsdp_cpu_init_stays_on_cpu(self):
     """Tests that passing a CPU module to FSDP preserves that the wrapped
     module is on CPU after FSDP initialization, albeit after loging a
     warning, and that FSDP moves CPU input to GPU before the forward."""
     torch.cuda.set_device(self.rank)
     regex = "Module is put on CPU"
     context = self.assertWarnsRegex(
         expected_warning=UserWarning, expected_regex=regex
     )
     with context:
         nested_wrapped_module = NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.RECURSIVE,
             CUDAInitMode.CUDA_NEVER,
         )
         fsdp_model = FSDP(nested_wrapped_module, self.process_group)
     devices = {p.device for p in fsdp_model.parameters()}
     self.assertEqual(1, len(devices))
     self.assertEqual(torch.device("cpu"), devices.pop())
     fsdp_model = fsdp_model.cuda()
     # Ensure fwd + backward can be performed after moving to CUDA.
     # CPU input also tests that input is correctly moved to appropriate
     # CUDA device.
     inp = fsdp_model.module.get_input(device=torch.device("cpu"))
     fsdp_model(*inp).sum().backward()
 def test_params_count_and_value(self):
     fsdp_model = FSDP(
         NestedWrappedModule(
             group=dist.distributed_c10d._get_default_group(),
             wrap_fsdp=True,
             fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
         ))
     model = NestedWrappedModule(
         group=dist.distributed_c10d._get_default_group(),
         wrap_fsdp=False,
         fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
     )
     with fsdp_model.summon_full_params():
         for p1, p2 in itertools.zip_longest(fsdp_model.parameters(),
                                             model.module.parameters()):
             self.assertEqual(p1, p2)
예제 #9
0
 def _init_model(
     self,
     nested_model: bool,
     sharding_strategy: ShardingStrategy,
     device: torch.device,
 ):
     group = dist.distributed_c10d._get_default_group()
     if nested_model:
         model = NestedWrappedModule(
             group,
             wrap_fsdp=True,
             sharding_strategy=sharding_strategy,
         )
         fsdp_model: FSDP = FSDP(
             model,
             group,
             sharding_strategy=sharding_strategy,
         ).to(device)
     else:
         fsdp_model: FSDP = self._get_wrapped_model(
             group,
             cuda_first=False,
             config={"sharding_strategy": sharding_strategy},
         )
     return fsdp_model
예제 #10
0
 def test_nested_module_apply(self):
     """Tests that ``apply()`` modifies parameter values in-place on a
     non-FSDP-root nested FSDP-wrapped model."""
     nested_wrapped_module = NestedWrappedModule.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_AFTER,
     )
     self._check_apply(nested_wrapped_module)
예제 #11
0
 def test_nested_module_apply(self):
     """
     Checks apply() modifies weights appropriately on a nested FSDP instance.
     """
     nested_module = NestedWrappedModule(self.process_group,
                                         wrap_fsdp=True,
                                         wrap_everything=True)
     fsdp_module = FSDP(nested_module, self.process_group).cuda(self.rank)
     self._check_apply(fsdp_module)
예제 #12
0
 def test_module_device_mismatches_device_id(self):
     """Tests that specifying a ``device_id`` argument to FSDP for a GPU
     module that does not match the GPU device ID raises an error."""
     context = (
         self.assertRaisesRegex(
             RuntimeError,
             f"on rank {self.rank}.*cuda:0, but is on cuda:{self.rank}"
         ) if self.rank != 0 else suppress()
     )
     with context:
         NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.RECURSIVE,
             # Move wrapped modules to CUDA before wrapping with FSDP
             cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
             # Should raise error since rank 1 is given `device_id=0` when
             # the model is on cuda:1
             fsdp_kwargs={"device_id": 0},
         )
예제 #13
0
    def test_raises_rank0_with_writeback(self):
        fsdp_model = FSDP(
            NestedWrappedModule(
                group=dist.distributed_c10d._get_default_group(),
                wrap_fsdp=True,
                fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
            ))

        with self.assertRaisesRegex(ValueError, "is not supported"):
            with fsdp_model.summon_full_params(rank0_only=True,
                                               writeback=True):
                pass
 def test_raises_rank0_with_writeback(self):
     """Tests that ``summon_full_params()`` with both ``rank0_only=True``
     and ``writeback=True`` raises an error."""
     nested_wrapped_module = NestedWrappedModule.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE,
     )
     with self.assertRaisesRegex(ValueError, "is not supported"):
         with FSDP.summon_full_params(nested_wrapped_module,
                                      rank0_only=True,
                                      writeback=True):
             pass
예제 #15
0
 def test_named_parameters_buffers(self, prefix: str, recurse: bool):
     fsdp_model = FSDP(
         NestedWrappedModule(
             group=dist.distributed_c10d._get_default_group(),
             wrap_fsdp=True,
             fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
         )
     )
     fsdp_model.register_buffer("buffer", torch.ones(1))
     model = NestedWrappedModule(
         group=dist.distributed_c10d._get_default_group(),
         wrap_fsdp=False,
         fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
     )
     model.register_buffer("buffer", torch.ones(1))
     with fsdp_model.summon_full_params(fsdp_model):
         for call in ["named_parameters", "named_buffers"]:
             for (n1, p1), (n2, p2) in itertools.zip_longest(
                 getattr(fsdp_model, call)(prefix=prefix, recurse=recurse),
                 getattr(model, call)(prefix=prefix, recurse=recurse),
             ):
                 self.assertEqual(n1, n2)
                 self.assertEqual(p1, p2)
예제 #16
0
 def test_fsdp_modules(self):
     group = dist.distributed_c10d._get_default_group()
     model = NestedWrappedModule(group, wrap_fsdp=True)
     modules = FSDP.fsdp_modules(model)
     self.assertEquals(modules, [
         model.module.get_submodule("1"),
         model.module.get_submodule("1").get_submodule("0"),
         model.module.get_submodule("2"),
     ])
     modules = FSDP.fsdp_modules(model, root_only=True)
     self.assertEqual(modules, [
         model.module.get_submodule("1"),
         model.module.get_submodule("2"),
     ])
예제 #17
0
 def test_fsdp_modules(self):
     nested_wrapped_module = NestedWrappedModule.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE,
     )
     modules = FSDP.fsdp_modules(nested_wrapped_module)
     self.assertEquals(modules, [
         nested_wrapped_module.module.get_submodule("1"),
         nested_wrapped_module.module.get_submodule("1").get_submodule("0"),
         nested_wrapped_module.module.get_submodule("2"),
     ])
     modules = FSDP.fsdp_modules(nested_wrapped_module, root_only=True)
     self.assertEqual(modules, [
         nested_wrapped_module.module.get_submodule("1"),
         nested_wrapped_module.module.get_submodule("2"),
     ])
예제 #18
0
 def test_module_device_mismatches_device_id(self):
     """
     FSDP raises errors when module is on a GPU that does
     not match device_id.
     """
     context = (self.assertRaisesRegex(
         RuntimeError,
         f"on rank {self.rank}.*cuda:0, but is on cuda:{self.rank}")
                if self.rank != 0 else suppress())
     with context:
         mod = NestedWrappedModule(
             group=self.process_group,
             wrap_fsdp=True,
             wrap_everything=True,
             # Would move module to current cuda device before
             # wrapping with FSDP
             fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
             # Rank 1 is given device id 0, but model is on cuda:1,
             # should throw errors.
             device_id=0)
예제 #19
0
    def test_cpu_init_with_sync_module_raises(self):
        """
        CPU module with sync_module_states=True throws appropriate
        error because it requires GPU comm.
        """
        mod = NestedWrappedModule(
            group=self.process_group,
            wrap_fsdp=False,
            wrap_everything=True,
            fsdp_init_mode=FSDPInitMode.CUDA_NEVER,
        )
        with self.assertRaisesRegex(
                ValueError,
                "Module has CPU parameters, but sync_module_states=True is specified."
        ):
            FSDP(mod, sync_module_states=True)

        # Specifying device_id with sync_module_states=True works.
        FSDP(mod,
             device_id=torch.cuda.current_device(),
             sync_module_states=True)
예제 #20
0
    def test_cpu_init_with_sync_module_states(self):
        """Tests that passing ``sync_module_states=True`` raises an error for
        a CPU module since the synchronization requires GPU communication,
        while additionally passing ``device_id`` does not raise an error."""
        nested_wrapped_module = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_NEVER,
        )
        with self.assertRaisesRegex(
            ValueError,
            "Module has CPU parameters, but sync_module_states=True is specified."
        ):
            FSDP(nested_wrapped_module, self.process_group, sync_module_states=True)

        # Specifying device_id with sync_module_states=True works.
        FSDP(
            nested_wrapped_module,
            self.process_group,
            device_id=torch.cuda.current_device(),
            sync_module_states=True,
        )
예제 #21
0
    def test_fsdp_device_id(self, use_index):
        """
        Tests the FSDP ``device_id`` argument:
          - Wrapping a CPU module should move the module to the GPU matching
          ``device_id``
          - Wrapping a GPU module already on the GPU matching ``device_id``
          should not raise an error
          - Wrapping a GPU module already on GPU and passing a GPU device
          without specifying a device ID (i.e. ``torch.device("cuda")``) warns
        """
        dev_id = (
            torch.cuda.current_device() if use_index
            else torch.device("cuda", torch.cuda.current_device())
        )

        def _check_device_matches(module, device_id):
            """Checks that the ``FlatParameter``s in ``module`` have device
            matching ``device_id``."""
            devices = {
                p.device for p in module.parameters()
                if isinstance(p, FlatParameter)
            }
            assert len(devices) > 0
            self.assertEqual(1, len(devices))
            found_device = devices.pop()
            if use_index and not isinstance(device_id, torch.device):
                device = torch.device("cuda", device_id)
            else:
                device = device_id
            self.assertEqual(found_device, device)

        # Check that FSDP parameters are moved to `device_id` for a CPU module
        nested_wrapped_module = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_NEVER,
            fsdp_kwargs={"device_id": dev_id},
        )
        _check_device_matches(nested_wrapped_module, dev_id)
        # Check that specifying `device_id` for a GPU module already on that
        # device does not raise an error
        nested_wrapped_module = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
            fsdp_kwargs={"device_id": dev_id},
        )
        _check_device_matches(nested_wrapped_module, dev_id)
        # Check that passing in `torch.device("cuda")` for a GPU module warns
        regex = "does not have explicit index"
        context = self.assertWarnsRegex(
            expected_warning=UserWarning, expected_regex=regex
        )
        with context:
            nested_wrapped_module = NestedWrappedModule.init(
                self.process_group,
                FSDPInitMode.RECURSIVE,
                CUDAInitMode.CUDA_BEFORE,
                fsdp_kwargs={"device_id": torch.device("cuda")}
            )
        _check_device_matches(
            nested_wrapped_module,
            torch.device("cuda", torch.cuda.current_device())
        )
예제 #22
0
    def test_communication(
        self,
        nested_model: bool,
        use_no_sync: bool,
        sharding_strategy: ShardingStrategy,
    ):
        """
        Tests FSDP's communication cost in terms of calls to collective
        communication primitives (i.e. all-gather and reduce-scatter).

        Arguments:
            nested_model (bool): If ``True``, uses ``NestedWrappedModule``,
                which has nested FSDP instances; if ``False``, uses the default
                model, which does not have nested FSDP instances.
            use_no_sync (bool): If ``True``, uses the ``no_sync()`` context
                manager to accumulate gradients for one iteration before
                synchronizing gradients in the second iteration; if ``False``,
                only checks the communication cost of normal execution.
        """
        # Initialize the model and inputs
        group = dist.distributed_c10d._get_default_group()
        device = torch.device("cuda")
        if nested_model:
            model = NestedWrappedModule(group,
                                        wrap_fsdp=True,
                                        sharding_strategy=sharding_strategy)
            fsdp_model: FSDP = FSDP(
                model, group, sharding_strategy=sharding_strategy).to(device)
        else:
            fsdp_model: FSDP = self._get_wrapped_model(
                group,
                cuda_first=False,
                config={"sharding_strategy": sharding_strategy},
            )
        batch = fsdp_model.module.get_input(device)

        # Count the number of FSDP instances
        num_fsdp = 0
        for m in fsdp_model.modules():  # includes `self`
            if isinstance(m, FSDP) and len(m.params) > 0:
                num_fsdp += 1

        # Count the number of all-gathers and reduce-scatters by mocking
        # `_all_gather_base()` and `_reducer_scatter_base()`
        #
        # with `no_sync()`:
        #   Forward: when no_sync mode, root will not free full parameters,
        #   thus there will be `num_fsdp-1` all-gathers.
        #   Backward: `num_fsdp` - 1 all-gathers (only excluding the root)
        # without `no_sync()`:
        #   Forward: all instances free full parameters, thus there will be ``
        #   `num_fsdp` all-gathers.
        #   Backward: `num_fsdp` - 1 all-gathers (only excluding the root)
        expected_num_all_gather_no_sync = (num_fsdp - 1) + (num_fsdp - 1)
        expected_num_all_gather_sync = num_fsdp + (num_fsdp - 1)
        expected_num_reduce_scatter_no_sync = 0
        expected_num_reduce_scatter_sync = num_fsdp

        num_no_sync_iters = 3
        num_sync_iters = 3
        with patch("torch.distributed._all_gather_base") as mock_all_gather, \
                patch("torch.distributed._reduce_scatter_base") as mock_reduce_scatter:

            def reset_mocks():
                mock_all_gather.reset_mock()
                mock_reduce_scatter.reset_mock()

            if use_no_sync:
                # Check the communication cost when using `no_sync()`
                for i in range(num_no_sync_iters):
                    reset_mocks()
                    with fsdp_model.no_sync():
                        output = fsdp_model(*batch)
                        loss = fsdp_model.module.get_loss(batch, output)
                        loss.backward()
                    num_all_gather = mock_all_gather.call_count
                    num_reduce_scatter = mock_reduce_scatter.call_count
                    # in the first iteration, all fsdp instances including root
                    # need to all_gather shards in the forward pass.
                    if i == 0:
                        expected_num_all_gather_no_sync_updated = expected_num_all_gather_no_sync + 1
                        # in the first iteration, all fsdp instances need to all_gather shards
                        # in the forward pass
                        if sharding_strategy == ShardingStrategy.SHARD_GRAD_OP:
                            expected_num_all_gather_no_sync_updated = num_fsdp
                    else:
                        expected_num_all_gather_no_sync_updated = expected_num_all_gather_no_sync
                        # full parameters are not freed after first iteration in the no_sync mode
                        if sharding_strategy == ShardingStrategy.SHARD_GRAD_OP:
                            expected_num_all_gather_no_sync_updated = 0
                    self.assertEqual(
                        num_all_gather,
                        expected_num_all_gather_no_sync_updated,
                        f"Expected {expected_num_all_gather_no_sync_updated} "
                        f"all-gathers but saw {num_all_gather} all-gathers "
                        f"when using `no_sync()`")
                    self.assertEqual(
                        num_reduce_scatter,
                        expected_num_reduce_scatter_no_sync,
                        f"Expected {expected_num_reduce_scatter_no_sync} "
                        f"reduce-scatters but saw {num_reduce_scatter} "
                        "reduce-scatters when using `no_sync()`")

            # Check the normal communication cost (when not using `no_sync()`)
            for i in range(num_sync_iters):
                reset_mocks()
                output = fsdp_model(*batch)
                loss = fsdp_model.module.get_loss(batch, output)
                loss.backward()
                num_all_gather = mock_all_gather.call_count
                num_reduce_scatter = mock_reduce_scatter.call_count
                # previous non-sync iteration does not free full parameters for
                # the root instance.
                if use_no_sync and i == 0:
                    expected_num_all_gather_sync_updated = expected_num_all_gather_sync - 1
                    # previous non-sync iteration does not free full parameters
                    if sharding_strategy == ShardingStrategy.SHARD_GRAD_OP:
                        expected_num_all_gather_sync_updated = 0
                else:
                    expected_num_all_gather_sync_updated = expected_num_all_gather_sync
                    # no need to all_gather shards in the backward pass when in
                    # SHARD_GRAD_OP mode
                    if sharding_strategy == ShardingStrategy.SHARD_GRAD_OP:
                        expected_num_all_gather_sync_updated = num_fsdp
                self.assertEqual(
                    num_all_gather, expected_num_all_gather_sync_updated,
                    f"Expected {expected_num_all_gather_sync_updated} all-gathers "
                    f"but saw {num_all_gather} all-gathers when not using "
                    "`no_sync()`")
                self.assertEqual(
                    num_reduce_scatter, expected_num_reduce_scatter_sync,
                    f"Expected {expected_num_reduce_scatter_sync} reduce-"
                    f"scatters but saw {num_reduce_scatter} reduce-scatters "
                    "when not using `no_sync()`")
예제 #23
0
    def test_communication(
        self,
        nested_model: bool,
        use_no_sync: bool,
    ):
        """
        Tests FSDP's communication cost in terms of calls to collective
        communication primitives (i.e. all-gather and reduce-scatter).

        Arguments:
            nested_model (bool): If ``True``, uses ``NestedWrappedModule``,
                which has nested FSDP instances; if ``False``, uses the default
                model, which does not have nested FSDP instances.
            use_no_sync (bool): If ``True``, uses the ``no_sync()`` context
                manager to accumulate gradients for one iteration before
                synchronizing gradients in the second iteration; if ``False``,
                only checks the communication cost of normal execution.
        """
        # Initialize the model and inputs
        group = dist.distributed_c10d._get_default_group()
        device = torch.device("cuda")
        if nested_model:
            model = NestedWrappedModule(group, wrap_fsdp=True)
            fsdp_model: FSDP = FSDP(model, group).to(device)
        else:
            fsdp_model: FSDP = self._get_wrapped_model(group, cuda_first=False)
        batch = fsdp_model.module.get_input(device)

        # Count the number of FSDP instances
        num_fsdp = 0
        for m in fsdp_model.modules():  # includes `self`
            if isinstance(m, FSDP) and len(m.params) > 0:
                num_fsdp += 1

        # Count the number of all-gathers and reduce-scatters by mocking
        # `_all_gather_base()` and `_reducer_scatter_base()`
        # Both with and without `no_sync()`:
        #   Forward: `num_fsdp` all-gathers
        #   Backward: `num_fsdp` - 1 all-gathers (only excluding the root)
        expected_num_all_gather_no_sync = num_fsdp + (num_fsdp - 1)
        expected_num_all_gather_sync = num_fsdp + (num_fsdp - 1)
        expected_num_reduce_scatter_no_sync = 0
        expected_num_reduce_scatter_sync = num_fsdp

        num_no_sync_iters = 3
        num_sync_iters = 3
        with patch("torch.distributed._all_gather_base") as mock_all_gather, \
                patch("torch.distributed._reduce_scatter_base") as mock_reduce_scatter:

            def reset_mocks():
                mock_all_gather.reset_mock()
                mock_reduce_scatter.reset_mock()

            if use_no_sync:
                # Check the communication cost when using `no_sync()`
                for _ in range(num_no_sync_iters):
                    reset_mocks()
                    with fsdp_model.no_sync():
                        output = fsdp_model(*batch)
                        loss = fsdp_model.module.get_loss(batch, output)
                        loss.backward()
                    num_all_gather = mock_all_gather.call_count
                    num_reduce_scatter = mock_reduce_scatter.call_count
                    assert num_all_gather == expected_num_all_gather_no_sync, \
                        f"Expected {expected_num_all_gather_no_sync} " \
                        f"all-gathers but saw {num_all_gather} all-gathers " \
                        f"when using `no_sync()`"
                    assert num_reduce_scatter == \
                        expected_num_reduce_scatter_no_sync, \
                        f"Expected {expected_num_reduce_scatter_no_sync} " \
                        f"reduce-scatters but saw {num_reduce_scatter} " \
                        "reduce-scatters when using `no_sync()`"

            # Check the normal communication cost (when not using `no_sync()`)
            for _ in range(num_sync_iters):
                reset_mocks()
                output = fsdp_model(*batch)
                loss = fsdp_model.module.get_loss(batch, output)
                loss.backward()
                num_all_gather = mock_all_gather.call_count
                num_reduce_scatter = mock_reduce_scatter.call_count
                assert num_all_gather == expected_num_all_gather_sync, \
                    f"Expected {expected_num_all_gather_sync} all-gathers " \
                    f"but saw {num_all_gather} all-gathers when not using " \
                    "`no_sync()`"
                assert num_reduce_scatter == \
                    expected_num_reduce_scatter_sync, \
                    f"Expected {expected_num_reduce_scatter_sync} reduce-" \
                    f"scatters but saw {num_reduce_scatter} reduce-scatters " \
                    "when not using `no_sync()`"