def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = False): try: from fairscale.nn import enable_wrap except ImportError: raise ImportError( "Cannot find FullyShardedDataParallel. " "Please install fairscale with: pip install fairscale") if cfg.memory_efficient_fp16: assert cfg.fp16 # memory_efficient_fp16 should imply fp16 group = dist_utils.get_data_parallel_group() if group is None and cfg.distributed_world_size == 1: from fairscale.utils.testing import DummyProcessGroup group = DummyProcessGroup(rank=0, size=1) fsdp_config = { "process_group": group, "reshard_after_forward": not cfg.no_reshard_after_forward, "mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16, "fp32_reduce_scatter": cfg.fp32_reduce_scatter, "flatten_parameters": True, "cpu_offload": cfg.cpu_offload, "compute_dtype": torch.float16 if cfg.fp16 else torch.float32, "bucket_cap_mb": cfg.bucket_cap_mb, } with enable_wrap(use_sharded_state=use_sharded_state, **fsdp_config): yield
def model_sharded_context(self) -> Generator: log.detail( f"{self.__class__.__name__}: entered model_sharded_context.") precision = self.precision_plugin.precision def wrap_policy(*args, **kwargs): return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params) with enable_wrap( wrapper_cls=FullyShardedDataParallel, auto_wrap_policy=wrap_policy, process_group=self.process_group, cpu_offload=self.cpu_offload, move_grads_to_cpu=self.move_grads_to_cpu, flatten_parameters=self.flatten_parameters, mixed_precision=(precision in (PrecisionType.MIXED, PrecisionType.HALF)), reshard_after_forward=self.reshard_after_forward, fp32_reduce_scatter=self.fp32_reduce_scatter, compute_dtype=self.compute_dtype, bucket_cap_mb=self.bucket_cap_mb, state_dict_device=self.state_dict_device, ): yield log.detail( f"{self.__class__.__name__}: exiting model_sharded_context.")
def _auto_wrap_smoke_test(self, enable_mixed_precision): device = torch.device("cuda") torch.cuda.set_device(0) # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(random.randint(2000, 3000)) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) try: with enable_wrap(wrapper_cls=FSDP, mixed_precision=enable_mixed_precision): sequential = nn.Sequential( nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)) ) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) model.to(device) input = torch.rand((1, 5), dtype=torch.float).to(device) with autocast(enabled=enable_mixed_precision): output = model(input) loss = F.mse_loss(input, output) loss.backward() finally: torch.distributed.destroy_process_group() del os.environ["MASTER_ADDR"] del os.environ["MASTER_PORT"]
def auto_wrap_big_layers(module: nn.Module, fsdp_config: AttrDict): """ Automatically wrap the bigger layer in the module """ with enable_wrap(auto_wrap_policy=_BigConvAutoWrapPolicy( fsdp_config.AUTO_WRAP_THRESHOLD), wrapper_cls=_FSDP_WRAPPER, **fsdp_config): return auto_wrap(module)
def _create_model(fsdp_config, compute_cycles, has_params: bool): with enable_wrap(wrapper_cls=FSDP, **fsdp_config): model = wrap( nn.Sequential( wrap(Layer(compute_cycles, has_params)), wrap(Layer(compute_cycles, has_params)), wrap(Layer(compute_cycles, has_params)), wrap(Layer(compute_cycles, has_params)), )).cuda() return model
def test_auto_wrap_preset_exclude_wrap_include_children(self): """ Test to ensure excluded modules are not wrapped, but children are if param size is greater than min_num_params """ with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False): sequential = nn.ModuleList([nn.Linear(10, 10)]) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) assert isinstance(model, nn.ModuleList) assert isinstance(model[0], FSDP)
def test_auto_wrap_preset_force_leaf(self): """ Test to ensure force-leaf modules are not wrapped, and children are not wrapped. """ with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False): sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1)) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) assert isinstance(model.module[0], FSDP) # Assert children of multihead attention are not wrapped assert isinstance(model.module[1], nn.MultiheadAttention) assert isinstance(model.module[1].out_proj, nn.Linear)
def test_auto_wrap_preset_exclude_wrap(self): """ Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the min_num_params. """ with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False): sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)]) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) assert isinstance(model, nn.ModuleList) assert isinstance(model[0], nn.Linear) assert isinstance(model[1], nn.Linear)
def test_auto_wrap_preset_blocklist(self): """ Test to ensure blocklisted modules are not wrapped, and children are not wrapped. """ with enable_wrap(process_group=self.process_group, flatten_parameters=False): sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1)) model = auto_wrap(sequential, min_num_params=40) assert isinstance(model.module[0], FSDP) # Assert children of multihead attention are not wrapped assert isinstance(model.module[1], nn.MultiheadAttention) assert isinstance(model.module[1].out_proj, nn.Linear)
def test_auto_wrap_preset_blocklist_custom(self): """ Test to ensure blocklisted modules are not wrapped. """ with enable_wrap(module_blocklist=[nn.Linear], process_group=self.process_group, flatten_parameters=False): sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)])) model = auto_wrap(sequential, min_num_params=40) # Model was wrapped in FSDP as no inner modules were wrapped. assert isinstance(model, FSDP) assert isinstance(model.module[0], nn.Linear) assert isinstance(model.module[1], nn.ModuleList)
def test_auto_wrap(self): """ Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params. ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do. """ with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False): sequential = nn.Sequential( nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)) ) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) assert isinstance(model, FSDP) assert isinstance(model.module[0], nn.Linear) assert isinstance(model.module[1], nn.Linear) assert isinstance(model.module[2], FSDP) assert isinstance(model.module[2].module[0], nn.Linear) assert isinstance(model.module[2].module[1], nn.Linear)
def model_sharded_context(self) -> Generator: precision = self.lightning_module.trainer.precision def wrap_policy(*args, **kwargs): return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params) with enable_wrap( wrapper_cls=FullyShardedDataParallel, auto_wrap_policy=wrap_policy, process_group=self.process_group, cpu_offload=self.cpu_offload, move_grads_to_cpu=self.move_grads_to_cpu, flatten_parameters=self.flatten_parameters, mixed_precision=precision == "mixed", reshard_after_forward=self.reshard_after_forward, fp32_reduce_scatter=self.fp32_reduce_scatter, compute_dtype=self.compute_dtype, bucket_cap_mb=self.bucket_cap_mb, state_dict_device=self.state_dict_device, ): yield
def test_auto_wrap_preset_force_leaf_custom(self): """ Test to ensure force-leaf modules are not wrapped. """ my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40, force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES.union({nn.Linear}), ) with enable_wrap( auto_wrap_policy=my_auto_wrap_policy, wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False, ): sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)])) model = auto_wrap(sequential) # Model was wrapped in FSDP as no inner modules were wrapped. assert isinstance(model, FSDP) assert isinstance(model.module[0], nn.Linear) assert isinstance(model.module[1], nn.ModuleList)
def _auto_wrap_smoke_test(self, enable_mixed_precision): from torch.cuda.amp import autocast device = torch.device("cuda") torch.cuda.set_device(0) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) with enable_wrap(mixed_precision=enable_mixed_precision): sequential = nn.Sequential( nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))) model = auto_wrap(sequential, min_num_params=40) model.to(device) input = torch.rand((1, 5), dtype=torch.float).to(device) with autocast(enabled=enable_mixed_precision): output = model(input) loss = F.mse_loss(input, output) loss.backward() torch.distributed.destroy_process_group()
def test_wrap_override_defaults(self): with enable_wrap(wrapper_cls=FSDP, flatten_parameters=False, process_group=self.process_group): layer = wrap(nn.Linear(5, 5), flatten_parameters=True) assert isinstance(layer, FSDP) assert layer.flatten_parameters
def test_wrap(self): with enable_wrap(flatten_parameters=False, process_group=self.process_group): layer = wrap(nn.Linear(5, 5)) assert isinstance(layer, FSDP) assert layer.flatten_parameters is False