def test_norm(device, norm_type, mixed_precision): """Test checkpoint_wrapper with different norm layers.""" if device == "cuda" and not torch.cuda.is_available(): pytest.skip("Skip due to lack of GPU") # Get input, ref, checkpoint models and make them equal. in_data = torch.rand(2, 2, 3, 3).to(device) m_ref = get_model(norm_type, False, mixed_precision).to(device) m_cpt = get_model(norm_type, True, mixed_precision).to(device) m_cpt.load_state_dict(m_ref.state_dict()) if torch_version() >= (1, 6, 0): # This assert fails on 1.5.1. assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict()) if mixed_precision != "fp32": in_data = in_data.half() # Needed due to checkpointing. in_data.requires_grad = True for model in (m_ref, m_cpt): optim = SGD(model.parameters(), lr=0.1) if device == "cpu" and mixed_precision != "fp32": # Got: RuntimeError: "batch_norm"/"layer_norm" not implemented for 'Half'. with pytest.raises(RuntimeError): out = model(in_data) return else: # Everything else work. out = model(in_data) out.sum().backward() optim.step() if torch_version() >= (1, 6, 0): assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict())
def _dist_worker(rank, world_size, files, wrap_middle, test_fn): # Get data from files. file1, file2, sd_before, sd_after, in_data = files sd_before = torch.load( sd_before, map_location=lambda storage, loc: storage.cuda(rank)) if test_fn == "train": sd_after = torch.load( sd_after, map_location=lambda storage, loc: storage.cuda(rank)) in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank)) result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2) assert result, "Dist init failed" fsdp_model = FSDP( # To debug: first make with_fsdp=False (no inner wrapping) work, then enable inner wrapping # and make that work. Model(with_fsdp=True, wrap_middle=wrap_middle), flatten_parameters=test_fn == "optim_state", mixed_precision=False, compute_dtype=torch.float16, ) fsdp_model.load_state_dict(sd_before) if test_fn == "train": _train(fsdp_model, in_data) objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True) elif test_fn == "eval": _eval(fsdp_model, in_data) elif test_fn == "optim_state": optim = SGD(fsdp_model.parameters(), lr=0.1) for _ in range(3): out = fsdp_model(in_data) out.backward() optim.step() sd = fsdp_model.gather_full_optim_state_dict(optim) if rank == 0: # There should 8 momentum buffers in the state. assert len(sd["state"].keys()) == 8 else: assert sd is None, "only rank 0 should have the optim state" else: assert 0, f"invalid test_fn {test_fn}" teardown()
def test_named_params_ordering(self): """Test assumption of consolidate_optimizer_state_dict""" group = DummyProcessGroup(0, 1) model = TransformerWithSharedParams(group) named_pars = [p for n, p in model.named_parameters()] for i, p in enumerate(model.parameters()): assert objects_are_equal(p, named_pars[i])
def test_shared_weight_mevo(temp_files, wrap_middle, test_fn): """Test FSDP with a model with shared weights.""" if test_fn == "optim_state": if wrap_middle != "flat": pytest.skip( "only support optim_state when root and middle part is flat") world_size = 2 # Get ref. model = Model() sd_before = deepcopy(model.state_dict()) in_data = (torch.rand(BS, SEQ) * (VOCAB - 1)).cuda().long() if test_fn == "train": _train(model, in_data, world_size) sd_after = deepcopy(model.state_dict()) # Before and after state should not be equal. assert not objects_are_equal(sd_before, sd_after) # Save data torch.save(sd_before, temp_files[2]) if test_fn == "train": torch.save(sd_after, temp_files[3]) torch.save(in_data, temp_files[4]) # Run FSDP mp.spawn( _dist_worker, (world_size, temp_files, wrap_middle, test_fn), nprocs=world_size, )
def _test_nested_wrapped_model(cls, rank, group, config=None): # Get reference state dict without any nested FSDP instances. model = NestedWrappedModule(group, None).cuda() model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group) cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"]) ref_state_dict = { k: v.clone() for k, v in model.module.state_dict().items() } # Create a nested FSDP-wrapped instance. if config["mixed_precision"]: config["compute_dtype"] = torch.float32 model = NestedWrappedModule(group, config) model = FullyShardedDataParallel(model, group, **config).cuda() cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"]) # Round-trip state dict save/load/save. state_dict = {k: v.clone() for k, v in model.state_dict().items()} model.load_state_dict(state_dict) state_dict = model.state_dict() assert ref_state_dict.keys() == state_dict.keys( ), f"{ref_state_dict.keys()} != {state_dict.keys()}" for key in ref_state_dict.keys(): assert objects_are_equal( ref_state_dict[key], state_dict[key], raise_exception=False ), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
def test_shared_weight(temp_files, outer_flat, inner_flat, sharing): """Test FSDP with a model with shared weights.""" outer_flat = outer_flat == "outer_flat" inner_flat = inner_flat == "inner_flat" world_size = 2 # Get reference results. model = Model(sharing=sharing) sd_before = deepcopy(model.state_dict()) in_data = torch.rand(1, 4).cuda() _train(model, in_data, world_size) sd_after = deepcopy(model.state_dict()) # Before and after state should not be equal. assert not objects_are_equal(sd_before, sd_after) # Save data torch.save(sd_before, temp_files[2]) torch.save(sd_after, temp_files[3]) torch.save(in_data, temp_files[4]) # Run FSDP mp.spawn( _dist_worker, (world_size, temp_files, outer_flat, inner_flat, sharing), nprocs=world_size, )
def test_unflatten_params(self): for module_init_fn in self._get_module_init_fns(): module = FlattenParamsWrapper(module_init_fn()) buffers = {k.replace("_fpw_module.", "") for k, _ in module.named_buffers()} def clone_state_dict(): return OrderedDict((k, v.clone()) for k, v in module.state_dict().items()) ref_flat_param = module.flat_param.clone() with module.unflatten_params(): ref_state_dict = clone_state_dict() assert not torch.all(ref_flat_param == 0) # confirm that unflatten_params reflects values from new_flat_param new_flat_param = torch.full_like(module.flat_param, fill_value=42.0) with module.unflatten_params(flat_param=new_flat_param): new_state_dict = clone_state_dict() assert new_state_dict.keys() == ref_state_dict.keys() for k, v in new_state_dict.items(): if k in buffers: # buffers are not changed torch.testing.assert_allclose(v, ref_state_dict[k]) else: # params reflect new_flat_param value assert torch.all(v == 42.0) # after context manager exits, we go back to previous (reference) state torch.testing.assert_allclose(module.flat_param, ref_flat_param) with module.unflatten_params(): ref_state_dict2 = clone_state_dict() assert objects_are_equal(ref_state_dict, ref_state_dict2) # if we load the new_state_dict, then the flat param should match new_flat_param module.load_state_dict(new_state_dict) torch.testing.assert_allclose(module.flat_param, new_flat_param)
def _test_nested_wrapped_model_local_state_dict(cls, rank, group, config=None, local=None): # Create a nested FSDP-wrapped instance. model = NestedWrappedModule(group, config) model = FullyShardedDataParallel(model, group, **config).cuda() cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"]) # Round trip state dict save/load/save. ref_state_dict = { k: v.clone() for k, v in model.local_state_dict().items() } model.load_local_state_dict(ref_state_dict) state_dict = model.local_state_dict() assert ref_state_dict.keys() == state_dict.keys( ), f"{ref_state_dict.keys()} != {state_dict.keys()}" for key in ref_state_dict.keys(): assert objects_are_equal( ref_state_dict[key], state_dict[key], raise_exception=False ), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
def test_state_dict_equality(self): module = self._get_shared_params_transformer() ref_state_dict = module.state_dict() flat_module = FlattenParamsWrapper(module) flat_state_dict = flat_module.state_dict() assert objects_are_equal(ref_state_dict, flat_state_dict)
def _distributed_worker(gpu_id, world_size, with_fsdp, freezing_method, tempfile_name, unused, rank_0_output, expected_state): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = _create_model(with_fsdp) model = model.cuda() # freezing the trunk using requires_grad. assert freezing_method in ["requires_grad", "grad_to_none"] if freezing_method == "requires_grad": for param in model.trunk.parameters(): param.requires_grad = False if with_fsdp: model = FSDP(model) else: model = DistributedDataParallel(model, device_ids=[gpu_id]) if gpu_id == 0: print(model) target = torch.LongTensor([0, 1]).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) for iteration in range(3): out = model(batch) fake_loss = criterion(out, target) print("Loss", iteration, ":", fake_loss.item()) optimizer.zero_grad() fake_loss.backward() if freezing_method == "grad_to_none": for param in model.trunk.parameters(): param.grad = None optimizer.step() if with_fsdp: fsdp_state = model.state_dict() # Move tensors to CPU to compare numerics. for k, v in fsdp_state.items(): fsdp_state[k] = v.cpu() assert objects_are_equal(expected_state, fsdp_state, raise_exception=True) elif rank == 0: state_after = model.module.cpu().state_dict() torch.save(state_after, rank_0_output) teardown()
def _dist_worker(rank, world_size, files, outer_flat, inner_flat, sharing): # Get data from files. file1, file2, sd_before, sd_after, in_data = files sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank)) sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank)) in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank)) result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2) assert result, "Dist init failed" fsdp_model = FSDP(Model(with_fsdp=True, inner_flat=inner_flat, sharing=sharing), flatten_parameters=outer_flat) fsdp_model.load_state_dict(sd_before) _train(fsdp_model, in_data) objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True) teardown()
def test_load_state_dict(self): """Test that original (unwrapped) state_dict can be loaded in wrapped module.""" for module_init_fn in self._get_module_init_fns(): module = module_init_fn() ref_state_dict = module.state_dict() ref_output = self._get_output(module) module = module_init_fn(seed=1234) flat_module = FlattenParamsWrapper(module) # This should work without the unflatten_params context manager flat_module.load_state_dict(ref_state_dict) flat_output = self._get_output(flat_module) assert objects_are_equal(ref_output, flat_output) # And it should work with the context manager too with flat_module.unflatten_params(): flat_module.load_state_dict(ref_state_dict) flat_output = self._get_output(flat_module) assert objects_are_equal(ref_output, flat_output)
def test_load_state_dict(self): module = self._get_shared_params_transformer() ref_state_dict = module.state_dict() ref_output = self._get_output(module) module = self._get_shared_params_transformer(seed=1234) flat_module = FlattenParamsWrapper(module) flat_module.load_state_dict(ref_state_dict) flat_output = self._get_output(flat_module) assert objects_are_equal(ref_output, flat_output)
def test_state_dict_equality(self): """Test that unflattened state dict matches original (unwrapped) one.""" modules_to_test = [init_fn() for init_fn in self._get_module_init_fns()] for module in modules_to_test: ref_state_dict = module.state_dict() flat_module = FlattenParamsWrapper(module) flat_state_dict = flat_module.state_dict() assert ( ref_state_dict.keys() == flat_state_dict.keys() ), f"{ref_state_dict.keys()} != {flat_state_dict.keys()}" assert objects_are_equal(ref_state_dict, flat_state_dict), f"{ref_state_dict} != {flat_state_dict}"
def test_flat_state_dict(self): flat_module = self._get_shared_params_transformer() flat_module = FlattenParamsWrapper(flat_module) ref_output = self._get_output(flat_module) flat_state_dict = flat_module.flat_state_dict() new_module = self._get_shared_params_transformer(seed=1234) new_module = FlattenParamsWrapper(new_module) new_module.load_state_dict(flat_state_dict) new_output = self._get_output(new_module) assert objects_are_equal(ref_output, new_output)
def test_flat_state_dict(self): """Test that flat state dict can be reloaded and produces the same results.""" for module_init_fn in self._get_module_init_fns(): flat_module = FlattenParamsWrapper(module_init_fn()) ref_output = self._get_output(flat_module) flat_state_dict = flat_module.flat_state_dict() new_module = FlattenParamsWrapper(module_init_fn(seed=1234)) new_module.load_state_dict(flat_state_dict) new_output = self._get_output(new_module) assert objects_are_equal(ref_output, new_output)
def _test_param_change_after_init(self, rank, group, config): # Establish reference behavior. model = self.get_wrapped_model(group, cuda_first=False, config=config) model.eval() # no dropout for this test input = model.module.get_input(torch.device("cuda")) ref_output = model(*input) # Change the weights in place. model = self.get_wrapped_model(group, cuda_first=False, config=config) model.eval() # no dropout for this test first_param = next(model.parameters()) nn.init.normal_(first_param.data) new_output = model(*input) assert not objects_are_equal(ref_output, new_output), "new_output did not reflect change to param after init"
def _test_grad_acc(self, model, batch_dim, use_no_sync_context=True): # Generate two input batches. We'll test that we get the same grads if # we train on them sequentially while accumulating grads (with no_sync # or without no_sync) vs. concatenating the batches and training in one go. # # The difference between with no_sync and without is GPU memory vs. networking # bandwidth tradeoff. batch1 = model.module.get_input(torch.device("cuda")) assert isinstance(batch1, tuple) batch2 = tuple( # This randomly permutes the values in a multi-dim tensor. x.view(-1)[torch.randperm(x.numel())].view_as(x) for x in batch1) for x, y in zip(batch1, batch2): assert not torch.all(x == y) # Concat the batches along batch dimension. concat_batch = tuple( torch.cat((x, y), dim=batch_dim) for (x, y) in zip(batch1, batch2)) # Establish reference behavior on the concat batch. model.zero_grad() output = model(*concat_batch) ref_loss = model.module.get_loss(concat_batch, output) ref_loss.backward() ref_grads = [p.grad.detach().clone() for p in model.parameters()] # Test that we get the same results by accumulating grads. model.zero_grad() context = contextlib.suppress() if use_no_sync_context: context = model.no_sync() with context: # accumulate gradients from the first batch output = model(*batch1) loss1 = model.module.get_loss(batch1, output) loss1.backward() output = model(*batch2) loss2 = model.module.get_loss(batch2, output) loss2.backward() accumulated_loss = loss1 + loss2 accumulated_grads = [ p.grad.detach().clone() for p in model.parameters() ] torch.testing.assert_allclose(ref_loss, accumulated_loss) assert objects_are_equal(ref_grads, accumulated_grads, raise_exception=True)
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn): mixed_precision = precision == "mixed" flatten = flatten == "flatten" wrap_bn = wrap_bn == "auto_wrap_bn" fp32_reduce_scatter = True if mixed_precision else None if torch_version() < (1, 8, 0) and flatten: # 1.6 and 1.7 throws this error: # RuntimeError: Trying to backward through the graph a second time, but the saved # intermediate results have already been freed. Specify retain_graph=True when calling # backward the first time. pytest.skip("older pytorch throws error when flatten is used") world_size = 2 expected_losses = None # Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt. for with_fsdp in [False, True]: for with_checkpoint in [False, True]: # Get 4 files: 2 for dist_init and 2 for each rank to save the losses. with temp_files_ctx(num=2 + world_size) as temp_files: mp.spawn( _distributed_worker, ( world_size, with_fsdp, with_checkpoint, temp_files, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, ), nprocs=world_size, ) final_losses = {} for rank in range(world_size): with open(temp_files[2 + rank], "rb") as f: final_losses[f"rank_{rank}"] = pickle.load(f) if expected_losses is None: expected_losses = final_losses else: print(f"fsdp: {with_fsdp} ckpt: {with_checkpoint}") assert objects_are_equal(expected_losses, final_losses, raise_exception=True)
def _test_transformer(self, rank, group, config): autocast = config["mixed_precision"] # Train model for a step model = self.get_wrapped_model(group, cuda_first=False, config=config) self._train_for_several_steps(model, 1, autocast) model.eval() # no dropout for this test # Eval in standard mode (i.e., without no_grad) input = model.module.get_input(torch.device("cuda")) ref_output = model(*input) # Eval with no_grad and compare with torch.no_grad(): no_grad_output = model(*input) assert objects_are_equal(ref_output, no_grad_output, raise_exception=True)
def _test_identical_outputs( cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2, ): if config.get("mixed_precision", False): autocast = True # Force the compute dtype to be torch.float32 so that we get # identical results as PyTorch DDP when using autocast. Note that # this will cause the all-gather to happen in FP32, which is slower # than necessary in most cases. config["compute_dtype"] = torch.float32 else: autocast = False # Establish reference behavior with PyTorch DDP (+ optionally autocast). model = model_init_fn(group=group, wrapper_config=None).cuda() if ref_ddp_fn is None: model = nn.parallel.DistributedDataParallel( model, device_ids=[rank], output_device=rank, process_group=group ) else: model = ref_ddp_fn(model, group) ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) ref_state_dict = model.module.state_dict() if config.get("cpu_offload", False): for k in ref_state_dict.keys(): ref_state_dict[k] = ref_state_dict[k].cpu() # Confirm we get the same behavior using FullyShardedDataParallel. model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config) if use_cuda: model = model.cuda() else: assert next(model.parameters()).device == torch.device("cpu") shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) shard_state_dict = model.state_dict() try: torch.testing.assert_allclose(ref_loss, shard_loss) assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True) except (AssertionError, RuntimeError) as e: raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}") if config.get("flatten_parameters", True): metadata = model.local_metadata_dict() assert isinstance(metadata, dict)
def _distributed_worker( rank, world_size, fsdp_config, fsdp_wrap_bn, ddp_mixed_precision, tempfile_name, unused, state_before, inputs, rank_0_output, state_after, sync_bn, conv_bias, linear_bias, ): torch.backends.cudnn.deterministic = True result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" ddp = True if fsdp_config: ddp = False assert isinstance(fsdp_config, dict), str(fsdp_config) if fsdp_config["mixed_precision"]: # To match DDP in AMP -O1, we need fp32 reduce scatter. fsdp_config["fp32_reduce_scatter"] = True model = Model(conv_bias, linear_bias) model.load_state_dict(state_before) model = model.cuda() class DummyScaler: def scale(self, loss): return loss def step(self, optim): optim.step() def update(self): pass scaler = DummyScaler() if ddp: if sync_bn == "pytorch": model = pytorch_bn_converter(model) model = DDP(model, device_ids=[rank], broadcast_buffers=True) if ddp_mixed_precision: scaler = GradScaler() else: # Note, different rank may wrap in different order due to different random # seeds. But results should be the same. if random.randint(0, 1) == 0: print(f"auto_wrap_bn {fsdp_wrap_bn}, then sync_bn {sync_bn}") if fsdp_wrap_bn: model = auto_wrap_bn(model, _single_rank_pg) if sync_bn == "pytorch": model = pytorch_bn_converter(model) else: print(f"sync_bn {sync_bn}, then auto_wrap_bn {fsdp_wrap_bn}") if sync_bn == "pytorch": model = pytorch_bn_converter(model) if fsdp_wrap_bn: model = auto_wrap_bn(model, _single_rank_pg) model = FSDP(model, **fsdp_config).cuda() if fsdp_config["mixed_precision"]: scaler = ShardedGradScaler() # Print the model for verification. if rank == 0: print(model) optim = SGD(model.parameters(), lr=0.1) loss_func = CrossEntropyLoss() for in_data in inputs[rank]: in_data = in_data.cuda() context = contextlib.suppress() if ddp and ddp_mixed_precision: in_data = in_data.half() context = torch.cuda.amp.autocast(enabled=True) if not ddp and fsdp_config["mixed_precision"]: context = torch.cuda.amp.autocast(enabled=True) with context: out = model(in_data) fake_label = torch.zeros(1, dtype=torch.long).cuda() loss = loss_func(out.unsqueeze(0), fake_label) scaler.scale(loss).backward() scaler.step(optim) scaler.update() optim.zero_grad() if ddp: # Save the rank 0 state_dict to the output file. if rank == 0: state_after = model.module.cpu().state_dict() torch.save(state_after, rank_0_output) else: model.assert_state(TrainingState.IDLE) # Ensure final state equals to the state_after. fsdp_state = model.state_dict() # Move tensors to CPU to compare numerics. for k, v in fsdp_state.items(): fsdp_state[k] = v.cpu() # Change False to True to enable this when you want to debug the mismatch. if False and rank == 0: def dump(d): for k, v in d.items(): print(k, v) dump(state_after) dump(fsdp_state) # If sync_bn is used, all ranks should have the same state, so we can compare with # rank 0 state on every rank. Otherwise, only compare rank 0 with rank 0. if sync_bn != "none" or rank == 0: assert objects_are_equal(state_after, fsdp_state, raise_exception=True) teardown()
def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False): """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()""" # Establish reference behavior. if transformer: fsdp = self.get_wrapped_model(group, config=config).cuda() unwrapped_model = TransformerWithSharedParams(group).cuda() else: fsdp = FullyShardedDataParallel( NestedWrappedModule(group, wrapper_config=config), group, **config).cuda() unwrapped_model = NestedWrappedModule(group, wrapper_config=None).cuda() try: fsdp_optim = optim_fn( fsdp.parameters(), lr=0.01, ) optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) except TypeError: # Adadelta fsdp_optim = optim_fn(fsdp.parameters()) optim_unwrapped = optim_fn(unwrapped_model.parameters()) fsdp_optim.zero_grad() optim_unwrapped.zero_grad() x = fsdp.module.get_input(torch.device("cuda")) output = fsdp(*x) loss = fsdp.module.get_loss(x, output).to("cuda") fsdp.module.run_backward(loss) fsdp_optim.step() output = unwrapped_model(*x) loss = unwrapped_model.get_loss(x, output) unwrapped_model.run_backward(loss) optim_unwrapped.step() unwrapped_sd = optim_unwrapped.state_dict() tstart = time() sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0) duration = time() - tstart # Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate" if fsdp.rank > 0: return assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) assert_equal( sum([first_tensor_numel(v) for k, v in sd["state"].items()]), sum([ first_tensor_numel(v) for k, v in unwrapped_sd["state"].items() ]), ) shard_sd = fsdp.get_shard_from_optim_state_dict(sd) original_shard_sd = fsdp_optim.state_dict() assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"])) assert_equal(shard_sd.keys(), original_shard_sd.keys()) original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu") assert_equal( sum([first_tensor_numel(v) for k, v in shard_sd["state"].items()]), sum([ first_tensor_numel(v) for k, v in original_shard_sd["state"].items() ]), ) assert objects_are_equal(shard_sd, original_shard_sd)
def _test_output(self, module): ref_output = self._get_output(module) flat_module = FlattenParamsWrapper(module) flat_output = self._get_output(flat_module) assert objects_are_equal(ref_output, flat_output)
def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False): """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()""" # Establish reference behavior. if transformer: unwrapped_model = TransformerWithSharedParams( group, wrapper_config=config).cuda() fsdp = self.get_wrapped_model(group, config=config).cuda() else: unwrapped_model = MixtureOfExperts(group, wrapper_config=None).cuda() fsdp = FullyShardedDataParallel( MixtureOfExperts(group, wrapper_config=config)).cuda() try: fsdp_optim = optim_fn( fsdp.parameters(), lr=0.01, ) optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) except TypeError: # Adadelta fsdp_optim = optim_fn(fsdp.parameters()) optim_unwrapped = optim_fn(unwrapped_model.parameters()) fsdp_optim.zero_grad() optim_unwrapped.zero_grad() with torch.cuda.amp.autocast(enabled=True): x = fsdp.module.get_input(torch.device("cuda")) output = fsdp(*x) loss = fsdp.module.get_loss(x, output).to("cuda") fsdp.module.run_backward(loss) fsdp_optim.step() output = unwrapped_model(*x) loss = unwrapped_model.get_loss(x, output) unwrapped_model.run_backward(loss) optim_unwrapped.step() unwrapped_sd = optim_unwrapped.state_dict() if not transformer: no_broadcast_children = [ x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state ] assert len(no_broadcast_children) == 1 assert fsdp._fsdp_instances[-1].no_broadcast_optim_state torch.cuda.empty_cache() cuda_gb_before = torch.cuda.memory_stats( fsdp.rank)["allocated_bytes.all.current"] / 1024**3 tstart = time() sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0) duration = time() - tstart assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate" cuda_gb_after = torch.cuda.memory_stats( fsdp.rank)["allocated_bytes.all.current"] / 1024**3 mem_usg_gb = cuda_gb_after - cuda_gb_before assert mem_usg_gb == 0, f"gather_full_optim_state_dict used {mem_usg_gb:.2f} CUDA GB, max allowed is 0" assert cuda_gb_after > 0, "got 0 memory usage, logging is broken" if fsdp.rank > 0: assert sd is None return # assert whole state dict on CPU for k, v in sd["state"].items(): for buffer_name, t in v.items(): if torch.is_tensor(t): msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU" assert t.device == torch.device("cpu"), msg unflat_state = sd["state"] assert "uncollected_local_ids" in sd shard_sd = fsdp.get_shard_from_optim_state_dict(sd) shard_sd = recursive_copy_to_device(shard_sd, non_blocking=False, device="cpu") state_after_get_shard = sd["state"] assert objects_are_equal(unflat_state, state_after_get_shard) # no side effects. assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) assert_equal( sum([first_tensor_numel(v) for k, v in sd["state"].items()]), sum([ first_tensor_numel(v) for k, v in unwrapped_sd["state"].items() ]), ) original_shard_sd = fsdp_optim.state_dict() assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"])) assert_equal(shard_sd.keys(), original_shard_sd.keys()) original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu") # Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks. assert_equal( [first_tensor_numel(v) for k, v in shard_sd["state"].items()], [ first_tensor_numel(v) for k, v in original_shard_sd["state"].items() ], ) assert_equal( [v for k, v in shard_sd["param_groups"][0].items()], [v for k, v in original_shard_sd["param_groups"][0].items()], ) assert objects_are_equal(shard_sd["state"], original_shard_sd["state"]) assert objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd)