コード例 #1
0
ファイル: common_fsdp.py プロジェクト: standbyme/pytorch
    def __init__(self, group, wrap_fsdp, delay_before_free_ms=0):
        super().__init__(group, wrap_fsdp)
        self.group = group
        self.delay_before_free_ms = delay_before_free_ms
        self.wrap_fsdp = wrap_fsdp

        # "expert" params are different on each rank
        torch.manual_seed(42 + group.rank())
        d_expert = 23
        d_shared = 12
        d_input = 8
        expert = nn.Linear(d_expert, d_shared)

        self.num_expert_params = sum([p.numel() for p in expert.parameters()])
        for p in expert.parameters():
            p.expert = True  # type: ignore[attr-defined]

        # everything else is shared
        torch.manual_seed(0)

        shared = nn.Linear(d_shared, d_expert)

        if wrap_fsdp:
            # we create a process group of size 1 for the expert params
            expert_group = torch.distributed.new_group(
                [group.rank()])  # world size 1 means no shard
            expert = FullyShardedDataParallel(
                expert, expert_group)  # type: ignore[assignment]

            shared = FullyShardedDataParallel(
                shared, group)  # type: ignore[assignment]

        self.module = nn.Sequential(nn.Linear(d_input, d_shared), shared,
                                    expert, nn.Linear(d_shared, d_input))
コード例 #2
0
 def _get_wrapped_model(
     self, group, cuda_first=False, **model_kwargs
 ) -> FullyShardedDataParallel:
     if cuda_first:
         model = FullyShardedDataParallel(
             TransformerWithSharedParams(group, **model_kwargs).cuda(), group
         )
     else:
         model = FullyShardedDataParallel(
             TransformerWithSharedParams(group, **model_kwargs), group
         ).cuda()
     return model
コード例 #3
0
ファイル: common_fsdp.py プロジェクト: standbyme/pytorch
    def _test_identical_outputs(self,
                                model_init_fn,
                                ref_ddp_fn=None,
                                num_steps=2,
                                use_cuda=True,
                                lr=0.01):
        group = dist.distributed_c10d._get_default_group()
        rank = group.rank()
        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrap_fsdp=False).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(model,
                                                        device_ids=[rank],
                                                        output_device=rank)
        else:
            model = ref_ddp_fn(model)
        ref_loss = self._train_for_several_steps(model,
                                                 num_steps,
                                                 autocast=False,
                                                 lr=lr)
        ref_full_params = list(model.parameters())

        # Confirm we get the same behavior using FullyShardedDataParallel.
        model = model_init_fn(group=group, wrap_fsdp=True)
        model = FullyShardedDataParallel(model)
        if use_cuda:
            model = model.cuda()
        else:
            assert next(model.parameters()).device == torch.device(
                "cpu"
            ), "module parameters should be placed on cpu if use_cuda is False."
        shard_loss = self._train_for_several_steps(model,
                                                   num_steps,
                                                   autocast=False,
                                                   lr=lr)
        get_full_params(model)
        shard_full_params = list(model.parameters())

        torch.testing.assert_allclose(ref_loss, shard_loss)
        self.assertEqual(
            ref_full_params,
            shard_full_params,
            exact_device=True,
            msg="FullyShardedDataParallel didn't match PyTorch DDP",
        )
コード例 #4
0
def _maybe_wrap_fsdp(model, wrap_fsdp, *args, **kwargs):
    return (
        model if not wrap_fsdp
        else FullyShardedDataParallel(model, *args, **kwargs)
    )
コード例 #5
0
    def _test_identical_outputs(
        self,
        model_init_fn,
        *args,
        ref_ddp_fn=None,
        num_steps=2,
        fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
        lr=0.01,
        cpu_offload=CPUOffload(),
        backward_prefetch=None,
        **kwargs
    ):
        group = dist.distributed_c10d._get_default_group()
        rank = group.rank()
        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrap_fsdp=False).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[rank], output_device=rank
            )
        else:
            model = ref_ddp_fn(model)
        ref_loss = self._train_for_several_steps(
            model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload
        )
        ref_full_params = list(model.parameters())

        # Confirm we get the same behavior using FullyShardedDataParallel.
        try:
            model = model_init_fn(
                group=group,
                wrap_fsdp=True,
                fsdp_init_mode=fsdp_init_mode,
                cpu_offload=cpu_offload,
                backward_prefetch=backward_prefetch
            )
        except Exception as e:
            raise ValueError(f"model_Init_fn {model_init_fn} got error {str(e)}")

        cpu_offload = cpu_offload or CPUOffload()  # disabled if not specified.
        model = FullyShardedDataParallel(model, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch)
        # Call model.cuda() after init FSDP if specified.
        if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
            model = model.cuda()

        # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we
        # expect FSDP code to raise error that we check below, in the case of
        # offload params.
        if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
            for p in model.parameters():
                # Should be on CPU regardless of if param is sharded.
                self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}")

        only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params
        ctx = (
            self.assertRaisesRegex(AssertionError, "Expected param to be on CPU")
            if only_check_err else suppress()
        )
        with ctx:
            shard_loss = self._train_for_several_steps(
                model, num_steps, autocast=False, lr=lr,
                fsdp_cpu_offload=cpu_offload,
            )
        # We only check for errors in the case we have the following setup:
        # model = FSDP(model, cpu_offload=True)
        # model = model.cuda()
        # so skip the rest of this logic.
        if only_check_err:
            return
        # If CPU offload, next call will change model params to GPU. Sanity
        # check that params are on CPU before.
        if cpu_offload.offload_params:
            device_set = {p.device for p in model.parameters()}
            self.assertEqual(
                {torch.device("cpu")},
                device_set,
                f"Got device set {device_set}"
            )
        get_full_params(model)
        shard_full_params = list(model.parameters())

        if cpu_offload.offload_params:
            shard_loss = shard_loss.cuda()
        torch.testing.assert_allclose(ref_loss, shard_loss)
        self.assertEqual(
            ref_full_params,
            shard_full_params,
            exact_device=True,
            msg="FullyShardedDataParallel didn't match PyTorch DDP",
        )
コード例 #6
0
 def _maybe_wrap(layer):
     if wrap_fsdp:
         return FullyShardedDataParallel(layer, group, *args, **kwargs)
     return layer