def __init__(self, group, wrap_fsdp, delay_before_free_ms=0): super().__init__(group, wrap_fsdp) self.group = group self.delay_before_free_ms = delay_before_free_ms self.wrap_fsdp = wrap_fsdp # "expert" params are different on each rank torch.manual_seed(42 + group.rank()) d_expert = 23 d_shared = 12 d_input = 8 expert = nn.Linear(d_expert, d_shared) self.num_expert_params = sum([p.numel() for p in expert.parameters()]) for p in expert.parameters(): p.expert = True # type: ignore[attr-defined] # everything else is shared torch.manual_seed(0) shared = nn.Linear(d_shared, d_expert) if wrap_fsdp: # we create a process group of size 1 for the expert params expert_group = torch.distributed.new_group( [group.rank()]) # world size 1 means no shard expert = FullyShardedDataParallel( expert, expert_group) # type: ignore[assignment] shared = FullyShardedDataParallel( shared, group) # type: ignore[assignment] self.module = nn.Sequential(nn.Linear(d_input, d_shared), shared, expert, nn.Linear(d_shared, d_input))
def test_multiple_wrapping(self): """ This test simulates wrapping the module after training to run inference. This is required in cases where later in a session, the model is wrapped again in FSDP but contains nested FSDP wrappers within the module. """ inner_model = InnerModel() model = FSDP(inner_model).cuda() optim = SGD(model.parameters(), lr=0.1) for i in range(3): input = torch.rand((1, 5), dtype=torch.float).cuda() input.requires_grad = True output = model(input) output.sum().backward() optim.step() optim.zero_grad() input = torch.rand((1, 5), dtype=torch.float).cuda() output = model(input) # second time to rewrap the inner model rewrapped_model = FSDP(inner_model).cuda() rewrapped_output = rewrapped_model(input) self.assertEqual(output, rewrapped_output)
def test_input_type(self, input_cls): """Test FSDP with input being a list or a dict, only single GPU.""" class Model(Module): def __init__(self): super().__init__() self.layer = Linear(4, 4) def forward(self, input): if isinstance(input, list): input = input[0] else: assert isinstance(input, dict), input input = input["in"] return self.layer(input) model = FSDP(Model()).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(5): in_data = torch.rand(64, 4).cuda() in_data.requires_grad = True if input_cls is list: in_data = [in_data] else: self.assertTrue(input_cls is dict) in_data = {"in": in_data} out = model(in_data) out.sum().backward() optim.step() optim.zero_grad()
def _get_wrapped_model( self, group, cuda_first=False, **model_kwargs ) -> FullyShardedDataParallel: if cuda_first: model = FullyShardedDataParallel( TransformerWithSharedParams(group, **model_kwargs).cuda(), group ) else: model = FullyShardedDataParallel( TransformerWithSharedParams(group, **model_kwargs), group ).cuda() return model
def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations): gpu_id = self.rank world_size = self.world_size batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = create_model( with_fsdp=True, with_checkpoint=with_checkpoint, model_hidden_dim=model_hidden_dim, ) model = model.cuda() model = FSDP(model) # We enable momentum so that after the first iteration, the optimizer state is added # to the total memory used. criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) results = {} # results of memory stats for iteration in range(iterations): get_cur_mem(gpu_id, results, f"iter {iteration}: start") out = model(batch) get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd") out = sum(o.sum() for o in out[0]) fake_loss = criterion(out, torch.tensor(0.0).cuda()) get_cur_mem(gpu_id, results, f"iter {iteration}: after loss") fake_loss.backward() get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd") optimizer.step() get_cur_mem(gpu_id, results, f"iter {iteration}: after step") # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory. model.zero_grad(set_to_none=True) get_cur_mem(gpu_id, results, f"iter {iteration}: done") def cmp(results, expected): ret = "" self.assertEqual(results.keys(), expected.keys()) for k, v in results.items(): exp = expected[k] if abs(exp - v) > 1: # allow 1MB rounding differences ret += f"{k}: got {v}, expected {exp}\n" return ret output = cmp(results, expected) self.assertEqual(output, "")
def _test_identical_outputs(self, model_init_fn, ref_ddp_fn=None, num_steps=2, use_cuda=True, lr=0.01): group = dist.distributed_c10d._get_default_group() rank = group.rank() # Establish reference behavior with PyTorch DDP (+ optionally autocast). model = model_init_fn(group=group, wrap_fsdp=False).cuda() if ref_ddp_fn is None: model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank) else: model = ref_ddp_fn(model) ref_loss = self._train_for_several_steps(model, num_steps, autocast=False, lr=lr) ref_full_params = list(model.parameters()) # Confirm we get the same behavior using FullyShardedDataParallel. model = model_init_fn(group=group, wrap_fsdp=True) model = FullyShardedDataParallel(model) if use_cuda: model = model.cuda() else: assert next(model.parameters()).device == torch.device( "cpu" ), "module parameters should be placed on cpu if use_cuda is False." shard_loss = self._train_for_several_steps(model, num_steps, autocast=False, lr=lr) get_full_params(model) shard_full_params = list(model.parameters()) torch.testing.assert_allclose(ref_loss, shard_loss) self.assertEqual( ref_full_params, shard_full_params, exact_device=True, msg="FullyShardedDataParallel didn't match PyTorch DDP", )
def _dist_train( self, with_nested_trunk, freezing_method, freeze_after_wrap_fsdp, with_fsdp ): torch.manual_seed(0) batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = self._create_model(with_fsdp, with_nested_trunk, freeze_after_wrap_fsdp) model = model.cuda() # freezing the trunk using requires_grad. if freezing_method == FreezingMethod.RequiresGrad: for param in model.trunk.parameters(): param.requires_grad = False if with_fsdp: if not freeze_after_wrap_fsdp: model.fsdp_wrap() model = FSDP(model) else: model = DistributedDataParallel(model, device_ids=[self.rank]) target = torch.tensor([0, 1], dtype=torch.long).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) for iteration in range(3): out = model(batch) fake_loss = criterion(out, target) optimizer.zero_grad() fake_loss.backward() if freezing_method == FreezingMethod.GradToNone: if with_fsdp: for param in model.module.module.trunk.parameters(): param.grad = None else: for param in model.module.trunk.parameters(): param.grad = None optimizer.step() if with_fsdp: get_full_params(model) return list(model.parameters())
def _maybe_wrap_fsdp(model, wrap_fsdp, *args, **kwargs): return ( model if not wrap_fsdp else FullyShardedDataParallel(model, *args, **kwargs) )
def _test_identical_outputs( self, model_init_fn, *args, ref_ddp_fn=None, num_steps=2, fsdp_init_mode=FSDPInitMode.CUDA_AFTER, lr=0.01, cpu_offload=CPUOffload(), backward_prefetch=None, **kwargs ): group = dist.distributed_c10d._get_default_group() rank = group.rank() # Establish reference behavior with PyTorch DDP (+ optionally autocast). model = model_init_fn(group=group, wrap_fsdp=False).cuda() if ref_ddp_fn is None: model = nn.parallel.DistributedDataParallel( model, device_ids=[rank], output_device=rank ) else: model = ref_ddp_fn(model) ref_loss = self._train_for_several_steps( model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload ) ref_full_params = list(model.parameters()) # Confirm we get the same behavior using FullyShardedDataParallel. try: model = model_init_fn( group=group, wrap_fsdp=True, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch ) except Exception as e: raise ValueError(f"model_Init_fn {model_init_fn} got error {str(e)}") cpu_offload = cpu_offload or CPUOffload() # disabled if not specified. model = FullyShardedDataParallel(model, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch) # Call model.cuda() after init FSDP if specified. if fsdp_init_mode == FSDPInitMode.CUDA_AFTER: model = model.cuda() # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we # expect FSDP code to raise error that we check below, in the case of # offload params. if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params: for p in model.parameters(): # Should be on CPU regardless of if param is sharded. self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}") only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params ctx = ( self.assertRaisesRegex(AssertionError, "Expected param to be on CPU") if only_check_err else suppress() ) with ctx: shard_loss = self._train_for_several_steps( model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload, ) # We only check for errors in the case we have the following setup: # model = FSDP(model, cpu_offload=True) # model = model.cuda() # so skip the rest of this logic. if only_check_err: return # If CPU offload, next call will change model params to GPU. Sanity # check that params are on CPU before. if cpu_offload.offload_params: device_set = {p.device for p in model.parameters()} self.assertEqual( {torch.device("cpu")}, device_set, f"Got device set {device_set}" ) get_full_params(model) shard_full_params = list(model.parameters()) if cpu_offload.offload_params: shard_loss = shard_loss.cuda() torch.testing.assert_allclose(ref_loss, shard_loss) self.assertEqual( ref_full_params, shard_full_params, exact_device=True, msg="FullyShardedDataParallel didn't match PyTorch DDP", )
def _maybe_wrap(layer): if wrap_fsdp: return FullyShardedDataParallel(layer, group, *args, **kwargs) return layer