def test_mixed_precision_resnet(self): """ End to end test to ensure mixed precision + auto_wrap works for ResNet model. """ resnet_model = torchvision.models.resnet50().cuda() resnet_model = nn.SyncBatchNorm.convert_sync_batchnorm( resnet_model, process_group=dist.distributed_c10d._get_default_group()) n_bn = sum(1 if isinstance(x, _BatchNorm) else 0 for x in resnet_model.modules()) inp = torch.ones(1, 3, 1000, 1000, device='cuda') mp_config = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) fsdp = FSDP(resnet_model, auto_wrap_policy=size_based_auto_wrap_policy, mixed_precision=mp_config) # Batchnorm units should be wrapped individually. Validate this by # ensuring there are equal no. of FSDP units that are BN as BN units # in original resnet model. fsdp_bn = 0 for module in fsdp.fsdp_modules(fsdp): wrapped_module = module.module.module if isinstance(wrapped_module, _BatchNorm): fsdp_bn += 1 self.assertEqual(fsdp_bn, n_bn) # Would throw type mismatch issue without mixed precision autowrapping. loss = fsdp(inp).sum() loss.backward()
def test_save_and_load_after_forward_state_dict(self, mixed_precision): """ Test that saving after some training results in params being updated as expected. """ torch.cuda.set_device(self.rank) mixed_precision = MixedPrecision() if mixed_precision else None model = self._get_simple_nested_model(mixed_precision=mixed_precision) optim = torch.optim.SGD(model.parameters(), lr=0.1) initial_params = _get_full_detached_param(model) for _ in range(6): inp = torch.randn(1, 10, device=torch.cuda.current_device()) output = model(*inp) loss = output.sum() expected_dtype = torch.float32 if mixed_precision is None else torch.float16 self.assertEqual(expected_dtype, loss.dtype) loss.backward() optim.step() trained_params = _get_full_detached_param(model) # Ensure some training occured self.assertNotEqual(initial_params, trained_params) # Save a copy of the state_dict state_dict = {k: v.clone() for k, v in model.state_dict().items()} _zero_model(model) # Ensure checkpointed params have the full param dtype for tensor in state_dict.values(): self.assertEqual(tensor.dtype, torch.float32) # Load state_dict into zeroed model model.load_state_dict(state_dict) loaded_params = _get_full_detached_param(model) self.assertEqual(loaded_params, trained_params)
def test_transformer_no_grad(self, mixed_precision): """Tests that for an FSDP-wrapped transformer model with shared parameters, after training for one iteration, running a forward pass in ``eval()`` mode gives the same output as running a forward pass in ``torch.no_grad()``.""" fsdp_kwargs = {} if mixed_precision: fsdp_kwargs["mixed_precision"] = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) else: fsdp_kwargs["mixed_precision"] = None fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, fsdp_kwargs, ) self._train_for_several_steps( fsdp_model, num_steps=1, autocast=False, mixed_precision=fsdp_kwargs["mixed_precision"] ) input = fsdp_model.module.get_input(torch.device("cuda")) # Run a forward in eval mode fsdp_model.eval() ref_output = fsdp_model(*input) # Run a forward in `no_grad()` and compare with torch.no_grad(): no_grad_output = fsdp_model(*input) self.assertEqual(ref_output, no_grad_output)
def test_param_change_after_init(self, mixed_precision): """ Tests that changing FSDP model parameter values in-place after FSDP initialization persist. """ # Establish reference behavior fsdp_kwargs = {} if mixed_precision: fsdp_kwargs["mixed_precision"] = MixedPrecision() fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, fsdp_kwargs, deterministic=True, ) input = fsdp_model.module.get_input(torch.device("cuda")) ref_output = fsdp_model(*input) # Initialize the same model but change its first parameter value # in-place after FSDP initialization new_fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, fsdp_kwargs, deterministic=True, ) first_param = next(new_fsdp_model.parameters()) nn.init.normal_(first_param.data) new_output = new_fsdp_model(*input) self.assertNotEqual( ref_output, new_output, msg="new_output did not reflect change to param after init", )
def test_summon_full_param_recursive(self, recurse, summon_outer, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False), mixed_precision=mixed_precision), nn.Linear(5, 3, bias=False), ), mixed_precision=mixed_precision, ).cuda(self.rank) global_inner_numel = self.get_model_param_count(nn.Linear(5, 5, bias=False)) global_outer_numel = self.get_model_param_count(nn.Linear(5, 3, bias=False)) shard_inner_numel = int(math.ceil(global_inner_numel / self.world_size)) shard_outer_numel = int(math.ceil(global_outer_numel / self.world_size)) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) self.assertEqual(shard_outer_numel, outer_param.numel()) self.assertEqual(shard_inner_numel, inner_param.numel()) model_to_summon = model if summon_outer else model[0] # outer is summoned if _summon_full_param is called on the outer FSDP module expected_outer_numel = global_outer_numel if summon_outer else shard_outer_numel # inner is summoned if _summon_full_param is called with recursion or on the inner FSDP module expected_inner_numel = ( global_inner_numel if recurse or not summon_outer else shard_inner_numel ) with model_to_summon.summon_full_params(model_to_summon, recurse=recurse): self.assertEqual(expected_outer_numel, outer_param.numel()) self.assertEqual(expected_inner_numel, inner_param.numel())
def test_mp_embedding_default(self): default_mp_config = MixedPrecision( param_dtype=torch.float16, buffer_dtype=torch.float16, reduce_dtype=torch.float16, ) self._test_mixed_precision_embedding_table(mp_config=default_mp_config)
def test_transformer_no_grad(self, mixed_precision): group = dist.distributed_c10d._get_default_group() mixed_precision = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) if mixed_precision else None config = {"mixed_precision": mixed_precision} model = self._get_wrapped_model(group, config=config, cuda_first=False) # Train model for a step self._train_for_several_steps( model, num_steps=1, autocast=False, mixed_precision=config["mixed_precision"] ) model.eval() # no dropout for this test # Eval in standard mode (i.e., without no_grad) input = model.module.get_input(torch.device("cuda")) ref_output = model(*input) # Eval with no_grad and compare with torch.no_grad(): no_grad_output = model(*input) self.assertEqual(ref_output, no_grad_output)
def test_params_count_and_value(self, rank0_only, offload_to_cpu, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None fsdp_model = FSDP( NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, mixed_precision=mixed_precision, ), mixed_precision=mixed_precision, ) model = NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=False, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([p.to(dev) for p in model.module.parameters()] if not rank0_only or self.rank == 0 else list( p.clone() for p in fsdp_model.parameters())) with fsdp_model.summon_full_params(fsdp_model, rank0_only=rank0_only, writeback=not rank0_only): for p1, p2 in itertools.zip_longest(fsdp_model.parameters(), params_to_compare): self.assertEqual(p1, p2) # CPU offload should restore the param device param = next(fsdp_model.parameters()) self.assertTrue( param.device == torch.device("cuda", torch.cuda.current_device()))
def test_mp_embedding_only_params_and_bufs(self): self._test_mixed_precision_embedding_table( mp_config=MixedPrecision( param_dtype=torch.float16, buffer_dtype=torch.float16, ) )
def test_mp_embedding_params_and_reduce_diff(self): params_and_reduce_different = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float32, buffer_dtype=torch.float16) self._test_mixed_precision_embedding_table( mp_config=params_and_reduce_different)
def test_summon_full_param_shard_value(self, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None raw_model = nn.Linear(10, 11) raw_model_size = self.get_model_param_count(raw_model) expected_shard_size = self.get_expected_sharded_size(raw_model_size) model = FSDP(raw_model.cuda(self.rank), mixed_precision=mixed_precision) self.assertEqual(expected_shard_size, self.get_model_param_count(model)) # we're assuming a single flattened param self.assertEqual(1, len(list(model.parameters()))) my_shard = torch.clone(next(model.parameters())) with model.summon_full_params(model): self.assertEqual(raw_model_size, self.get_model_param_count(model)) parameters = list(model.parameters()) all_shards = FlatParamHandle.flatten_params(parameters, requires_grad=False) my_slice = torch.chunk(all_shards, self.world_size)[self.rank] # shards are padded but the full_param tensor is not a, b = my_shard[0:my_slice.numel()], my_slice self.assertTrue( torch.equal(my_shard[0:my_slice.numel()].cpu(), my_slice.cpu()))
def test_summon_full_params_respects_reshard_after_forward( self, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False), mixed_precision=mixed_precision), nn.Linear(5, 3, bias=False), ), mixed_precision=mixed_precision, ).cuda(self.rank) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) outer_full_param_size = outer_param.numel() * self.world_size # trigger lazy init model(torch.zeros(5).cuda(self.rank)) # the root FSDP module keeps all params around self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) # similarly summon_full_params should have the same behavior with model.summon_full_params(model): pass self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size())
def test_params_count_and_value( self, rank0_only: bool, offload_to_cpu: bool, mixed_precision: bool, ): mixed_precision = MixedPrecision() if mixed_precision else None model = NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) fsdp_model = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([p.to(dev) for p in model.module.parameters()] if not rank0_only or self.rank == 0 else list( p.clone() for p in fsdp_model.parameters())) with FSDP.summon_full_params(fsdp_model, rank0_only=rank0_only, writeback=not rank0_only): for p1, p2 in itertools.zip_longest(fsdp_model.parameters(), params_to_compare): self.assertEqual(p1, p2) # CPU offload should restore the param device param = next(fsdp_model.parameters()) self.assertTrue( param.device == torch.device("cuda", torch.cuda.current_device()))
def test_nested_wrapped_model_single_iteration_mixed_precision( self, cpu_offload: CPUOffload, sharding_strategy: Optional[ShardingStrategy], mixed_precision: bool, ): init_modes = self._get_init_modes_for_test(cpu_offload) mixed_precision = MixedPrecision( param_dtype=torch.float16, buffer_dtype=torch.float16, reduce_dtype=torch.float16, ) if mixed_precision else None for cuda_init_mode in init_modes: with self.subTest(cuda_init_mode=cuda_init_mode): self._test_fsdp_parity( NestedWrappedModule, FSDPInitMode.RECURSIVE, # Only run one step for comparison, as usually grad scaler # is needed to avoid NaN after first step. num_iters=1, cuda_init_mode=cuda_init_mode, cpu_offload=cpu_offload, sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, )
def test_reshard_outside_forward_backward_iteration( self, rank0_only, offload_to_cpu, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False), mixed_precision=mixed_precision), nn.Linear(5, 1, bias=False), ), mixed_precision=mixed_precision, ).cuda(self.rank) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) outer_full_param_size = outer_param.numel() * self.world_size # First lets validate our assumption about resharding output = model(torch.zeros(5).cuda(self.rank)) # the root FSDP module keeps all params around self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) output.backward() # we reshard everything after backward() finishes self.assertEqual(0, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) # now lets repeat it with summon done in between output = model(torch.zeros(5).cuda(self.rank)) self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) with model.summon_full_params( model, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): pass self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) output.backward() with model.summon_full_params( model, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): pass self.assertEqual(0, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size())
def test_summon_full_param_writeback(self, writeback, cpu_offload, mixed_precision, modify_outer): mixed_precision = MixedPrecision() if mixed_precision else None return _run_test_summon_full_param_writeback( self, writeback, modify_outer, cpu_offload=cpu_offload, mixed_precision=mixed_precision, )
def test_summon_full_param_writeback(self, writeback, modify_outer, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None return _run_test_summon_full_param_writeback( self, writeback, modify_outer=modify_outer, cpu_offload=CPUOffload(offload_params=False), mixed_precision=mixed_precision, )
def test_mp_batchnorm(self, convert_sync_bn): class BatchNormNet(nn.Module): def __init__(self, affine=True): super(BatchNormNet, self).__init__() self.fc1 = nn.Linear(2, 40, bias=False) self.bn = nn.BatchNorm1d(4, affine=affine) self.fc2 = nn.Linear(40, 4, bias=False) def forward(self, x): x = torch.reshape(self.fc1(x), (-1, 4, 10)) x = self.bn(x) x = torch.reshape(x, (-1, 40)) x = self.fc2(x) return F.softmax(x, dim=1) def never_wrap_policy(*args, **kwargs): return False net = BatchNormNet().cuda() if convert_sync_bn: net = nn.SyncBatchNorm.convert_sync_batchnorm(net) # FSDP detects that mixed precision + batchnorm will cause issues # and thus wrap batchnorm in a distinct FSDP unit that does not # use mixed precision. mp_config = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) with self.assertWarnsRegex( expected_warning=UserWarning, expected_regex="BatchNorm units will be wrapped as a separate" ): model = FSDP( net, mixed_precision=mp_config, auto_wrap_policy=never_wrap_policy, ) bn = model.bn self.assertTrue(isinstance(bn, FSDP)) # policy should not have wrapped any other submodules self.assertFalse(isinstance(model.fc1, FSDP)) self.assertFalse(isinstance(model.fc2, FSDP)) self.assertEqual(None, bn.mixed_precision) self.assertNotEqual(None, model.mixed_precision) inp = torch.randn((1, 2), device='cuda') # Without FSDP BN mixed precision fix, this would result in # RuntimeError: Expected counts to have type Half but got Float # for syncBN model(inp).sum().backward()
def test_save_and_load_after_forward_state_dict( self, mixed_precision, state_dict_rank0_and_offload): """ Test that saving after some training results in params being updated as expected. """ torch.cuda.set_device(self.rank) mixed_precision = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) if mixed_precision else None model = self._get_simple_nested_model(mixed_precision=mixed_precision) optim = torch.optim.SGD(model.parameters(), lr=0.1) initial_params = _get_full_detached_param(model) for _ in range(6): inp = torch.randn(1, 10, device=torch.cuda.current_device()) output = model(*inp) loss = output.sum() expected_dtype = torch.float32 if mixed_precision is None else torch.float16 self.assertEqual(expected_dtype, loss.dtype) loss.backward() optim.step() trained_params = _get_full_detached_param(model) # Ensure some training occured self.assertNotEqual(initial_params, trained_params) # Save a copy of the state_dict fsd_mgr = self._get_full_state_dict_mgr(model, state_dict_rank0_and_offload) with fsd_mgr: state_dict = {k: v.clone() for k, v in model.state_dict().items()} self._validate_state_dict_contents(state_dict, state_dict_rank0_and_offload) _zero_model(model) # Ensure checkpointed params have the full param dtype for tensor in state_dict.values(): self.assertEqual(tensor.dtype, torch.float32) # Load state_dict into zeroed model if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. state_dict = self._broadcast_state_dict(state_dict) for key in state_dict.keys(): state_dict[key] = state_dict[key].cuda() model.load_state_dict(state_dict) loaded_params = _get_full_detached_param(model) self.assertEqual(loaded_params, trained_params)
def test_nested_wrapped_model_single_iteration_mixed_precision( self, cpu_offload, sharding_strategy, mixed_precision): init_modes = self._get_init_modes_for_test(cpu_offload) mixed_precision = MixedPrecision() if mixed_precision else None for fsdp_init_mode in init_modes: with self.subTest(fsdp_init_mode=fsdp_init_mode): self._test_identical_outputs( NestedWrappedModule, # Only run one step for comparison, as usually grad scaler # is needed to avoid NaN after first step. num_steps=1, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, )
def test_register_functions_called(self, cuda_first, mixed_precision): """Tests that _register_{pre|post}_backward_hooks called during forward.""" group = dist.distributed_c10d._get_default_group() mixed_precision = MixedPrecision() if mixed_precision else None config = {"mixed_precision": mixed_precision} model = self._get_wrapped_model( group, mixed_precision=mixed_precision, cuda_first=cuda_first ) input = model.module.get_input(torch.device("cuda")) model._register_post_backward_hooks = mock.MagicMock(return_value=None) model._register_pre_backward_hooks = mock.MagicMock(return_value=None) self.assertFalse(model._register_post_backward_hooks.called) self.assertFalse(model._register_pre_backward_hooks.called) model(*input) self.assertTrue(model._register_post_backward_hooks.called) self.assertTrue(model._register_pre_backward_hooks.called)
def test_scaler_enabled(self, cpu_offload, sharding_strategy, mixed_precision): init_modes = self._get_init_modes_for_test(cpu_offload) mp = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) if mixed_precision else None for fsdp_init_mode in init_modes: self._test_identical_outputs( NestedWrappedModule, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, sharding_strategy=sharding_strategy, mixed_precision=mp, enable_sharded_grad_scaler=True, )
def test_mixed_precision_embedding_table(self): # Basic test to ensure int inputs are not casted which would break # modules such as embedding tables. mp_config = MixedPrecision() model = self._get_wrapped_model( group=torch.distributed.distributed_c10d._get_default_group(), config={"mixed_precision": mp_config}) optim = torch.optim.SGD(model.parameters(), lr=0.1) for _ in range(6): inp = model.module.get_input(torch.device("cuda")) # This would fail if we casted integer module inputs such as for # embedding tables. output = model(*inp) loss = model.module.get_loss(inp, output).cuda() self.assertEqual(loss.dtype, mp_config.param_dtype) model.module.run_backward(loss) optim.step()
def test_nested_wrapped_model_single_iteration_mixed_precision( self, cpu_offload: CPUOffload, sharding_strategy: Optional[ShardingStrategy], ): mixed_precision = MixedPrecision( param_dtype=torch.float16, buffer_dtype=torch.float16, reduce_dtype=torch.float16, ) self.run_subtests( self._get_subtest_config(cpu_offload), self._test_fsdp_parity, NestedWrappedModule, FSDPInitMode.RECURSIVE, cpu_offload=cpu_offload, sharding_strategy=sharding_strategy, num_iters=1, mixed_precision=mixed_precision, )
def test_register_functions_called(self, cuda_first: bool, mixed_precision: bool): """Tests that ``_register_{pre|post}_backward_hooks()`` are called during the FSDP forward.""" fsdp_kwargs = {} if mixed_precision: fsdp_kwargs["mixed_precision"] = MixedPrecision() fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE if cuda_first else CUDAInitMode.CUDA_AFTER, fsdp_kwargs, ) input = fsdp_model.module.get_input(torch.device("cuda")) fsdp_model._register_pre_backward_hooks = mock.MagicMock(return_value=None) fsdp_model._register_post_backward_hooks = mock.MagicMock(return_value=None) self.assertFalse(fsdp_model._register_post_backward_hooks.called) self.assertFalse(fsdp_model._register_pre_backward_hooks.called) fsdp_model(*input) self.assertTrue(fsdp_model._register_post_backward_hooks.called) self.assertTrue(fsdp_model._register_pre_backward_hooks.called)
def test_params_are_unflattenned(self, rank0_only, offload_to_cpu, mixed_precision): layer_shape = (10, 12) model = nn.Linear(*layer_shape, bias=False).cuda(self.rank) mixed_precision = MixedPrecision() if mixed_precision else None fsdp_model = FSDP(deepcopy(model), mixed_precision=mixed_precision).cuda( self.rank ) def _get_flat_param(): return fsdp_model.get_parameter("_fsdp_wrapped_module.flat_param") flattened_param = _get_flat_param() self.assertEqual(layer_shape[0] * layer_shape[1] / 2, flattened_param.numel()) with fsdp_model.summon_full_params( fsdp_model, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): if self.rank == 0 or not rank0_only: self.assertEqual(fsdp_model.weight.shape, model.weight.shape) expected_device = ( torch.device("cpu") if offload_to_cpu else torch.device("cuda", torch.cuda.current_device()) ) self.assertTrue(expected_device == fsdp_model.weight.device) else: # Nonzero rank with rank0_only maintains original params. flat_within_ctx = _get_flat_param() self.assertEqual(flat_within_ctx, flattened_param) self.assertEqual( flat_within_ctx.device, torch.device(torch.cuda.current_device()) ) # CPU offload should restore the param device param = next(fsdp_model.parameters()) self.assertTrue( param.device == torch.device("cuda", torch.cuda.current_device()) )
def test_fsdp_ddp_parity_with_grad_scaler( self, cpu_offload: CPUOffload, sharding_strategy: Optional[ShardingStrategy], mixed_precision: Optional[str], ): init_modes = self._get_init_modes_for_test(cpu_offload) mp = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) if mixed_precision is not None else None for cuda_init_mode in init_modes: self._test_fsdp_parity( NestedWrappedModule, FSDPInitMode.RECURSIVE, cuda_init_mode=cuda_init_mode, cpu_offload=cpu_offload, sharding_strategy=sharding_strategy, mixed_precision=mp, enable_sharded_grad_scaler=True, )
def _check_low_precision_hook(self, state, hook, sharding_strategy, dtype, has_wrapping): # keep everything deterministic for input data torch.manual_seed(0) torch.cuda.manual_seed(0) fsdp_with_hook = self._init_model( Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy), sharding_strategy=sharding_strategy ) fsdp_with_hook.register_comm_hook(state, hook) mp_only_grad = MixedPrecision(reduce_dtype=dtype) fsdp_with_mp = self._init_model( Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy, mixed_precision=mp_only_grad), sharding_strategy=sharding_strategy, mixed_precision=mp_only_grad ) optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1) optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1) in_data = torch.rand(16, 8).cuda() fsdp_with_hook.train() fsdp_with_mp.train() loss_hook = fsdp_with_hook(in_data).sum() loss_mp = fsdp_with_mp(in_data).sum() loss_hook.backward() # Make sure grads were cast to the parameter's precision self.assertEqual(fsdp_with_hook.params[0].dtype, state.parameter_type) loss_mp.backward() optim_hook.step() optim_mp.step() dist.barrier() for hook_param, mp_param in zip(fsdp_with_hook.parameters(), fsdp_with_mp.parameters()): self.assertEqual(hook_param.grad, mp_param.grad)
def test_param_change_after_init(self, mixed_precision): group = dist.distributed_c10d._get_default_group() # Establish reference behavior. mixed_precision = MixedPrecision() if mixed_precision else None config = {"mixed_precision": mixed_precision} model = self._get_wrapped_model( group, mixed_precision=mixed_precision, cuda_first=False ) model.eval() # no dropout for this test input = model.module.get_input(torch.device("cuda")) ref_output = model(*input) # Change the weights in place. model = self._get_wrapped_model(group, cuda_first=False) model.eval() # no dropout for this test first_param = next(model.parameters()) nn.init.normal_(first_param.data) new_output = model(*input) self.assertNotEqual( ref_output, new_output, msg="new_output did not reflect change to param after init", )
def test_mp_embedding_reduce(self): self._test_mixed_precision_embedding_table(mp_config=MixedPrecision( reduce_dtype=torch.float16))