def test_save_and_load_after_forward_state_dict(self):
        """
        Test that saving after some training results in params being updated as
        expected.
        """
        torch.cuda.set_device(self.rank)
        model = self._get_wrapped_model(group=torch.distributed.distributed_c10d._get_default_group())
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
        initial_params = _get_full_detached_param(model)
        for _ in range(6):
            inp = model.module.get_input(torch.device("cuda"))
            output = model(*inp)
            loss = model.module.get_loss(inp, output).cuda()
            model.module.run_backward(loss)
            optim.step()

        trained_params = _get_full_detached_param(model)
        # Ensure some training occured
        self.assertNotEqual(initial_params, trained_params)
        # Save a copy of the state_dict
        state_dict = {k: v.clone() for k, v in model.state_dict().items()}
        _zero_model(model)

        # Load state_dict into zeroed model
        model.load_state_dict(state_dict)
        loaded_params = _get_full_detached_param(model)
        self.assertEqual(loaded_params, trained_params)
Esempio n. 2
0
    def test_save_and_load_after_forward_state_dict(self, mixed_precision):
        """
        Test that saving after some training results in params being updated as
        expected.
        """
        torch.cuda.set_device(self.rank)
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = self._get_simple_nested_model(mixed_precision=mixed_precision)
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
        initial_params = _get_full_detached_param(model)
        for _ in range(6):
            inp = torch.randn(1, 10, device=torch.cuda.current_device())
            output = model(*inp)
            loss = output.sum()
            expected_dtype = torch.float32 if mixed_precision is None else torch.float16
            self.assertEqual(expected_dtype, loss.dtype)
            loss.backward()
            optim.step()

        trained_params = _get_full_detached_param(model)
        # Ensure some training occured
        self.assertNotEqual(initial_params, trained_params)
        # Save a copy of the state_dict
        state_dict = {k: v.clone() for k, v in model.state_dict().items()}
        _zero_model(model)

        # Ensure checkpointed params have the full param dtype
        for tensor in state_dict.values():
            self.assertEqual(tensor.dtype, torch.float32)

        # Load state_dict into zeroed model
        model.load_state_dict(state_dict)
        loaded_params = _get_full_detached_param(model)
        self.assertEqual(loaded_params, trained_params)
Esempio n. 3
0
    def test_basic_save_and_load_state_dict(self, cpu_offload, fp16,
                                            state_dict_rank0_and_offload):
        """
        Tests that we can save a state_dict and load it into a blank model
        with various configs such as fp16 and cpu offload and parameters
        match as expected.
        """
        for model_call in [
                partial(self._get_simple_nested_model,
                        cpu_offload=cpu_offload),
                partial(self._get_simple_model, cpu_offload=cpu_offload),
        ]:
            model = model_call()
            full_state_dict_mgr = self._get_full_state_dict_mgr(
                model, state_dict_rank0_and_offload)
            with full_state_dict_mgr:
                fsdp_state_dict = _get_state_dict(model,
                                                  cpu_offload.offload_params,
                                                  fp16)

            self._validate_state_dict_contents(fsdp_state_dict,
                                               state_dict_rank0_and_offload)
            if fp16:
                # Verify fp16 is the type
                for tensor in fsdp_state_dict.values():
                    self.assertEqual(tensor.dtype, torch.float16)

            model_new = model_call()
            if not cpu_offload.offload_params:
                model_new = model_new.cuda()
            if fp16:
                model_new.half()

            # zero the model to ensure parameters are different.
            _zero_model(model_new)

            with FullyShardedDataParallel.summon_full_params(model):
                with FullyShardedDataParallel.summon_full_params(model_new):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertNotEqual(params, params_new)

            # Verify parameters are the same in the new model.
            if state_dict_rank0_and_offload:
                # Broadcast the state dict and move it back to GPU in
                # preparation for loading.
                fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict)
                for key in fsdp_state_dict.keys():
                    fsdp_state_dict[key] = fsdp_state_dict[key].cuda()

            model_new.load_state_dict(fsdp_state_dict)
            with FullyShardedDataParallel.summon_full_params(model_new):
                with FullyShardedDataParallel.summon_full_params(model):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertEqual(params, params_new)
                    if fp16:
                        for tensor in model_new.parameters():
                            self.assertEqual(tensor.dtype, torch.float16)
Esempio n. 4
0
    def test_save_and_load_after_forward_state_dict(
            self, mixed_precision, state_dict_rank0_and_offload):
        """
        Test that saving after some training results in params being updated as
        expected.
        """
        torch.cuda.set_device(self.rank)
        mixed_precision = MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
        ) if mixed_precision else None
        model = self._get_simple_nested_model(mixed_precision=mixed_precision)
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
        initial_params = _get_full_detached_param(model)
        for _ in range(6):
            inp = torch.randn(1, 10, device=torch.cuda.current_device())
            output = model(*inp)
            loss = output.sum()
            expected_dtype = torch.float32 if mixed_precision is None else torch.float16
            self.assertEqual(expected_dtype, loss.dtype)
            loss.backward()
            optim.step()

        trained_params = _get_full_detached_param(model)
        # Ensure some training occured
        self.assertNotEqual(initial_params, trained_params)
        # Save a copy of the state_dict
        fsd_mgr = self._get_full_state_dict_mgr(model,
                                                state_dict_rank0_and_offload)
        with fsd_mgr:
            state_dict = {k: v.clone() for k, v in model.state_dict().items()}
        self._validate_state_dict_contents(state_dict,
                                           state_dict_rank0_and_offload)
        _zero_model(model)

        # Ensure checkpointed params have the full param dtype
        for tensor in state_dict.values():
            self.assertEqual(tensor.dtype, torch.float32)

        # Load state_dict into zeroed model
        if state_dict_rank0_and_offload:
            # Broadcast the state dict and move it back to GPU in
            # preparation for loading.
            state_dict = self._broadcast_state_dict(state_dict)
            for key in state_dict.keys():
                state_dict[key] = state_dict[key].cuda()

        model.load_state_dict(state_dict)
        loaded_params = _get_full_detached_param(model)
        self.assertEqual(loaded_params, trained_params)
Esempio n. 5
0
 def test_fsdp_state_dict_with_activation_checkpoint(self, checkpoint_wrap):
     for model_call in [
             partial(self._get_simple_model),
             partial(self._get_simple_nested_model)
     ]:
         model = model_call(
             checkpoint_wrap=(checkpoint_wrap in ["first", "both"]))
         state_dict = _get_state_dict(model, False, False)
         # Possibly wrap new model in activation checkpoint wrapper to test save/
         # load with this wrapper
         model_new = model_call(
             checkpoint_wrap=(checkpoint_wrap in ["second", "both"]))
         _zero_model(model_new)
         self._compare_models(model, model_new, self.assertNotEqual)
         # Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks
         model_new.load_state_dict(state_dict)
         self._compare_models(model, model_new, self.assertEqual)
Esempio n. 6
0
    def test_basic_save_and_load_state_dict(self, cpu_offload, fp16):
        """
        Tests that we can save a state_dict and load it into a blank model
        with various configs such as fp16 and cpu offload and parameters
        match as expected.
        """
        for model_call in [
                partial(self._get_simple_nested_model,
                        cpu_offload=cpu_offload),
                partial(self._get_simple_model, cpu_offload=cpu_offload),
        ]:
            model = model_call()
            fsdp_state_dict = _get_state_dict(model,
                                              cpu_offload.offload_params, fp16)
            if fp16:
                # Verify fp16 is the type
                for tensor in fsdp_state_dict.values():
                    self.assertEqual(tensor.dtype, torch.float16)

            model_new = model_call()
            if not cpu_offload.offload_params:
                model_new = model_new.cuda()
            if fp16:
                model_new.half()

            # zero the model to ensure parameters are different.
            _zero_model(model_new)

            with FullyShardedDataParallel.summon_full_params(model):
                with FullyShardedDataParallel.summon_full_params(model_new):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertNotEqual(params, params_new)

            # Verify parameters are the same in the new model.
            model_new.load_state_dict(fsdp_state_dict)
            with FullyShardedDataParallel.summon_full_params(model_new):
                with FullyShardedDataParallel.summon_full_params(model):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertEqual(params, params_new)
                    if fp16:
                        for tensor in model_new.parameters():
                            self.assertEqual(tensor.dtype, torch.float16)
Esempio n. 7
0
 def test_fsdp_state_dict_with_activation_checkpoint(self, checkpoint_wrap):
     """Tests saving the state dict, zeroing a target model's parameters, and
     loading the state dict, where the source and target models may have a
     checkpoint wrapper."""
     for model_call in [
             partial(self._get_simple_model),
             partial(self._get_simple_nested_model)
     ]:
         model = model_call(
             checkpoint_wrap=(checkpoint_wrap in ["first", "both"]))
         state_dict = _get_state_dict(model, False, False)
         # Possibly wrap new model in activation checkpoint wrapper to test save/
         # load with this wrapper
         model_new = model_call(
             checkpoint_wrap=(checkpoint_wrap in ["second", "both"]))
         _zero_model(model_new)
         self._compare_models(model, model_new, self.assertNotEqual)
         # Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks
         model_new.load_state_dict(state_dict, strict=True)
         self._compare_models(model, model_new, self.assertEqual)
    def _dist_train(self, wrap_fsdp: bool, state_dict_type: str = ""):
        # TODO: Move this test to common_fsdp.
        model = self._initialize_model(wrap_fsdp)
        optim = SGD(model.parameters(), lr=0.1)

        in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda"))
        for _ in range(3):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        if wrap_fsdp:
            blank_model = FSDP(Model(True).cuda())
            _zero_model(blank_model)
            state_dict = self._state_dict(model, state_dict_type)
            self._load_state_dict(blank_model, state_dict_type, state_dict)
            return get_full_params(blank_model)
        else:
            return list(model.parameters())
Esempio n. 9
0
    def _dist_train(self,
                    wrap_fsdp: bool,
                    state_dict_type: str = "",
                    with_context: bool = False):
        # TODO: Move this test to common_fsdp.
        model = self._initialize_model(wrap_fsdp)
        optim = SGD(model.parameters(), lr=0.1)

        in_data = torch.rand(64,
                             4,
                             requires_grad=True,
                             device=torch.device("cuda"))
        for _ in range(3):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        if wrap_fsdp:
            blank_model = FSDP(Model(True).cuda())
            _zero_model(blank_model)
            if with_context:
                state_dict_type = {
                    "state_dict": StateDictType.FULL_STATE_DICT,
                    "local_state_dict": StateDictType.LOCAL_STATE_DICT,
                    "sharded_state_dict": StateDictType.SHARDED_STATE_DICT,
                }[state_dict_type]
                with model.state_dict_type(state_dict_type):
                    state_dict = model.state_dict()
                with blank_model.state_dict_type(state_dict_type):
                    blank_model.load_state_dict(state_dict)
            else:
                state_dict = self._state_dict(model, state_dict_type)
                self._load_state_dict(blank_model, state_dict_type, state_dict)
            return get_full_params(blank_model)
        else:
            return list(model.parameters())
Esempio n. 10
0
    def test_state_dict_skip_module(self, state_dict_type, double_nest):
        torch.cuda.set_device(self.rank)

        def _create_module(wrap_fsdp=True):
            LINEAR_SKIP = "linear_skip"
            ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else suppress()
            with ctx:
                module = SkipModel(double_nest=double_nest)
                # Full name of linear_skip param tensors in SkipModel, as would be
                # stored in checkpoint.
                linear_skip_tensor_names = [
                    k for k in dict(module.named_parameters()).keys()
                    if LINEAR_SKIP in k
                ]
                # skip SkipModule
                linear_skip = getattr(module, LINEAR_SKIP)
                delattr(module, LINEAR_SKIP)
                # Wrap FSDP
                fsdp = wrap(module)
                # reattach
                setattr(module, LINEAR_SKIP, linear_skip)
                return fsdp, linear_skip_tensor_names

        fsdp, linear_skip_tensor_names = _create_module()
        # Run a forward pass
        inp = torch.randn((1, 10), device=torch.cuda.current_device())
        loss = fsdp(inp)
        loss.sum().backward()

        with FSDP.state_dict_type(fsdp, STATE_DICT_MAPPING[state_dict_type]):
            state_dict = fsdp.state_dict()
        if self.rank == 0 and state_dict_type != "local_state_dict":
            sd_keys = list(state_dict.keys())
            expected = list(SkipModel(double_nest=False).state_dict().keys())
            self.assertEqual(sorted(sd_keys), sorted(expected))
            # TODO: parameters in linear_skip_tensor_names should not be handled
            # by FSDP.state_dict(). Have a check once this is implemented in
            # FSDP.state_dict().

        # Check that it can be loaded into FSDP.
        new_fsdp, _ = _create_module()
        _zero_model(new_fsdp)
        for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()):
            self.assertNotEqual(p1, p2)
        with FSDP.state_dict_type(new_fsdp,
                                  STATE_DICT_MAPPING[state_dict_type]):
            if state_dict_type != "local_state_dict":
                # FlatParameter has not supported deepcopy yet.
                state_dict = deepcopy(state_dict)
            new_fsdp.load_state_dict(state_dict, strict=True)
        for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()):
            self.assertEqual(p1, p2)

        # Test that the checkpoint can be loaded into a local model.
        local, _ = _create_module(wrap_fsdp=False)
        for param in local.parameters():
            with torch.no_grad():
                param.zero_()

        with fsdp.summon_full_params(fsdp):
            for (p1, p2) in zip(fsdp.parameters(), local.parameters()):
                self.assertNotEqual(p1, p2)

        if state_dict_type == "local_state_dict":
            return
        state_dict = _gather_state_dict(state_dict)
        with fsdp.summon_full_params(fsdp):
            if self.rank == 0:
                local.load_state_dict(state_dict, strict=True)
                for (p1, p2) in zip(fsdp.parameters(), local.parameters()):
                    self.assertEqual(p1, p2)
Esempio n. 11
0
    def test_basic_save_and_load_state_dict(self, state_dict_type, cpu_offload,
                                            fp16,
                                            state_dict_rank0_and_offload):
        """
        Tests that we can save a state_dict and load it into a blank model
        with various configs such as fp16 and cpu offload and parameters
        match as expected.
        """
        if state_dict_rank0_and_offload and state_dict_type != "state_dict":
            return
        for model_call in [
                partial(self._get_non_fsdp_root_module,
                        cpu_offload=cpu_offload),
                partial(self._get_simple_nested_model,
                        cpu_offload=cpu_offload),
                partial(self._get_simple_model, cpu_offload=cpu_offload),
        ]:
            model = model_call()

            ctx = self._get_state_dict_mgr(model, state_dict_type,
                                           state_dict_rank0_and_offload)
            with ctx:
                fsdp_state_dict = _get_state_dict(model,
                                                  cpu_offload.offload_params,
                                                  fp16)

            ignore_keys = [
                k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k
            ]

            self._validate_state_dict_contents(
                model,
                fsdp_state_dict,
                state_dict_rank0_and_offload,
                ignore_keys=ignore_keys,
            )
            if fp16:
                # Verify fp16 is the type
                for tensor in fsdp_state_dict.values():
                    self.assertEqual(tensor.dtype, torch.float16)

            model_new = model_call()
            if not cpu_offload.offload_params:
                model_new = model_new.cuda()
            if fp16:
                model_new.half()

            # zero the model to ensure parameters are different.
            _zero_model(model_new)
            self._compare_models(model, model_new, self.assertNotEqual)

            # Verify parameters are the same in the new model.
            if state_dict_rank0_and_offload:
                # Broadcast the state dict and move it back to GPU in
                # preparation for loading.
                if not isinstance(model, FSDP):
                    # Move everything to CPU to avoid running into
                    # https://github.com/pytorch/pytorch/issues/77113, some params
                    # will still be on GPU for non FSDP root modules.
                    for k in fsdp_state_dict.keys():
                        fsdp_state_dict[k] = fsdp_state_dict[k].cpu()
                fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict)
                for key in fsdp_state_dict.keys():
                    fsdp_state_dict[key] = fsdp_state_dict[key].cuda()
            with FSDP.state_dict_type(model_new,
                                      STATE_DICT_MAPPING[state_dict_type]):
                model_new.load_state_dict(fsdp_state_dict, strict=True)

            self._compare_models(model,
                                 model_new,
                                 self.assertEqual,
                                 check_fp16=fp16)
Esempio n. 12
0
 def test_state_dict_rank0_offload_save_load_flow(self):
     """Tests saving a model checkpoint only on rank 0 and loading it only
     on rank 0 with ``sync_module_states=True`` to emulate the workflow to
     avoid redundant CPU memory usage."""
     auto_wrap_policy = partial(
         transformer_auto_wrap_policy,
         transformer_layer_cls={
             TransformerEncoderLayer, TransformerDecoderLayer
         },
     )
     fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy}
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE,
         fsdp_kwargs,
     )
     # Force model parameters and buffers to be nonzero
     with FSDP.summon_full_params(fsdp_model):
         for tensor in itertools.chain(fsdp_model.parameters(),
                                       fsdp_model.buffers()):
             if torch.count_nonzero(tensor) == 0:
                 with torch.no_grad():
                     tensor.add_(
                         torch.tensor(1,
                                      dtype=tensor.dtype,
                                      device=tensor.device))
     with self._get_state_dict_mgr(fsdp_model, "state_dict", True):
         state_dict = deepcopy(_get_state_dict(fsdp_model))
     # Initialize a non-wrapped model on all ranks
     new_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
     )
     _zero_model(new_model, zero_buffers=True)
     # Only load the checkpoint on rank 0
     if self.rank == 0:
         new_model.load_state_dict(state_dict, strict=True)
     _assert_module_states(
         new_model,
         process_group=self.process_group,
         assert_fn=self.assertNotEqual,
     )
     # Broadcast the module states from rank 0 with `sync_module_states=True`
     new_fsdp_model = FSDP(
         new_model,
         device_id=torch.cuda.current_device(),
         auto_wrap_policy=auto_wrap_policy,
         sync_module_states=True,
     )
     # Check FSDP models are equal across ranks
     with FSDP.summon_full_params(new_fsdp_model):
         _assert_module_states(
             new_fsdp_model,
             process_group=self.process_group,
             assert_fn=self.assertEqual,
         )
     # Check FSDP models correctly loaded the checkpoint
     with FullyShardedDataParallel.summon_full_params(fsdp_model):
         with FullyShardedDataParallel.summon_full_params(new_fsdp_model):
             params = list(fsdp_model.parameters())
             params_new = list(new_fsdp_model.parameters())
             self.assertEqual(params, params_new)