def test_inf_gradients_skip_optim_step(self): pg = DummyProcessGroup(0, 1) scaler = ShardedGradScaler(init_scale=2.0, process_group=pg, enabled=True) loss = torch.full((1,), 4.0, dtype=torch.float32, device="cpu") t0 = torch.tensor([float('inf')], dtype=torch.float32, device="cpu") t0.grad = t0.clone() opt = torch.optim.SGD([t0], lr=1.0) scaler.scale(loss) ret_val = scaler.step(opt) self.assertTrue(ret_val is None)
def test_grad_scaling(self): pg = DummyProcessGroup(0, 1) scaler = ShardedGradScaler(init_scale=2.0, process_group=pg, enabled=True) t0 = torch.full((1,), 4.0, dtype=torch.float32, device="cpu") t1 = torch.full((1,), 8.0, dtype=torch.float32, device="cpu") outputs = [t1.clone(), (t0.clone(), t1.clone()), [t0.clone(), t1.clone()]] outputs = scaler.scale(outputs) self.assertTrue(outputs[0] == 16.0 and outputs[1][0] == 8.0 and outputs[1][1] == 16.0) self.assertTrue(outputs[2][0] == 8.0 and outputs[2][1] == 16.0) self.assertTrue(scaler._scale.device == t1.device)
def _train_for_several_steps( self, model: nn.Module, num_steps: int, autocast: bool, lr: float = 0.01, fsdp_cpu_offload: Optional[CPUOffload] = None, norm_type: Optional[Union[float, int]] = None, save_model: bool = False, mixed_precision: Optional[MixedPrecision] = None, enable_sharded_grad_scaler: bool = False, use_pure_fp16: bool = False, ): cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params model_device = next(model.parameters()).device sharded_grad_scaler = ShardedGradScaler( enabled=enable_sharded_grad_scaler) # use SGD with momentum instead of Adam, since Adam is scale invariant # and this makes it bad for tests optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) for _ in range(num_steps): optim.zero_grad() with torch.cuda.amp.autocast(enabled=autocast): # Inputs always cuda regardless of cpu offloading, or model.device input = model.module.get_input(torch.device("cuda")) if use_pure_fp16 or (mixed_precision and not isinstance(model, FSDP)): if isinstance(input, torch.Tensor): input = input.half() else: input = tuple(x.half() for x in input) output = model(*input) # Post-forward, if CPU offloading model param should be on CPU. if cpu_offload_params and isinstance(model, FSDP): for p in model.parameters(): # Params should always be on CPU, even if # p._is_sharded=False self.assertEqual(p.device, torch.device("cpu")) loss = model.module.get_loss(input, output).to(model_device) loss = sharded_grad_scaler.scale(loss) if not mixed_precision and not use_pure_fp16: assert (loss.dtype == torch.float32 ), "loss data type should be float32, as the original \ parameter data type is float32." else: if use_pure_fp16: self.assertEqual(loss.dtype, torch.float16) # FSDP loss is fp16, DDP AMP loss is fp32 elif isinstance(model, FSDP): self.assertEqual(loss.dtype, mixed_precision.param_dtype) else: self.assertEqual(loss.dtype, torch.float32) model.module.run_backward(loss) if norm_type is not None: max_norm = 0.3 if isinstance(model, FSDP): model.clip_grad_norm_(max_norm, norm_type) total_norm_after_clip = _collect_total_grad_norm_fsdp( model, norm_type, self.rank) else: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, norm_type) total_norm_after_clip = _collect_total_grad_norm_local( model, norm_type) self.assertTrue(total_norm_after_clip <= max_norm) # Post-backward, if CPU offloading model params should be on CPU. if cpu_offload_params and isinstance(model, FSDP): for p in model.parameters(): # Params should always be on CPU, even if # p._is_sharded=False self.assertEqual(p.device, torch.device("cpu")) # Unscale the gradients and step sharded_grad_scaler.step(optim) # Update the scale factor sharded_grad_scaler.update() # if save_model, simulate save + load. if save_model: state_dict = { k: v.clone() for k, v in model.state_dict().items() } # Zero params, if save/load state_dict did not work properly, this # would break the parity test with DDP. _zero_model(model) model.load_state_dict(state_dict) if isinstance(model, FSDP): model._assert_state(TrainingState_.IDLE) return loss.detach()
def _run_test_mixed_precision_e2e( self, mp_config, cpu_offload, backward_prefetch, full_precision_param_dtype, sharding_strategy, sharded_grad_scaler, ): torch.cuda.set_device(self.rank) fsdp_models = [ self._get_simple_model(param_dtype=full_precision_param_dtype, sharding_strategy=sharding_strategy, cpu_offload=cpu_offload, mixed_precision=mp_config, backward_prefetch=backward_prefetch), self._get_simple_nested_model( param_dtype=full_precision_param_dtype, sharding_strategy=sharding_strategy, cpu_offload=cpu_offload, mixed_precision=mp_config, backward_prefetch=backward_prefetch), ] for model in fsdp_models: if not cpu_offload.offload_params: model.cuda() # Patch reduce_scatter to add validation for mixed precision types. orig_reduce_scatter = dist._reduce_scatter_base test_reduce_scatter = partial( self._reduce_scatter_base_validate_mp, orig_reduce_scatter, mp_config, ) with patch_reduce_scatter(test_reduce_scatter, full_precision_param_dtype): scaler = ShardedGradScaler(enabled=sharded_grad_scaler) optim = torch.optim.Adam(model.parameters()) for _ in range(3): inp = torch.randn(3, 10, device='cuda', dtype=full_precision_param_dtype) # Forward pass of LinearMixedPrecision check casting of # inputs, params, buffers. act, *_ = model((inp, self, model, mp_config, full_precision_param_dtype)) # Buffers should be casted. for buf in model.buffers(): if mp_config.buffer_dtype is not None: self.assertEqual(buf.dtype, mp_config.buffer_dtype) else: self.assertEqual(buf.dtype, _BUFFER_ORIG_DTYPE) # p._mp_shard should be freed. if model.params[0]._is_sharded: # i.e. world_size > 1 # TODO: free the mixed precision shard after forward # when world_size == 1 as well, currently when # world_size == 1 it is only freed after backward. if mp_config.param_dtype is not None: self._validate_mp_shard_freed(model) else: # We never should have allocated an _mp_shard. self._validate_no_mp_shard(model) loss = act.sum() loss = scaler.scale(loss) if mp_config.param_dtype is not None: self.assertEqual(loss.dtype, mp_config.param_dtype) else: self.assertEqual(loss.dtype, full_precision_param_dtype) # Will run patched reduce scatter that validates mixed_precision # types in backward. loss.backward() # Buffers stay casted even after backwards. for buf in model.buffers(): if mp_config.buffer_dtype is not None: self.assertEqual(buf.dtype, mp_config.buffer_dtype) else: self.assertEqual(buf.dtype, _BUFFER_ORIG_DTYPE) # p._mp_shard should be freed. if mp_config.param_dtype is not None: self._validate_mp_shard_freed(model) else: self._validate_no_mp_shard(model) # Ensure params and grads are in full precision, # as after fwd/backward we maintain full precision shards. for param in model.parameters(): self.assertEqual(param.dtype, full_precision_param_dtype) if param.grad is not None: self.assertEqual(param.grad.dtype, full_precision_param_dtype) # Unscale the gradients and step scaler.step(optim) # Update the scale factor scaler.update() # Summon full params should be in full precision with model.summon_full_params(model): # It is not expected for summon_full_params to allocate # a mixed precision shard. if mp_config.param_dtype is not None: self._validate_mp_shard_freed(model) else: self._validate_no_mp_shard(model) params = list(model.parameters()) for p in params: self.assertEqual(p.dtype, full_precision_param_dtype) # Note that buffers are cast only once and only restored # to the original buffer dtype in state_dict, so # summon_full_params is not expected to restore buffer # types to their original. named_buffers = dict(model.named_buffers()) for v in named_buffers.values(): if mp_config.buffer_dtype is not None: self.assertEqual(v.dtype, mp_config.buffer_dtype) else: self.assertEqual(v.dtype, _BUFFER_ORIG_DTYPE) # state_dict should be in full precision state_dict = { k: v.clone() for k, v in model.state_dict().items() } for name, tensor in state_dict.items(): # Parameters and buffers are checkpointed in their # original dtypes, which may be different. if name in named_buffers.keys(): self.assertEqual(tensor.dtype, _BUFFER_ORIG_DTYPE) else: self.assertEqual( tensor.dtype, full_precision_param_dtype, f"{name}: {tensor.dtype} vs {full_precision_param_dtype}" ) # After state_dict, buffer's dtype should have been restored # to the mixed precision one. for buf in model.buffers(): if mp_config.buffer_dtype is not None: self.assertEqual(buf.dtype, mp_config.buffer_dtype) else: self.assertEqual(buf.dtype, _BUFFER_ORIG_DTYPE)