def test_fsdp_modules(self): group = dist.distributed_c10d._get_default_group() model = NestedWrappedModule(group, wrap_fsdp=True) modules = FSDP.fsdp_modules(model) self.assertEquals(modules, [ model.module.get_submodule("1"), model.module.get_submodule("1").get_submodule("0"), model.module.get_submodule("2"), ]) modules = FSDP.fsdp_modules(model, root_only=True) self.assertEqual(modules, [ model.module.get_submodule("1"), model.module.get_submodule("2"), ])
def test_mixed_precision_resnet(self): """ End to end test to ensure mixed precision + auto_wrap works for ResNet model. """ resnet_model = torchvision.models.resnet50().cuda() resnet_model = nn.SyncBatchNorm.convert_sync_batchnorm( resnet_model, process_group=dist.distributed_c10d._get_default_group()) n_bn = sum(1 if isinstance(x, _BatchNorm) else 0 for x in resnet_model.modules()) inp = torch.ones(1, 3, 1000, 1000, device='cuda') mp_config = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) fsdp = FSDP(resnet_model, auto_wrap_policy=size_based_auto_wrap_policy, mixed_precision=mp_config) # Batchnorm units should be wrapped individually. Validate this by # ensuring there are equal no. of FSDP units that are BN as BN units # in original resnet model. fsdp_bn = 0 for module in fsdp.fsdp_modules(fsdp): wrapped_module = module.module.module if isinstance(wrapped_module, _BatchNorm): fsdp_bn += 1 self.assertEqual(fsdp_bn, n_bn) # Would throw type mismatch issue without mixed precision autowrapping. loss = fsdp(inp).sum() loss.backward()
def _validate_mp_shard_freed(self, fsdp_model): """ Ensures that the mixed precision shard is greed for all FSDP units. """ fsdp_units = FSDP.fsdp_modules(fsdp_model) for fsdp in fsdp_units: for param in fsdp.params: self.assertEqual(0, param._mp_shard.storage().size())
def test_state_dict_type(self): module = SkipModel(double_nest=True) with enable_wrap(wrapper_cls=FSDP): fsdp = wrap(module) with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT): pass for module in FSDP.fsdp_modules(fsdp): self.assertEqual(module._state_dict_type, StateDictType.FULL_STATE_DICT)
def test_fsdp_modules(self): nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, ) modules = FSDP.fsdp_modules(nested_wrapped_module) self.assertEquals(modules, [ nested_wrapped_module.module.get_submodule("1"), nested_wrapped_module.module.get_submodule("1").get_submodule("0"), nested_wrapped_module.module.get_submodule("2"), ]) modules = FSDP.fsdp_modules(nested_wrapped_module, root_only=True) self.assertEqual(modules, [ nested_wrapped_module.module.get_submodule("1"), nested_wrapped_module.module.get_submodule("2"), ])
def _validate_no_mp_shard(self, fsdp_model): """ Validates that there is no mixed precision _mp_shard allocated when it is not expected to be. """ fsdp_units = FSDP.fsdp_modules(fsdp_model) for fsdp in fsdp_units: for param in fsdp.params: self.assertFalse(hasattr(param, '_mp_shard'))
def forward(self, tup): # Param and input should be the mixed precision type inp, cls, fsdp, mp_config, full_precision_param_dtype = tup expected_param_type = ( mp_config.param_dtype if mp_config.param_dtype is not None else self._orig_param_type ) expected_buffer_type = ( mp_config.buffer_dtype if mp_config.buffer_dtype is not None else self._orig_buffer_dtype ) cls.assertEqual(inp.dtype, expected_param_type) # Buffer should be in specified precision as well. cls.assertEqual(self.buffer.dtype, expected_buffer_type) # In FSDP, self.params should point to the right type. num_active_fsdp = 0 for fsdp_module in FSDP.fsdp_modules(fsdp): fsdp_managed_params = fsdp_module.params # Single param assumption cls.assertEqual(1, len(fsdp_managed_params)) for param in fsdp_managed_params: # FSDP unit is currently active if it is not using the param # local shard. This supports both FULL_SHARD and SHARD_GRAD_OP # cases. In FULL_SHARD, we have the additional property that # param._full_param_padded has not been freed. is_fsdp_unit_active = ( param._is_sharded and (param.data.data_ptr() != param._local_shard.data_ptr()) ) if is_fsdp_unit_active: num_active_fsdp += 1 # This FSDP unit is active, verify param points to mixed cls.assertEqual(param.dtype, expected_param_type) # _rebuild_full_param should have also freed the fp16 shard. # Shard is never allocated if param_dtype mixed precision is not # enabled. if mp_config.param_dtype is not None: cls.assertEqual(0, param._mp_shard.storage().size()) else: cls.assertFalse(hasattr(param, '_mp_shard')) elif param._is_sharded: # This FSDP unit is not active as full param has been # freed or not yet allocated. Ensure param points to full # precision param. cls.assertEqual(param.dtype, full_precision_param_dtype) # We should have gotten at least one active FSDP unit for sharded # (world size > 1) cases. For cases where param is not sharded # (ie world_size == 1) it is a bit hard to check if FSDP unit is active # as we'd always point to the local shard, so we rely on the forward # pass self.lin(inp) working well and inp being reduced precision to # implicitly validate that the param is indeed in the reduced precision. if cls.world_size > 1: cls.assertGreater(num_active_fsdp, 0) return (self.lin(inp), cls, fsdp, mp_config, full_precision_param_dtype)
def test_device_id_auto_wrap(self): """ Test auto wrapping propagates the device id. """ model = TransformerWithSharedParams(group=self.process_group) my_auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer }) wrapped = FSDP(model, auto_wrap_policy=my_auto_wrap_policy, device_id=torch.cuda.current_device()) # All FSDP instances should have device_id set for m in FSDP.fsdp_modules(wrapped): self.assertEqual(m.device_id, torch.device("cuda", torch.cuda.current_device()))
def test_device_id_auto_wrap(self): """Tests that ``auto_wrap_policy`` propagates ``device_id`` to all nested FSDP instances.""" auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, ) fsdp_kwargs = { "auto_wrap_policy": auto_wrap_policy, "device_id": torch.cuda.current_device(), } fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs, ) for fsdp_module in FSDP.fsdp_modules(fsdp_model): self.assertEqual( fsdp_module.device_id, torch.device("cuda", torch.cuda.current_device()), )
def test_default_communication_hook_behavior( self, sharding_strategy: Optional[ShardingStrategy] ): """ Tests FSDP's default communication hook's behavior and correctness. Arguments: sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm. """ m = torch.nn.Linear(1, 5, bias=False) inpt = torch.tensor([self.rank]).float().cuda(self.rank) net_default_hook = FSDP( m, device_id=torch.cuda.current_device(), sharding_strategy=sharding_strategy ).to(self.rank) # Check that default hook is set to `all_reduce` for entry in FSDP.fsdp_modules(net_default_hook): self.assertEqual(entry.communication_hook, default_hooks.allreduce_hook) for _ in range(4): # Clear gradients net_default_hook.zero_grad() loss = net_default_hook(inpt).sum() loss.backward() # For each worker, the gradient on the weight should be worker_rank. grad = net_default_hook.params[0].grad expected_grad = ( sum(i for i in range(dist.get_world_size())) / dist.get_world_size() ) # Verify default hook produces expected gradients self.assertEqual( grad[0].item(), expected_grad, msg=f"Expected hook grad of {expected_grad} but got {grad[0].item()}")
def test_default_communication_hook_initialization( self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy] ): """ Tests FSDP's communication hook interface behavior. Arguments: has_wrapping (bool): Configures wrapping of a module. sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm. """ # Initialize a model fsdp_model_with_hook = self._init_model( Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy), sharding_strategy=sharding_strategy ) dummy_state = DummyState(process_group=None) # FSDP currently supports communication hooks for a NO_SHARD strategy # Check that a `NotImplementedError` is raised for other strategies if sharding_strategy != ShardingStrategy.NO_SHARD: # Check that default hook is set to None for entry in FSDP.fsdp_modules(fsdp_model_with_hook): self.assertIsNone(entry.communication_hook) self.assertIsNone(entry.communication_hook_state) with self.assertRaisesRegex( NotImplementedError, '^Communication hooks are currently only available for a NO_SHARD strategy.$' ): fsdp_model_with_hook.register_comm_hook(dummy_state, DummyHook.dummy_hook) else: # Check that default hook is set to `all_reduce` for entry in FSDP.fsdp_modules(fsdp_model_with_hook): self.assertEqual(entry.communication_hook, default_hooks.allreduce_hook) dummy_state = DummyState(process_group=None) fsdp_model_with_hook.register_comm_hook( dummy_state, DummyHook.dummy_hook ) # Check that we can't register comm hook twice with self.assertRaisesRegex(AssertionError, '^communication hook can be only registered once$'): fsdp_model_with_hook.register_comm_hook( dummy_state, DummyHook.dummy_hook ) # Check dummy hook was registered for the root and all submodules if any for entry in FSDP.fsdp_modules(fsdp_model_with_hook): self.assertEqual( entry.communication_hook, DummyHook.dummy_hook ) self.assertEqual( entry.communication_hook_state, dummy_state )
def _get_submodules(self, fsdp_net): return [ submodule for submodule in FSDP.fsdp_modules(fsdp_net) if not submodule.check_is_root() ]