def configure_sharded_model(self) -> None:
     # the model is already wrapped with FSDP: no need to wrap again!
     if isinstance(self.layer, FullyShardedDataParallel):
         return
     for i, layer in enumerate(self.layer):
         if i % 2 == 0:
             self.layer[i] = wrap(layer)
     self.layer = wrap(self.layer)
Exemplo n.º 2
0
def _create_model(fsdp_config, compute_cycles, has_params: bool):
    with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
        model = wrap(
            nn.Sequential(
                wrap(Layer(compute_cycles, has_params)),
                wrap(Layer(compute_cycles, has_params)),
                wrap(Layer(compute_cycles, has_params)),
                wrap(Layer(compute_cycles, has_params)),
            )).cuda()
    return model
def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs):
    """
    Helper to wrap layers/modules in FSDP. This falls back to a no-op if
    fairscale is not available.

    Args:
        module (nn.Module): module to (maybe) wrap
        min_num_params (int, Optional): minimum number of layer params to wrap
    """
    try:
        from fairscale.nn import wrap
        if min_num_params is not None:
            num_params = sum(p.numel() for p in module.parameters())
            if num_params >= min_num_params:
                return wrap(module, **kwargs)
            else:
                return module
        else:
            return wrap(module, **kwargs)
    except ImportError:
        return module
Exemplo n.º 4
0
 def test_wrap_override_defaults(self):
     with enable_wrap(wrapper_cls=FSDP, flatten_parameters=False, process_group=self.process_group):
         layer = wrap(nn.Linear(5, 5), flatten_parameters=True)
     assert isinstance(layer, FSDP)
     assert layer.flatten_parameters
Exemplo n.º 5
0
 def test_wrap_disabled_outside_context(self):
     layer = wrap(nn.Linear(5, 5))
     assert isinstance(layer, nn.Linear)
 def configure_sharded_model(self) -> None:
     for i, layer in enumerate(self.layer):
         if i % 2 == 0:
             self.layer[i] = wrap(layer)
     self.layer = wrap(self.layer)
Exemplo n.º 7
0
 def test_wrap(self):
     with enable_wrap(flatten_parameters=False,
                      process_group=self.process_group):
         layer = wrap(nn.Linear(5, 5))
     assert isinstance(layer, FSDP)
     assert layer.flatten_parameters is False