def configure_sharded_model(self) -> None: # the model is already wrapped with FSDP: no need to wrap again! if isinstance(self.layer, FullyShardedDataParallel): return for i, layer in enumerate(self.layer): if i % 2 == 0: self.layer[i] = wrap(layer) self.layer = wrap(self.layer)
def _create_model(fsdp_config, compute_cycles, has_params: bool): with enable_wrap(wrapper_cls=FSDP, **fsdp_config): model = wrap( nn.Sequential( wrap(Layer(compute_cycles, has_params)), wrap(Layer(compute_cycles, has_params)), wrap(Layer(compute_cycles, has_params)), wrap(Layer(compute_cycles, has_params)), )).cuda() return model
def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): """ Helper to wrap layers/modules in FSDP. This falls back to a no-op if fairscale is not available. Args: module (nn.Module): module to (maybe) wrap min_num_params (int, Optional): minimum number of layer params to wrap """ try: from fairscale.nn import wrap if min_num_params is not None: num_params = sum(p.numel() for p in module.parameters()) if num_params >= min_num_params: return wrap(module, **kwargs) else: return module else: return wrap(module, **kwargs) except ImportError: return module
def test_wrap_override_defaults(self): with enable_wrap(wrapper_cls=FSDP, flatten_parameters=False, process_group=self.process_group): layer = wrap(nn.Linear(5, 5), flatten_parameters=True) assert isinstance(layer, FSDP) assert layer.flatten_parameters
def test_wrap_disabled_outside_context(self): layer = wrap(nn.Linear(5, 5)) assert isinstance(layer, nn.Linear)
def configure_sharded_model(self) -> None: for i, layer in enumerate(self.layer): if i % 2 == 0: self.layer[i] = wrap(layer) self.layer = wrap(self.layer)
def test_wrap(self): with enable_wrap(flatten_parameters=False, process_group=self.process_group): layer = wrap(nn.Linear(5, 5)) assert isinstance(layer, FSDP) assert layer.flatten_parameters is False