def test_fsdp_modules(self):
     group = dist.distributed_c10d._get_default_group()
     model = NestedWrappedModule(group, wrap_fsdp=True)
     modules = FSDP.fsdp_modules(model)
     self.assertEquals(modules, [
         model.module.get_submodule("1"),
         model.module.get_submodule("1").get_submodule("0"),
         model.module.get_submodule("2"),
     ])
     modules = FSDP.fsdp_modules(model, root_only=True)
     self.assertEqual(modules, [
         model.module.get_submodule("1"),
         model.module.get_submodule("2"),
     ])
    def test_mixed_precision_resnet(self):
        """
        End to end test to ensure mixed precision + auto_wrap works
        for ResNet model.
        """
        resnet_model = torchvision.models.resnet50().cuda()
        resnet_model = nn.SyncBatchNorm.convert_sync_batchnorm(
            resnet_model,
            process_group=dist.distributed_c10d._get_default_group())
        n_bn = sum(1 if isinstance(x, _BatchNorm) else 0
                   for x in resnet_model.modules())
        inp = torch.ones(1, 3, 1000, 1000, device='cuda')
        mp_config = MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
        )
        fsdp = FSDP(resnet_model,
                    auto_wrap_policy=size_based_auto_wrap_policy,
                    mixed_precision=mp_config)
        # Batchnorm units should be wrapped individually. Validate this by
        # ensuring there are equal no. of FSDP units that are BN as BN units
        # in original resnet model.
        fsdp_bn = 0
        for module in fsdp.fsdp_modules(fsdp):
            wrapped_module = module.module.module
            if isinstance(wrapped_module, _BatchNorm):
                fsdp_bn += 1

        self.assertEqual(fsdp_bn, n_bn)
        # Would throw type mismatch issue without mixed precision autowrapping.
        loss = fsdp(inp).sum()
        loss.backward()
 def _validate_mp_shard_freed(self, fsdp_model):
     """
     Ensures that the mixed precision shard is greed for all FSDP units.
     """
     fsdp_units = FSDP.fsdp_modules(fsdp_model)
     for fsdp in fsdp_units:
         for param in fsdp.params:
             self.assertEqual(0, param._mp_shard.storage().size())
 def test_state_dict_type(self):
     module = SkipModel(double_nest=True)
     with enable_wrap(wrapper_cls=FSDP):
         fsdp = wrap(module)
     with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT):
         pass
     for module in FSDP.fsdp_modules(fsdp):
         self.assertEqual(module._state_dict_type, StateDictType.FULL_STATE_DICT)
 def test_fsdp_modules(self):
     nested_wrapped_module = NestedWrappedModule.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE,
     )
     modules = FSDP.fsdp_modules(nested_wrapped_module)
     self.assertEquals(modules, [
         nested_wrapped_module.module.get_submodule("1"),
         nested_wrapped_module.module.get_submodule("1").get_submodule("0"),
         nested_wrapped_module.module.get_submodule("2"),
     ])
     modules = FSDP.fsdp_modules(nested_wrapped_module, root_only=True)
     self.assertEqual(modules, [
         nested_wrapped_module.module.get_submodule("1"),
         nested_wrapped_module.module.get_submodule("2"),
     ])
 def _validate_no_mp_shard(self, fsdp_model):
     """
     Validates that there is no mixed precision _mp_shard allocated
     when it is not expected to be.
     """
     fsdp_units = FSDP.fsdp_modules(fsdp_model)
     for fsdp in fsdp_units:
         for param in fsdp.params:
             self.assertFalse(hasattr(param, '_mp_shard'))
Example #7
0
    def forward(self, tup):
        # Param and input should be the mixed precision type
        inp, cls, fsdp, mp_config, full_precision_param_dtype = tup
        expected_param_type = (
            mp_config.param_dtype if mp_config.param_dtype is not None
            else self._orig_param_type
        )
        expected_buffer_type = (
            mp_config.buffer_dtype if mp_config.buffer_dtype is not None
            else self._orig_buffer_dtype
        )
        cls.assertEqual(inp.dtype, expected_param_type)
        # Buffer should be in specified precision as well.
        cls.assertEqual(self.buffer.dtype, expected_buffer_type)

        # In FSDP, self.params should point to the right type.
        num_active_fsdp = 0
        for fsdp_module in FSDP.fsdp_modules(fsdp):
            fsdp_managed_params = fsdp_module.params
            # Single param assumption
            cls.assertEqual(1, len(fsdp_managed_params))
            for param in fsdp_managed_params:
                # FSDP unit is currently active if it is not using the param
                # local shard. This supports both FULL_SHARD and SHARD_GRAD_OP
                # cases. In FULL_SHARD, we have the additional property that
                # param._full_param_padded has not been freed.
                is_fsdp_unit_active = (
                    param._is_sharded and
                    (param.data.data_ptr() != param._local_shard.data_ptr())
                )
                if is_fsdp_unit_active:
                    num_active_fsdp += 1
                    # This FSDP unit is active, verify param points to mixed
                    cls.assertEqual(param.dtype, expected_param_type)
                    # _rebuild_full_param should have also freed the fp16 shard.
                    # Shard is never allocated if param_dtype mixed precision is not
                    # enabled.
                    if mp_config.param_dtype is not None:
                        cls.assertEqual(0, param._mp_shard.storage().size())
                    else:
                        cls.assertFalse(hasattr(param, '_mp_shard'))
                elif param._is_sharded:
                    # This FSDP unit is not active as full param has been
                    # freed or not yet allocated. Ensure param points to full
                    # precision param.
                    cls.assertEqual(param.dtype, full_precision_param_dtype)
        # We should have gotten at least one active FSDP unit for sharded
        # (world size > 1) cases. For cases where param is not sharded
        # (ie world_size == 1) it is a bit hard to check if FSDP unit is active
        # as we'd always point to the local shard, so we rely on the forward
        # pass self.lin(inp) working well and inp being reduced precision to
        # implicitly validate that the param is indeed in the reduced precision.
        if cls.world_size > 1:
            cls.assertGreater(num_active_fsdp, 0)

        return (self.lin(inp), cls, fsdp, mp_config, full_precision_param_dtype)
Example #8
0
 def test_device_id_auto_wrap(self):
     """
     Test auto wrapping propagates the device id.
     """
     model = TransformerWithSharedParams(group=self.process_group)
     my_auto_wrap_policy = functools.partial(transformer_auto_wrap_policy,
                                             transformer_layer_cls={
                                                 TransformerEncoderLayer,
                                                 TransformerDecoderLayer
                                             })
     wrapped = FSDP(model,
                    auto_wrap_policy=my_auto_wrap_policy,
                    device_id=torch.cuda.current_device())
     # All FSDP instances should have device_id set
     for m in FSDP.fsdp_modules(wrapped):
         self.assertEqual(m.device_id,
                          torch.device("cuda", torch.cuda.current_device()))
Example #9
0
 def test_device_id_auto_wrap(self):
     """Tests that ``auto_wrap_policy`` propagates ``device_id`` to all
     nested FSDP instances."""
     auto_wrap_policy = functools.partial(
         transformer_auto_wrap_policy,
         transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer},
     )
     fsdp_kwargs = {
         "auto_wrap_policy": auto_wrap_policy,
         "device_id": torch.cuda.current_device(),
     }
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE,
         fsdp_kwargs,
     )
     for fsdp_module in FSDP.fsdp_modules(fsdp_model):
         self.assertEqual(
             fsdp_module.device_id,
             torch.device("cuda", torch.cuda.current_device()),
         )
Example #10
0
    def test_default_communication_hook_behavior(
        self,
        sharding_strategy: Optional[ShardingStrategy]
    ):
        """
        Tests FSDP's default communication hook's behavior and correctness.
        Arguments:
            sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
        """
        m = torch.nn.Linear(1, 5, bias=False)
        inpt = torch.tensor([self.rank]).float().cuda(self.rank)

        net_default_hook = FSDP(
            m,
            device_id=torch.cuda.current_device(),
            sharding_strategy=sharding_strategy
        ).to(self.rank)

        # Check that default hook is set to `all_reduce`
        for entry in FSDP.fsdp_modules(net_default_hook):
            self.assertEqual(entry.communication_hook, default_hooks.allreduce_hook)

        for _ in range(4):

            # Clear gradients
            net_default_hook.zero_grad()
            loss = net_default_hook(inpt).sum()
            loss.backward()

            # For each worker, the gradient on the weight should be worker_rank.
            grad = net_default_hook.params[0].grad
            expected_grad = (
                sum(i for i in range(dist.get_world_size())) / dist.get_world_size()
            )
            # Verify default hook produces expected gradients
            self.assertEqual(
                grad[0].item(),
                expected_grad,
                msg=f"Expected hook grad of {expected_grad} but got {grad[0].item()}")
Example #11
0
    def test_default_communication_hook_initialization(
        self,
        has_wrapping: bool,
        sharding_strategy: Optional[ShardingStrategy]
    ):
        """
        Tests FSDP's communication hook interface behavior.
        Arguments:
            has_wrapping (bool): Configures wrapping of a module.
            sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
        """

        # Initialize a model
        fsdp_model_with_hook = self._init_model(
            Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
            sharding_strategy=sharding_strategy
        )
        dummy_state = DummyState(process_group=None)

        # FSDP currently supports communication hooks for a NO_SHARD strategy
        # Check that a `NotImplementedError` is raised for other strategies
        if sharding_strategy != ShardingStrategy.NO_SHARD:
            # Check that default hook is set to None
            for entry in FSDP.fsdp_modules(fsdp_model_with_hook):
                self.assertIsNone(entry.communication_hook)
                self.assertIsNone(entry.communication_hook_state)

            with self.assertRaisesRegex(
                NotImplementedError,
                '^Communication hooks are currently only available for a NO_SHARD strategy.$'
            ):
                fsdp_model_with_hook.register_comm_hook(dummy_state, DummyHook.dummy_hook)

        else:

            # Check that default hook is set to `all_reduce`
            for entry in FSDP.fsdp_modules(fsdp_model_with_hook):
                self.assertEqual(entry.communication_hook, default_hooks.allreduce_hook)

            dummy_state = DummyState(process_group=None)

            fsdp_model_with_hook.register_comm_hook(
                dummy_state,
                DummyHook.dummy_hook
            )

            # Check that we can't register comm hook twice
            with self.assertRaisesRegex(AssertionError, '^communication hook can be only registered once$'):
                fsdp_model_with_hook.register_comm_hook(
                    dummy_state,
                    DummyHook.dummy_hook
                )

            # Check dummy hook was registered for the root and all submodules if any
            for entry in FSDP.fsdp_modules(fsdp_model_with_hook):
                self.assertEqual(
                    entry.communication_hook,
                    DummyHook.dummy_hook
                )
                self.assertEqual(
                    entry.communication_hook_state,
                    dummy_state
                )
Example #12
0
 def _get_submodules(self, fsdp_net):
     return [
         submodule for submodule in FSDP.fsdp_modules(fsdp_net)
         if not submodule.check_is_root()
     ]