示例#1
0
 def test_wrap_override_defaults(self):
     new_process_group = DummyProcessGroup(rank=0, size=2)
     with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
         layer = wrap(nn.Linear(5, 5), process_group=new_process_group)
     self.assertTrue(isinstance(layer, FSDP))
     self.assertEqual(layer.rank, 0)
     self.assertEqual(layer.world_size, 2)
示例#2
0
 def test_wrap(self, wrap_method):
     if wrap_method == WrapMethod.WRAP_API:
         with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
             layer = wrap(nn.Linear(5, 5))
     else:
         assert wrap_method == WrapMethod.FSDP_CTOR
         layer = FSDP(
             nn.Linear(5, 5),
             process_group=self.process_group,
             fsdp_auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1)
         )
     self.assertTrue(isinstance(layer, FSDP))
     self.assertEqual(layer.rank, self.process_group.rank())
     self.assertEqual(layer.world_size, self.process_group.size())
示例#3
0
    def test_wrap_disabled_outside_context(self):
        pg = self.process_group

        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin = wrap(nn.Linear(5, 5), process_group=pg)

        model = MyModel()
        with enable_wrap(wrapper_cls=FSDP, process_group=pg):
            model = wrap(model)

        self.assertTrue(isinstance(model, FSDP))
        self.assertFalse(isinstance(model.lin, FSDP))
        self.assertTrue(isinstance(model.lin, nn.Linear))
示例#4
0
 def test_wrap_disabled_outside_context(self):
     layer = wrap(nn.Linear(5, 5))
     self.assertTrue(isinstance(layer, nn.Linear))
示例#5
0
 def test_wrap(self):
     with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
         layer = wrap(nn.Linear(5, 5))
     self.assertTrue(isinstance(layer, FSDP))
     self.assertEqual(layer.rank, self.process_group.rank())
     self.assertEqual(layer.world_size, self.process_group.size())
示例#6
0
 def __init__(self):
     super().__init__()
     self.lin = wrap(nn.Linear(5, 5), process_group=pg)