def test_base_preconditioner_init() -> None: """Test BaseKFACPreconditioner initialize.""" factor_update_steps = 1 inv_update_steps = 2 damping = 0.003 factor_decay = 0.95 kl_clip = 0.001 lr = 0.1 accumulation_steps = 1 preconditioner = BaseKFACPreconditioner( layers=example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), factor_update_steps=factor_update_steps, inv_update_steps=inv_update_steps, damping=damping, factor_decay=factor_decay, kl_clip=kl_clip, lr=lr, accumulation_steps=accumulation_steps, ) assert preconditioner.damping == damping assert preconditioner.factor_decay == factor_decay assert preconditioner.kl_clip == kl_clip assert preconditioner.lr == lr assert preconditioner.factor_update_steps == factor_update_steps assert preconditioner.inv_update_steps == inv_update_steps assert preconditioner.steps == 0 preconditioner = BaseKFACPreconditioner( layers=example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), damping=lambda x: damping, factor_decay=lambda x: factor_decay, kl_clip=lambda x: kl_clip, lr=lambda x: lr, ) assert preconditioner.damping == damping assert preconditioner.factor_decay == factor_decay assert preconditioner.kl_clip == kl_clip assert preconditioner.lr == lr defaults = {'default1': None, 'default2': None} preconditioner2 = BaseKFACPreconditioner( layers=example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), defaults=defaults, ) assert repr(preconditioner) != repr(preconditioner2) # repr() should list all parameters including those passed in with the # defaults parameter assert 'default1' in repr(preconditioner2) assert 'default2' in repr(preconditioner2)
def test_nonsymmetric_eigen() -> None: """Test nonsymmetric eigen decomposition.""" batch_size, in_features, out_features = 2, 5, 5 module = torch.nn.Linear(in_features, out_features) x = torch.rand([batch_size, in_features]) y = torch.rand([batch_size, out_features]) loss = (module(x) - y).sum() loss.backward() with mock.patch.object( LinearModuleHelper, 'has_symmetric_factors', return_value=False, ): module_helper = LinearModuleHelper(module) layer = KFACEigenLayer( module=module_helper, tdc=TorchDistributedCommunicator(), ) assert not layer.symmetric_factors layer.save_layer_input([x]) layer.save_layer_grad_output((y, )) layer.update_a_factor() layer.update_g_factor() layer.compute_a_inv() layer.compute_g_inv() layer.preconditioned_grad() layer.update_grad()
def allreduce( shape: list[int], tensor_count: int, bucket_cap_mb: float, symmetric: bool = False, expect_raises: type[BaseException] | None = None, ) -> None: """Test allreduce in distributed environment.""" try: world_size = torch.distributed.get_world_size() comm = TorchDistributedCommunicator(bucket_cap_mb) tensors = [] for _ in range(tensor_count): t = torch.ones(shape, dtype=torch.float32) tensors.append(comm.allreduce_bucketed(t, symmetric=symmetric), ) if world_size > 1: with pytest.raises(RuntimeError): comm._new_allreduce_bucket(None) comm.flush_allreduce_buckets() for tensor in tensors: if isinstance(tensor, Future): tensor = tensor.wait() assert isinstance(tensor, torch.Tensor) assert torch.sum(tensor).item() == world_size * torch.numel( tensor, ) except Exception as e: if expect_raises is not None and not isinstance(e, expect_raises): raise
def simple_allreduce( shape: list[int], symmetric: bool = False, expect_raises: type[BaseException] | None = None, ) -> None: try: world_size = torch.distributed.get_world_size() comm = TorchDistributedCommunicator() t = torch.ones(shape) t_res = comm.allreduce(t, symmetric=symmetric) if isinstance(t_res, Future): t_res = t_res.wait() assert isinstance(t_res, torch.Tensor) assert torch.sum(t_res).item() == torch.numel(t_res) * world_size except Exception as e: if expect_raises is not None and not isinstance(e, expect_raises): raise
def test_grad_scale_no_layers() -> None: """Test computing grad scale with no layers has no divide by 0 error.""" p = BaseKFACPreconditioner( layers=example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), ) p._layers = {} assert p._compute_grad_scale() == 1.0
def simple_broadcast( shape: list[int], symmetric: bool = False, expect_raises: type[BaseException] | None = None, ) -> None: try: rank = torch.distributed.get_rank() comm = TorchDistributedCommunicator() t = rank * torch.ones(shape) t_res = comm.broadcast(t, src=0, symmetric=symmetric) if isinstance(t_res, Future): t_res = t_res.wait() assert isinstance(t_res, torch.Tensor) # Rank 0 will broadcast and it should be all zeros assert torch.sum(t_res).item() == 0 except Exception as e: if expect_raises is not None and not isinstance(e, expect_raises): raise
def test_base_preconditioner_callable_hyperparams() -> None: """Test BaseKFACPreconditioner supports callable hyperparams.""" p = BaseKFACPreconditioner( example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), factor_update_steps=lambda x: x * 2, inv_update_steps=lambda x: x * 3, damping=lambda x: x * 5, factor_decay=lambda x: x * 7, kl_clip=lambda x: x * 9, ) for x in range(0, 10): p._steps = x assert p.factor_update_steps == x * 2 assert p.inv_update_steps == x * 3 assert p.damping == x * 5 assert p.factor_decay == x * 7 assert p.kl_clip == x * 9 p = BaseKFACPreconditioner( example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), factor_update_steps=lambda x: 2, inv_update_steps=lambda x: 3, damping=lambda x: 5, factor_decay=lambda x: 7, kl_clip=lambda x: 9, ) for x in range(0, 10): p._steps = x assert p.factor_update_steps == 2 assert p.inv_update_steps == 3 assert p.damping == 5 assert p.factor_decay == 7 assert p.kl_clip == 9
def example_layers() -> dict[torch.nn.Module, tuple[str, KFACBaseLayer]]: """Return register layers of LeNet with KFAC.""" return register_modules( LeNet(), kfac_layer_type=KFACInverseLayer, allreduce_method=AllreduceMethod.ALLREDUCE, grad_scaler=None, factor_dtype=None, inv_dtype=torch.float32, skip_layers=[], symmetry_aware=False, tdc=TorchDistributedCommunicator(), )
def test_register_modules( model: torch.nn.Module, layer_type: type[KFACBaseLayer], skip_layers: list[str], expected_count: int, ) -> None: """Test register_modules.""" kwargs = dict( allreduce_method=AllreduceMethod.ALLREDUCE, grad_scaler=None, factor_dtype=None, inv_dtype=torch.float32, symmetry_aware=False, tdc=TorchDistributedCommunicator(), ) kfac_layers = register_modules( model, layer_type, skip_layers=skip_layers, **kwargs, ) assert len(kfac_layers) == expected_count
def allreduce( shape: list[int], tensor_count: int, bucket_cap_mb: float, symmetric: bool = False, ) -> None: """Test allreduce in distributed environment.""" rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() comm = TorchDistributedCommunicator(bucket_cap_mb) if world_size == 1: group = None else: # Exclude rank 0 group = torch.distributed.new_group( [i for i in range(world_size) if i >= 1], ) group_size = torch.distributed.get_world_size(group) tensors = [] for _ in range(tensor_count): t = torch.ones(shape, dtype=torch.float32) if group is None or rank > 0: tensors.append( comm.allreduce_bucketed( t, symmetric=symmetric, group=group, ), ) comm.flush_allreduce_buckets() # All buckets should be removed now so calling again shouldn't be # an issue comm.flush_allreduce_buckets() for tensor in tensors: if isinstance(tensor, Future): tensor = tensor.wait() assert isinstance(tensor, torch.Tensor) assert torch.sum(tensor).item() == group_size * torch.numel( tensor, )
def __init__( self, model: torch.nn.Module, *, factor_update_steps: Callable[[int], int] | int = 1, inv_update_steps: Callable[[int], int] | int = 1, # KFAC hyperparameters damping: Callable[[int], float] | float = 0.001, factor_decay: Callable[[int], float] | float = 0.95, kl_clip: Callable[[int], float] | float = 0.001, lr: Callable[[int], float] | float = 0.1, # Distribution strategy accumulation_steps: int = 1, allreduce_bucket_cap_mb: float = 25.0, assignment_strategy: ( AssignmentStrategy | str ) = AssignmentStrategy.COMPUTE, compute_method: ComputeMethod | str = ComputeMethod.EIGEN, compute_eigenvalue_outer_product: bool = False, symmetry_aware: bool = False, # DeepSpeed 3D parallelism data_parallel_group: torch.distributed.ProcessGroup | None = None, model_parallel_group: torch.distributed.ProcessGroup | None = None, pipeline_parallel_group: torch.distributed.ProcessGroup | None = None, # Optional other parameters grad_scaler: ( torch.cuda.amp.GradScaler | Callable[[], float] | None ) = None, factor_dtype: torch.dtype | None = None, inv_dtype: torch.dtype = torch.float32, factor_checkpoint_dir: str | None = None, skip_layers: list[str] | None = None, update_factors_in_hook: bool = True, loglevel: int = logging.DEBUG, ) -> None: """Init KFACPreconditioner. Args: model (torch.nn.Module): model to precondition with KFAC. factor_update_steps (Callable, int): steps between computing and updating the running average of the Kronecker factors or callable that takes the K-FAC step and returns the value. inv_update_steps (Callble, int): steps between recomputing and communicating the second-order information or callable that takes the K-FAC step and returns the value. damping (Callable, float): Tikhonov damping parameter or a callable that takes the K-FAC step and returns the damping parameter as a float (default: 0.001). factor_decay (Callable, float): running average coefficient for Kronecker factors or callable that takes the K-FAC step and returns the factor_decay (default: 0.95). kl_clip (Callable, float): clipping parameter for gradient scaling or a callable that takes the K-FAC step and returns a float. If None, no scaling/clipping will be applied (default: 0.001). lr (Callable, float): learning rate or callable that takes the K-FAC step and returns learning rate (default: 0.1). accumulation_steps (int): number of forward/backward passes between optimization steps (default: 1). allreduce_bucket_cap_mb (float): maximum size in megabytes for allreduce bucketing. If zero, bucketing is not used (default: 25). assignment_strategy (AssignmentStrategy, str): See `AssignmentStrategy` for more details (default: AssignmentStrategy.COMPUTE). compute_method (ComputeMethod, str): See `ComputeMethod` for more details (default: ComputeMethod.EIGEN). compute_eigenvalue_outer_product (bool): when using the eigen compute method, precompute the element-wise inverse of the outer product of eigenvectors on the eigen decomposition worker rather to reduce computation in the gradient preconditioning stage. `colocate_factors` must be True (default: True). symmetry_aware (bool): communicate only the upper triangle of symmetric matrices. Can reduce communication time when factors are large (default: False). data_parallel_group (ProcessGroup): DeepSpeed data parallel group. model_parallel_group (ProcessGroup): DeepSpeed model parallel group. pipeline_parallel_group (ProcessGroup): DeepSpeed pipeline parallel group. grad_scaler (torch.cuda.amp.GradScaler or callable): Gradient scaler used for Torch AMP training. Used to unscale the G factors as they are accumulated during the backward pass. Alternatively can be a callable which will return the current scale (default: None). factor_dtype (torch.dtype): force data type for storing factors. If None, defaults to data type of intermediate values in forward/backward pass (default: None). inv_dtype (torch.dtype): force data type for storing second-order data (e.g., inverses or eigen decompositions) (default: torch.float32). factor_checkpoint_dir (str): directory to store factors checkpoints in. skip_layers (list): list of module names to ignore when registering layers. Passing the name of parent modules will prevent recursively registering child modules of the parent. Case-insensitive (default: []). update_factors_in_hook (bool): If True, running average of factors is updated in the module hook and the async commmunication is started. Otherwise, this will be performed at the start of step() (default: True). loglevel (int): logging level (default: logging.DEBUG). """ if deepspeed_import_error is not None: # pragma: no cover raise deepspeed_import_error warnings.warn( 'KFAC support for GPT-NeoX training is experimental.', ExperimentalFeatureWarning, ) if not isinstance(model, PipelineModule): raise ValueError( 'model must be an instance of deepspeed.pipe.PipelineModule. ' f'Got an instance of {type(model)}.', ) if allreduce_bucket_cap_mb < 0: raise ValueError('allreduce_bucket_cap_mb must be >= 0') if isinstance(assignment_strategy, str): assignment_strategy = AssignmentStrategy[ assignment_strategy.upper() ] if isinstance(compute_method, str): compute_method = ComputeMethod[compute_method.upper()] self.allreduce_bucket_cap_mb = allreduce_bucket_cap_mb self.assignment_strategy = assignment_strategy self.compute_eigenvalue_outer_product = ( compute_eigenvalue_outer_product ) self.compute_method = compute_method self.grad_scaler = grad_scaler self.factor_dtype = factor_dtype self.inv_dtype = inv_dtype self.factor_checkpoint_dir = factor_checkpoint_dir self.skip_layers = [] if skip_layers is None else skip_layers self.symmetry_aware = symmetry_aware self.data_parallel_group = data_parallel_group self.model_parallel_group = model_parallel_group self.pipeline_parallel_group = pipeline_parallel_group if self.allreduce_bucket_cap_mb > 0: self.allreduce_method = AllreduceMethod.ALLREDUCE_BUCKETED else: self.allreduce_method = AllreduceMethod.ALLREDUCE self.tdc = TorchDistributedCommunicator( bucket_cap_mb=self.allreduce_bucket_cap_mb, ) layer_kwargs = dict( allreduce_method=self.allreduce_method, grad_scaler=self.grad_scaler, factor_dtype=self.factor_dtype, inv_dtype=self.inv_dtype, symmetry_aware=self.symmetry_aware, tdc=self.tdc, ) if self.compute_method == ComputeMethod.EIGEN: pass elif self.compute_method == ComputeMethod.INVERSE: raise ValueError('Inverse method not supported with GPT NeoX.') else: raise AssertionError( f'Unknown compute_method={self.compute_method}', ) kfac_layers = register_modules( model, model_parallel_group=self.model_parallel_group, skip_layers=self.skip_layers, **layer_kwargs, ) data_parallel_ranks = [ -1 for _ in range(get_world_size(self.data_parallel_group)) ] torch.distributed.all_gather_object( object_list=data_parallel_ranks, obj=get_rank(), group=self.data_parallel_group, ) for name, kfac_layer in kfac_layers.values(): logger.log( loglevel, f'Registered name="{name}": {repr(kfac_layer)} on ' f'global-rank={get_rank()} and ' f'pipeline-rank={get_rank(self.pipeline_parallel_group)}', ) if self.assignment_strategy == AssignmentStrategy.COMPUTE: cost_func = lambda n: n**3 # noqa: E731 elif self.assignment_strategy == AssignmentStrategy.MEMORY: cost_func = lambda n: n**2 # noqa: E731 else: raise AssertionError( f'Unknown assignment_strategy={self.assignment_strategy}', ) work = { name: { 'A': cost_func(kfac_layer.module.a_factor_shape[0]), 'G': cost_func(kfac_layer.module.g_factor_shape[0]), } for name, kfac_layer in kfac_layers.values() } assignment = GPTNeoXAssignment( work, local_rank=get_rank(), topology=model.topology(), data_parallel_group=self.data_parallel_group, model_parallel_group=self.model_parallel_group, ) logger.log(loglevel, f'KFAC layer assignments: {assignment}') # Set primary rank for each layer for name, kfac_layer in kfac_layers.values(): # Inv worker for A and G should be same because GPTNeoXAssignment # uses gradient worker fraction 1/world_size (mem-opt) if not isinstance(kfac_layer, GPTNeoXKFACEigenLayer): raise AssertionError( 'GPTNeoXKFACPreconditioner only supports ' 'GPTNeoXKFACEigenLayer.', ) kfac_layer.primary_rank = assignment.factor_worker(name, 'A') kfac_layer.data_parallel_group = assignment.data_parallel_group kfac_layer.pipe_parallel_peer_group = ( assignment.pipe_parallel_peer_group ) defaults = { 'allreduce_bucket_cap_mb': self.allreduce_bucket_cap_mb, 'allreduce_method': self.allreduce_method, 'assignment_strategy': self.assignment_strategy, 'compute_eigenvalue_outer_product': ( self.compute_eigenvalue_outer_product ), 'compute_method': self.compute_method, 'grad_scaler': self.grad_scaler is not None, 'factor_checkpoint_dir': self.factor_checkpoint_dir, 'factor_dtype': self.factor_dtype, 'inv_dtype': self.inv_dtype, 'skip_layers': self.skip_layers, 'symmetry_aware': self.symmetry_aware, } super().__init__( kfac_layers, factor_update_steps=factor_update_steps, inv_update_steps=inv_update_steps, factor_decay=factor_decay, damping=damping, kl_clip=kl_clip, lr=lr, accumulation_steps=accumulation_steps, assignment=assignment, update_factors_in_hook=update_factors_in_hook, defaults=defaults, tdc=self.tdc, loglevel=loglevel, )
def test_base_preconditioner_init_raises() -> None: """Test BaseKFACPreconditioner raises.""" with pytest.raises(ValueError): BaseKFACPreconditioner( example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), factor_update_steps=-1, ) with pytest.raises(ValueError): BaseKFACPreconditioner( example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), inv_update_steps=-1, ) with pytest.raises(ValueError): BaseKFACPreconditioner( example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), damping=-1, ) with pytest.raises(ValueError): BaseKFACPreconditioner( example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), factor_decay=-1, ) with pytest.raises(ValueError): BaseKFACPreconditioner( example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), factor_decay=2, ) with pytest.raises(ValueError): BaseKFACPreconditioner( example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), kl_clip=-1, ) with pytest.raises(ValueError): BaseKFACPreconditioner( example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), lr=-1, ) with pytest.raises(ValueError): BaseKFACPreconditioner( example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), accumulation_steps=-1, ) with pytest.warns(): BaseKFACPreconditioner( example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), factor_update_steps=3, inv_update_steps=2, )
def e2e() -> None: """Helper to run training in simulated distributed environment.""" batch_size = 2 model = TinyModel() criterion = torch.nn.MSELoss(reduction='sum') optimizer = torch.optim.SGD(model.parameters(), lr=0.001) tdc = TorchDistributedCommunicator() layers = register_modules( model, KFACInverseLayer, allreduce_method=AllreduceMethod.ALLREDUCE, grad_scaler=None, factor_dtype=None, inv_dtype=torch.float32, skip_layers=[], symmetry_aware=False, tdc=tdc, ) preconditioner = BaseKFACPreconditioner( layers=layers, assignment=LazyAssignment(broadcast=broadcast), tdc=tdc, accumulation_steps=accumulation_steps, **kfac_args, ) for i in range(1, 10): x = torch.rand(batch_size, 10) y = torch.rand(batch_size, 10) y_pred = model(x) if i % accumulation_steps == 0: loss = criterion(y_pred, y) loss.backward() grad_weight_linear2 = model.linear2.weight.grad grad_bias_linear2 = model.linear2.bias.grad preconditioner.step() # Verify gradient was preconditioned assert not torch.equal( grad_weight_linear2, model.linear2.weight.grad, ) assert not torch.equal( grad_bias_linear2, model.linear2.bias.grad, ) optimizer.step() optimizer.zero_grad() # Test state dict computes inverses state_dict = preconditioner.state_dict() for _, layer in preconditioner._layers.values(): layer = cast(KFACInverseLayer, layer) layer.a_factor = None layer.g_factor = None layer.a_inv = None layer.g_inv = None preconditioner.load_state_dict(state_dict) for _, layer in preconditioner._layers.values(): layer = cast(KFACInverseLayer, layer) assert isinstance(layer.a_inv, torch.Tensor) assert isinstance(layer.g_inv, torch.Tensor) # Test grad hook supports tensor input rather than tuple preconditioner._save_grad_output( model.linear1, torch.rand(batch_size, 10), torch.rand(batch_size, 20), ) # Test hook additional functionality if preconditioner._update_factors_in_hook: # Reset preconditioner to ensure hooks trigger preconditioner._steps = 0 preconditioner._mini_steps = defaultdict(int) preconditioner._accumulation_steps = 100 # Do forward/backward pass to verify hooks trigger and we # have temp factors for batch x = torch.rand(batch_size, 10) y = torch.rand(batch_size, 10) loss = criterion(model(x), y) loss.backward() mem_usage = preconditioner.memory_usage() for mem in mem_usage.values(): assert mem > 0 preconditioner.reset_batch() # Make sure hooks do not trigger when model is not in training mode model.eval() x = torch.rand(batch_size, 10) y = torch.rand(batch_size, 10) loss = criterion(model(x), y) loss.backward() mem_usage = preconditioner.memory_usage() for key, mem in mem_usage.items(): if 'batch' in key: assert mem == 0
def test_empty_state_dict() -> None: """Test state dict functionality with no factors.""" p1 = BaseKFACPreconditioner( layers=example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), factor_update_steps=1, inv_update_steps=3, damping=5, factor_decay=0.7, kl_clip=11, lr=13, accumulation_steps=17, update_factors_in_hook=False, defaults={'default1': 19}, ) p1._steps = 99 state_dict = p1.state_dict(include_factors=False) # include_factors=True should add entries for the factors even though # they are None at this point assert state_dict != p1.state_dict(include_factors=True) # We filled p1 with non-default values so we can load the # state_dict of p1 into p2 and see what is loaded p2 = BaseKFACPreconditioner( layers=example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), ) p2.load_state_dict(state_dict, compute_inverses=False) assert p1.factor_update_steps == p2.factor_update_steps assert p1.inv_update_steps == p2.inv_update_steps assert p1.damping == p2.damping assert p1.factor_decay == p2.factor_decay assert p1.kl_clip == p2.kl_clip assert p1.lr == p2.lr # We only load the hyperparameters and training state assert p1._accumulation_steps != p2._accumulation_steps assert p1._update_factors_in_hook != p2._update_factors_in_hook assert p1._defaults != p2._defaults # Steps should be loaded assert p1._steps == p2._steps p3 = BaseKFACPreconditioner( layers=example_layers(), assignment=LazyAssignment(), tdc=TorchDistributedCommunicator(), factor_update_steps=lambda x: 1, inv_update_steps=lambda x: 3, damping=lambda x: 5, factor_decay=lambda x: 0.7, kl_clip=lambda x: 11, lr=lambda x: 13, ) state_dict = p3.state_dict() assert 'factor_update_steps' not in state_dict assert 'inv_update_steps' not in state_dict assert 'damping' not in state_dict assert 'factor_decay' not in state_dict assert 'kl_clip' not in state_dict assert 'lr' not in state_dict p3.load_state_dict(state_dict, compute_inverses=False) assert p3.factor_update_steps == 1 assert p3.inv_update_steps == 3 assert p3.damping == 5 assert p3.factor_decay == 0.7 assert p3.kl_clip == 11 assert p3.lr == 13 # Check warns user if they set compute_inverses but the state dict was # created with include_factors=False with pytest.warns(): p3.load_state_dict({'steps': 0}, compute_inverses=True) # Should cause no problems... but doesn't do much but set steps! p3.load_state_dict({'steps': 0}, compute_inverses=False) # Check mismatch in registered layers del state_dict['layers'][list(state_dict['layers'].keys()).pop()] with pytest.raises(ValueError): p3.load_state_dict(state_dict, compute_inverses=False)
def __init__( self, model: torch.nn.Module, *, factor_update_steps: Callable[[int], int] | int = 1, inv_update_steps: Callable[[int], int] | int = 1, # KFAC hyperparameters damping: Callable[[int], float] | float = 0.001, factor_decay: Callable[[int], float] | float = 0.95, kl_clip: Callable[[int], float] | float = 0.001, lr: Callable[[int], float] | float = 0.1, # Distribution strategy accumulation_steps: int = 1, allreduce_bucket_cap_mb: float = 25.0, assignment_strategy: (AssignmentStrategy | str) = AssignmentStrategy.COMPUTE, colocate_factors: bool = True, compute_method: ComputeMethod | str = ComputeMethod.EIGEN, compute_eigenvalue_outer_product: bool = True, grad_worker_fraction: (DistributedStrategy | float) = DistributedStrategy.COMM_OPT, symmetry_aware: bool = False, # Optional other parameters grad_scaler: (torch.cuda.amp.GradScaler | Callable[[], float] | None) = None, factor_dtype: torch.dtype | None = None, inv_dtype: torch.dtype = torch.float32, skip_layers: list[str] | None = None, update_factors_in_hook: bool = True, loglevel: int = logging.DEBUG, ) -> None: """Init KFACPreconditioner. Args: model (torch.nn.Module): model to precondition with KFAC. factor_update_steps (Callable, int): steps between computing and updating the running average of the Kronecker factors or callable that takes the K-FAC step and returns the value. inv_update_steps (Callble, int): steps between recomputing and communicating the second-order information or callable that takes the K-FAC step and returns the value. damping (Callable, float): Tikhonov damping parameter or a callable that takes the K-FAC step and returns the damping parameter as a float (default: 0.001). factor_decay (Callable, float): running average coefficient for Kronecker factors or callable that takes the K-FAC step and returns the factor_decay (default: 0.95). kl_clip (Callable, float): clipping parameter for gradient scaling or a callable that takes the K-FAC step and returns a float. If None, no scaling/clipping will be applied (default: 0.001). lr (Callable, float): learning rate or callable that takes the K-FAC step and returns learning rate (default: 0.1). accumulation_steps (int): number of forward/backward passes between optimization steps (default: 1). allreduce_bucket_cap_mb (float): maximum size in megabytes for allreduce bucketing. If zero, bucketing is not used (default: 25). assignment_strategy (AssignmentStrategy, str): See `AssignmentStrategy` for more details (default: AssignmentStrategy.COMPUTE). colocate_factors (bool): assign both factors for a single layer to the same worker. Reccomended when num_layers < world_size (default: True). compute_method (ComputeMethod, str): See `ComputeMethod` for more details (default: ComputeMethod.EIGEN). compute_eigenvalue_outer_product (bool): when using the eigen compute method, precompute the element-wise inverse of the outer product of eigenvectors on the eigen decomposition worker rather to reduce computation in the gradient preconditioning stage. `colocate_factors` must be True (default: True). grad_worker_fraction (DistributedStrategy, float): controls the fraction of workers assigned as gradient workers for each layer. Optionally, predefined configurations can be passed using the DistributedStrategy enum (default: DistributedStrategy.COMM_OPT). symmetry_aware (bool): communicate only the upper triangle of symmetric matrices. Can reduce communication time when factors are large (default: False). grad_scaler (torch.cuda.amp.GradScaler or callable): Gradient scaler used for Torch AMP training. Used to unscale the G factors as they are accumulated during the backward pass. Alternatively can be a callable which will return the current scale (default: None). factor_dtype (torch.dtype): force data type for storing factors. If None, defaults to data type of intermediate values in forward/backward pass (default: None). inv_dtype (torch.dtype): force data type for storing second-order data (e.g., inverses or eigen decompositions) (default: torch.float32). skip_layers (list[str]): regex patterns that if matched, will cause the layer to not be registered. The patterns will be applied against the layer's name and class name. update_factors_in_hook (bool): If True, running average of factors is updated in the module hook and the async commmunication is started. Otherwise, this will be performed at the start of step() (default: True). loglevel (int): logging level (default: logging.DEBUG). """ if allreduce_bucket_cap_mb < 0: raise ValueError('allreduce_bucket_cap_mb must be >= 0') if (compute_method == ComputeMethod.EIGEN and compute_eigenvalue_outer_product and not colocate_factors): raise ValueError( 'colocate_factors must be True to use ' 'compute_eigenvalue_outer_product', ) if isinstance(assignment_strategy, str): assignment_strategy = AssignmentStrategy[ assignment_strategy.upper()] if isinstance(compute_method, str): compute_method = ComputeMethod[compute_method.upper()] size = get_world_size() if isinstance(grad_worker_fraction, DistributedStrategy): distributed_strategy = grad_worker_fraction if distributed_strategy == DistributedStrategy.COMM_OPT: grad_worker_fraction = 1.0 elif distributed_strategy == DistributedStrategy.HYBRID_OPT: grad_worker_fraction = 0.5 elif distributed_strategy == DistributedStrategy.MEM_OPT: grad_worker_fraction = 1.0 / size else: raise AssertionError(f'Unknown enum {grad_worker_fraction}') else: if not 0 <= grad_worker_fraction or not 1 >= grad_worker_fraction: raise ValueError('grad_worker_fraction must in [0, 1]') if grad_worker_fraction == 0: grad_worker_fraction = 1.0 / size if size % max(1, round(size * grad_worker_fraction)) != 0: raise ValueError( 'grad_worker_fraction must produce groups of ' 'equal size', ) if grad_worker_fraction == 1: grad_worker_fraction = 1.0 # ensure float distributed_strategy = DistributedStrategy.COMM_OPT elif grad_worker_fraction <= 1 / size: distributed_strategy = DistributedStrategy.MEM_OPT else: distributed_strategy = DistributedStrategy.HYBRID_OPT assert isinstance(grad_worker_fraction, float) if (not colocate_factors and distributed_strategy is DistributedStrategy.MEM_OPT): warnings.warn( 'grad_worker_frac=1/world_size (MEM_OPT) requires ' 'colocate_factors=True. Enabling colocate_factors.', ) colocate_factors = True self.allreduce_bucket_cap_mb = allreduce_bucket_cap_mb self.assignment_strategy = assignment_strategy self.colocate_factors = colocate_factors self.compute_eigenvalue_outer_product = ( compute_eigenvalue_outer_product) self.compute_method = compute_method self.distributed_strategy = distributed_strategy self.grad_worker_fraction = grad_worker_fraction self.grad_scaler = grad_scaler self.factor_dtype = factor_dtype self.inv_dtype = inv_dtype self.skip_layers = [] if skip_layers is None else skip_layers self.symmetry_aware = symmetry_aware if self.allreduce_bucket_cap_mb > 0: self.allreduce_method = AllreduceMethod.ALLREDUCE_BUCKETED else: self.allreduce_method = AllreduceMethod.ALLREDUCE self.tdc = TorchDistributedCommunicator( bucket_cap_mb=self.allreduce_bucket_cap_mb, ) layer_kwargs = dict( allreduce_method=self.allreduce_method, grad_scaler=self.grad_scaler, factor_dtype=self.factor_dtype, inv_dtype=self.inv_dtype, symmetry_aware=self.symmetry_aware, tdc=self.tdc, ) layer_type: type[KFACBaseLayer] if self.compute_method == ComputeMethod.EIGEN: layer_type = KFACEigenLayer layer_kwargs[ 'prediv_eigenvalues'] = self.compute_eigenvalue_outer_product elif self.compute_method == ComputeMethod.INVERSE: layer_type = KFACInverseLayer else: raise AssertionError( f'Unknown compute_method={self.compute_method}', ) kfac_layers = register_modules( model, kfac_layer_type=layer_type, skip_layers=self.skip_layers, **layer_kwargs, ) for name, kfac_layer in kfac_layers.values(): logger.log( loglevel, f'Registered name="{name}": {repr(kfac_layer)}', ) if self.assignment_strategy == AssignmentStrategy.COMPUTE: cost_func = lambda n: n**3 # noqa: E731 elif self.assignment_strategy == AssignmentStrategy.MEMORY: cost_func = lambda n: n**2 # noqa: E731 else: raise AssertionError( f'Unknown assignment_strategy={self.assignment_strategy}', ) work = { name: { 'A': cost_func(kfac_layer.module.a_factor_shape[0]), 'G': cost_func(kfac_layer.module.g_factor_shape[0]), } for name, kfac_layer in kfac_layers.values() } new_group = cast( Callable[[List[int]], dist.ProcessGroup], dist.new_group, ) mock_new_group: Callable[[list[int]], None] = lambda x: None assignment = KAISAAssignment( work, local_rank=get_rank(), world_size=get_world_size(), grad_worker_fraction=self.grad_worker_fraction, group_func=new_group if dist.is_initialized() else mock_new_group, colocate_factors=self.colocate_factors, ) logger.log(loglevel, f'KFAC layer assignments: {assignment}') defaults = { 'allreduce_bucket_cap_mb': self.allreduce_bucket_cap_mb, 'allreduce_method': self.allreduce_method, 'assignment_strategy': self.assignment_strategy, 'colocate_factors': self.colocate_factors, 'compute_eigenvalue_outer_product': (self.compute_eigenvalue_outer_product), 'compute_method': self.compute_method, 'distributed_strategy': self.distributed_strategy, 'grad_worker_fraction': self.grad_worker_fraction, 'grad_scaler': self.grad_scaler is not None, 'factor_dtype': self.factor_dtype, 'inv_dtype': self.inv_dtype, 'skip_layers': self.skip_layers, 'symmetry_aware': self.symmetry_aware, } super().__init__( kfac_layers, factor_update_steps=factor_update_steps, inv_update_steps=inv_update_steps, factor_decay=factor_decay, damping=damping, kl_clip=kl_clip, lr=lr, accumulation_steps=accumulation_steps, assignment=assignment, update_factors_in_hook=update_factors_in_hook, defaults=defaults, tdc=self.tdc, loglevel=loglevel, )
def precondition() -> None: """Precondition layer in distributed environment.""" in_features = 10 out_features = 5 batch_size = 2 module = torch.nn.Linear(in_features, out_features) module_helper = LinearModuleHelper(module) layer = kfac_layer( module=module_helper, tdc=TorchDistributedCommunicator(), **kwargs, ) # Compute gradient x = torch.rand([batch_size, in_features]) y = torch.rand([batch_size, out_features]) loss = (module(x) - y).sum() loss.backward() weight_grad = module.weight.grad bias_grad = module.bias.grad # Stage 1: save intermediate variables layer.save_layer_input([x]) layer.save_layer_input([x]) layer.save_layer_grad_output((y, )) layer.save_layer_grad_output((y, )) # Stage 2: compute factors layer.update_a_factor() layer.update_g_factor() if 'factor_dtype' in kwargs: assert (layer.a_factor is not None and layer.a_factor.dtype == kwargs['factor_dtype']) assert (layer.g_factor is not None and layer.g_factor.dtype == kwargs['factor_dtype']) # Stage 3: reduce factors layer.reduce_a_factor() layer.reduce_g_factor() if layer.allreduce_method == AllreduceMethod.ALLREDUCE_BUCKETED: layer.tdc.flush_allreduce_buckets() # Stage 4: compute second-order info if dist.get_rank() == 0: layer.compute_a_inv() layer.compute_g_inv() # Stage 5: communicate second-order info if world_size > 1 and strategy == DistributedStrategy.COMM_OPT: layer.broadcast_a_inv(src=0) layer.broadcast_g_inv(src=0) if 'inv_dtype' in kwargs: if kfac_layer == KFACInverseLayer: layer = cast(KFACInverseLayer, layer) assert (layer.a_inv is not None and layer.a_inv.dtype == kwargs['inv_dtype']) assert (layer.g_inv is not None and layer.g_inv.dtype == kwargs['inv_dtype']) elif kfac_layer == KFACEigenLayer: layer = cast(KFACEigenLayer, layer) assert (layer.qa is not None and layer.qa.dtype == kwargs['inv_dtype']) assert (layer.qg is not None and layer.qg.dtype == kwargs['inv_dtype']) if ('prediv_eigenvalues' in kwargs and kwargs['prediv_eigenvalues']): assert (layer.dgda is not None and layer.dgda.dtype == kwargs['inv_dtype']) else: assert (layer.da is not None and layer.da.dtype == kwargs['inv_dtype']) assert (layer.dg is not None and layer.dg.dtype == kwargs['inv_dtype']) else: raise AssertionError # Stage 6: compute and communicate preconditioned gradient if strategy == DistributedStrategy.COMM_OPT or ( strategy == DistributedStrategy.MEM_OPT and dist.get_rank() == 0): layer.preconditioned_grad() if strategy == DistributedStrategy.MEM_OPT: layer.broadcast_grad(src=0) # Stage 7: update gradient layer.update_grad() # Make sure gradient changed due to preconditioning assert not torch.equal(weight_grad, module.weight.grad) assert not torch.equal(bias_grad, module.bias.grad)
def test_kfac_layers(layer_type: type[KFACBaseLayer]) -> None: """Test KFACBaseLayer implementation.""" batch_size, in_features, out_features = 2, 5, 5 module = torch.nn.Linear(in_features, out_features) x = torch.rand([batch_size, in_features]) y = torch.rand([batch_size, out_features]) loss = (module(x) - y).sum() loss.backward() module_helper = LinearModuleHelper(module) layer = layer_type( module=module_helper, tdc=TorchDistributedCommunicator(), ) assert 'LinearModuleHelper' in repr(layer) assert layer_type.__name__ in repr(layer) # Cannot reduce factors, update gradient, or compute inverses with pytest.raises(RuntimeError): layer.reduce_a_factor() with pytest.raises(RuntimeError): layer.reduce_g_factor() with pytest.raises(RuntimeError): layer.update_grad() with pytest.raises(RuntimeError): layer.compute_a_inv() with pytest.raises(RuntimeError): layer.compute_g_inv() with pytest.raises(RuntimeError): layer.preconditioned_grad() # Broadcasts should fail because src rank has not computed the data with mock.patch('torch.distributed.get_rank', return_value=0): with pytest.raises(RuntimeError): layer.broadcast_grad(src=0) with pytest.raises(RuntimeError): layer.broadcast_a_inv(src=0) with pytest.raises(RuntimeError): layer.broadcast_g_inv(src=0) state_dict = layer.state_dict() assert 'A' in state_dict and state_dict['A'] is None assert 'G' in state_dict and state_dict['G'] is None with pytest.raises(KeyError): # state_dict must have A and G keys layer.load_state_dict({}) layer.load_state_dict(state_dict) mem_usage = layer.memory_usage() for key in mem_usage: assert mem_usage[key] == 0 layer.save_layer_input([x]) layer.save_layer_grad_output((y, )) # layer memory usage should reflect temp factors for current batch mem_usage = layer.memory_usage() assert mem_usage['a_batch'] > 0 assert mem_usage['g_batch'] > 0 # Check clear current batch layer.reset_batch() mem_usage = layer.memory_usage() assert mem_usage['a_batch'] == 0 assert mem_usage['g_batch'] == 0 # Should not raise an error no batch data has been accumulated layer.update_a_factor() layer.update_g_factor() # Repeat twice: once initializes the factors, the second will add to # the factors layer.save_layer_input([x]) layer.save_layer_grad_output((y, )) layer.update_a_factor() layer.update_g_factor() layer.save_layer_input([x]) layer.save_layer_grad_output((y, )) layer.update_a_factor() layer.update_g_factor() # flushed current batch so factors should have size but there is no # temp factors anymore mem_usage = layer.memory_usage() assert mem_usage['a_factors'] > 0 assert mem_usage['g_factors'] > 0 assert mem_usage['a_batch'] == 0 assert mem_usage['g_batch'] == 0 state_dict = layer.state_dict() assert isinstance(state_dict['A'], torch.Tensor) assert isinstance(state_dict['G'], torch.Tensor) layer.load_state_dict(state_dict) assert layer.a_factor is not None assert layer.g_factor is not None assert torch.equal(layer.a_factor, state_dict['A']) assert torch.equal(layer.g_factor, state_dict['G']) # Check gradient scaling. We haven't computed the preconditioned gradient # so just fake one grad = module_helper.get_grad() layer.grad = grad layer.update_grad(scale=10) assert torch.equal(10 * grad, module_helper.get_grad())