Example #1
0
 def model_sharded_context(self) -> Generator:
     log.detail(f"{self.__class__.__name__}: entered model_sharded_context.")
     with enable_wrap(
         wrapper_cls=FullyShardedDataParallel,
         process_group=self.process_group,
         cpu_offload=self.cpu_offload,
         backward_prefetch=self.backward_prefetch,
         mixed_precision=self.mixed_precision_config,
         device_id=self.root_device.index,
         **self.kwargs,
     ):
         yield
Example #2
0
 def test_wrap(self, wrap_method):
     if wrap_method == WrapMethod.WRAP_API:
         with enable_wrap(wrapper_cls=FSDP,
                          process_group=self.process_group):
             layer = wrap(nn.Linear(5, 5))
     else:
         assert wrap_method == WrapMethod.FSDP_CTOR
         layer = FSDP(nn.Linear(5, 5),
                      process_group=self.process_group,
                      auto_wrap_policy=functools.partial(
                          size_based_auto_wrap_policy, min_num_params=1))
     self.assertTrue(isinstance(layer, FSDP))
     self.assertEqual(layer.rank, self.process_group.rank())
     self.assertEqual(layer.world_size, self.process_group.size())
Example #3
0
    def test_wrap_disabled_outside_context(self):
        pg = self.process_group

        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin = wrap(nn.Linear(5, 5), process_group=pg)

        model = MyModel()
        with enable_wrap(wrapper_cls=FSDP, process_group=pg):
            model = wrap(model)

        self.assertTrue(isinstance(model, FSDP))
        self.assertFalse(isinstance(model.lin, FSDP))
        self.assertTrue(isinstance(model.lin, nn.Linear))
 def _create_module(wrap_fsdp=True):
     LINEAR_SKIP = "linear_skip"
     ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else suppress()
     with ctx:
         module = SkipModel(double_nest=double_nest)
         # Full name of linear_skip param tensors in SkipModel, as would be
         # stored in checkpoint.
         linear_skip_tensor_names = [
             k for k in dict(module.named_parameters()).keys()
             if LINEAR_SKIP in k
         ]
         # skip SkipModule
         linear_skip = getattr(module, LINEAR_SKIP)
         delattr(module, LINEAR_SKIP)
         # Wrap FSDP
         fsdp = wrap(module)
         # reattach
         setattr(module, LINEAR_SKIP, linear_skip)
         return fsdp, linear_skip_tensor_names
Example #5
0
 def test_always_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
     sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
     ignored_modules = [sequential[1], sequential[2][0]]
     fsdp_kwargs = {
         "process_group": self.process_group,
         "auto_wrap_policy": always_wrap_policy,
         "ignored_modules": ignored_modules,
     }
     if wrap_method == WrapMethod.FSDP_CTOR:
         model = FSDP(sequential, **fsdp_kwargs)
     elif wrap_method == WrapMethod.WRAP_API:
         with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):
             model = wrap(sequential)
     else:
         assert 0, f"Unsupported wrap method: {wrap_method}"
     # All non-ignored modules should be wrapped with FSDP
     self.assertTrue(isinstance(model, FSDP))
     self.assertTrue(isinstance(model.module[0], FSDP))
     self.assertTrue(isinstance(model.module[1], nn.Linear))
     self.assertTrue(isinstance(model.module[2], FSDP))
     self.assertTrue(isinstance(model.module[2].module[0], nn.Linear))
     self.assertTrue(isinstance(model.module[2].module[1], FSDP))
Example #6
0
    def test_distributed_checkpoint(self, state_dict_type) -> None:
        with enable_wrap(wrapper_cls=FSDP):
            torch.manual_seed(100)
            model = wrap(SkipModel(double_nest=True))
            torch.manual_seed(200)
            new_model = wrap(SkipModel(double_nest=True))

        with FullyShardedDataParallel.summon_full_params(
                model), FullyShardedDataParallel.summon_full_params(new_model):
            params = list(model.parameters())
            new_params = list(new_model.parameters())
            self.assertNotEqual(params, new_params)

        with tempfile.TemporaryDirectory() as path:
            paths = [path]
            dist.broadcast_object_list(paths)
            path = paths[0]
            writer = FileSystemWriter(path)
            reader = FileSystemReader(path)
            with FSDP.state_dict_type(model,
                                      state_dict_type), FSDP.state_dict_type(
                                          new_model, state_dict_type):
                state_dict = model.state_dict()

            save_state_dict(state_dict, writer)

            with FSDP.state_dict_type(model,
                                      state_dict_type), FSDP.state_dict_type(
                                          new_model, state_dict_type):
                state_dict = new_model.state_dict()
                load_state_dict(state_dict, reader)
                new_model.load_state_dict(state_dict)

        with FullyShardedDataParallel.summon_full_params(
                model), FullyShardedDataParallel.summon_full_params(new_model):
            params = list(model.parameters())
            new_params = list(new_model.parameters())
            self.assertEqual(params, new_params)