Example #1
0
    def test_auto_wrap_smoke_test(self):
        device = torch.device("cuda")
        torch.cuda.set_device(0)

        # Random port in case the next test run quickly, same port would cause conflict.
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(find_free_port())
        torch.distributed.init_process_group(backend="nccl",
                                             rank=0,
                                             world_size=1)

        try:
            with enable_wrap(wrapper_cls=FSDP):
                sequential = nn.Sequential(
                    nn.Linear(5, 5), nn.Linear(5, 5),
                    nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)))
                my_auto_wrap_policy = functools.partial(
                    default_auto_wrap_policy, min_num_params=40)
                model = auto_wrap(sequential,
                                  auto_wrap_policy=my_auto_wrap_policy)
            model.to(device)
            input = torch.rand((1, 5), dtype=torch.float).to(device)
            output = model(input)
            loss = F.mse_loss(input, output)
            loss.backward()
        finally:
            torch.distributed.destroy_process_group()
            del os.environ["MASTER_ADDR"]
            del os.environ["MASTER_PORT"]
Example #2
0
 def test_auto_wrap_preset_force_leaf_custom(self, wrap_method):
     """
     Test to ensure force-leaf modules are not wrapped.
     """
     my_auto_wrap_policy = functools.partial(
         default_auto_wrap_policy,
         min_num_params=40,
         force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES.
         union({nn.Linear}),
     )
     sequential = nn.Sequential(nn.Linear(10, 10),
                                nn.ModuleList([nn.Linear(10, 10)]))
     if wrap_method == WrapMethod.WRAP_API:
         with enable_wrap(
                 auto_wrap_policy=my_auto_wrap_policy,
                 wrapper_cls=FSDP,
                 process_group=self.process_group,
         ):
             model = auto_wrap(sequential)
     else:
         assert wrap_method == WrapMethod.FSDP_CTOR
         model = FSDP(sequential,
                      process_group=self.process_group,
                      fsdp_auto_wrap_policy=my_auto_wrap_policy)
     # Model was wrapped in FSDP as no inner modules were wrapped.
     self.assertTrue(isinstance(model, FSDP))
     self.assertTrue(isinstance(model.module[0], nn.Linear))
     self.assertTrue(isinstance(model.module[1], nn.ModuleList))
Example #3
0
    def test_auto_wrap_preset_exclude_wrap(self, wrap_method):
        """
        Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
        min_num_params. the default_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict}
        """
        sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
        my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                min_num_params=40)

        if wrap_method == WrapMethod.WRAP_API:
            with enable_wrap(wrapper_cls=FSDP,
                             process_group=self.process_group):
                model = auto_wrap(sequential,
                                  auto_wrap_policy=my_auto_wrap_policy)
        else:
            assert wrap_method == WrapMethod.FSDP_CTOR
            model = FSDP(sequential,
                         process_group=self.process_group,
                         fsdp_auto_wrap_policy=my_auto_wrap_policy)

        # Note that outermost module will be FSDP instance for FSDP_CTOR
        # approach, because we need to call the FSDP ctor so the returned obj
        # will be an FSDP instance. If we don't want to shard the outermost
        # module based on policy, we can apply the policy manually to the
        # outermost instance and skip the sharding.
        if wrap_method == WrapMethod.WRAP_API:
            self.assertTrue(isinstance(model, nn.ModuleList))
        else:
            self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model[0], nn.Linear))
        self.assertTrue(isinstance(model[1], nn.Linear))
Example #4
0
 def test_auto_wrap_preset_exclude_wrap_include_children(self):
     """
     Test to ensure excluded modules are not wrapped, but children are if param size is greater than
     min_num_params
     """
     with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
         sequential = nn.ModuleList([nn.Linear(10, 10)])
         my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                 min_num_params=40)
         model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
     self.assertTrue(isinstance(model, nn.ModuleList))
     self.assertTrue(isinstance(model[0], FSDP))
Example #5
0
 def test_auto_wrap_preset_exclude_wrap(self):
     """
     Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
     min_num_params. the default_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict}
     """
     with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
         sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
         my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                 min_num_params=40)
         model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
     self.assertTrue(isinstance(model, nn.ModuleList))
     self.assertTrue(isinstance(model[0], nn.Linear))
     self.assertTrue(isinstance(model[1], nn.Linear))
Example #6
0
 def test_auto_wrap_preset_force_leaf(self):
     """
     Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The
     default_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped
     """
     with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
         sequential = nn.Sequential(nn.Linear(10, 10),
                                    nn.MultiheadAttention(100, 1))
         my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                 min_num_params=40)
         model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
     self.assertTrue(isinstance(model.module[0], FSDP))
     # Assert children of multihead attention are not wrapped
     self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention))
     self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear))
Example #7
0
 def test_auto_wrap(self):
     """
     Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
     ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
     """
     with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
         sequential = nn.Sequential(
             nn.Linear(5, 5), nn.Linear(5, 5),
             nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)))
         my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                 min_num_params=40)
         model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
     self.assertTrue(isinstance(model, FSDP))
     self.assertTrue(isinstance(model.module[0], nn.Linear))
     self.assertTrue(isinstance(model.module[1], nn.Linear))
     self.assertTrue(isinstance(model.module[2], FSDP))
     self.assertTrue(isinstance(model.module[2].module[0], nn.Linear))
     self.assertTrue(isinstance(model.module[2].module[1], nn.Linear))
Example #8
0
    def test_auto_wrap_smoke_test(self, wrap_method, fsdp_init_mode,
                                  cpu_offload):
        # CPU offload and CUDA after don't work together as expected.
        if (cpu_offload.offload_params
                and fsdp_init_mode == FSDPInitMode.CUDA_AFTER):
            return

        device = torch.device("cuda")
        torch.cuda.set_device(0)

        # Random port in case the next test run quickly, same port would cause conflict.
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(find_free_port())
        torch.distributed.init_process_group(backend="nccl",
                                             rank=0,
                                             world_size=1)

        # NOTE: We move model to CUDA after init with FSDP to simulate real use
        # cases where full model cannot be loaded onto GPU, but their shards can.
        cuda_after_init = fsdp_init_mode == FSDPInitMode.CUDA_AFTER
        try:
            sequential = TestFSDPWrap.NestedSequentialModel.get_model(
                cuda=(not cuda_after_init))
            my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                    min_num_params=40)
            if wrap_method == WrapMethod.WRAP_API:
                with enable_wrap(wrapper_cls=FSDP, cpu_offload=cpu_offload):
                    model = auto_wrap(sequential,
                                      auto_wrap_policy=my_auto_wrap_policy)
            else:
                model = FSDP(sequential,
                             cpu_offload=cpu_offload,
                             fsdp_auto_wrap_policy=my_auto_wrap_policy)
            TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
            if cuda_after_init:
                model = model.cuda()
            input = torch.rand((1, 5), dtype=torch.float).to(device)
            output = model(input)
            loss = F.mse_loss(input, output)
            loss.backward()
        finally:
            torch.distributed.destroy_process_group()
            del os.environ["MASTER_ADDR"]
            del os.environ["MASTER_PORT"]
Example #9
0
    def test_auto_wrap_foo(self, wrap_method):
        """
        Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
        ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
        """
        sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
        my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                min_num_params=40)
        if wrap_method == WrapMethod.WRAP_API:
            with enable_wrap(wrapper_cls=FSDP,
                             process_group=self.process_group):
                model = auto_wrap(sequential,
                                  auto_wrap_policy=my_auto_wrap_policy)
        else:
            assert wrap_method == WrapMethod.FSDP_CTOR
            model = FSDP(sequential,
                         process_group=self.process_group,
                         fsdp_auto_wrap_policy=my_auto_wrap_policy)

        TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
Example #10
0
    def test_auto_wrap_preset_exclude_wrap_include_children(self, wrap_method):
        """
        Test to ensure excluded modules are not wrapped, but children are if param size is greater than
        min_num_params
        """
        sequential = nn.ModuleList([nn.Linear(10, 10)])
        my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                min_num_params=40)
        if wrap_method == WrapMethod.WRAP_API:
            with enable_wrap(wrapper_cls=FSDP,
                             process_group=self.process_group):
                model = auto_wrap(sequential,
                                  auto_wrap_policy=my_auto_wrap_policy)
        else:
            assert wrap_method == WrapMethod.FSDP_CTOR
            model = FSDP(sequential,
                         process_group=self.process_group,
                         fsdp_auto_wrap_policy=my_auto_wrap_policy)

        if wrap_method == WrapMethod.WRAP_API:
            self.assertTrue(isinstance(model, nn.ModuleList))
        else:
            self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model[0], FSDP))