def _dist_train(self, wrap_fsdp, cpu_offload=CPUOffload(offload_params=False)): # keep everything deterministic for input data torch.manual_seed(0) model = Model(wrap_fsdp, cpu_offload) if wrap_fsdp: model = FSDP(model, cpu_offload=cpu_offload) else: model = DistributedDataParallel(model, device_ids=[self.rank]) model.half() optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(16, 2).cuda().half() in_data.requires_grad = True for _ in range(1): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if wrap_fsdp: get_full_params(model) return list(model.parameters())
def __init__(self, wrap_fsdp, cpu_offload=CPUOffload(offload_params=False)): super().__init__() # keep everything deterministic for model initialization torch.manual_seed(0) self.inner = Linear(2, 2).cuda() if wrap_fsdp: self.inner = FSDP(self.inner, cpu_offload=cpu_offload) self.outer = Linear(2, 2).cuda()
class TestPureFP16(FSDPTest): def _dist_train(self, wrap_fsdp, cpu_offload=CPUOffload(offload_params=False)): # keep everything deterministic for input data torch.manual_seed(0) model = Model(wrap_fsdp, cpu_offload) if wrap_fsdp: model = FSDP(model, cpu_offload=cpu_offload) else: model = DistributedDataParallel(model, device_ids=[self.rank]) model.half() optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(16, 2).cuda().half() in_data.requires_grad = True for _ in range(1): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if wrap_fsdp: get_full_params(model) return list(model.parameters()) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) def test_pure_fp16(self, cpu_offload): # DDP ddp_state = self._dist_train(wrap_fsdp=False) # FSDP fsdp_state = self._dist_train(wrap_fsdp=True, cpu_offload=cpu_offload) self.assertEqual(ddp_state, fsdp_state)
def _test_identical_outputs( self, model_init_fn, *args, ref_ddp_fn=None, num_steps=2, fsdp_init_mode=FSDPInitMode.CUDA_AFTER, lr=0.01, cpu_offload=CPUOffload(), backward_prefetch=None, **kwargs ): group = dist.distributed_c10d._get_default_group() rank = group.rank() # Establish reference behavior with PyTorch DDP (+ optionally autocast). model = model_init_fn(group=group, wrap_fsdp=False).cuda() if ref_ddp_fn is None: model = nn.parallel.DistributedDataParallel( model, device_ids=[rank], output_device=rank ) else: model = ref_ddp_fn(model) ref_loss = self._train_for_several_steps( model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload ) ref_full_params = list(model.parameters()) # Confirm we get the same behavior using FullyShardedDataParallel. try: model = model_init_fn( group=group, wrap_fsdp=True, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch ) except Exception as e: raise ValueError(f"model_Init_fn {model_init_fn} got error {str(e)}") cpu_offload = cpu_offload or CPUOffload() # disabled if not specified. model = FullyShardedDataParallel(model, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch) # Call model.cuda() after init FSDP if specified. if fsdp_init_mode == FSDPInitMode.CUDA_AFTER: model = model.cuda() # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we # expect FSDP code to raise error that we check below, in the case of # offload params. if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params: for p in model.parameters(): # Should be on CPU regardless of if param is sharded. self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}") only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params ctx = ( self.assertRaisesRegex(AssertionError, "Expected param to be on CPU") if only_check_err else suppress() ) with ctx: shard_loss = self._train_for_several_steps( model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload, ) # We only check for errors in the case we have the following setup: # model = FSDP(model, cpu_offload=True) # model = model.cuda() # so skip the rest of this logic. if only_check_err: return # If CPU offload, next call will change model params to GPU. Sanity # check that params are on CPU before. if cpu_offload.offload_params: device_set = {p.device for p in model.parameters()} self.assertEqual( {torch.device("cpu")}, device_set, f"Got device set {device_set}" ) get_full_params(model) shard_full_params = list(model.parameters()) if cpu_offload.offload_params: shard_loss = shard_loss.cuda() torch.testing.assert_allclose(ref_loss, shard_loss) self.assertEqual( ref_full_params, shard_full_params, exact_device=True, msg="FullyShardedDataParallel didn't match PyTorch DDP", )
class TestParityWithDDP(FSDPTest): """ Compare losses and parameter values after several updates when using PyTorch DDP vs. FullyShardedDataParallel. """ def _get_init_modes_for_test(self, cpu_offload): modes = [FSDPInitMode.CUDA_AFTER, FSDPInitMode.CUDA_BEFORE] # Note that FSDPInitMode.CUDA_NEVER works currently only with CPU # offload as we explicitly bring the param back to CUDA device. In # general, it will not work since we try to all_gather p.data which is # on CPU but NCCL only supports GPU. if cpu_offload.offload_params: modes.append(FSDPInitMode.CUDA_NEVER) return modes @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]) @parametrize("backward_prefetch", [ BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None ]) def test_nested_wrapped_model(self, cpu_offload, backward_prefetch): init_modes = self._get_init_modes_for_test(cpu_offload) for fsdp_init_mode in init_modes: with self.subTest(fsdp_init_mode=fsdp_init_mode): self._test_identical_outputs( NestedWrappedModule, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, ) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]) @parametrize("backward_prefetch", [ BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None ]) def test_nested_all_wrapped_model(self, cpu_offload, backward_prefetch): init_modes = self._get_init_modes_for_test(cpu_offload) for fsdp_init_mode in init_modes: with self.subTest(fsdp_init_mode=fsdp_init_mode): model_fn = functools.partial(NestedWrappedModule, wrap_everything=True) self._test_identical_outputs( model_fn, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, ) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]) @parametrize("backward_prefetch", [ BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None ]) def test_transformer_parameterized(self, cpu_offload, backward_prefetch): init_modes = self._get_init_modes_for_test(cpu_offload) for fsdp_init_mode in init_modes: with self.subTest(fsdp_init_mode=fsdp_init_mode): self._test_identical_outputs( TransformerWithSharedParams, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, ) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]) @parametrize("backward_prefetch", [ BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None ]) def test_delayed_optim_step(self, cpu_offload, backward_prefetch): # We use a model with a long CUDA delay right before the optimizer step. # This tests our streams logic, and that we don't start the allgather # until after the optimization step completes. init_modes = self._get_init_modes_for_test(cpu_offload) for fsdp_init_mode in init_modes: with self.subTest(fsdp_init_mode=fsdp_init_mode): model_fn = functools.partial(NestedWrappedModuleWithDelay, delay_after_loss_ms=250) self._test_identical_outputs( model_fn, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, ) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]) @parametrize("backward_prefetch", [ BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None ]) def test_delayed_reduce_scatter(self, cpu_offload, backward_prefetch): # We insert a delay in the torch.distributed._reduce_scatter_base op, so that # the post_backward_stream takes much longer than the backward pass. # This tests that we properly block at the end of the backward pass for # the reductions to finish. init_modes = self._get_init_modes_for_test(cpu_offload) for fsdp_init_mode in init_modes: with self.subTest(fsdp_init_mode=fsdp_init_mode): model_fn = functools.partial(NestedWrappedModuleWithDelay, delay_before_reduction_ms=250) self._test_identical_outputs( model_fn, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, ) def _dummy_ddp_fn(self, model): return DummyDDP(model) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]) @parametrize("backward_prefetch", [ BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None ]) def test_mixture_of_experts(self, cpu_offload, backward_prefetch): init_modes = self._get_init_modes_for_test(cpu_offload) for fsdp_init_mode in init_modes: with self.subTest(fsdp_init_mode=fsdp_init_mode): self._test_identical_outputs( MixtureOfExperts, # MixtureOfExperts implements custom reduce logic, so the reference # behavior should use that logic instead of PyTorch DDP. ref_ddp_fn=self._dummy_ddp_fn, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, ) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]) @parametrize("backward_prefetch", [ BackwardPrefetch_.BACKWARD_PRE, BackwardPrefetch_.BACKWARD_POST, None ]) def test_mixture_of_experts_with_delay_before_free(self, cpu_offload, backward_prefetch): init_modes = self._get_init_modes_for_test(cpu_offload) for fsdp_init_mode in init_modes: with self.subTest(fsdp_init_mode=fsdp_init_mode): model_fn = functools.partial(MixtureOfExperts, delay_before_free_ms=250) self._test_identical_outputs( model_fn, ref_ddp_fn=self._dummy_ddp_fn, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, )