Пример #1
0
    def test_basic_save_and_load_state_dict(self, cpu_offload, fp16,
                                            state_dict_rank0_and_offload):
        """
        Tests that we can save a state_dict and load it into a blank model
        with various configs such as fp16 and cpu offload and parameters
        match as expected.
        """
        for model_call in [
                partial(self._get_simple_nested_model,
                        cpu_offload=cpu_offload),
                partial(self._get_simple_model, cpu_offload=cpu_offload),
        ]:
            model = model_call()
            full_state_dict_mgr = self._get_full_state_dict_mgr(
                model, state_dict_rank0_and_offload)
            with full_state_dict_mgr:
                fsdp_state_dict = _get_state_dict(model,
                                                  cpu_offload.offload_params,
                                                  fp16)

            self._validate_state_dict_contents(fsdp_state_dict,
                                               state_dict_rank0_and_offload)
            if fp16:
                # Verify fp16 is the type
                for tensor in fsdp_state_dict.values():
                    self.assertEqual(tensor.dtype, torch.float16)

            model_new = model_call()
            if not cpu_offload.offload_params:
                model_new = model_new.cuda()
            if fp16:
                model_new.half()

            # zero the model to ensure parameters are different.
            _zero_model(model_new)

            with FullyShardedDataParallel.summon_full_params(model):
                with FullyShardedDataParallel.summon_full_params(model_new):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertNotEqual(params, params_new)

            # Verify parameters are the same in the new model.
            if state_dict_rank0_and_offload:
                # Broadcast the state dict and move it back to GPU in
                # preparation for loading.
                fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict)
                for key in fsdp_state_dict.keys():
                    fsdp_state_dict[key] = fsdp_state_dict[key].cuda()

            model_new.load_state_dict(fsdp_state_dict)
            with FullyShardedDataParallel.summon_full_params(model_new):
                with FullyShardedDataParallel.summon_full_params(model):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertEqual(params, params_new)
                    if fp16:
                        for tensor in model_new.parameters():
                            self.assertEqual(tensor.dtype, torch.float16)
Пример #2
0
 def _compare_models(self, model, model_new, assert_fn, check_fp16=False):
     with FullyShardedDataParallel.summon_full_params(model):
         with FullyShardedDataParallel.summon_full_params(model_new):
             params = list(model.parameters())
             params_new = list(model_new.parameters())
             assert_fn(params, params_new)
             if check_fp16:
                 for tensor in model_new.parameters():
                     self.assertEqual(tensor.dtype, torch.float16)
Пример #3
0
    def test_auto_wrap_smoke_test(self, cuda_init_mode, cpu_offload,
                                  use_device_id):
        # CPU offload and CUDA after don't work together as expected.
        if (cpu_offload.offload_params
                and cuda_init_mode == CUDAInitMode.CUDA_AFTER):
            return

        device = torch.device("cuda")
        torch.cuda.set_device(0)
        device_id = (torch.device("cuda", torch.cuda.current_device())
                     if use_device_id else None)

        # Random port in case the next test run quickly, same port would cause conflict.
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(find_free_port())

        file_name = tempfile.NamedTemporaryFile(delete=False).name
        torch.distributed.init_process_group(
            backend="nccl",
            init_method=f"{FILE_SCHEMA}_{file_name}",
            rank=0,
            world_size=1,
        )

        # NOTE: We move model to CUDA after init with FSDP to simulate real use
        # cases where full model cannot be loaded onto GPU, but their shards can.
        cuda_after_init = cuda_init_mode == CUDAInitMode.CUDA_AFTER
        try:
            sequential = TestFSDPWrap.NestedSequentialModel.get_model(
                cuda=(not cuda_after_init))
            my_auto_wrap_policy = functools.partial(
                size_based_auto_wrap_policy, min_num_params=40)
            model = FSDP(sequential,
                         cpu_offload=cpu_offload,
                         auto_wrap_policy=my_auto_wrap_policy,
                         device_id=device_id)
            TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
            if cuda_after_init:
                model = model.cuda()
            input = torch.rand((1, 5), dtype=torch.float).to(device)
            output = model(input)
            loss = F.mse_loss(input, output)
            loss.backward()
        finally:
            torch.distributed.destroy_process_group()

        try:
            os.remove(file_name)
        except FileNotFoundError:
            pass
Пример #4
0
    def test_state_dict_load_into_local_module(self):
        """
        Tests that FSDP's state_dict can be loaded into a local model.
        """
        model = self._initialize_model(wrap_fsdp=True)
        optim = SGD(model.parameters(), lr=0.1)
        in_data = torch.rand(64,
                             4,
                             requires_grad=True,
                             device=torch.device("cuda"))
        for _ in range(3):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        with FullyShardedDataParallel.summon_full_params(model):
            fsdp_params = deepcopy(list(model.parameters()))

        # get FSDP state_dict. Note that by default we return state_dict.
        fsdp_state_dict = model.state_dict()
        # Create zeroed local model
        blank_local_model = self._initialize_model(wrap_fsdp=False,
                                                   wrap_ddp=False)
        for param in blank_local_model.parameters():
            with torch.no_grad():
                param.zero_()

        # Load fsdp's full state dict into the local and verify params are as
        # expected.
        blank_local_model.load_state_dict(fsdp_state_dict)
        local_params = list(blank_local_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)
Пример #5
0
    def test_state_dict_rank0_offload_save_load_flow(self):
        # Test taking checkpoint on rank 0 only, and reload
        # without redundant CPU memories.
        model = TransformerWithSharedParams(
            group=dist.distributed_c10d._get_default_group())
        my_auto_wrap_policy = partial(transformer_auto_wrap_policy,
                                      transformer_layer_cls={
                                          TransformerEncoderLayer,
                                          TransformerDecoderLayer
                                      })
        model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy)
        ctx = self._get_state_dict_mgr(model, "state_dict", True)
        with ctx:
            state_dict = deepcopy(_get_state_dict(model))

        # All ranks initialize non-FSDP model
        grp = dist.distributed_c10d._get_default_group()
        model_new = TransformerWithSharedParams(group=grp)
        for p in model_new.parameters():
            with torch.no_grad():
                p.zero_()
        # Only rank 0 loads the checkpoint
        if self.rank == 0:
            model_new.load_state_dict(state_dict)

        # TransformerWithSharedParams has a buffer of zeros, so can't pass in
        # self.assertNotEqual since the buffers would be equal. So just checking that
        # there is some difference in the model across ranks before state_dict is
        # broadcasted.
        with self.assertRaisesRegex(AssertionError,
                                    "Tensor-likes are not close"):
            _validate(model_new, process_group=grp, assert_fn=self.assertEqual)
        # FSDP with sync_module_states=True broadcasts the checkpointed states.
        model_new = FSDP(model_new,
                         device_id=torch.cuda.current_device(),
                         auto_wrap_policy=my_auto_wrap_policy,
                         sync_module_states=True)
        # After wrapping with FSDP models are equal across ranks, and have loaded the checkpoint
        with FSDP.summon_full_params(model_new):
            _validate(model_new, process_group=grp, assert_fn=self.assertEqual)

        with FullyShardedDataParallel.summon_full_params(model):
            with FullyShardedDataParallel.summon_full_params(model_new):
                params = list(model.parameters())
                params_new = list(model_new.parameters())
                self.assertEqual(params, params_new)
Пример #6
0
    def test_basic_save_and_load_state_dict(self, cpu_offload, fp16):
        """
        Tests that we can save a state_dict and load it into a blank model
        with various configs such as fp16 and cpu offload and parameters
        match as expected.
        """
        for model_call in [
                partial(self._get_simple_nested_model,
                        cpu_offload=cpu_offload),
                partial(self._get_simple_model, cpu_offload=cpu_offload),
        ]:
            model = model_call()
            fsdp_state_dict = _get_state_dict(model,
                                              cpu_offload.offload_params, fp16)
            if fp16:
                # Verify fp16 is the type
                for tensor in fsdp_state_dict.values():
                    self.assertEqual(tensor.dtype, torch.float16)

            model_new = model_call()
            if not cpu_offload.offload_params:
                model_new = model_new.cuda()
            if fp16:
                model_new.half()

            # zero the model to ensure parameters are different.
            _zero_model(model_new)

            with FullyShardedDataParallel.summon_full_params(model):
                with FullyShardedDataParallel.summon_full_params(model_new):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertNotEqual(params, params_new)

            # Verify parameters are the same in the new model.
            model_new.load_state_dict(fsdp_state_dict)
            with FullyShardedDataParallel.summon_full_params(model_new):
                with FullyShardedDataParallel.summon_full_params(model):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertEqual(params, params_new)
                    if fp16:
                        for tensor in model_new.parameters():
                            self.assertEqual(tensor.dtype, torch.float16)
    def test_state_dict_load_into_local_module(
        self, state_dict_type, state_dict_rank0_and_offload
    ):
        """
        Tests that FSDP's state_dict can be loaded into a local model.
        """
        if state_dict_rank0_and_offload and state_dict_type != "state_dict":
            return
        model = self._initialize_model(wrap_fsdp=True, register_buffers=True)
        optim = SGD(model.parameters(), lr=0.1)
        in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda"))
        for _ in range(3):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        with FullyShardedDataParallel.summon_full_params(model):
            fsdp_params = deepcopy(list(model.parameters()))

        # get FSDP state_dict. Note that by default we return full_state_dict.
        sd_mgr = self._get_state_dict_mgr(
            model, state_dict_type, state_dict_rank0_and_offload
        )
        with sd_mgr:
            fsdp_state_dict = model.state_dict()

        self._validate_state_dict_contents(
            fsdp_state_dict, state_dict_rank0_and_offload
        )
        # Create zeroed local model
        blank_local_model = self._initialize_model(
            wrap_fsdp=False, wrap_ddp=False, register_buffers=True,
        )
        for param in blank_local_model.parameters():
            with torch.no_grad():
                param.zero_()

        fsdp_state_dict = _gather_state_dict(fsdp_state_dict)

        # Load fsdp's full state dict into the local and verify params are as
        # expected.
        if state_dict_rank0_and_offload:
            # Broadcast + CUDA state_dict
            fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict)
            for key in fsdp_state_dict.keys():
                fsdp_state_dict[key] = fsdp_state_dict[key].cuda()

        if self.rank == 0:
            blank_local_model.load_state_dict(fsdp_state_dict)
            local_params = list(blank_local_model.parameters())
            for fsdp_param, local_param in zip(fsdp_params, local_params):
                self.assertEqual(fsdp_param, local_param)
Пример #8
0
    def test_distributed_checkpoint(self, state_dict_type) -> None:
        with enable_wrap(wrapper_cls=FSDP):
            torch.manual_seed(100)
            model = wrap(SkipModel(double_nest=True))
            torch.manual_seed(200)
            new_model = wrap(SkipModel(double_nest=True))

        with FullyShardedDataParallel.summon_full_params(
                model), FullyShardedDataParallel.summon_full_params(new_model):
            params = list(model.parameters())
            new_params = list(new_model.parameters())
            self.assertNotEqual(params, new_params)

        with tempfile.TemporaryDirectory() as path:
            paths = [path]
            dist.broadcast_object_list(paths)
            path = paths[0]
            writer = FileSystemWriter(path)
            reader = FileSystemReader(path)
            with FSDP.state_dict_type(model,
                                      state_dict_type), FSDP.state_dict_type(
                                          new_model, state_dict_type):
                state_dict = model.state_dict()

            save_state_dict(state_dict, writer)

            with FSDP.state_dict_type(model,
                                      state_dict_type), FSDP.state_dict_type(
                                          new_model, state_dict_type):
                state_dict = new_model.state_dict()
                load_state_dict(state_dict, reader)
                new_model.load_state_dict(state_dict)

        with FullyShardedDataParallel.summon_full_params(
                model), FullyShardedDataParallel.summon_full_params(new_model):
            params = list(model.parameters())
            new_params = list(new_model.parameters())
            self.assertEqual(params, new_params)
    def test_summon_from_non_fsdp(self):
        class FSDPContainer(nn.Module):
            def __init__(self, fsdp_1, fsdp_2, fsdp_3):
                super().__init__()
                self.fsdp_1 = fsdp_1
                self.fsdp_2 = fsdp_2
                self.fsdp_3 = fsdp_3

        model_fsdp = FSDPContainer(
            FSDP(DeterministicModel(wrap_fsdp=True)),
            FSDP(DeterministicModel(wrap_fsdp=True)),
            DeterministicModel(wrap_fsdp=False),
        )
        model_no_fsdp = FSDPContainer(
            DeterministicModel(wrap_fsdp=False),
            DeterministicModel(wrap_fsdp=False),
            DeterministicModel(wrap_fsdp=False),
        )

        params_to_compare = list(model_no_fsdp.parameters())
        with FullyShardedDataParallel.summon_full_params(model_fsdp):
            fsdp_params = [p.clone() for p in model_fsdp.parameters()]

        self.assertEqual(params_to_compare, fsdp_params)
Пример #10
0
    def test_state_dict_load_into_local_module(
        self,
        state_dict_type,
        state_dict_rank0_and_offload,
        fsdp_root,
    ):
        """
        Tests that FSDP's state_dict can be loaded into a local model.
        """
        if state_dict_rank0_and_offload and state_dict_type != "state_dict":
            return
        if not fsdp_root:
            model = self._get_non_fsdp_root_module()
        else:
            model = self._initialize_model(wrap_fsdp=True,
                                           register_buffers=True)
        optim = SGD(model.parameters(), lr=0.1)
        if not fsdp_root:
            in_data = torch.randn(1,
                                  10,
                                  requires_grad=True,
                                  device=torch.device("cuda"))
        else:
            in_data = torch.rand(64,
                                 4,
                                 requires_grad=True,
                                 device=torch.device("cuda"))
        for _ in range(3):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        with FullyShardedDataParallel.summon_full_params(model):
            fsdp_params = deepcopy(list(model.parameters()))

        # get FSDP state_dict. Note that by default we return full_state_dict.
        sd_mgr = self._get_state_dict_mgr(model, state_dict_type,
                                          state_dict_rank0_and_offload)
        with sd_mgr:
            fsdp_state_dict = model.state_dict()

        ignore_keys = [
            k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k
        ]
        self._validate_state_dict_contents(
            model,
            fsdp_state_dict,
            state_dict_rank0_and_offload,
            ignore_keys=ignore_keys,
        )
        # Create zeroed local model
        if not fsdp_root:
            blank_local_model = self._get_non_fsdp_root_module(wrap=False)
        else:
            blank_local_model = self._initialize_model(wrap_fsdp=False,
                                                       wrap_ddp=False,
                                                       register_buffers=True)

        # Nothing should be FSDP
        for mod in blank_local_model.modules():
            self.assertFalse(isinstance(mod, FSDP))

        for param in blank_local_model.parameters():
            with torch.no_grad():
                param.zero_()

        fsdp_state_dict = _gather_state_dict(fsdp_state_dict)

        # Load fsdp's full state dict into the local and verify params are as
        # expected.
        if state_dict_rank0_and_offload:
            # Broadcast + CUDA state_dict
            if not isinstance(model, FSDP):
                # Some portions of the model on rank 0 might not be on CPU,
                # move everything to CPU to avoid running into
                # https://github.com/pytorch/pytorch/issues/77113.
                for k, t in fsdp_state_dict.items():
                    if t.device != torch.device("cpu"):
                        fsdp_state_dict[k] = t.cpu()
            fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict)
            for key in fsdp_state_dict.keys():
                fsdp_state_dict[key] = fsdp_state_dict[key].cuda()

        # if self.rank == 0:
        blank_local_model.load_state_dict(fsdp_state_dict, strict=True)
        local_params = list(blank_local_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)
Пример #11
0
 def test_state_dict_rank0_offload_save_load_flow(self):
     """Tests saving a model checkpoint only on rank 0 and loading it only
     on rank 0 with ``sync_module_states=True`` to emulate the workflow to
     avoid redundant CPU memory usage."""
     auto_wrap_policy = partial(
         transformer_auto_wrap_policy,
         transformer_layer_cls={
             TransformerEncoderLayer, TransformerDecoderLayer
         },
     )
     fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy}
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE,
         fsdp_kwargs,
     )
     # Force model parameters and buffers to be nonzero
     with FSDP.summon_full_params(fsdp_model):
         for tensor in itertools.chain(fsdp_model.parameters(),
                                       fsdp_model.buffers()):
             if torch.count_nonzero(tensor) == 0:
                 with torch.no_grad():
                     tensor.add_(
                         torch.tensor(1,
                                      dtype=tensor.dtype,
                                      device=tensor.device))
     with self._get_state_dict_mgr(fsdp_model, "state_dict", True):
         state_dict = deepcopy(_get_state_dict(fsdp_model))
     # Initialize a non-wrapped model on all ranks
     new_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
     )
     _zero_model(new_model, zero_buffers=True)
     # Only load the checkpoint on rank 0
     if self.rank == 0:
         new_model.load_state_dict(state_dict, strict=True)
     _assert_module_states(
         new_model,
         process_group=self.process_group,
         assert_fn=self.assertNotEqual,
     )
     # Broadcast the module states from rank 0 with `sync_module_states=True`
     new_fsdp_model = FSDP(
         new_model,
         device_id=torch.cuda.current_device(),
         auto_wrap_policy=auto_wrap_policy,
         sync_module_states=True,
     )
     # Check FSDP models are equal across ranks
     with FSDP.summon_full_params(new_fsdp_model):
         _assert_module_states(
             new_fsdp_model,
             process_group=self.process_group,
             assert_fn=self.assertEqual,
         )
     # Check FSDP models correctly loaded the checkpoint
     with FullyShardedDataParallel.summon_full_params(fsdp_model):
         with FullyShardedDataParallel.summon_full_params(new_fsdp_model):
             params = list(fsdp_model.parameters())
             params_new = list(new_fsdp_model.parameters())
             self.assertEqual(params, params_new)
Пример #12
0
    def test_main_wrap_api(self, cpu_offload, backward_prefetch,
                           forward_prefetch, cuda_init_mode):

        if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params:
            # they don't work together, expected
            return

        move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE

        class Nested(nn.Module):
            def __init__(self):
                super().__init__()
                self.nested_lin = _maybe_cuda(nn.Linear(1, 1, bias=False),
                                              move_to_cuda)

            def forward(self, input):
                return self.nested_lin(input)

        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin1 = _maybe_cuda(nn.Linear(1, 1, bias=False),
                                        move_to_cuda)
                self.lin2 = _maybe_cuda(nn.Linear(1, 1, bias=False),
                                        move_to_cuda)
                self.lin3 = _maybe_cuda(nn.Linear(1, 1, bias=False),
                                        move_to_cuda)
                self.lin4 = Nested()

            def forward(self, input):
                return self.lin4(self.lin3(self.lin2(self.lin1(input))))

        model = MyModel()
        wrapped_model = FSDP(
            model,
            auto_wrap_policy=functools.partial(
                size_based_auto_wrap_policy,
                min_num_params=0,  # wrap all modules
            ),
            cpu_offload=cpu_offload,
            backward_prefetch=backward_prefetch,
            forward_prefetch=forward_prefetch,
        )
        if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
            wrapped_model = wrapped_model.cuda()

        modules_in_fsdp_graph_order = [
            wrapped_model.module.lin1, wrapped_model.module.lin2,
            wrapped_model.module.lin3,
            wrapped_model.module.lin4.module.nested_lin,
            wrapped_model.module.lin4, wrapped_model
        ]

        for module in modules_in_fsdp_graph_order:
            self.assertTrue(isinstance(module, FSDP))
            self._check_cpu_offload(module, cpu_offload)
            self._check_backward_prefetch(module, backward_prefetch)
            self._check_forward_prefetch(module, forward_prefetch)

        # Run model a few times for sanity check.
        optim = torch.optim.SGD(wrapped_model.parameters(),
                                lr=1e-2,
                                momentum=0.9)
        inp = torch.ones(1).cuda()
        for _ in range(6):
            optim.zero_grad()
            loss = wrapped_model(inp).sum()
            loss.backward()
            optim.step()

        # Since we ran with backward prefetch, verify backward prefetch related
        # data.
        for i, module in enumerate(modules_in_fsdp_graph_order):
            self.assertEqual(i, module._my_fsdp_idx_in_graph)
            self.assertTrue(
                module._fsdp_graph_order == modules_in_fsdp_graph_order)