コード例 #1
0
 def test_wrap_override_defaults(self):
     new_process_group = DummyProcessGroup(rank=0, size=2)
     with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
         layer = wrap(nn.Linear(5, 5), process_group=new_process_group)
     self.assertTrue(isinstance(layer, FSDP))
     self.assertEqual(layer.rank, 0)
     self.assertEqual(layer.world_size, 2)
コード例 #2
0
    def test_scaling_unscaling_sparse(self):
        pg = DummyProcessGroup(0, 1)
        scaler = ShardedGradScaler(init_scale=2.0,
                                   process_group=pg,
                                   enabled=True)
        inv_scale = torch.full((1, ), 0.5, dtype=torch.float, device="cpu")
        found_inf = torch.full((1, ), 0, dtype=torch.float, device="cpu")

        i = torch.tensor([[0, 1, 1], [2, 0, 2]],
                         device="cpu",
                         dtype=torch.int64)
        v = torch.tensor([16.0, 32.0, 64.0], dtype=torch.float, device="cpu")
        s = torch.sparse_coo_tensor(i,
                                    v,
                                    torch.Size([2, 3]),
                                    device="cpu",
                                    dtype=torch.float)

        # unscale sparse tensors
        s1 = s.clone()
        s1.grad = s.clone()
        opt = torch.optim.SGD([s1], lr=1.0)
        found_inf.zero_()
        found_inf = scaler._unscale_grads_(opt, inv_scale,
                                           found_inf)[s1.device]
        self.assertEqual(found_inf, 0.0)
        self.assertEqual(s1.grad.to_dense(), (s / 2).to_dense())

        # unscale sparse tensor: inf
        v = torch.tensor([16.0, 32.0, float('inf')],
                         dtype=torch.float,
                         device="cpu")
        s1.grad = torch.sparse_coo_tensor(i,
                                          v,
                                          torch.Size([2, 3]),
                                          device="cpu",
                                          dtype=torch.float)
        found_inf.zero_()
        found_inf = scaler._unscale_grads_(opt, inv_scale,
                                           found_inf)[s1.device]
        self.assertEqual(found_inf, 1.0)

        # unscale sparse tensor: overflow (marked as inf)
        i = torch.tensor([[1, 1, 1], [0, 0, 2]],
                         device="cpu",
                         dtype=torch.int64)
        # coalescing sparse tensor here will cause the value to be Inf
        v = torch.tensor([2**15, 2**15, 1.0],
                         dtype=torch.float16,
                         device="cpu")
        s1 = torch.sparse_coo_tensor(i,
                                     v,
                                     torch.Size([2, 3]),
                                     device="cpu",
                                     dtype=torch.float16)
        s1.grad = s1.clone()
        found_inf.zero_()
        found_inf = scaler._unscale_grads_(opt, inv_scale,
                                           found_inf)[s1.device]
        self.assertEqual(found_inf, 1.0)
コード例 #3
0
 def test_grad_scaling(self):
     pg = DummyProcessGroup(0, 1)
     scaler = ShardedGradScaler(init_scale=2.0, process_group=pg, enabled=True)
     t0 = torch.full((1,), 4.0, dtype=torch.float32, device="cpu")
     t1 = torch.full((1,), 8.0, dtype=torch.float32, device="cpu")
     outputs = [t1.clone(), (t0.clone(), t1.clone()), [t0.clone(), t1.clone()]]
     outputs = scaler.scale(outputs)
     self.assertTrue(outputs[0] == 16.0 and outputs[1][0] == 8.0 and outputs[1][1] == 16.0)
     self.assertTrue(outputs[2][0] == 8.0 and outputs[2][1] == 16.0)
     self.assertTrue(scaler._scale.device == t1.device)
コード例 #4
0
 def test_inf_gradients_skip_optim_step(self):
     pg = DummyProcessGroup(0, 1)
     scaler = ShardedGradScaler(init_scale=2.0, process_group=pg, enabled=True)
     loss = torch.full((1,), 4.0, dtype=torch.float32, device="cpu")
     t0 = torch.tensor([float('inf')], dtype=torch.float32, device="cpu")
     t0.grad = t0.clone()
     opt = torch.optim.SGD([t0], lr=1.0)
     scaler.scale(loss)
     ret_val = scaler.step(opt)
     self.assertTrue(ret_val is None)
コード例 #5
0
 def setUp(self) -> None:
     super().setUp()
     # For all the tests here, we use a fake group
     self.process_group = DummyProcessGroup(rank=0, size=1)
コード例 #6
0
class TestAutoWrap(TestCase):
    def setUp(self) -> None:
        super().setUp()
        # For all the tests here, we use a fake group
        self.process_group = DummyProcessGroup(rank=0, size=1)

    @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])
    def test_wrap(self, wrap_method):
        if wrap_method == WrapMethod.WRAP_API:
            with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
                layer = wrap(nn.Linear(5, 5))
        else:
            assert wrap_method == WrapMethod.FSDP_CTOR
            layer = FSDP(
                nn.Linear(5, 5),
                process_group=self.process_group,
                fsdp_auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1)
            )
        self.assertTrue(isinstance(layer, FSDP))
        self.assertEqual(layer.rank, self.process_group.rank())
        self.assertEqual(layer.world_size, self.process_group.size())

    def test_wrap_disabled_outside_context(self):
        layer = wrap(nn.Linear(5, 5))
        self.assertTrue(isinstance(layer, nn.Linear))

    def test_wrap_override_defaults(self):
        new_process_group = DummyProcessGroup(rank=0, size=2)
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
            layer = wrap(nn.Linear(5, 5), process_group=new_process_group)
        self.assertTrue(isinstance(layer, FSDP))
        self.assertEqual(layer.rank, 0)
        self.assertEqual(layer.world_size, 2)

    def test_auto_wrap_api(self):
        """
        Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
        ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
        """
        sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
        my_auto_wrap_policy = functools.partial(
            default_auto_wrap_policy, min_num_params=40
        )
        model = FSDP(
            sequential,
            process_group=self.process_group,
            fsdp_auto_wrap_policy=my_auto_wrap_policy
        )

        TestFSDPWrap.NestedSequentialModel.verify_model(self, model)


    def test_auto_wrap_preset_exclude_wrap(self):
        """
        Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
        min_num_params. the default_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict}
        """
        sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
        my_auto_wrap_policy = functools.partial(
            default_auto_wrap_policy, min_num_params=40
        )

        model = FSDP(
            sequential,
            process_group=self.process_group,
            fsdp_auto_wrap_policy=my_auto_wrap_policy
        )

        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model[0], nn.Linear))
        self.assertTrue(isinstance(model[1], nn.Linear))

    def test_auto_wrap_preset_exclude_wrap_include_children(self):
        """
        Test to ensure excluded modules are not wrapped, but children are if param size is greater than
        min_num_params
        """
        sequential = nn.ModuleList([nn.Linear(10, 10)])
        my_auto_wrap_policy = functools.partial(
            default_auto_wrap_policy, min_num_params=40
        )
        model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy)

        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model[0], FSDP))

    def test_auto_wrap_preset_force_leaf(self):
        """
        Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The
        default_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped
        """
        sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1))
        my_auto_wrap_policy = functools.partial(
            default_auto_wrap_policy, min_num_params=40
        )
        model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy)
        self.assertTrue(isinstance(model.module[0], FSDP))
        # Assert children of multihead attention are not wrapped
        self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention))
        self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear))

    def test_auto_wrap_preset_force_leaf_custom(self):
        """
        Test to ensure force-leaf modules are not wrapped.
        """
        my_auto_wrap_policy = functools.partial(
            default_auto_wrap_policy,
            min_num_params=40,
            force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES.union(
                {nn.Linear}
            ),
        )
        sequential = nn.Sequential(
            nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)])
        )
        model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy)
        # Model was wrapped in FSDP as no inner modules were wrapped.
        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model.module[0], nn.Linear))
        self.assertTrue(isinstance(model.module[1], nn.ModuleList))

    @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
    @parametrize("fsdp_init_mode", [FSDPInitMode.CUDA_BEFORE, FSDPInitMode.CUDA_AFTER])
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=False), CPUOffload(offload_params=True)]
    )
    def test_auto_wrap_smoke_test(self, fsdp_init_mode, cpu_offload):
        # CPU offload and CUDA after don't work together as expected.
        if (
            cpu_offload.offload_params and fsdp_init_mode == FSDPInitMode.CUDA_AFTER
        ):
            return

        device = torch.device("cuda")
        torch.cuda.set_device(0)

        # Random port in case the next test run quickly, same port would cause conflict.
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(find_free_port())
        torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

        # NOTE: We move model to CUDA after init with FSDP to simulate real use
        # cases where full model cannot be loaded onto GPU, but their shards can.
        cuda_after_init = fsdp_init_mode == FSDPInitMode.CUDA_AFTER
        try:
            sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=(not cuda_after_init))
            my_auto_wrap_policy = functools.partial(
                default_auto_wrap_policy, min_num_params=40
            )
            model = FSDP(sequential, cpu_offload=cpu_offload, fsdp_auto_wrap_policy=my_auto_wrap_policy)
            TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
            if cuda_after_init:
                model = model.cuda()
            input = torch.rand((1, 5), dtype=torch.float).to(device)
            output = model(input)
            loss = F.mse_loss(input, output)
            loss.backward()
        finally:
            torch.distributed.destroy_process_group()
            del os.environ["MASTER_ADDR"]
            del os.environ["MASTER_PORT"]
コード例 #7
0
ファイル: test_wrap.py プロジェクト: timgates42/pytorch
class TestAutoWrap(TestCase):
    def setUp(self) -> None:
        super().setUp()
        # For all the tests here, we use a fake group
        self.process_group = DummyProcessGroup(rank=0, size=1)

    @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
    @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])
    def test_wrap(self, wrap_method):
        if wrap_method == WrapMethod.WRAP_API:
            with enable_wrap(wrapper_cls=FSDP,
                             process_group=self.process_group):
                layer = wrap(nn.Linear(5, 5))
        else:
            assert wrap_method == WrapMethod.FSDP_CTOR
            layer = FSDP(nn.Linear(5, 5),
                         process_group=self.process_group,
                         auto_wrap_policy=functools.partial(
                             size_based_auto_wrap_policy, min_num_params=1))
        self.assertTrue(isinstance(layer, FSDP))
        self.assertEqual(layer.rank, self.process_group.rank())
        self.assertEqual(layer.world_size, self.process_group.size())

    @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
    def test_wrap_disabled_outside_context(self):
        pg = self.process_group

        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin = wrap(nn.Linear(5, 5), process_group=pg)

        model = MyModel()
        with enable_wrap(wrapper_cls=FSDP, process_group=pg):
            model = wrap(model)

        self.assertTrue(isinstance(model, FSDP))
        self.assertFalse(isinstance(model.lin, FSDP))
        self.assertTrue(isinstance(model.lin, nn.Linear))

    @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
    def test_wrap_override_defaults(self):
        new_process_group = DummyProcessGroup(rank=0, size=2)
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
            layer = wrap(nn.Linear(5, 5), process_group=new_process_group)
        self.assertTrue(isinstance(layer, FSDP))
        self.assertTrue(layer.process_group is new_process_group)
        self.assertEqual(layer.rank, 0)
        self.assertEqual(layer.world_size, 2)

    @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
    def test_always_wrap(self):
        """
        Test to ensure that if `always_wrap_policy` is
        passed into FSDP, all submodules are wrapped.
        """
        seq = TestFSDPWrap.NestedSequentialModel.get_model(cuda=True)
        model = FSDP(seq,
                     process_group=self.process_group,
                     auto_wrap_policy=always_wrap_policy)
        TestFSDPWrap.NestedSequentialModel.verify_model_all_wrapped(
            self, model)

    @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
    def test_transformer_auto_wrap_policy(self):
        """Tests the ``transformer_auto_wrap_policy``."""
        auto_wrap_policy = functools.partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={
                TransformerEncoderLayer, TransformerDecoderLayer
            },
        )
        fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy}
        fsdp_model = TransformerWithSharedParams.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
            fsdp_kwargs,
        )
        modules = list(fsdp_model.modules())
        encoder_layers = set(fsdp_model.module.transformer.encoder.layers)
        decoder_layers = set(fsdp_model.module.transformer.decoder.layers)
        for module in modules:
            if module is fsdp_model or module in encoder_layers or module in decoder_layers:
                self.assertTrue(isinstance(module, FSDP))
            else:
                self.assertFalse(isinstance(module, FSDP))

    @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
    def test_auto_wrap_api(self):
        """
        Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
        ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
        """
        sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
        my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy,
                                                min_num_params=40)
        model = FSDP(sequential,
                     process_group=self.process_group,
                     auto_wrap_policy=my_auto_wrap_policy)

        TestFSDPWrap.NestedSequentialModel.verify_model(self, model)

    @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
    def test_auto_wrap_preset_exclude_wrap(self):
        """
        Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
        min_num_params. the size_based_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict}
        """
        sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
        my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy,
                                                min_num_params=40)

        model = FSDP(sequential,
                     process_group=self.process_group,
                     auto_wrap_policy=my_auto_wrap_policy)

        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model[0], nn.Linear))
        self.assertTrue(isinstance(model[1], nn.Linear))

    @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
    def test_auto_wrap_preset_exclude_wrap_include_children(self):
        """
        Test to ensure excluded modules are not wrapped, but children are if param size is greater than
        min_num_params
        """
        sequential = nn.ModuleList([nn.Linear(10, 10)])
        my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy,
                                                min_num_params=40)
        model = FSDP(sequential,
                     process_group=self.process_group,
                     auto_wrap_policy=my_auto_wrap_policy)

        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model[0], FSDP))

    @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
    def test_auto_wrap_preset_force_leaf(self):
        """
        Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The
        size_based_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped
        """
        sequential = nn.Sequential(nn.Linear(10, 10),
                                   nn.MultiheadAttention(100, 1))
        my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy,
                                                min_num_params=40)
        model = FSDP(sequential,
                     process_group=self.process_group,
                     auto_wrap_policy=my_auto_wrap_policy)
        self.assertTrue(isinstance(model.module[0], FSDP))
        # Assert children of multihead attention are not wrapped
        self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention))
        self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear))

    @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
    def test_auto_wrap_preset_force_leaf_custom(self):
        """
        Test to ensure force-leaf modules are not wrapped.
        """
        my_auto_wrap_policy = functools.partial(
            size_based_auto_wrap_policy,
            min_num_params=40,
            force_leaf_modules=size_based_auto_wrap_policy.FORCE_LEAF_MODULES.
            union({nn.Linear}),
        )
        sequential = nn.Sequential(nn.Linear(10, 10),
                                   nn.ModuleList([nn.Linear(10, 10)]))
        model = FSDP(sequential,
                     process_group=self.process_group,
                     auto_wrap_policy=my_auto_wrap_policy)
        # Model was wrapped in FSDP as no inner modules were wrapped.
        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model.module[0], nn.Linear))
        self.assertTrue(isinstance(model.module[1], nn.ModuleList))

    @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
    @parametrize("cuda_init_mode",
                 [CUDAInitMode.CUDA_BEFORE, CUDAInitMode.CUDA_AFTER])
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=False),
         CPUOffload(offload_params=True)])
    @parametrize("use_device_id", [True, False])
    def test_auto_wrap_smoke_test(self, cuda_init_mode, cpu_offload,
                                  use_device_id):
        # CPU offload and CUDA after don't work together as expected.
        if (cpu_offload.offload_params
                and cuda_init_mode == CUDAInitMode.CUDA_AFTER):
            return

        device = torch.device("cuda")
        torch.cuda.set_device(0)
        device_id = (torch.device("cuda", torch.cuda.current_device())
                     if use_device_id else None)

        # Random port in case the next test run quickly, same port would cause conflict.
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(find_free_port())

        file_name = tempfile.NamedTemporaryFile(delete=False).name
        torch.distributed.init_process_group(
            backend="nccl",
            init_method=f"{FILE_SCHEMA}_{file_name}",
            rank=0,
            world_size=1,
        )

        # NOTE: We move model to CUDA after init with FSDP to simulate real use
        # cases where full model cannot be loaded onto GPU, but their shards can.
        cuda_after_init = cuda_init_mode == CUDAInitMode.CUDA_AFTER
        try:
            sequential = TestFSDPWrap.NestedSequentialModel.get_model(
                cuda=(not cuda_after_init))
            my_auto_wrap_policy = functools.partial(
                size_based_auto_wrap_policy, min_num_params=40)
            model = FSDP(sequential,
                         cpu_offload=cpu_offload,
                         auto_wrap_policy=my_auto_wrap_policy,
                         device_id=device_id)
            TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
            if cuda_after_init:
                model = model.cuda()
            input = torch.rand((1, 5), dtype=torch.float).to(device)
            output = model(input)
            loss = F.mse_loss(input, output)
            loss.backward()
        finally:
            torch.distributed.destroy_process_group()

        try:
            os.remove(file_name)
        except FileNotFoundError:
            pass

    @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
    @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])
    def test_always_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
        sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
        ignored_modules = [sequential[1], sequential[2][0]]
        fsdp_kwargs = {
            "process_group": self.process_group,
            "auto_wrap_policy": always_wrap_policy,
            "ignored_modules": ignored_modules,
        }
        if wrap_method == WrapMethod.FSDP_CTOR:
            model = FSDP(sequential, **fsdp_kwargs)
        elif wrap_method == WrapMethod.WRAP_API:
            with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):
                model = wrap(sequential)
        else:
            assert 0, f"Unsupported wrap method: {wrap_method}"
        # All non-ignored modules should be wrapped with FSDP
        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model.module[0], FSDP))
        self.assertTrue(isinstance(model.module[1], nn.Linear))
        self.assertTrue(isinstance(model.module[2], FSDP))
        self.assertTrue(isinstance(model.module[2].module[0], nn.Linear))
        self.assertTrue(isinstance(model.module[2].module[1], FSDP))

    @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
    @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])
    def test_auto_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
        sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
        ignored_modules = [sequential[1], sequential[2][0]]
        my_auto_wrap_policy = functools.partial(
            size_based_auto_wrap_policy,
            min_num_params=40,
        )
        fsdp_kwargs = {
            "process_group": self.process_group,
            "auto_wrap_policy": my_auto_wrap_policy,
            "ignored_modules": ignored_modules,
        }
        if wrap_method == WrapMethod.FSDP_CTOR:
            model = FSDP(sequential, **fsdp_kwargs)
        elif wrap_method == WrapMethod.WRAP_API:
            with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):
                model = wrap(sequential)
        else:
            assert 0, f"Unsupported wrap method: {wrap_method}"
        # Since the 2nd linear (`sequential[1]`) is ignored, the wrapping
        # policy does not exceed the parameter threshold before the inner
        # sequential (`sequential[2]`) anymore; hence, it flattens
        # `sequential[0]` and `sequential[2][0]` into `model` and leaves
        # `sequential[1]` and `sequential[2][1]` as-is since they are ignored
        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model.module[0], nn.Linear))
        self.assertTrue(isinstance(model.module[1], nn.Linear))
        self.assertTrue(isinstance(model.module[2], nn.Sequential))
        self.assertTrue(isinstance(model.module[2][0], nn.Linear))
        self.assertTrue(isinstance(model.module[2][1], nn.Linear))
コード例 #8
0
class TestAutoWrap(unittest.TestCase):
    def setUp(self) -> None:
        # For all the tests here, we use a fake group
        self.process_group = DummyProcessGroup(rank=0, size=1)

    def test_wrap(self):
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
            layer = wrap(nn.Linear(5, 5))
        self.assertTrue(isinstance(layer, FSDP))
        self.assertEqual(layer.rank, self.process_group.rank())
        self.assertEqual(layer.world_size, self.process_group.size())

    def test_wrap_disabled_outside_context(self):
        layer = wrap(nn.Linear(5, 5))
        self.assertTrue(isinstance(layer, nn.Linear))

    def test_wrap_override_defaults(self):
        new_process_group = DummyProcessGroup(rank=0, size=2)
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
            layer = wrap(nn.Linear(5, 5), process_group=new_process_group)
        self.assertTrue(isinstance(layer, FSDP))
        self.assertEqual(layer.rank, 0)
        self.assertEqual(layer.world_size, 2)

    def test_auto_wrap(self):
        """
        Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
        ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
        """
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
            sequential = nn.Sequential(
                nn.Linear(5, 5), nn.Linear(5, 5),
                nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)))
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                    min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model.module[0], nn.Linear))
        self.assertTrue(isinstance(model.module[1], nn.Linear))
        self.assertTrue(isinstance(model.module[2], FSDP))
        self.assertTrue(isinstance(model.module[2].module[0], nn.Linear))
        self.assertTrue(isinstance(model.module[2].module[1], nn.Linear))

    def test_auto_wrap_preset_exclude_wrap(self):
        """
        Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
        min_num_params. the default_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict}
        """
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
            sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                    min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
        self.assertTrue(isinstance(model, nn.ModuleList))
        self.assertTrue(isinstance(model[0], nn.Linear))
        self.assertTrue(isinstance(model[1], nn.Linear))

    def test_auto_wrap_preset_exclude_wrap_include_children(self):
        """
        Test to ensure excluded modules are not wrapped, but children are if param size is greater than
        min_num_params
        """
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
            sequential = nn.ModuleList([nn.Linear(10, 10)])
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                    min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
        self.assertTrue(isinstance(model, nn.ModuleList))
        self.assertTrue(isinstance(model[0], FSDP))

    def test_auto_wrap_preset_force_leaf(self):
        """
        Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The
        default_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped
        """
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
            sequential = nn.Sequential(nn.Linear(10, 10),
                                       nn.MultiheadAttention(100, 1))
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                    min_num_params=40)
            model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
        self.assertTrue(isinstance(model.module[0], FSDP))
        # Assert children of multihead attention are not wrapped
        self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention))
        self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear))

    def test_auto_wrap_preset_force_leaf_custom(self):
        """
        Test to ensure force-leaf modules are not wrapped.
        """
        my_auto_wrap_policy = functools.partial(
            default_auto_wrap_policy,
            min_num_params=40,
            force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES.
            union({nn.Linear}),
        )
        with enable_wrap(
                auto_wrap_policy=my_auto_wrap_policy,
                wrapper_cls=FSDP,
                process_group=self.process_group,
        ):
            sequential = nn.Sequential(nn.Linear(10, 10),
                                       nn.ModuleList([nn.Linear(10, 10)]))
            model = auto_wrap(sequential)
        # Model was wrapped in FSDP as no inner modules were wrapped.
        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model.module[0], nn.Linear))
        self.assertTrue(isinstance(model.module[1], nn.ModuleList))

    @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
    def test_auto_wrap_smoke_test(self):
        device = torch.device("cuda")
        torch.cuda.set_device(0)

        # Random port in case the next test run quickly, same port would cause conflict.
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(find_free_port())
        torch.distributed.init_process_group(backend="nccl",
                                             rank=0,
                                             world_size=1)

        try:
            with enable_wrap(wrapper_cls=FSDP):
                sequential = nn.Sequential(
                    nn.Linear(5, 5), nn.Linear(5, 5),
                    nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)))
                my_auto_wrap_policy = functools.partial(
                    default_auto_wrap_policy, min_num_params=40)
                model = auto_wrap(sequential,
                                  auto_wrap_policy=my_auto_wrap_policy)
            model.to(device)
            input = torch.rand((1, 5), dtype=torch.float).to(device)
            output = model(input)
            loss = F.mse_loss(input, output)
            loss.backward()
        finally:
            torch.distributed.destroy_process_group()
            del os.environ["MASTER_ADDR"]
            del os.environ["MASTER_PORT"]