def load_factors_from_dir(self, compute_inverses: bool = True) -> None: """Load factors from `factor_checkpoint_dir`.""" if self.factor_checkpoint_dir is None: raise ValueError('factor_checkpoint_dir is None.') if not os.path.isdir(self.factor_checkpoint_dir): warnings.warn( f'factor_checkpoint_dir={self.factor_checkpoint_dir} ' 'is not a directory. Skipping KFAC checkpoint load.', ) return for name, layer in self._layers.values(): if ( cast(GPTNeoXAssignment, self._assignment).factor_worker( name, 'A', ) == get_rank() ): filepath = os.path.join(self.factor_checkpoint_dir, name) if os.path.exists(filepath): logger.info( f'loading KFAC factors for {name} on rank ' f'{get_rank()}', ) state_dict = torch.load(filepath) layer.load_state_dict(state_dict) if compute_inverses: layer.compute_a_inv(damping=self.damping) layer.compute_g_inv(damping=self.damping)
def load_state_dict( self, state_dict: dict[str, Any], compute_inverses: bool = True, ) -> None: """Load state dict.""" layers = state_dict.pop('layers', None) super().load_state_dict(state_dict, compute_inverses=False) if self.factor_checkpoint_dir is not None: self.load_factors_from_dir(compute_inverses) return if layers is None: return for found_name, layer_state_dict in layers.items(): for name, layer in self._layers.values(): if ( found_name == name and cast( GPTNeoXAssignment, self._assignment, ).factor_worker(name, 'A') == get_rank() ): assert isinstance(layer_state_dict['A'], torch.Tensor) assert isinstance(layer_state_dict['G'], torch.Tensor) layer.load_state_dict(layer_state_dict) if compute_inverses: layer.compute_a_inv(damping=self.damping) layer.compute_g_inv(damping=self.damping) torch.distributed.barrier()
def broadcast_g_inv( self, src: int, group: dist.ProcessGroup | None = None, ) -> None: """Initiate G inv broadcast and store future to result. Note: all ranks must enter this function even if the rank is not a part of the inverse broadcast group. Args: src (int): src rank that computed G inverse. group (ProcessGroup): process group to which src should broadcast G inv. All ranks in group should enter this function. Defaults to None, the default process group. """ if self.g_inv is None: if get_rank() == src: raise RuntimeError( f'Attempt to broadcast G inv from src={src} but this rank ' 'has not computed G inv yet.', ) assert isinstance(self.g_factor, torch.Tensor) self.g_inv = torch.empty( self.g_factor.shape, device=self.g_factor.device, dtype=self.inv_dtype, ) self.g_inv = self.tdc.broadcast( # type: ignore self.g_inv, src=src, group=group, symmetric=self.symmetric_factors and self.symmetry_aware, )
def broadcast_grad( self, src: int, group: dist.ProcessGroup | None = None, ) -> None: """Broadcast preconditioned gradient and store future to result. Note: all ranks must enter this function. Args: src (int): src rank that preconditioned the gradient. group (ProcessGroup): process group to which src should broadcast the gradient. All ranks in group should enter this function. Defaults to None, the default process group. """ if self.grad is None: if get_rank() == src: raise RuntimeError( f'Attempt to broadcast gradient from src={src} but this ' 'rank has not computed the preconditioned gradient yet.', ) self.grad = torch.empty_like(self.module.get_grad()) self.grad = self.tdc.broadcast( # type: ignore self.grad, src=src, group=group, )
def reduce_g_factor( self, group: torch.distributed.ProcessGroup | None = None, ) -> None: # pragma: no cover """Initiate reduction of G and store future to result. Note: all ranks should enter this function. Args: group (ProcessGroup): ignored because the correct group depends on if the parallelism is on the output or input of the layer. """ if self.primary_rank is None: raise RuntimeError('primary rank has not been set yet.') valid = (torch.distributed.ProcessGroup, type(None)) if not isinstance(self.data_parallel_group, valid) or not isinstance( self.pipe_parallel_peer_group, valid, ): raise RuntimeError( 'data_parallel_group or pipe_parallel_peer_group has not ' 'been set yet.', ) if self.parallelism == 'input': super().reduce_g_factor(self.pipe_parallel_peer_group) elif self.parallelism == 'output': if get_rank() != self.primary_rank: return super().reduce_g_factor(self.data_parallel_group) else: raise AssertionError('Unreachable.')
def save_factors_to_dir(self) -> None: """Save factors to `factor_checkpoint_dir`. Saves the state dict for each layer to a separate file in `self.factor_checkpoint_dir` with the filename as the name of the layer. Note: only the inverse worker for the layer will save the layer. """ if self.factor_checkpoint_dir is None: raise ValueError('factor_checkpoint_dir is None') if get_rank() == 0: os.makedirs(self.factor_checkpoint_dir, exist_ok=True) torch.distributed.barrier() for name, layer in self._layers.values(): if get_rank() == self._assignment.inv_worker(name, 'A'): layer_state_dict = layer.state_dict() filepath = os.path.join(self.factor_checkpoint_dir, name) logger.info(f'saving KFAC factors for {name} to {filepath}') torch.save(layer_state_dict, filepath)
def broadcast_a_inv( self, src: int, group: dist.ProcessGroup | None = None, ) -> None: """Initiate A inv broadcast and store future to result. Note: all ranks must enter this function even if the rank is not a part of the inverse broadcast group. Args: src (int): src rank that computed A inverse. group (ProcessGroup): process group to which src should broadcast A inv. All ranks in group should enter this function. Defaults to None, the default process group. """ if self.qa is None or (not self.prediv_eigenvalues and self.da is None): if get_rank() == src: raise RuntimeError( f'Attempt to broadcast A inv from src={src} but this rank ' 'has not computed A inv yet.', ) assert isinstance(self.a_factor, torch.Tensor) self.qa = torch.empty( self.a_factor.shape, device=self.a_factor.device, dtype=self.inv_dtype, ) self.da = torch.empty( self.a_factor.shape[0], device=self.a_factor.device, dtype=self.inv_dtype, ) self.qa = self.tdc.broadcast( # type: ignore self.qa, src=src, group=group, ) if not self.prediv_eigenvalues: assert self.da is not None self.da = self.tdc.broadcast( # type: ignore self.da, src=src, group=group, )
def state_dict(self, include_factors: bool = True) -> dict[str, Any]: """Get state dict. Note: all ranks must enter this. """ state_dict = super().state_dict(include_factors=False) if not include_factors: return state_dict if self.factor_checkpoint_dir is not None: self.save_factors_to_dir() return state_dict partition: list[tuple[str, dict[str, Any]]] = [] for name, layer in self._layers.values(): # Inv worker for A and G is the same for GPT training if get_rank() == self._assignment.inv_worker(name, 'A'): layer_state_dict = layer.state_dict() assert layer_state_dict['A'] is not None assert layer_state_dict['G'] is not None # Move to CPU where we have more RAM layer_state_dict['A'] = layer_state_dict['A'].cpu() layer_state_dict['G'] = layer_state_dict['G'].cpu() partition.append((name, layer_state_dict)) partitions = [None for _ in range(get_world_size())] # Use gloo group because we moved data to CPU for more RAM group = torch.distributed.new_group(backend='gloo') torch.distributed.all_gather_object(partitions, partition, group=group) layers = {} for partition in partitions: # type: ignore for name, layer_state_dict in partition: layers[name] = layer_state_dict state_dict['layers'] = layers torch.distributed.barrier(group) return state_dict
def step(self) -> None: """Perform one K-FAC step. Note: This function should always be called before `optimizer.step()` as it modifies the gradients and does not modify the weights. Note: Gradients must be averaged across ranks before calling `step()`. This condition is guarenteed to be true if using the `DistributedDataParallel` model wrapper as gradients are communicated during `loss.backward()`. """ if ( not self._update_factors_in_hook and self.steps % self.factor_update_steps == 0 ): for name, layer in reversed(list(self._layers.values())): self._mini_steps[name] = 0 layer.update_a_factor(alpha=self.factor_decay) layer.reduce_a_factor(self._assignment.factor_group(name, 'A')) layer.update_g_factor(alpha=self.factor_decay) layer.reduce_g_factor(self._assignment.factor_group(name, 'G')) # Flush last allreduce bucket from forward/backward pass. # Will be a no-op if bucketing was not used self._tdc.flush_allreduce_buckets() # Compute Inverses if self.steps % self.inv_update_steps == 0: for name, layer in reversed(list(self._layers.values())): if get_rank() == self._assignment.inv_worker(name, 'A'): layer.compute_a_inv(damping=self.damping) if ( self._assignment.broadcast_inverses() and self._assignment.is_grad_worker(name) ): layer.broadcast_a_inv( src=self._assignment.inv_worker(name, 'A'), group=self._assignment.grad_worker_group(name), ) if get_rank() == self._assignment.inv_worker(name, 'G'): layer.compute_g_inv(damping=self.damping) if ( self._assignment.broadcast_inverses() and self._assignment.is_grad_worker(name) ): layer.broadcast_g_inv( src=self._assignment.inv_worker(name, 'G'), group=self._assignment.grad_worker_group(name), ) self._tdc.flush_allreduce_buckets() # Compute Preconditioned Gradients for name, layer in reversed(list(self._layers.values())): if self._assignment.is_grad_worker(name): layer.preconditioned_grad(damping=self.damping) if self._assignment.broadcast_gradients(): layer.broadcast_grad( src=self._assignment.src_grad_worker(name), group=self._assignment.grad_receiver_group(name), ) self._tdc.flush_allreduce_buckets() scale = None if self.kl_clip is None else self._compute_grad_scale() # Update gradients in-place for _, layer in reversed(list(self._layers.values())): layer.update_grad(scale=scale) self._steps += 1 self._mini_steps = defaultdict(int)
def __init__( self, model: torch.nn.Module, *, factor_update_steps: Callable[[int], int] | int = 1, inv_update_steps: Callable[[int], int] | int = 1, # KFAC hyperparameters damping: Callable[[int], float] | float = 0.001, factor_decay: Callable[[int], float] | float = 0.95, kl_clip: Callable[[int], float] | float = 0.001, lr: Callable[[int], float] | float = 0.1, # Distribution strategy accumulation_steps: int = 1, allreduce_bucket_cap_mb: float = 25.0, assignment_strategy: ( AssignmentStrategy | str ) = AssignmentStrategy.COMPUTE, compute_method: ComputeMethod | str = ComputeMethod.EIGEN, compute_eigenvalue_outer_product: bool = False, symmetry_aware: bool = False, # DeepSpeed 3D parallelism data_parallel_group: torch.distributed.ProcessGroup | None = None, model_parallel_group: torch.distributed.ProcessGroup | None = None, pipeline_parallel_group: torch.distributed.ProcessGroup | None = None, # Optional other parameters grad_scaler: ( torch.cuda.amp.GradScaler | Callable[[], float] | None ) = None, factor_dtype: torch.dtype | None = None, inv_dtype: torch.dtype = torch.float32, factor_checkpoint_dir: str | None = None, skip_layers: list[str] | None = None, update_factors_in_hook: bool = True, loglevel: int = logging.DEBUG, ) -> None: """Init KFACPreconditioner. Args: model (torch.nn.Module): model to precondition with KFAC. factor_update_steps (Callable, int): steps between computing and updating the running average of the Kronecker factors or callable that takes the K-FAC step and returns the value. inv_update_steps (Callble, int): steps between recomputing and communicating the second-order information or callable that takes the K-FAC step and returns the value. damping (Callable, float): Tikhonov damping parameter or a callable that takes the K-FAC step and returns the damping parameter as a float (default: 0.001). factor_decay (Callable, float): running average coefficient for Kronecker factors or callable that takes the K-FAC step and returns the factor_decay (default: 0.95). kl_clip (Callable, float): clipping parameter for gradient scaling or a callable that takes the K-FAC step and returns a float. If None, no scaling/clipping will be applied (default: 0.001). lr (Callable, float): learning rate or callable that takes the K-FAC step and returns learning rate (default: 0.1). accumulation_steps (int): number of forward/backward passes between optimization steps (default: 1). allreduce_bucket_cap_mb (float): maximum size in megabytes for allreduce bucketing. If zero, bucketing is not used (default: 25). assignment_strategy (AssignmentStrategy, str): See `AssignmentStrategy` for more details (default: AssignmentStrategy.COMPUTE). compute_method (ComputeMethod, str): See `ComputeMethod` for more details (default: ComputeMethod.EIGEN). compute_eigenvalue_outer_product (bool): when using the eigen compute method, precompute the element-wise inverse of the outer product of eigenvectors on the eigen decomposition worker rather to reduce computation in the gradient preconditioning stage. `colocate_factors` must be True (default: True). symmetry_aware (bool): communicate only the upper triangle of symmetric matrices. Can reduce communication time when factors are large (default: False). data_parallel_group (ProcessGroup): DeepSpeed data parallel group. model_parallel_group (ProcessGroup): DeepSpeed model parallel group. pipeline_parallel_group (ProcessGroup): DeepSpeed pipeline parallel group. grad_scaler (torch.cuda.amp.GradScaler or callable): Gradient scaler used for Torch AMP training. Used to unscale the G factors as they are accumulated during the backward pass. Alternatively can be a callable which will return the current scale (default: None). factor_dtype (torch.dtype): force data type for storing factors. If None, defaults to data type of intermediate values in forward/backward pass (default: None). inv_dtype (torch.dtype): force data type for storing second-order data (e.g., inverses or eigen decompositions) (default: torch.float32). factor_checkpoint_dir (str): directory to store factors checkpoints in. skip_layers (list): list of module names to ignore when registering layers. Passing the name of parent modules will prevent recursively registering child modules of the parent. Case-insensitive (default: []). update_factors_in_hook (bool): If True, running average of factors is updated in the module hook and the async commmunication is started. Otherwise, this will be performed at the start of step() (default: True). loglevel (int): logging level (default: logging.DEBUG). """ if deepspeed_import_error is not None: # pragma: no cover raise deepspeed_import_error warnings.warn( 'KFAC support for GPT-NeoX training is experimental.', ExperimentalFeatureWarning, ) if not isinstance(model, PipelineModule): raise ValueError( 'model must be an instance of deepspeed.pipe.PipelineModule. ' f'Got an instance of {type(model)}.', ) if allreduce_bucket_cap_mb < 0: raise ValueError('allreduce_bucket_cap_mb must be >= 0') if isinstance(assignment_strategy, str): assignment_strategy = AssignmentStrategy[ assignment_strategy.upper() ] if isinstance(compute_method, str): compute_method = ComputeMethod[compute_method.upper()] self.allreduce_bucket_cap_mb = allreduce_bucket_cap_mb self.assignment_strategy = assignment_strategy self.compute_eigenvalue_outer_product = ( compute_eigenvalue_outer_product ) self.compute_method = compute_method self.grad_scaler = grad_scaler self.factor_dtype = factor_dtype self.inv_dtype = inv_dtype self.factor_checkpoint_dir = factor_checkpoint_dir self.skip_layers = [] if skip_layers is None else skip_layers self.symmetry_aware = symmetry_aware self.data_parallel_group = data_parallel_group self.model_parallel_group = model_parallel_group self.pipeline_parallel_group = pipeline_parallel_group if self.allreduce_bucket_cap_mb > 0: self.allreduce_method = AllreduceMethod.ALLREDUCE_BUCKETED else: self.allreduce_method = AllreduceMethod.ALLREDUCE self.tdc = TorchDistributedCommunicator( bucket_cap_mb=self.allreduce_bucket_cap_mb, ) layer_kwargs = dict( allreduce_method=self.allreduce_method, grad_scaler=self.grad_scaler, factor_dtype=self.factor_dtype, inv_dtype=self.inv_dtype, symmetry_aware=self.symmetry_aware, tdc=self.tdc, ) if self.compute_method == ComputeMethod.EIGEN: pass elif self.compute_method == ComputeMethod.INVERSE: raise ValueError('Inverse method not supported with GPT NeoX.') else: raise AssertionError( f'Unknown compute_method={self.compute_method}', ) kfac_layers = register_modules( model, model_parallel_group=self.model_parallel_group, skip_layers=self.skip_layers, **layer_kwargs, ) data_parallel_ranks = [ -1 for _ in range(get_world_size(self.data_parallel_group)) ] torch.distributed.all_gather_object( object_list=data_parallel_ranks, obj=get_rank(), group=self.data_parallel_group, ) for name, kfac_layer in kfac_layers.values(): logger.log( loglevel, f'Registered name="{name}": {repr(kfac_layer)} on ' f'global-rank={get_rank()} and ' f'pipeline-rank={get_rank(self.pipeline_parallel_group)}', ) if self.assignment_strategy == AssignmentStrategy.COMPUTE: cost_func = lambda n: n**3 # noqa: E731 elif self.assignment_strategy == AssignmentStrategy.MEMORY: cost_func = lambda n: n**2 # noqa: E731 else: raise AssertionError( f'Unknown assignment_strategy={self.assignment_strategy}', ) work = { name: { 'A': cost_func(kfac_layer.module.a_factor_shape[0]), 'G': cost_func(kfac_layer.module.g_factor_shape[0]), } for name, kfac_layer in kfac_layers.values() } assignment = GPTNeoXAssignment( work, local_rank=get_rank(), topology=model.topology(), data_parallel_group=self.data_parallel_group, model_parallel_group=self.model_parallel_group, ) logger.log(loglevel, f'KFAC layer assignments: {assignment}') # Set primary rank for each layer for name, kfac_layer in kfac_layers.values(): # Inv worker for A and G should be same because GPTNeoXAssignment # uses gradient worker fraction 1/world_size (mem-opt) if not isinstance(kfac_layer, GPTNeoXKFACEigenLayer): raise AssertionError( 'GPTNeoXKFACPreconditioner only supports ' 'GPTNeoXKFACEigenLayer.', ) kfac_layer.primary_rank = assignment.factor_worker(name, 'A') kfac_layer.data_parallel_group = assignment.data_parallel_group kfac_layer.pipe_parallel_peer_group = ( assignment.pipe_parallel_peer_group ) defaults = { 'allreduce_bucket_cap_mb': self.allreduce_bucket_cap_mb, 'allreduce_method': self.allreduce_method, 'assignment_strategy': self.assignment_strategy, 'compute_eigenvalue_outer_product': ( self.compute_eigenvalue_outer_product ), 'compute_method': self.compute_method, 'grad_scaler': self.grad_scaler is not None, 'factor_checkpoint_dir': self.factor_checkpoint_dir, 'factor_dtype': self.factor_dtype, 'inv_dtype': self.inv_dtype, 'skip_layers': self.skip_layers, 'symmetry_aware': self.symmetry_aware, } super().__init__( kfac_layers, factor_update_steps=factor_update_steps, inv_update_steps=inv_update_steps, factor_decay=factor_decay, damping=damping, kl_clip=kl_clip, lr=lr, accumulation_steps=accumulation_steps, assignment=assignment, update_factors_in_hook=update_factors_in_hook, defaults=defaults, tdc=self.tdc, loglevel=loglevel, )
def preconditioned_grad( self, damping: float = 0.001, ) -> None: # pragma: no cover """Compute precondition gradient of each weight in module. Note: Unlike KFACEigenLayer, every rank in the model parallel group should enter this function. Preconditioned gradients can be applied to the actual gradients with `update_gradient()`. Note the steps are separate in the event that intermediate steps will be applied to the preconditioned gradient. Args: damping (float, optional): damping to use if preconditioning using the eigendecomposition method (default: 0.001). """ if self.primary_rank is None: raise RuntimeError('primary rank has not been set yet.') if get_rank() == self.primary_rank and ( self.qa is None or self.qg is None or (not self.prediv_eigenvalues and self.da is None) or (not self.prediv_eigenvalues and self.dg is None) or (self.prediv_eigenvalues and self.dgda is None)): raise RuntimeError( 'Eigendecompositions for both A and G have not been computed', ) grad_partition = self.module.get_weight_grad() grad = gather_from_model_parallel_region( grad_partition, dst=self.primary_rank, model_parallel_group=self.model_parallel_group, dim=-1 if self.parallelism == 'input' else 0, ) if self.module.has_bias(): bias_grad_partition = self.module.get_bias_grad() # Bias is only actually partitioned if parallelism is done on # output if self.parallelism == 'output': bias_grad = gather_from_model_parallel_region( bias_grad_partition, dst=self.primary_rank, model_parallel_group=self.model_parallel_group, dim=0, ) else: bias_grad = bias_grad_partition else: bias_grad = None if grad is not None: # Only perform preconditioning on worker that got the full gradient grad_shape = grad.size() if self.module.has_bias(): assert bias_grad is not None bias_grad_shape = bias_grad.size() grad = torch.cat([grad, bias_grad.view(-1, 1)], 1) # mypy won't know these are not none because they are properties assert self.da is not None assert self.dg is not None assert self.qa is not None assert self.qg is not None grad_type = grad.dtype grad = grad.to(self.qa.dtype) v1 = self.qg.t() @ grad @ self.qa if self.prediv_eigenvalues: v2 = v1 * self.dgda else: v2 = v1 / (torch.outer(self.dg, self.da) + damping) grad = (self.qg @ v2 @ self.qa.t()).to(grad_type) if self.module.has_bias(): weight_grad = grad[:, :-1].view(grad_shape) bias_grad = grad[:, -1:].view(bias_grad_shape).contiguous() else: weight_grad = grad.view(grad_shape) weight_grads = list( split_tensor_along_dim( weight_grad, get_world_size(self.model_parallel_group), dim=-1 if self.parallelism == 'input' else 0, contiguous_split_chunks=True, ), ) if self.module.has_bias() and self.parallelism == 'output': assert bias_grad is not None bias_grads = list( split_tensor_along_dim( bias_grad, get_world_size(self.model_parallel_group), dim=0, contiguous_split_chunks=True, ), ) else: weight_grads = [ torch.zeros_like(grad_partition) for _ in range(get_world_size(self.model_parallel_group)) ] if self.parallelism == 'output': bias_grads = [ torch.zeros_like(bias_grad_partition) for _ in range(get_world_size(self.model_parallel_group)) ] # PyTorch NCCL does not support scatter but we can emulate it # with reduce_scatter where the reduction operation is sum and the # non_src ranks contribute zero filled tensors if get_world_size(self.model_parallel_group) > 1: torch.distributed.reduce_scatter( grad_partition, weight_grads, group=self.model_parallel_group, ) else: grad_partition = weight_grads[0] if self.module.has_bias(): if get_world_size(self.model_parallel_group) > 1: if self.parallelism == 'output': torch.distributed.reduce_scatter( bias_grad_partition, bias_grads, group=self.model_parallel_group, ) bias_grad = bias_grad_partition else: torch.distributed.broadcast( bias_grad, src=self.primary_rank, group=self.model_parallel_group, ) assert bias_grad is not None self.grad = torch.cat([grad_partition, bias_grad.view(-1, 1)], 1) else: self.grad = grad_partition
def __init__( self, model: torch.nn.Module, *, factor_update_steps: Callable[[int], int] | int = 1, inv_update_steps: Callable[[int], int] | int = 1, # KFAC hyperparameters damping: Callable[[int], float] | float = 0.001, factor_decay: Callable[[int], float] | float = 0.95, kl_clip: Callable[[int], float] | float = 0.001, lr: Callable[[int], float] | float = 0.1, # Distribution strategy accumulation_steps: int = 1, allreduce_bucket_cap_mb: float = 25.0, assignment_strategy: (AssignmentStrategy | str) = AssignmentStrategy.COMPUTE, colocate_factors: bool = True, compute_method: ComputeMethod | str = ComputeMethod.EIGEN, compute_eigenvalue_outer_product: bool = True, grad_worker_fraction: (DistributedStrategy | float) = DistributedStrategy.COMM_OPT, symmetry_aware: bool = False, # Optional other parameters grad_scaler: (torch.cuda.amp.GradScaler | Callable[[], float] | None) = None, factor_dtype: torch.dtype | None = None, inv_dtype: torch.dtype = torch.float32, skip_layers: list[str] | None = None, update_factors_in_hook: bool = True, loglevel: int = logging.DEBUG, ) -> None: """Init KFACPreconditioner. Args: model (torch.nn.Module): model to precondition with KFAC. factor_update_steps (Callable, int): steps between computing and updating the running average of the Kronecker factors or callable that takes the K-FAC step and returns the value. inv_update_steps (Callble, int): steps between recomputing and communicating the second-order information or callable that takes the K-FAC step and returns the value. damping (Callable, float): Tikhonov damping parameter or a callable that takes the K-FAC step and returns the damping parameter as a float (default: 0.001). factor_decay (Callable, float): running average coefficient for Kronecker factors or callable that takes the K-FAC step and returns the factor_decay (default: 0.95). kl_clip (Callable, float): clipping parameter for gradient scaling or a callable that takes the K-FAC step and returns a float. If None, no scaling/clipping will be applied (default: 0.001). lr (Callable, float): learning rate or callable that takes the K-FAC step and returns learning rate (default: 0.1). accumulation_steps (int): number of forward/backward passes between optimization steps (default: 1). allreduce_bucket_cap_mb (float): maximum size in megabytes for allreduce bucketing. If zero, bucketing is not used (default: 25). assignment_strategy (AssignmentStrategy, str): See `AssignmentStrategy` for more details (default: AssignmentStrategy.COMPUTE). colocate_factors (bool): assign both factors for a single layer to the same worker. Reccomended when num_layers < world_size (default: True). compute_method (ComputeMethod, str): See `ComputeMethod` for more details (default: ComputeMethod.EIGEN). compute_eigenvalue_outer_product (bool): when using the eigen compute method, precompute the element-wise inverse of the outer product of eigenvectors on the eigen decomposition worker rather to reduce computation in the gradient preconditioning stage. `colocate_factors` must be True (default: True). grad_worker_fraction (DistributedStrategy, float): controls the fraction of workers assigned as gradient workers for each layer. Optionally, predefined configurations can be passed using the DistributedStrategy enum (default: DistributedStrategy.COMM_OPT). symmetry_aware (bool): communicate only the upper triangle of symmetric matrices. Can reduce communication time when factors are large (default: False). grad_scaler (torch.cuda.amp.GradScaler or callable): Gradient scaler used for Torch AMP training. Used to unscale the G factors as they are accumulated during the backward pass. Alternatively can be a callable which will return the current scale (default: None). factor_dtype (torch.dtype): force data type for storing factors. If None, defaults to data type of intermediate values in forward/backward pass (default: None). inv_dtype (torch.dtype): force data type for storing second-order data (e.g., inverses or eigen decompositions) (default: torch.float32). skip_layers (list[str]): regex patterns that if matched, will cause the layer to not be registered. The patterns will be applied against the layer's name and class name. update_factors_in_hook (bool): If True, running average of factors is updated in the module hook and the async commmunication is started. Otherwise, this will be performed at the start of step() (default: True). loglevel (int): logging level (default: logging.DEBUG). """ if allreduce_bucket_cap_mb < 0: raise ValueError('allreduce_bucket_cap_mb must be >= 0') if (compute_method == ComputeMethod.EIGEN and compute_eigenvalue_outer_product and not colocate_factors): raise ValueError( 'colocate_factors must be True to use ' 'compute_eigenvalue_outer_product', ) if isinstance(assignment_strategy, str): assignment_strategy = AssignmentStrategy[ assignment_strategy.upper()] if isinstance(compute_method, str): compute_method = ComputeMethod[compute_method.upper()] size = get_world_size() if isinstance(grad_worker_fraction, DistributedStrategy): distributed_strategy = grad_worker_fraction if distributed_strategy == DistributedStrategy.COMM_OPT: grad_worker_fraction = 1.0 elif distributed_strategy == DistributedStrategy.HYBRID_OPT: grad_worker_fraction = 0.5 elif distributed_strategy == DistributedStrategy.MEM_OPT: grad_worker_fraction = 1.0 / size else: raise AssertionError(f'Unknown enum {grad_worker_fraction}') else: if not 0 <= grad_worker_fraction or not 1 >= grad_worker_fraction: raise ValueError('grad_worker_fraction must in [0, 1]') if grad_worker_fraction == 0: grad_worker_fraction = 1.0 / size if size % max(1, round(size * grad_worker_fraction)) != 0: raise ValueError( 'grad_worker_fraction must produce groups of ' 'equal size', ) if grad_worker_fraction == 1: grad_worker_fraction = 1.0 # ensure float distributed_strategy = DistributedStrategy.COMM_OPT elif grad_worker_fraction <= 1 / size: distributed_strategy = DistributedStrategy.MEM_OPT else: distributed_strategy = DistributedStrategy.HYBRID_OPT assert isinstance(grad_worker_fraction, float) if (not colocate_factors and distributed_strategy is DistributedStrategy.MEM_OPT): warnings.warn( 'grad_worker_frac=1/world_size (MEM_OPT) requires ' 'colocate_factors=True. Enabling colocate_factors.', ) colocate_factors = True self.allreduce_bucket_cap_mb = allreduce_bucket_cap_mb self.assignment_strategy = assignment_strategy self.colocate_factors = colocate_factors self.compute_eigenvalue_outer_product = ( compute_eigenvalue_outer_product) self.compute_method = compute_method self.distributed_strategy = distributed_strategy self.grad_worker_fraction = grad_worker_fraction self.grad_scaler = grad_scaler self.factor_dtype = factor_dtype self.inv_dtype = inv_dtype self.skip_layers = [] if skip_layers is None else skip_layers self.symmetry_aware = symmetry_aware if self.allreduce_bucket_cap_mb > 0: self.allreduce_method = AllreduceMethod.ALLREDUCE_BUCKETED else: self.allreduce_method = AllreduceMethod.ALLREDUCE self.tdc = TorchDistributedCommunicator( bucket_cap_mb=self.allreduce_bucket_cap_mb, ) layer_kwargs = dict( allreduce_method=self.allreduce_method, grad_scaler=self.grad_scaler, factor_dtype=self.factor_dtype, inv_dtype=self.inv_dtype, symmetry_aware=self.symmetry_aware, tdc=self.tdc, ) layer_type: type[KFACBaseLayer] if self.compute_method == ComputeMethod.EIGEN: layer_type = KFACEigenLayer layer_kwargs[ 'prediv_eigenvalues'] = self.compute_eigenvalue_outer_product elif self.compute_method == ComputeMethod.INVERSE: layer_type = KFACInverseLayer else: raise AssertionError( f'Unknown compute_method={self.compute_method}', ) kfac_layers = register_modules( model, kfac_layer_type=layer_type, skip_layers=self.skip_layers, **layer_kwargs, ) for name, kfac_layer in kfac_layers.values(): logger.log( loglevel, f'Registered name="{name}": {repr(kfac_layer)}', ) if self.assignment_strategy == AssignmentStrategy.COMPUTE: cost_func = lambda n: n**3 # noqa: E731 elif self.assignment_strategy == AssignmentStrategy.MEMORY: cost_func = lambda n: n**2 # noqa: E731 else: raise AssertionError( f'Unknown assignment_strategy={self.assignment_strategy}', ) work = { name: { 'A': cost_func(kfac_layer.module.a_factor_shape[0]), 'G': cost_func(kfac_layer.module.g_factor_shape[0]), } for name, kfac_layer in kfac_layers.values() } new_group = cast( Callable[[List[int]], dist.ProcessGroup], dist.new_group, ) mock_new_group: Callable[[list[int]], None] = lambda x: None assignment = KAISAAssignment( work, local_rank=get_rank(), world_size=get_world_size(), grad_worker_fraction=self.grad_worker_fraction, group_func=new_group if dist.is_initialized() else mock_new_group, colocate_factors=self.colocate_factors, ) logger.log(loglevel, f'KFAC layer assignments: {assignment}') defaults = { 'allreduce_bucket_cap_mb': self.allreduce_bucket_cap_mb, 'allreduce_method': self.allreduce_method, 'assignment_strategy': self.assignment_strategy, 'colocate_factors': self.colocate_factors, 'compute_eigenvalue_outer_product': (self.compute_eigenvalue_outer_product), 'compute_method': self.compute_method, 'distributed_strategy': self.distributed_strategy, 'grad_worker_fraction': self.grad_worker_fraction, 'grad_scaler': self.grad_scaler is not None, 'factor_dtype': self.factor_dtype, 'inv_dtype': self.inv_dtype, 'skip_layers': self.skip_layers, 'symmetry_aware': self.symmetry_aware, } super().__init__( kfac_layers, factor_update_steps=factor_update_steps, inv_update_steps=inv_update_steps, factor_decay=factor_decay, damping=damping, kl_clip=kl_clip, lr=lr, accumulation_steps=accumulation_steps, assignment=assignment, update_factors_in_hook=update_factors_in_hook, defaults=defaults, tdc=self.tdc, loglevel=loglevel, )
def test_distributed_not_initialized() -> None: """Test rank/world_size functions when not using distributed.""" assert get_rank() == 0 assert get_world_size() == 1