Exemple #1
0
 def test_fsdp_cpu_init_stays_on_cpu(self):
     """
     Ensure that CPU model input stays on CPU
     after FSDP init and we log a warning.
     """
     torch.cuda.set_device(self.rank)
     regex = "Module is put on CPU"
     context = self.assertWarnsRegex(expected_warning=UserWarning,
                                     expected_regex=regex)
     with context:
         mod = NestedWrappedModule(
             group=self.process_group,
             wrap_fsdp=True,
             wrap_everything=True,
             fsdp_init_mode=FSDPInitMode.CUDA_NEVER,
         )
         fsdp = FSDP(mod)
     devices = {p.device for p in fsdp.parameters()}
     self.assertEqual(1, len(devices))
     self.assertEqual(torch.device("cpu"), devices.pop())
     fsdp = fsdp.cuda()
     # Ensure fwd + backward can be performed after moving to CUDA.
     # CPU input also tests that input is correctly moved to appropriate
     # CUDA device.
     inp = mod.get_input(device=torch.device("cpu"))
     fsdp(inp[0]).sum().backward()
Exemple #2
0
 def init(
     group: dist.ProcessGroup,
     fsdp_init_mode: FSDPInitMode,
     cuda_init_mode: CUDAInitMode,
     fsdp_kwargs: Optional[Dict[str, Any]] = None,
     deterministic: bool = False,
 ):
     """
     Initializes a :class:`NestedWrappedModule` instance, but unlike
     :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this
     wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap
     policy.
     """
     super_ = super(AlwaysWrapNestedWrappedModule,
                    AlwaysWrapNestedWrappedModule)
     model = super_.init(
         group=group,
         fsdp_init_mode=FSDPInitMode.NO_FSDP,
         cuda_init_mode=cuda_init_mode,
         fsdp_kwargs=fsdp_kwargs,
         deterministic=deterministic,
     )
     if fsdp_init_mode == FSDPInitMode.NO_FSDP:
         return model
     elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
         fsdp_model = FSDP(model,
                           auto_wrap_policy=always_wrap_policy,
                           **fsdp_kwargs)
         if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
             fsdp_model = fsdp_model.cuda()
         return fsdp_model
Exemple #3
0
 def test_fsdp_cpu_init_stays_on_cpu(self):
     """Tests that passing a CPU module to FSDP preserves that the wrapped
     module is on CPU after FSDP initialization, albeit after loging a
     warning, and that FSDP moves CPU input to GPU before the forward."""
     torch.cuda.set_device(self.rank)
     regex = "Module is put on CPU"
     context = self.assertWarnsRegex(
         expected_warning=UserWarning, expected_regex=regex
     )
     with context:
         nested_wrapped_module = NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.RECURSIVE,
             CUDAInitMode.CUDA_NEVER,
         )
         fsdp_model = FSDP(nested_wrapped_module, self.process_group)
     devices = {p.device for p in fsdp_model.parameters()}
     self.assertEqual(1, len(devices))
     self.assertEqual(torch.device("cpu"), devices.pop())
     fsdp_model = fsdp_model.cuda()
     # Ensure fwd + backward can be performed after moving to CUDA.
     # CPU input also tests that input is correctly moved to appropriate
     # CUDA device.
     inp = fsdp_model.module.get_input(device=torch.device("cpu"))
     fsdp_model(*inp).sum().backward()
Exemple #4
0
 def _get_wrapped_model(
     self,
     group,
     cuda_first=False,
     config=None,
     **model_kwargs,
 ) -> FullyShardedDataParallel:
     if config is None:
         config = {}
     move_to_cuda = not ("cpu_offload" in config
                         and config["cpu_offload"].offload_params)
     if cuda_first:
         transformer = TransformerWithSharedParams(group, **model_kwargs)
         if move_to_cuda:
             transformer = transformer.cuda()
         model = FullyShardedDataParallel(transformer, group, **config)
     else:
         model = FullyShardedDataParallel(
             TransformerWithSharedParams(group, **model_kwargs),
             group,
             **config,
         )
         if move_to_cuda:
             model = model.cuda()
     return model
Exemple #5
0
    def init(
        group: dist.ProcessGroup,
        fsdp_init_mode: FSDPInitMode,
        cuda_init_mode: CUDAInitMode,
        fsdp_kwargs: Optional[Dict[str, Any]] = None,
        deterministic: bool = False,
        add_bn: bool = True,
    ) -> Union[nn.Module, FSDP]:
        """
        Initializes a :class:`TransformerWithSharedParams` instance.

        Args:
            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
                any modules with FSDP. If ``RECURSIVE``, then wraps with
                top-level FSDP. By default, the top-level FSDP uses the
                ``transformer_auto_wrap_policy()`` for encoder and decoder
                layers, but a different auto wrap policy may be specified via
                ``fsdp_kwargs``.
            cuda_init_mode (CUDAInitMode): Determines model movement to CUDA.
            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
                forwarded to the FSDP constructor.
            deterministic (bool): Whether to make the model deterministic
                across constructions.
            add_bn (bool): Whether to include batch norm in the model.
        """
        if fsdp_kwargs is None:
            fsdp_kwargs = {}
        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
            return TransformerWithSharedParams(group, cuda_init_mode, add_bn,
                                               deterministic)
        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
            # Default to the `transformer_auto_wrap_policy()`
            if "auto_wrap_policy" not in fsdp_kwargs:
                auto_wrap_policy = functools.partial(
                    transformer_auto_wrap_policy,
                    transformer_layer_cls={
                        TransformerEncoderLayer,
                        TransformerDecoderLayer,
                    },
                )
            else:
                auto_wrap_policy = fsdp_kwargs.pop("auto_wrap_policy")
            fsdp_model = FSDP(
                TransformerWithSharedParams(group, cuda_init_mode, add_bn,
                                            deterministic),
                group,
                auto_wrap_policy=auto_wrap_policy,
                **fsdp_kwargs,
            )
            if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
                fsdp_model = fsdp_model.cuda()
            return fsdp_model
        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
Exemple #6
0
 def _get_wrapped_model(
     self,
     group,
     cuda_first=False,
     ignore_modules=False,
     config=None,
     **model_kwargs,
 ) -> FullyShardedDataParallel:
     if config is None:
         config = {}
     move_to_cuda = not ("cpu_offload" in config
                         and config["cpu_offload"].offload_params)
     transformer = TransformerWithSharedParams(group, **model_kwargs)
     if cuda_first and move_to_cuda:
         transformer = transformer.cuda()
     if ignore_modules:
         assert "ignored_modules" not in config, \
             "Do not pass in `ignored_modules` via `config`"
         config["ignored_modules"] = transformer.get_ignored_modules()
     model = FullyShardedDataParallel(transformer, group, **config)
     if not cuda_first and move_to_cuda:
         model = model.cuda()
     return model
Exemple #7
0
    def _test_fsdp_parity(
        self,
        model_class: Type[FSDPTestModel],
        fsdp_init_mode: FSDPInitMode,
        cuda_init_mode: CUDAInitMode,
        ref_init_fn: Optional[Callable] = None,
        num_iters: int = 2,
        save_model: bool = True,
        cpu_offload: CPUOffload = CPUOffload(),
        backward_prefetch: Optional[BackwardPrefetch] = None,
        forward_prefetch: bool = False,
        sharding_strategy: Optional[ShardingStrategy] = None,
        mixed_precision: Optional[MixedPrecision] = None,
        enable_sharded_grad_scaler: bool = False,
        use_pure_fp16: bool = False,
        norm_type: Optional[Union[float, int]] = None,
        init_kwargs: Optional[Dict[str, Any]] = None,
        **fsdp_kwargs,
    ):
        """
        Tests FSDP training against a reference, which defaults to DDP but
        may be customized with ``ref_init_fn``.

        Args:
            model_class (Type[FSDPTestModel]): A model class that inherits from
                ``FSDPTestModel``, which defines the expected interface.
            fsdp_init_mode (FSDPInitMode): The mode to initialize the
                FSDP-wrapped model. This should not be ``NO_FSDP``.
            ref_init_fn (Optional[Callable]): A callable to invoke that wraps a
                non-wrapped model to construct the reference model, where this
                wrapper should provide data parallel semantics. If ``None``,
                then the callable defaults to the DDP constructor.
        """
        assert fsdp_init_mode != FSDPInitMode.NO_FSDP, "Expects an FSDP init mode that wraps with FSDP"
        if init_kwargs is None:
            init_kwargs = {}
        lr = 1e-2
        rank = self.process_group.rank()
        # Establish reference behavior with DDP
        model = model_class.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
            **init_kwargs,
        )
        if ref_init_fn is None:
            ref_model = DDP(model, device_ids=[rank], output_device=rank)
        else:
            ref_model = ref_init_fn(model)
        if use_pure_fp16:
            ref_model = ref_model.half()
        ref_loss = self._train_for_several_steps(
            ref_model,
            num_iters,
            autocast=mixed_precision is not None,
            lr=lr,
            fsdp_cpu_offload=cpu_offload,
            mixed_precision=mixed_precision,
            norm_type=norm_type,
            enable_sharded_grad_scaler=enable_sharded_grad_scaler,
            use_pure_fp16=use_pure_fp16,
        )
        ddp_params = list(ref_model.parameters())
        # Check against FSDP behavior
        fsdp_kwargs.update({
            "cpu_offload": cpu_offload,
            "backward_prefetch": backward_prefetch,
            "forward_prefetch": forward_prefetch,
            "sharding_strategy": sharding_strategy,
            "mixed_precision": mixed_precision,
        })
        try:
            fsdp_model = model_class.init(
                self.process_group,
                fsdp_init_mode,
                cuda_init_mode,
                fsdp_kwargs,
                deterministic=True,
                **init_kwargs,
            )
        except Exception as e:
            raise ValueError(
                f"Initializing {model_class} raised error {str(e)}")
        if not isinstance(fsdp_model, FSDP):
            # Enforce that we wrap with top-level FSDP since we are comparing
            # assuming a data parallel reference and some test models may not
            # do so in their `init()` method
            fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs)
        if use_pure_fp16:
            # Change the model parameter dtype after FSDP initialization
            fsdp_model = fsdp_model.half()
        if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
            fsdp_model = fsdp_model.cuda()
        offload_params = cpu_offload is not None and cpu_offload.offload_params
        # Offloading parameters with `CUDA_AFTER` should raise an error during
        # lazy initialization due to the parameter devices not being CPU;
        # otherwise, all parameter devices should be CPU
        expects_device_error = offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER
        expects_cpu_device = offload_params and cuda_init_mode != CUDAInitMode.CUDA_AFTER
        if expects_cpu_device:
            cpu_device = torch.device("cpu")
            for param in fsdp_model.parameters():
                self.assertEqual(param.device, cpu_device)
        context = (self.assertRaisesRegex(AssertionError,
                                          "Expected param to be on CPU")
                   if expects_device_error else suppress())
        with context:
            fsdp_loss = self._train_for_several_steps(
                fsdp_model,
                num_iters,
                autocast=False,
                lr=lr,
                fsdp_cpu_offload=cpu_offload,
                save_model=save_model,
                mixed_precision=mixed_precision,
                norm_type=norm_type,
                enable_sharded_grad_scaler=enable_sharded_grad_scaler,
                use_pure_fp16=use_pure_fp16,
            )
        # No need to check for parameter and loss parity if expecting an error
        if expects_device_error:
            return
        # Check parameter devices are CPU if offloading to CPU before calling
        # `get_full_params()`, which will cast the parameters to FP32
        if offload_params:
            for param in fsdp_model.parameters():
                self.assertEqual(param.device, cpu_device)
            fsdp_loss = fsdp_loss.cuda()
        fsdp_unsharded_params = get_full_params(fsdp_model)
        torch.testing.assert_allclose(ref_loss, fsdp_loss)
        # Do not check for parameter parity if using mixed precision since (1)
        # the DDP parameters are in FP16 (from `half()`) while the FSDP
        # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs
        # the optimizer in FP16 while FSDP runs it in FP32
        if mixed_precision is not None:
            self.assertEqual(
                ddp_params,
                fsdp_unsharded_params,
                exact_device=True,
                msg="FSDP did not match DDP",
            )
Exemple #8
0
    def _test_identical_outputs(self,
                                model_init_fn,
                                *args,
                                ref_ddp_fn=None,
                                num_steps=2,
                                fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
                                lr=0.01,
                                cpu_offload=CPUOffload(),
                                backward_prefetch=None,
                                sharding_strategy=None,
                                save_model=True,
                                clip_norm=0.3,
                                norm_type=None,
                                **kwargs):
        group = dist.distributed_c10d._get_default_group()
        rank = group.rank()
        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrap_fsdp=False).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(model,
                                                        device_ids=[rank],
                                                        output_device=rank)
        else:
            model = ref_ddp_fn(model)

        # DDP training
        ref_loss = self._train_for_several_steps(model,
                                                 num_steps,
                                                 autocast=False,
                                                 lr=lr,
                                                 fsdp_cpu_offload=cpu_offload)
        ref_full_params = list(model.parameters())

        # Confirm we get the same behavior using FullyShardedDataParallel.
        try:
            model = model_init_fn(
                group=group,
                wrap_fsdp=True,
                fsdp_init_mode=fsdp_init_mode,
                cpu_offload=cpu_offload,
                backward_prefetch=backward_prefetch,
                sharding_strategy=sharding_strategy,
            )
        except Exception as e:
            raise ValueError(
                f"model_Init_fn {model_init_fn} got error {str(e)}")

        cpu_offload = cpu_offload or CPUOffload()  # disabled if not specified.
        model = FullyShardedDataParallel(
            model,
            cpu_offload=cpu_offload,
            backward_prefetch=backward_prefetch,
            sharding_strategy=sharding_strategy,
        )
        # Call model.cuda() after init FSDP if specified.
        if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
            model = model.cuda()

        # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we
        # expect FSDP code to raise error that we check below, in the case of
        # offload params.
        if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
            for p in model.parameters():
                # Should be on CPU regardless of if param is sharded.
                self.assertEqual(p.device, torch.device("cpu"),
                                 f"Mismatch, cpu offload is {cpu_offload}")

        only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params
        ctx = (self.assertRaisesRegex(AssertionError,
                                      "Expected param to be on CPU")
               if only_check_err else suppress())
        with ctx:
            # FSDP training
            shard_loss = self._train_for_several_steps(
                model,
                num_steps,
                autocast=False,
                lr=lr,
                fsdp_cpu_offload=cpu_offload,
                save_model=save_model,
            )
        # We only check for errors in the case we have the following setup:
        # model = FSDP(model, cpu_offload=True)
        # model = model.cuda()
        # so skip the rest of this logic.
        if only_check_err:
            return
        # If CPU offload, next call will change model params to GPU. Sanity
        # check that params are on CPU before.
        if cpu_offload.offload_params:
            device_set = {p.device for p in model.parameters()}
            self.assertEqual({torch.device("cpu")}, device_set,
                             f"Got device set {device_set}")
        shard_full_params = get_full_params(model)

        if cpu_offload.offload_params:
            shard_loss = shard_loss.cuda()
        torch.testing.assert_allclose(ref_loss, shard_loss)
        self.assertEqual(
            ref_full_params,
            shard_full_params,
            exact_device=True,
            msg="FullyShardedDataParallel didn't match PyTorch DDP",
        )