예제 #1
0
    def _test_identical_outputs(self,
                                model_init_fn,
                                ref_ddp_fn=None,
                                num_steps=2,
                                use_cuda=True,
                                lr=0.01):
        group = dist.distributed_c10d._get_default_group()
        rank = group.rank()
        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrap_fsdp=False).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(model,
                                                        device_ids=[rank],
                                                        output_device=rank)
        else:
            model = ref_ddp_fn(model)
        ref_loss = self._train_for_several_steps(model,
                                                 num_steps,
                                                 autocast=False,
                                                 lr=lr)
        ref_full_params = list(model.parameters())

        # Confirm we get the same behavior using FullyShardedDataParallel.
        model = model_init_fn(group=group, wrap_fsdp=True)
        model = FullyShardedDataParallel(model)
        if use_cuda:
            model = model.cuda()
        else:
            assert next(model.parameters()).device == torch.device(
                "cpu"
            ), "module parameters should be placed on cpu if use_cuda is False."
        shard_loss = self._train_for_several_steps(model,
                                                   num_steps,
                                                   autocast=False,
                                                   lr=lr)
        get_full_params(model)
        shard_full_params = list(model.parameters())

        torch.testing.assert_allclose(ref_loss, shard_loss)
        self.assertEqual(
            ref_full_params,
            shard_full_params,
            exact_device=True,
            msg="FullyShardedDataParallel didn't match PyTorch DDP",
        )
예제 #2
0
    def _test_identical_outputs(
        self,
        model_init_fn,
        *args,
        ref_ddp_fn=None,
        num_steps=2,
        fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
        lr=0.01,
        cpu_offload=CPUOffload(),
        backward_prefetch=None,
        **kwargs
    ):
        group = dist.distributed_c10d._get_default_group()
        rank = group.rank()
        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrap_fsdp=False).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[rank], output_device=rank
            )
        else:
            model = ref_ddp_fn(model)
        ref_loss = self._train_for_several_steps(
            model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload
        )
        ref_full_params = list(model.parameters())

        # Confirm we get the same behavior using FullyShardedDataParallel.
        try:
            model = model_init_fn(
                group=group,
                wrap_fsdp=True,
                fsdp_init_mode=fsdp_init_mode,
                cpu_offload=cpu_offload,
                backward_prefetch=backward_prefetch
            )
        except Exception as e:
            raise ValueError(f"model_Init_fn {model_init_fn} got error {str(e)}")

        cpu_offload = cpu_offload or CPUOffload()  # disabled if not specified.
        model = FullyShardedDataParallel(model, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch)
        # Call model.cuda() after init FSDP if specified.
        if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
            model = model.cuda()

        # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we
        # expect FSDP code to raise error that we check below, in the case of
        # offload params.
        if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
            for p in model.parameters():
                # Should be on CPU regardless of if param is sharded.
                self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}")

        only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params
        ctx = (
            self.assertRaisesRegex(AssertionError, "Expected param to be on CPU")
            if only_check_err else suppress()
        )
        with ctx:
            shard_loss = self._train_for_several_steps(
                model, num_steps, autocast=False, lr=lr,
                fsdp_cpu_offload=cpu_offload,
            )
        # We only check for errors in the case we have the following setup:
        # model = FSDP(model, cpu_offload=True)
        # model = model.cuda()
        # so skip the rest of this logic.
        if only_check_err:
            return
        # If CPU offload, next call will change model params to GPU. Sanity
        # check that params are on CPU before.
        if cpu_offload.offload_params:
            device_set = {p.device for p in model.parameters()}
            self.assertEqual(
                {torch.device("cpu")},
                device_set,
                f"Got device set {device_set}"
            )
        get_full_params(model)
        shard_full_params = list(model.parameters())

        if cpu_offload.offload_params:
            shard_loss = shard_loss.cuda()
        torch.testing.assert_allclose(ref_loss, shard_loss)
        self.assertEqual(
            ref_full_params,
            shard_full_params,
            exact_device=True,
            msg="FullyShardedDataParallel didn't match PyTorch DDP",
        )