def test_basic_save_and_load_state_dict(self, cpu_offload, fp16, state_dict_rank0_and_offload): """ Tests that we can save a state_dict and load it into a blank model with various configs such as fp16 and cpu offload and parameters match as expected. """ for model_call in [ partial(self._get_simple_nested_model, cpu_offload=cpu_offload), partial(self._get_simple_model, cpu_offload=cpu_offload), ]: model = model_call() full_state_dict_mgr = self._get_full_state_dict_mgr( model, state_dict_rank0_and_offload) with full_state_dict_mgr: fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, fp16) self._validate_state_dict_contents(fsdp_state_dict, state_dict_rank0_and_offload) if fp16: # Verify fp16 is the type for tensor in fsdp_state_dict.values(): self.assertEqual(tensor.dtype, torch.float16) model_new = model_call() if not cpu_offload.offload_params: model_new = model_new.cuda() if fp16: model_new.half() # zero the model to ensure parameters are different. _zero_model(model_new) with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertNotEqual(params, params_new) # Verify parameters are the same in the new model. if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() model_new.load_state_dict(fsdp_state_dict) with FullyShardedDataParallel.summon_full_params(model_new): with FullyShardedDataParallel.summon_full_params(model): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertEqual(params, params_new) if fp16: for tensor in model_new.parameters(): self.assertEqual(tensor.dtype, torch.float16)
def _compare_models(self, model, model_new, assert_fn, check_fp16=False): with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) assert_fn(params, params_new) if check_fp16: for tensor in model_new.parameters(): self.assertEqual(tensor.dtype, torch.float16)
def test_auto_wrap_smoke_test(self, cuda_init_mode, cpu_offload, use_device_id): # CPU offload and CUDA after don't work together as expected. if (cpu_offload.offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER): return device = torch.device("cuda") torch.cuda.set_device(0) device_id = (torch.device("cuda", torch.cuda.current_device()) if use_device_id else None) # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(find_free_port()) file_name = tempfile.NamedTemporaryFile(delete=False).name torch.distributed.init_process_group( backend="nccl", init_method=f"{FILE_SCHEMA}_{file_name}", rank=0, world_size=1, ) # NOTE: We move model to CUDA after init with FSDP to simulate real use # cases where full model cannot be loaded onto GPU, but their shards can. cuda_after_init = cuda_init_mode == CUDAInitMode.CUDA_AFTER try: sequential = TestFSDPWrap.NestedSequentialModel.get_model( cuda=(not cuda_after_init)) my_auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, cpu_offload=cpu_offload, auto_wrap_policy=my_auto_wrap_policy, device_id=device_id) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) if cuda_after_init: model = model.cuda() input = torch.rand((1, 5), dtype=torch.float).to(device) output = model(input) loss = F.mse_loss(input, output) loss.backward() finally: torch.distributed.destroy_process_group() try: os.remove(file_name) except FileNotFoundError: pass
def test_state_dict_load_into_local_module(self): """ Tests that FSDP's state_dict can be loaded into a local model. """ model = self._initialize_model(wrap_fsdp=True) optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() with FullyShardedDataParallel.summon_full_params(model): fsdp_params = deepcopy(list(model.parameters())) # get FSDP state_dict. Note that by default we return state_dict. fsdp_state_dict = model.state_dict() # Create zeroed local model blank_local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False) for param in blank_local_model.parameters(): with torch.no_grad(): param.zero_() # Load fsdp's full state dict into the local and verify params are as # expected. blank_local_model.load_state_dict(fsdp_state_dict) local_params = list(blank_local_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param)
def test_state_dict_rank0_offload_save_load_flow(self): # Test taking checkpoint on rank 0 only, and reload # without redundant CPU memories. model = TransformerWithSharedParams( group=dist.distributed_c10d._get_default_group()) my_auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer }) model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy) ctx = self._get_state_dict_mgr(model, "state_dict", True) with ctx: state_dict = deepcopy(_get_state_dict(model)) # All ranks initialize non-FSDP model grp = dist.distributed_c10d._get_default_group() model_new = TransformerWithSharedParams(group=grp) for p in model_new.parameters(): with torch.no_grad(): p.zero_() # Only rank 0 loads the checkpoint if self.rank == 0: model_new.load_state_dict(state_dict) # TransformerWithSharedParams has a buffer of zeros, so can't pass in # self.assertNotEqual since the buffers would be equal. So just checking that # there is some difference in the model across ranks before state_dict is # broadcasted. with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close"): _validate(model_new, process_group=grp, assert_fn=self.assertEqual) # FSDP with sync_module_states=True broadcasts the checkpointed states. model_new = FSDP(model_new, device_id=torch.cuda.current_device(), auto_wrap_policy=my_auto_wrap_policy, sync_module_states=True) # After wrapping with FSDP models are equal across ranks, and have loaded the checkpoint with FSDP.summon_full_params(model_new): _validate(model_new, process_group=grp, assert_fn=self.assertEqual) with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertEqual(params, params_new)
def test_basic_save_and_load_state_dict(self, cpu_offload, fp16): """ Tests that we can save a state_dict and load it into a blank model with various configs such as fp16 and cpu offload and parameters match as expected. """ for model_call in [ partial(self._get_simple_nested_model, cpu_offload=cpu_offload), partial(self._get_simple_model, cpu_offload=cpu_offload), ]: model = model_call() fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, fp16) if fp16: # Verify fp16 is the type for tensor in fsdp_state_dict.values(): self.assertEqual(tensor.dtype, torch.float16) model_new = model_call() if not cpu_offload.offload_params: model_new = model_new.cuda() if fp16: model_new.half() # zero the model to ensure parameters are different. _zero_model(model_new) with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertNotEqual(params, params_new) # Verify parameters are the same in the new model. model_new.load_state_dict(fsdp_state_dict) with FullyShardedDataParallel.summon_full_params(model_new): with FullyShardedDataParallel.summon_full_params(model): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertEqual(params, params_new) if fp16: for tensor in model_new.parameters(): self.assertEqual(tensor.dtype, torch.float16)
def test_state_dict_load_into_local_module( self, state_dict_type, state_dict_rank0_and_offload ): """ Tests that FSDP's state_dict can be loaded into a local model. """ if state_dict_rank0_and_offload and state_dict_type != "state_dict": return model = self._initialize_model(wrap_fsdp=True, register_buffers=True) optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() with FullyShardedDataParallel.summon_full_params(model): fsdp_params = deepcopy(list(model.parameters())) # get FSDP state_dict. Note that by default we return full_state_dict. sd_mgr = self._get_state_dict_mgr( model, state_dict_type, state_dict_rank0_and_offload ) with sd_mgr: fsdp_state_dict = model.state_dict() self._validate_state_dict_contents( fsdp_state_dict, state_dict_rank0_and_offload ) # Create zeroed local model blank_local_model = self._initialize_model( wrap_fsdp=False, wrap_ddp=False, register_buffers=True, ) for param in blank_local_model.parameters(): with torch.no_grad(): param.zero_() fsdp_state_dict = _gather_state_dict(fsdp_state_dict) # Load fsdp's full state dict into the local and verify params are as # expected. if state_dict_rank0_and_offload: # Broadcast + CUDA state_dict fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() if self.rank == 0: blank_local_model.load_state_dict(fsdp_state_dict) local_params = list(blank_local_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param)
def test_distributed_checkpoint(self, state_dict_type) -> None: with enable_wrap(wrapper_cls=FSDP): torch.manual_seed(100) model = wrap(SkipModel(double_nest=True)) torch.manual_seed(200) new_model = wrap(SkipModel(double_nest=True)) with FullyShardedDataParallel.summon_full_params( model), FullyShardedDataParallel.summon_full_params(new_model): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertNotEqual(params, new_params) with tempfile.TemporaryDirectory() as path: paths = [path] dist.broadcast_object_list(paths) path = paths[0] writer = FileSystemWriter(path) reader = FileSystemReader(path) with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( new_model, state_dict_type): state_dict = model.state_dict() save_state_dict(state_dict, writer) with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( new_model, state_dict_type): state_dict = new_model.state_dict() load_state_dict(state_dict, reader) new_model.load_state_dict(state_dict) with FullyShardedDataParallel.summon_full_params( model), FullyShardedDataParallel.summon_full_params(new_model): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertEqual(params, new_params)
def test_summon_from_non_fsdp(self): class FSDPContainer(nn.Module): def __init__(self, fsdp_1, fsdp_2, fsdp_3): super().__init__() self.fsdp_1 = fsdp_1 self.fsdp_2 = fsdp_2 self.fsdp_3 = fsdp_3 model_fsdp = FSDPContainer( FSDP(DeterministicModel(wrap_fsdp=True)), FSDP(DeterministicModel(wrap_fsdp=True)), DeterministicModel(wrap_fsdp=False), ) model_no_fsdp = FSDPContainer( DeterministicModel(wrap_fsdp=False), DeterministicModel(wrap_fsdp=False), DeterministicModel(wrap_fsdp=False), ) params_to_compare = list(model_no_fsdp.parameters()) with FullyShardedDataParallel.summon_full_params(model_fsdp): fsdp_params = [p.clone() for p in model_fsdp.parameters()] self.assertEqual(params_to_compare, fsdp_params)
def test_state_dict_load_into_local_module( self, state_dict_type, state_dict_rank0_and_offload, fsdp_root, ): """ Tests that FSDP's state_dict can be loaded into a local model. """ if state_dict_rank0_and_offload and state_dict_type != "state_dict": return if not fsdp_root: model = self._get_non_fsdp_root_module() else: model = self._initialize_model(wrap_fsdp=True, register_buffers=True) optim = SGD(model.parameters(), lr=0.1) if not fsdp_root: in_data = torch.randn(1, 10, requires_grad=True, device=torch.device("cuda")) else: in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() with FullyShardedDataParallel.summon_full_params(model): fsdp_params = deepcopy(list(model.parameters())) # get FSDP state_dict. Note that by default we return full_state_dict. sd_mgr = self._get_state_dict_mgr(model, state_dict_type, state_dict_rank0_and_offload) with sd_mgr: fsdp_state_dict = model.state_dict() ignore_keys = [ k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k ] self._validate_state_dict_contents( model, fsdp_state_dict, state_dict_rank0_and_offload, ignore_keys=ignore_keys, ) # Create zeroed local model if not fsdp_root: blank_local_model = self._get_non_fsdp_root_module(wrap=False) else: blank_local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False, register_buffers=True) # Nothing should be FSDP for mod in blank_local_model.modules(): self.assertFalse(isinstance(mod, FSDP)) for param in blank_local_model.parameters(): with torch.no_grad(): param.zero_() fsdp_state_dict = _gather_state_dict(fsdp_state_dict) # Load fsdp's full state dict into the local and verify params are as # expected. if state_dict_rank0_and_offload: # Broadcast + CUDA state_dict if not isinstance(model, FSDP): # Some portions of the model on rank 0 might not be on CPU, # move everything to CPU to avoid running into # https://github.com/pytorch/pytorch/issues/77113. for k, t in fsdp_state_dict.items(): if t.device != torch.device("cpu"): fsdp_state_dict[k] = t.cpu() fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() # if self.rank == 0: blank_local_model.load_state_dict(fsdp_state_dict, strict=True) local_params = list(blank_local_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param)
def test_state_dict_rank0_offload_save_load_flow(self): """Tests saving a model checkpoint only on rank 0 and loading it only on rank 0 with ``sync_module_states=True`` to emulate the workflow to avoid redundant CPU memory usage.""" auto_wrap_policy = partial( transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer }, ) fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs, ) # Force model parameters and buffers to be nonzero with FSDP.summon_full_params(fsdp_model): for tensor in itertools.chain(fsdp_model.parameters(), fsdp_model.buffers()): if torch.count_nonzero(tensor) == 0: with torch.no_grad(): tensor.add_( torch.tensor(1, dtype=tensor.dtype, device=tensor.device)) with self._get_state_dict_mgr(fsdp_model, "state_dict", True): state_dict = deepcopy(_get_state_dict(fsdp_model)) # Initialize a non-wrapped model on all ranks new_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, ) _zero_model(new_model, zero_buffers=True) # Only load the checkpoint on rank 0 if self.rank == 0: new_model.load_state_dict(state_dict, strict=True) _assert_module_states( new_model, process_group=self.process_group, assert_fn=self.assertNotEqual, ) # Broadcast the module states from rank 0 with `sync_module_states=True` new_fsdp_model = FSDP( new_model, device_id=torch.cuda.current_device(), auto_wrap_policy=auto_wrap_policy, sync_module_states=True, ) # Check FSDP models are equal across ranks with FSDP.summon_full_params(new_fsdp_model): _assert_module_states( new_fsdp_model, process_group=self.process_group, assert_fn=self.assertEqual, ) # Check FSDP models correctly loaded the checkpoint with FullyShardedDataParallel.summon_full_params(fsdp_model): with FullyShardedDataParallel.summon_full_params(new_fsdp_model): params = list(fsdp_model.parameters()) params_new = list(new_fsdp_model.parameters()) self.assertEqual(params, params_new)
def test_main_wrap_api(self, cpu_offload, backward_prefetch, forward_prefetch, cuda_init_mode): if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params: # they don't work together, expected return move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE class Nested(nn.Module): def __init__(self): super().__init__() self.nested_lin = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) def forward(self, input): return self.nested_lin(input) class MyModel(nn.Module): def __init__(self): super().__init__() self.lin1 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin2 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin3 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin4 = Nested() def forward(self, input): return self.lin4(self.lin3(self.lin2(self.lin1(input)))) model = MyModel() wrapped_model = FSDP( model, auto_wrap_policy=functools.partial( size_based_auto_wrap_policy, min_num_params=0, # wrap all modules ), cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, forward_prefetch=forward_prefetch, ) if cuda_init_mode == CUDAInitMode.CUDA_AFTER: wrapped_model = wrapped_model.cuda() modules_in_fsdp_graph_order = [ wrapped_model.module.lin1, wrapped_model.module.lin2, wrapped_model.module.lin3, wrapped_model.module.lin4.module.nested_lin, wrapped_model.module.lin4, wrapped_model ] for module in modules_in_fsdp_graph_order: self.assertTrue(isinstance(module, FSDP)) self._check_cpu_offload(module, cpu_offload) self._check_backward_prefetch(module, backward_prefetch) self._check_forward_prefetch(module, forward_prefetch) # Run model a few times for sanity check. optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9) inp = torch.ones(1).cuda() for _ in range(6): optim.zero_grad() loss = wrapped_model(inp).sum() loss.backward() optim.step() # Since we ran with backward prefetch, verify backward prefetch related # data. for i, module in enumerate(modules_in_fsdp_graph_order): self.assertEqual(i, module._my_fsdp_idx_in_graph) self.assertTrue( module._fsdp_graph_order == modules_in_fsdp_graph_order)