def _test_identical_outputs(self, model_init_fn, ref_ddp_fn=None, num_steps=2, use_cuda=True, lr=0.01): group = dist.distributed_c10d._get_default_group() rank = group.rank() # Establish reference behavior with PyTorch DDP (+ optionally autocast). model = model_init_fn(group=group, wrap_fsdp=False).cuda() if ref_ddp_fn is None: model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank) else: model = ref_ddp_fn(model) ref_loss = self._train_for_several_steps(model, num_steps, autocast=False, lr=lr) ref_full_params = list(model.parameters()) # Confirm we get the same behavior using FullyShardedDataParallel. model = model_init_fn(group=group, wrap_fsdp=True) model = FullyShardedDataParallel(model) if use_cuda: model = model.cuda() else: assert next(model.parameters()).device == torch.device( "cpu" ), "module parameters should be placed on cpu if use_cuda is False." shard_loss = self._train_for_several_steps(model, num_steps, autocast=False, lr=lr) get_full_params(model) shard_full_params = list(model.parameters()) torch.testing.assert_allclose(ref_loss, shard_loss) self.assertEqual( ref_full_params, shard_full_params, exact_device=True, msg="FullyShardedDataParallel didn't match PyTorch DDP", )
def _test_identical_outputs( self, model_init_fn, *args, ref_ddp_fn=None, num_steps=2, fsdp_init_mode=FSDPInitMode.CUDA_AFTER, lr=0.01, cpu_offload=CPUOffload(), backward_prefetch=None, **kwargs ): group = dist.distributed_c10d._get_default_group() rank = group.rank() # Establish reference behavior with PyTorch DDP (+ optionally autocast). model = model_init_fn(group=group, wrap_fsdp=False).cuda() if ref_ddp_fn is None: model = nn.parallel.DistributedDataParallel( model, device_ids=[rank], output_device=rank ) else: model = ref_ddp_fn(model) ref_loss = self._train_for_several_steps( model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload ) ref_full_params = list(model.parameters()) # Confirm we get the same behavior using FullyShardedDataParallel. try: model = model_init_fn( group=group, wrap_fsdp=True, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch ) except Exception as e: raise ValueError(f"model_Init_fn {model_init_fn} got error {str(e)}") cpu_offload = cpu_offload or CPUOffload() # disabled if not specified. model = FullyShardedDataParallel(model, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch) # Call model.cuda() after init FSDP if specified. if fsdp_init_mode == FSDPInitMode.CUDA_AFTER: model = model.cuda() # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we # expect FSDP code to raise error that we check below, in the case of # offload params. if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params: for p in model.parameters(): # Should be on CPU regardless of if param is sharded. self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}") only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params ctx = ( self.assertRaisesRegex(AssertionError, "Expected param to be on CPU") if only_check_err else suppress() ) with ctx: shard_loss = self._train_for_several_steps( model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload, ) # We only check for errors in the case we have the following setup: # model = FSDP(model, cpu_offload=True) # model = model.cuda() # so skip the rest of this logic. if only_check_err: return # If CPU offload, next call will change model params to GPU. Sanity # check that params are on CPU before. if cpu_offload.offload_params: device_set = {p.device for p in model.parameters()} self.assertEqual( {torch.device("cpu")}, device_set, f"Got device set {device_set}" ) get_full_params(model) shard_full_params = list(model.parameters()) if cpu_offload.offload_params: shard_loss = shard_loss.cuda() torch.testing.assert_allclose(ref_loss, shard_loss) self.assertEqual( ref_full_params, shard_full_params, exact_device=True, msg="FullyShardedDataParallel didn't match PyTorch DDP", )