Beispiel #1
0
    def _test_consolidate_weights(self,
                                  config,
                                  rank,
                                  group,
                                  paths=None,
                                  transformer=False):
        """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
        # Establish reference behavior.

        if transformer:
            fsdp = self.get_wrapped_model(group, config=config).cuda()
        else:
            fsdp = FullyShardedDataParallel(
                MixtureOfExperts(group, wrapper_config=config)).cuda()

        optim = Adam(
            fsdp.parameters(),
            lr=0.01,
        )
        optim.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):
            x = fsdp.module.get_input(torch.device("cuda"))
            output = fsdp(*x)
            loss = fsdp.module.get_loss(x, output).to("cuda")
            fsdp.module.run_backward(loss)
            optim.step()

        # each worker saves a checkpoint with local_state_dict
        cp_data = {
            "weights":
            {k: v.cpu()
             for k, v in fsdp.local_state_dict().items()},
            "meta": fsdp.local_metadata_dict(),
        }
        torch.save(cp_data, paths[fsdp.rank])
        full_model_state_dict = fsdp.state_dict()
        torch.distributed.barrier()
        if fsdp.rank > 0:
            return
        all_checkpoints = [torch.load(p) for p in paths]
        consolidated_checkpoint = FullyShardedDataParallel.consolidate_shard_weights(
            shard_weights=[c["weights"] for c in all_checkpoints],
            shard_metadata=[c["meta"] for c in all_checkpoints],
        )
        full_model_extra = set(full_model_state_dict).difference(
            set(consolidated_checkpoint))
        consolidated_extra = set(consolidated_checkpoint).difference(
            set(full_model_state_dict))
        msg = f"full model extra keys: {full_model_extra}, consolidated extra {consolidated_extra}"
        for k in full_model_state_dict.keys():
            assert consolidated_checkpoint[k].shape == full_model_state_dict[
                k].shape
        assert set(full_model_state_dict.keys()) == set(
            consolidated_checkpoint.keys()), msg
Beispiel #2
0
def test_consolidate_missing_params():
    """This tests that fairseq experts, which are saved independently from the rest of the model, can be consolidated."""
    desired_path = "decoder.layers.1.moe_layer.experts.0"
    shard_metadata = {
        "param_metadata": [
            {
                "fsdp_path": "",
                "params": {
                    "flat_param_0": {
                        "names": ["missing"],
                        "shapes": [(12, 4)],
                        "numels": [12 * 4],
                        "padding": 0
                    }
                },
                "no_broadcast_optim_state": False,
                "shared_param_info": [],
            },
            {
                "fsdp_path": desired_path,
                "params": {
                    "flat_param_0": {
                        "names":
                        ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"],
                        "shapes": [(4, 4), (4, ), (4, 4), (4, )],
                        "numels": [16, 4, 16, 4],
                        "padding":
                        0,
                    }
                },
                "no_broadcast_optim_state": True,
                "shared_param_info": [],
            },
        ],
        "buffer_names": ["missing.buffer"],
    }
    shard_weights = {
        "decoder.layers.1.moe_layer.experts.0.flat_param_0":
        torch.randn(40, dtype=torch.float16)
    }
    consolidated_weights = FullyShardedDataParallel.consolidate_shard_weights(
        [shard_weights], [shard_metadata], strict=False)
    assert len(consolidated_weights) == 4
    for k in consolidated_weights:
        assert k.startswith(
            desired_path), f"{k} doesnt start with {desired_path}"
def _worker(gpu_id: int, sync_file: str, world_size: int, embedding_size: int, flatten_parameters: bool):
    torch.manual_seed(0)
    torch.cuda.set_device(gpu_id)
    torch.distributed.init_process_group(
        backend="nccl", init_method=f"file://{sync_file}", world_size=world_size, rank=gpu_id,
    )
    process_group = torch.distributed.new_group()

    # Create a dummy model with dummy inputs and targets
    batch_size = 4
    input = torch.randn(size=(batch_size, 3, 32, 32)).cuda()
    target = torch.zeros(size=(batch_size, embedding_size)).cuda()
    model = _create_model(
        with_fsdp=True,
        process_group=process_group,
        embedding_size=embedding_size,
        flatten_parameters=flatten_parameters,
    )
    criterion = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

    # Train the model for a few epochs
    for epoch in range(2):
        out = model(input)
        loss = criterion(out, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Save a bunch of checkpoint, one by shard
    cp_data = {
        "weights": {k: v.cpu() for k, v in model.local_state_dict().items()},
        "meta": model.local_metadata_dict(),
    }
    torch.save(cp_data, f"checkpoint_{gpu_id}.torch")

    # Wait for all files to be written on the disk
    dist.barrier()  # type: ignore

    # Reconstruct a full checkpoint from the sharded checkpoints
    all_checkpoints = [_load_sharded_checkpoint(rank) for rank in range(world_size)]
    consolidated_checkpoint = FullyShardedDataParallel.consolidate_shard_weights(
        shard_weights=[c["weights"] for c in all_checkpoints], shard_metadata=[c["meta"] for c in all_checkpoints],
    )

    # Check that the reconstructed parameters are correct and of the right shape
    full_model = _create_model(with_fsdp=False, process_group=process_group, embedding_size=embedding_size)
    full_model_state_dict = full_model.state_dict()
    assert set(full_model_state_dict.keys()) == set(consolidated_checkpoint.keys())
    for k in full_model_state_dict.keys():
        assert consolidated_checkpoint[k].shape == full_model_state_dict[k].shape

    # Verify that the checkpoint can be loaded by a FSDP model
    loaded_model = _create_model(
        with_fsdp=True,
        process_group=process_group,
        embedding_size=embedding_size,
        flatten_parameters=flatten_parameters,
    )
    loaded_model.load_state_dict(consolidated_checkpoint)
    for m in loaded_model.modules():
        if isinstance(m, FullyShardedDataParallel):
            m._reset_lazy_init()

    # Verify that the model saved and the model loaded give the same results
    with torch.no_grad():
        before_checkpoint_loss = criterion(model(input), target).item()
        after_checkpoint_loss = criterion(loaded_model(input), target).item()
        assert before_checkpoint_loss == after_checkpoint_loss
Beispiel #4
0
 def _consolidate_shards(cls, weights: List[Dict[str, torch.Tensor]],
                         metadata: List[Dict[str, Any]]):
     logging.info("Consolidating shards...")
     return FullyShardedDataParallel.consolidate_shard_weights(
         weights, metadata)