def test_params_count_and_value( self, rank0_only: bool, offload_to_cpu: bool, mixed_precision: bool, ): mixed_precision = MixedPrecision() if mixed_precision else None model = NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) fsdp_model = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([p.to(dev) for p in model.module.parameters()] if not rank0_only or self.rank == 0 else list( p.clone() for p in fsdp_model.parameters())) with FSDP.summon_full_params(fsdp_model, rank0_only=rank0_only, writeback=not rank0_only): for p1, p2 in itertools.zip_longest(fsdp_model.parameters(), params_to_compare): self.assertEqual(p1, p2) # CPU offload should restore the param device param = next(fsdp_model.parameters()) self.assertTrue( param.device == torch.device("cuda", torch.cuda.current_device()))
def test_named_parameters_buffers(self, prefix: str, recurse: bool): """Tests that ``named_parameters()`` and ``named_buffers()`` for a top-level FSDP-wrapped model matches their behavior for the equivalent non-wrapped model.""" model = NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) model.register_buffer("buffer", torch.ones(1)) # `named_parameters()` and `named_buffers` will contain FSDP prefixes # if called on a non-FSDP root module fsdp_model = FSDP( NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ), self.process_group, ) fsdp_model.register_buffer("buffer", torch.ones(1)) with FSDP.summon_full_params(fsdp_model): for call in ["named_parameters", "named_buffers"]: for (n1, p1), (n2, p2) in itertools.zip_longest( getattr(fsdp_model, call)(prefix=prefix, recurse=recurse), getattr(model, call)(prefix=prefix, recurse=recurse), ): self.assertEqual(n1, n2) self.assertEqual(p1, p2)
def test_params_count_and_value(self, rank0_only, offload_to_cpu, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None fsdp_model = FSDP( NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, mixed_precision=mixed_precision, ), mixed_precision=mixed_precision, ) model = NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=False, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([p.to(dev) for p in model.module.parameters()] if not rank0_only or self.rank == 0 else list( p.clone() for p in fsdp_model.parameters())) with fsdp_model.summon_full_params(fsdp_model, rank0_only=rank0_only, writeback=not rank0_only): for p1, p2 in itertools.zip_longest(fsdp_model.parameters(), params_to_compare): self.assertEqual(p1, p2) # CPU offload should restore the param device param = next(fsdp_model.parameters()) self.assertTrue( param.device == torch.device("cuda", torch.cuda.current_device()))
def test_fsdp_cpu_init_stays_on_cpu(self): """ Ensure that CPU model input stays on CPU after FSDP init and we log a warning. """ torch.cuda.set_device(self.rank) regex = "Module is put on CPU" context = self.assertWarnsRegex(expected_warning=UserWarning, expected_regex=regex) with context: mod = NestedWrappedModule( group=self.process_group, wrap_fsdp=True, wrap_everything=True, fsdp_init_mode=FSDPInitMode.CUDA_NEVER, ) fsdp = FSDP(mod) devices = {p.device for p in fsdp.parameters()} self.assertEqual(1, len(devices)) self.assertEqual(torch.device("cpu"), devices.pop()) fsdp = fsdp.cuda() # Ensure fwd + backward can be performed after moving to CUDA. # CPU input also tests that input is correctly moved to appropriate # CUDA device. inp = mod.get_input(device=torch.device("cpu")) fsdp(inp[0]).sum().backward()
def test_fsdp_device_id(self, use_index): """ If CPU module is passed into FSDP with device_id argument, it is moved to the GPU with that device_id. """ dev_id = ( torch.cuda.current_device() if use_index else torch.device("cuda", torch.cuda.current_device()) ) def _check_device_matches(fsdp, dev_id): devices = {p.device for p in fsdp.parameters()} self.assertEqual(1, len(devices)) found_dev = devices.pop() if use_index and not isinstance(dev_id, torch.device): dev_id = torch.device("cuda", dev_id) self.assertEqual(found_dev, dev_id) mod = NestedWrappedModule( group=self.process_group, wrap_fsdp=True, wrap_everything=True, fsdp_init_mode=FSDPInitMode.CUDA_NEVER, device_id=dev_id ) fsdp = FSDP(mod, device_id=dev_id) # Check FSDP parameters are moved. _check_device_matches(fsdp, dev_id) # device_id matching module device before FSDP construction # should not throw errors. mod = NestedWrappedModule( group=self.process_group, wrap_fsdp=True, wrap_everything=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, device_id=dev_id ) fsdp = FSDP(mod, device_id=dev_id) _check_device_matches(fsdp, dev_id) # Passing in torch.device("cuda") should work. regex = "does not have explicit index" context = self.assertWarnsRegex( expected_warning=UserWarning, expected_regex=regex ) with context: mod = NestedWrappedModule( group=self.process_group, wrap_fsdp=True, wrap_everything=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, device_id=torch.device("cuda") ) fsdp = FSDP(mod, device_id=torch.device("cuda")) _check_device_matches(fsdp, torch.device("cuda", torch.cuda.current_device()))
def _init_model( self, nested_model: bool, sharding_strategy: ShardingStrategy, device: torch.device, ): fsdp_kwargs = {"sharding_strategy": sharding_strategy} if nested_model: model = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, fsdp_kwargs, ) fsdp_model: FSDP = FSDP( model, self.process_group, **fsdp_kwargs, ).to(device) else: fsdp_model: FSDP = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs, ) return fsdp_model
def test_fsdp_cpu_init_stays_on_cpu(self): """Tests that passing a CPU module to FSDP preserves that the wrapped module is on CPU after FSDP initialization, albeit after loging a warning, and that FSDP moves CPU input to GPU before the forward.""" torch.cuda.set_device(self.rank) regex = "Module is put on CPU" context = self.assertWarnsRegex( expected_warning=UserWarning, expected_regex=regex ) with context: nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_NEVER, ) fsdp_model = FSDP(nested_wrapped_module, self.process_group) devices = {p.device for p in fsdp_model.parameters()} self.assertEqual(1, len(devices)) self.assertEqual(torch.device("cpu"), devices.pop()) fsdp_model = fsdp_model.cuda() # Ensure fwd + backward can be performed after moving to CUDA. # CPU input also tests that input is correctly moved to appropriate # CUDA device. inp = fsdp_model.module.get_input(device=torch.device("cpu")) fsdp_model(*inp).sum().backward()
def test_params_count_and_value(self): fsdp_model = FSDP( NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, )) model = NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=False, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) with fsdp_model.summon_full_params(): for p1, p2 in itertools.zip_longest(fsdp_model.parameters(), model.module.parameters()): self.assertEqual(p1, p2)
def _init_model( self, nested_model: bool, sharding_strategy: ShardingStrategy, device: torch.device, ): group = dist.distributed_c10d._get_default_group() if nested_model: model = NestedWrappedModule( group, wrap_fsdp=True, sharding_strategy=sharding_strategy, ) fsdp_model: FSDP = FSDP( model, group, sharding_strategy=sharding_strategy, ).to(device) else: fsdp_model: FSDP = self._get_wrapped_model( group, cuda_first=False, config={"sharding_strategy": sharding_strategy}, ) return fsdp_model
def test_nested_module_apply(self): """Tests that ``apply()`` modifies parameter values in-place on a non-FSDP-root nested FSDP-wrapped model.""" nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, ) self._check_apply(nested_wrapped_module)
def test_nested_module_apply(self): """ Checks apply() modifies weights appropriately on a nested FSDP instance. """ nested_module = NestedWrappedModule(self.process_group, wrap_fsdp=True, wrap_everything=True) fsdp_module = FSDP(nested_module, self.process_group).cuda(self.rank) self._check_apply(fsdp_module)
def test_module_device_mismatches_device_id(self): """Tests that specifying a ``device_id`` argument to FSDP for a GPU module that does not match the GPU device ID raises an error.""" context = ( self.assertRaisesRegex( RuntimeError, f"on rank {self.rank}.*cuda:0, but is on cuda:{self.rank}" ) if self.rank != 0 else suppress() ) with context: NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, # Move wrapped modules to CUDA before wrapping with FSDP cuda_init_mode=CUDAInitMode.CUDA_BEFORE, # Should raise error since rank 1 is given `device_id=0` when # the model is on cuda:1 fsdp_kwargs={"device_id": 0}, )
def test_raises_rank0_with_writeback(self): fsdp_model = FSDP( NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, )) with self.assertRaisesRegex(ValueError, "is not supported"): with fsdp_model.summon_full_params(rank0_only=True, writeback=True): pass
def test_raises_rank0_with_writeback(self): """Tests that ``summon_full_params()`` with both ``rank0_only=True`` and ``writeback=True`` raises an error.""" nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, ) with self.assertRaisesRegex(ValueError, "is not supported"): with FSDP.summon_full_params(nested_wrapped_module, rank0_only=True, writeback=True): pass
def test_named_parameters_buffers(self, prefix: str, recurse: bool): fsdp_model = FSDP( NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) ) fsdp_model.register_buffer("buffer", torch.ones(1)) model = NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=False, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) model.register_buffer("buffer", torch.ones(1)) with fsdp_model.summon_full_params(fsdp_model): for call in ["named_parameters", "named_buffers"]: for (n1, p1), (n2, p2) in itertools.zip_longest( getattr(fsdp_model, call)(prefix=prefix, recurse=recurse), getattr(model, call)(prefix=prefix, recurse=recurse), ): self.assertEqual(n1, n2) self.assertEqual(p1, p2)
def test_fsdp_modules(self): group = dist.distributed_c10d._get_default_group() model = NestedWrappedModule(group, wrap_fsdp=True) modules = FSDP.fsdp_modules(model) self.assertEquals(modules, [ model.module.get_submodule("1"), model.module.get_submodule("1").get_submodule("0"), model.module.get_submodule("2"), ]) modules = FSDP.fsdp_modules(model, root_only=True) self.assertEqual(modules, [ model.module.get_submodule("1"), model.module.get_submodule("2"), ])
def test_fsdp_modules(self): nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, ) modules = FSDP.fsdp_modules(nested_wrapped_module) self.assertEquals(modules, [ nested_wrapped_module.module.get_submodule("1"), nested_wrapped_module.module.get_submodule("1").get_submodule("0"), nested_wrapped_module.module.get_submodule("2"), ]) modules = FSDP.fsdp_modules(nested_wrapped_module, root_only=True) self.assertEqual(modules, [ nested_wrapped_module.module.get_submodule("1"), nested_wrapped_module.module.get_submodule("2"), ])
def test_module_device_mismatches_device_id(self): """ FSDP raises errors when module is on a GPU that does not match device_id. """ context = (self.assertRaisesRegex( RuntimeError, f"on rank {self.rank}.*cuda:0, but is on cuda:{self.rank}") if self.rank != 0 else suppress()) with context: mod = NestedWrappedModule( group=self.process_group, wrap_fsdp=True, wrap_everything=True, # Would move module to current cuda device before # wrapping with FSDP fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, # Rank 1 is given device id 0, but model is on cuda:1, # should throw errors. device_id=0)
def test_cpu_init_with_sync_module_raises(self): """ CPU module with sync_module_states=True throws appropriate error because it requires GPU comm. """ mod = NestedWrappedModule( group=self.process_group, wrap_fsdp=False, wrap_everything=True, fsdp_init_mode=FSDPInitMode.CUDA_NEVER, ) with self.assertRaisesRegex( ValueError, "Module has CPU parameters, but sync_module_states=True is specified." ): FSDP(mod, sync_module_states=True) # Specifying device_id with sync_module_states=True works. FSDP(mod, device_id=torch.cuda.current_device(), sync_module_states=True)
def test_cpu_init_with_sync_module_states(self): """Tests that passing ``sync_module_states=True`` raises an error for a CPU module since the synchronization requires GPU communication, while additionally passing ``device_id`` does not raise an error.""" nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_NEVER, ) with self.assertRaisesRegex( ValueError, "Module has CPU parameters, but sync_module_states=True is specified." ): FSDP(nested_wrapped_module, self.process_group, sync_module_states=True) # Specifying device_id with sync_module_states=True works. FSDP( nested_wrapped_module, self.process_group, device_id=torch.cuda.current_device(), sync_module_states=True, )
def test_fsdp_device_id(self, use_index): """ Tests the FSDP ``device_id`` argument: - Wrapping a CPU module should move the module to the GPU matching ``device_id`` - Wrapping a GPU module already on the GPU matching ``device_id`` should not raise an error - Wrapping a GPU module already on GPU and passing a GPU device without specifying a device ID (i.e. ``torch.device("cuda")``) warns """ dev_id = ( torch.cuda.current_device() if use_index else torch.device("cuda", torch.cuda.current_device()) ) def _check_device_matches(module, device_id): """Checks that the ``FlatParameter``s in ``module`` have device matching ``device_id``.""" devices = { p.device for p in module.parameters() if isinstance(p, FlatParameter) } assert len(devices) > 0 self.assertEqual(1, len(devices)) found_device = devices.pop() if use_index and not isinstance(device_id, torch.device): device = torch.device("cuda", device_id) else: device = device_id self.assertEqual(found_device, device) # Check that FSDP parameters are moved to `device_id` for a CPU module nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_NEVER, fsdp_kwargs={"device_id": dev_id}, ) _check_device_matches(nested_wrapped_module, dev_id) # Check that specifying `device_id` for a GPU module already on that # device does not raise an error nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs={"device_id": dev_id}, ) _check_device_matches(nested_wrapped_module, dev_id) # Check that passing in `torch.device("cuda")` for a GPU module warns regex = "does not have explicit index" context = self.assertWarnsRegex( expected_warning=UserWarning, expected_regex=regex ) with context: nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs={"device_id": torch.device("cuda")} ) _check_device_matches( nested_wrapped_module, torch.device("cuda", torch.cuda.current_device()) )
def test_communication( self, nested_model: bool, use_no_sync: bool, sharding_strategy: ShardingStrategy, ): """ Tests FSDP's communication cost in terms of calls to collective communication primitives (i.e. all-gather and reduce-scatter). Arguments: nested_model (bool): If ``True``, uses ``NestedWrappedModule``, which has nested FSDP instances; if ``False``, uses the default model, which does not have nested FSDP instances. use_no_sync (bool): If ``True``, uses the ``no_sync()`` context manager to accumulate gradients for one iteration before synchronizing gradients in the second iteration; if ``False``, only checks the communication cost of normal execution. """ # Initialize the model and inputs group = dist.distributed_c10d._get_default_group() device = torch.device("cuda") if nested_model: model = NestedWrappedModule(group, wrap_fsdp=True, sharding_strategy=sharding_strategy) fsdp_model: FSDP = FSDP( model, group, sharding_strategy=sharding_strategy).to(device) else: fsdp_model: FSDP = self._get_wrapped_model( group, cuda_first=False, config={"sharding_strategy": sharding_strategy}, ) batch = fsdp_model.module.get_input(device) # Count the number of FSDP instances num_fsdp = 0 for m in fsdp_model.modules(): # includes `self` if isinstance(m, FSDP) and len(m.params) > 0: num_fsdp += 1 # Count the number of all-gathers and reduce-scatters by mocking # `_all_gather_base()` and `_reducer_scatter_base()` # # with `no_sync()`: # Forward: when no_sync mode, root will not free full parameters, # thus there will be `num_fsdp-1` all-gathers. # Backward: `num_fsdp` - 1 all-gathers (only excluding the root) # without `no_sync()`: # Forward: all instances free full parameters, thus there will be `` # `num_fsdp` all-gathers. # Backward: `num_fsdp` - 1 all-gathers (only excluding the root) expected_num_all_gather_no_sync = (num_fsdp - 1) + (num_fsdp - 1) expected_num_all_gather_sync = num_fsdp + (num_fsdp - 1) expected_num_reduce_scatter_no_sync = 0 expected_num_reduce_scatter_sync = num_fsdp num_no_sync_iters = 3 num_sync_iters = 3 with patch("torch.distributed._all_gather_base") as mock_all_gather, \ patch("torch.distributed._reduce_scatter_base") as mock_reduce_scatter: def reset_mocks(): mock_all_gather.reset_mock() mock_reduce_scatter.reset_mock() if use_no_sync: # Check the communication cost when using `no_sync()` for i in range(num_no_sync_iters): reset_mocks() with fsdp_model.no_sync(): output = fsdp_model(*batch) loss = fsdp_model.module.get_loss(batch, output) loss.backward() num_all_gather = mock_all_gather.call_count num_reduce_scatter = mock_reduce_scatter.call_count # in the first iteration, all fsdp instances including root # need to all_gather shards in the forward pass. if i == 0: expected_num_all_gather_no_sync_updated = expected_num_all_gather_no_sync + 1 # in the first iteration, all fsdp instances need to all_gather shards # in the forward pass if sharding_strategy == ShardingStrategy.SHARD_GRAD_OP: expected_num_all_gather_no_sync_updated = num_fsdp else: expected_num_all_gather_no_sync_updated = expected_num_all_gather_no_sync # full parameters are not freed after first iteration in the no_sync mode if sharding_strategy == ShardingStrategy.SHARD_GRAD_OP: expected_num_all_gather_no_sync_updated = 0 self.assertEqual( num_all_gather, expected_num_all_gather_no_sync_updated, f"Expected {expected_num_all_gather_no_sync_updated} " f"all-gathers but saw {num_all_gather} all-gathers " f"when using `no_sync()`") self.assertEqual( num_reduce_scatter, expected_num_reduce_scatter_no_sync, f"Expected {expected_num_reduce_scatter_no_sync} " f"reduce-scatters but saw {num_reduce_scatter} " "reduce-scatters when using `no_sync()`") # Check the normal communication cost (when not using `no_sync()`) for i in range(num_sync_iters): reset_mocks() output = fsdp_model(*batch) loss = fsdp_model.module.get_loss(batch, output) loss.backward() num_all_gather = mock_all_gather.call_count num_reduce_scatter = mock_reduce_scatter.call_count # previous non-sync iteration does not free full parameters for # the root instance. if use_no_sync and i == 0: expected_num_all_gather_sync_updated = expected_num_all_gather_sync - 1 # previous non-sync iteration does not free full parameters if sharding_strategy == ShardingStrategy.SHARD_GRAD_OP: expected_num_all_gather_sync_updated = 0 else: expected_num_all_gather_sync_updated = expected_num_all_gather_sync # no need to all_gather shards in the backward pass when in # SHARD_GRAD_OP mode if sharding_strategy == ShardingStrategy.SHARD_GRAD_OP: expected_num_all_gather_sync_updated = num_fsdp self.assertEqual( num_all_gather, expected_num_all_gather_sync_updated, f"Expected {expected_num_all_gather_sync_updated} all-gathers " f"but saw {num_all_gather} all-gathers when not using " "`no_sync()`") self.assertEqual( num_reduce_scatter, expected_num_reduce_scatter_sync, f"Expected {expected_num_reduce_scatter_sync} reduce-" f"scatters but saw {num_reduce_scatter} reduce-scatters " "when not using `no_sync()`")
def test_communication( self, nested_model: bool, use_no_sync: bool, ): """ Tests FSDP's communication cost in terms of calls to collective communication primitives (i.e. all-gather and reduce-scatter). Arguments: nested_model (bool): If ``True``, uses ``NestedWrappedModule``, which has nested FSDP instances; if ``False``, uses the default model, which does not have nested FSDP instances. use_no_sync (bool): If ``True``, uses the ``no_sync()`` context manager to accumulate gradients for one iteration before synchronizing gradients in the second iteration; if ``False``, only checks the communication cost of normal execution. """ # Initialize the model and inputs group = dist.distributed_c10d._get_default_group() device = torch.device("cuda") if nested_model: model = NestedWrappedModule(group, wrap_fsdp=True) fsdp_model: FSDP = FSDP(model, group).to(device) else: fsdp_model: FSDP = self._get_wrapped_model(group, cuda_first=False) batch = fsdp_model.module.get_input(device) # Count the number of FSDP instances num_fsdp = 0 for m in fsdp_model.modules(): # includes `self` if isinstance(m, FSDP) and len(m.params) > 0: num_fsdp += 1 # Count the number of all-gathers and reduce-scatters by mocking # `_all_gather_base()` and `_reducer_scatter_base()` # Both with and without `no_sync()`: # Forward: `num_fsdp` all-gathers # Backward: `num_fsdp` - 1 all-gathers (only excluding the root) expected_num_all_gather_no_sync = num_fsdp + (num_fsdp - 1) expected_num_all_gather_sync = num_fsdp + (num_fsdp - 1) expected_num_reduce_scatter_no_sync = 0 expected_num_reduce_scatter_sync = num_fsdp num_no_sync_iters = 3 num_sync_iters = 3 with patch("torch.distributed._all_gather_base") as mock_all_gather, \ patch("torch.distributed._reduce_scatter_base") as mock_reduce_scatter: def reset_mocks(): mock_all_gather.reset_mock() mock_reduce_scatter.reset_mock() if use_no_sync: # Check the communication cost when using `no_sync()` for _ in range(num_no_sync_iters): reset_mocks() with fsdp_model.no_sync(): output = fsdp_model(*batch) loss = fsdp_model.module.get_loss(batch, output) loss.backward() num_all_gather = mock_all_gather.call_count num_reduce_scatter = mock_reduce_scatter.call_count assert num_all_gather == expected_num_all_gather_no_sync, \ f"Expected {expected_num_all_gather_no_sync} " \ f"all-gathers but saw {num_all_gather} all-gathers " \ f"when using `no_sync()`" assert num_reduce_scatter == \ expected_num_reduce_scatter_no_sync, \ f"Expected {expected_num_reduce_scatter_no_sync} " \ f"reduce-scatters but saw {num_reduce_scatter} " \ "reduce-scatters when using `no_sync()`" # Check the normal communication cost (when not using `no_sync()`) for _ in range(num_sync_iters): reset_mocks() output = fsdp_model(*batch) loss = fsdp_model.module.get_loss(batch, output) loss.backward() num_all_gather = mock_all_gather.call_count num_reduce_scatter = mock_reduce_scatter.call_count assert num_all_gather == expected_num_all_gather_sync, \ f"Expected {expected_num_all_gather_sync} all-gathers " \ f"but saw {num_all_gather} all-gathers when not using " \ "`no_sync()`" assert num_reduce_scatter == \ expected_num_reduce_scatter_sync, \ f"Expected {expected_num_reduce_scatter_sync} reduce-" \ f"scatters but saw {num_reduce_scatter} reduce-scatters " \ "when not using `no_sync()`"