Пример #1
0
    def __init__(self, group, wrap_fsdp, delay_before_free_ms=0):
        super().__init__(group, wrap_fsdp)
        self.group = group
        self.delay_before_free_ms = delay_before_free_ms
        self.wrap_fsdp = wrap_fsdp

        # "expert" params are different on each rank
        torch.manual_seed(42 + group.rank())
        d_expert = 23
        d_shared = 12
        d_input = 8
        expert = nn.Linear(d_expert, d_shared)

        self.num_expert_params = sum([p.numel() for p in expert.parameters()])
        for p in expert.parameters():
            p.expert = True  # type: ignore[attr-defined]

        # everything else is shared
        torch.manual_seed(0)

        shared = nn.Linear(d_shared, d_expert)

        if wrap_fsdp:
            # we create a process group of size 1 for the expert params
            expert_group = torch.distributed.new_group(
                [group.rank()])  # world size 1 means no shard
            expert = FullyShardedDataParallel(
                expert, expert_group)  # type: ignore[assignment]

            shared = FullyShardedDataParallel(
                shared, group)  # type: ignore[assignment]

        self.module = nn.Sequential(nn.Linear(d_input, d_shared), shared,
                                    expert, nn.Linear(d_shared, d_input))
    def test_multiple_wrapping(self):
        """
        This test simulates wrapping the module after training to run inference.
        This is required in cases where later in a session, the model is wrapped again in FSDP but
        contains nested FSDP wrappers within the module.
        """
        inner_model = InnerModel()
        model = FSDP(inner_model).cuda()
        optim = SGD(model.parameters(), lr=0.1)

        for i in range(3):
            input = torch.rand((1, 5), dtype=torch.float).cuda()
            input.requires_grad = True
            output = model(input)
            output.sum().backward()
            optim.step()
            optim.zero_grad()
        input = torch.rand((1, 5), dtype=torch.float).cuda()
        output = model(input)

        # second time to rewrap the inner model
        rewrapped_model = FSDP(inner_model).cuda()
        rewrapped_output = rewrapped_model(input)

        self.assertEqual(output, rewrapped_output)
Пример #3
0
    def test_input_type(self, input_cls):
        """Test FSDP with input being a list or a dict, only single GPU."""

        class Model(Module):
            def __init__(self):
                super().__init__()
                self.layer = Linear(4, 4)

            def forward(self, input):
                if isinstance(input, list):
                    input = input[0]
                else:
                    assert isinstance(input, dict), input
                    input = input["in"]
                return self.layer(input)

        model = FSDP(Model()).cuda()
        optim = SGD(model.parameters(), lr=0.1)

        for _ in range(5):
            in_data = torch.rand(64, 4).cuda()
            in_data.requires_grad = True
            if input_cls is list:
                in_data = [in_data]
            else:
                self.assertTrue(input_cls is dict)
                in_data = {"in": in_data}

            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()
Пример #4
0
 def _get_wrapped_model(
     self, group, cuda_first=False, **model_kwargs
 ) -> FullyShardedDataParallel:
     if cuda_first:
         model = FullyShardedDataParallel(
             TransformerWithSharedParams(group, **model_kwargs).cuda(), group
         )
     else:
         model = FullyShardedDataParallel(
             TransformerWithSharedParams(group, **model_kwargs), group
         ).cuda()
     return model
Пример #5
0
    def _dist_train(self, with_checkpoint, expected, model_hidden_dim,
                    iterations):
        gpu_id = self.rank
        world_size = self.world_size

        batch = torch.randn(size=(2, 3, 224, 224)).cuda()

        model = create_model(
            with_fsdp=True,
            with_checkpoint=with_checkpoint,
            model_hidden_dim=model_hidden_dim,
        )
        model = model.cuda()
        model = FSDP(model)

        # We enable momentum so that after the first iteration, the optimizer state is added
        # to the total memory used.
        criterion = nn.MSELoss()
        optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

        results = {}  # results of memory stats
        for iteration in range(iterations):
            get_cur_mem(gpu_id, results, f"iter {iteration}: start")

            out = model(batch)
            get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")

            out = sum(o.sum() for o in out[0])
            fake_loss = criterion(out, torch.tensor(0.0).cuda())
            get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")

            fake_loss.backward()
            get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")

            optimizer.step()
            get_cur_mem(gpu_id, results, f"iter {iteration}: after step")

            # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
            model.zero_grad(set_to_none=True)
            get_cur_mem(gpu_id, results, f"iter {iteration}: done")

        def cmp(results, expected):
            ret = ""
            self.assertEqual(results.keys(), expected.keys())
            for k, v in results.items():
                exp = expected[k]
                if abs(exp - v) > 1:  # allow 1MB rounding differences
                    ret += f"{k}: got {v}, expected {exp}\n"
            return ret

        output = cmp(results, expected)
        self.assertEqual(output, "")
Пример #6
0
    def _test_identical_outputs(self,
                                model_init_fn,
                                ref_ddp_fn=None,
                                num_steps=2,
                                use_cuda=True,
                                lr=0.01):
        group = dist.distributed_c10d._get_default_group()
        rank = group.rank()
        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrap_fsdp=False).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(model,
                                                        device_ids=[rank],
                                                        output_device=rank)
        else:
            model = ref_ddp_fn(model)
        ref_loss = self._train_for_several_steps(model,
                                                 num_steps,
                                                 autocast=False,
                                                 lr=lr)
        ref_full_params = list(model.parameters())

        # Confirm we get the same behavior using FullyShardedDataParallel.
        model = model_init_fn(group=group, wrap_fsdp=True)
        model = FullyShardedDataParallel(model)
        if use_cuda:
            model = model.cuda()
        else:
            assert next(model.parameters()).device == torch.device(
                "cpu"
            ), "module parameters should be placed on cpu if use_cuda is False."
        shard_loss = self._train_for_several_steps(model,
                                                   num_steps,
                                                   autocast=False,
                                                   lr=lr)
        get_full_params(model)
        shard_full_params = list(model.parameters())

        torch.testing.assert_allclose(ref_loss, shard_loss)
        self.assertEqual(
            ref_full_params,
            shard_full_params,
            exact_device=True,
            msg="FullyShardedDataParallel didn't match PyTorch DDP",
        )
    def _dist_train(
        self, with_nested_trunk, freezing_method, freeze_after_wrap_fsdp, with_fsdp
    ):
        torch.manual_seed(0)
        batch = torch.randn(size=(2, 3, 224, 224)).cuda()

        model = self._create_model(with_fsdp, with_nested_trunk, freeze_after_wrap_fsdp)
        model = model.cuda()

        # freezing the trunk using requires_grad.
        if freezing_method == FreezingMethod.RequiresGrad:
            for param in model.trunk.parameters():
                param.requires_grad = False

        if with_fsdp:
            if not freeze_after_wrap_fsdp:
                model.fsdp_wrap()
            model = FSDP(model)
        else:
            model = DistributedDataParallel(model, device_ids=[self.rank])

        target = torch.tensor([0, 1], dtype=torch.long).cuda()
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

        for iteration in range(3):
            out = model(batch)
            fake_loss = criterion(out, target)
            optimizer.zero_grad()
            fake_loss.backward()
            if freezing_method == FreezingMethod.GradToNone:
                if with_fsdp:
                    for param in model.module.module.trunk.parameters():
                        param.grad = None
                else:
                    for param in model.module.trunk.parameters():
                        param.grad = None
            optimizer.step()

        if with_fsdp:
            get_full_params(model)

        return list(model.parameters())
Пример #8
0
def _maybe_wrap_fsdp(model, wrap_fsdp, *args, **kwargs):
    return (
        model if not wrap_fsdp
        else FullyShardedDataParallel(model, *args, **kwargs)
    )
Пример #9
0
    def _test_identical_outputs(
        self,
        model_init_fn,
        *args,
        ref_ddp_fn=None,
        num_steps=2,
        fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
        lr=0.01,
        cpu_offload=CPUOffload(),
        backward_prefetch=None,
        **kwargs
    ):
        group = dist.distributed_c10d._get_default_group()
        rank = group.rank()
        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrap_fsdp=False).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[rank], output_device=rank
            )
        else:
            model = ref_ddp_fn(model)
        ref_loss = self._train_for_several_steps(
            model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload
        )
        ref_full_params = list(model.parameters())

        # Confirm we get the same behavior using FullyShardedDataParallel.
        try:
            model = model_init_fn(
                group=group,
                wrap_fsdp=True,
                fsdp_init_mode=fsdp_init_mode,
                cpu_offload=cpu_offload,
                backward_prefetch=backward_prefetch
            )
        except Exception as e:
            raise ValueError(f"model_Init_fn {model_init_fn} got error {str(e)}")

        cpu_offload = cpu_offload or CPUOffload()  # disabled if not specified.
        model = FullyShardedDataParallel(model, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch)
        # Call model.cuda() after init FSDP if specified.
        if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
            model = model.cuda()

        # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we
        # expect FSDP code to raise error that we check below, in the case of
        # offload params.
        if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
            for p in model.parameters():
                # Should be on CPU regardless of if param is sharded.
                self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}")

        only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params
        ctx = (
            self.assertRaisesRegex(AssertionError, "Expected param to be on CPU")
            if only_check_err else suppress()
        )
        with ctx:
            shard_loss = self._train_for_several_steps(
                model, num_steps, autocast=False, lr=lr,
                fsdp_cpu_offload=cpu_offload,
            )
        # We only check for errors in the case we have the following setup:
        # model = FSDP(model, cpu_offload=True)
        # model = model.cuda()
        # so skip the rest of this logic.
        if only_check_err:
            return
        # If CPU offload, next call will change model params to GPU. Sanity
        # check that params are on CPU before.
        if cpu_offload.offload_params:
            device_set = {p.device for p in model.parameters()}
            self.assertEqual(
                {torch.device("cpu")},
                device_set,
                f"Got device set {device_set}"
            )
        get_full_params(model)
        shard_full_params = list(model.parameters())

        if cpu_offload.offload_params:
            shard_loss = shard_loss.cuda()
        torch.testing.assert_allclose(ref_loss, shard_loss)
        self.assertEqual(
            ref_full_params,
            shard_full_params,
            exact_device=True,
            msg="FullyShardedDataParallel didn't match PyTorch DDP",
        )
Пример #10
0
 def _maybe_wrap(layer):
     if wrap_fsdp:
         return FullyShardedDataParallel(layer, group, *args, **kwargs)
     return layer