def _test_consolidate_weights(self, config, rank, group, paths=None, transformer=False): """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()""" # Establish reference behavior. if transformer: fsdp = self.get_wrapped_model(group, config=config).cuda() else: fsdp = FullyShardedDataParallel( MixtureOfExperts(group, wrapper_config=config)).cuda() optim = Adam( fsdp.parameters(), lr=0.01, ) optim.zero_grad() with torch.cuda.amp.autocast(enabled=True): x = fsdp.module.get_input(torch.device("cuda")) output = fsdp(*x) loss = fsdp.module.get_loss(x, output).to("cuda") fsdp.module.run_backward(loss) optim.step() # each worker saves a checkpoint with local_state_dict cp_data = { "weights": {k: v.cpu() for k, v in fsdp.local_state_dict().items()}, "meta": fsdp.local_metadata_dict(), } torch.save(cp_data, paths[fsdp.rank]) full_model_state_dict = fsdp.state_dict() torch.distributed.barrier() if fsdp.rank > 0: return all_checkpoints = [torch.load(p) for p in paths] consolidated_checkpoint = FullyShardedDataParallel.consolidate_shard_weights( shard_weights=[c["weights"] for c in all_checkpoints], shard_metadata=[c["meta"] for c in all_checkpoints], ) full_model_extra = set(full_model_state_dict).difference( set(consolidated_checkpoint)) consolidated_extra = set(consolidated_checkpoint).difference( set(full_model_state_dict)) msg = f"full model extra keys: {full_model_extra}, consolidated extra {consolidated_extra}" for k in full_model_state_dict.keys(): assert consolidated_checkpoint[k].shape == full_model_state_dict[ k].shape assert set(full_model_state_dict.keys()) == set( consolidated_checkpoint.keys()), msg
def test_consolidate_missing_params(): """This tests that fairseq experts, which are saved independently from the rest of the model, can be consolidated.""" desired_path = "decoder.layers.1.moe_layer.experts.0" shard_metadata = { "param_metadata": [ { "fsdp_path": "", "params": { "flat_param_0": { "names": ["missing"], "shapes": [(12, 4)], "numels": [12 * 4], "padding": 0 } }, "no_broadcast_optim_state": False, "shared_param_info": [], }, { "fsdp_path": desired_path, "params": { "flat_param_0": { "names": ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"], "shapes": [(4, 4), (4, ), (4, 4), (4, )], "numels": [16, 4, 16, 4], "padding": 0, } }, "no_broadcast_optim_state": True, "shared_param_info": [], }, ], "buffer_names": ["missing.buffer"], } shard_weights = { "decoder.layers.1.moe_layer.experts.0.flat_param_0": torch.randn(40, dtype=torch.float16) } consolidated_weights = FullyShardedDataParallel.consolidate_shard_weights( [shard_weights], [shard_metadata], strict=False) assert len(consolidated_weights) == 4 for k in consolidated_weights: assert k.startswith( desired_path), f"{k} doesnt start with {desired_path}"
def _worker(gpu_id: int, sync_file: str, world_size: int, embedding_size: int, flatten_parameters: bool): torch.manual_seed(0) torch.cuda.set_device(gpu_id) torch.distributed.init_process_group( backend="nccl", init_method=f"file://{sync_file}", world_size=world_size, rank=gpu_id, ) process_group = torch.distributed.new_group() # Create a dummy model with dummy inputs and targets batch_size = 4 input = torch.randn(size=(batch_size, 3, 32, 32)).cuda() target = torch.zeros(size=(batch_size, embedding_size)).cuda() model = _create_model( with_fsdp=True, process_group=process_group, embedding_size=embedding_size, flatten_parameters=flatten_parameters, ) criterion = nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) # Train the model for a few epochs for epoch in range(2): out = model(input) loss = criterion(out, target) optimizer.zero_grad() loss.backward() optimizer.step() # Save a bunch of checkpoint, one by shard cp_data = { "weights": {k: v.cpu() for k, v in model.local_state_dict().items()}, "meta": model.local_metadata_dict(), } torch.save(cp_data, f"checkpoint_{gpu_id}.torch") # Wait for all files to be written on the disk dist.barrier() # type: ignore # Reconstruct a full checkpoint from the sharded checkpoints all_checkpoints = [_load_sharded_checkpoint(rank) for rank in range(world_size)] consolidated_checkpoint = FullyShardedDataParallel.consolidate_shard_weights( shard_weights=[c["weights"] for c in all_checkpoints], shard_metadata=[c["meta"] for c in all_checkpoints], ) # Check that the reconstructed parameters are correct and of the right shape full_model = _create_model(with_fsdp=False, process_group=process_group, embedding_size=embedding_size) full_model_state_dict = full_model.state_dict() assert set(full_model_state_dict.keys()) == set(consolidated_checkpoint.keys()) for k in full_model_state_dict.keys(): assert consolidated_checkpoint[k].shape == full_model_state_dict[k].shape # Verify that the checkpoint can be loaded by a FSDP model loaded_model = _create_model( with_fsdp=True, process_group=process_group, embedding_size=embedding_size, flatten_parameters=flatten_parameters, ) loaded_model.load_state_dict(consolidated_checkpoint) for m in loaded_model.modules(): if isinstance(m, FullyShardedDataParallel): m._reset_lazy_init() # Verify that the model saved and the model loaded give the same results with torch.no_grad(): before_checkpoint_loss = criterion(model(input), target).item() after_checkpoint_loss = criterion(loaded_model(input), target).item() assert before_checkpoint_loss == after_checkpoint_loss
def _consolidate_shards(cls, weights: List[Dict[str, torch.Tensor]], metadata: List[Dict[str, Any]]): logging.info("Consolidating shards...") return FullyShardedDataParallel.consolidate_shard_weights( weights, metadata)