示例#1
0
    def extensions(self, global_step):
        """Return list of BackPACK extensions required for the computation.

        Args:
            global_step (int): The current iteration number.

        Returns:
            list: (Potentially empty) list with required BackPACK quantities.
        """
        ext = []

        if self.is_active(global_step):
            ext.append(
                extensions.BatchGradTransforms(
                    transforms={"hist_2d": self._compute_histogram}))

        if self._adapt_schedule(global_step):
            if self._adapt_policy == "abs_max":
                ext.append(
                    extensions.BatchGradTransforms(
                        transforms={
                            "grad_batch_abs_max": transform_grad_batch_abs_max,
                            "param_abs_max": transform_param_abs_max,
                        }))
            elif self._adapt_policy == "min_max":
                ext.append(
                    extensions.BatchGradTransforms(
                        transforms={
                            "grad_batch_min_max": transform_grad_batch_min_max,
                            "param_min_max": transform_param_min_max,
                        }))
            else:
                raise ValueError("Invalid adaptation policy")

        return ext
示例#2
0
    def _end_search_dir_projection_info(self):
        """Compute information for individual gradient projections onto search dir.

        The search direction at an end point is inferred from the model parameters.

        We want to compute dᵀgᵢ / ||d||₂² where d is the search direction and gᵢ are
        individual gradients. However, this fraction cannot be aggregated among
        parameters. We have to split the computation into components that can be
        aggregated, namely dᵀgᵢ and ||d||₂².
        """
        def compute_end_projection_info(batch_grad):
            """Compute information to project individual gradients onto the gradient."""
            batch_size = batch_grad.shape[0]

            end_param = batch_grad._param_weakref()
            start_param = self._get_info("params", end=False)[0][id(end_param)]

            search_dir_flat = (end_param.data - start_param).flatten()
            batch_grad_flat = batch_grad.data.flatten(start_dim=1)

            search_dir_l2_squared = (search_dir_flat**2).sum()
            dot_products = torch.einsum("ni,i->n",
                                        batch_size * batch_grad_flat,
                                        search_dir_flat)

            return {
                "dot_products": dot_products,
                "search_dir_l2_squared": search_dir_l2_squared,
            }

        return extensions.BatchGradTransforms(
            {"end_projection_info": compute_end_projection_info})
示例#3
0
    def _start_search_dir_projection_info(self):
        """Compute information for individual gradient projections onto search dir.

        The search direction at a start point depends on the optimizer.

        We want to compute dᵀgᵢ / ||d||₂² where d is the search direction and gᵢ are
        individual gradients. However, this fraction cannot be aggregated among
        parameters. We have to split the computation into components that can be
        aggregated, namely dᵀgᵢ and ||d||₂².
        """
        def compute_start_projection_info(batch_grad):
            """Compute information to project individual gradients onto the gradient."""
            batch_size = batch_grad.shape[0]

            # TODO Currently only correctly implemented for SGD! Make more general
            warnings.warn(
                "Alpha will only be correct if optimizer is SGD with momentum 0"
            )
            search_dir_flat = -1 * (batch_grad.data.sum(0).flatten())
            batch_grad_flat = batch_grad.data.flatten(start_dim=1)

            search_dir_l2_squared = (search_dir_flat**2).sum()
            dot_products = torch.einsum("ni,i->n",
                                        batch_size * batch_grad_flat,
                                        search_dir_flat)

            return {
                "dot_products": dot_products,
                "search_dir_l2_squared": search_dir_l2_squared,
            }

        return extensions.BatchGradTransforms(
            {"start_projection_info": compute_start_projection_info})
示例#4
0
    def extensions(self, global_step):
        """Return list of BackPACK extensions required for the computation.

        Args:
            global_step (int): The current iteration number.

        Returns:
            list: (Potentially empty) list with required BackPACK quantities.
        """
        ext = []

        if self.is_active(global_step):
            ext.append(
                extensions.BatchGradTransforms(
                    transforms={"hist_1d": self._compute_histogram}))

        if self._adapt_schedule(global_step):
            ext.append(
                extensions.BatchGradTransforms(
                    transforms={
                        "grad_batch_abs_max": transform_grad_batch_abs_max
                    }))

        return ext
示例#5
0
def BatchGradTransforms_BatchL2Grad():
    """Compute individual gradient ℓ₂ norms via individual gradients."""
    return extensions.BatchGradTransforms({"batch_l2": batch_l2_transform})
示例#6
0
def BatchGradTransforms_SumGradSquared():
    """Compute sum of squared individual gradients via individual gradients."""
    return extensions.BatchGradTransforms(
        {"sum_grad_squared": sum_grad_squared_transform})
示例#7
0
def BatchGradTransforms_BatchDotGrad():
    """Compute pairwise individual gradient dot products via individual gradients."""
    return extensions.BatchGradTransforms({"batch_dot": batch_dot_transform})