def test_fsdp_cpu_init_stays_on_cpu(self): """ Ensure that CPU model input stays on CPU after FSDP init and we log a warning. """ torch.cuda.set_device(self.rank) regex = "Module is put on CPU" context = self.assertWarnsRegex(expected_warning=UserWarning, expected_regex=regex) with context: mod = NestedWrappedModule( group=self.process_group, wrap_fsdp=True, wrap_everything=True, fsdp_init_mode=FSDPInitMode.CUDA_NEVER, ) fsdp = FSDP(mod) devices = {p.device for p in fsdp.parameters()} self.assertEqual(1, len(devices)) self.assertEqual(torch.device("cpu"), devices.pop()) fsdp = fsdp.cuda() # Ensure fwd + backward can be performed after moving to CUDA. # CPU input also tests that input is correctly moved to appropriate # CUDA device. inp = mod.get_input(device=torch.device("cpu")) fsdp(inp[0]).sum().backward()
def init( group: dist.ProcessGroup, fsdp_init_mode: FSDPInitMode, cuda_init_mode: CUDAInitMode, fsdp_kwargs: Optional[Dict[str, Any]] = None, deterministic: bool = False, ): """ Initializes a :class:`NestedWrappedModule` instance, but unlike :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap policy. """ super_ = super(AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule) model = super_.init( group=group, fsdp_init_mode=FSDPInitMode.NO_FSDP, cuda_init_mode=cuda_init_mode, fsdp_kwargs=fsdp_kwargs, deterministic=deterministic, ) if fsdp_init_mode == FSDPInitMode.NO_FSDP: return model elif fsdp_init_mode == FSDPInitMode.RECURSIVE: fsdp_model = FSDP(model, auto_wrap_policy=always_wrap_policy, **fsdp_kwargs) if cuda_init_mode == CUDAInitMode.CUDA_AFTER: fsdp_model = fsdp_model.cuda() return fsdp_model
def test_fsdp_cpu_init_stays_on_cpu(self): """Tests that passing a CPU module to FSDP preserves that the wrapped module is on CPU after FSDP initialization, albeit after loging a warning, and that FSDP moves CPU input to GPU before the forward.""" torch.cuda.set_device(self.rank) regex = "Module is put on CPU" context = self.assertWarnsRegex( expected_warning=UserWarning, expected_regex=regex ) with context: nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_NEVER, ) fsdp_model = FSDP(nested_wrapped_module, self.process_group) devices = {p.device for p in fsdp_model.parameters()} self.assertEqual(1, len(devices)) self.assertEqual(torch.device("cpu"), devices.pop()) fsdp_model = fsdp_model.cuda() # Ensure fwd + backward can be performed after moving to CUDA. # CPU input also tests that input is correctly moved to appropriate # CUDA device. inp = fsdp_model.module.get_input(device=torch.device("cpu")) fsdp_model(*inp).sum().backward()
def _get_wrapped_model( self, group, cuda_first=False, config=None, **model_kwargs, ) -> FullyShardedDataParallel: if config is None: config = {} move_to_cuda = not ("cpu_offload" in config and config["cpu_offload"].offload_params) if cuda_first: transformer = TransformerWithSharedParams(group, **model_kwargs) if move_to_cuda: transformer = transformer.cuda() model = FullyShardedDataParallel(transformer, group, **config) else: model = FullyShardedDataParallel( TransformerWithSharedParams(group, **model_kwargs), group, **config, ) if move_to_cuda: model = model.cuda() return model
def init( group: dist.ProcessGroup, fsdp_init_mode: FSDPInitMode, cuda_init_mode: CUDAInitMode, fsdp_kwargs: Optional[Dict[str, Any]] = None, deterministic: bool = False, add_bn: bool = True, ) -> Union[nn.Module, FSDP]: """ Initializes a :class:`TransformerWithSharedParams` instance. Args: fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap any modules with FSDP. If ``RECURSIVE``, then wraps with top-level FSDP. By default, the top-level FSDP uses the ``transformer_auto_wrap_policy()`` for encoder and decoder layers, but a different auto wrap policy may be specified via ``fsdp_kwargs``. cuda_init_mode (CUDAInitMode): Determines model movement to CUDA. fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments forwarded to the FSDP constructor. deterministic (bool): Whether to make the model deterministic across constructions. add_bn (bool): Whether to include batch norm in the model. """ if fsdp_kwargs is None: fsdp_kwargs = {} if fsdp_init_mode == FSDPInitMode.NO_FSDP: return TransformerWithSharedParams(group, cuda_init_mode, add_bn, deterministic) elif fsdp_init_mode == FSDPInitMode.RECURSIVE: # Default to the `transformer_auto_wrap_policy()` if "auto_wrap_policy" not in fsdp_kwargs: auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer, }, ) else: auto_wrap_policy = fsdp_kwargs.pop("auto_wrap_policy") fsdp_model = FSDP( TransformerWithSharedParams(group, cuda_init_mode, add_bn, deterministic), group, auto_wrap_policy=auto_wrap_policy, **fsdp_kwargs, ) if cuda_init_mode == CUDAInitMode.CUDA_AFTER: fsdp_model = fsdp_model.cuda() return fsdp_model raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
def _get_wrapped_model( self, group, cuda_first=False, ignore_modules=False, config=None, **model_kwargs, ) -> FullyShardedDataParallel: if config is None: config = {} move_to_cuda = not ("cpu_offload" in config and config["cpu_offload"].offload_params) transformer = TransformerWithSharedParams(group, **model_kwargs) if cuda_first and move_to_cuda: transformer = transformer.cuda() if ignore_modules: assert "ignored_modules" not in config, \ "Do not pass in `ignored_modules` via `config`" config["ignored_modules"] = transformer.get_ignored_modules() model = FullyShardedDataParallel(transformer, group, **config) if not cuda_first and move_to_cuda: model = model.cuda() return model
def _test_fsdp_parity( self, model_class: Type[FSDPTestModel], fsdp_init_mode: FSDPInitMode, cuda_init_mode: CUDAInitMode, ref_init_fn: Optional[Callable] = None, num_iters: int = 2, save_model: bool = True, cpu_offload: CPUOffload = CPUOffload(), backward_prefetch: Optional[BackwardPrefetch] = None, forward_prefetch: bool = False, sharding_strategy: Optional[ShardingStrategy] = None, mixed_precision: Optional[MixedPrecision] = None, enable_sharded_grad_scaler: bool = False, use_pure_fp16: bool = False, norm_type: Optional[Union[float, int]] = None, init_kwargs: Optional[Dict[str, Any]] = None, **fsdp_kwargs, ): """ Tests FSDP training against a reference, which defaults to DDP but may be customized with ``ref_init_fn``. Args: model_class (Type[FSDPTestModel]): A model class that inherits from ``FSDPTestModel``, which defines the expected interface. fsdp_init_mode (FSDPInitMode): The mode to initialize the FSDP-wrapped model. This should not be ``NO_FSDP``. ref_init_fn (Optional[Callable]): A callable to invoke that wraps a non-wrapped model to construct the reference model, where this wrapper should provide data parallel semantics. If ``None``, then the callable defaults to the DDP constructor. """ assert fsdp_init_mode != FSDPInitMode.NO_FSDP, "Expects an FSDP init mode that wraps with FSDP" if init_kwargs is None: init_kwargs = {} lr = 1e-2 rank = self.process_group.rank() # Establish reference behavior with DDP model = model_class.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, **init_kwargs, ) if ref_init_fn is None: ref_model = DDP(model, device_ids=[rank], output_device=rank) else: ref_model = ref_init_fn(model) if use_pure_fp16: ref_model = ref_model.half() ref_loss = self._train_for_several_steps( ref_model, num_iters, autocast=mixed_precision is not None, lr=lr, fsdp_cpu_offload=cpu_offload, mixed_precision=mixed_precision, norm_type=norm_type, enable_sharded_grad_scaler=enable_sharded_grad_scaler, use_pure_fp16=use_pure_fp16, ) ddp_params = list(ref_model.parameters()) # Check against FSDP behavior fsdp_kwargs.update({ "cpu_offload": cpu_offload, "backward_prefetch": backward_prefetch, "forward_prefetch": forward_prefetch, "sharding_strategy": sharding_strategy, "mixed_precision": mixed_precision, }) try: fsdp_model = model_class.init( self.process_group, fsdp_init_mode, cuda_init_mode, fsdp_kwargs, deterministic=True, **init_kwargs, ) except Exception as e: raise ValueError( f"Initializing {model_class} raised error {str(e)}") if not isinstance(fsdp_model, FSDP): # Enforce that we wrap with top-level FSDP since we are comparing # assuming a data parallel reference and some test models may not # do so in their `init()` method fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs) if use_pure_fp16: # Change the model parameter dtype after FSDP initialization fsdp_model = fsdp_model.half() if cuda_init_mode == CUDAInitMode.CUDA_AFTER: fsdp_model = fsdp_model.cuda() offload_params = cpu_offload is not None and cpu_offload.offload_params # Offloading parameters with `CUDA_AFTER` should raise an error during # lazy initialization due to the parameter devices not being CPU; # otherwise, all parameter devices should be CPU expects_device_error = offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER expects_cpu_device = offload_params and cuda_init_mode != CUDAInitMode.CUDA_AFTER if expects_cpu_device: cpu_device = torch.device("cpu") for param in fsdp_model.parameters(): self.assertEqual(param.device, cpu_device) context = (self.assertRaisesRegex(AssertionError, "Expected param to be on CPU") if expects_device_error else suppress()) with context: fsdp_loss = self._train_for_several_steps( fsdp_model, num_iters, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload, save_model=save_model, mixed_precision=mixed_precision, norm_type=norm_type, enable_sharded_grad_scaler=enable_sharded_grad_scaler, use_pure_fp16=use_pure_fp16, ) # No need to check for parameter and loss parity if expecting an error if expects_device_error: return # Check parameter devices are CPU if offloading to CPU before calling # `get_full_params()`, which will cast the parameters to FP32 if offload_params: for param in fsdp_model.parameters(): self.assertEqual(param.device, cpu_device) fsdp_loss = fsdp_loss.cuda() fsdp_unsharded_params = get_full_params(fsdp_model) torch.testing.assert_allclose(ref_loss, fsdp_loss) # Do not check for parameter parity if using mixed precision since (1) # the DDP parameters are in FP16 (from `half()`) while the FSDP # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs # the optimizer in FP16 while FSDP runs it in FP32 if mixed_precision is not None: self.assertEqual( ddp_params, fsdp_unsharded_params, exact_device=True, msg="FSDP did not match DDP", )
def _test_identical_outputs(self, model_init_fn, *args, ref_ddp_fn=None, num_steps=2, fsdp_init_mode=FSDPInitMode.CUDA_AFTER, lr=0.01, cpu_offload=CPUOffload(), backward_prefetch=None, sharding_strategy=None, save_model=True, clip_norm=0.3, norm_type=None, **kwargs): group = dist.distributed_c10d._get_default_group() rank = group.rank() # Establish reference behavior with PyTorch DDP (+ optionally autocast). model = model_init_fn(group=group, wrap_fsdp=False).cuda() if ref_ddp_fn is None: model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank) else: model = ref_ddp_fn(model) # DDP training ref_loss = self._train_for_several_steps(model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload) ref_full_params = list(model.parameters()) # Confirm we get the same behavior using FullyShardedDataParallel. try: model = model_init_fn( group=group, wrap_fsdp=True, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, sharding_strategy=sharding_strategy, ) except Exception as e: raise ValueError( f"model_Init_fn {model_init_fn} got error {str(e)}") cpu_offload = cpu_offload or CPUOffload() # disabled if not specified. model = FullyShardedDataParallel( model, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, sharding_strategy=sharding_strategy, ) # Call model.cuda() after init FSDP if specified. if fsdp_init_mode == FSDPInitMode.CUDA_AFTER: model = model.cuda() # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we # expect FSDP code to raise error that we check below, in the case of # offload params. if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params: for p in model.parameters(): # Should be on CPU regardless of if param is sharded. self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}") only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params ctx = (self.assertRaisesRegex(AssertionError, "Expected param to be on CPU") if only_check_err else suppress()) with ctx: # FSDP training shard_loss = self._train_for_several_steps( model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload, save_model=save_model, ) # We only check for errors in the case we have the following setup: # model = FSDP(model, cpu_offload=True) # model = model.cuda() # so skip the rest of this logic. if only_check_err: return # If CPU offload, next call will change model params to GPU. Sanity # check that params are on CPU before. if cpu_offload.offload_params: device_set = {p.device for p in model.parameters()} self.assertEqual({torch.device("cpu")}, device_set, f"Got device set {device_set}") shard_full_params = get_full_params(model) if cpu_offload.offload_params: shard_loss = shard_loss.cuda() torch.testing.assert_allclose(ref_loss, shard_loss) self.assertEqual( ref_full_params, shard_full_params, exact_device=True, msg="FullyShardedDataParallel didn't match PyTorch DDP", )