示例#1
0
def fsdp_enable_wrap(cfg: DistributedTrainingConfig,
                     use_sharded_state: bool = False):
    try:
        from fairscale.nn import enable_wrap
    except ImportError:
        raise ImportError(
            "Cannot find FullyShardedDataParallel. "
            "Please install fairscale with: pip install fairscale")
    if cfg.memory_efficient_fp16:
        assert cfg.fp16  # memory_efficient_fp16 should imply fp16
    group = dist_utils.get_data_parallel_group()
    if group is None and cfg.distributed_world_size == 1:
        from fairscale.utils.testing import DummyProcessGroup
        group = DummyProcessGroup(rank=0, size=1)
    fsdp_config = {
        "process_group": group,
        "reshard_after_forward": not cfg.no_reshard_after_forward,
        "mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16,
        "fp32_reduce_scatter": cfg.fp32_reduce_scatter,
        "flatten_parameters": True,
        "cpu_offload": cfg.cpu_offload,
        "compute_dtype": torch.float16 if cfg.fp16 else torch.float32,
        "bucket_cap_mb": cfg.bucket_cap_mb,
    }
    with enable_wrap(use_sharded_state=use_sharded_state, **fsdp_config):
        yield
    def model_sharded_context(self) -> Generator:
        log.detail(
            f"{self.__class__.__name__}: entered model_sharded_context.")
        precision = self.precision_plugin.precision

        def wrap_policy(*args, **kwargs):
            return default_auto_wrap_policy(*args,
                                            **kwargs,
                                            min_num_params=self.min_num_params)

        with enable_wrap(
                wrapper_cls=FullyShardedDataParallel,
                auto_wrap_policy=wrap_policy,
                process_group=self.process_group,
                cpu_offload=self.cpu_offload,
                move_grads_to_cpu=self.move_grads_to_cpu,
                flatten_parameters=self.flatten_parameters,
                mixed_precision=(precision
                                 in (PrecisionType.MIXED, PrecisionType.HALF)),
                reshard_after_forward=self.reshard_after_forward,
                fp32_reduce_scatter=self.fp32_reduce_scatter,
                compute_dtype=self.compute_dtype,
                bucket_cap_mb=self.bucket_cap_mb,
                state_dict_device=self.state_dict_device,
        ):
            yield

        log.detail(
            f"{self.__class__.__name__}: exiting model_sharded_context.")
示例#3
0
    def _auto_wrap_smoke_test(self, enable_mixed_precision):
        device = torch.device("cuda")
        torch.cuda.set_device(0)

        # Random port in case the next test run quickly, same port would cause conflict.
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
        torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

        try:
            with enable_wrap(wrapper_cls=FSDP, mixed_precision=enable_mixed_precision):
                sequential = nn.Sequential(
                    nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
                )
                my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
                model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
            model.to(device)
            input = torch.rand((1, 5), dtype=torch.float).to(device)

            with autocast(enabled=enable_mixed_precision):
                output = model(input)
                loss = F.mse_loss(input, output)
            loss.backward()
        finally:
            torch.distributed.destroy_process_group()
            del os.environ["MASTER_ADDR"]
            del os.environ["MASTER_PORT"]
示例#4
0
def auto_wrap_big_layers(module: nn.Module, fsdp_config: AttrDict):
    """
    Automatically wrap the bigger layer in the module
    """
    with enable_wrap(auto_wrap_policy=_BigConvAutoWrapPolicy(
            fsdp_config.AUTO_WRAP_THRESHOLD),
                     wrapper_cls=_FSDP_WRAPPER,
                     **fsdp_config):
        return auto_wrap(module)
示例#5
0
def _create_model(fsdp_config, compute_cycles, has_params: bool):
    with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
        model = wrap(
            nn.Sequential(
                wrap(Layer(compute_cycles, has_params)),
                wrap(Layer(compute_cycles, has_params)),
                wrap(Layer(compute_cycles, has_params)),
                wrap(Layer(compute_cycles, has_params)),
            )).cuda()
    return model
示例#6
0
 def test_auto_wrap_preset_exclude_wrap_include_children(self):
     """
     Test to ensure excluded modules are not wrapped, but children are if param size is greater than
     min_num_params
     """
     with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
         sequential = nn.ModuleList([nn.Linear(10, 10)])
         my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
         model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
     assert isinstance(model, nn.ModuleList)
     assert isinstance(model[0], FSDP)
示例#7
0
 def test_auto_wrap_preset_force_leaf(self):
     """
     Test to ensure force-leaf modules are not wrapped, and children are not wrapped.
     """
     with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
         sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1))
         my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
         model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
     assert isinstance(model.module[0], FSDP)
     # Assert children of multihead attention are not wrapped
     assert isinstance(model.module[1], nn.MultiheadAttention)
     assert isinstance(model.module[1].out_proj, nn.Linear)
示例#8
0
 def test_auto_wrap_preset_exclude_wrap(self):
     """
     Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
     min_num_params.
     """
     with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
         sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
         my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
         model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
     assert isinstance(model, nn.ModuleList)
     assert isinstance(model[0], nn.Linear)
     assert isinstance(model[1], nn.Linear)
示例#9
0
 def test_auto_wrap_preset_blocklist(self):
     """
     Test to ensure blocklisted modules are not wrapped, and children are not wrapped.
     """
     with enable_wrap(process_group=self.process_group,
                      flatten_parameters=False):
         sequential = nn.Sequential(nn.Linear(10, 10),
                                    nn.MultiheadAttention(100, 1))
         model = auto_wrap(sequential, min_num_params=40)
     assert isinstance(model.module[0], FSDP)
     # Assert children of multihead attention are not wrapped
     assert isinstance(model.module[1], nn.MultiheadAttention)
     assert isinstance(model.module[1].out_proj, nn.Linear)
示例#10
0
 def test_auto_wrap_preset_blocklist_custom(self):
     """
     Test to ensure blocklisted modules are not wrapped.
     """
     with enable_wrap(module_blocklist=[nn.Linear],
                      process_group=self.process_group,
                      flatten_parameters=False):
         sequential = nn.Sequential(nn.Linear(10, 10),
                                    nn.ModuleList([nn.Linear(10, 10)]))
         model = auto_wrap(sequential, min_num_params=40)
     # Model was wrapped in FSDP as no inner modules were wrapped.
     assert isinstance(model, FSDP)
     assert isinstance(model.module[0], nn.Linear)
     assert isinstance(model.module[1], nn.ModuleList)
示例#11
0
 def test_auto_wrap(self):
     """
     Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
     ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
     """
     with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
         sequential = nn.Sequential(
             nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
         )
         my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
         model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
     assert isinstance(model, FSDP)
     assert isinstance(model.module[0], nn.Linear)
     assert isinstance(model.module[1], nn.Linear)
     assert isinstance(model.module[2], FSDP)
     assert isinstance(model.module[2].module[0], nn.Linear)
     assert isinstance(model.module[2].module[1], nn.Linear)
示例#12
0
    def model_sharded_context(self) -> Generator:
        precision = self.lightning_module.trainer.precision

        def wrap_policy(*args, **kwargs):
            return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params)

        with enable_wrap(
            wrapper_cls=FullyShardedDataParallel,
            auto_wrap_policy=wrap_policy,
            process_group=self.process_group,
            cpu_offload=self.cpu_offload,
            move_grads_to_cpu=self.move_grads_to_cpu,
            flatten_parameters=self.flatten_parameters,
            mixed_precision=precision == "mixed",
            reshard_after_forward=self.reshard_after_forward,
            fp32_reduce_scatter=self.fp32_reduce_scatter,
            compute_dtype=self.compute_dtype,
            bucket_cap_mb=self.bucket_cap_mb,
            state_dict_device=self.state_dict_device,
        ):
            yield
示例#13
0
 def test_auto_wrap_preset_force_leaf_custom(self):
     """
     Test to ensure force-leaf modules are not wrapped.
     """
     my_auto_wrap_policy = functools.partial(
         default_auto_wrap_policy,
         min_num_params=40,
         force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES.union({nn.Linear}),
     )
     with enable_wrap(
         auto_wrap_policy=my_auto_wrap_policy,
         wrapper_cls=FSDP,
         process_group=self.process_group,
         flatten_parameters=False,
     ):
         sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)]))
         model = auto_wrap(sequential)
     # Model was wrapped in FSDP as no inner modules were wrapped.
     assert isinstance(model, FSDP)
     assert isinstance(model.module[0], nn.Linear)
     assert isinstance(model.module[1], nn.ModuleList)
示例#14
0
    def _auto_wrap_smoke_test(self, enable_mixed_precision):
        from torch.cuda.amp import autocast

        device = torch.device("cuda")
        torch.cuda.set_device(0)
        torch.distributed.init_process_group(backend="nccl",
                                             rank=0,
                                             world_size=1)

        with enable_wrap(mixed_precision=enable_mixed_precision):
            sequential = nn.Sequential(
                nn.Linear(5, 5), nn.Linear(5, 5),
                nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)))
            model = auto_wrap(sequential, min_num_params=40)
        model.to(device)
        input = torch.rand((1, 5), dtype=torch.float).to(device)

        with autocast(enabled=enable_mixed_precision):
            output = model(input)
            loss = F.mse_loss(input, output)
        loss.backward()
        torch.distributed.destroy_process_group()
示例#15
0
 def test_wrap_override_defaults(self):
     with enable_wrap(wrapper_cls=FSDP, flatten_parameters=False, process_group=self.process_group):
         layer = wrap(nn.Linear(5, 5), flatten_parameters=True)
     assert isinstance(layer, FSDP)
     assert layer.flatten_parameters
示例#16
0
 def test_wrap(self):
     with enable_wrap(flatten_parameters=False,
                      process_group=self.process_group):
         layer = wrap(nn.Linear(5, 5))
     assert isinstance(layer, FSDP)
     assert layer.flatten_parameters is False