コード例 #1
0
def _run_test_summon_full_param_writeback(cls, writeback, modify_outer,
                                          *fsdp_args, **fsdp_kwargs):
    with enable_wrap(wrapper_cls=FSDP, *fsdp_args, **fsdp_kwargs):
        lin1 = wrap(nn.Linear(5, 5, bias=False).cuda(cls.rank))
        lin2 = nn.Linear(5, 3, bias=False).cuda(cls.rank)
        model = wrap(nn.Sequential(lin1, lin2))

    # set the value
    outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
    inner_param = model.get_parameter(
        "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param")
    p = outer_param if modify_outer else inner_param

    with torch.no_grad():
        # This sets the local shard value
        p[0] = cls.rank + 2

    with model.summon_full_params(model, writeback=writeback):
        with torch.no_grad():
            p.copy_(torch.zeros_like(p))

    if writeback or cls.world_size == 1:
        # When world_size = 1, FSDP does not shard and parameter is not set to
        # a local shard, so write is always reflected.
        cls.assertEqual(p.cpu()[0], 0)
    else:
        cls.assertEqual(p.cpu()[0], cls.rank + 2)
コード例 #2
0
ファイル: test_fsdp_meta.py プロジェクト: yuguo68/pytorch
 def __init__(self, device):
     super().__init__()
     self.lin1 = MyLinear(2, 2, bias=False, device=device)
     self.lin1 = wrap(self.lin1)
     self.lin2 = MyLinear(2, 2, bias=False, device=device)
     self.l3 = MyModel(device=device)
     self.l3 = wrap(self.l3)
 def configure_sharded_model(self) -> None:
     # the model is already wrapped with FSDP: no need to wrap again!
     if isinstance(self.layer, FullyShardedDataParallel):
         return
     for i, layer in enumerate(self.layer):
         if i % 2 == 0:
             self.layer[i] = wrap(layer)
     self.layer = wrap(self.layer)
コード例 #4
0
ファイル: test_wrap.py プロジェクト: timgates42/pytorch
 def test_auto_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
     sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
     ignored_modules = [sequential[1], sequential[2][0]]
     my_auto_wrap_policy = functools.partial(
         size_based_auto_wrap_policy,
         min_num_params=40,
     )
     fsdp_kwargs = {
         "process_group": self.process_group,
         "auto_wrap_policy": my_auto_wrap_policy,
         "ignored_modules": ignored_modules,
     }
     if wrap_method == WrapMethod.FSDP_CTOR:
         model = FSDP(sequential, **fsdp_kwargs)
     elif wrap_method == WrapMethod.WRAP_API:
         with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):
             model = wrap(sequential)
     else:
         assert 0, f"Unsupported wrap method: {wrap_method}"
     # Since the 2nd linear (`sequential[1]`) is ignored, the wrapping
     # policy does not exceed the parameter threshold before the inner
     # sequential (`sequential[2]`) anymore; hence, it flattens
     # `sequential[0]` and `sequential[2][0]` into `model` and leaves
     # `sequential[1]` and `sequential[2][1]` as-is since they are ignored
     self.assertTrue(isinstance(model, FSDP))
     self.assertTrue(isinstance(model.module[0], nn.Linear))
     self.assertTrue(isinstance(model.module[1], nn.Linear))
     self.assertTrue(isinstance(model.module[2], nn.Sequential))
     self.assertTrue(isinstance(model.module[2][0], nn.Linear))
     self.assertTrue(isinstance(model.module[2][1], nn.Linear))
コード例 #5
0
ファイル: test_wrap.py プロジェクト: timgates42/pytorch
 def test_wrap_override_defaults(self):
     new_process_group = DummyProcessGroup(rank=0, size=2)
     with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
         layer = wrap(nn.Linear(5, 5), process_group=new_process_group)
     self.assertTrue(isinstance(layer, FSDP))
     self.assertTrue(layer.process_group is new_process_group)
     self.assertEqual(layer.rank, 0)
     self.assertEqual(layer.world_size, 2)
コード例 #6
0
ファイル: test_fsdp_meta.py プロジェクト: yuguo68/pytorch
    def _test_nested_model_with_meta_device(self, auto_wrap, meta_module_fn, init_fn=None):
        if auto_wrap:
            module = meta_module_fn()
            is_meta = next(module.parameters()).is_meta
            fsdp_meta = FSDP(
                module,
                auto_wrap_policy=always_wrap,
                param_init_fn=init_fn,
            )
            meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
            module_regular = NestedModel(device="cuda")
            _reset_params_if_meta(is_meta, module_regular)
            fsdp_regular = FSDP(
                module_regular,
                auto_wrap_policy=always_wrap,
            )
            regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)
        else:
            with enable_wrap(
                wrapper_cls=FSDP, param_init_fn=init_fn,
            ):
                module = meta_module_fn()
                is_meta = next(module.parameters()).is_meta
                # Non FSDP modules will still be initialized because they bubble up
                # to be part of a larger FSDP unit.
                fsdp_meta = wrap(module)
                meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)

            # Init and reset parameters before wrapping so that reset_params
            # matches up with meta device's initialization.
            module_regular = NestedModel(device="cuda")
            _reset_params_if_meta(is_meta, module_regular)
            with enable_wrap(wrapper_cls=FSDP):
                module_regular.lin1 = wrap(module_regular.lin1)
                module_regular.l3 = wrap(module_regular.l3)
                fsdp_regular = wrap(module_regular)
                regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)

        # Compare it before training
        self._compare_fsdp(fsdp_meta, fsdp_regular)
        inp = torch.randn(10, 2, device='cuda')
        fsdp_meta(inp).sum().backward()
        fsdp_regular(inp).sum().backward()
        meta_opt.step()
        regular_opt.step()
        self._compare_fsdp(fsdp_meta, fsdp_regular)
コード例 #7
0
 def test_state_dict_type(self):
     module = SkipModel(double_nest=True)
     with enable_wrap(wrapper_cls=FSDP):
         fsdp = wrap(module)
     with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT):
         pass
     for module in FSDP.fsdp_modules(fsdp):
         self.assertEqual(module._state_dict_type, StateDictType.FULL_STATE_DICT)
コード例 #8
0
    def test_distributed_checkpoint(self, state_dict_type) -> None:
        with enable_wrap(wrapper_cls=FSDP):
            torch.manual_seed(100)
            model = wrap(SkipModel(double_nest=True))
            torch.manual_seed(200)
            new_model = wrap(SkipModel(double_nest=True))

        with FullyShardedDataParallel.summon_full_params(
                model), FullyShardedDataParallel.summon_full_params(new_model):
            params = list(model.parameters())
            new_params = list(new_model.parameters())
            self.assertNotEqual(params, new_params)

        with tempfile.TemporaryDirectory() as path:
            paths = [path]
            dist.broadcast_object_list(paths)
            path = paths[0]
            writer = FileSystemWriter(path)
            reader = FileSystemReader(path)
            with FSDP.state_dict_type(model,
                                      state_dict_type), FSDP.state_dict_type(
                                          new_model, state_dict_type):
                state_dict = model.state_dict()

            save_state_dict(state_dict, writer)

            with FSDP.state_dict_type(model,
                                      state_dict_type), FSDP.state_dict_type(
                                          new_model, state_dict_type):
                state_dict = new_model.state_dict()
                load_state_dict(state_dict, reader)
                new_model.load_state_dict(state_dict)

        with FullyShardedDataParallel.summon_full_params(
                model), FullyShardedDataParallel.summon_full_params(new_model):
            params = list(model.parameters())
            new_params = list(new_model.parameters())
            self.assertEqual(params, new_params)
コード例 #9
0
ファイル: test_wrap.py プロジェクト: timgates42/pytorch
 def test_wrap(self, wrap_method):
     if wrap_method == WrapMethod.WRAP_API:
         with enable_wrap(wrapper_cls=FSDP,
                          process_group=self.process_group):
             layer = wrap(nn.Linear(5, 5))
     else:
         assert wrap_method == WrapMethod.FSDP_CTOR
         layer = FSDP(nn.Linear(5, 5),
                      process_group=self.process_group,
                      auto_wrap_policy=functools.partial(
                          size_based_auto_wrap_policy, min_num_params=1))
     self.assertTrue(isinstance(layer, FSDP))
     self.assertEqual(layer.rank, self.process_group.rank())
     self.assertEqual(layer.world_size, self.process_group.size())
コード例 #10
0
ファイル: test_wrap.py プロジェクト: timgates42/pytorch
    def test_wrap_disabled_outside_context(self):
        pg = self.process_group

        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin = wrap(nn.Linear(5, 5), process_group=pg)

        model = MyModel()
        with enable_wrap(wrapper_cls=FSDP, process_group=pg):
            model = wrap(model)

        self.assertTrue(isinstance(model, FSDP))
        self.assertFalse(isinstance(model.lin, FSDP))
        self.assertTrue(isinstance(model.lin, nn.Linear))
コード例 #11
0
 def _create_module(wrap_fsdp=True):
     LINEAR_SKIP = "linear_skip"
     ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else suppress()
     with ctx:
         module = SkipModel(double_nest=double_nest)
         # Full name of linear_skip param tensors in SkipModel, as would be
         # stored in checkpoint.
         linear_skip_tensor_names = [
             k for k in dict(module.named_parameters()).keys()
             if LINEAR_SKIP in k
         ]
         # skip SkipModule
         linear_skip = getattr(module, LINEAR_SKIP)
         delattr(module, LINEAR_SKIP)
         # Wrap FSDP
         fsdp = wrap(module)
         # reattach
         setattr(module, LINEAR_SKIP, linear_skip)
         return fsdp, linear_skip_tensor_names
コード例 #12
0
ファイル: test_wrap.py プロジェクト: timgates42/pytorch
 def test_always_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
     sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
     ignored_modules = [sequential[1], sequential[2][0]]
     fsdp_kwargs = {
         "process_group": self.process_group,
         "auto_wrap_policy": always_wrap_policy,
         "ignored_modules": ignored_modules,
     }
     if wrap_method == WrapMethod.FSDP_CTOR:
         model = FSDP(sequential, **fsdp_kwargs)
     elif wrap_method == WrapMethod.WRAP_API:
         with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):
             model = wrap(sequential)
     else:
         assert 0, f"Unsupported wrap method: {wrap_method}"
     # All non-ignored modules should be wrapped with FSDP
     self.assertTrue(isinstance(model, FSDP))
     self.assertTrue(isinstance(model.module[0], FSDP))
     self.assertTrue(isinstance(model.module[1], nn.Linear))
     self.assertTrue(isinstance(model.module[2], FSDP))
     self.assertTrue(isinstance(model.module[2].module[0], nn.Linear))
     self.assertTrue(isinstance(model.module[2].module[1], FSDP))
コード例 #13
0
ファイル: test_wrap.py プロジェクト: timgates42/pytorch
 def __init__(self):
     super().__init__()
     self.lin = wrap(nn.Linear(5, 5), process_group=pg)
コード例 #14
0
 def __init__(self, double_nest):
     super().__init__()
     self.linear = nn.Linear(10, 10, bias=False).cuda()
     self.linear_skip = SkipModule().cuda()
     self.nested_linear = wrap(NestedLinear(fsdp_wrap=double_nest))
コード例 #15
0
 def __init__(self, fsdp_wrap):
     super().__init__()
     if fsdp_wrap:
         self.nested_linear = wrap(nn.Linear(10, 10, bias=False).cuda())
     else:
         self.nested_linear = nn.Linear(10, 10, bias=False).cuda()