def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu):
        offload = CPUOffload(offload_params=True)
        model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload),
                     cpu_offload=offload)
        local_model = DeterministicModel(wrap_fsdp=False)

        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))

        params_to_compare = ([
            p.clone() for p in model.parameters()
        ] if rank0_only and self.rank != 0 else list(local_model.parameters()))

        with model.summon_full_params(
                model,
                recurse=True,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            # Below sleep causes failures without stream synchronization in
            # summon_full_params fix.
            torch.cuda._sleep(1000000)
            # FSDP param deepcopy() of params has issues
            fsdp_params = [p.clone() for p in model.parameters()]

        self.assertEqual(fsdp_params, params_to_compare)
示例#2
0
 def test_summon_full_param_writeback(self, writeback, modify_outer):
     return _run_test_summon_full_param_writeback(
         self,
         writeback,
         cpu_offload=CPUOffload(offload_params=False),
         modify_outer=modify_outer,
     )
示例#3
0
    def _dist_train(self,
                    wrap_fsdp,
                    cpu_offload=CPUOffload(offload_params=False)):
        # keep everything deterministic for input data
        torch.manual_seed(0)

        model = Model(wrap_fsdp, cpu_offload)
        if wrap_fsdp:
            model = FSDP(model, cpu_offload=cpu_offload)
        else:
            model = DistributedDataParallel(model, device_ids=[self.rank])
        model.half()
        optim = SGD(model.parameters(), lr=0.1)

        in_data = torch.rand(16, 2).cuda().half()
        in_data.requires_grad = True
        for _ in range(1):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        if wrap_fsdp:
            get_full_params(model)

        return list(model.parameters())
示例#4
0
 def __init__(self, wrap_fsdp, cpu_offload=CPUOffload(offload_params=False)):
     super().__init__()
     # keep everything deterministic for model initialization
     torch.manual_seed(0)
     self.inner = torch.nn.Linear(2, 2).cuda()
     if wrap_fsdp:
         self.inner = FullyShardedDataParallel(self.inner, cpu_offload=cpu_offload)
     self.outer = torch.nn.Linear(2, 2).cuda()
 def test_mixed_precision_e2e_full_shard(self):
     mp = default_mp if not nccl_supports_bf16 else mp_diff_buffer_and_reduce
     self._run_test_mixed_precision_e2e(
         mp_config=mp,
         cpu_offload=CPUOffload(offload_params=True),
         backward_prefetch=None,
         full_precision_param_dtype=torch.float64,
         sharding_strategy=ShardingStrategy.FULL_SHARD,
     )
 def test_summon_full_param_writeback(self, writeback, modify_outer,
                                      mixed_precision):
     mixed_precision = MixedPrecision() if mixed_precision else None
     return _run_test_summon_full_param_writeback(
         self,
         writeback,
         modify_outer=modify_outer,
         cpu_offload=CPUOffload(offload_params=False),
         mixed_precision=mixed_precision,
     )
示例#7
0
class TestPureFP16(FSDPTest):

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
    )
    def test_pure_fp16(self, cpu_offload: CPUOffload):
        """Tests pure FP16 training, including when the parameter's dtype is
        changed after FSDP initialization and before training."""
        self._test_fsdp_parity(
            NestedWrappedModule,
            FSDPInitMode.RECURSIVE,
            cuda_init_mode=CUDAInitMode.CUDA_AFTER,
            # Run one iteration to avoid NaN without a gradient scaler
            num_iters=1,
            cpu_offload=cpu_offload,
            use_pure_fp16=True,
        )
示例#8
0
class TestPureFP16(FSDPTest):
    def _dist_train(self,
                    wrap_fsdp,
                    cpu_offload=CPUOffload(offload_params=False)):
        # keep everything deterministic for input data
        torch.manual_seed(0)

        model = Model(wrap_fsdp, cpu_offload)
        if wrap_fsdp:
            model = FSDP(model, cpu_offload=cpu_offload)
        else:
            model = DistributedDataParallel(model, device_ids=[self.rank])
        model.half()
        optim = SGD(model.parameters(), lr=0.1)

        in_data = torch.rand(16, 2).cuda().half()
        in_data.requires_grad = True
        for _ in range(1):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        if wrap_fsdp:
            get_full_params(model)

        return list(model.parameters())

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)],
    )
    def test_pure_fp16(self, cpu_offload):
        # DDP
        ddp_state = self._dist_train(wrap_fsdp=False)

        # FSDP
        fsdp_state = self._dist_train(wrap_fsdp=True, cpu_offload=cpu_offload)

        self.assertEqual(ddp_state, fsdp_state)
 def test_mixed_precision_no_reshard_after_forward(self):
     # Note that we don't exercise all possible different configs so as to
     # not increase test TTS too much.
     mp = default_mp if not nccl_supports_bf16 else mp_diff_buffer_and_reduce
     self._run_test_mixed_precision_e2e(
         mp_config=mp,
         cpu_offload=CPUOffload(offload_params=True),
         backward_prefetch=None,
         full_precision_param_dtype=torch.float64,
         sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
     )
class TestSummonFullParamsNoShard(FSDPTest):
    @property
    def world_size(self):
        return 1  # does not shard

    @skip_if_lt_x_gpu(2)
    @parametrize("writeback", [True, False])
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)],
    )
    @parametrize("modify_outer", [True, False])
    def test_summon_full_param_writeback(self, writeback, cpu_offload,
                                         modify_outer):
        return _run_test_summon_full_param_writeback(
            self,
            writeback,
            cpu_offload,
            modify_outer,
        )
    def test_summon_full_params_equivalence(self):
        offload = CPUOffload(offload_params=True)
        model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload),
                     cpu_offload=offload)
        local_model = DeterministicModel(wrap_fsdp=False)

        with model.summon_full_params(recurse=True):
            # Below sleep causes failures without stream synchronization in
            # summon_full_params fix.
            torch.cuda._sleep(1000000)
            fsdp_params = deepcopy(list(model.parameters()))

        self.assertEqual(fsdp_params, list(local_model.parameters()))
示例#12
0
    def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu):
        offload = CPUOffload(offload_params=True)
        model = FSDP(
            DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload
        )
        local_model = DeterministicModel(wrap_fsdp=False)

        params_to_compare = (
            [p.clone() for p in model.parameters()]
            if rank0_only and self.rank != 0
            else list(local_model.parameters())
        )

        writeback = not rank0_only

        with model.summon_full_params(
            model,
            recurse=True,
            rank0_only=rank0_only,
            writeback=writeback,
            offload_to_cpu=offload_to_cpu,
        ):
            if writeback:
                with torch.no_grad():
                    for p in model.parameters():
                        p.add_(1)
                    for p in params_to_compare:
                        p.add_(1)
            # Below sleep causes failures without stream synchronization in
            # summon_full_params fix.
            torch.cuda._sleep(1000000)
            # FSDP param deepcopy() of params has issues
            fsdp_params = [p.clone() for p in model.parameters()]

        self.assertEqual(fsdp_params, params_to_compare)

        # CPU offload is enabled for main API, so we should point back to CPU
        for param in model.parameters():
            self.assertEqual(param.device, torch.device("cpu"))
示例#13
0
class TestFSDPStateDict(FSDPTest):
    @property
    def world_size(self):
        return 2

    def _broadcast_state_dict(self, state_dict):
        olist = [state_dict if self.rank == 0 else None]
        dist.broadcast_object_list(olist)
        return olist[0]

    def _compare_models(self, model, model_new, assert_fn, check_fp16=False):
        with FullyShardedDataParallel.summon_full_params(model):
            with FullyShardedDataParallel.summon_full_params(model_new):
                params = list(model.parameters())
                params_new = list(model_new.parameters())
                assert_fn(params, params_new)
                if check_fp16:
                    for tensor in model_new.parameters():
                        self.assertEqual(tensor.dtype, torch.float16)

    def _get_simple_nested_model(self,
                                 *fsdp_args,
                                 wrap=True,
                                 checkpoint_wrap=False,
                                 **fsdp_kwargs):
        if wrap:
            lin1 = nn.Linear(10, 10, bias=False).cuda()
            lin2 = nn.Linear(10, 10, bias=False).cuda()
            if checkpoint_wrap:
                lin1 = checkpoint_wrapper(lin1)
                lin2 = checkpoint_wrapper(lin2)
            seq = nn.Sequential(FSDP(lin1, *fsdp_args, **fsdp_kwargs), lin2)
            if checkpoint_wrap:
                seq = checkpoint_wrapper(seq)
            model = FSDP(seq, *fsdp_args, **fsdp_kwargs)
        else:
            model = nn.Sequential(
                nn.Linear(10, 10, bias=False).cuda(),
                nn.Linear(10, 10, bias=False).cuda())
        return model

    def _get_simple_model(self,
                          *fsdp_args,
                          checkpoint_wrap=False,
                          **fsdp_kwargs):
        lin = nn.Linear(10, 10, bias=False).cuda()
        if checkpoint_wrap:
            lin = checkpoint_wrapper(lin)
        model = FSDP(lin, *fsdp_args, **fsdp_kwargs)
        return model

    def _get_non_fsdp_root_module(self, *fsdp_args, wrap=True, **fsdp_kwargs):
        class FSDPContainer(nn.Module):
            def __init__(self, fsdp_1, fsdp_2):
                super().__init__()
                self.non_fsdp_lin = nn.Linear(10, 10, bias=False).cuda()
                self.fsdp_1 = fsdp_1
                self.fsdp_2 = fsdp_2

            def forward(self, x):
                x = self.non_fsdp_lin(x)
                x = self.fsdp_1(x)
                x = self.fsdp_2(x)
                return x

        return FSDPContainer(
            self._get_simple_nested_model(*fsdp_args, wrap=wrap,
                                          **fsdp_kwargs),
            self._get_simple_nested_model(*fsdp_args, wrap=wrap,
                                          **fsdp_kwargs),
        )

    def _get_state_dict_mgr(
        self,
        model: nn.Module,
        state_dict_type: str,
        state_dict_rank0_and_offload: bool,
    ):
        _state_dict_type = STATE_DICT_MAPPING[state_dict_type]
        if state_dict_type == "state_dict":
            config = FullStateDictConfig(
                rank0_only=state_dict_rank0_and_offload,
                offload_to_cpu=state_dict_rank0_and_offload,
            )
        else:
            config = None
        return FSDP.state_dict_type(model, _state_dict_type, config)

    def _validate_state_dict_contents(self,
                                      model,
                                      fsdp_state_dict,
                                      state_dict_rank0_and_offload,
                                      ignore_keys=None):
        if state_dict_rank0_and_offload:
            if self.rank == 0:
                self.assertNotEqual(fsdp_state_dict, {})
                for key, tensor in fsdp_state_dict.items():
                    if ignore_keys and key in ignore_keys:
                        continue
                    self.assertEqual(
                        tensor.device,
                        torch.device("cpu"),
                        f"{key} is unexpectedly on device {tensor.device}",
                    )
            else:
                # For non-FSDP roots, the non FSDP portion can still have parameters on rank 0,
                # so bypass the check for now.
                if isinstance(model, FSDP):
                    self.assertEqual(fsdp_state_dict, {})

    @skip_if_lt_x_gpu(2)
    @parametrize("checkpoint_wrap", ["first", "second", "both"])
    def test_fsdp_state_dict_with_activation_checkpoint(self, checkpoint_wrap):
        """Tests saving the state dict, zeroing a target model's parameters, and
        loading the state dict, where the source and target models may have a
        checkpoint wrapper."""
        for model_call in [
                partial(self._get_simple_model),
                partial(self._get_simple_nested_model)
        ]:
            model = model_call(
                checkpoint_wrap=(checkpoint_wrap in ["first", "both"]))
            state_dict = _get_state_dict(model, False, False)
            # Possibly wrap new model in activation checkpoint wrapper to test save/
            # load with this wrapper
            model_new = model_call(
                checkpoint_wrap=(checkpoint_wrap in ["second", "both"]))
            _zero_model(model_new)
            self._compare_models(model, model_new, self.assertNotEqual)
            # Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks
            model_new.load_state_dict(state_dict, strict=True)
            self._compare_models(model, model_new, self.assertEqual)

    @skip_if_lt_x_gpu(2)
    def test_state_dict_rank0_offload_save_load_flow(self):
        """Tests saving a model checkpoint only on rank 0 and loading it only
        on rank 0 with ``sync_module_states=True`` to emulate the workflow to
        avoid redundant CPU memory usage."""
        auto_wrap_policy = partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={
                TransformerEncoderLayer, TransformerDecoderLayer
            },
        )
        fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy}
        fsdp_model = TransformerWithSharedParams.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
            fsdp_kwargs,
        )
        # Force model parameters and buffers to be nonzero
        with FSDP.summon_full_params(fsdp_model):
            for tensor in itertools.chain(fsdp_model.parameters(),
                                          fsdp_model.buffers()):
                if torch.count_nonzero(tensor) == 0:
                    with torch.no_grad():
                        tensor.add_(
                            torch.tensor(1,
                                         dtype=tensor.dtype,
                                         device=tensor.device))
        with self._get_state_dict_mgr(fsdp_model, "state_dict", True):
            state_dict = deepcopy(_get_state_dict(fsdp_model))
        # Initialize a non-wrapped model on all ranks
        new_model = TransformerWithSharedParams.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_BEFORE,
        )
        _zero_model(new_model, zero_buffers=True)
        # Only load the checkpoint on rank 0
        if self.rank == 0:
            new_model.load_state_dict(state_dict, strict=True)
        _assert_module_states(
            new_model,
            process_group=self.process_group,
            assert_fn=self.assertNotEqual,
        )
        # Broadcast the module states from rank 0 with `sync_module_states=True`
        new_fsdp_model = FSDP(
            new_model,
            device_id=torch.cuda.current_device(),
            auto_wrap_policy=auto_wrap_policy,
            sync_module_states=True,
        )
        # Check FSDP models are equal across ranks
        with FSDP.summon_full_params(new_fsdp_model):
            _assert_module_states(
                new_fsdp_model,
                process_group=self.process_group,
                assert_fn=self.assertEqual,
            )
        # Check FSDP models correctly loaded the checkpoint
        with FullyShardedDataParallel.summon_full_params(fsdp_model):
            with FullyShardedDataParallel.summon_full_params(new_fsdp_model):
                params = list(fsdp_model.parameters())
                params_new = list(new_fsdp_model.parameters())
                self.assertEqual(params, params_new)

    @skip_if_lt_x_gpu(2)
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)],
    )
    @parametrize("fp16", [True, False])
    @parametrize("state_dict_rank0_and_offload", [True, False])
    def test_basic_save_and_load_state_dict(self, state_dict_type, cpu_offload,
                                            fp16,
                                            state_dict_rank0_and_offload):
        """
        Tests that we can save a state_dict and load it into a blank model
        with various configs such as fp16 and cpu offload and parameters
        match as expected.
        """
        if state_dict_rank0_and_offload and state_dict_type != "state_dict":
            return
        for model_call in [
                partial(self._get_non_fsdp_root_module,
                        cpu_offload=cpu_offload),
                partial(self._get_simple_nested_model,
                        cpu_offload=cpu_offload),
                partial(self._get_simple_model, cpu_offload=cpu_offload),
        ]:
            model = model_call()

            ctx = self._get_state_dict_mgr(model, state_dict_type,
                                           state_dict_rank0_and_offload)
            with ctx:
                fsdp_state_dict = _get_state_dict(model,
                                                  cpu_offload.offload_params,
                                                  fp16)

            ignore_keys = [
                k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k
            ]

            self._validate_state_dict_contents(
                model,
                fsdp_state_dict,
                state_dict_rank0_and_offload,
                ignore_keys=ignore_keys,
            )
            if fp16:
                # Verify fp16 is the type
                for tensor in fsdp_state_dict.values():
                    self.assertEqual(tensor.dtype, torch.float16)

            model_new = model_call()
            if not cpu_offload.offload_params:
                model_new = model_new.cuda()
            if fp16:
                model_new.half()

            # zero the model to ensure parameters are different.
            _zero_model(model_new)
            self._compare_models(model, model_new, self.assertNotEqual)

            # Verify parameters are the same in the new model.
            if state_dict_rank0_and_offload:
                # Broadcast the state dict and move it back to GPU in
                # preparation for loading.
                if not isinstance(model, FSDP):
                    # Move everything to CPU to avoid running into
                    # https://github.com/pytorch/pytorch/issues/77113, some params
                    # will still be on GPU for non FSDP root modules.
                    for k in fsdp_state_dict.keys():
                        fsdp_state_dict[k] = fsdp_state_dict[k].cpu()
                fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict)
                for key in fsdp_state_dict.keys():
                    fsdp_state_dict[key] = fsdp_state_dict[key].cuda()
            with FSDP.state_dict_type(model_new,
                                      STATE_DICT_MAPPING[state_dict_type]):
                model_new.load_state_dict(fsdp_state_dict, strict=True)

            self._compare_models(model,
                                 model_new,
                                 self.assertEqual,
                                 check_fp16=fp16)

    @skip_if_lt_x_gpu(2)
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
    @parametrize("mixed_precision", [True, False])
    @parametrize("state_dict_rank0_and_offload", [True, False])
    def test_save_and_load_after_forward_state_dict(
            self, state_dict_type, mixed_precision,
            state_dict_rank0_and_offload):
        """
        Test that saving after some training results in params being updated as
        expected.
        """
        if state_dict_rank0_and_offload and state_dict_type != "state_dict":
            return
        torch.cuda.set_device(self.rank)
        mixed_precision = (MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
        ) if mixed_precision else None)
        model = self._get_simple_nested_model(mixed_precision=mixed_precision)
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
        initial_params = get_full_params(model)
        for _ in range(6):
            inp = torch.randn(1, 10, device=torch.cuda.current_device())
            output = model(*inp)
            loss = output.sum()
            expected_dtype = torch.float32 if mixed_precision is None else torch.float16
            self.assertEqual(expected_dtype, loss.dtype)
            loss.backward()
            optim.step()

        trained_params = get_full_params(model)
        # Ensure some training occured
        self.assertNotEqual(initial_params, trained_params)
        # Save a copy of the state_dict
        fsd_mgr = self._get_state_dict_mgr(model, state_dict_type,
                                           state_dict_rank0_and_offload)
        with fsd_mgr:
            state_dict = model.state_dict()
            if state_dict_type == "state_dict":
                state_dict = {k: v.clone() for k, v in state_dict.items()}
            else:
                for sharded_tensor in state_dict.values():
                    shard = sharded_tensor._local_shards[0]
                    shard.tensor = shard.tensor.clone().detach_()
        self._validate_state_dict_contents(model, state_dict,
                                           state_dict_rank0_and_offload)
        _zero_model(model)

        # Ensure checkpointed params have the full param dtype
        for tensor in state_dict.values():
            self.assertEqual(tensor.dtype, torch.float32)

        # Load state_dict into zeroed model
        if state_dict_rank0_and_offload:
            # Broadcast the state dict and move it back to GPU in
            # preparation for loading.
            state_dict = self._broadcast_state_dict(state_dict)
            for key in state_dict.keys():
                state_dict[key] = state_dict[key].cuda()

        with FSDP.state_dict_type(model, STATE_DICT_MAPPING[state_dict_type]):
            model.load_state_dict(state_dict, strict=True)
        loaded_params = get_full_params(model)
        self.assertEqual(loaded_params, trained_params)

    def _initialize_model(
        self,
        wrap_fsdp: bool,
        wrap_ddp: bool = True,
        register_buffers: bool = False,
    ):
        # keep everything deterministic for input data
        torch.manual_seed(0)

        model = Model(wrap_fsdp, register_buffers=register_buffers).cuda()
        if wrap_fsdp:
            model = FSDP(model)
        elif wrap_ddp:
            model = DistributedDataParallel(model, device_ids=[self.rank])
        return model

    @staticmethod
    def _state_dict(model: Module, state_dict_type: str):
        try:
            enum_val = STATE_DICT_MAPPING[state_dict_type]
        except KeyError:
            raise ValueError(f"No state_dict type for {state_dict_type}")

        with FSDP.state_dict_type(model, enum_val):
            return model.state_dict()

    @staticmethod
    def _load_state_dict(model: Module, state_dict_type: str,
                         state_dict: Dict[str, Any]):
        try:
            enum_val = STATE_DICT_MAPPING[state_dict_type]
        except KeyError:
            raise ValueError(f"No state_dict for {state_dict_type}")

        with FSDP.state_dict_type(model, enum_val):
            return model.load_state_dict(state_dict, strict=True)

    def _dist_train(self, wrap_fsdp: bool, state_dict_type: str = ""):
        # TODO: Move this test to common_fsdp.
        model = self._initialize_model(wrap_fsdp)
        optim = SGD(model.parameters(), lr=0.1)

        in_data = torch.rand(64,
                             4,
                             requires_grad=True,
                             device=torch.device("cuda"))
        for _ in range(3):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        if wrap_fsdp:
            blank_model = FSDP(Model(True).cuda())
            _zero_model(blank_model)
            state_dict = self._state_dict(model, state_dict_type)
            self._load_state_dict(blank_model, state_dict_type, state_dict)
            return get_full_params(blank_model)
        else:
            return list(model.parameters())

    @skip_if_lt_x_gpu(2)
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
    def test_state_dict_save_load_flow(self, state_dict_type):
        fsdp_params = self._dist_train(wrap_fsdp=True,
                                       state_dict_type=state_dict_type)
        ddp_params = self._dist_train(wrap_fsdp=False)
        self.assertEqual(ddp_params, fsdp_params)

    @skip_if_lt_x_gpu(2)
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
    def test_fsdp_state_dict_keys(self, state_dict_type):
        state_dict = self._state_dict(self._initialize_model(True),
                                      state_dict_type)
        if state_dict_type == "local_state_dict":
            self.assertEqual(set(["flat_param", "inner.flat_param"]),
                             state_dict.keys())
        elif state_dict_type in ("state_dict", "sharded_state_dict"):
            # Keys should match local model.
            local_model = self._initialize_model(wrap_fsdp=False,
                                                 wrap_ddp=False)
            local_keys = local_model.state_dict().keys()
            self.assertEqual(state_dict.keys(), local_keys)
        else:
            raise NotImplementedError(f"No test for {state_dict_type}!")

    @skip_if_lt_x_gpu(2)
    @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)
    @parametrize("state_dict_rank0_and_offload", [True, False])
    @parametrize("fsdp_root", [True, False])
    def test_state_dict_load_into_local_module(
        self,
        state_dict_type,
        state_dict_rank0_and_offload,
        fsdp_root,
    ):
        """
        Tests that FSDP's state_dict can be loaded into a local model.
        """
        if state_dict_rank0_and_offload and state_dict_type != "state_dict":
            return
        if not fsdp_root:
            model = self._get_non_fsdp_root_module()
        else:
            model = self._initialize_model(wrap_fsdp=True,
                                           register_buffers=True)
        optim = SGD(model.parameters(), lr=0.1)
        if not fsdp_root:
            in_data = torch.randn(1,
                                  10,
                                  requires_grad=True,
                                  device=torch.device("cuda"))
        else:
            in_data = torch.rand(64,
                                 4,
                                 requires_grad=True,
                                 device=torch.device("cuda"))
        for _ in range(3):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        with FullyShardedDataParallel.summon_full_params(model):
            fsdp_params = deepcopy(list(model.parameters()))

        # get FSDP state_dict. Note that by default we return full_state_dict.
        sd_mgr = self._get_state_dict_mgr(model, state_dict_type,
                                          state_dict_rank0_and_offload)
        with sd_mgr:
            fsdp_state_dict = model.state_dict()

        ignore_keys = [
            k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k
        ]
        self._validate_state_dict_contents(
            model,
            fsdp_state_dict,
            state_dict_rank0_and_offload,
            ignore_keys=ignore_keys,
        )
        # Create zeroed local model
        if not fsdp_root:
            blank_local_model = self._get_non_fsdp_root_module(wrap=False)
        else:
            blank_local_model = self._initialize_model(wrap_fsdp=False,
                                                       wrap_ddp=False,
                                                       register_buffers=True)

        # Nothing should be FSDP
        for mod in blank_local_model.modules():
            self.assertFalse(isinstance(mod, FSDP))

        for param in blank_local_model.parameters():
            with torch.no_grad():
                param.zero_()

        fsdp_state_dict = _gather_state_dict(fsdp_state_dict)

        # Load fsdp's full state dict into the local and verify params are as
        # expected.
        if state_dict_rank0_and_offload:
            # Broadcast + CUDA state_dict
            if not isinstance(model, FSDP):
                # Some portions of the model on rank 0 might not be on CPU,
                # move everything to CPU to avoid running into
                # https://github.com/pytorch/pytorch/issues/77113.
                for k, t in fsdp_state_dict.items():
                    if t.device != torch.device("cpu"):
                        fsdp_state_dict[k] = t.cpu()
            fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict)
            for key in fsdp_state_dict.keys():
                fsdp_state_dict[key] = fsdp_state_dict[key].cuda()

        # if self.rank == 0:
        blank_local_model.load_state_dict(fsdp_state_dict, strict=True)
        local_params = list(blank_local_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)

    @skip_if_lt_x_gpu(2)
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
    @parametrize("double_nest", [True])
    def test_state_dict_skip_module(self, state_dict_type, double_nest):
        torch.cuda.set_device(self.rank)

        def _create_module(wrap_fsdp=True):
            LINEAR_SKIP = "linear_skip"
            ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else suppress()
            with ctx:
                module = SkipModel(double_nest=double_nest)
                # Full name of linear_skip param tensors in SkipModel, as would be
                # stored in checkpoint.
                linear_skip_tensor_names = [
                    k for k in dict(module.named_parameters()).keys()
                    if LINEAR_SKIP in k
                ]
                # skip SkipModule
                linear_skip = getattr(module, LINEAR_SKIP)
                delattr(module, LINEAR_SKIP)
                # Wrap FSDP
                fsdp = wrap(module)
                # reattach
                setattr(module, LINEAR_SKIP, linear_skip)
                return fsdp, linear_skip_tensor_names

        fsdp, linear_skip_tensor_names = _create_module()
        # Run a forward pass
        inp = torch.randn((1, 10), device=torch.cuda.current_device())
        loss = fsdp(inp)
        loss.sum().backward()

        with FSDP.state_dict_type(fsdp, STATE_DICT_MAPPING[state_dict_type]):
            state_dict = fsdp.state_dict()
        if self.rank == 0 and state_dict_type != "local_state_dict":
            sd_keys = list(state_dict.keys())
            expected = list(SkipModel(double_nest=False).state_dict().keys())
            self.assertEqual(sorted(sd_keys), sorted(expected))
            # TODO: parameters in linear_skip_tensor_names should not be handled
            # by FSDP.state_dict(). Have a check once this is implemented in
            # FSDP.state_dict().

        # Check that it can be loaded into FSDP.
        new_fsdp, _ = _create_module()
        _zero_model(new_fsdp)
        for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()):
            self.assertNotEqual(p1, p2)
        with FSDP.state_dict_type(new_fsdp,
                                  STATE_DICT_MAPPING[state_dict_type]):
            if state_dict_type != "local_state_dict":
                # FlatParameter has not supported deepcopy yet.
                state_dict = deepcopy(state_dict)
            new_fsdp.load_state_dict(state_dict, strict=True)
        for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()):
            self.assertEqual(p1, p2)

        # Test that the checkpoint can be loaded into a local model.
        local, _ = _create_module(wrap_fsdp=False)
        for param in local.parameters():
            with torch.no_grad():
                param.zero_()

        with fsdp.summon_full_params(fsdp):
            for (p1, p2) in zip(fsdp.parameters(), local.parameters()):
                self.assertNotEqual(p1, p2)

        if state_dict_type == "local_state_dict":
            return
        state_dict = _gather_state_dict(state_dict)
        with fsdp.summon_full_params(fsdp):
            if self.rank == 0:
                local.load_state_dict(state_dict, strict=True)
                for (p1, p2) in zip(fsdp.parameters(), local.parameters()):
                    self.assertEqual(p1, p2)

    @skip_if_lt_x_gpu(2)
    def test_wrong_state_dict_config(self):
        model = FSDP(Model(wrap_fsdp=True).cuda())
        with self.assertRaisesRegex(RuntimeError,
                                    "Expected state_dict_config of type"):
            with model.state_dict_type(model, StateDictType.FULL_STATE_DICT,
                                       LocalStateDictConfig()):
                pass

    @skip_if_lt_x_gpu(2)
    @parametrize("prefix", [True, False])
    @parametrize("ignore_inner", [True, False])
    def test_state_dict_with_ignored_modules(self, prefix, ignore_inner):
        # Initialize an FSDP-wrapped model with an ignored module that includes
        # both parameters and a buffer
        model = Model(wrap_fsdp=True,
                      register_buffers=True,
                      ignore_inner=ignore_inner).cuda()
        ignored_modules = [model.outer]
        ignored_tensor_to_tensor_name = {
            model.outer.bias: "outer.bias",
            model.outer.weight: "outer.weight",
        }
        if ignore_inner:
            ignored_tensor_to_tensor_name = {
                **ignored_tensor_to_tensor_name,
                model.inner.bias: "inner.bias",
                model.inner.weight: "inner.weight",
            }
        # Note that when model.inner is not ignored this test also ensures
        # non-ignored buffers are not cloned.
        buffer_to_buffer_name = {
            model.inner.buffer: "inner.buffer",
            model.outer.buffer: "outer.buffer",
        }
        fsdp_model = FSDP(model, ignored_modules=ignored_modules)
        prefix_str = "foo." if prefix else ""
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd1 = fsdp_model.state_dict(prefix=prefix_str)
        with FSDP.summon_full_params(fsdp_model):
            fsdp_params = deepcopy(list(fsdp_model.parameters()))
        # Check that the ignored parameters and all buffers are not cloned
        for tensor, tensor_name in {
                **ignored_tensor_to_tensor_name,
                **buffer_to_buffer_name,
        }.items():
            prefixed_tensor_name = f"{prefix_str}{tensor_name}"
            self.assertTrue(prefixed_tensor_name in sd1)
            self.assertEqual(tensor.data_ptr(),
                             sd1[prefixed_tensor_name].data_ptr(),
                             f"{prefixed_tensor_name}")
        # Check that the state dict can be loaded into a non-wrapped version of
        # the model
        nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda()
        for param in nonwrapped_model.parameters():
            with torch.no_grad():
                param.zero_()

        to_load = {k[len(prefix_str):]: v for k, v in sd1.items()}
        nonwrapped_model.load_state_dict(to_load, strict=True)
        local_params = list(nonwrapped_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)
        # Check that if we save a state dict again, the ignored parameters and
        # buffer still have the same data pointer
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd2 = fsdp_model.state_dict(prefix=prefix_str)
        for tensor, tensor_name in {
                **ignored_tensor_to_tensor_name,
                **buffer_to_buffer_name,
        }.items():
            prefixed_tensor_name = f"{prefix_str}{tensor_name}"
            self.assertTrue(prefixed_tensor_name in sd2)
            self.assertEqual(tensor.data_ptr(),
                             sd2[prefixed_tensor_name].data_ptr())
            self.assertEqual(sd1[prefixed_tensor_name].data_ptr(),
                             sd2[prefixed_tensor_name].data_ptr())

    @skip_if_lt_x_gpu(2)
    def test_state_dict_type(self):
        module = SkipModel(double_nest=True)
        with enable_wrap(wrapper_cls=FSDP):
            fsdp = wrap(module)
        with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT):
            pass
        for module in FSDP.fsdp_modules(fsdp):
            self.assertEqual(module._state_dict_type,
                             StateDictType.FULL_STATE_DICT)
示例#14
0
class TestGradAcc(FSDPTest):
    """Tests ``FullyShardedDataParallel``'s gradient accumulation via both its
    ``no_sync()`` context manager and without the context manager."""
    def _test_grad_acc(
        self,
        batch_dim: int,
        configs: List[_GradAccConfig],
        cpu_offload: CPUOffload,
        backward_prefetch: Optional[BackwardPrefetch],
    ):
        """
        Tests gradient accumulation by comparing a run that trains sequentially
        through some batches while accumulating gradients with a run that
        trains on the concatenation of those batches in a single iteration.

        The last iteration always synchronizes gradients regardless of what is
        specified by the last element of ``configs``.

        Arguments:
            batch_dim (int): Batch dimension in the input tensor to be passed
                into the model for the forward pass.
            configs (List[_GradAccConfig]): :class:`list` of configurations
                specifying how gradients are accumulated; for example, a list
                corresponding to [(False, 2), (True, 2), (False, 2)] indicates
                to accumulate over 2 + 2 + 2 = 6 total iterations, where the
                first two do not use ``no_sync()``, the middle two do use
                ``no_sync()``, and the final two again do not use
                ``no_sync()``.
            cpu_offload (CPUOffload): Configures CPU offloading.
            backward_prefetch (Optional[BackwardPrefetch]): Specifies at which
                point to prefetch the next layer's full parameters during the
                backward pass, if at all.
        """
        # Gradient accumulation outside `no_sync()` is not currently compatible
        # with CPU offloading
        if cpu_offload.offload_params and \
                any(not config.use_no_sync for config in configs):
            return
        old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32
        try:
            # Disable TF32 to prevent floating point drift
            torch.backends.cuda.matmul.allow_tf32 = False

            # Initialize the FSDP model and optimizer
            group = dist.distributed_c10d._get_default_group()
            fsdp_model: FSDP = self._get_wrapped_model(
                group,
                cuda_first=False,
                add_bn=False,
                config={
                    "cpu_offload": cpu_offload,
                    "backward_prefetch": backward_prefetch,
                },
            )  # disable BN since the test uses varying batch sizes
            fsdp_model.eval()  # disable dropout
            device = torch.device("cuda")
            optim = torch.optim.SGD(
                fsdp_model.parameters(),
                lr=0.01,
                momentum=0.9,
            )

            # Generate the sequence of batches, each containing the same data
            # but permuted
            def permute_tensor(x: torch.Tensor):
                return x.view(-1)[torch.randperm(x.numel())].view_as(x)

            batch: Tuple[torch.Tensor, ...] = \
                fsdp_model.module.get_input(device)
            batches: List[Tuple[torch.Tensor, ...]] = [batch]
            num_iters_to_acc = sum(config.num_iters for config in configs)
            for _ in range(num_iters_to_acc - 1):
                batches.append(tuple(permute_tensor(t) for t in batch))
            for (batch1, batch2) in itertools.combinations(batches, r=2):
                for t1, t2 in zip(batch1, batch2):
                    assert not torch.all(t1 == t2), \
                        "Check the test to make sure that batches are distinct"

            # Concatenate the batches along the given batch dimension
            concat_batch: Tuple[torch.Tensor, ...] = tuple(
                torch.cat(ts, dim=batch_dim) for ts in zip(*batches))

            # Establish reference gradients using the concatenated batch
            fsdp_model.zero_grad()
            output = fsdp_model(*concat_batch)
            ref_loss = fsdp_model.module.get_loss(concat_batch, output)
            ref_loss.backward()
            ref_grads = [
                p.grad.detach().clone() for p in fsdp_model.parameters()
            ]

            # Compute and accumulate the gradients
            fsdp_model.zero_grad()
            losses = []
            batch_idx = 0
            for config in configs:
                sync_context = fsdp_model.no_sync() if config.use_no_sync \
                    else contextlib.suppress()
                with sync_context:
                    for _ in range(config.num_iters):
                        if batch_idx == num_iters_to_acc - 1:
                            break  # always sync on the last iteration
                        batch = batches[batch_idx]
                        batch_idx += 1
                        output = fsdp_model(*batch)
                        loss = fsdp_model.module.get_loss(batch, output)
                        loss.backward()
                        losses.append(loss)
            output = fsdp_model(*batches[-1])
            loss = fsdp_model.module.get_loss(batches[-1], output)
            loss.backward()
            losses.append(loss)
            acc_loss = sum(losses)
            acc_grads = [
                p.grad.detach().clone() for p in fsdp_model.parameters()
            ]

            # Compare the losses and gradients
            torch.testing.assert_close(ref_loss, acc_loss)
            self.assertEqual(len(ref_grads), len(acc_grads))
            for ref_grad, acc_grad in zip(ref_grads, acc_grads):
                self.assertEqual(ref_grad.device, acc_grad.device)
                self.assertEqual(ref_grad.size(), acc_grad.size())
                self.assertEqual(ref_grad.dtype, acc_grad.dtype)
                torch.testing.assert_close(ref_grad, acc_grad)

            # Check that the optimizer step does not error
            optim.step()
        finally:
            torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32

    @skip_if_lt_x_gpu(2)
    @parametrize("configs", [
        _GradAccConfigs([
            _GradAccConfig(use_no_sync=True, num_iters=3),
            _GradAccConfig(use_no_sync=False, num_iters=3),
            _GradAccConfig(use_no_sync=True, num_iters=3),
        ]),
        _GradAccConfigs([
            _GradAccConfig(use_no_sync=False, num_iters=3),
            _GradAccConfig(use_no_sync=True, num_iters=3),
            _GradAccConfig(use_no_sync=False, num_iters=3),
        ]),
    ])
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=False),
         CPUOffload(offload_params=True)],
    )
    @parametrize(
        "backward_prefetch",
        [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None],
    )
    def test_grad_acc(
        self,
        configs: _GradAccConfigs,
        cpu_offload: CPUOffload,
        backward_prefetch: Optional[BackwardPrefetch],
    ):
        """
        Tests gradient accumulation.

        This exercises gradient accumulation inside and outside the
        ``no_sync()`` context manager, in particular by interleaving the two.
        It tests both interleaving starting with (and ending with, resp.)
        inside versus outside ``no_sync()`` to ensure that initial conditions
        (and final conditions, resp.) do not affect the correctness. This test
        also checks for compatibility with the CPU offload and backward
        prefetch options.

        NOTE: Gradient accumulation without using the ``no_sync()`` context
        manager is not currently compatible with CPU offloading, so those tests
        are vacuous.
        """
        self._test_grad_acc(
            batch_dim=1,
            configs=configs.configs,
            cpu_offload=cpu_offload,
            backward_prefetch=backward_prefetch,
        )
class TestSummonFullParams(FSDPTest):
    @property
    def world_size(self):
        return 2

    def get_model_param_count(self, m):
        return sum([p.numel() for p in m.parameters()])

    # padding ensures that all shards have the same size with the least amount of padding
    def get_expected_sharded_size(self, global_size):
        return int(math.ceil(global_size / self.world_size))

    @skip_if_lt_x_gpu(2)
    @parametrize("writeback", [True, False])
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)],
    )
    @parametrize("mixed_precision", [True, False])
    @parametrize("modify_outer", [True, False])
    def test_summon_full_param_writeback(self, writeback, cpu_offload,
                                         mixed_precision, modify_outer):
        mixed_precision = MixedPrecision() if mixed_precision else None
        return _run_test_summon_full_param_writeback(
            self,
            writeback,
            modify_outer,
            cpu_offload=cpu_offload,
            mixed_precision=mixed_precision,
        )

    @skip_if_lt_x_gpu(2)
    @parametrize("mixed_precision", [True, False])
    def test_summon_full_param_shard_value(self, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        raw_model = nn.Linear(10, 11)
        raw_model_size = self.get_model_param_count(raw_model)
        expected_shard_size = self.get_expected_sharded_size(raw_model_size)

        model = FSDP(raw_model.cuda(self.rank),
                     mixed_precision=mixed_precision)
        self.assertEqual(expected_shard_size,
                         self.get_model_param_count(model))

        # we're assuming a single flattened param
        self.assertEqual(1, len(list(model.parameters())))

        my_shard = torch.clone(next(model.parameters()))

        with model.summon_full_params(model):
            self.assertEqual(raw_model_size, self.get_model_param_count(model))
            parameters = list(model.parameters())
            all_shards = FlatParamHandle.flatten_params(parameters,
                                                        requires_grad=False)
            my_slice = torch.chunk(all_shards, self.world_size)[self.rank]

            # shards are padded but the full_param tensor is not
            a, b = my_shard[0:my_slice.numel()], my_slice
            self.assertTrue(
                torch.equal(my_shard[0:my_slice.numel()].cpu(),
                            my_slice.cpu()))

    @skip_if_lt_x_gpu(2)
    @parametrize("recurse", [True, False])
    @parametrize("summon_outer", [True, False])
    @parametrize("mixed_precision", [True, False])
    def test_summon_full_param_recursive(self, recurse, summon_outer,
                                         mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False),
                     mixed_precision=mixed_precision),
                nn.Linear(5, 3, bias=False),
            ),
            mixed_precision=mixed_precision,
        ).cuda(self.rank)

        global_inner_numel = self.get_model_param_count(
            nn.Linear(5, 5, bias=False))
        global_outer_numel = self.get_model_param_count(
            nn.Linear(5, 3, bias=False))

        shard_inner_numel = int(math.ceil(global_inner_numel /
                                          self.world_size))
        shard_outer_numel = int(math.ceil(global_outer_numel /
                                          self.world_size))

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        self.assertEqual(shard_outer_numel, outer_param.numel())
        self.assertEqual(shard_inner_numel, inner_param.numel())

        model_to_summon = model if summon_outer else model[0]
        # outer is summoned if _summon_full_param is called on the outer FSDP module
        expected_outer_numel = global_outer_numel if summon_outer else shard_outer_numel

        # inner is summoned if _summon_full_param is called with recursion or on the inner FSDP module
        expected_inner_numel = (global_inner_numel if recurse
                                or not summon_outer else shard_inner_numel)

        with model_to_summon.summon_full_params(model_to_summon,
                                                recurse=recurse):
            self.assertEqual(expected_outer_numel, outer_param.numel())
            self.assertEqual(expected_inner_numel, inner_param.numel())

    @skip_if_lt_x_gpu(2)
    def test_cannot_summon_full_params_from_forward(self):
        class MyModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.a = nn.Parameter(torch.zeros(5))

            def forward(self, fsdp_module):
                with fsdp_module.summon_full_params(fsdp_module):
                    pass

        model = FSDP(MyModule()).cuda(self.rank)
        with self.assertRaisesRegex(ValueError,
                                    "current state is TrainingState_.FORWARD"):
            model(model)

    @skip_if_lt_x_gpu(2)
    def test_cannot_summon_full_params_from_backward(self):
        model = FSDP(nn.Linear(2, 1)).cuda(self.rank)

        output = model(torch.ones(2).cuda(self.rank))

        def bad_backwards_hook(tensor):
            with model.summon_full_params(model):
                pass
            return None

        self.assertTrue(output.requires_grad)
        output.register_hook(bad_backwards_hook)

        with self.assertRaisesRegex(
                ValueError, "current state is TrainingState_.BACKWARD_PRE"):
            output.backward()

    @skip_if_lt_x_gpu(2)
    @parametrize("mixed_precision", [True, False])
    def test_summon_full_params_respects_reshard_after_forward(
            self, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False),
                     mixed_precision=mixed_precision),
                nn.Linear(5, 3, bias=False),
            ),
            mixed_precision=mixed_precision,
        ).cuda(self.rank)

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        outer_full_param_size = outer_param.numel() * self.world_size

        # trigger lazy init
        model(torch.zeros(5).cuda(self.rank))
        # the root FSDP module keeps all params around
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        # similarly summon_full_params should have the same behavior
        with model.summon_full_params(model):
            pass
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

    @skip_if_lt_x_gpu(2)
    def test_summon_single_param(self):
        model = FSDP(nn.Linear(1, 1, bias=False)).cuda(self.rank)

        p = model.get_parameter("_fsdp_wrapped_module.flat_param")
        self.assertEqual(1, p.numel())

        with torch.no_grad():
            # This sets the local shard value
            p[0] = self.rank + 2

        with model.summon_full_params(model, writeback=True):
            self.assertEqual(1, p.numel())
            with torch.no_grad():
                p.copy_(torch.zeros_like(p))

        # most ranks hold no data and wrote to padding so only rank zero will observe the above write
        if self.rank == 0:
            self.assertEqual(0, p[0])
        else:
            self.assertEqual(self.rank + 2, p[0])

    @skip_if_lt_x_gpu(2)
    @parametrize("rank0_only", [True, False])
    @parametrize("offload_to_cpu", [True, False])
    def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu):
        offload = CPUOffload(offload_params=True)
        model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload),
                     cpu_offload=offload)
        local_model = DeterministicModel(wrap_fsdp=False)

        params_to_compare = ([
            p.clone() for p in model.parameters()
        ] if rank0_only and self.rank != 0 else list(local_model.parameters()))

        writeback = not rank0_only

        with model.summon_full_params(
                model,
                recurse=True,
                rank0_only=rank0_only,
                writeback=writeback,
                offload_to_cpu=offload_to_cpu,
        ):
            if writeback:
                with torch.no_grad():
                    for p in model.parameters():
                        p.add_(1)
                    for p in params_to_compare:
                        p.add_(1)
            # Below sleep causes failures without stream synchronization in
            # summon_full_params fix.
            torch.cuda._sleep(1000000)
            # FSDP param deepcopy() of params has issues
            fsdp_params = [p.clone() for p in model.parameters()]

        self.assertEqual(fsdp_params, params_to_compare)

        # CPU offload is enabled for main API, so we should point back to CPU
        for param in model.parameters():
            self.assertEqual(param.device, torch.device("cpu"))

    @skip_if_lt_x_gpu(2)
    def test_summon_from_non_fsdp(self):
        class FSDPContainer(nn.Module):
            def __init__(self, fsdp_1, fsdp_2, fsdp_3):
                super().__init__()
                self.fsdp_1 = fsdp_1
                self.fsdp_2 = fsdp_2
                self.fsdp_3 = fsdp_3

        model_fsdp = FSDPContainer(
            FSDP(DeterministicModel(wrap_fsdp=True)),
            FSDP(DeterministicModel(wrap_fsdp=True)),
            DeterministicModel(wrap_fsdp=False),
        )
        model_no_fsdp = FSDPContainer(
            DeterministicModel(wrap_fsdp=False),
            DeterministicModel(wrap_fsdp=False),
            DeterministicModel(wrap_fsdp=False),
        )

        params_to_compare = list(model_no_fsdp.parameters())
        with FSDP.summon_full_params(model_fsdp):
            fsdp_params = [p.clone() for p in model_fsdp.parameters()]

        self.assertEqual(params_to_compare, fsdp_params)

    @skip_if_lt_x_gpu(2)
    @parametrize("rank0_only", [True, False])
    @parametrize("offload_to_cpu", [True, False])
    @parametrize("mixed_precision", [True, False])
    def test_reshard_outside_forward_backward_iteration(
            self, rank0_only, offload_to_cpu, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False),
                     mixed_precision=mixed_precision),
                nn.Linear(5, 1, bias=False),
            ),
            mixed_precision=mixed_precision,
        ).cuda(self.rank)

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        outer_full_param_size = outer_param.numel() * self.world_size

        # First lets validate our assumption about resharding

        output = model(torch.zeros(5).cuda(self.rank))
        # the root FSDP module keeps all params around
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        # we reshard everything after backward() finishes
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        # now lets repeat it with summon done in between

        output = model(torch.zeros(5).cuda(self.rank))
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
        with model.summon_full_params(
                model,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            pass
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        with model.summon_full_params(
                model,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            pass
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

    @skip_if_lt_x_gpu(2)
    @parametrize("rank0_only", [True, False])
    @parametrize("offload_to_cpu", [True, False])
    @parametrize("mixed_precision", [True, False])
    def test_params_are_unflattenned(self, rank0_only, offload_to_cpu,
                                     mixed_precision):
        layer_shape = (10, 12)
        model = nn.Linear(*layer_shape, bias=False).cuda(self.rank)
        mixed_precision = MixedPrecision() if mixed_precision else None
        fsdp_model = FSDP(deepcopy(model),
                          mixed_precision=mixed_precision).cuda(self.rank)

        def _get_flat_param():
            return fsdp_model.get_parameter("_fsdp_wrapped_module.flat_param")

        flattened_param = _get_flat_param()
        self.assertEqual(layer_shape[0] * layer_shape[1] / 2,
                         flattened_param.numel())

        with fsdp_model.summon_full_params(
                fsdp_model,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            if self.rank == 0 or not rank0_only:
                self.assertEqual(fsdp_model.weight.shape, model.weight.shape)
                expected_device = (torch.device("cpu")
                                   if offload_to_cpu else torch.device(
                                       "cuda", torch.cuda.current_device()))
                self.assertTrue(expected_device == fsdp_model.weight.device)
            else:
                # Nonzero rank with rank0_only maintains original params.
                flat_within_ctx = _get_flat_param()
                self.assertEqual(flat_within_ctx, flattened_param)
                self.assertEqual(flat_within_ctx.device,
                                 torch.device(torch.cuda.current_device()))

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device()))

    @skip_if_lt_x_gpu(2)
    @parametrize("rank0_only", [True, False])
    @parametrize("offload_to_cpu", [True, False])
    @parametrize("mixed_precision", [True, False])
    def test_params_count_and_value(
        self,
        rank0_only: bool,
        offload_to_cpu: bool,
        mixed_precision: bool,
    ):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
        )
        fsdp_model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
        )
        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))
        params_to_compare = ([p.to(dev) for p in model.module.parameters()]
                             if not rank0_only or self.rank == 0 else list(
                                 p.clone() for p in fsdp_model.parameters()))
        with FSDP.summon_full_params(fsdp_model,
                                     rank0_only=rank0_only,
                                     writeback=not rank0_only):
            for p1, p2 in itertools.zip_longest(fsdp_model.parameters(),
                                                params_to_compare):
                self.assertEqual(p1, p2)

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device()))

    @skip_if_lt_x_gpu(2)
    def test_raises_rank0_with_writeback(self):
        """Tests that ``summon_full_params()`` with both ``rank0_only=True``
        and ``writeback=True`` raises an error."""
        nested_wrapped_module = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
        )
        with self.assertRaisesRegex(ValueError, "is not supported"):
            with FSDP.summon_full_params(nested_wrapped_module,
                                         rank0_only=True,
                                         writeback=True):
                pass

    @skip_if_lt_x_gpu(2)
    @parametrize("prefix", ["", "test_prefix"])
    @parametrize("recurse", [False, True])
    def test_named_parameters_buffers(self, prefix: str, recurse: bool):
        """Tests that ``named_parameters()`` and ``named_buffers()`` for a
        top-level FSDP-wrapped model matches their behavior for the equivalent
        non-wrapped model."""
        model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
        )
        model.register_buffer("buffer", torch.ones(1))
        # `named_parameters()` and `named_buffers` will contain FSDP prefixes
        # if called on a non-FSDP root module
        fsdp_model = FSDP(
            NestedWrappedModule.init(
                self.process_group,
                FSDPInitMode.NO_FSDP,
                CUDAInitMode.CUDA_BEFORE,
                deterministic=True,
            ),
            self.process_group,
        )
        fsdp_model.register_buffer("buffer", torch.ones(1))
        with FSDP.summon_full_params(fsdp_model):
            for call in ["named_parameters", "named_buffers"]:
                for (n1, p1), (n2, p2) in itertools.zip_longest(
                        getattr(fsdp_model, call)(prefix=prefix,
                                                  recurse=recurse),
                        getattr(model, call)(prefix=prefix, recurse=recurse),
                ):
                    self.assertEqual(n1, n2)
                    self.assertEqual(p1, p2)
示例#16
0
    def _test_identical_outputs(self,
                                model_init_fn,
                                *args,
                                ref_ddp_fn=None,
                                num_steps=2,
                                fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
                                lr=0.01,
                                cpu_offload=CPUOffload(),
                                backward_prefetch=None,
                                sharding_strategy=None,
                                save_model=True,
                                clip_norm=0.3,
                                norm_type=None,
                                **kwargs):
        group = dist.distributed_c10d._get_default_group()
        rank = group.rank()
        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrap_fsdp=False).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(model,
                                                        device_ids=[rank],
                                                        output_device=rank)
        else:
            model = ref_ddp_fn(model)

        # DDP training
        ref_loss = self._train_for_several_steps(model,
                                                 num_steps,
                                                 autocast=False,
                                                 lr=lr,
                                                 fsdp_cpu_offload=cpu_offload)
        ref_full_params = list(model.parameters())

        # Confirm we get the same behavior using FullyShardedDataParallel.
        try:
            model = model_init_fn(
                group=group,
                wrap_fsdp=True,
                fsdp_init_mode=fsdp_init_mode,
                cpu_offload=cpu_offload,
                backward_prefetch=backward_prefetch,
                sharding_strategy=sharding_strategy,
            )
        except Exception as e:
            raise ValueError(
                f"model_Init_fn {model_init_fn} got error {str(e)}")

        cpu_offload = cpu_offload or CPUOffload()  # disabled if not specified.
        model = FullyShardedDataParallel(
            model,
            cpu_offload=cpu_offload,
            backward_prefetch=backward_prefetch,
            sharding_strategy=sharding_strategy,
        )
        # Call model.cuda() after init FSDP if specified.
        if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
            model = model.cuda()

        # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we
        # expect FSDP code to raise error that we check below, in the case of
        # offload params.
        if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
            for p in model.parameters():
                # Should be on CPU regardless of if param is sharded.
                self.assertEqual(p.device, torch.device("cpu"),
                                 f"Mismatch, cpu offload is {cpu_offload}")

        only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params
        ctx = (self.assertRaisesRegex(AssertionError,
                                      "Expected param to be on CPU")
               if only_check_err else suppress())
        with ctx:
            # FSDP training
            shard_loss = self._train_for_several_steps(
                model,
                num_steps,
                autocast=False,
                lr=lr,
                fsdp_cpu_offload=cpu_offload,
                save_model=save_model,
            )
        # We only check for errors in the case we have the following setup:
        # model = FSDP(model, cpu_offload=True)
        # model = model.cuda()
        # so skip the rest of this logic.
        if only_check_err:
            return
        # If CPU offload, next call will change model params to GPU. Sanity
        # check that params are on CPU before.
        if cpu_offload.offload_params:
            device_set = {p.device for p in model.parameters()}
            self.assertEqual({torch.device("cpu")}, device_set,
                             f"Got device set {device_set}")
        shard_full_params = get_full_params(model)

        if cpu_offload.offload_params:
            shard_loss = shard_loss.cuda()
        torch.testing.assert_allclose(ref_loss, shard_loss)
        self.assertEqual(
            ref_full_params,
            shard_full_params,
            exact_device=True,
            msg="FullyShardedDataParallel didn't match PyTorch DDP",
        )
示例#17
0
    def _test_fsdp_parity(
        self,
        model_class: Type[FSDPTestModel],
        fsdp_init_mode: FSDPInitMode,
        cuda_init_mode: CUDAInitMode,
        ref_init_fn: Optional[Callable] = None,
        num_iters: int = 2,
        save_model: bool = True,
        cpu_offload: CPUOffload = CPUOffload(),
        backward_prefetch: Optional[BackwardPrefetch] = None,
        forward_prefetch: bool = False,
        sharding_strategy: Optional[ShardingStrategy] = None,
        mixed_precision: Optional[MixedPrecision] = None,
        enable_sharded_grad_scaler: bool = False,
        use_pure_fp16: bool = False,
        norm_type: Optional[Union[float, int]] = None,
        init_kwargs: Optional[Dict[str, Any]] = None,
        **fsdp_kwargs,
    ):
        """
        Tests FSDP training against a reference, which defaults to DDP but
        may be customized with ``ref_init_fn``.

        Args:
            model_class (Type[FSDPTestModel]): A model class that inherits from
                ``FSDPTestModel``, which defines the expected interface.
            fsdp_init_mode (FSDPInitMode): The mode to initialize the
                FSDP-wrapped model. This should not be ``NO_FSDP``.
            ref_init_fn (Optional[Callable]): A callable to invoke that wraps a
                non-wrapped model to construct the reference model, where this
                wrapper should provide data parallel semantics. If ``None``,
                then the callable defaults to the DDP constructor.
        """
        assert fsdp_init_mode != FSDPInitMode.NO_FSDP, "Expects an FSDP init mode that wraps with FSDP"
        if init_kwargs is None:
            init_kwargs = {}
        lr = 1e-2
        rank = self.process_group.rank()
        # Establish reference behavior with DDP
        model = model_class.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
            **init_kwargs,
        )
        if ref_init_fn is None:
            ref_model = DDP(model, device_ids=[rank], output_device=rank)
        else:
            ref_model = ref_init_fn(model)
        if use_pure_fp16:
            ref_model = ref_model.half()
        ref_loss = self._train_for_several_steps(
            ref_model,
            num_iters,
            autocast=mixed_precision is not None,
            lr=lr,
            fsdp_cpu_offload=cpu_offload,
            mixed_precision=mixed_precision,
            norm_type=norm_type,
            enable_sharded_grad_scaler=enable_sharded_grad_scaler,
            use_pure_fp16=use_pure_fp16,
        )
        ddp_params = list(ref_model.parameters())
        # Check against FSDP behavior
        fsdp_kwargs.update({
            "cpu_offload": cpu_offload,
            "backward_prefetch": backward_prefetch,
            "forward_prefetch": forward_prefetch,
            "sharding_strategy": sharding_strategy,
            "mixed_precision": mixed_precision,
        })
        try:
            fsdp_model = model_class.init(
                self.process_group,
                fsdp_init_mode,
                cuda_init_mode,
                fsdp_kwargs,
                deterministic=True,
                **init_kwargs,
            )
        except Exception as e:
            raise ValueError(
                f"Initializing {model_class} raised error {str(e)}")
        if not isinstance(fsdp_model, FSDP):
            # Enforce that we wrap with top-level FSDP since we are comparing
            # assuming a data parallel reference and some test models may not
            # do so in their `init()` method
            fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs)
        if use_pure_fp16:
            # Change the model parameter dtype after FSDP initialization
            fsdp_model = fsdp_model.half()
        if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
            fsdp_model = fsdp_model.cuda()
        offload_params = cpu_offload is not None and cpu_offload.offload_params
        # Offloading parameters with `CUDA_AFTER` should raise an error during
        # lazy initialization due to the parameter devices not being CPU;
        # otherwise, all parameter devices should be CPU
        expects_device_error = offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER
        expects_cpu_device = offload_params and cuda_init_mode != CUDAInitMode.CUDA_AFTER
        if expects_cpu_device:
            cpu_device = torch.device("cpu")
            for param in fsdp_model.parameters():
                self.assertEqual(param.device, cpu_device)
        context = (self.assertRaisesRegex(AssertionError,
                                          "Expected param to be on CPU")
                   if expects_device_error else suppress())
        with context:
            fsdp_loss = self._train_for_several_steps(
                fsdp_model,
                num_iters,
                autocast=False,
                lr=lr,
                fsdp_cpu_offload=cpu_offload,
                save_model=save_model,
                mixed_precision=mixed_precision,
                norm_type=norm_type,
                enable_sharded_grad_scaler=enable_sharded_grad_scaler,
                use_pure_fp16=use_pure_fp16,
            )
        # No need to check for parameter and loss parity if expecting an error
        if expects_device_error:
            return
        # Check parameter devices are CPU if offloading to CPU before calling
        # `get_full_params()`, which will cast the parameters to FP32
        if offload_params:
            for param in fsdp_model.parameters():
                self.assertEqual(param.device, cpu_device)
            fsdp_loss = fsdp_loss.cuda()
        fsdp_unsharded_params = get_full_params(fsdp_model)
        torch.testing.assert_allclose(ref_loss, fsdp_loss)
        # Do not check for parameter parity if using mixed precision since (1)
        # the DDP parameters are in FP16 (from `half()`) while the FSDP
        # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs
        # the optimizer in FP16 while FSDP runs it in FP32
        if mixed_precision is not None:
            self.assertEqual(
                ddp_params,
                fsdp_unsharded_params,
                exact_device=True,
                msg="FSDP did not match DDP",
            )
mp_configs = [
    default_mp, mp_only_reduce, mp_only_param_and_buf, mp_no_mixed_precision
]
if nccl_supports_bf16:
    mp_diff_buffer_and_reduce = MixedPrecision(param_dtype=torch.float16,
                                               buffer_dtype=torch.bfloat16,
                                               reduce_dtype=torch.float32)
    mp_configs.extend([mp_diff_buffer_and_reduce])

# Buffer original dtype, which can differ from model params.
_BUFFER_ORIG_DTYPE = torch.float64

params = "mp_config,cpu_offload,backward_prefetch,full_precision_param_dtype"
cpu_offload_config = [
    CPUOffload(offload_params=True),
    CPUOffload(offload_params=False)
]
backward_prefetch_config = [
    BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST
]
full_precision_param_dtype_config = [torch.float32, torch.float64]
configs = list(
    product(
        mp_configs,
        cpu_offload_config,
        backward_prefetch_config,
        full_precision_param_dtype_config,
    ))

test_name_mapping = {
示例#19
0
class TestNoSync(FSDPTest):
    """Tests ``FullyShardedDataParallel``'s gradient accumulation via its
    ``no_sync()`` context manager."""
    def _test_no_sync(
        self,
        batch_dim: int,
        num_iters_to_acc: int,
        cpu_offload: CPUOffload,
        backward_prefetch: Optional[BackwardPrefetch],
    ):
        """
        Tests ``no_sync()`` by comparing a run that trains sequentially through
        some batches while accumulating gradients with a run that trains on the
        concatenation of those batches in a single iteration. The number of
        batches, i.e. the number of iterations for which to accumulate
        gradients, is given by ``num_iters_to_acc``.

        Arguments:
            batch_dim (int): Batch dimension in the input tensor to be passed
                into the model for the forward pass.
            num_iters_to_acc (int): Number of iterations for which to
                accumulate gradients; all but the last iteration are run using
                the ``no_sync()`` context manager so that gradients are not
                synchronized until the final iteration.
            cpu_offload (CPUOffload): Configures CPU offloading.
            backward_prefetch (Optional[BackwardPrefetch]): Specifies at which
                point to prefetch the next layer's full parameters during the
                backward pass, if at all.
        """
        old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32
        try:
            # Disable TF32 to prevent floating point drift
            torch.backends.cuda.matmul.allow_tf32 = False

            # Initialize the FSDP model and optimizer
            group = dist.distributed_c10d._get_default_group()
            fsdp_model: FSDP = self._get_wrapped_model(
                group,
                cuda_first=False,
                add_bn=False,
                cpu_offload=cpu_offload,
                backward_prefetch=backward_prefetch,
            )  # disable BN since the test uses varying batch sizes
            fsdp_model.eval()  # disable dropout
            device = torch.device("cuda")
            optim = torch.optim.SGD(fsdp_model.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

            # Generate the sequence of batches, each containing the same data but
            # permuted
            def permute_tensor(x: torch.Tensor):
                return x.view(-1)[torch.randperm(x.numel())].view_as(x)

            batch: Tuple[torch.Tensor,
                         ...] = fsdp_model.module.get_input(device)
            batches: List[Tuple[torch.Tensor, ...]] = [batch]
            for _ in range(num_iters_to_acc - 1):
                batches.append(tuple(permute_tensor(t) for t in batch))
            for (batch1, batch2) in itertools.combinations(batches, r=2):
                for t1, t2 in zip(batch1, batch2):
                    assert not torch.all(t1 == t2)

            # Concatenate the batches along the given batch dimension
            concat_batch: Tuple[torch.Tensor, ...] = tuple(
                torch.cat(ts, dim=batch_dim) for ts in zip(*batches))

            # Establish reference gradients using the concatenated batch
            fsdp_model.zero_grad()
            output = fsdp_model(*concat_batch)
            ref_loss = fsdp_model.module.get_loss(concat_batch, output)
            ref_loss.backward()
            ref_grads = [
                p.grad.detach().clone() for p in fsdp_model.parameters()
            ]

            # Compute the gradients by accumulating via `no_sync()`
            fsdp_model.zero_grad()
            losses = []
            with fsdp_model.no_sync():
                for batch in batches[:
                                     -1]:  # accumulate for all but the last batch
                    output = fsdp_model(*batch)
                    loss = fsdp_model.module.get_loss(batch, output)
                    loss.backward()
                    losses.append(loss)
            output = fsdp_model(*batches[-1])
            loss = fsdp_model.module.get_loss(batches[-1], output)
            loss.backward()
            losses.append(loss)
            acc_loss = sum(losses)
            acc_grads = [
                p.grad.detach().clone() for p in fsdp_model.parameters()
            ]

            # Compare the losses and gradients
            torch.testing.assert_allclose(ref_loss, acc_loss)
            assert len(ref_grads) == len(acc_grads)
            for ref_grad, acc_grad in zip(ref_grads, acc_grads):
                assert ref_grad.device == acc_grad.device
                assert ref_grad.size() == acc_grad.size()
                assert ref_grad.dtype == acc_grad.dtype
                torch.testing.assert_allclose(ref_grad, acc_grad)

            # Check that the optimizer step does not error
            optim.step()
        finally:
            torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "num_iters_to_acc",
        [2, 4],
    )
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=False),
         CPUOffload(offload_params=True)],
    )
    @parametrize(
        "backward_prefetch",
        [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None])
    def test_no_sync(
        self,
        num_iters_to_acc: int,
        cpu_offload: CPUOffload,
        backward_prefetch: Optional[BackwardPrefetch],
    ):
        """Tests the ``no_sync()`` context manager."""
        assert num_iters_to_acc >= 2, \
            "Accumulate for at least 2 iterations to be nontrivial"
        self._test_no_sync(
            batch_dim=1,
            num_iters_to_acc=num_iters_to_acc,
            cpu_offload=cpu_offload,
            backward_prefetch=backward_prefetch,
        )
示例#20
0
    run_tests,
)

if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)

params = "cpu_offload,backward_prefetch,forward_prefetch,sharding_strategy"
cpu_offload_config = [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
backward_prefetch_config = [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]
forward_prefetch_config = ["forward_prefetch", "no_forward_prefetch"]
sharding_strategy_config = [ShardingStrategy.SHARD_GRAD_OP, None, ShardingStrategy.NO_SHARD]
configs = list(itertools.product(cpu_offload_config,
                                 backward_prefetch_config,
                                 forward_prefetch_config,
                                 sharding_strategy_config))
test_name_mapping = {
    str(CPUOffload(offload_params=True)): "offload_true",
    str(CPUOffload(offload_params=False)): "offload_false",
    str(BackwardPrefetch.BACKWARD_PRE): "backward_prefetch_pre",
    str(BackwardPrefetch.BACKWARD_POST): "backward_prefetch_post",
    "forward_prefetch": "forward_prefetch",
    "no_forward_prefetch": "no_forward_prefetch",
    str(ShardingStrategy.SHARD_GRAD_OP): "shard_grad_op",
class TestSummonFullParams(FSDPTest):
    @property
    def world_size(self):
        return 2

    def get_model_param_count(self, m):
        return sum([p.numel() for p in m.parameters()])

    # padding ensures that all shards have the same size with the least amount of padding
    def get_expected_sharded_size(self, global_size):
        return int(math.ceil(global_size / self.world_size))

    @skip_if_lt_x_gpu(2)
    @parametrize("writeback", [True, False])
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
    )
    @parametrize("modify_outer", [True, False])
    def test_summon_full_param_writeback(
        self, writeback, cpu_offload, modify_outer
    ):
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False)
            )
        ).cuda(self.rank)

        # set the value
        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        p = outer_param if modify_outer else inner_param

        with torch.no_grad():
            # This sets the local shard value
            p[0] = self.rank + 2

        with model._summon_full_params(writeback=writeback):
            with torch.no_grad():
                p.copy_(torch.zeros_like(p))

        if writeback:
            self.assertEqual(p.cpu()[0], 0)
        else:
            self.assertEqual(p.cpu()[0], self.rank + 2)

    @skip_if_lt_x_gpu(2)
    def test_summon_full_param_shard_value(self):

        raw_model = nn.Linear(10, 11)
        raw_model_size = self.get_model_param_count(raw_model)
        expected_shard_size = self.get_expected_sharded_size(raw_model_size)

        model = FSDP(raw_model.cuda(self.rank))
        self.assertEqual(expected_shard_size, self.get_model_param_count(model))

        # we're assuming a single flatenned param
        self.assertEqual(1, len(list(model.parameters())))

        my_shard = torch.clone(next(model.parameters()))

        with model._summon_full_params():
            self.assertEqual(raw_model_size, self.get_model_param_count(model))
            parameters = list(model.parameters())
            all_shards = FlatParameter(parameters, requires_grad=False)
            my_slice = torch.chunk(all_shards, self.world_size)[self.rank]

            # shards are padded but the full_param tensor is not
            a, b = my_shard[0 : my_slice.numel()], my_slice
            self.assertTrue(
                torch.equal(my_shard[0 : my_slice.numel()].cpu(), my_slice.cpu())
            )

    @skip_if_lt_x_gpu(2)
    @parametrize("recurse", [True, False])
    @parametrize("summon_outer", [True, False])
    def test_summon_full_param_recursive(self, recurse, summon_outer):
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False)
            )
        ).cuda(self.rank)

        global_inner_numel = self.get_model_param_count(nn.Linear(5, 5, bias=False))
        global_outer_numel = self.get_model_param_count(nn.Linear(5, 3, bias=False))

        shard_inner_numel = int(math.ceil(global_inner_numel / self.world_size))
        shard_outer_numel = int(math.ceil(global_outer_numel / self.world_size))

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        self.assertEqual(shard_outer_numel, outer_param.numel())
        self.assertEqual(shard_inner_numel, inner_param.numel())

        model_to_summon = model if summon_outer else model[0]
        # outer is summoned if _summon_full_param is called on the outer FSDP module
        expected_outer_numel = global_outer_numel if summon_outer else shard_outer_numel

        # inner is summoned if _summon_full_param is called with recursion or on the inner FSDP module
        expected_inner_numel = (
            global_inner_numel if recurse or not summon_outer else shard_inner_numel
        )

        with model_to_summon._summon_full_params(recurse=recurse):
            self.assertEqual(expected_outer_numel, outer_param.numel())
            self.assertEqual(expected_inner_numel, inner_param.numel())

    @skip_if_lt_x_gpu(2)
    def test_cannot_summon_full_params_from_forward(self):
        class MyModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.a = nn.Parameter(torch.zeros(5))

            def forward(self, fsdp_module):
                with fsdp_module._summon_full_params():
                    pass

        model = FSDP(MyModule()).cuda(self.rank)
        with self.assertRaisesRegex(
            ValueError, "current state is TrainingState_.FORWARD"
        ):
            model(model)

    @skip_if_lt_x_gpu(2)
    def test_cannot_summon_full_params_from_backward(self):
        model = FSDP(nn.Linear(2, 1)).cuda(self.rank)

        output = model(torch.ones(2).cuda(self.rank))

        def bad_backwards_hook(tensor):
            with model._summon_full_params():
                pass
            return None

        self.assertTrue(output.requires_grad)
        output.register_hook(bad_backwards_hook)

        with self.assertRaisesRegex(
            ValueError, "current state is TrainingState_.BACKWARD_PRE"
        ):
            output.backward()

    @skip_if_lt_x_gpu(2)
    def test_summon_full_params_respects_reshard_after_forward(self):
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False)
            )
        ).cuda(self.rank)

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        outer_full_param_size = outer_param.numel() * self.world_size

        # trigger lazy init
        model(torch.zeros(5).cuda(self.rank))

        # the root FSDP module keeps all params around
        self.assertEqual(
            outer_full_param_size, outer_param._full_param_padded.storage().size()
        )
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        # similarly _summon_full_params should have the same behavior
        with model._summon_full_params():
            pass
        self.assertEqual(
            outer_full_param_size, outer_param._full_param_padded.storage().size()
        )
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

    @skip_if_lt_x_gpu(2)
    def test_summon_single_param(self):
        model = FSDP(nn.Linear(1, 1, bias=False)).cuda(self.rank)

        p = model.get_parameter("_fsdp_wrapped_module.flat_param")
        self.assertEqual(1, p.numel())

        with torch.no_grad():
            # This sets the local shard value
            p[0] = self.rank + 2

        with model._summon_full_params(writeback=True):
            self.assertEqual(1, p.numel())
            with torch.no_grad():
                p.copy_(torch.zeros_like(p))

        # most ranks hold no data and wrote to padding so only rank zero will observe the above write
        if self.rank == 0:
            self.assertEqual(0, p[0])
        else:
            self.assertEqual(self.rank + 2, p[0])

    @skip_if_lt_x_gpu(2)
    def test_reshard_outside_forward_backward_iteration(self):
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 1, bias=False)
            )
        ).cuda(self.rank)

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        outer_full_param_size = outer_param.numel() * self.world_size

        # First lets validate our assumption about resharding

        output = model(torch.zeros(5).cuda(self.rank))
        # the root FSDP module keeps all params around
        self.assertEqual(
            outer_full_param_size, outer_param._full_param_padded.storage().size()
        )
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        # we reshard everything after backward() finishes
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        # now lets repeat it with summon done in between

        output = model(torch.zeros(5).cuda(self.rank))
        with model._summon_full_params():
            pass
        self.assertEqual(
            outer_full_param_size, outer_param._full_param_padded.storage().size()
        )
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        with model._summon_full_params():
            pass
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

    @skip_if_lt_x_gpu(2)
    def test_params_are_unflatenned(self):
        model = FSDP(nn.Linear(self.world_size, 1, bias=False)).cuda(self.rank)

        flattened_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        self.assertEqual(1, flattened_param.numel())

        with model._summon_full_params():
            a = model.weight.flatten().detach()
            b = flattened_param.detach()
            self.assertTrue(torch.equal(a, b))

    @skip_if_lt_x_gpu(2)
    def test_params_count_and_value(self):
        fsdp_model = FSDP(
            NestedWrappedModule(
                group=dist.distributed_c10d._get_default_group(),
                wrap_fsdp=True,
                fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
            )
        )
        model = NestedWrappedModule(
            group=dist.distributed_c10d._get_default_group(),
            wrap_fsdp=False,
            fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
        )
        with fsdp_model._summon_full_params():
            for p1, p2 in itertools.zip_longest(
                fsdp_model.parameters(), model.module.parameters()
            ):
                self.assertEqual(p1, p2)
示例#22
0
class TestFSDPStateDict(FSDPTest):
    @property
    def world_size(self):
        return 2

    def _broadcast_state_dict(self, state_dict):
        olist = [state_dict if self.rank == 0 else None]
        dist.broadcast_object_list(olist)
        return olist[0]

    def _get_simple_nested_model(self, *fsdp_args, **fsdp_kwargs):
        model = FSDP(
            nn.Sequential(
                FSDP(
                    nn.Linear(10, 10, bias=False).cuda(), *fsdp_args,
                    **fsdp_kwargs),
                nn.Linear(10, 10, bias=False).cuda(),
            ),
            *fsdp_args,
            **fsdp_kwargs,
        )
        return model

    def _get_simple_model(self, *fsdp_args, **fsdp_kwargs):
        model = FSDP(
            nn.Linear(10, 10, bias=False).cuda(), *fsdp_args, **fsdp_kwargs)
        return model

    def _get_full_state_dict_mgr(self, model, state_dict_rank0_and_offload):
        return FSDP.state_dict_type(
            model, StateDictType.FULL_STATE_DICT,
            FullStateDictConfig(
                rank0_only=state_dict_rank0_and_offload,
                offload_to_cpu=state_dict_rank0_and_offload,
            ))

    def _validate_state_dict_contents(self,
                                      fsdp_state_dict,
                                      state_dict_rank0_and_offload,
                                      ignore_keys=None):
        if state_dict_rank0_and_offload:
            if self.rank == 0:
                self.assertNotEqual(fsdp_state_dict, {})
                for key, tensor in fsdp_state_dict.items():
                    if ignore_keys and key in ignore_keys:
                        continue
                    self.assertEqual(
                        tensor.device, torch.device("cpu"),
                        f"{key} is unexpectedly on device {tensor.device}")
            else:
                self.assertEqual(fsdp_state_dict, {})

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)],
    )
    @parametrize("fp16", [True, False])
    @parametrize("state_dict_rank0_and_offload", [True, False])
    def test_basic_save_and_load_state_dict(self, cpu_offload, fp16,
                                            state_dict_rank0_and_offload):
        """
        Tests that we can save a state_dict and load it into a blank model
        with various configs such as fp16 and cpu offload and parameters
        match as expected.
        """
        for model_call in [
                partial(self._get_simple_nested_model,
                        cpu_offload=cpu_offload),
                partial(self._get_simple_model, cpu_offload=cpu_offload),
        ]:
            model = model_call()
            full_state_dict_mgr = self._get_full_state_dict_mgr(
                model, state_dict_rank0_and_offload)
            with full_state_dict_mgr:
                fsdp_state_dict = _get_state_dict(model,
                                                  cpu_offload.offload_params,
                                                  fp16)

            self._validate_state_dict_contents(fsdp_state_dict,
                                               state_dict_rank0_and_offload)
            if fp16:
                # Verify fp16 is the type
                for tensor in fsdp_state_dict.values():
                    self.assertEqual(tensor.dtype, torch.float16)

            model_new = model_call()
            if not cpu_offload.offload_params:
                model_new = model_new.cuda()
            if fp16:
                model_new.half()

            # zero the model to ensure parameters are different.
            _zero_model(model_new)

            with FullyShardedDataParallel.summon_full_params(model):
                with FullyShardedDataParallel.summon_full_params(model_new):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertNotEqual(params, params_new)

            # Verify parameters are the same in the new model.
            if state_dict_rank0_and_offload:
                # Broadcast the state dict and move it back to GPU in
                # preparation for loading.
                fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict)
                for key in fsdp_state_dict.keys():
                    fsdp_state_dict[key] = fsdp_state_dict[key].cuda()

            model_new.load_state_dict(fsdp_state_dict)
            with FullyShardedDataParallel.summon_full_params(model_new):
                with FullyShardedDataParallel.summon_full_params(model):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertEqual(params, params_new)
                    if fp16:
                        for tensor in model_new.parameters():
                            self.assertEqual(tensor.dtype, torch.float16)

    @skip_if_lt_x_gpu(2)
    @parametrize("mixed_precision", [True, False])
    @parametrize("state_dict_rank0_and_offload", [True, False])
    def test_save_and_load_after_forward_state_dict(
            self, mixed_precision, state_dict_rank0_and_offload):
        """
        Test that saving after some training results in params being updated as
        expected.
        """
        torch.cuda.set_device(self.rank)
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = self._get_simple_nested_model(mixed_precision=mixed_precision)
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
        initial_params = _get_full_detached_param(model)
        for _ in range(6):
            inp = torch.randn(1, 10, device=torch.cuda.current_device())
            output = model(*inp)
            loss = output.sum()
            expected_dtype = torch.float32 if mixed_precision is None else torch.float16
            self.assertEqual(expected_dtype, loss.dtype)
            loss.backward()
            optim.step()

        trained_params = _get_full_detached_param(model)
        # Ensure some training occured
        self.assertNotEqual(initial_params, trained_params)
        # Save a copy of the state_dict
        fsd_mgr = self._get_full_state_dict_mgr(model,
                                                state_dict_rank0_and_offload)
        with fsd_mgr:
            state_dict = {k: v.clone() for k, v in model.state_dict().items()}
        self._validate_state_dict_contents(state_dict,
                                           state_dict_rank0_and_offload)
        _zero_model(model)

        # Ensure checkpointed params have the full param dtype
        for tensor in state_dict.values():
            self.assertEqual(tensor.dtype, torch.float32)

        # Load state_dict into zeroed model
        if state_dict_rank0_and_offload:
            # Broadcast the state dict and move it back to GPU in
            # preparation for loading.
            state_dict = self._broadcast_state_dict(state_dict)
            for key in state_dict.keys():
                state_dict[key] = state_dict[key].cuda()

        model.load_state_dict(state_dict)
        loaded_params = _get_full_detached_param(model)
        self.assertEqual(loaded_params, trained_params)

    def _initialize_model(self, wrap_fsdp: bool, wrap_ddp: bool = True):
        # keep everything deterministic for input data
        torch.manual_seed(0)

        model = Model(wrap_fsdp).cuda()
        if wrap_fsdp:
            model = FSDP(model)
        elif wrap_ddp:
            model = DistributedDataParallel(model, device_ids=[self.rank])
        return model

    @staticmethod
    def _state_dict(model: Module, state_dict_type: str):
        try:
            enum_val = STATE_DICT_MAPPING[state_dict_type]
        except KeyError:
            raise ValueError(f"No state_dict type for {state_dict_type}")

        with FSDP.state_dict_type(model, enum_val):
            return model.state_dict()

    @staticmethod
    def _load_state_dict(model: Module, state_dict_type: str,
                         state_dict: Dict[str, Any]):
        try:
            enum_val = STATE_DICT_MAPPING[state_dict_type]
        except KeyError:
            raise ValueError(f"No state_dict for {state_dict_type}")

        with FSDP.state_dict_type(model, enum_val):
            return model.load_state_dict(state_dict)

    def _dist_train(self, wrap_fsdp: bool, state_dict_type: str = ""):
        # TODO: Move this test to common_fsdp.
        model = self._initialize_model(wrap_fsdp)
        optim = SGD(model.parameters(), lr=0.1)

        in_data = torch.rand(64,
                             4,
                             requires_grad=True,
                             device=torch.device("cuda"))
        for _ in range(3):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        if wrap_fsdp:
            blank_model = FSDP(Model(True).cuda())
            _zero_model(blank_model)
            state_dict = self._state_dict(model, state_dict_type)
            self._load_state_dict(blank_model, state_dict_type, state_dict)
            return get_full_params(blank_model)
        else:
            return list(model.parameters())

    @skip_if_lt_x_gpu(2)
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
    def test_state_dict_save_load_flow(self, state_dict_type):
        fsdp_params = self._dist_train(wrap_fsdp=True,
                                       state_dict_type=state_dict_type)
        ddp_params = self._dist_train(wrap_fsdp=False)
        self.assertEqual(ddp_params, fsdp_params)

    @skip_if_lt_x_gpu(2)
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
    def test_fsdp_state_dict_keys(self, state_dict_type):
        state_dict = self._state_dict(self._initialize_model(True),
                                      state_dict_type)
        if state_dict_type == "local_state_dict":
            self.assertEqual(set(["flat_param", "inner.flat_param"]),
                             state_dict.keys())
        elif state_dict_type == "state_dict":
            # Keys should match local model.
            local_model = self._initialize_model(wrap_fsdp=False,
                                                 wrap_ddp=False)
            local_keys = local_model.state_dict().keys()
            self.assertEqual(state_dict.keys(), local_keys)
        else:
            raise NotImplementedError(f"No test for {state_dict_type}!")

    @skip_if_lt_x_gpu(2)
    @parametrize("state_dict_rank0_and_offload", [True, False])
    def test_state_dict_load_into_local_module(self,
                                               state_dict_rank0_and_offload):
        """
        Tests that FSDP's state_dict can be loaded into a local model.
        """
        model = self._initialize_model(wrap_fsdp=True)
        optim = SGD(model.parameters(), lr=0.1)
        in_data = torch.rand(64,
                             4,
                             requires_grad=True,
                             device=torch.device("cuda"))
        for _ in range(3):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        with FullyShardedDataParallel.summon_full_params(model):
            fsdp_params = deepcopy(list(model.parameters()))

        # get FSDP state_dict. Note that by default we return full_state_dict.
        sd_mgr = self._get_full_state_dict_mgr(model,
                                               state_dict_rank0_and_offload)
        with sd_mgr:
            fsdp_state_dict = model.state_dict()

        self._validate_state_dict_contents(fsdp_state_dict,
                                           state_dict_rank0_and_offload)
        # Create zeroed local model
        blank_local_model = self._initialize_model(wrap_fsdp=False,
                                                   wrap_ddp=False)
        for param in blank_local_model.parameters():
            with torch.no_grad():
                param.zero_()

        # Load fsdp's full state dict into the local and verify params are as
        # expected.
        if state_dict_rank0_and_offload:
            # Broadcast + CUDA state_dict
            fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict)
            for key in fsdp_state_dict.keys():
                fsdp_state_dict[key] = fsdp_state_dict[key].cuda()

        blank_local_model.load_state_dict(fsdp_state_dict)
        local_params = list(blank_local_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)

    @skip_if_lt_x_gpu(2)
    @parametrize("double_nest", [True])
    def test_state_dict_skip_module(self, double_nest):
        torch.cuda.set_device(self.rank)

        def _create_module(wrap_fsdp=True):
            LINEAR_SKIP = "linear_skip"
            ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else suppress()
            with ctx:
                module = SkipModel(double_nest=double_nest)
                # Full name of linear_skip param tensors in SkipModel, as would be
                # stored in checkpoint.
                linear_skip_tensor_names = [
                    k for k in dict(module.named_parameters()).keys()
                    if LINEAR_SKIP in k
                ]
                # skip SkipModule
                linear_skip = getattr(module, LINEAR_SKIP)
                delattr(module, LINEAR_SKIP)
                # Wrap FSDP
                fsdp = wrap(module)
                # reattach
                setattr(module, LINEAR_SKIP, linear_skip)
                return fsdp, linear_skip_tensor_names

        fsdp, linear_skip_tensor_names = _create_module()
        # Run a forward pass
        inp = torch.randn((1, 10), device=torch.cuda.current_device())
        loss = fsdp(inp)
        loss.sum().backward()

        state_dict = fsdp.state_dict()
        if self.rank == 0:
            sd_keys = list(state_dict.keys())
            expected = list(SkipModel(double_nest=False).state_dict().keys())
            self.assertEqual(sorted(sd_keys), sorted(expected))
            # TODO: parameters in linear_skip_tensor_names should not be handled
            # by FSDP.state_dict(). Have a check once this is implemented in
            # FSDP.state_dict().

        # Check that it can be loaded into FSDP.
        new_fsdp, _ = _create_module()
        _zero_model(new_fsdp)
        for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()):
            self.assertNotEqual(p1, p2)
        new_fsdp.load_state_dict(deepcopy(state_dict))
        for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()):
            self.assertEqual(p1, p2)

        # Test that the checkpoint can be loaded into a local model.
        local, _ = _create_module(wrap_fsdp=False)
        for param in local.parameters():
            with torch.no_grad():
                param.zero_()

        with fsdp.summon_full_params(fsdp):
            for (p1, p2) in zip(fsdp.parameters(), local.parameters()):
                self.assertNotEqual(p1, p2)

        local.load_state_dict(deepcopy(state_dict))
        with fsdp.summon_full_params(fsdp):
            for (p1, p2) in zip(fsdp.parameters(), local.parameters()):
                self.assertEqual(p1, p2)

    @skip_if_lt_x_gpu(2)
    def test_wrong_state_dict_config(self):
        model = FSDP(Model(wrap_fsdp=True).cuda())
        with self.assertRaisesRegex(RuntimeError,
                                    "Expected state_dict_config of type"):
            with model.state_dict_type(model, StateDictType.FULL_STATE_DICT,
                                       LocalStateDictConfig()):
                pass

    @skip_if_lt_x_gpu(2)
    def test_state_dict_with_ignored_modules(self):
        # Initialize an FSDP-wrapped model with an ignored module
        model = Model(wrap_fsdp=True).cuda()
        ignored_modules = [model.outer]
        ignored_param_to_param_name = {
            model.outer.bias: "outer.bias",
            model.outer.weight: "outer.weight",
        }
        fsdp_model = FSDP(model, ignored_modules=ignored_modules)
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd = fsdp_model.state_dict()

        with FSDP.summon_full_params(fsdp_model):
            fsdp_params = deepcopy(list(fsdp_model.parameters()))
        # Check that the ignored parameters are not cloned

        for param, param_name in ignored_param_to_param_name.items():
            self.assertTrue(param_name in sd)
            self.assertEqual(param.data_ptr(), sd[param_name].data_ptr())
        # Check that the state dict can be loaded into a non-wrapped version of
        # the model
        nonwrapped_model = Model(wrap_fsdp=False).cuda()
        for param in nonwrapped_model.parameters():
            with torch.no_grad():
                param.zero_()

        nonwrapped_model.load_state_dict(sd)
        local_params = list(nonwrapped_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)
示例#23
0
class TestFSDPStateDict(FSDPTest):
    @property
    def world_size(self):
        return 2

    def _get_simple_nested_model(self, *fsdp_args, **fsdp_kwargs):
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(10, 10, bias=False), *fsdp_args, **fsdp_kwargs),
                nn.Linear(10, 10, bias=False),
            ),
            *fsdp_args,
            **fsdp_kwargs,
        )
        return model

    def _get_simple_model(self, *fsdp_args, **fsdp_kwargs):
        model = FSDP(nn.Linear(10, 10, bias=False), *fsdp_args, **fsdp_kwargs)
        return model

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
    )
    @parametrize("fp16", [True, False])
    def test_basic_save_and_load_state_dict(self, cpu_offload, fp16):
        """
        Tests that we can save a state_dict and load it into a blank model
        with various configs such as fp16 and cpu offload and parameters
        match as expected.
        """
        for model_call in [
            partial(self._get_simple_nested_model, cpu_offload=cpu_offload),
            partial(self._get_simple_model, cpu_offload=cpu_offload),
        ]:
            model = model_call()
            fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, fp16)
            if fp16:
                # Verify fp16 is the type
                for tensor in fsdp_state_dict.values():
                    self.assertEqual(tensor.dtype, torch.float16)

            model_new = model_call()
            if not cpu_offload.offload_params:
                model_new = model_new.cuda()
            if fp16:
                model_new.half()

            # zero the model to ensure parameters are different.
            _zero_model(model_new)

            with model.summon_full_params(), model_new.summon_full_params():
                params = list(model.parameters())
                params_new = list(model_new.parameters())
                self.assertNotEqual(params, params_new)

            # Verify parameters are the same in the new model.
            model_new.load_state_dict(fsdp_state_dict)
            with model_new.summon_full_params():
                with model.summon_full_params():
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertEqual(params, params_new)
                    if fp16:
                        for tensor in model_new.parameters():
                            self.assertEqual(tensor.dtype, torch.float16)

    @skip_if_lt_x_gpu(2)
    def test_save_and_load_after_forward_state_dict(self):
        """
        Test that saving after some training results in params being updated as
        expected.
        """
        torch.cuda.set_device(self.rank)
        model = self._get_wrapped_model(group=torch.distributed.distributed_c10d._get_default_group())
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
        initial_params = _get_full_detached_param(model)
        for _ in range(6):
            inp = model.module.get_input(torch.device("cuda"))
            output = model(*inp)
            loss = model.module.get_loss(inp, output).cuda()
            model.module.run_backward(loss)
            optim.step()

        trained_params = _get_full_detached_param(model)
        # Ensure some training occured
        self.assertNotEqual(initial_params, trained_params)
        # Save a copy of the state_dict
        state_dict = {k: v.clone() for k, v in model.state_dict().items()}
        _zero_model(model)

        # Load state_dict into zeroed model
        model.load_state_dict(state_dict)
        loaded_params = _get_full_detached_param(model)
        self.assertEqual(loaded_params, trained_params)

    def _initialize_model(self, wrap_fsdp: bool, wrap_ddp: bool = True):
        # keep everything deterministic for input data
        torch.manual_seed(0)

        model = Model(wrap_fsdp).cuda()
        if wrap_fsdp:
            model = FSDP(model)
        elif wrap_ddp:
            model = DistributedDataParallel(model, device_ids=[self.rank])
        return model

    @staticmethod
    def _state_dict(model: Module, state_dict_type: str):
        try:
            enum_val = STATE_DICT_MAPPING[state_dict_type]
        except KeyError:
            raise ValueError(f"No state_dict type for {state_dict_type}")

        with FSDP.state_dict_type(model, enum_val):
            return model.state_dict()

    @staticmethod
    def _load_state_dict(
        model: Module, state_dict_type: str, state_dict: Dict[str, Any]
    ):
        try:
            enum_val = STATE_DICT_MAPPING[state_dict_type]
        except KeyError:
            raise ValueError(f"No state_dict for {state_dict_type}")

        with FSDP.state_dict_type(model, enum_val):
            return model.load_state_dict(state_dict)

    def _dist_train(self, wrap_fsdp: bool, state_dict_type: str = ""):
        # TODO: Move this test to common_fsdp.
        model = self._initialize_model(wrap_fsdp)
        optim = SGD(model.parameters(), lr=0.1)

        in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda"))
        for _ in range(3):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        if wrap_fsdp:
            blank_model = FSDP(Model(True).cuda())
            _zero_model(blank_model)
            state_dict = self._state_dict(model, state_dict_type)
            self._load_state_dict(blank_model, state_dict_type, state_dict)
            return get_full_params(blank_model)
        else:
            return list(model.parameters())

    @skip_if_lt_x_gpu(2)
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
    def test_state_dict_save_load_flow(self, state_dict_type):
        fsdp_params = self._dist_train(wrap_fsdp=True, state_dict_type=state_dict_type)
        ddp_params = self._dist_train(wrap_fsdp=False)
        self.assertEqual(ddp_params, fsdp_params)

    @skip_if_lt_x_gpu(2)
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
    def test_fsdp_state_dict_keys(self, state_dict_type):
        state_dict = self._state_dict(self._initialize_model(True), state_dict_type)
        if state_dict_type == "local_state_dict":
            self.assertEqual(set(["flat_param", "inner.flat_param"]), state_dict.keys())
        elif state_dict_type == "state_dict":
            # Keys should match local model.
            local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False)
            local_keys = local_model.state_dict().keys()
            self.assertEqual(state_dict.keys(), local_keys)
        else:
            raise NotImplementedError(f"No test for {state_dict_type}!")

    @skip_if_lt_x_gpu(2)
    def test_state_dict_load_into_local_module(self):
        """
        Tests that FSDP's state_dict can be loaded into a local model.
        """
        model = self._initialize_model(wrap_fsdp=True)
        optim = SGD(model.parameters(), lr=0.1)
        in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda"))
        for _ in range(3):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        with model.summon_full_params():
            fsdp_params = deepcopy(list(model.parameters()))

        # get FSDP state_dict. Note that by default we return state_dict.
        fsdp_state_dict = model.state_dict()
        # Create zeroed local model
        blank_local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False)
        for param in blank_local_model.parameters():
            with torch.no_grad():
                param.zero_()

        # Load fsdp's full state dict into the local and verify params are as
        # expected.
        blank_local_model.load_state_dict(fsdp_state_dict)
        local_params = list(blank_local_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)
示例#24
0
class TestParityWithDDP(FSDPTest):
    """
    Compare losses and parameter values after several updates when using
    PyTorch DDP vs. FullyShardedDataParallel.
    """
    def _get_init_modes_for_test(self, cpu_offload):
        modes = [FSDPInitMode.CUDA_AFTER, FSDPInitMode.CUDA_BEFORE]
        # Note that FSDPInitMode.CUDA_NEVER works currently only with CPU
        # offload as we explicitly bring the param back to CUDA device. In
        # general, it will not work since we try to all_gather p.data which is
        # on CPU but NCCL only supports GPU.
        if cpu_offload.offload_params:
            modes.append(FSDPInitMode.CUDA_NEVER)

        return modes

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)])
    @parametrize(
        "backward_prefetch",
        [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None])
    def test_nested_wrapped_model(self, cpu_offload, backward_prefetch):
        init_modes = self._get_init_modes_for_test(cpu_offload)
        for fsdp_init_mode in init_modes:
            with self.subTest(fsdp_init_mode=fsdp_init_mode):
                self._test_identical_outputs(
                    NestedWrappedModule,
                    fsdp_init_mode=fsdp_init_mode,
                    cpu_offload=cpu_offload,
                    backward_prefetch=backward_prefetch,
                )

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)])
    @parametrize(
        "backward_prefetch",
        [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None])
    @parametrize("clip_norm_type", [2.0, None])
    def test_nested_all_wrapped_model(self, cpu_offload, backward_prefetch,
                                      clip_norm_type):
        init_modes = self._get_init_modes_for_test(cpu_offload)
        for fsdp_init_mode in init_modes:
            with self.subTest(fsdp_init_mode=fsdp_init_mode):
                model_fn = functools.partial(NestedWrappedModule,
                                             wrap_everything=True)
                self._test_identical_outputs(
                    model_fn,
                    fsdp_init_mode=fsdp_init_mode,
                    cpu_offload=cpu_offload,
                    backward_prefetch=backward_prefetch,
                    norm_type=clip_norm_type,
                )

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)])
    @parametrize(
        "backward_prefetch",
        [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None])
    @parametrize("clip_norm_type", [2.0, None])
    def test_transformer_parameterized(self, cpu_offload, backward_prefetch,
                                       clip_norm_type):
        init_modes = self._get_init_modes_for_test(cpu_offload)
        for fsdp_init_mode in init_modes:
            with self.subTest(fsdp_init_mode=fsdp_init_mode):
                self._test_identical_outputs(
                    TransformerWithSharedParams,
                    fsdp_init_mode=fsdp_init_mode,
                    cpu_offload=cpu_offload,
                    backward_prefetch=backward_prefetch,
                    norm_type=clip_norm_type,
                )

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)])
    @parametrize(
        "backward_prefetch",
        [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None])
    def test_delayed_optim_step(self, cpu_offload, backward_prefetch):
        # We use a model with a long CUDA delay right before the optimizer step.
        # This tests our streams logic, and that we don't start the allgather
        # until after the optimization step completes.
        init_modes = self._get_init_modes_for_test(cpu_offload)
        for fsdp_init_mode in init_modes:
            with self.subTest(fsdp_init_mode=fsdp_init_mode):
                model_fn = functools.partial(NestedWrappedModuleWithDelay,
                                             delay_after_loss_ms=250)
                self._test_identical_outputs(
                    model_fn,
                    fsdp_init_mode=fsdp_init_mode,
                    cpu_offload=cpu_offload,
                    backward_prefetch=backward_prefetch,
                )

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)])
    @parametrize(
        "backward_prefetch",
        [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None])
    def test_delayed_reduce_scatter(self, cpu_offload, backward_prefetch):
        # We insert a delay in the torch.distributed._reduce_scatter_base op, so that
        # the post_backward_stream takes much longer than the backward pass.
        # This tests that we properly block at the end of the backward pass for
        # the reductions to finish.
        init_modes = self._get_init_modes_for_test(cpu_offload)
        for fsdp_init_mode in init_modes:
            with self.subTest(fsdp_init_mode=fsdp_init_mode):
                model_fn = functools.partial(NestedWrappedModuleWithDelay,
                                             delay_before_reduction_ms=250)
                self._test_identical_outputs(
                    model_fn,
                    fsdp_init_mode=fsdp_init_mode,
                    cpu_offload=cpu_offload,
                    backward_prefetch=backward_prefetch,
                )

    def _dummy_ddp_fn(self, model):
        return DummyDDP(model)

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)])
    @parametrize(
        "backward_prefetch",
        [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None])
    @parametrize("clip_norm_type", [2.0, None])
    def test_mixture_of_experts(self, cpu_offload, backward_prefetch,
                                clip_norm_type):
        init_modes = self._get_init_modes_for_test(cpu_offload)
        for fsdp_init_mode in init_modes:
            with self.subTest(fsdp_init_mode=fsdp_init_mode):
                self._test_identical_outputs(
                    MixtureOfExperts,
                    # MixtureOfExperts implements custom reduce logic, so the reference
                    # behavior should use that logic instead of PyTorch DDP.
                    ref_ddp_fn=self._dummy_ddp_fn,
                    fsdp_init_mode=fsdp_init_mode,
                    cpu_offload=cpu_offload,
                    backward_prefetch=backward_prefetch,
                    norm_type=clip_norm_type,
                )

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)])
    @parametrize(
        "backward_prefetch",
        [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None])
    def test_mixture_of_experts_with_delay_before_free(self, cpu_offload,
                                                       backward_prefetch):
        init_modes = self._get_init_modes_for_test(cpu_offload)
        for fsdp_init_mode in init_modes:
            with self.subTest(fsdp_init_mode=fsdp_init_mode):
                model_fn = functools.partial(MixtureOfExperts,
                                             delay_before_free_ms=250)
                self._test_identical_outputs(
                    model_fn,
                    ref_ddp_fn=self._dummy_ddp_fn,
                    fsdp_init_mode=fsdp_init_mode,
                    cpu_offload=cpu_offload,
                    backward_prefetch=backward_prefetch,
                )