def _init_param_attributes(self, p: Parameter) -> None: """ We manage several attributes on each Parameter instance. The first two are set by :func:`_shard_parameters`: ``_is_sharded``: ``True`` if the Parameter is sharded or ``False`` if the Parameter is intentionally not sharded (in which case we will all-reduce grads for this param). Currently the only way `_is_sharded = False` is if world_size = 1. ``_orig_size``: the size of the original Parameter (before sharding) A few attributes are set here: ``_local_shard``: a single shard of the parameter. This is needed to recover the shard after rebuilding full parameter in forward and backward. ``_full_param_padded``: the full weight (padded to be evenly divisible by ``world_size``), used for computation in the forward and backward pass. It is initialized with the appropriate size and then has its storage freed. This will be resized in place and only materialized (via all-gather) as needed. Another attribute is set by :func:`_register_post_backward_hooks`: ``_shard_bwd_hook``: it holds the parameter's AccumulateGrad object and the registered post hook handle. """ assert hasattr(p, "_is_sharded") and hasattr( p, "_orig_size" ), "Parameters should have been sharded during construction." if hasattr(p, "_local_shard"): return # A single shard of the parameters. p._local_shard = p.data # type: ignore[attr-defined] # We also maintain a full-sized parameter of type self.compute_dtype. # We resize the storage to size 0 at init (here) and only materialize # as needed. The storage may contain padding elements so that it is # evenly divisible by world_size, although these padding elements will # be removed before the relevant computation. if p._is_sharded: # type: ignore[attr-defined] p._full_param_padded = torch.zeros( # type: ignore[attr-defined] p.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype, ) _free_storage(p._full_param_padded) # type: ignore[attr-defined]
def _init_param_attributes(self, p: Parameter) -> None: """ We manage several attributes on each Parameter instance. The first two are set by :func:`_shard_parameters`: ``_is_sharded``: ``True`` if the Parameter is sharded or ``False`` if the Parameter is intentionally not sharded (in which case we will all-reduce grads for this param). Currently the only way `_is_sharded = False` is if world_size = 1. ``_orig_size``: the size of the original Parameter (before sharding) A few attributes are set here: ``_local_shard``: a single shard of the parameter. This is needed to recover the shard after rebuilding full parameter in forward and backward. ``_full_param_padded``: the full weight (padded to be evenly divisible by ``world_size``), used for computation in the forward and backward pass. It is initialized with the appropriate size and then has its storage freed. This will be resized in place and only materialized (via all-gather) as needed. Another attribute is set by :func:`_register_post_backward_hooks`: ``_shard_bwd_hook``: it holds the parameter's AccumulateGrad object and the registered post hook handle. """ assert hasattr(p, "_is_sharded") and hasattr( p, "_orig_size" ), "Parameters should have been sharded during construction." if hasattr(p, "_local_shard"): # If CPU offloading, p._local_shard should have been placed on CPU # during its first lazy construction. if self.cpu_offload.offload_params: assert p._local_shard.device == torch.device( # type: ignore[attr-defined] "cpu" ), ( "Expected p._local_shard to be on CPU, " # type: ignore[attr-defined] f"but it's on {p._local_shard.device}" # type: ignore[attr-defined] ) return # A single shard of the parameters. Also makes p._local_shard to be on # CPU if we are CPU offloading, since p.data would be on CPU during # init. if self.cpu_offload.offload_params: assert p.device == torch.device( "cpu" ), "Expected param to be on CPU when cpu_offloading is enabled." p._local_shard = p.data # type: ignore[attr-defined] # If CPU offloading, pin the memory to enable faster CPU -> GPU device # transfer. if self.cpu_offload.offload_params: assert p._local_shard.device == torch.device( "cpu") # type: ignore[attr-defined] p._local_shard.pin_memory() # type: ignore[attr-defined] # When offloading parameters, also move the grad shard to CPU during # backward pass. In this case, it's important to pre-allocate the # CPU grad shard in pinned memory so that we can do a non-blocking # transfer. p._cpu_grad = torch.zeros_like( # type: ignore[attr-defined] p, device=torch.device("cpu")).pin_memory() # We also maintain a full-sized parameter of type self.compute_dtype. # We resize the storage to size 0 at init (here) and only materialize # as needed. The storage may contain padding elements so that it is # evenly divisible by world_size, although these padding elements will # be removed before the relevant computation. if p._is_sharded: # type: ignore[attr-defined] p._full_param_padded = torch.zeros( # type: ignore[attr-defined] p.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype, ) _free_storage(p._full_param_padded) # type: ignore[attr-defined]