コード例 #1
0
def test_scaler_cpu_offload_breaks():
    device = torch.device("cuda")
    torch.cuda.set_device(0)

    # Random port in case the next test run quickly, same port would cause conflict.
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
    torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

    try:
        scaler = ShardedGradScaler()
        model = FullyShardedDataParallel(nn.Linear(5, 5),
                                         cpu_offload=True,
                                         mixed_precision=True)
        optim = torch.optim.SGD(model.parameters(), lr=1e-3)

        input = torch.rand((1, 5), dtype=torch.float).to(device)
        optim.zero_grad()
        with autocast():
            output = model(input)
            loss = F.mse_loss(input, output)

        scaler.scale(loss).backward()
        # TODO (Min): Need to fix. Details in issue #421.
        with pytest.raises(RuntimeError):
            scaler.step(optim)
            scaler.update()

    finally:
        # Clean-up is important or the next test in this file may fail to init the PG.
        torch.distributed.destroy_process_group()
        del os.environ["MASTER_ADDR"]
        del os.environ["MASTER_PORT"]
コード例 #2
0
 def __init__(self, embedding_size: int, with_fsdp: bool, process_group):
     super().__init__()
     self.conv1 = self._conv_block(3, embedding_size)
     self.conv2: nn.Module = self._conv_block(embedding_size,
                                              embedding_size // 2)
     self.conv3: nn.Module = self._conv_block(embedding_size // 2,
                                              embedding_size)
     self.pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
     self.flatten = nn.Flatten(start_dim=1)
     self.relu = nn.ReLU()
     self.fc1: nn.Module = nn.Linear(embedding_size, 2 * embedding_size)
     self.fc2: nn.Module = nn.Linear(2 * embedding_size, 2 * embedding_size)
     self.fc3: nn.Module = nn.Linear(2 * embedding_size, embedding_size + 1)
     self.fc4: nn.Module = nn.Linear(embedding_size + 1, embedding_size)
     if with_fsdp:
         self.conv2 = FullyShardedDataParallel(self.conv2,
                                               process_group=process_group)
         self.conv3 = FullyShardedDataParallel(self.conv3,
                                               process_group=process_group,
                                               flatten_parameters=False)
         self.fc1 = FullyShardedDataParallel(self.fc1,
                                             process_group=process_group)
         self.fc3 = FullyShardedDataParallel(self.fc3,
                                             process_group=process_group,
                                             flatten_parameters=False)
コード例 #3
0
 def init_fsdp_model_from_weights(
     cls,
     model: FullyShardedDataParallel,
     checkpoint: Dict[str, Any],
     weights_path: List[str],
     strict: bool = True,
     head_index: int = -1,
 ):
     """
     Load the weights of the checkpoint to the FSDP model:
     - Take into account the type of checkpoint to decide on how
       to perform the load (sharded or consolidated load)
     - Takes into account the head_index (-1 if trunk else >= 0)
       to find the appropriate weights for the head
     """
     if checkpoint["type"] == CheckpointItemType.slice_list.name:
         # Hack for checkpoints consolidated with the "layers" format
         # instead of the new "classy_state_dict" format: in that case
         # the slices are directly saved under "layers" and do not take
         # into account the 'weights_path' variable
         if "classy_state_dict" not in checkpoint:
             weights = checkpoint["layers"]
         else:
             weights = cls._extract_weights(checkpoint, weights_path,
                                            head_index)
         if weights is not None:
             SlicedCheckpointLoader.load_slice_state_dict(model,
                                                          weights,
                                                          strict=strict)
         else:
             raise ValueError(
                 f"Could not find weights path: {weights_path}")
     elif checkpoint["type"] == CheckpointItemType.consolidated.name:
         weights = cls._extract_weights(checkpoint, weights_path,
                                        head_index)
         if weights is not None:
             out = model.load_state_dict(weights, strict=False)
             cls._check_load_state_dict_out(out, strict=strict)
         elif strict:
             raise ValueError(
                 f"Could not find weights path: {weights_path}")
     else:
         weights = cls._extract_weights(checkpoint, weights_path,
                                        head_index)
         if weights is not None:
             out = model.load_local_state_dict(weights, strict=False)
             cls._check_load_state_dict_out(out, strict=strict)
         elif strict:
             raise ValueError(
                 f"Could not find weights path: {weights_path}")
コード例 #4
0
ファイル: checkpoint.py プロジェクト: QuentinDuval/vissl
    def init_fsdp_model_from_weights(
        cls,
        model: FullyShardedDataParallel,
        checkpoint: Dict[str, Any],
        weights_path: List[str],
    ):
        """
        Load the weights of the checkpoint to the FSDP model:
        Take into account the type of checkpoint to decide how
        to perform the load (sharded or consolidated load)
        """

        if checkpoint["type"] == CheckpointItemType.slice_list.name:
            SlicedCheckpointLoader.init_model_weights(model, checkpoint)
        elif checkpoint["type"] == CheckpointItemType.consolidated.name:
            weights = cls._extract_weights(checkpoint, weights_path)
            model.load_state_dict(weights)
        else:
            weights = cls._extract_weights(checkpoint, weights_path)
            model.load_local_state_dict(weights)
コード例 #5
0
def SwavPrototypesHeadFSDP(
    model_config: AttrDict,
    dims: List[int],
    use_bn: bool,
    num_clusters: int,
    use_bias: bool = True,
    return_embeddings: bool = True,
    skip_last_bn: bool = True,
    normalize_feats: bool = True,
):
    """
    SwAV head specific FSDP wrapping: we keep the full precision for the prototypes

    This is important for convergence:
    Since we "normalize" this layer in the update hook, we need to keep its
    weights in full precision. It is output is going into the loss and used
    for clustering, so we need to have that in full precision as well.
    """

    head = SwAVPrototypesHead(
        model_config=model_config,
        dims=dims,
        use_bn=use_bn,
        num_clusters=num_clusters,
        use_bias=use_bias,
        return_embeddings=return_embeddings,
        skip_last_bn=skip_last_bn,
        normalize_feats=normalize_feats,
    )

    fp32_fsdp_config = model_config.FSDP_CONFIG.copy()
    fp32_fsdp_config["flatten_parameters"] = False
    fp32_fsdp_config["mixed_precision"] = False
    fp32_fsdp_config["fp32_reduce_scatter"] = False
    fp32_fsdp_config["compute_dtype"] = torch.float32

    for j in range(head.nmb_heads):
        module = getattr(head, "prototypes" + str(j))
        module = FullyShardedDataParallel(module=module, **fp32_fsdp_config)
        setattr(head, "prototypes" + str(j), module)
    return FullyShardedDataParallel(head)
コード例 #6
0
def _create_model(embedding_size: int,
                  with_fsdp: bool,
                  process_group,
                  flatten_parameters: bool = True):
    model = ConvolutionalModel(with_fsdp=with_fsdp,
                               process_group=process_group,
                               embedding_size=embedding_size).cuda()
    if with_fsdp:
        return FullyShardedDataParallel(model,
                                        process_group=process_group,
                                        flatten_parameters=flatten_parameters)
    else:
        return model
コード例 #7
0
    def _test_consolidate_weights(self,
                                  config,
                                  rank,
                                  group,
                                  paths=None,
                                  transformer=False):
        """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
        # Establish reference behavior.

        if transformer:
            fsdp = self.get_wrapped_model(group, config=config).cuda()
        else:
            fsdp = FullyShardedDataParallel(
                MixtureOfExperts(group, wrapper_config=config)).cuda()

        optim = Adam(
            fsdp.parameters(),
            lr=0.01,
        )
        optim.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):
            x = fsdp.module.get_input(torch.device("cuda"))
            output = fsdp(*x)
            loss = fsdp.module.get_loss(x, output).to("cuda")
            fsdp.module.run_backward(loss)
            optim.step()

        # each worker saves a checkpoint with local_state_dict
        cp_data = {
            "weights":
            {k: v.cpu()
             for k, v in fsdp.local_state_dict().items()},
            "meta": fsdp.local_metadata_dict(),
        }
        torch.save(cp_data, paths[fsdp.rank])
        full_model_state_dict = fsdp.state_dict()
        torch.distributed.barrier()
        if fsdp.rank > 0:
            return
        all_checkpoints = [torch.load(p) for p in paths]
        consolidated_checkpoint = FullyShardedDataParallel.consolidate_shard_weights(
            shard_weights=[c["weights"] for c in all_checkpoints],
            shard_metadata=[c["meta"] for c in all_checkpoints],
        )
        full_model_extra = set(full_model_state_dict).difference(
            set(consolidated_checkpoint))
        consolidated_extra = set(consolidated_checkpoint).difference(
            set(full_model_state_dict))
        msg = f"full model extra keys: {full_model_extra}, consolidated extra {consolidated_extra}"
        for k in full_model_state_dict.keys():
            assert consolidated_checkpoint[k].shape == full_model_state_dict[
                k].shape
        assert set(full_model_state_dict.keys()) == set(
            consolidated_checkpoint.keys()), msg
コード例 #8
0
def test_consolidate_missing_params():
    """This tests that fairseq experts, which are saved independently from the rest of the model, can be consolidated."""
    desired_path = "decoder.layers.1.moe_layer.experts.0"
    shard_metadata = {
        "param_metadata": [
            {
                "fsdp_path": "",
                "params": {
                    "flat_param_0": {
                        "names": ["missing"],
                        "shapes": [(12, 4)],
                        "numels": [12 * 4],
                        "padding": 0
                    }
                },
                "no_broadcast_optim_state": False,
                "shared_param_info": [],
            },
            {
                "fsdp_path": desired_path,
                "params": {
                    "flat_param_0": {
                        "names":
                        ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"],
                        "shapes": [(4, 4), (4, ), (4, 4), (4, )],
                        "numels": [16, 4, 16, 4],
                        "padding":
                        0,
                    }
                },
                "no_broadcast_optim_state": True,
                "shared_param_info": [],
            },
        ],
        "buffer_names": ["missing.buffer"],
    }
    shard_weights = {
        "decoder.layers.1.moe_layer.experts.0.flat_param_0":
        torch.randn(40, dtype=torch.float16)
    }
    consolidated_weights = FullyShardedDataParallel.consolidate_shard_weights(
        [shard_weights], [shard_metadata], strict=False)
    assert len(consolidated_weights) == 4
    for k in consolidated_weights:
        assert k.startswith(
            desired_path), f"{k} doesnt start with {desired_path}"
コード例 #9
0
    def _test_consolidated_optimizer(self,
                                     config,
                                     rank,
                                     group,
                                     optim_fn=torch.optim.SGD,
                                     transformer=False):
        """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
        # Establish reference behavior.

        if transformer:
            fsdp = self.get_wrapped_model(group, config=config).cuda()
            unwrapped_model = TransformerWithSharedParams(group).cuda()
        else:
            fsdp = FullyShardedDataParallel(
                NestedWrappedModule(group, wrapper_config=config), group,
                **config).cuda()
            unwrapped_model = NestedWrappedModule(group,
                                                  wrapper_config=None).cuda()

        try:
            fsdp_optim = optim_fn(
                fsdp.parameters(),
                lr=0.01,
            )
            optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01)
        except TypeError:  # Adadelta
            fsdp_optim = optim_fn(fsdp.parameters())
            optim_unwrapped = optim_fn(unwrapped_model.parameters())

        fsdp_optim.zero_grad()
        optim_unwrapped.zero_grad()

        x = fsdp.module.get_input(torch.device("cuda"))
        output = fsdp(*x)
        loss = fsdp.module.get_loss(x, output).to("cuda")
        fsdp.module.run_backward(loss)
        fsdp_optim.step()

        output = unwrapped_model(*x)
        loss = unwrapped_model.get_loss(x, output)
        unwrapped_model.run_backward(loss)
        optim_unwrapped.step()
        unwrapped_sd = optim_unwrapped.state_dict()

        tstart = time()
        sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
        duration = time() - tstart
        # Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise
        assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"

        if fsdp.rank > 0:
            return

        assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
        assert_equal(len(sd["param_groups"][0]["params"]),
                     len(unwrapped_sd["param_groups"][0]["params"]))
        assert_equal(
            sum([first_tensor_numel(v) for k, v in sd["state"].items()]),
            sum([
                first_tensor_numel(v)
                for k, v in unwrapped_sd["state"].items()
            ]),
        )

        shard_sd = fsdp.get_shard_from_optim_state_dict(sd)

        original_shard_sd = fsdp_optim.state_dict()
        assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"]))
        assert_equal(shard_sd.keys(), original_shard_sd.keys())
        original_shard_sd = recursive_copy_to_device(original_shard_sd,
                                                     non_blocking=False,
                                                     device="cpu")

        assert_equal(
            sum([first_tensor_numel(v) for k, v in shard_sd["state"].items()]),
            sum([
                first_tensor_numel(v)
                for k, v in original_shard_sd["state"].items()
            ]),
        )
        assert objects_are_equal(shard_sd, original_shard_sd)
コード例 #10
0
    def _test_consolidated_optimizer(self,
                                     config,
                                     rank,
                                     group,
                                     optim_fn=torch.optim.SGD,
                                     transformer=False):
        """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
        # Establish reference behavior.

        if transformer:
            unwrapped_model = TransformerWithSharedParams(
                group, wrapper_config=config).cuda()
            fsdp = self.get_wrapped_model(group, config=config).cuda()
        else:
            unwrapped_model = MixtureOfExperts(group,
                                               wrapper_config=None).cuda()
            fsdp = FullyShardedDataParallel(
                MixtureOfExperts(group, wrapper_config=config)).cuda()

        try:
            fsdp_optim = optim_fn(
                fsdp.parameters(),
                lr=0.01,
            )
            optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01)
        except TypeError:  # Adadelta
            fsdp_optim = optim_fn(fsdp.parameters())
            optim_unwrapped = optim_fn(unwrapped_model.parameters())

        fsdp_optim.zero_grad()
        optim_unwrapped.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):
            x = fsdp.module.get_input(torch.device("cuda"))
            output = fsdp(*x)
            loss = fsdp.module.get_loss(x, output).to("cuda")
            fsdp.module.run_backward(loss)
            fsdp_optim.step()

            output = unwrapped_model(*x)
            loss = unwrapped_model.get_loss(x, output)
            unwrapped_model.run_backward(loss)
            optim_unwrapped.step()
        unwrapped_sd = optim_unwrapped.state_dict()

        if not transformer:
            no_broadcast_children = [
                x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state
            ]
            assert len(no_broadcast_children) == 1
            assert fsdp._fsdp_instances[-1].no_broadcast_optim_state
        torch.cuda.empty_cache()
        cuda_gb_before = torch.cuda.memory_stats(
            fsdp.rank)["allocated_bytes.all.current"] / 1024**3
        tstart = time()
        sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
        duration = time() - tstart
        assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"

        cuda_gb_after = torch.cuda.memory_stats(
            fsdp.rank)["allocated_bytes.all.current"] / 1024**3
        mem_usg_gb = cuda_gb_after - cuda_gb_before
        assert mem_usg_gb == 0, f"gather_full_optim_state_dict used {mem_usg_gb:.2f} CUDA GB, max allowed is 0"
        assert cuda_gb_after > 0, "got 0 memory usage, logging is broken"

        if fsdp.rank > 0:
            assert sd is None
            return

        # assert whole state dict on CPU
        for k, v in sd["state"].items():
            for buffer_name, t in v.items():
                if torch.is_tensor(t):
                    msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU"
                    assert t.device == torch.device("cpu"), msg

        unflat_state = sd["state"]
        assert "uncollected_local_ids" in sd
        shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
        shard_sd = recursive_copy_to_device(shard_sd,
                                            non_blocking=False,
                                            device="cpu")
        state_after_get_shard = sd["state"]
        assert objects_are_equal(unflat_state,
                                 state_after_get_shard)  # no side effects.

        assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
        assert_equal(len(sd["param_groups"][0]["params"]),
                     len(unwrapped_sd["param_groups"][0]["params"]))
        assert_equal(
            sum([first_tensor_numel(v) for k, v in sd["state"].items()]),
            sum([
                first_tensor_numel(v)
                for k, v in unwrapped_sd["state"].items()
            ]),
        )

        original_shard_sd = fsdp_optim.state_dict()
        assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"]))
        assert_equal(shard_sd.keys(), original_shard_sd.keys())
        original_shard_sd = recursive_copy_to_device(original_shard_sd,
                                                     non_blocking=False,
                                                     device="cpu")
        # Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks.
        assert_equal(
            [first_tensor_numel(v) for k, v in shard_sd["state"].items()],
            [
                first_tensor_numel(v)
                for k, v in original_shard_sd["state"].items()
            ],
        )
        assert_equal(
            [v for k, v in shard_sd["param_groups"][0].items()],
            [v for k, v in original_shard_sd["param_groups"][0].items()],
        )
        assert objects_are_equal(shard_sd["state"], original_shard_sd["state"])
        assert objects_are_equal({k: shard_sd[k]
                                  for k in original_shard_sd},
                                 original_shard_sd)
コード例 #11
0
def _worker(gpu_id: int, sync_file: str, world_size: int, embedding_size: int, flatten_parameters: bool):
    torch.manual_seed(0)
    torch.cuda.set_device(gpu_id)
    torch.distributed.init_process_group(
        backend="nccl", init_method=f"file://{sync_file}", world_size=world_size, rank=gpu_id,
    )
    process_group = torch.distributed.new_group()

    # Create a dummy model with dummy inputs and targets
    batch_size = 4
    input = torch.randn(size=(batch_size, 3, 32, 32)).cuda()
    target = torch.zeros(size=(batch_size, embedding_size)).cuda()
    model = _create_model(
        with_fsdp=True,
        process_group=process_group,
        embedding_size=embedding_size,
        flatten_parameters=flatten_parameters,
    )
    criterion = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

    # Train the model for a few epochs
    for epoch in range(2):
        out = model(input)
        loss = criterion(out, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Save a bunch of checkpoint, one by shard
    cp_data = {
        "weights": {k: v.cpu() for k, v in model.local_state_dict().items()},
        "meta": model.local_metadata_dict(),
    }
    torch.save(cp_data, f"checkpoint_{gpu_id}.torch")

    # Wait for all files to be written on the disk
    dist.barrier()  # type: ignore

    # Reconstruct a full checkpoint from the sharded checkpoints
    all_checkpoints = [_load_sharded_checkpoint(rank) for rank in range(world_size)]
    consolidated_checkpoint = FullyShardedDataParallel.consolidate_shard_weights(
        shard_weights=[c["weights"] for c in all_checkpoints], shard_metadata=[c["meta"] for c in all_checkpoints],
    )

    # Check that the reconstructed parameters are correct and of the right shape
    full_model = _create_model(with_fsdp=False, process_group=process_group, embedding_size=embedding_size)
    full_model_state_dict = full_model.state_dict()
    assert set(full_model_state_dict.keys()) == set(consolidated_checkpoint.keys())
    for k in full_model_state_dict.keys():
        assert consolidated_checkpoint[k].shape == full_model_state_dict[k].shape

    # Verify that the checkpoint can be loaded by a FSDP model
    loaded_model = _create_model(
        with_fsdp=True,
        process_group=process_group,
        embedding_size=embedding_size,
        flatten_parameters=flatten_parameters,
    )
    loaded_model.load_state_dict(consolidated_checkpoint)
    for m in loaded_model.modules():
        if isinstance(m, FullyShardedDataParallel):
            m._reset_lazy_init()

    # Verify that the model saved and the model loaded give the same results
    with torch.no_grad():
        before_checkpoint_loss = criterion(model(input), target).item()
        after_checkpoint_loss = criterion(loaded_model(input), target).item()
        assert before_checkpoint_loss == after_checkpoint_loss
コード例 #12
0
ファイル: checkpoint.py プロジェクト: QuentinDuval/vissl
 def _consolidate_shards(cls, weights: List[Dict[str, torch.Tensor]],
                         metadata: List[Dict[str, Any]]):
     logging.info("Consolidating shards...")
     return FullyShardedDataParallel.consolidate_shard_weights(
         weights, metadata)
コード例 #13
0
def _layer_memory_tracking_fsdp_worker(gpu_id: int, sync_files: Tuple[str,
                                                                      str],
                                       world_size: int):
    dist_init(world_size=world_size,
              rank=gpu_id,
              filename=sync_files[0],
              filename_rpc=sync_files[1])
    torch.backends.cudnn.deterministic = True

    # Create different inputs on each GPU
    batch_size = 16
    torch.manual_seed(gpu_id)
    fake_inputs = torch.randn(size=(batch_size, 10)).cuda(gpu_id)
    fake_targets = torch.randn(size=(batch_size, 10)).cuda(gpu_id)
    fake_criterion = nn.MSELoss()

    # Create a global group and a tracker around it
    group = dist.new_group()
    group = ProcessGroupTracker(group)

    # Create a simple model
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    model = nn.Sequential(
        nn.Linear(10, 10).cuda(gpu_id),
        nn.ReLU(),
        FullyShardedDataParallel(
            nn.Linear(10, 10).cuda(gpu_id),
            flatten_parameters=False,
            process_group=group,
        ),
        nn.ReLU(),
        FullyShardedDataParallel(
            nn.Linear(10, 10).cuda(gpu_id),
            flatten_parameters=True,
            process_group=group,
        ),
    )
    model = model.cuda(gpu_id)
    dist_model = FullyShardedDataParallel(model,
                                          flatten_parameters=False,
                                          process_group=group)

    # Track the model on a forward / backward pass
    tracker = LayerwiseMemoryTracker()
    tracker.monitor(dist_model)
    fake_criterion(dist_model(fake_inputs), fake_targets).backward()
    tracker.stop()

    # Check results of all gathers tracking (feature specific to FSDP)
    all_gathered_traces = [
        (t.module_name, t.all_gathered, t.cumul_all_gathered)
        for t in tracker.memory_traces if t.all_gathered > 0
    ]
    assert all_gathered_traces == [
        ("_fsdp_wrapped_module._fpw_module.0", 440, 440),
        ("_fsdp_wrapped_module._fpw_module.2._fsdp_wrapped_module._fpw_module",
         440, 880),
        ("_fsdp_wrapped_module._fpw_module.4._fsdp_wrapped_module._fpw_module",
         440, 880),
        ("_fsdp_wrapped_module._fpw_module.4._fsdp_wrapped_module._fpw_module",
         440, 0),
        ("_fsdp_wrapped_module._fpw_module.2._fsdp_wrapped_module._fpw_module",
         440, 0),
    ], all_gathered_traces
コード例 #14
0
ファイル: test_larc_fsdp.py プロジェクト: zlapp/vissl
    def _norm_computation_worker(gpu_id: int, sync_file: str, world_size: int):
        init_distributed_on_file(world_size=world_size,
                                 gpu_id=gpu_id,
                                 sync_file=sync_file)
        torch.manual_seed(0)
        torch.backends.cudnn.deterministic = True

        num_iterations = 10
        batch_size = 128
        torch.manual_seed(gpu_id)
        fake_inputs = torch.randn(size=(num_iterations, batch_size, 129))
        fake_targets = torch.randn(size=(num_iterations, batch_size))

        losses = {}
        for with_fsdp in [False, True]:
            torch.manual_seed(0)
            torch.cuda.manual_seed(0)
            losses[with_fsdp] = []

            # Create a simple model
            model = nn.Sequential(nn.Linear(129, 128), nn.ReLU(),
                                  nn.Linear(128, 10))
            model = model.cuda(gpu_id)

            # Setting up FSDP vs DDP with LARC
            larc_config = {
                "clip": False,
                "trust_coefficient": 0.01,
                "eps": 0.00000001
            }
            optimizer = optim.SGD(model.parameters(),
                                  lr=1e-2,
                                  weight_decay=1e-4,
                                  momentum=0.9)
            if with_fsdp:
                model = FullyShardedDataParallel(model,
                                                 flatten_parameters=False)
                optimizer = LARC_FSDP(optimizer,
                                      distributed_norm=True,
                                      **larc_config)
            else:
                model = DistributedDataParallel(model, device_ids=[gpu_id])
                optimizer = LARC_FSDP(optimizer,
                                      distributed_norm=False,
                                      **larc_config)

            # Training loop
            criterion = nn.MSELoss()
            for iteration in range(num_iterations):
                fake_input = fake_inputs[iteration].cuda(gpu_id)
                fake_target = fake_targets[iteration].cuda(gpu_id)
                output = model(fake_input)
                loss = criterion(output.sum(axis=-1), fake_target)
                if gpu_id == 0:
                    losses[with_fsdp].append(loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        if gpu_id == 0:
            for with_fsdp in [False, True]:
                print(losses[with_fsdp])
                if world_size > 1:
                    losses[with_fsdp] = [
                        round(loss, 5) for loss in losses[with_fsdp]
                    ]
            assert losses[False] == losses[True]
コード例 #15
0
    def _layer_memory_tracking_worker(gpu_id: int, sync_file: str,
                                      world_size: int):
        init_distributed_on_file(world_size=world_size,
                                 gpu_id=gpu_id,
                                 sync_file=sync_file)
        torch.manual_seed(0)
        torch.backends.cudnn.deterministic = True
        torch.manual_seed(gpu_id)

        batch_size = 16
        fake_inputs = torch.randn(size=(batch_size, 10)).cuda(gpu_id)
        fake_targets = torch.randn(size=(batch_size, 10)).cuda(gpu_id)
        fake_criterion = nn.MSELoss()

        torch.manual_seed(0)
        torch.cuda.manual_seed(0)

        # Create a global group and a tracker around it
        group = dist.new_group()
        group = ProcessGroupTracker(group)

        # Create a simple model
        model = nn.Sequential(
            nn.Linear(10, 10).cuda(gpu_id),
            nn.ReLU(),
            FullyShardedDataParallel(
                nn.Linear(10, 10).cuda(gpu_id),
                flatten_parameters=False,
                process_group=group,
            ),
            nn.ReLU(),
            FullyShardedDataParallel(
                nn.Linear(10, 10).cuda(gpu_id),
                flatten_parameters=True,
                process_group=group,
            ),
        )
        model = model.cuda(gpu_id)
        model = FullyShardedDataParallel(model,
                                         flatten_parameters=False,
                                         process_group=group)

        # Setup the tracking of the model
        tracker = LayerwiseMemoryTracker()
        tracker.monitor(model)

        # Fake forward / backward pass
        fake_criterion(model(fake_inputs), fake_targets).backward()

        # Collect results of all gathers (the feature specific to FSDP)
        tracker.stop()
        all_gathered_traces = [
            (t.module_name, t.all_gathered, t.cumul_all_gathered)
            for t in tracker.memory_traces if t.all_gathered > 0
        ]
        assert all_gathered_traces == [
            ("_fsdp_wrapped_module.0", 440, 440),
            ("_fsdp_wrapped_module.2._fsdp_wrapped_module", 440, 880),
            ("_fsdp_wrapped_module.4._fsdp_wrapped_module._fpw_module", 440,
             880),
            ("_fsdp_wrapped_module.4._fsdp_wrapped_module._fpw_module", 440,
             0),
            ("_fsdp_wrapped_module.2._fsdp_wrapped_module", 440, 0),
        ]