def test_ignored_modules_transformer(self): """Tests that ignored modules' parameters are not flattened for a transformer model with shared parameters.""" # Initialize an FSDP-wrapped transformer model that has FSDP ignore # the `nn.Transformer` module's parameters model: nn.Module = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) wrapped_model = FSDP( model, self.process_group, ignored_modules=[model.transformer], ) # Check that the wrapped model's flattened parameter does not include # the ignored transformer module's parameters nonwrapped_model: nn.Module = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) total_numel = sum(p.numel() for p in nonwrapped_model.parameters()) ignored_numel = sum(p.numel() for p in nonwrapped_model.transformer.parameters()) nonignored_numel = total_numel - ignored_numel with FSDP.summon_full_params(wrapped_model): flat_param_numel = wrapped_model.params[0].numel() self.assertEqual(flat_param_numel, nonignored_numel) # Check that we can run a few iterations optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._train_model(wrapped_model, optim, 3)
def test_param_change_after_init(self, mixed_precision): """ Tests that changing FSDP model parameter values in-place after FSDP initialization persist. """ # Establish reference behavior fsdp_kwargs = {} if mixed_precision: fsdp_kwargs["mixed_precision"] = MixedPrecision() fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, fsdp_kwargs, deterministic=True, ) input = fsdp_model.module.get_input(torch.device("cuda")) ref_output = fsdp_model(*input) # Initialize the same model but change its first parameter value # in-place after FSDP initialization new_fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, fsdp_kwargs, deterministic=True, ) first_param = next(new_fsdp_model.parameters()) nn.init.normal_(first_param.data) new_output = new_fsdp_model(*input) self.assertNotEqual( ref_output, new_output, msg="new_output did not reflect change to param after init", )
def _init_model( self, nested_model: bool, sharding_strategy: ShardingStrategy, device: torch.device, ): fsdp_kwargs = {"sharding_strategy": sharding_strategy} if nested_model: model = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, fsdp_kwargs, ) fsdp_model: FSDP = FSDP( model, self.process_group, **fsdp_kwargs, ).to(device) else: fsdp_model: FSDP = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs, ) return fsdp_model
def _init_transformer_model( self, wrap: bool, device: torch.device = torch.device("cuda"), group=None, optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam, use_multiple_param_groups: bool = False, use_diff_optim_inputs: bool = False, ): if use_multiple_param_groups or use_diff_optim_inputs: # Keep these as arguments for parity with `_init_nested_model()`; # these settings are not implemented since the transformer is # wrapped with FSDP at the top-level, which means that there is # only a single flattened parameter, making these booleans vacuous raise NotImplementedError() if group is None: group = dist.distributed_c10d._get_default_group() model = TransformerWithSharedParams.init( group, FSDPInitMode.RECURSIVE if wrap else FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) optim = optim_class(model.parameters(), lr=0.01) return model, optim, None
def test_transformer_no_grad(self, mixed_precision): """Tests that for an FSDP-wrapped transformer model with shared parameters, after training for one iteration, running a forward pass in ``eval()`` mode gives the same output as running a forward pass in ``torch.no_grad()``.""" fsdp_kwargs = {} if mixed_precision: fsdp_kwargs["mixed_precision"] = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) else: fsdp_kwargs["mixed_precision"] = None fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, fsdp_kwargs, ) self._train_for_several_steps( fsdp_model, num_steps=1, autocast=False, mixed_precision=fsdp_kwargs["mixed_precision"] ) input = fsdp_model.module.get_input(torch.device("cuda")) # Run a forward in eval mode fsdp_model.eval() ref_output = fsdp_model(*input) # Run a forward in `no_grad()` and compare with torch.no_grad(): no_grad_output = fsdp_model(*input) self.assertEqual(ref_output, no_grad_output)
def _test_mixed_precision_embedding_table(self, mp_config): # Basic test to ensure int inputs are not casted which would break # modules such as embedding tables. param_dtype = mp_config.param_dtype or torch.float32 orig_reduce_scatter = dist._reduce_scatter_base test_reduce_scatter = partial( self._reduce_scatter_base_validate_mp, orig_reduce_scatter, mp_config, ) with patch_reduce_scatter(test_reduce_scatter, param_dtype): # TODO: `test_mp_embedding_reduce()` fails if we do not wrap the # entire `TransformerWithSharedParams` with a single top-level FSDP model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, {"mixed_precision": mp_config}, ) fsdp_model = FSDP(model, mixed_precision=mp_config) optim = torch.optim.SGD(fsdp_model.parameters(), lr=0.1) for _ in range(6): inp = fsdp_model.module.get_input(torch.device("cuda")) # This would fail if we casted integer module inputs such as for # embedding tables. output = fsdp_model(*inp) loss = fsdp_model.module.get_loss(inp, output).cuda() self.assertEqual(loss.dtype, param_dtype) fsdp_model.module.run_backward(loss) optim.step()
def test_pre_backward_hook_registration(self, cuda_first: bool): """Tests that FSDP pre-backward hooks are registered on forward pass outputs.""" fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE if cuda_first else CUDAInitMode.CUDA_AFTER, ) self._test_pre_backward_hook_registration(fsdp_model)
def test_transformer_module_apply(self): """Tests that ``apply()`` modifies parameter values in-place on an FSDP-wrapped transformer model with shared parameters.""" transformer = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, ) self._check_apply(transformer)
def test_apply_in_summon_raises_error(self): """Tests that calling ``apply()`` on an FSDP instance inside the ``summon_full_params()`` context raises an error.""" transformer = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, ) with transformer.summon_full_params(transformer): with self.assertRaisesRegex(ValueError, "expected to be in states"): transformer.apply(self._init_linear_weights)
def test_pre_backward_hook_registration_after_state_dict(self): """Tests that FSDP pre-backward hooks are registered on forward pass outputs after saving and loading the model from a checkpoint.""" fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, ) self._train_for_several_steps(fsdp_model, num_steps=2, autocast=False) state_dict = fsdp_model.state_dict() fsdp_model.load_state_dict(state_dict) self._test_pre_backward_hook_registration(fsdp_model)
def test_register_functions_called(self, cuda_first: bool, mixed_precision: bool): """Tests that ``_register_{pre|post}_backward_hooks()`` are called during the FSDP forward.""" fsdp_kwargs = {} if mixed_precision: fsdp_kwargs["mixed_precision"] = MixedPrecision() fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE if cuda_first else CUDAInitMode.CUDA_AFTER, fsdp_kwargs, ) input = fsdp_model.module.get_input(torch.device("cuda")) fsdp_model._register_pre_backward_hooks = mock.MagicMock(return_value=None) fsdp_model._register_post_backward_hooks = mock.MagicMock(return_value=None) self.assertFalse(fsdp_model._register_post_backward_hooks.called) self.assertFalse(fsdp_model._register_pre_backward_hooks.called) fsdp_model(*input) self.assertTrue(fsdp_model._register_post_backward_hooks.called) self.assertTrue(fsdp_model._register_pre_backward_hooks.called)
def test_device_id_auto_wrap(self): """Tests that ``auto_wrap_policy`` propagates ``device_id`` to all nested FSDP instances.""" auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, ) fsdp_kwargs = { "auto_wrap_policy": auto_wrap_policy, "device_id": torch.cuda.current_device(), } fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs, ) for fsdp_module in FSDP.fsdp_modules(fsdp_model): self.assertEqual( fsdp_module.device_id, torch.device("cuda", torch.cuda.current_device()), )
def test_transformer_auto_wrap_policy(self): """Tests the ``transformer_auto_wrap_policy``.""" auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer }, ) fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs, ) modules = list(fsdp_model.modules()) encoder_layers = set(fsdp_model.module.transformer.encoder.layers) decoder_layers = set(fsdp_model.module.transformer.decoder.layers) for module in modules: if module is fsdp_model or module in encoder_layers or module in decoder_layers: self.assertTrue(isinstance(module, FSDP)) else: self.assertFalse(isinstance(module, FSDP))
def test_state_dict_rank0_offload_save_load_flow(self): """Tests saving a model checkpoint only on rank 0 and loading it only on rank 0 with ``sync_module_states=True`` to emulate the workflow to avoid redundant CPU memory usage.""" auto_wrap_policy = partial( transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer }, ) fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs, ) # Force model parameters and buffers to be nonzero with FSDP.summon_full_params(fsdp_model): for tensor in itertools.chain(fsdp_model.parameters(), fsdp_model.buffers()): if torch.count_nonzero(tensor) == 0: with torch.no_grad(): tensor.add_( torch.tensor(1, dtype=tensor.dtype, device=tensor.device)) with self._get_state_dict_mgr(fsdp_model, "state_dict", True): state_dict = deepcopy(_get_state_dict(fsdp_model)) # Initialize a non-wrapped model on all ranks new_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, ) _zero_model(new_model, zero_buffers=True) # Only load the checkpoint on rank 0 if self.rank == 0: new_model.load_state_dict(state_dict, strict=True) _assert_module_states( new_model, process_group=self.process_group, assert_fn=self.assertNotEqual, ) # Broadcast the module states from rank 0 with `sync_module_states=True` new_fsdp_model = FSDP( new_model, device_id=torch.cuda.current_device(), auto_wrap_policy=auto_wrap_policy, sync_module_states=True, ) # Check FSDP models are equal across ranks with FSDP.summon_full_params(new_fsdp_model): _assert_module_states( new_fsdp_model, process_group=self.process_group, assert_fn=self.assertEqual, ) # Check FSDP models correctly loaded the checkpoint with FullyShardedDataParallel.summon_full_params(fsdp_model): with FullyShardedDataParallel.summon_full_params(new_fsdp_model): params = list(fsdp_model.parameters()) params_new = list(new_fsdp_model.parameters()) self.assertEqual(params, params_new)
def _test_grad_acc( self, batch_dim: int, configs: List[_GradAccConfig], cpu_offload: CPUOffload, backward_prefetch: Optional[BackwardPrefetch], ): """ Tests gradient accumulation by comparing a run that trains sequentially through some batches while accumulating gradients with a run that trains on the concatenation of those batches in a single iteration. The last iteration always synchronizes gradients regardless of what is specified by the last element of ``configs``. Arguments: batch_dim (int): Batch dimension in the input tensor to be passed into the model for the forward pass. configs (List[_GradAccConfig]): :class:`list` of configurations specifying how gradients are accumulated; for example, a list corresponding to [(False, 2), (True, 2), (False, 2)] indicates to accumulate over 2 + 2 + 2 = 6 total iterations, where the first two do not use ``no_sync()``, the middle two do use ``no_sync()``, and the final two again do not use ``no_sync()``. cpu_offload (CPUOffload): Configures CPU offloading. backward_prefetch (Optional[BackwardPrefetch]): Specifies at which point to prefetch the next layer's full parameters during the backward pass, if at all. """ # Gradient accumulation outside `no_sync()` is not currently compatible # with CPU offloading if cpu_offload.offload_params and \ any(not config.use_no_sync for config in configs): return old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32 try: # Disable TF32 to prevent floating point drift torch.backends.cuda.matmul.allow_tf32 = False # Initialize the FSDP model and optimizer fsdp_kwargs = { "cpu_offload": cpu_offload, "backward_prefetch": backward_prefetch, } fsdp_model: FSDP = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, fsdp_kwargs, deterministic=True, add_bn= False, # disable BN since the test uses varying batch sizes ) device = torch.device("cuda") optim = torch.optim.SGD( fsdp_model.parameters(), lr=0.01, momentum=0.9, ) # Generate the sequence of batches, each containing the same data # but permuted def permute_tensor(x: torch.Tensor): return x.view(-1)[torch.randperm(x.numel())].view_as(x) batch: Tuple[torch.Tensor, ...] = \ fsdp_model.module.get_input(device) batches: List[Tuple[torch.Tensor, ...]] = [batch] num_iters_to_acc = sum(config.num_iters for config in configs) for _ in range(num_iters_to_acc - 1): batches.append(tuple(permute_tensor(t) for t in batch)) for (batch1, batch2) in itertools.combinations(batches, r=2): for t1, t2 in zip(batch1, batch2): assert not torch.all(t1 == t2), \ "Check the test to make sure that batches are distinct" # Concatenate the batches along the given batch dimension concat_batch: Tuple[torch.Tensor, ...] = tuple( torch.cat(ts, dim=batch_dim) for ts in zip(*batches)) # Establish reference gradients using the concatenated batch fsdp_model.zero_grad() output = fsdp_model(*concat_batch) ref_loss = fsdp_model.module.get_loss(concat_batch, output) ref_loss.backward() ref_grads = [ p.grad.detach().clone() for p in fsdp_model.parameters() ] # Compute and accumulate the gradients fsdp_model.zero_grad() losses = [] batch_idx = 0 for config in configs: sync_context = fsdp_model.no_sync() if config.use_no_sync \ else contextlib.suppress() with sync_context: for _ in range(config.num_iters): if batch_idx == num_iters_to_acc - 1: break # always sync on the last iteration batch = batches[batch_idx] batch_idx += 1 output = fsdp_model(*batch) loss = fsdp_model.module.get_loss(batch, output) loss.backward() losses.append(loss) output = fsdp_model(*batches[-1]) loss = fsdp_model.module.get_loss(batches[-1], output) loss.backward() losses.append(loss) acc_loss = sum(losses) acc_grads = [ p.grad.detach().clone() for p in fsdp_model.parameters() ] # Compare the losses and gradients torch.testing.assert_close(ref_loss, acc_loss) self.assertEqual(len(ref_grads), len(acc_grads)) for ref_grad, acc_grad in zip(ref_grads, acc_grads): self.assertEqual(ref_grad.device, acc_grad.device) self.assertEqual(ref_grad.size(), acc_grad.size()) self.assertEqual(ref_grad.dtype, acc_grad.dtype) torch.testing.assert_close(ref_grad, acc_grad) # Check that the optimizer step does not error optim.step() finally: torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32