def test_inf_gradients_skip_optim_step(self):
     pg = DummyProcessGroup(0, 1)
     scaler = ShardedGradScaler(init_scale=2.0, process_group=pg, enabled=True)
     loss = torch.full((1,), 4.0, dtype=torch.float32, device="cpu")
     t0 = torch.tensor([float('inf')], dtype=torch.float32, device="cpu")
     t0.grad = t0.clone()
     opt = torch.optim.SGD([t0], lr=1.0)
     scaler.scale(loss)
     ret_val = scaler.step(opt)
     self.assertTrue(ret_val is None)
 def test_grad_scaling(self):
     pg = DummyProcessGroup(0, 1)
     scaler = ShardedGradScaler(init_scale=2.0, process_group=pg, enabled=True)
     t0 = torch.full((1,), 4.0, dtype=torch.float32, device="cpu")
     t1 = torch.full((1,), 8.0, dtype=torch.float32, device="cpu")
     outputs = [t1.clone(), (t0.clone(), t1.clone()), [t0.clone(), t1.clone()]]
     outputs = scaler.scale(outputs)
     self.assertTrue(outputs[0] == 16.0 and outputs[1][0] == 8.0 and outputs[1][1] == 16.0)
     self.assertTrue(outputs[2][0] == 8.0 and outputs[2][1] == 16.0)
     self.assertTrue(scaler._scale.device == t1.device)
Пример #3
0
    def _train_for_several_steps(
        self,
        model: nn.Module,
        num_steps: int,
        autocast: bool,
        lr: float = 0.01,
        fsdp_cpu_offload: Optional[CPUOffload] = None,
        norm_type: Optional[Union[float, int]] = None,
        save_model: bool = False,
        mixed_precision: Optional[MixedPrecision] = None,
        enable_sharded_grad_scaler: bool = False,
        use_pure_fp16: bool = False,
    ):
        cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params

        model_device = next(model.parameters()).device
        sharded_grad_scaler = ShardedGradScaler(
            enabled=enable_sharded_grad_scaler)
        # use SGD with momentum instead of Adam, since Adam is scale invariant
        # and this makes it bad for tests
        optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        for _ in range(num_steps):
            optim.zero_grad()
            with torch.cuda.amp.autocast(enabled=autocast):
                # Inputs always cuda regardless of cpu offloading, or model.device
                input = model.module.get_input(torch.device("cuda"))
                if use_pure_fp16 or (mixed_precision
                                     and not isinstance(model, FSDP)):
                    if isinstance(input, torch.Tensor):
                        input = input.half()
                    else:
                        input = tuple(x.half() for x in input)
                output = model(*input)
                # Post-forward, if CPU offloading model param should be on CPU.
                if cpu_offload_params and isinstance(model, FSDP):
                    for p in model.parameters():
                        # Params should always be on CPU, even if
                        # p._is_sharded=False
                        self.assertEqual(p.device, torch.device("cpu"))

                loss = model.module.get_loss(input, output).to(model_device)
            loss = sharded_grad_scaler.scale(loss)

            if not mixed_precision and not use_pure_fp16:
                assert (loss.dtype == torch.float32
                        ), "loss data type should be float32, as the original \
                    parameter data type is float32."

            else:
                if use_pure_fp16:
                    self.assertEqual(loss.dtype, torch.float16)
                # FSDP loss is fp16, DDP AMP loss is fp32
                elif isinstance(model, FSDP):
                    self.assertEqual(loss.dtype, mixed_precision.param_dtype)
                else:
                    self.assertEqual(loss.dtype, torch.float32)
            model.module.run_backward(loss)
            if norm_type is not None:
                max_norm = 0.3
                if isinstance(model, FSDP):
                    model.clip_grad_norm_(max_norm, norm_type)
                    total_norm_after_clip = _collect_total_grad_norm_fsdp(
                        model, norm_type, self.rank)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   max_norm, norm_type)
                    total_norm_after_clip = _collect_total_grad_norm_local(
                        model, norm_type)
                self.assertTrue(total_norm_after_clip <= max_norm)
            # Post-backward, if CPU offloading model params should be on CPU.
            if cpu_offload_params and isinstance(model, FSDP):
                for p in model.parameters():
                    # Params should always be on CPU, even if
                    # p._is_sharded=False
                    self.assertEqual(p.device, torch.device("cpu"))
            # Unscale the gradients and step
            sharded_grad_scaler.step(optim)
            # Update the scale factor
            sharded_grad_scaler.update()
            # if save_model, simulate save + load.
            if save_model:
                state_dict = {
                    k: v.clone()
                    for k, v in model.state_dict().items()
                }
                # Zero params, if save/load state_dict did not work properly, this
                # would break the parity test with DDP.
                _zero_model(model)
                model.load_state_dict(state_dict)

        if isinstance(model, FSDP):
            model._assert_state(TrainingState_.IDLE)
        return loss.detach()
Пример #4
0
    def _run_test_mixed_precision_e2e(
        self,
        mp_config,
        cpu_offload,
        backward_prefetch,
        full_precision_param_dtype,
        sharding_strategy,
        sharded_grad_scaler,
    ):
        torch.cuda.set_device(self.rank)
        fsdp_models = [
            self._get_simple_model(param_dtype=full_precision_param_dtype,
                                   sharding_strategy=sharding_strategy,
                                   cpu_offload=cpu_offload,
                                   mixed_precision=mp_config,
                                   backward_prefetch=backward_prefetch),
            self._get_simple_nested_model(
                param_dtype=full_precision_param_dtype,
                sharding_strategy=sharding_strategy,
                cpu_offload=cpu_offload,
                mixed_precision=mp_config,
                backward_prefetch=backward_prefetch),
        ]
        for model in fsdp_models:
            if not cpu_offload.offload_params:
                model.cuda()

            # Patch reduce_scatter to add validation for mixed precision types.
            orig_reduce_scatter = dist._reduce_scatter_base
            test_reduce_scatter = partial(
                self._reduce_scatter_base_validate_mp,
                orig_reduce_scatter,
                mp_config,
            )
            with patch_reduce_scatter(test_reduce_scatter,
                                      full_precision_param_dtype):
                scaler = ShardedGradScaler(enabled=sharded_grad_scaler)
                optim = torch.optim.Adam(model.parameters())

                for _ in range(3):
                    inp = torch.randn(3,
                                      10,
                                      device='cuda',
                                      dtype=full_precision_param_dtype)
                    # Forward pass of LinearMixedPrecision check casting of
                    # inputs, params, buffers.
                    act, *_ = model((inp, self, model, mp_config,
                                     full_precision_param_dtype))
                    # Buffers should be casted.
                    for buf in model.buffers():
                        if mp_config.buffer_dtype is not None:
                            self.assertEqual(buf.dtype, mp_config.buffer_dtype)
                        else:
                            self.assertEqual(buf.dtype, _BUFFER_ORIG_DTYPE)
                    # p._mp_shard should be freed.
                    if model.params[0]._is_sharded:  # i.e. world_size > 1
                        # TODO: free the mixed precision shard after forward
                        # when world_size == 1 as well, currently when
                        # world_size == 1 it is only freed after backward.
                        if mp_config.param_dtype is not None:
                            self._validate_mp_shard_freed(model)
                        else:
                            # We never should have allocated an _mp_shard.
                            self._validate_no_mp_shard(model)

                    loss = act.sum()
                    loss = scaler.scale(loss)
                    if mp_config.param_dtype is not None:
                        self.assertEqual(loss.dtype, mp_config.param_dtype)
                    else:
                        self.assertEqual(loss.dtype,
                                         full_precision_param_dtype)
                    # Will run patched reduce scatter that validates mixed_precision
                    # types in backward.
                    loss.backward()
                    # Buffers stay casted even after backwards.
                    for buf in model.buffers():
                        if mp_config.buffer_dtype is not None:
                            self.assertEqual(buf.dtype, mp_config.buffer_dtype)
                        else:
                            self.assertEqual(buf.dtype, _BUFFER_ORIG_DTYPE)
                    # p._mp_shard should be freed.
                    if mp_config.param_dtype is not None:
                        self._validate_mp_shard_freed(model)
                    else:
                        self._validate_no_mp_shard(model)

                    # Ensure params and grads are in full precision,
                    # as after fwd/backward we maintain full precision shards.
                    for param in model.parameters():
                        self.assertEqual(param.dtype,
                                         full_precision_param_dtype)
                        if param.grad is not None:
                            self.assertEqual(param.grad.dtype,
                                             full_precision_param_dtype)

                    # Unscale the gradients and step
                    scaler.step(optim)
                    # Update the scale factor
                    scaler.update()

                    # Summon full params should be in full precision
                    with model.summon_full_params(model):
                        # It is not expected for summon_full_params to allocate
                        # a mixed precision shard.
                        if mp_config.param_dtype is not None:
                            self._validate_mp_shard_freed(model)
                        else:
                            self._validate_no_mp_shard(model)
                        params = list(model.parameters())
                        for p in params:
                            self.assertEqual(p.dtype,
                                             full_precision_param_dtype)

                        # Note that buffers are cast only once and only restored
                        # to the original buffer dtype in state_dict, so
                        # summon_full_params is not expected to restore buffer
                        # types to their original.
                        named_buffers = dict(model.named_buffers())
                        for v in named_buffers.values():
                            if mp_config.buffer_dtype is not None:
                                self.assertEqual(v.dtype,
                                                 mp_config.buffer_dtype)
                            else:
                                self.assertEqual(v.dtype, _BUFFER_ORIG_DTYPE)

                    # state_dict should be in full precision
                    state_dict = {
                        k: v.clone()
                        for k, v in model.state_dict().items()
                    }
                    for name, tensor in state_dict.items():
                        # Parameters and buffers are checkpointed in their
                        # original dtypes, which may be different.
                        if name in named_buffers.keys():
                            self.assertEqual(tensor.dtype, _BUFFER_ORIG_DTYPE)
                        else:
                            self.assertEqual(
                                tensor.dtype, full_precision_param_dtype,
                                f"{name}: {tensor.dtype} vs {full_precision_param_dtype}"
                            )

                    # After state_dict, buffer's dtype should have been restored
                    # to the mixed precision one.
                    for buf in model.buffers():
                        if mp_config.buffer_dtype is not None:
                            self.assertEqual(buf.dtype, mp_config.buffer_dtype)
                        else:
                            self.assertEqual(buf.dtype, _BUFFER_ORIG_DTYPE)