Ejemplo n.º 1
0
 def test_ignored_modules_transformer(self):
     """Tests that ignored modules' parameters are not flattened for a
     transformer model with shared parameters."""
     # Initialize an FSDP-wrapped transformer model that has FSDP ignore
     # the `nn.Transformer` module's parameters
     model: nn.Module = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
         deterministic=True,
     )
     wrapped_model = FSDP(
         model,
         self.process_group,
         ignored_modules=[model.transformer],
     )
     # Check that the wrapped model's flattened parameter does not include
     # the ignored transformer module's parameters
     nonwrapped_model: nn.Module = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
         deterministic=True,
     )
     total_numel = sum(p.numel() for p in nonwrapped_model.parameters())
     ignored_numel = sum(p.numel()
                         for p in nonwrapped_model.transformer.parameters())
     nonignored_numel = total_numel - ignored_numel
     with FSDP.summon_full_params(wrapped_model):
         flat_param_numel = wrapped_model.params[0].numel()
         self.assertEqual(flat_param_numel, nonignored_numel)
     # Check that we can run a few iterations
     optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
     self._train_model(wrapped_model, optim, 3)
Ejemplo n.º 2
0
 def test_param_change_after_init(self, mixed_precision):
     """
     Tests that changing FSDP model parameter values in-place after FSDP
     initialization persist.
     """
     # Establish reference behavior
     fsdp_kwargs = {}
     if mixed_precision:
         fsdp_kwargs["mixed_precision"] = MixedPrecision()
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_AFTER,
         fsdp_kwargs,
         deterministic=True,
     )
     input = fsdp_model.module.get_input(torch.device("cuda"))
     ref_output = fsdp_model(*input)
     # Initialize the same model but change its first parameter value
     # in-place after FSDP initialization
     new_fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_AFTER,
         fsdp_kwargs,
         deterministic=True,
     )
     first_param = next(new_fsdp_model.parameters())
     nn.init.normal_(first_param.data)
     new_output = new_fsdp_model(*input)
     self.assertNotEqual(
         ref_output,
         new_output,
         msg="new_output did not reflect change to param after init",
     )
Ejemplo n.º 3
0
 def _init_model(
     self,
     nested_model: bool,
     sharding_strategy: ShardingStrategy,
     device: torch.device,
 ):
     fsdp_kwargs = {"sharding_strategy": sharding_strategy}
     if nested_model:
         model = NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.RECURSIVE,
             CUDAInitMode.CUDA_AFTER,
             fsdp_kwargs,
         )
         fsdp_model: FSDP = FSDP(
             model,
             self.process_group,
             **fsdp_kwargs,
         ).to(device)
     else:
         fsdp_model: FSDP = TransformerWithSharedParams.init(
             self.process_group,
             FSDPInitMode.RECURSIVE,
             CUDAInitMode.CUDA_BEFORE,
             fsdp_kwargs,
         )
     return fsdp_model
Ejemplo n.º 4
0
 def _init_transformer_model(
     self,
     wrap: bool,
     device: torch.device = torch.device("cuda"),
     group=None,
     optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
     use_multiple_param_groups: bool = False,
     use_diff_optim_inputs: bool = False,
 ):
     if use_multiple_param_groups or use_diff_optim_inputs:
         # Keep these as arguments for parity with `_init_nested_model()`;
         # these settings are not implemented since the transformer is
         # wrapped with FSDP at the top-level, which means that there is
         # only a single flattened parameter, making these booleans vacuous
         raise NotImplementedError()
     if group is None:
         group = dist.distributed_c10d._get_default_group()
     model = TransformerWithSharedParams.init(
         group,
         FSDPInitMode.RECURSIVE if wrap else FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
         deterministic=True,
     )
     optim = optim_class(model.parameters(), lr=0.01)
     return model, optim, None
Ejemplo n.º 5
0
 def test_transformer_no_grad(self, mixed_precision):
     """Tests that for an FSDP-wrapped transformer model with shared
     parameters, after training for one iteration, running a forward pass in
     ``eval()`` mode gives the same output as running a forward pass in
     ``torch.no_grad()``."""
     fsdp_kwargs = {}
     if mixed_precision:
         fsdp_kwargs["mixed_precision"] = MixedPrecision(
             param_dtype=torch.float16,
             reduce_dtype=torch.float16,
             buffer_dtype=torch.float16,
         )
     else:
         fsdp_kwargs["mixed_precision"] = None
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_AFTER,
         fsdp_kwargs,
     )
     self._train_for_several_steps(
         fsdp_model,
         num_steps=1,
         autocast=False,
         mixed_precision=fsdp_kwargs["mixed_precision"]
     )
     input = fsdp_model.module.get_input(torch.device("cuda"))
     # Run a forward in eval mode
     fsdp_model.eval()
     ref_output = fsdp_model(*input)
     # Run a forward in `no_grad()` and compare
     with torch.no_grad():
         no_grad_output = fsdp_model(*input)
     self.assertEqual(ref_output, no_grad_output)
 def _test_mixed_precision_embedding_table(self, mp_config):
     # Basic test to ensure int inputs are not casted which would break
     # modules such as embedding tables.
     param_dtype = mp_config.param_dtype or torch.float32
     orig_reduce_scatter = dist._reduce_scatter_base
     test_reduce_scatter = partial(
         self._reduce_scatter_base_validate_mp,
         orig_reduce_scatter,
         mp_config,
     )
     with patch_reduce_scatter(test_reduce_scatter, param_dtype):
         # TODO: `test_mp_embedding_reduce()` fails if we do not wrap the
         # entire `TransformerWithSharedParams` with a single top-level FSDP
         model = TransformerWithSharedParams.init(
             self.process_group,
             FSDPInitMode.NO_FSDP,
             CUDAInitMode.CUDA_BEFORE,
             {"mixed_precision": mp_config},
         )
         fsdp_model = FSDP(model, mixed_precision=mp_config)
         optim = torch.optim.SGD(fsdp_model.parameters(), lr=0.1)
         for _ in range(6):
             inp = fsdp_model.module.get_input(torch.device("cuda"))
             # This would fail if we casted integer module inputs such as for
             # embedding tables.
             output = fsdp_model(*inp)
             loss = fsdp_model.module.get_loss(inp, output).cuda()
             self.assertEqual(loss.dtype, param_dtype)
             fsdp_model.module.run_backward(loss)
             optim.step()
Ejemplo n.º 7
0
 def test_pre_backward_hook_registration(self, cuda_first: bool):
     """Tests that FSDP pre-backward hooks are registered on forward pass
     outputs."""
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE if cuda_first else CUDAInitMode.CUDA_AFTER,
     )
     self._test_pre_backward_hook_registration(fsdp_model)
Ejemplo n.º 8
0
 def test_transformer_module_apply(self):
     """Tests that ``apply()`` modifies parameter values in-place on an
     FSDP-wrapped transformer model with shared parameters."""
     transformer = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_AFTER,
     )
     self._check_apply(transformer)
Ejemplo n.º 9
0
 def test_apply_in_summon_raises_error(self):
     """Tests that calling ``apply()`` on an FSDP instance inside the
     ``summon_full_params()`` context raises an error."""
     transformer = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_AFTER,
     )
     with transformer.summon_full_params(transformer):
         with self.assertRaisesRegex(ValueError, "expected to be in states"):
             transformer.apply(self._init_linear_weights)
Ejemplo n.º 10
0
 def test_pre_backward_hook_registration_after_state_dict(self):
     """Tests that FSDP pre-backward hooks are registered on forward pass
     outputs after saving and loading the model from a checkpoint."""
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_AFTER,
     )
     self._train_for_several_steps(fsdp_model, num_steps=2, autocast=False)
     state_dict = fsdp_model.state_dict()
     fsdp_model.load_state_dict(state_dict)
     self._test_pre_backward_hook_registration(fsdp_model)
Ejemplo n.º 11
0
 def test_register_functions_called(self, cuda_first: bool, mixed_precision: bool):
     """Tests that ``_register_{pre|post}_backward_hooks()`` are called
     during the FSDP forward."""
     fsdp_kwargs = {}
     if mixed_precision:
         fsdp_kwargs["mixed_precision"] = MixedPrecision()
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE if cuda_first else CUDAInitMode.CUDA_AFTER,
         fsdp_kwargs,
     )
     input = fsdp_model.module.get_input(torch.device("cuda"))
     fsdp_model._register_pre_backward_hooks = mock.MagicMock(return_value=None)
     fsdp_model._register_post_backward_hooks = mock.MagicMock(return_value=None)
     self.assertFalse(fsdp_model._register_post_backward_hooks.called)
     self.assertFalse(fsdp_model._register_pre_backward_hooks.called)
     fsdp_model(*input)
     self.assertTrue(fsdp_model._register_post_backward_hooks.called)
     self.assertTrue(fsdp_model._register_pre_backward_hooks.called)
Ejemplo n.º 12
0
 def test_device_id_auto_wrap(self):
     """Tests that ``auto_wrap_policy`` propagates ``device_id`` to all
     nested FSDP instances."""
     auto_wrap_policy = functools.partial(
         transformer_auto_wrap_policy,
         transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer},
     )
     fsdp_kwargs = {
         "auto_wrap_policy": auto_wrap_policy,
         "device_id": torch.cuda.current_device(),
     }
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE,
         fsdp_kwargs,
     )
     for fsdp_module in FSDP.fsdp_modules(fsdp_model):
         self.assertEqual(
             fsdp_module.device_id,
             torch.device("cuda", torch.cuda.current_device()),
         )
Ejemplo n.º 13
0
 def test_transformer_auto_wrap_policy(self):
     """Tests the ``transformer_auto_wrap_policy``."""
     auto_wrap_policy = functools.partial(
         transformer_auto_wrap_policy,
         transformer_layer_cls={
             TransformerEncoderLayer, TransformerDecoderLayer
         },
     )
     fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy}
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE,
         fsdp_kwargs,
     )
     modules = list(fsdp_model.modules())
     encoder_layers = set(fsdp_model.module.transformer.encoder.layers)
     decoder_layers = set(fsdp_model.module.transformer.decoder.layers)
     for module in modules:
         if module is fsdp_model or module in encoder_layers or module in decoder_layers:
             self.assertTrue(isinstance(module, FSDP))
         else:
             self.assertFalse(isinstance(module, FSDP))
Ejemplo n.º 14
0
 def test_state_dict_rank0_offload_save_load_flow(self):
     """Tests saving a model checkpoint only on rank 0 and loading it only
     on rank 0 with ``sync_module_states=True`` to emulate the workflow to
     avoid redundant CPU memory usage."""
     auto_wrap_policy = partial(
         transformer_auto_wrap_policy,
         transformer_layer_cls={
             TransformerEncoderLayer, TransformerDecoderLayer
         },
     )
     fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy}
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE,
         fsdp_kwargs,
     )
     # Force model parameters and buffers to be nonzero
     with FSDP.summon_full_params(fsdp_model):
         for tensor in itertools.chain(fsdp_model.parameters(),
                                       fsdp_model.buffers()):
             if torch.count_nonzero(tensor) == 0:
                 with torch.no_grad():
                     tensor.add_(
                         torch.tensor(1,
                                      dtype=tensor.dtype,
                                      device=tensor.device))
     with self._get_state_dict_mgr(fsdp_model, "state_dict", True):
         state_dict = deepcopy(_get_state_dict(fsdp_model))
     # Initialize a non-wrapped model on all ranks
     new_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
     )
     _zero_model(new_model, zero_buffers=True)
     # Only load the checkpoint on rank 0
     if self.rank == 0:
         new_model.load_state_dict(state_dict, strict=True)
     _assert_module_states(
         new_model,
         process_group=self.process_group,
         assert_fn=self.assertNotEqual,
     )
     # Broadcast the module states from rank 0 with `sync_module_states=True`
     new_fsdp_model = FSDP(
         new_model,
         device_id=torch.cuda.current_device(),
         auto_wrap_policy=auto_wrap_policy,
         sync_module_states=True,
     )
     # Check FSDP models are equal across ranks
     with FSDP.summon_full_params(new_fsdp_model):
         _assert_module_states(
             new_fsdp_model,
             process_group=self.process_group,
             assert_fn=self.assertEqual,
         )
     # Check FSDP models correctly loaded the checkpoint
     with FullyShardedDataParallel.summon_full_params(fsdp_model):
         with FullyShardedDataParallel.summon_full_params(new_fsdp_model):
             params = list(fsdp_model.parameters())
             params_new = list(new_fsdp_model.parameters())
             self.assertEqual(params, params_new)
Ejemplo n.º 15
0
    def _test_grad_acc(
        self,
        batch_dim: int,
        configs: List[_GradAccConfig],
        cpu_offload: CPUOffload,
        backward_prefetch: Optional[BackwardPrefetch],
    ):
        """
        Tests gradient accumulation by comparing a run that trains sequentially
        through some batches while accumulating gradients with a run that
        trains on the concatenation of those batches in a single iteration.

        The last iteration always synchronizes gradients regardless of what is
        specified by the last element of ``configs``.

        Arguments:
            batch_dim (int): Batch dimension in the input tensor to be passed
                into the model for the forward pass.
            configs (List[_GradAccConfig]): :class:`list` of configurations
                specifying how gradients are accumulated; for example, a list
                corresponding to [(False, 2), (True, 2), (False, 2)] indicates
                to accumulate over 2 + 2 + 2 = 6 total iterations, where the
                first two do not use ``no_sync()``, the middle two do use
                ``no_sync()``, and the final two again do not use
                ``no_sync()``.
            cpu_offload (CPUOffload): Configures CPU offloading.
            backward_prefetch (Optional[BackwardPrefetch]): Specifies at which
                point to prefetch the next layer's full parameters during the
                backward pass, if at all.
        """
        # Gradient accumulation outside `no_sync()` is not currently compatible
        # with CPU offloading
        if cpu_offload.offload_params and \
                any(not config.use_no_sync for config in configs):
            return
        old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32
        try:
            # Disable TF32 to prevent floating point drift
            torch.backends.cuda.matmul.allow_tf32 = False

            # Initialize the FSDP model and optimizer
            fsdp_kwargs = {
                "cpu_offload": cpu_offload,
                "backward_prefetch": backward_prefetch,
            }
            fsdp_model: FSDP = TransformerWithSharedParams.init(
                self.process_group,
                FSDPInitMode.RECURSIVE,
                CUDAInitMode.CUDA_AFTER,
                fsdp_kwargs,
                deterministic=True,
                add_bn=
                False,  # disable BN since the test uses varying batch sizes
            )
            device = torch.device("cuda")
            optim = torch.optim.SGD(
                fsdp_model.parameters(),
                lr=0.01,
                momentum=0.9,
            )

            # Generate the sequence of batches, each containing the same data
            # but permuted
            def permute_tensor(x: torch.Tensor):
                return x.view(-1)[torch.randperm(x.numel())].view_as(x)

            batch: Tuple[torch.Tensor, ...] = \
                fsdp_model.module.get_input(device)
            batches: List[Tuple[torch.Tensor, ...]] = [batch]
            num_iters_to_acc = sum(config.num_iters for config in configs)
            for _ in range(num_iters_to_acc - 1):
                batches.append(tuple(permute_tensor(t) for t in batch))
            for (batch1, batch2) in itertools.combinations(batches, r=2):
                for t1, t2 in zip(batch1, batch2):
                    assert not torch.all(t1 == t2), \
                        "Check the test to make sure that batches are distinct"

            # Concatenate the batches along the given batch dimension
            concat_batch: Tuple[torch.Tensor, ...] = tuple(
                torch.cat(ts, dim=batch_dim) for ts in zip(*batches))

            # Establish reference gradients using the concatenated batch
            fsdp_model.zero_grad()
            output = fsdp_model(*concat_batch)
            ref_loss = fsdp_model.module.get_loss(concat_batch, output)
            ref_loss.backward()
            ref_grads = [
                p.grad.detach().clone() for p in fsdp_model.parameters()
            ]

            # Compute and accumulate the gradients
            fsdp_model.zero_grad()
            losses = []
            batch_idx = 0
            for config in configs:
                sync_context = fsdp_model.no_sync() if config.use_no_sync \
                    else contextlib.suppress()
                with sync_context:
                    for _ in range(config.num_iters):
                        if batch_idx == num_iters_to_acc - 1:
                            break  # always sync on the last iteration
                        batch = batches[batch_idx]
                        batch_idx += 1
                        output = fsdp_model(*batch)
                        loss = fsdp_model.module.get_loss(batch, output)
                        loss.backward()
                        losses.append(loss)
            output = fsdp_model(*batches[-1])
            loss = fsdp_model.module.get_loss(batches[-1], output)
            loss.backward()
            losses.append(loss)
            acc_loss = sum(losses)
            acc_grads = [
                p.grad.detach().clone() for p in fsdp_model.parameters()
            ]

            # Compare the losses and gradients
            torch.testing.assert_close(ref_loss, acc_loss)
            self.assertEqual(len(ref_grads), len(acc_grads))
            for ref_grad, acc_grad in zip(ref_grads, acc_grads):
                self.assertEqual(ref_grad.device, acc_grad.device)
                self.assertEqual(ref_grad.size(), acc_grad.size())
                self.assertEqual(ref_grad.dtype, acc_grad.dtype)
                torch.testing.assert_close(ref_grad, acc_grad)

            # Check that the optimizer step does not error
            optim.step()
        finally:
            torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32