Ejemplo n.º 1
0
    def load_factors_from_dir(self, compute_inverses: bool = True) -> None:
        """Load factors from `factor_checkpoint_dir`."""
        if self.factor_checkpoint_dir is None:
            raise ValueError('factor_checkpoint_dir is None.')

        if not os.path.isdir(self.factor_checkpoint_dir):
            warnings.warn(
                f'factor_checkpoint_dir={self.factor_checkpoint_dir} '
                'is not a directory. Skipping KFAC checkpoint load.',
            )
            return

        for name, layer in self._layers.values():
            if (
                cast(GPTNeoXAssignment, self._assignment).factor_worker(
                    name,
                    'A',
                )
                == get_rank()
            ):
                filepath = os.path.join(self.factor_checkpoint_dir, name)
                if os.path.exists(filepath):
                    logger.info(
                        f'loading KFAC factors for {name} on rank '
                        f'{get_rank()}',
                    )
                    state_dict = torch.load(filepath)
                    layer.load_state_dict(state_dict)
                    if compute_inverses:
                        layer.compute_a_inv(damping=self.damping)
                        layer.compute_g_inv(damping=self.damping)
Ejemplo n.º 2
0
    def load_state_dict(
        self,
        state_dict: dict[str, Any],
        compute_inverses: bool = True,
    ) -> None:
        """Load state dict."""
        layers = state_dict.pop('layers', None)
        super().load_state_dict(state_dict, compute_inverses=False)

        if self.factor_checkpoint_dir is not None:
            self.load_factors_from_dir(compute_inverses)
            return

        if layers is None:
            return

        for found_name, layer_state_dict in layers.items():
            for name, layer in self._layers.values():
                if (
                    found_name == name
                    and cast(
                        GPTNeoXAssignment,
                        self._assignment,
                    ).factor_worker(name, 'A')
                    == get_rank()
                ):
                    assert isinstance(layer_state_dict['A'], torch.Tensor)
                    assert isinstance(layer_state_dict['G'], torch.Tensor)

                    layer.load_state_dict(layer_state_dict)
                    if compute_inverses:
                        layer.compute_a_inv(damping=self.damping)
                        layer.compute_g_inv(damping=self.damping)

        torch.distributed.barrier()
Ejemplo n.º 3
0
    def broadcast_g_inv(
        self,
        src: int,
        group: dist.ProcessGroup | None = None,
    ) -> None:
        """Initiate G inv broadcast and store future to result.

        Note:
            all ranks must enter this function even if the rank is not
            a part of the inverse broadcast group.

        Args:
            src (int): src rank that computed G inverse.
            group (ProcessGroup): process group to which src should broadcast
                G inv. All ranks in group should enter this function.
                Defaults to None, the default process group.
        """
        if self.g_inv is None:
            if get_rank() == src:
                raise RuntimeError(
                    f'Attempt to broadcast G inv from src={src} but this rank '
                    'has not computed G inv yet.', )
            assert isinstance(self.g_factor, torch.Tensor)
            self.g_inv = torch.empty(
                self.g_factor.shape,
                device=self.g_factor.device,
                dtype=self.inv_dtype,
            )

        self.g_inv = self.tdc.broadcast(  # type: ignore
            self.g_inv,
            src=src,
            group=group,
            symmetric=self.symmetric_factors and self.symmetry_aware,
        )
Ejemplo n.º 4
0
    def broadcast_grad(
        self,
        src: int,
        group: dist.ProcessGroup | None = None,
    ) -> None:
        """Broadcast preconditioned gradient and store future to result.

        Note:
            all ranks must enter this function.

        Args:
            src (int): src rank that preconditioned the gradient.
            group (ProcessGroup): process group to which src should broadcast
                the gradient. All ranks in group should enter this function.
                Defaults to None, the default process group.
        """
        if self.grad is None:
            if get_rank() == src:
                raise RuntimeError(
                    f'Attempt to broadcast gradient from src={src} but this '
                    'rank has not computed the preconditioned gradient yet.', )
            self.grad = torch.empty_like(self.module.get_grad())

        self.grad = self.tdc.broadcast(  # type: ignore
            self.grad,
            src=src,
            group=group,
        )
Ejemplo n.º 5
0
    def reduce_g_factor(
        self,
        group: torch.distributed.ProcessGroup | None = None,
    ) -> None:  # pragma: no cover
        """Initiate reduction of G and store future to result.

        Note:
            all ranks should enter this function.

        Args:
            group (ProcessGroup): ignored because the correct group depends
                on if the parallelism is on the output or input of the layer.
        """
        if self.primary_rank is None:
            raise RuntimeError('primary rank has not been set yet.')
        valid = (torch.distributed.ProcessGroup, type(None))
        if not isinstance(self.data_parallel_group, valid) or not isinstance(
                self.pipe_parallel_peer_group,
                valid,
        ):
            raise RuntimeError(
                'data_parallel_group or pipe_parallel_peer_group has not '
                'been set yet.', )

        if self.parallelism == 'input':
            super().reduce_g_factor(self.pipe_parallel_peer_group)
        elif self.parallelism == 'output':
            if get_rank() != self.primary_rank:
                return
            super().reduce_g_factor(self.data_parallel_group)
        else:
            raise AssertionError('Unreachable.')
Ejemplo n.º 6
0
    def save_factors_to_dir(self) -> None:
        """Save factors to `factor_checkpoint_dir`.

        Saves the state dict for each layer to a separate file in
        `self.factor_checkpoint_dir` with the filename as the name of the
        layer. Note: only the inverse worker for the layer will save the
        layer.
        """
        if self.factor_checkpoint_dir is None:
            raise ValueError('factor_checkpoint_dir is None')

        if get_rank() == 0:
            os.makedirs(self.factor_checkpoint_dir, exist_ok=True)
        torch.distributed.barrier()

        for name, layer in self._layers.values():
            if get_rank() == self._assignment.inv_worker(name, 'A'):
                layer_state_dict = layer.state_dict()
                filepath = os.path.join(self.factor_checkpoint_dir, name)
                logger.info(f'saving KFAC factors for {name} to {filepath}')
                torch.save(layer_state_dict, filepath)
Ejemplo n.º 7
0
    def broadcast_a_inv(
        self,
        src: int,
        group: dist.ProcessGroup | None = None,
    ) -> None:
        """Initiate A inv broadcast and store future to result.

        Note:
            all ranks must enter this function even if the rank is not
            a part of the inverse broadcast group.

        Args:
            src (int): src rank that computed A inverse.
            group (ProcessGroup): process group to which src should broadcast
                A inv. All ranks in group should enter this function.
                Defaults to None, the default process group.
        """
        if self.qa is None or (not self.prediv_eigenvalues
                               and self.da is None):
            if get_rank() == src:
                raise RuntimeError(
                    f'Attempt to broadcast A inv from src={src} but this rank '
                    'has not computed A inv yet.', )
            assert isinstance(self.a_factor, torch.Tensor)
            self.qa = torch.empty(
                self.a_factor.shape,
                device=self.a_factor.device,
                dtype=self.inv_dtype,
            )
            self.da = torch.empty(
                self.a_factor.shape[0],
                device=self.a_factor.device,
                dtype=self.inv_dtype,
            )

        self.qa = self.tdc.broadcast(  # type: ignore
            self.qa,
            src=src,
            group=group,
        )
        if not self.prediv_eigenvalues:
            assert self.da is not None
            self.da = self.tdc.broadcast(  # type: ignore
                self.da,
                src=src,
                group=group,
            )
Ejemplo n.º 8
0
    def state_dict(self, include_factors: bool = True) -> dict[str, Any]:
        """Get state dict.

        Note:
            all ranks must enter this.
        """
        state_dict = super().state_dict(include_factors=False)

        if not include_factors:
            return state_dict

        if self.factor_checkpoint_dir is not None:
            self.save_factors_to_dir()
            return state_dict

        partition: list[tuple[str, dict[str, Any]]] = []
        for name, layer in self._layers.values():
            # Inv worker for A and G is the same for GPT training
            if get_rank() == self._assignment.inv_worker(name, 'A'):
                layer_state_dict = layer.state_dict()
                assert layer_state_dict['A'] is not None
                assert layer_state_dict['G'] is not None
                # Move to CPU where we have more RAM
                layer_state_dict['A'] = layer_state_dict['A'].cpu()
                layer_state_dict['G'] = layer_state_dict['G'].cpu()
                partition.append((name, layer_state_dict))

        partitions = [None for _ in range(get_world_size())]
        # Use gloo group because we moved data to CPU for more RAM
        group = torch.distributed.new_group(backend='gloo')
        torch.distributed.all_gather_object(partitions, partition, group=group)

        layers = {}
        for partition in partitions:  # type: ignore
            for name, layer_state_dict in partition:
                layers[name] = layer_state_dict
        state_dict['layers'] = layers

        torch.distributed.barrier(group)

        return state_dict
Ejemplo n.º 9
0
    def step(self) -> None:
        """Perform one K-FAC step.

        Note:
            This function should always be called before `optimizer.step()` as
            it modifies the gradients and does not modify the weights.

        Note:
            Gradients must be averaged across ranks before calling `step()`.
            This condition is guarenteed to be true if using the
            `DistributedDataParallel` model wrapper as gradients are
            communicated during `loss.backward()`.
        """
        if (
            not self._update_factors_in_hook
            and self.steps % self.factor_update_steps == 0
        ):
            for name, layer in reversed(list(self._layers.values())):
                self._mini_steps[name] = 0
                layer.update_a_factor(alpha=self.factor_decay)
                layer.reduce_a_factor(self._assignment.factor_group(name, 'A'))
                layer.update_g_factor(alpha=self.factor_decay)
                layer.reduce_g_factor(self._assignment.factor_group(name, 'G'))

        # Flush last allreduce bucket from forward/backward pass.
        # Will be a no-op if bucketing was not used
        self._tdc.flush_allreduce_buckets()

        # Compute Inverses
        if self.steps % self.inv_update_steps == 0:
            for name, layer in reversed(list(self._layers.values())):
                if get_rank() == self._assignment.inv_worker(name, 'A'):
                    layer.compute_a_inv(damping=self.damping)
                if (
                    self._assignment.broadcast_inverses()
                    and self._assignment.is_grad_worker(name)
                ):
                    layer.broadcast_a_inv(
                        src=self._assignment.inv_worker(name, 'A'),
                        group=self._assignment.grad_worker_group(name),
                    )
                if get_rank() == self._assignment.inv_worker(name, 'G'):
                    layer.compute_g_inv(damping=self.damping)
                if (
                    self._assignment.broadcast_inverses()
                    and self._assignment.is_grad_worker(name)
                ):
                    layer.broadcast_g_inv(
                        src=self._assignment.inv_worker(name, 'G'),
                        group=self._assignment.grad_worker_group(name),
                    )
            self._tdc.flush_allreduce_buckets()

        # Compute Preconditioned Gradients
        for name, layer in reversed(list(self._layers.values())):
            if self._assignment.is_grad_worker(name):
                layer.preconditioned_grad(damping=self.damping)
            if self._assignment.broadcast_gradients():
                layer.broadcast_grad(
                    src=self._assignment.src_grad_worker(name),
                    group=self._assignment.grad_receiver_group(name),
                )
        self._tdc.flush_allreduce_buckets()

        scale = None if self.kl_clip is None else self._compute_grad_scale()

        # Update gradients in-place
        for _, layer in reversed(list(self._layers.values())):
            layer.update_grad(scale=scale)

        self._steps += 1
        self._mini_steps = defaultdict(int)
Ejemplo n.º 10
0
    def __init__(
        self,
        model: torch.nn.Module,
        *,
        factor_update_steps: Callable[[int], int] | int = 1,
        inv_update_steps: Callable[[int], int] | int = 1,
        # KFAC hyperparameters
        damping: Callable[[int], float] | float = 0.001,
        factor_decay: Callable[[int], float] | float = 0.95,
        kl_clip: Callable[[int], float] | float = 0.001,
        lr: Callable[[int], float] | float = 0.1,
        # Distribution strategy
        accumulation_steps: int = 1,
        allreduce_bucket_cap_mb: float = 25.0,
        assignment_strategy: (
            AssignmentStrategy | str
        ) = AssignmentStrategy.COMPUTE,
        compute_method: ComputeMethod | str = ComputeMethod.EIGEN,
        compute_eigenvalue_outer_product: bool = False,
        symmetry_aware: bool = False,
        # DeepSpeed 3D parallelism
        data_parallel_group: torch.distributed.ProcessGroup | None = None,
        model_parallel_group: torch.distributed.ProcessGroup | None = None,
        pipeline_parallel_group: torch.distributed.ProcessGroup | None = None,
        # Optional other parameters
        grad_scaler: (
            torch.cuda.amp.GradScaler | Callable[[], float] | None
        ) = None,
        factor_dtype: torch.dtype | None = None,
        inv_dtype: torch.dtype = torch.float32,
        factor_checkpoint_dir: str | None = None,
        skip_layers: list[str] | None = None,
        update_factors_in_hook: bool = True,
        loglevel: int = logging.DEBUG,
    ) -> None:
        """Init KFACPreconditioner.

        Args:
            model (torch.nn.Module): model to precondition with KFAC.
            factor_update_steps (Callable, int): steps between computing and
                updating the running average of the Kronecker factors or
                callable that takes the K-FAC step and returns the value.
            inv_update_steps (Callble, int): steps between recomputing and
                communicating the second-order information or callable that
                takes the K-FAC step and returns the value.
            damping (Callable, float): Tikhonov damping parameter or a callable
                that takes the K-FAC step and returns the damping parameter
                as a float (default: 0.001).
            factor_decay (Callable, float): running average coefficient for
                Kronecker factors or callable that takes the K-FAC step and
                returns the factor_decay (default: 0.95).
            kl_clip (Callable, float): clipping parameter for gradient scaling
                or a callable that takes the K-FAC step and returns a float.
                If None, no scaling/clipping will be applied (default: 0.001).
            lr (Callable, float): learning rate or callable that takes the
                K-FAC step and returns learning rate (default: 0.1).
            accumulation_steps (int): number of forward/backward passes
                between optimization steps (default: 1).
            allreduce_bucket_cap_mb (float): maximum size in megabytes for
                allreduce bucketing. If zero, bucketing is not used
                (default: 25).
            assignment_strategy (AssignmentStrategy, str): See
                `AssignmentStrategy` for more details
                (default: AssignmentStrategy.COMPUTE).
            compute_method (ComputeMethod, str): See `ComputeMethod` for more
                details (default: ComputeMethod.EIGEN).
            compute_eigenvalue_outer_product (bool): when using the eigen
                compute method, precompute the element-wise inverse of the
                outer product of eigenvectors on the eigen decomposition worker
                rather to reduce computation in the gradient preconditioning
                stage. `colocate_factors` must be True (default: True).
            symmetry_aware (bool): communicate only the upper triangle of
                symmetric matrices. Can reduce communication time when factors
                are large (default: False).
            data_parallel_group (ProcessGroup): DeepSpeed data parallel group.
            model_parallel_group (ProcessGroup): DeepSpeed model parallel
                group.
            pipeline_parallel_group (ProcessGroup): DeepSpeed pipeline parallel
                group.
            grad_scaler (torch.cuda.amp.GradScaler or callable): Gradient
                scaler used for Torch AMP training. Used to unscale the G
                factors as they are accumulated during the backward pass.
                Alternatively can be a callable which will return the current
                scale (default: None).
            factor_dtype (torch.dtype): force data type for storing factors.
                If None, defaults to data type of intermediate values in
                forward/backward pass (default: None).
            inv_dtype (torch.dtype): force data type for storing second-order
                data (e.g., inverses or eigen decompositions)
                (default: torch.float32).
            factor_checkpoint_dir (str): directory to store factors
                checkpoints in.
            skip_layers (list): list of module names to ignore when registering
                layers. Passing the name of parent modules will prevent
                recursively registering child modules of the parent.
                Case-insensitive (default: []).
            update_factors_in_hook (bool): If True, running average of factors
                is updated in the module hook and the async commmunication is
                started. Otherwise, this will be performed at the start of
                step() (default: True).
            loglevel (int): logging level (default: logging.DEBUG).
        """
        if deepspeed_import_error is not None:  # pragma: no cover
            raise deepspeed_import_error

        warnings.warn(
            'KFAC support for GPT-NeoX training is experimental.',
            ExperimentalFeatureWarning,
        )

        if not isinstance(model, PipelineModule):
            raise ValueError(
                'model must be an instance of deepspeed.pipe.PipelineModule. '
                f'Got an instance of {type(model)}.',
            )

        if allreduce_bucket_cap_mb < 0:
            raise ValueError('allreduce_bucket_cap_mb must be >= 0')
        if isinstance(assignment_strategy, str):
            assignment_strategy = AssignmentStrategy[
                assignment_strategy.upper()
            ]
        if isinstance(compute_method, str):
            compute_method = ComputeMethod[compute_method.upper()]

        self.allreduce_bucket_cap_mb = allreduce_bucket_cap_mb
        self.assignment_strategy = assignment_strategy
        self.compute_eigenvalue_outer_product = (
            compute_eigenvalue_outer_product
        )
        self.compute_method = compute_method
        self.grad_scaler = grad_scaler
        self.factor_dtype = factor_dtype
        self.inv_dtype = inv_dtype
        self.factor_checkpoint_dir = factor_checkpoint_dir
        self.skip_layers = [] if skip_layers is None else skip_layers
        self.symmetry_aware = symmetry_aware

        self.data_parallel_group = data_parallel_group
        self.model_parallel_group = model_parallel_group
        self.pipeline_parallel_group = pipeline_parallel_group

        if self.allreduce_bucket_cap_mb > 0:
            self.allreduce_method = AllreduceMethod.ALLREDUCE_BUCKETED
        else:
            self.allreduce_method = AllreduceMethod.ALLREDUCE
        self.tdc = TorchDistributedCommunicator(
            bucket_cap_mb=self.allreduce_bucket_cap_mb,
        )

        layer_kwargs = dict(
            allreduce_method=self.allreduce_method,
            grad_scaler=self.grad_scaler,
            factor_dtype=self.factor_dtype,
            inv_dtype=self.inv_dtype,
            symmetry_aware=self.symmetry_aware,
            tdc=self.tdc,
        )

        if self.compute_method == ComputeMethod.EIGEN:
            pass
        elif self.compute_method == ComputeMethod.INVERSE:
            raise ValueError('Inverse method not supported with GPT NeoX.')
        else:
            raise AssertionError(
                f'Unknown compute_method={self.compute_method}',
            )

        kfac_layers = register_modules(
            model,
            model_parallel_group=self.model_parallel_group,
            skip_layers=self.skip_layers,
            **layer_kwargs,
        )

        data_parallel_ranks = [
            -1 for _ in range(get_world_size(self.data_parallel_group))
        ]
        torch.distributed.all_gather_object(
            object_list=data_parallel_ranks,
            obj=get_rank(),
            group=self.data_parallel_group,
        )

        for name, kfac_layer in kfac_layers.values():
            logger.log(
                loglevel,
                f'Registered name="{name}": {repr(kfac_layer)} on '
                f'global-rank={get_rank()} and '
                f'pipeline-rank={get_rank(self.pipeline_parallel_group)}',
            )

        if self.assignment_strategy == AssignmentStrategy.COMPUTE:
            cost_func = lambda n: n**3  # noqa: E731
        elif self.assignment_strategy == AssignmentStrategy.MEMORY:
            cost_func = lambda n: n**2  # noqa: E731
        else:
            raise AssertionError(
                f'Unknown assignment_strategy={self.assignment_strategy}',
            )

        work = {
            name: {
                'A': cost_func(kfac_layer.module.a_factor_shape[0]),
                'G': cost_func(kfac_layer.module.g_factor_shape[0]),
            }
            for name, kfac_layer in kfac_layers.values()
        }

        assignment = GPTNeoXAssignment(
            work,
            local_rank=get_rank(),
            topology=model.topology(),
            data_parallel_group=self.data_parallel_group,
            model_parallel_group=self.model_parallel_group,
        )
        logger.log(loglevel, f'KFAC layer assignments: {assignment}')

        # Set primary rank for each layer
        for name, kfac_layer in kfac_layers.values():
            # Inv worker for A and G should be same because GPTNeoXAssignment
            # uses gradient worker fraction 1/world_size (mem-opt)
            if not isinstance(kfac_layer, GPTNeoXKFACEigenLayer):
                raise AssertionError(
                    'GPTNeoXKFACPreconditioner only supports '
                    'GPTNeoXKFACEigenLayer.',
                )
            kfac_layer.primary_rank = assignment.factor_worker(name, 'A')
            kfac_layer.data_parallel_group = assignment.data_parallel_group
            kfac_layer.pipe_parallel_peer_group = (
                assignment.pipe_parallel_peer_group
            )

        defaults = {
            'allreduce_bucket_cap_mb': self.allreduce_bucket_cap_mb,
            'allreduce_method': self.allreduce_method,
            'assignment_strategy': self.assignment_strategy,
            'compute_eigenvalue_outer_product': (
                self.compute_eigenvalue_outer_product
            ),
            'compute_method': self.compute_method,
            'grad_scaler': self.grad_scaler is not None,
            'factor_checkpoint_dir': self.factor_checkpoint_dir,
            'factor_dtype': self.factor_dtype,
            'inv_dtype': self.inv_dtype,
            'skip_layers': self.skip_layers,
            'symmetry_aware': self.symmetry_aware,
        }

        super().__init__(
            kfac_layers,
            factor_update_steps=factor_update_steps,
            inv_update_steps=inv_update_steps,
            factor_decay=factor_decay,
            damping=damping,
            kl_clip=kl_clip,
            lr=lr,
            accumulation_steps=accumulation_steps,
            assignment=assignment,
            update_factors_in_hook=update_factors_in_hook,
            defaults=defaults,
            tdc=self.tdc,
            loglevel=loglevel,
        )
Ejemplo n.º 11
0
    def preconditioned_grad(
        self,
        damping: float = 0.001,
    ) -> None:  # pragma: no cover
        """Compute precondition gradient of each weight in module.

        Note:
            Unlike KFACEigenLayer, every rank in the model parallel group
            should enter this function.

        Preconditioned gradients can be applied to the actual gradients with
        `update_gradient()`. Note the steps are separate in the event that
        intermediate steps will be applied to the preconditioned gradient.

        Args:
            damping (float, optional): damping to use if preconditioning using
                the eigendecomposition method (default: 0.001).
        """
        if self.primary_rank is None:
            raise RuntimeError('primary rank has not been set yet.')

        if get_rank() == self.primary_rank and (
                self.qa is None or self.qg is None or
            (not self.prediv_eigenvalues and self.da is None) or
            (not self.prediv_eigenvalues and self.dg is None) or
            (self.prediv_eigenvalues and self.dgda is None)):
            raise RuntimeError(
                'Eigendecompositions for both A and G have not been computed',
            )

        grad_partition = self.module.get_weight_grad()
        grad = gather_from_model_parallel_region(
            grad_partition,
            dst=self.primary_rank,
            model_parallel_group=self.model_parallel_group,
            dim=-1 if self.parallelism == 'input' else 0,
        )

        if self.module.has_bias():
            bias_grad_partition = self.module.get_bias_grad()
            # Bias is only actually partitioned if parallelism is done on
            # output
            if self.parallelism == 'output':
                bias_grad = gather_from_model_parallel_region(
                    bias_grad_partition,
                    dst=self.primary_rank,
                    model_parallel_group=self.model_parallel_group,
                    dim=0,
                )
            else:
                bias_grad = bias_grad_partition
        else:
            bias_grad = None

        if grad is not None:
            # Only perform preconditioning on worker that got the full gradient
            grad_shape = grad.size()
            if self.module.has_bias():
                assert bias_grad is not None
                bias_grad_shape = bias_grad.size()
                grad = torch.cat([grad, bias_grad.view(-1, 1)], 1)

            # mypy won't know these are not none because they are properties
            assert self.da is not None
            assert self.dg is not None
            assert self.qa is not None
            assert self.qg is not None

            grad_type = grad.dtype
            grad = grad.to(self.qa.dtype)

            v1 = self.qg.t() @ grad @ self.qa
            if self.prediv_eigenvalues:
                v2 = v1 * self.dgda
            else:
                v2 = v1 / (torch.outer(self.dg, self.da) + damping)

            grad = (self.qg @ v2 @ self.qa.t()).to(grad_type)

            if self.module.has_bias():
                weight_grad = grad[:, :-1].view(grad_shape)
                bias_grad = grad[:, -1:].view(bias_grad_shape).contiguous()
            else:
                weight_grad = grad.view(grad_shape)

            weight_grads = list(
                split_tensor_along_dim(
                    weight_grad,
                    get_world_size(self.model_parallel_group),
                    dim=-1 if self.parallelism == 'input' else 0,
                    contiguous_split_chunks=True,
                ), )
            if self.module.has_bias() and self.parallelism == 'output':
                assert bias_grad is not None
                bias_grads = list(
                    split_tensor_along_dim(
                        bias_grad,
                        get_world_size(self.model_parallel_group),
                        dim=0,
                        contiguous_split_chunks=True,
                    ), )
        else:
            weight_grads = [
                torch.zeros_like(grad_partition)
                for _ in range(get_world_size(self.model_parallel_group))
            ]
            if self.parallelism == 'output':
                bias_grads = [
                    torch.zeros_like(bias_grad_partition)
                    for _ in range(get_world_size(self.model_parallel_group))
                ]

        # PyTorch NCCL does not support scatter but we can emulate it
        # with reduce_scatter where the reduction operation is sum and the
        # non_src ranks contribute zero filled tensors
        if get_world_size(self.model_parallel_group) > 1:
            torch.distributed.reduce_scatter(
                grad_partition,
                weight_grads,
                group=self.model_parallel_group,
            )
        else:
            grad_partition = weight_grads[0]

        if self.module.has_bias():
            if get_world_size(self.model_parallel_group) > 1:
                if self.parallelism == 'output':
                    torch.distributed.reduce_scatter(
                        bias_grad_partition,
                        bias_grads,
                        group=self.model_parallel_group,
                    )
                    bias_grad = bias_grad_partition
                else:
                    torch.distributed.broadcast(
                        bias_grad,
                        src=self.primary_rank,
                        group=self.model_parallel_group,
                    )
            assert bias_grad is not None
            self.grad = torch.cat([grad_partition, bias_grad.view(-1, 1)], 1)
        else:
            self.grad = grad_partition
Ejemplo n.º 12
0
    def __init__(
        self,
        model: torch.nn.Module,
        *,
        factor_update_steps: Callable[[int], int] | int = 1,
        inv_update_steps: Callable[[int], int] | int = 1,
        # KFAC hyperparameters
        damping: Callable[[int], float] | float = 0.001,
        factor_decay: Callable[[int], float] | float = 0.95,
        kl_clip: Callable[[int], float] | float = 0.001,
        lr: Callable[[int], float] | float = 0.1,
        # Distribution strategy
        accumulation_steps: int = 1,
        allreduce_bucket_cap_mb: float = 25.0,
        assignment_strategy: (AssignmentStrategy
                              | str) = AssignmentStrategy.COMPUTE,
        colocate_factors: bool = True,
        compute_method: ComputeMethod | str = ComputeMethod.EIGEN,
        compute_eigenvalue_outer_product: bool = True,
        grad_worker_fraction: (DistributedStrategy
                               | float) = DistributedStrategy.COMM_OPT,
        symmetry_aware: bool = False,
        # Optional other parameters
        grad_scaler: (torch.cuda.amp.GradScaler | Callable[[], float]
                      | None) = None,
        factor_dtype: torch.dtype | None = None,
        inv_dtype: torch.dtype = torch.float32,
        skip_layers: list[str] | None = None,
        update_factors_in_hook: bool = True,
        loglevel: int = logging.DEBUG,
    ) -> None:
        """Init KFACPreconditioner.

        Args:
            model (torch.nn.Module): model to precondition with KFAC.
            factor_update_steps (Callable, int): steps between computing and
                updating the running average of the Kronecker factors or
                callable that takes the K-FAC step and returns the value.
            inv_update_steps (Callble, int): steps between recomputing and
                communicating the second-order information or callable that
                takes the K-FAC step and returns the value.
            damping (Callable, float): Tikhonov damping parameter or a callable
                that takes the K-FAC step and returns the damping parameter
                as a float (default: 0.001).
            factor_decay (Callable, float): running average coefficient for
                Kronecker factors or callable that takes the K-FAC step and
                returns the factor_decay (default: 0.95).
            kl_clip (Callable, float): clipping parameter for gradient scaling
                or a callable that takes the K-FAC step and returns a float.
                If None, no scaling/clipping will be applied (default: 0.001).
            lr (Callable, float): learning rate or callable that takes the
                K-FAC step and returns learning rate (default: 0.1).
            accumulation_steps (int): number of forward/backward passes
                between optimization steps (default: 1).
            allreduce_bucket_cap_mb (float): maximum size in megabytes for
                allreduce bucketing. If zero, bucketing is not used
                (default: 25).
            assignment_strategy (AssignmentStrategy, str): See
                `AssignmentStrategy` for more details
                (default: AssignmentStrategy.COMPUTE).
            colocate_factors (bool): assign both factors for a single layer to
                the same worker. Reccomended when num_layers < world_size
                (default: True).
            compute_method (ComputeMethod, str): See `ComputeMethod` for more
                details (default: ComputeMethod.EIGEN).
            compute_eigenvalue_outer_product (bool): when using the eigen
                compute method, precompute the element-wise inverse of the
                outer product of eigenvectors on the eigen decomposition worker
                rather to reduce computation in the gradient preconditioning
                stage. `colocate_factors` must be True (default: True).
            grad_worker_fraction (DistributedStrategy, float): controls the
                fraction of workers assigned as gradient workers for each
                layer. Optionally, predefined configurations can be passed
                using the DistributedStrategy enum
                (default: DistributedStrategy.COMM_OPT).
            symmetry_aware (bool): communicate only the upper triangle of
                symmetric matrices. Can reduce communication time when factors
                are large (default: False).
            grad_scaler (torch.cuda.amp.GradScaler or callable): Gradient
                scaler used for Torch AMP training. Used to unscale the G
                factors as they are accumulated during the backward pass.
                Alternatively can be a callable which will return the current
                scale (default: None).
            factor_dtype (torch.dtype): force data type for storing factors.
                If None, defaults to data type of intermediate values in
                forward/backward pass (default: None).
            inv_dtype (torch.dtype): force data type for storing second-order
                data (e.g., inverses or eigen decompositions)
                (default: torch.float32).
            skip_layers (list[str]): regex patterns that if matched, will cause
                the layer to not be registered. The patterns will be applied
                against the layer's name and class name.
            update_factors_in_hook (bool): If True, running average of factors
                is updated in the module hook and the async commmunication is
                started. Otherwise, this will be performed at the start of
                step() (default: True).
            loglevel (int): logging level (default: logging.DEBUG).
        """
        if allreduce_bucket_cap_mb < 0:
            raise ValueError('allreduce_bucket_cap_mb must be >= 0')
        if (compute_method == ComputeMethod.EIGEN
                and compute_eigenvalue_outer_product and not colocate_factors):
            raise ValueError(
                'colocate_factors must be True to use '
                'compute_eigenvalue_outer_product', )
        if isinstance(assignment_strategy, str):
            assignment_strategy = AssignmentStrategy[
                assignment_strategy.upper()]
        if isinstance(compute_method, str):
            compute_method = ComputeMethod[compute_method.upper()]

        size = get_world_size()
        if isinstance(grad_worker_fraction, DistributedStrategy):
            distributed_strategy = grad_worker_fraction
            if distributed_strategy == DistributedStrategy.COMM_OPT:
                grad_worker_fraction = 1.0
            elif distributed_strategy == DistributedStrategy.HYBRID_OPT:
                grad_worker_fraction = 0.5
            elif distributed_strategy == DistributedStrategy.MEM_OPT:
                grad_worker_fraction = 1.0 / size
            else:
                raise AssertionError(f'Unknown enum {grad_worker_fraction}')
        else:
            if not 0 <= grad_worker_fraction or not 1 >= grad_worker_fraction:
                raise ValueError('grad_worker_fraction must in [0, 1]')
            if grad_worker_fraction == 0:
                grad_worker_fraction = 1.0 / size
            if size % max(1, round(size * grad_worker_fraction)) != 0:
                raise ValueError(
                    'grad_worker_fraction must produce groups of '
                    'equal size', )
            if grad_worker_fraction == 1:
                grad_worker_fraction = 1.0  # ensure float
                distributed_strategy = DistributedStrategy.COMM_OPT
            elif grad_worker_fraction <= 1 / size:
                distributed_strategy = DistributedStrategy.MEM_OPT
            else:
                distributed_strategy = DistributedStrategy.HYBRID_OPT
        assert isinstance(grad_worker_fraction, float)

        if (not colocate_factors
                and distributed_strategy is DistributedStrategy.MEM_OPT):
            warnings.warn(
                'grad_worker_frac=1/world_size (MEM_OPT) requires '
                'colocate_factors=True. Enabling colocate_factors.', )
            colocate_factors = True

        self.allreduce_bucket_cap_mb = allreduce_bucket_cap_mb
        self.assignment_strategy = assignment_strategy
        self.colocate_factors = colocate_factors
        self.compute_eigenvalue_outer_product = (
            compute_eigenvalue_outer_product)
        self.compute_method = compute_method
        self.distributed_strategy = distributed_strategy
        self.grad_worker_fraction = grad_worker_fraction
        self.grad_scaler = grad_scaler
        self.factor_dtype = factor_dtype
        self.inv_dtype = inv_dtype
        self.skip_layers = [] if skip_layers is None else skip_layers
        self.symmetry_aware = symmetry_aware

        if self.allreduce_bucket_cap_mb > 0:
            self.allreduce_method = AllreduceMethod.ALLREDUCE_BUCKETED
        else:
            self.allreduce_method = AllreduceMethod.ALLREDUCE
        self.tdc = TorchDistributedCommunicator(
            bucket_cap_mb=self.allreduce_bucket_cap_mb, )

        layer_kwargs = dict(
            allreduce_method=self.allreduce_method,
            grad_scaler=self.grad_scaler,
            factor_dtype=self.factor_dtype,
            inv_dtype=self.inv_dtype,
            symmetry_aware=self.symmetry_aware,
            tdc=self.tdc,
        )

        layer_type: type[KFACBaseLayer]
        if self.compute_method == ComputeMethod.EIGEN:
            layer_type = KFACEigenLayer
            layer_kwargs[
                'prediv_eigenvalues'] = self.compute_eigenvalue_outer_product
        elif self.compute_method == ComputeMethod.INVERSE:
            layer_type = KFACInverseLayer
        else:
            raise AssertionError(
                f'Unknown compute_method={self.compute_method}', )

        kfac_layers = register_modules(
            model,
            kfac_layer_type=layer_type,
            skip_layers=self.skip_layers,
            **layer_kwargs,
        )
        for name, kfac_layer in kfac_layers.values():
            logger.log(
                loglevel,
                f'Registered name="{name}": {repr(kfac_layer)}',
            )

        if self.assignment_strategy == AssignmentStrategy.COMPUTE:
            cost_func = lambda n: n**3  # noqa: E731
        elif self.assignment_strategy == AssignmentStrategy.MEMORY:
            cost_func = lambda n: n**2  # noqa: E731
        else:
            raise AssertionError(
                f'Unknown assignment_strategy={self.assignment_strategy}', )

        work = {
            name: {
                'A': cost_func(kfac_layer.module.a_factor_shape[0]),
                'G': cost_func(kfac_layer.module.g_factor_shape[0]),
            }
            for name, kfac_layer in kfac_layers.values()
        }

        new_group = cast(
            Callable[[List[int]], dist.ProcessGroup],
            dist.new_group,
        )
        mock_new_group: Callable[[list[int]], None] = lambda x: None
        assignment = KAISAAssignment(
            work,
            local_rank=get_rank(),
            world_size=get_world_size(),
            grad_worker_fraction=self.grad_worker_fraction,
            group_func=new_group if dist.is_initialized() else mock_new_group,
            colocate_factors=self.colocate_factors,
        )
        logger.log(loglevel, f'KFAC layer assignments: {assignment}')

        defaults = {
            'allreduce_bucket_cap_mb':
            self.allreduce_bucket_cap_mb,
            'allreduce_method':
            self.allreduce_method,
            'assignment_strategy':
            self.assignment_strategy,
            'colocate_factors':
            self.colocate_factors,
            'compute_eigenvalue_outer_product':
            (self.compute_eigenvalue_outer_product),
            'compute_method':
            self.compute_method,
            'distributed_strategy':
            self.distributed_strategy,
            'grad_worker_fraction':
            self.grad_worker_fraction,
            'grad_scaler':
            self.grad_scaler is not None,
            'factor_dtype':
            self.factor_dtype,
            'inv_dtype':
            self.inv_dtype,
            'skip_layers':
            self.skip_layers,
            'symmetry_aware':
            self.symmetry_aware,
        }

        super().__init__(
            kfac_layers,
            factor_update_steps=factor_update_steps,
            inv_update_steps=inv_update_steps,
            factor_decay=factor_decay,
            damping=damping,
            kl_clip=kl_clip,
            lr=lr,
            accumulation_steps=accumulation_steps,
            assignment=assignment,
            update_factors_in_hook=update_factors_in_hook,
            defaults=defaults,
            tdc=self.tdc,
            loglevel=loglevel,
        )
Ejemplo n.º 13
0
def test_distributed_not_initialized() -> None:
    """Test rank/world_size functions when not using distributed."""
    assert get_rank() == 0
    assert get_world_size() == 1