def test_base_preconditioner_init() -> None:
    """Test BaseKFACPreconditioner initialize."""
    factor_update_steps = 1
    inv_update_steps = 2
    damping = 0.003
    factor_decay = 0.95
    kl_clip = 0.001
    lr = 0.1
    accumulation_steps = 1

    preconditioner = BaseKFACPreconditioner(
        layers=example_layers(),
        assignment=LazyAssignment(),
        tdc=TorchDistributedCommunicator(),
        factor_update_steps=factor_update_steps,
        inv_update_steps=inv_update_steps,
        damping=damping,
        factor_decay=factor_decay,
        kl_clip=kl_clip,
        lr=lr,
        accumulation_steps=accumulation_steps,
    )
    assert preconditioner.damping == damping
    assert preconditioner.factor_decay == factor_decay
    assert preconditioner.kl_clip == kl_clip
    assert preconditioner.lr == lr
    assert preconditioner.factor_update_steps == factor_update_steps
    assert preconditioner.inv_update_steps == inv_update_steps
    assert preconditioner.steps == 0

    preconditioner = BaseKFACPreconditioner(
        layers=example_layers(),
        assignment=LazyAssignment(),
        tdc=TorchDistributedCommunicator(),
        damping=lambda x: damping,
        factor_decay=lambda x: factor_decay,
        kl_clip=lambda x: kl_clip,
        lr=lambda x: lr,
    )
    assert preconditioner.damping == damping
    assert preconditioner.factor_decay == factor_decay
    assert preconditioner.kl_clip == kl_clip
    assert preconditioner.lr == lr

    defaults = {'default1': None, 'default2': None}
    preconditioner2 = BaseKFACPreconditioner(
        layers=example_layers(),
        assignment=LazyAssignment(),
        tdc=TorchDistributedCommunicator(),
        defaults=defaults,
    )

    assert repr(preconditioner) != repr(preconditioner2)
    # repr() should list all parameters including those passed in with the
    # defaults parameter
    assert 'default1' in repr(preconditioner2)
    assert 'default2' in repr(preconditioner2)
Esempio n. 2
0
def test_nonsymmetric_eigen() -> None:
    """Test nonsymmetric eigen decomposition."""
    batch_size, in_features, out_features = 2, 5, 5
    module = torch.nn.Linear(in_features, out_features)
    x = torch.rand([batch_size, in_features])
    y = torch.rand([batch_size, out_features])
    loss = (module(x) - y).sum()
    loss.backward()

    with mock.patch.object(
            LinearModuleHelper,
            'has_symmetric_factors',
            return_value=False,
    ):
        module_helper = LinearModuleHelper(module)
        layer = KFACEigenLayer(
            module=module_helper,
            tdc=TorchDistributedCommunicator(),
        )
    assert not layer.symmetric_factors

    layer.save_layer_input([x])
    layer.save_layer_grad_output((y, ))
    layer.update_a_factor()
    layer.update_g_factor()
    layer.compute_a_inv()
    layer.compute_g_inv()
    layer.preconditioned_grad()
    layer.update_grad()
Esempio n. 3
0
    def allreduce(
        shape: list[int],
        tensor_count: int,
        bucket_cap_mb: float,
        symmetric: bool = False,
        expect_raises: type[BaseException] | None = None,
    ) -> None:
        """Test allreduce in distributed environment."""
        try:
            world_size = torch.distributed.get_world_size()
            comm = TorchDistributedCommunicator(bucket_cap_mb)

            tensors = []
            for _ in range(tensor_count):
                t = torch.ones(shape, dtype=torch.float32)
                tensors.append(comm.allreduce_bucketed(t,
                                                       symmetric=symmetric), )
            if world_size > 1:
                with pytest.raises(RuntimeError):
                    comm._new_allreduce_bucket(None)
            comm.flush_allreduce_buckets()

            for tensor in tensors:
                if isinstance(tensor, Future):
                    tensor = tensor.wait()
                assert isinstance(tensor, torch.Tensor)
                assert torch.sum(tensor).item() == world_size * torch.numel(
                    tensor, )
        except Exception as e:
            if expect_raises is not None and not isinstance(e, expect_raises):
                raise
Esempio n. 4
0
    def simple_allreduce(
        shape: list[int],
        symmetric: bool = False,
        expect_raises: type[BaseException] | None = None,
    ) -> None:
        try:
            world_size = torch.distributed.get_world_size()
            comm = TorchDistributedCommunicator()

            t = torch.ones(shape)
            t_res = comm.allreduce(t, symmetric=symmetric)
            if isinstance(t_res, Future):
                t_res = t_res.wait()
            assert isinstance(t_res, torch.Tensor)
            assert torch.sum(t_res).item() == torch.numel(t_res) * world_size
        except Exception as e:
            if expect_raises is not None and not isinstance(e, expect_raises):
                raise
def test_grad_scale_no_layers() -> None:
    """Test computing grad scale with no layers has no divide by 0 error."""
    p = BaseKFACPreconditioner(
        layers=example_layers(),
        assignment=LazyAssignment(),
        tdc=TorchDistributedCommunicator(),
    )
    p._layers = {}
    assert p._compute_grad_scale() == 1.0
Esempio n. 6
0
    def simple_broadcast(
        shape: list[int],
        symmetric: bool = False,
        expect_raises: type[BaseException] | None = None,
    ) -> None:
        try:
            rank = torch.distributed.get_rank()
            comm = TorchDistributedCommunicator()

            t = rank * torch.ones(shape)
            t_res = comm.broadcast(t, src=0, symmetric=symmetric)
            if isinstance(t_res, Future):
                t_res = t_res.wait()
            assert isinstance(t_res, torch.Tensor)
            # Rank 0 will broadcast and it should be all zeros
            assert torch.sum(t_res).item() == 0
        except Exception as e:
            if expect_raises is not None and not isinstance(e, expect_raises):
                raise
def test_base_preconditioner_callable_hyperparams() -> None:
    """Test BaseKFACPreconditioner supports callable hyperparams."""
    p = BaseKFACPreconditioner(
        example_layers(),
        assignment=LazyAssignment(),
        tdc=TorchDistributedCommunicator(),
        factor_update_steps=lambda x: x * 2,
        inv_update_steps=lambda x: x * 3,
        damping=lambda x: x * 5,
        factor_decay=lambda x: x * 7,
        kl_clip=lambda x: x * 9,
    )

    for x in range(0, 10):
        p._steps = x
        assert p.factor_update_steps == x * 2
        assert p.inv_update_steps == x * 3
        assert p.damping == x * 5
        assert p.factor_decay == x * 7
        assert p.kl_clip == x * 9

    p = BaseKFACPreconditioner(
        example_layers(),
        assignment=LazyAssignment(),
        tdc=TorchDistributedCommunicator(),
        factor_update_steps=lambda x: 2,
        inv_update_steps=lambda x: 3,
        damping=lambda x: 5,
        factor_decay=lambda x: 7,
        kl_clip=lambda x: 9,
    )

    for x in range(0, 10):
        p._steps = x
        assert p.factor_update_steps == 2
        assert p.inv_update_steps == 3
        assert p.damping == 5
        assert p.factor_decay == 7
        assert p.kl_clip == 9
def example_layers() -> dict[torch.nn.Module, tuple[str, KFACBaseLayer]]:
    """Return register layers of LeNet with KFAC."""
    return register_modules(
        LeNet(),
        kfac_layer_type=KFACInverseLayer,
        allreduce_method=AllreduceMethod.ALLREDUCE,
        grad_scaler=None,
        factor_dtype=None,
        inv_dtype=torch.float32,
        skip_layers=[],
        symmetry_aware=False,
        tdc=TorchDistributedCommunicator(),
    )
Esempio n. 9
0
def test_register_modules(
    model: torch.nn.Module,
    layer_type: type[KFACBaseLayer],
    skip_layers: list[str],
    expected_count: int,
) -> None:
    """Test register_modules."""
    kwargs = dict(
        allreduce_method=AllreduceMethod.ALLREDUCE,
        grad_scaler=None,
        factor_dtype=None,
        inv_dtype=torch.float32,
        symmetry_aware=False,
        tdc=TorchDistributedCommunicator(),
    )
    kfac_layers = register_modules(
        model,
        layer_type,
        skip_layers=skip_layers,
        **kwargs,
    )
    assert len(kfac_layers) == expected_count
Esempio n. 10
0
    def allreduce(
        shape: list[int],
        tensor_count: int,
        bucket_cap_mb: float,
        symmetric: bool = False,
    ) -> None:
        """Test allreduce in distributed environment."""
        rank = torch.distributed.get_rank()
        world_size = torch.distributed.get_world_size()
        comm = TorchDistributedCommunicator(bucket_cap_mb)

        if world_size == 1:
            group = None
        else:
            # Exclude rank 0
            group = torch.distributed.new_group(
                [i for i in range(world_size) if i >= 1], )
        group_size = torch.distributed.get_world_size(group)

        tensors = []
        for _ in range(tensor_count):
            t = torch.ones(shape, dtype=torch.float32)
            if group is None or rank > 0:
                tensors.append(
                    comm.allreduce_bucketed(
                        t,
                        symmetric=symmetric,
                        group=group,
                    ), )
        comm.flush_allreduce_buckets()
        # All buckets should be removed now so calling again shouldn't be
        # an issue
        comm.flush_allreduce_buckets()

        for tensor in tensors:
            if isinstance(tensor, Future):
                tensor = tensor.wait()
            assert isinstance(tensor, torch.Tensor)
            assert torch.sum(tensor).item() == group_size * torch.numel(
                tensor, )
Esempio n. 11
0
    def __init__(
        self,
        model: torch.nn.Module,
        *,
        factor_update_steps: Callable[[int], int] | int = 1,
        inv_update_steps: Callable[[int], int] | int = 1,
        # KFAC hyperparameters
        damping: Callable[[int], float] | float = 0.001,
        factor_decay: Callable[[int], float] | float = 0.95,
        kl_clip: Callable[[int], float] | float = 0.001,
        lr: Callable[[int], float] | float = 0.1,
        # Distribution strategy
        accumulation_steps: int = 1,
        allreduce_bucket_cap_mb: float = 25.0,
        assignment_strategy: (
            AssignmentStrategy | str
        ) = AssignmentStrategy.COMPUTE,
        compute_method: ComputeMethod | str = ComputeMethod.EIGEN,
        compute_eigenvalue_outer_product: bool = False,
        symmetry_aware: bool = False,
        # DeepSpeed 3D parallelism
        data_parallel_group: torch.distributed.ProcessGroup | None = None,
        model_parallel_group: torch.distributed.ProcessGroup | None = None,
        pipeline_parallel_group: torch.distributed.ProcessGroup | None = None,
        # Optional other parameters
        grad_scaler: (
            torch.cuda.amp.GradScaler | Callable[[], float] | None
        ) = None,
        factor_dtype: torch.dtype | None = None,
        inv_dtype: torch.dtype = torch.float32,
        factor_checkpoint_dir: str | None = None,
        skip_layers: list[str] | None = None,
        update_factors_in_hook: bool = True,
        loglevel: int = logging.DEBUG,
    ) -> None:
        """Init KFACPreconditioner.

        Args:
            model (torch.nn.Module): model to precondition with KFAC.
            factor_update_steps (Callable, int): steps between computing and
                updating the running average of the Kronecker factors or
                callable that takes the K-FAC step and returns the value.
            inv_update_steps (Callble, int): steps between recomputing and
                communicating the second-order information or callable that
                takes the K-FAC step and returns the value.
            damping (Callable, float): Tikhonov damping parameter or a callable
                that takes the K-FAC step and returns the damping parameter
                as a float (default: 0.001).
            factor_decay (Callable, float): running average coefficient for
                Kronecker factors or callable that takes the K-FAC step and
                returns the factor_decay (default: 0.95).
            kl_clip (Callable, float): clipping parameter for gradient scaling
                or a callable that takes the K-FAC step and returns a float.
                If None, no scaling/clipping will be applied (default: 0.001).
            lr (Callable, float): learning rate or callable that takes the
                K-FAC step and returns learning rate (default: 0.1).
            accumulation_steps (int): number of forward/backward passes
                between optimization steps (default: 1).
            allreduce_bucket_cap_mb (float): maximum size in megabytes for
                allreduce bucketing. If zero, bucketing is not used
                (default: 25).
            assignment_strategy (AssignmentStrategy, str): See
                `AssignmentStrategy` for more details
                (default: AssignmentStrategy.COMPUTE).
            compute_method (ComputeMethod, str): See `ComputeMethod` for more
                details (default: ComputeMethod.EIGEN).
            compute_eigenvalue_outer_product (bool): when using the eigen
                compute method, precompute the element-wise inverse of the
                outer product of eigenvectors on the eigen decomposition worker
                rather to reduce computation in the gradient preconditioning
                stage. `colocate_factors` must be True (default: True).
            symmetry_aware (bool): communicate only the upper triangle of
                symmetric matrices. Can reduce communication time when factors
                are large (default: False).
            data_parallel_group (ProcessGroup): DeepSpeed data parallel group.
            model_parallel_group (ProcessGroup): DeepSpeed model parallel
                group.
            pipeline_parallel_group (ProcessGroup): DeepSpeed pipeline parallel
                group.
            grad_scaler (torch.cuda.amp.GradScaler or callable): Gradient
                scaler used for Torch AMP training. Used to unscale the G
                factors as they are accumulated during the backward pass.
                Alternatively can be a callable which will return the current
                scale (default: None).
            factor_dtype (torch.dtype): force data type for storing factors.
                If None, defaults to data type of intermediate values in
                forward/backward pass (default: None).
            inv_dtype (torch.dtype): force data type for storing second-order
                data (e.g., inverses or eigen decompositions)
                (default: torch.float32).
            factor_checkpoint_dir (str): directory to store factors
                checkpoints in.
            skip_layers (list): list of module names to ignore when registering
                layers. Passing the name of parent modules will prevent
                recursively registering child modules of the parent.
                Case-insensitive (default: []).
            update_factors_in_hook (bool): If True, running average of factors
                is updated in the module hook and the async commmunication is
                started. Otherwise, this will be performed at the start of
                step() (default: True).
            loglevel (int): logging level (default: logging.DEBUG).
        """
        if deepspeed_import_error is not None:  # pragma: no cover
            raise deepspeed_import_error

        warnings.warn(
            'KFAC support for GPT-NeoX training is experimental.',
            ExperimentalFeatureWarning,
        )

        if not isinstance(model, PipelineModule):
            raise ValueError(
                'model must be an instance of deepspeed.pipe.PipelineModule. '
                f'Got an instance of {type(model)}.',
            )

        if allreduce_bucket_cap_mb < 0:
            raise ValueError('allreduce_bucket_cap_mb must be >= 0')
        if isinstance(assignment_strategy, str):
            assignment_strategy = AssignmentStrategy[
                assignment_strategy.upper()
            ]
        if isinstance(compute_method, str):
            compute_method = ComputeMethod[compute_method.upper()]

        self.allreduce_bucket_cap_mb = allreduce_bucket_cap_mb
        self.assignment_strategy = assignment_strategy
        self.compute_eigenvalue_outer_product = (
            compute_eigenvalue_outer_product
        )
        self.compute_method = compute_method
        self.grad_scaler = grad_scaler
        self.factor_dtype = factor_dtype
        self.inv_dtype = inv_dtype
        self.factor_checkpoint_dir = factor_checkpoint_dir
        self.skip_layers = [] if skip_layers is None else skip_layers
        self.symmetry_aware = symmetry_aware

        self.data_parallel_group = data_parallel_group
        self.model_parallel_group = model_parallel_group
        self.pipeline_parallel_group = pipeline_parallel_group

        if self.allreduce_bucket_cap_mb > 0:
            self.allreduce_method = AllreduceMethod.ALLREDUCE_BUCKETED
        else:
            self.allreduce_method = AllreduceMethod.ALLREDUCE
        self.tdc = TorchDistributedCommunicator(
            bucket_cap_mb=self.allreduce_bucket_cap_mb,
        )

        layer_kwargs = dict(
            allreduce_method=self.allreduce_method,
            grad_scaler=self.grad_scaler,
            factor_dtype=self.factor_dtype,
            inv_dtype=self.inv_dtype,
            symmetry_aware=self.symmetry_aware,
            tdc=self.tdc,
        )

        if self.compute_method == ComputeMethod.EIGEN:
            pass
        elif self.compute_method == ComputeMethod.INVERSE:
            raise ValueError('Inverse method not supported with GPT NeoX.')
        else:
            raise AssertionError(
                f'Unknown compute_method={self.compute_method}',
            )

        kfac_layers = register_modules(
            model,
            model_parallel_group=self.model_parallel_group,
            skip_layers=self.skip_layers,
            **layer_kwargs,
        )

        data_parallel_ranks = [
            -1 for _ in range(get_world_size(self.data_parallel_group))
        ]
        torch.distributed.all_gather_object(
            object_list=data_parallel_ranks,
            obj=get_rank(),
            group=self.data_parallel_group,
        )

        for name, kfac_layer in kfac_layers.values():
            logger.log(
                loglevel,
                f'Registered name="{name}": {repr(kfac_layer)} on '
                f'global-rank={get_rank()} and '
                f'pipeline-rank={get_rank(self.pipeline_parallel_group)}',
            )

        if self.assignment_strategy == AssignmentStrategy.COMPUTE:
            cost_func = lambda n: n**3  # noqa: E731
        elif self.assignment_strategy == AssignmentStrategy.MEMORY:
            cost_func = lambda n: n**2  # noqa: E731
        else:
            raise AssertionError(
                f'Unknown assignment_strategy={self.assignment_strategy}',
            )

        work = {
            name: {
                'A': cost_func(kfac_layer.module.a_factor_shape[0]),
                'G': cost_func(kfac_layer.module.g_factor_shape[0]),
            }
            for name, kfac_layer in kfac_layers.values()
        }

        assignment = GPTNeoXAssignment(
            work,
            local_rank=get_rank(),
            topology=model.topology(),
            data_parallel_group=self.data_parallel_group,
            model_parallel_group=self.model_parallel_group,
        )
        logger.log(loglevel, f'KFAC layer assignments: {assignment}')

        # Set primary rank for each layer
        for name, kfac_layer in kfac_layers.values():
            # Inv worker for A and G should be same because GPTNeoXAssignment
            # uses gradient worker fraction 1/world_size (mem-opt)
            if not isinstance(kfac_layer, GPTNeoXKFACEigenLayer):
                raise AssertionError(
                    'GPTNeoXKFACPreconditioner only supports '
                    'GPTNeoXKFACEigenLayer.',
                )
            kfac_layer.primary_rank = assignment.factor_worker(name, 'A')
            kfac_layer.data_parallel_group = assignment.data_parallel_group
            kfac_layer.pipe_parallel_peer_group = (
                assignment.pipe_parallel_peer_group
            )

        defaults = {
            'allreduce_bucket_cap_mb': self.allreduce_bucket_cap_mb,
            'allreduce_method': self.allreduce_method,
            'assignment_strategy': self.assignment_strategy,
            'compute_eigenvalue_outer_product': (
                self.compute_eigenvalue_outer_product
            ),
            'compute_method': self.compute_method,
            'grad_scaler': self.grad_scaler is not None,
            'factor_checkpoint_dir': self.factor_checkpoint_dir,
            'factor_dtype': self.factor_dtype,
            'inv_dtype': self.inv_dtype,
            'skip_layers': self.skip_layers,
            'symmetry_aware': self.symmetry_aware,
        }

        super().__init__(
            kfac_layers,
            factor_update_steps=factor_update_steps,
            inv_update_steps=inv_update_steps,
            factor_decay=factor_decay,
            damping=damping,
            kl_clip=kl_clip,
            lr=lr,
            accumulation_steps=accumulation_steps,
            assignment=assignment,
            update_factors_in_hook=update_factors_in_hook,
            defaults=defaults,
            tdc=self.tdc,
            loglevel=loglevel,
        )
def test_base_preconditioner_init_raises() -> None:
    """Test BaseKFACPreconditioner raises."""
    with pytest.raises(ValueError):
        BaseKFACPreconditioner(
            example_layers(),
            assignment=LazyAssignment(),
            tdc=TorchDistributedCommunicator(),
            factor_update_steps=-1,
        )

    with pytest.raises(ValueError):
        BaseKFACPreconditioner(
            example_layers(),
            assignment=LazyAssignment(),
            tdc=TorchDistributedCommunicator(),
            inv_update_steps=-1,
        )

    with pytest.raises(ValueError):
        BaseKFACPreconditioner(
            example_layers(),
            assignment=LazyAssignment(),
            tdc=TorchDistributedCommunicator(),
            damping=-1,
        )

    with pytest.raises(ValueError):
        BaseKFACPreconditioner(
            example_layers(),
            assignment=LazyAssignment(),
            tdc=TorchDistributedCommunicator(),
            factor_decay=-1,
        )

    with pytest.raises(ValueError):
        BaseKFACPreconditioner(
            example_layers(),
            assignment=LazyAssignment(),
            tdc=TorchDistributedCommunicator(),
            factor_decay=2,
        )

    with pytest.raises(ValueError):
        BaseKFACPreconditioner(
            example_layers(),
            assignment=LazyAssignment(),
            tdc=TorchDistributedCommunicator(),
            kl_clip=-1,
        )

    with pytest.raises(ValueError):
        BaseKFACPreconditioner(
            example_layers(),
            assignment=LazyAssignment(),
            tdc=TorchDistributedCommunicator(),
            lr=-1,
        )

    with pytest.raises(ValueError):
        BaseKFACPreconditioner(
            example_layers(),
            assignment=LazyAssignment(),
            tdc=TorchDistributedCommunicator(),
            accumulation_steps=-1,
        )

    with pytest.warns():
        BaseKFACPreconditioner(
            example_layers(),
            assignment=LazyAssignment(),
            tdc=TorchDistributedCommunicator(),
            factor_update_steps=3,
            inv_update_steps=2,
        )
    def e2e() -> None:
        """Helper to run training in simulated distributed environment."""
        batch_size = 2
        model = TinyModel()
        criterion = torch.nn.MSELoss(reduction='sum')
        optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

        tdc = TorchDistributedCommunicator()
        layers = register_modules(
            model,
            KFACInverseLayer,
            allreduce_method=AllreduceMethod.ALLREDUCE,
            grad_scaler=None,
            factor_dtype=None,
            inv_dtype=torch.float32,
            skip_layers=[],
            symmetry_aware=False,
            tdc=tdc,
        )
        preconditioner = BaseKFACPreconditioner(
            layers=layers,
            assignment=LazyAssignment(broadcast=broadcast),
            tdc=tdc,
            accumulation_steps=accumulation_steps,
            **kfac_args,
        )

        for i in range(1, 10):
            x = torch.rand(batch_size, 10)
            y = torch.rand(batch_size, 10)
            y_pred = model(x)
            if i % accumulation_steps == 0:
                loss = criterion(y_pred, y)
                loss.backward()
                grad_weight_linear2 = model.linear2.weight.grad
                grad_bias_linear2 = model.linear2.bias.grad
                preconditioner.step()
                # Verify gradient was preconditioned
                assert not torch.equal(
                    grad_weight_linear2,
                    model.linear2.weight.grad,
                )
                assert not torch.equal(
                    grad_bias_linear2,
                    model.linear2.bias.grad,
                )
                optimizer.step()
                optimizer.zero_grad()

        # Test state dict computes inverses
        state_dict = preconditioner.state_dict()
        for _, layer in preconditioner._layers.values():
            layer = cast(KFACInverseLayer, layer)
            layer.a_factor = None
            layer.g_factor = None
            layer.a_inv = None
            layer.g_inv = None
        preconditioner.load_state_dict(state_dict)
        for _, layer in preconditioner._layers.values():
            layer = cast(KFACInverseLayer, layer)
            assert isinstance(layer.a_inv, torch.Tensor)
            assert isinstance(layer.g_inv, torch.Tensor)

        # Test grad hook supports tensor input rather than tuple
        preconditioner._save_grad_output(
            model.linear1,
            torch.rand(batch_size, 10),
            torch.rand(batch_size, 20),
        )

        # Test hook additional functionality
        if preconditioner._update_factors_in_hook:
            # Reset preconditioner to ensure hooks trigger
            preconditioner._steps = 0
            preconditioner._mini_steps = defaultdict(int)
            preconditioner._accumulation_steps = 100

            # Do forward/backward pass to verify hooks trigger and we
            # have temp factors for batch
            x = torch.rand(batch_size, 10)
            y = torch.rand(batch_size, 10)
            loss = criterion(model(x), y)
            loss.backward()
            mem_usage = preconditioner.memory_usage()
            for mem in mem_usage.values():
                assert mem > 0
            preconditioner.reset_batch()

            # Make sure hooks do not trigger when model is not in training mode
            model.eval()
            x = torch.rand(batch_size, 10)
            y = torch.rand(batch_size, 10)
            loss = criterion(model(x), y)
            loss.backward()
            mem_usage = preconditioner.memory_usage()
            for key, mem in mem_usage.items():
                if 'batch' in key:
                    assert mem == 0
def test_empty_state_dict() -> None:
    """Test state dict functionality with no factors."""
    p1 = BaseKFACPreconditioner(
        layers=example_layers(),
        assignment=LazyAssignment(),
        tdc=TorchDistributedCommunicator(),
        factor_update_steps=1,
        inv_update_steps=3,
        damping=5,
        factor_decay=0.7,
        kl_clip=11,
        lr=13,
        accumulation_steps=17,
        update_factors_in_hook=False,
        defaults={'default1': 19},
    )
    p1._steps = 99
    state_dict = p1.state_dict(include_factors=False)
    # include_factors=True should add entries for the factors even though
    # they are None at this point
    assert state_dict != p1.state_dict(include_factors=True)

    # We filled p1 with non-default values so we can load the
    # state_dict of p1 into p2 and see what is loaded
    p2 = BaseKFACPreconditioner(
        layers=example_layers(),
        assignment=LazyAssignment(),
        tdc=TorchDistributedCommunicator(),
    )
    p2.load_state_dict(state_dict, compute_inverses=False)

    assert p1.factor_update_steps == p2.factor_update_steps
    assert p1.inv_update_steps == p2.inv_update_steps
    assert p1.damping == p2.damping
    assert p1.factor_decay == p2.factor_decay
    assert p1.kl_clip == p2.kl_clip
    assert p1.lr == p2.lr

    # We only load the hyperparameters and training state
    assert p1._accumulation_steps != p2._accumulation_steps
    assert p1._update_factors_in_hook != p2._update_factors_in_hook
    assert p1._defaults != p2._defaults

    # Steps should be loaded
    assert p1._steps == p2._steps

    p3 = BaseKFACPreconditioner(
        layers=example_layers(),
        assignment=LazyAssignment(),
        tdc=TorchDistributedCommunicator(),
        factor_update_steps=lambda x: 1,
        inv_update_steps=lambda x: 3,
        damping=lambda x: 5,
        factor_decay=lambda x: 0.7,
        kl_clip=lambda x: 11,
        lr=lambda x: 13,
    )
    state_dict = p3.state_dict()
    assert 'factor_update_steps' not in state_dict
    assert 'inv_update_steps' not in state_dict
    assert 'damping' not in state_dict
    assert 'factor_decay' not in state_dict
    assert 'kl_clip' not in state_dict
    assert 'lr' not in state_dict

    p3.load_state_dict(state_dict, compute_inverses=False)
    assert p3.factor_update_steps == 1
    assert p3.inv_update_steps == 3
    assert p3.damping == 5
    assert p3.factor_decay == 0.7
    assert p3.kl_clip == 11
    assert p3.lr == 13

    # Check warns user if they set compute_inverses but the state dict was
    # created with include_factors=False
    with pytest.warns():
        p3.load_state_dict({'steps': 0}, compute_inverses=True)

    # Should cause no problems... but doesn't do much but set steps!
    p3.load_state_dict({'steps': 0}, compute_inverses=False)

    # Check mismatch in registered layers
    del state_dict['layers'][list(state_dict['layers'].keys()).pop()]
    with pytest.raises(ValueError):
        p3.load_state_dict(state_dict, compute_inverses=False)
Esempio n. 15
0
    def __init__(
        self,
        model: torch.nn.Module,
        *,
        factor_update_steps: Callable[[int], int] | int = 1,
        inv_update_steps: Callable[[int], int] | int = 1,
        # KFAC hyperparameters
        damping: Callable[[int], float] | float = 0.001,
        factor_decay: Callable[[int], float] | float = 0.95,
        kl_clip: Callable[[int], float] | float = 0.001,
        lr: Callable[[int], float] | float = 0.1,
        # Distribution strategy
        accumulation_steps: int = 1,
        allreduce_bucket_cap_mb: float = 25.0,
        assignment_strategy: (AssignmentStrategy
                              | str) = AssignmentStrategy.COMPUTE,
        colocate_factors: bool = True,
        compute_method: ComputeMethod | str = ComputeMethod.EIGEN,
        compute_eigenvalue_outer_product: bool = True,
        grad_worker_fraction: (DistributedStrategy
                               | float) = DistributedStrategy.COMM_OPT,
        symmetry_aware: bool = False,
        # Optional other parameters
        grad_scaler: (torch.cuda.amp.GradScaler | Callable[[], float]
                      | None) = None,
        factor_dtype: torch.dtype | None = None,
        inv_dtype: torch.dtype = torch.float32,
        skip_layers: list[str] | None = None,
        update_factors_in_hook: bool = True,
        loglevel: int = logging.DEBUG,
    ) -> None:
        """Init KFACPreconditioner.

        Args:
            model (torch.nn.Module): model to precondition with KFAC.
            factor_update_steps (Callable, int): steps between computing and
                updating the running average of the Kronecker factors or
                callable that takes the K-FAC step and returns the value.
            inv_update_steps (Callble, int): steps between recomputing and
                communicating the second-order information or callable that
                takes the K-FAC step and returns the value.
            damping (Callable, float): Tikhonov damping parameter or a callable
                that takes the K-FAC step and returns the damping parameter
                as a float (default: 0.001).
            factor_decay (Callable, float): running average coefficient for
                Kronecker factors or callable that takes the K-FAC step and
                returns the factor_decay (default: 0.95).
            kl_clip (Callable, float): clipping parameter for gradient scaling
                or a callable that takes the K-FAC step and returns a float.
                If None, no scaling/clipping will be applied (default: 0.001).
            lr (Callable, float): learning rate or callable that takes the
                K-FAC step and returns learning rate (default: 0.1).
            accumulation_steps (int): number of forward/backward passes
                between optimization steps (default: 1).
            allreduce_bucket_cap_mb (float): maximum size in megabytes for
                allreduce bucketing. If zero, bucketing is not used
                (default: 25).
            assignment_strategy (AssignmentStrategy, str): See
                `AssignmentStrategy` for more details
                (default: AssignmentStrategy.COMPUTE).
            colocate_factors (bool): assign both factors for a single layer to
                the same worker. Reccomended when num_layers < world_size
                (default: True).
            compute_method (ComputeMethod, str): See `ComputeMethod` for more
                details (default: ComputeMethod.EIGEN).
            compute_eigenvalue_outer_product (bool): when using the eigen
                compute method, precompute the element-wise inverse of the
                outer product of eigenvectors on the eigen decomposition worker
                rather to reduce computation in the gradient preconditioning
                stage. `colocate_factors` must be True (default: True).
            grad_worker_fraction (DistributedStrategy, float): controls the
                fraction of workers assigned as gradient workers for each
                layer. Optionally, predefined configurations can be passed
                using the DistributedStrategy enum
                (default: DistributedStrategy.COMM_OPT).
            symmetry_aware (bool): communicate only the upper triangle of
                symmetric matrices. Can reduce communication time when factors
                are large (default: False).
            grad_scaler (torch.cuda.amp.GradScaler or callable): Gradient
                scaler used for Torch AMP training. Used to unscale the G
                factors as they are accumulated during the backward pass.
                Alternatively can be a callable which will return the current
                scale (default: None).
            factor_dtype (torch.dtype): force data type for storing factors.
                If None, defaults to data type of intermediate values in
                forward/backward pass (default: None).
            inv_dtype (torch.dtype): force data type for storing second-order
                data (e.g., inverses or eigen decompositions)
                (default: torch.float32).
            skip_layers (list[str]): regex patterns that if matched, will cause
                the layer to not be registered. The patterns will be applied
                against the layer's name and class name.
            update_factors_in_hook (bool): If True, running average of factors
                is updated in the module hook and the async commmunication is
                started. Otherwise, this will be performed at the start of
                step() (default: True).
            loglevel (int): logging level (default: logging.DEBUG).
        """
        if allreduce_bucket_cap_mb < 0:
            raise ValueError('allreduce_bucket_cap_mb must be >= 0')
        if (compute_method == ComputeMethod.EIGEN
                and compute_eigenvalue_outer_product and not colocate_factors):
            raise ValueError(
                'colocate_factors must be True to use '
                'compute_eigenvalue_outer_product', )
        if isinstance(assignment_strategy, str):
            assignment_strategy = AssignmentStrategy[
                assignment_strategy.upper()]
        if isinstance(compute_method, str):
            compute_method = ComputeMethod[compute_method.upper()]

        size = get_world_size()
        if isinstance(grad_worker_fraction, DistributedStrategy):
            distributed_strategy = grad_worker_fraction
            if distributed_strategy == DistributedStrategy.COMM_OPT:
                grad_worker_fraction = 1.0
            elif distributed_strategy == DistributedStrategy.HYBRID_OPT:
                grad_worker_fraction = 0.5
            elif distributed_strategy == DistributedStrategy.MEM_OPT:
                grad_worker_fraction = 1.0 / size
            else:
                raise AssertionError(f'Unknown enum {grad_worker_fraction}')
        else:
            if not 0 <= grad_worker_fraction or not 1 >= grad_worker_fraction:
                raise ValueError('grad_worker_fraction must in [0, 1]')
            if grad_worker_fraction == 0:
                grad_worker_fraction = 1.0 / size
            if size % max(1, round(size * grad_worker_fraction)) != 0:
                raise ValueError(
                    'grad_worker_fraction must produce groups of '
                    'equal size', )
            if grad_worker_fraction == 1:
                grad_worker_fraction = 1.0  # ensure float
                distributed_strategy = DistributedStrategy.COMM_OPT
            elif grad_worker_fraction <= 1 / size:
                distributed_strategy = DistributedStrategy.MEM_OPT
            else:
                distributed_strategy = DistributedStrategy.HYBRID_OPT
        assert isinstance(grad_worker_fraction, float)

        if (not colocate_factors
                and distributed_strategy is DistributedStrategy.MEM_OPT):
            warnings.warn(
                'grad_worker_frac=1/world_size (MEM_OPT) requires '
                'colocate_factors=True. Enabling colocate_factors.', )
            colocate_factors = True

        self.allreduce_bucket_cap_mb = allreduce_bucket_cap_mb
        self.assignment_strategy = assignment_strategy
        self.colocate_factors = colocate_factors
        self.compute_eigenvalue_outer_product = (
            compute_eigenvalue_outer_product)
        self.compute_method = compute_method
        self.distributed_strategy = distributed_strategy
        self.grad_worker_fraction = grad_worker_fraction
        self.grad_scaler = grad_scaler
        self.factor_dtype = factor_dtype
        self.inv_dtype = inv_dtype
        self.skip_layers = [] if skip_layers is None else skip_layers
        self.symmetry_aware = symmetry_aware

        if self.allreduce_bucket_cap_mb > 0:
            self.allreduce_method = AllreduceMethod.ALLREDUCE_BUCKETED
        else:
            self.allreduce_method = AllreduceMethod.ALLREDUCE
        self.tdc = TorchDistributedCommunicator(
            bucket_cap_mb=self.allreduce_bucket_cap_mb, )

        layer_kwargs = dict(
            allreduce_method=self.allreduce_method,
            grad_scaler=self.grad_scaler,
            factor_dtype=self.factor_dtype,
            inv_dtype=self.inv_dtype,
            symmetry_aware=self.symmetry_aware,
            tdc=self.tdc,
        )

        layer_type: type[KFACBaseLayer]
        if self.compute_method == ComputeMethod.EIGEN:
            layer_type = KFACEigenLayer
            layer_kwargs[
                'prediv_eigenvalues'] = self.compute_eigenvalue_outer_product
        elif self.compute_method == ComputeMethod.INVERSE:
            layer_type = KFACInverseLayer
        else:
            raise AssertionError(
                f'Unknown compute_method={self.compute_method}', )

        kfac_layers = register_modules(
            model,
            kfac_layer_type=layer_type,
            skip_layers=self.skip_layers,
            **layer_kwargs,
        )
        for name, kfac_layer in kfac_layers.values():
            logger.log(
                loglevel,
                f'Registered name="{name}": {repr(kfac_layer)}',
            )

        if self.assignment_strategy == AssignmentStrategy.COMPUTE:
            cost_func = lambda n: n**3  # noqa: E731
        elif self.assignment_strategy == AssignmentStrategy.MEMORY:
            cost_func = lambda n: n**2  # noqa: E731
        else:
            raise AssertionError(
                f'Unknown assignment_strategy={self.assignment_strategy}', )

        work = {
            name: {
                'A': cost_func(kfac_layer.module.a_factor_shape[0]),
                'G': cost_func(kfac_layer.module.g_factor_shape[0]),
            }
            for name, kfac_layer in kfac_layers.values()
        }

        new_group = cast(
            Callable[[List[int]], dist.ProcessGroup],
            dist.new_group,
        )
        mock_new_group: Callable[[list[int]], None] = lambda x: None
        assignment = KAISAAssignment(
            work,
            local_rank=get_rank(),
            world_size=get_world_size(),
            grad_worker_fraction=self.grad_worker_fraction,
            group_func=new_group if dist.is_initialized() else mock_new_group,
            colocate_factors=self.colocate_factors,
        )
        logger.log(loglevel, f'KFAC layer assignments: {assignment}')

        defaults = {
            'allreduce_bucket_cap_mb':
            self.allreduce_bucket_cap_mb,
            'allreduce_method':
            self.allreduce_method,
            'assignment_strategy':
            self.assignment_strategy,
            'colocate_factors':
            self.colocate_factors,
            'compute_eigenvalue_outer_product':
            (self.compute_eigenvalue_outer_product),
            'compute_method':
            self.compute_method,
            'distributed_strategy':
            self.distributed_strategy,
            'grad_worker_fraction':
            self.grad_worker_fraction,
            'grad_scaler':
            self.grad_scaler is not None,
            'factor_dtype':
            self.factor_dtype,
            'inv_dtype':
            self.inv_dtype,
            'skip_layers':
            self.skip_layers,
            'symmetry_aware':
            self.symmetry_aware,
        }

        super().__init__(
            kfac_layers,
            factor_update_steps=factor_update_steps,
            inv_update_steps=inv_update_steps,
            factor_decay=factor_decay,
            damping=damping,
            kl_clip=kl_clip,
            lr=lr,
            accumulation_steps=accumulation_steps,
            assignment=assignment,
            update_factors_in_hook=update_factors_in_hook,
            defaults=defaults,
            tdc=self.tdc,
            loglevel=loglevel,
        )
Esempio n. 16
0
    def precondition() -> None:
        """Precondition layer in distributed environment."""
        in_features = 10
        out_features = 5
        batch_size = 2
        module = torch.nn.Linear(in_features, out_features)
        module_helper = LinearModuleHelper(module)

        layer = kfac_layer(
            module=module_helper,
            tdc=TorchDistributedCommunicator(),
            **kwargs,
        )

        # Compute gradient
        x = torch.rand([batch_size, in_features])
        y = torch.rand([batch_size, out_features])
        loss = (module(x) - y).sum()
        loss.backward()
        weight_grad = module.weight.grad
        bias_grad = module.bias.grad

        # Stage 1: save intermediate variables
        layer.save_layer_input([x])
        layer.save_layer_input([x])
        layer.save_layer_grad_output((y, ))
        layer.save_layer_grad_output((y, ))

        # Stage 2: compute factors
        layer.update_a_factor()
        layer.update_g_factor()
        if 'factor_dtype' in kwargs:
            assert (layer.a_factor is not None
                    and layer.a_factor.dtype == kwargs['factor_dtype'])
            assert (layer.g_factor is not None
                    and layer.g_factor.dtype == kwargs['factor_dtype'])

        # Stage 3: reduce factors
        layer.reduce_a_factor()
        layer.reduce_g_factor()
        if layer.allreduce_method == AllreduceMethod.ALLREDUCE_BUCKETED:
            layer.tdc.flush_allreduce_buckets()

        # Stage 4: compute second-order info
        if dist.get_rank() == 0:
            layer.compute_a_inv()
            layer.compute_g_inv()

        # Stage 5: communicate second-order info
        if world_size > 1 and strategy == DistributedStrategy.COMM_OPT:
            layer.broadcast_a_inv(src=0)
            layer.broadcast_g_inv(src=0)

        if 'inv_dtype' in kwargs:
            if kfac_layer == KFACInverseLayer:
                layer = cast(KFACInverseLayer, layer)
                assert (layer.a_inv is not None
                        and layer.a_inv.dtype == kwargs['inv_dtype'])
                assert (layer.g_inv is not None
                        and layer.g_inv.dtype == kwargs['inv_dtype'])
            elif kfac_layer == KFACEigenLayer:
                layer = cast(KFACEigenLayer, layer)
                assert (layer.qa is not None
                        and layer.qa.dtype == kwargs['inv_dtype'])
                assert (layer.qg is not None
                        and layer.qg.dtype == kwargs['inv_dtype'])
                if ('prediv_eigenvalues' in kwargs
                        and kwargs['prediv_eigenvalues']):
                    assert (layer.dgda is not None
                            and layer.dgda.dtype == kwargs['inv_dtype'])
                else:
                    assert (layer.da is not None
                            and layer.da.dtype == kwargs['inv_dtype'])
                    assert (layer.dg is not None
                            and layer.dg.dtype == kwargs['inv_dtype'])
            else:
                raise AssertionError

        # Stage 6: compute and communicate preconditioned gradient
        if strategy == DistributedStrategy.COMM_OPT or (
                strategy == DistributedStrategy.MEM_OPT
                and dist.get_rank() == 0):
            layer.preconditioned_grad()
        if strategy == DistributedStrategy.MEM_OPT:
            layer.broadcast_grad(src=0)

        # Stage 7: update gradient
        layer.update_grad()

        # Make sure gradient changed due to preconditioning
        assert not torch.equal(weight_grad, module.weight.grad)
        assert not torch.equal(bias_grad, module.bias.grad)
Esempio n. 17
0
def test_kfac_layers(layer_type: type[KFACBaseLayer]) -> None:
    """Test KFACBaseLayer implementation."""
    batch_size, in_features, out_features = 2, 5, 5
    module = torch.nn.Linear(in_features, out_features)
    x = torch.rand([batch_size, in_features])
    y = torch.rand([batch_size, out_features])
    loss = (module(x) - y).sum()
    loss.backward()

    module_helper = LinearModuleHelper(module)
    layer = layer_type(
        module=module_helper,
        tdc=TorchDistributedCommunicator(),
    )

    assert 'LinearModuleHelper' in repr(layer)
    assert layer_type.__name__ in repr(layer)

    # Cannot reduce factors, update gradient, or compute inverses
    with pytest.raises(RuntimeError):
        layer.reduce_a_factor()
    with pytest.raises(RuntimeError):
        layer.reduce_g_factor()
    with pytest.raises(RuntimeError):
        layer.update_grad()
    with pytest.raises(RuntimeError):
        layer.compute_a_inv()
    with pytest.raises(RuntimeError):
        layer.compute_g_inv()
    with pytest.raises(RuntimeError):
        layer.preconditioned_grad()

    # Broadcasts should fail because src rank has not computed the data
    with mock.patch('torch.distributed.get_rank', return_value=0):
        with pytest.raises(RuntimeError):
            layer.broadcast_grad(src=0)
        with pytest.raises(RuntimeError):
            layer.broadcast_a_inv(src=0)
        with pytest.raises(RuntimeError):
            layer.broadcast_g_inv(src=0)

    state_dict = layer.state_dict()
    assert 'A' in state_dict and state_dict['A'] is None
    assert 'G' in state_dict and state_dict['G'] is None
    with pytest.raises(KeyError):
        # state_dict must have A and G keys
        layer.load_state_dict({})
    layer.load_state_dict(state_dict)

    mem_usage = layer.memory_usage()
    for key in mem_usage:
        assert mem_usage[key] == 0

    layer.save_layer_input([x])
    layer.save_layer_grad_output((y, ))

    # layer memory usage should reflect temp factors for current batch
    mem_usage = layer.memory_usage()
    assert mem_usage['a_batch'] > 0
    assert mem_usage['g_batch'] > 0

    # Check clear current batch
    layer.reset_batch()
    mem_usage = layer.memory_usage()
    assert mem_usage['a_batch'] == 0
    assert mem_usage['g_batch'] == 0
    # Should not raise an error no batch data has been accumulated
    layer.update_a_factor()
    layer.update_g_factor()

    # Repeat twice: once initializes the factors, the second will add to
    # the factors
    layer.save_layer_input([x])
    layer.save_layer_grad_output((y, ))
    layer.update_a_factor()
    layer.update_g_factor()
    layer.save_layer_input([x])
    layer.save_layer_grad_output((y, ))
    layer.update_a_factor()
    layer.update_g_factor()

    # flushed current batch so factors should have size but there is no
    # temp factors anymore
    mem_usage = layer.memory_usage()
    assert mem_usage['a_factors'] > 0
    assert mem_usage['g_factors'] > 0
    assert mem_usage['a_batch'] == 0
    assert mem_usage['g_batch'] == 0

    state_dict = layer.state_dict()
    assert isinstance(state_dict['A'], torch.Tensor)
    assert isinstance(state_dict['G'], torch.Tensor)
    layer.load_state_dict(state_dict)
    assert layer.a_factor is not None
    assert layer.g_factor is not None
    assert torch.equal(layer.a_factor, state_dict['A'])
    assert torch.equal(layer.g_factor, state_dict['G'])

    # Check gradient scaling. We haven't computed the preconditioned gradient
    # so just fake one
    grad = module_helper.get_grad()
    layer.grad = grad
    layer.update_grad(scale=10)
    assert torch.equal(10 * grad, module_helper.get_grad())