示例#1
0
    def test_basic_save_and_load_state_dict(self, cpu_offload, fp16,
                                            state_dict_rank0_and_offload):
        """
        Tests that we can save a state_dict and load it into a blank model
        with various configs such as fp16 and cpu offload and parameters
        match as expected.
        """
        for model_call in [
                partial(self._get_simple_nested_model,
                        cpu_offload=cpu_offload),
                partial(self._get_simple_model, cpu_offload=cpu_offload),
        ]:
            model = model_call()
            full_state_dict_mgr = self._get_full_state_dict_mgr(
                model, state_dict_rank0_and_offload)
            with full_state_dict_mgr:
                fsdp_state_dict = _get_state_dict(model,
                                                  cpu_offload.offload_params,
                                                  fp16)

            self._validate_state_dict_contents(fsdp_state_dict,
                                               state_dict_rank0_and_offload)
            if fp16:
                # Verify fp16 is the type
                for tensor in fsdp_state_dict.values():
                    self.assertEqual(tensor.dtype, torch.float16)

            model_new = model_call()
            if not cpu_offload.offload_params:
                model_new = model_new.cuda()
            if fp16:
                model_new.half()

            # zero the model to ensure parameters are different.
            _zero_model(model_new)

            with FullyShardedDataParallel.summon_full_params(model):
                with FullyShardedDataParallel.summon_full_params(model_new):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertNotEqual(params, params_new)

            # Verify parameters are the same in the new model.
            if state_dict_rank0_and_offload:
                # Broadcast the state dict and move it back to GPU in
                # preparation for loading.
                fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict)
                for key in fsdp_state_dict.keys():
                    fsdp_state_dict[key] = fsdp_state_dict[key].cuda()

            model_new.load_state_dict(fsdp_state_dict)
            with FullyShardedDataParallel.summon_full_params(model_new):
                with FullyShardedDataParallel.summon_full_params(model):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertEqual(params, params_new)
                    if fp16:
                        for tensor in model_new.parameters():
                            self.assertEqual(tensor.dtype, torch.float16)
示例#2
0
    def test_state_dict_rank0_offload_save_load_flow(self):
        # Test taking checkpoint on rank 0 only, and reload
        # without redundant CPU memories.
        model = TransformerWithSharedParams(
            group=dist.distributed_c10d._get_default_group())
        my_auto_wrap_policy = partial(transformer_auto_wrap_policy,
                                      transformer_layer_cls={
                                          TransformerEncoderLayer,
                                          TransformerDecoderLayer
                                      })
        model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy)
        ctx = self._get_state_dict_mgr(model, "state_dict", True)
        with ctx:
            state_dict = deepcopy(_get_state_dict(model))

        # All ranks initialize non-FSDP model
        grp = dist.distributed_c10d._get_default_group()
        model_new = TransformerWithSharedParams(group=grp)
        for p in model_new.parameters():
            with torch.no_grad():
                p.zero_()
        # Only rank 0 loads the checkpoint
        if self.rank == 0:
            model_new.load_state_dict(state_dict)

        # TransformerWithSharedParams has a buffer of zeros, so can't pass in
        # self.assertNotEqual since the buffers would be equal. So just checking that
        # there is some difference in the model across ranks before state_dict is
        # broadcasted.
        with self.assertRaisesRegex(AssertionError,
                                    "Tensor-likes are not close"):
            _validate(model_new, process_group=grp, assert_fn=self.assertEqual)
        # FSDP with sync_module_states=True broadcasts the checkpointed states.
        model_new = FSDP(model_new,
                         device_id=torch.cuda.current_device(),
                         auto_wrap_policy=my_auto_wrap_policy,
                         sync_module_states=True)
        # After wrapping with FSDP models are equal across ranks, and have loaded the checkpoint
        with FSDP.summon_full_params(model_new):
            _validate(model_new, process_group=grp, assert_fn=self.assertEqual)

        with FullyShardedDataParallel.summon_full_params(model):
            with FullyShardedDataParallel.summon_full_params(model_new):
                params = list(model.parameters())
                params_new = list(model_new.parameters())
                self.assertEqual(params, params_new)
示例#3
0
 def test_fsdp_state_dict_with_activation_checkpoint(self, checkpoint_wrap):
     for model_call in [
             partial(self._get_simple_model),
             partial(self._get_simple_nested_model)
     ]:
         model = model_call(
             checkpoint_wrap=(checkpoint_wrap in ["first", "both"]))
         state_dict = _get_state_dict(model, False, False)
         # Possibly wrap new model in activation checkpoint wrapper to test save/
         # load with this wrapper
         model_new = model_call(
             checkpoint_wrap=(checkpoint_wrap in ["second", "both"]))
         _zero_model(model_new)
         self._compare_models(model, model_new, self.assertNotEqual)
         # Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks
         model_new.load_state_dict(state_dict)
         self._compare_models(model, model_new, self.assertEqual)
示例#4
0
    def test_basic_save_and_load_state_dict(self, cpu_offload, fp16):
        """
        Tests that we can save a state_dict and load it into a blank model
        with various configs such as fp16 and cpu offload and parameters
        match as expected.
        """
        for model_call in [
                partial(self._get_simple_nested_model,
                        cpu_offload=cpu_offload),
                partial(self._get_simple_model, cpu_offload=cpu_offload),
        ]:
            model = model_call()
            fsdp_state_dict = _get_state_dict(model,
                                              cpu_offload.offload_params, fp16)
            if fp16:
                # Verify fp16 is the type
                for tensor in fsdp_state_dict.values():
                    self.assertEqual(tensor.dtype, torch.float16)

            model_new = model_call()
            if not cpu_offload.offload_params:
                model_new = model_new.cuda()
            if fp16:
                model_new.half()

            # zero the model to ensure parameters are different.
            _zero_model(model_new)

            with FullyShardedDataParallel.summon_full_params(model):
                with FullyShardedDataParallel.summon_full_params(model_new):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertNotEqual(params, params_new)

            # Verify parameters are the same in the new model.
            model_new.load_state_dict(fsdp_state_dict)
            with FullyShardedDataParallel.summon_full_params(model_new):
                with FullyShardedDataParallel.summon_full_params(model):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertEqual(params, params_new)
                    if fp16:
                        for tensor in model_new.parameters():
                            self.assertEqual(tensor.dtype, torch.float16)
示例#5
0
 def test_fsdp_state_dict_with_activation_checkpoint(self, checkpoint_wrap):
     """Tests saving the state dict, zeroing a target model's parameters, and
     loading the state dict, where the source and target models may have a
     checkpoint wrapper."""
     for model_call in [
             partial(self._get_simple_model),
             partial(self._get_simple_nested_model)
     ]:
         model = model_call(
             checkpoint_wrap=(checkpoint_wrap in ["first", "both"]))
         state_dict = _get_state_dict(model, False, False)
         # Possibly wrap new model in activation checkpoint wrapper to test save/
         # load with this wrapper
         model_new = model_call(
             checkpoint_wrap=(checkpoint_wrap in ["second", "both"]))
         _zero_model(model_new)
         self._compare_models(model, model_new, self.assertNotEqual)
         # Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks
         model_new.load_state_dict(state_dict, strict=True)
         self._compare_models(model, model_new, self.assertEqual)
示例#6
0
    def test_basic_save_and_load_state_dict(self, state_dict_type, cpu_offload,
                                            fp16,
                                            state_dict_rank0_and_offload):
        """
        Tests that we can save a state_dict and load it into a blank model
        with various configs such as fp16 and cpu offload and parameters
        match as expected.
        """
        if state_dict_rank0_and_offload and state_dict_type != "state_dict":
            return
        for model_call in [
                partial(self._get_non_fsdp_root_module,
                        cpu_offload=cpu_offload),
                partial(self._get_simple_nested_model,
                        cpu_offload=cpu_offload),
                partial(self._get_simple_model, cpu_offload=cpu_offload),
        ]:
            model = model_call()

            ctx = self._get_state_dict_mgr(model, state_dict_type,
                                           state_dict_rank0_and_offload)
            with ctx:
                fsdp_state_dict = _get_state_dict(model,
                                                  cpu_offload.offload_params,
                                                  fp16)

            ignore_keys = [
                k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k
            ]

            self._validate_state_dict_contents(
                model,
                fsdp_state_dict,
                state_dict_rank0_and_offload,
                ignore_keys=ignore_keys,
            )
            if fp16:
                # Verify fp16 is the type
                for tensor in fsdp_state_dict.values():
                    self.assertEqual(tensor.dtype, torch.float16)

            model_new = model_call()
            if not cpu_offload.offload_params:
                model_new = model_new.cuda()
            if fp16:
                model_new.half()

            # zero the model to ensure parameters are different.
            _zero_model(model_new)
            self._compare_models(model, model_new, self.assertNotEqual)

            # Verify parameters are the same in the new model.
            if state_dict_rank0_and_offload:
                # Broadcast the state dict and move it back to GPU in
                # preparation for loading.
                if not isinstance(model, FSDP):
                    # Move everything to CPU to avoid running into
                    # https://github.com/pytorch/pytorch/issues/77113, some params
                    # will still be on GPU for non FSDP root modules.
                    for k in fsdp_state_dict.keys():
                        fsdp_state_dict[k] = fsdp_state_dict[k].cpu()
                fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict)
                for key in fsdp_state_dict.keys():
                    fsdp_state_dict[key] = fsdp_state_dict[key].cuda()
            with FSDP.state_dict_type(model_new,
                                      STATE_DICT_MAPPING[state_dict_type]):
                model_new.load_state_dict(fsdp_state_dict, strict=True)

            self._compare_models(model,
                                 model_new,
                                 self.assertEqual,
                                 check_fp16=fp16)
示例#7
0
 def test_state_dict_rank0_offload_save_load_flow(self):
     """Tests saving a model checkpoint only on rank 0 and loading it only
     on rank 0 with ``sync_module_states=True`` to emulate the workflow to
     avoid redundant CPU memory usage."""
     auto_wrap_policy = partial(
         transformer_auto_wrap_policy,
         transformer_layer_cls={
             TransformerEncoderLayer, TransformerDecoderLayer
         },
     )
     fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy}
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE,
         fsdp_kwargs,
     )
     # Force model parameters and buffers to be nonzero
     with FSDP.summon_full_params(fsdp_model):
         for tensor in itertools.chain(fsdp_model.parameters(),
                                       fsdp_model.buffers()):
             if torch.count_nonzero(tensor) == 0:
                 with torch.no_grad():
                     tensor.add_(
                         torch.tensor(1,
                                      dtype=tensor.dtype,
                                      device=tensor.device))
     with self._get_state_dict_mgr(fsdp_model, "state_dict", True):
         state_dict = deepcopy(_get_state_dict(fsdp_model))
     # Initialize a non-wrapped model on all ranks
     new_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
     )
     _zero_model(new_model, zero_buffers=True)
     # Only load the checkpoint on rank 0
     if self.rank == 0:
         new_model.load_state_dict(state_dict, strict=True)
     _assert_module_states(
         new_model,
         process_group=self.process_group,
         assert_fn=self.assertNotEqual,
     )
     # Broadcast the module states from rank 0 with `sync_module_states=True`
     new_fsdp_model = FSDP(
         new_model,
         device_id=torch.cuda.current_device(),
         auto_wrap_policy=auto_wrap_policy,
         sync_module_states=True,
     )
     # Check FSDP models are equal across ranks
     with FSDP.summon_full_params(new_fsdp_model):
         _assert_module_states(
             new_fsdp_model,
             process_group=self.process_group,
             assert_fn=self.assertEqual,
         )
     # Check FSDP models correctly loaded the checkpoint
     with FullyShardedDataParallel.summon_full_params(fsdp_model):
         with FullyShardedDataParallel.summon_full_params(new_fsdp_model):
             params = list(fsdp_model.parameters())
             params_new = list(new_fsdp_model.parameters())
             self.assertEqual(params, params_new)