def test_wrap_override_defaults(self): new_process_group = DummyProcessGroup(rank=0, size=2) with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5), process_group=new_process_group) self.assertTrue(isinstance(layer, FSDP)) self.assertEqual(layer.rank, 0) self.assertEqual(layer.world_size, 2)
def test_wrap(self, wrap_method): if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5)) else: assert wrap_method == WrapMethod.FSDP_CTOR layer = FSDP( nn.Linear(5, 5), process_group=self.process_group, fsdp_auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1) ) self.assertTrue(isinstance(layer, FSDP)) self.assertEqual(layer.rank, self.process_group.rank()) self.assertEqual(layer.world_size, self.process_group.size())
def test_wrap_disabled_outside_context(self): pg = self.process_group class MyModel(nn.Module): def __init__(self): super().__init__() self.lin = wrap(nn.Linear(5, 5), process_group=pg) model = MyModel() with enable_wrap(wrapper_cls=FSDP, process_group=pg): model = wrap(model) self.assertTrue(isinstance(model, FSDP)) self.assertFalse(isinstance(model.lin, FSDP)) self.assertTrue(isinstance(model.lin, nn.Linear))
def test_wrap_disabled_outside_context(self): layer = wrap(nn.Linear(5, 5)) self.assertTrue(isinstance(layer, nn.Linear))
def test_wrap(self): with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5)) self.assertTrue(isinstance(layer, FSDP)) self.assertEqual(layer.rank, self.process_group.rank()) self.assertEqual(layer.world_size, self.process_group.size())
def __init__(self): super().__init__() self.lin = wrap(nn.Linear(5, 5), process_group=pg)