Ejemplo n.º 1
0
def test_get_cov(
    a: torch.Tensor,
    b: torch.Tensor | None,
    scale: float | None,
    expected: torch.Tensor,
) -> None:
    """Test get_cov."""
    if len(a.shape) != 2 or (b is not None and a.shape != b.shape):
        with pytest.raises(ValueError):
            get_cov(a, b, scale)
    else:
        out = get_cov(a, b, scale)
        assert torch.equal(out, expected)
        if b is None:
            assert torch.equal(out, out.t())
Ejemplo n.º 2
0
    def get_g_factor(self, g: torch.Tensor) -> torch.Tensor:
        """Compute G factor with the gradient w.r.t. the output.

        Args:
            g (torch.Tensor): tensor with shape batch_size * out_dim.
        """
        g = g.reshape(-1, g.size(-1))
        return get_cov(g)
Ejemplo n.º 3
0
 def get_a_factor(self, a: torch.Tensor) -> torch.Tensor:
     """Compute A factor with the input from the forward pass."""
     a = self._extract_patches(a)
     spatial_size = a.size(1) * a.size(2)
     a = a.view(-1, a.size(-1))
     if self.has_bias():
         a = append_bias_ones(a)
     a = a / spatial_size
     return get_cov(a)
Ejemplo n.º 4
0
    def get_a_factor(self, a: torch.Tensor) -> torch.Tensor:
        """Compute A factor with the input from the forward pass.

        Args:
            a (torch.Tensor): tensor with shape batch_size * in_dim.
        """
        a = a.view(-1, a.size(-1))
        if self.has_bias():
            a = append_bias_ones(a)
        return get_cov(a)
Ejemplo n.º 5
0
    def get_g_factor(self, g: torch.Tensor) -> torch.Tensor:
        """Compute G factor with the gradient w.r.t. the output.

        Args:
            g (torch.Tensor): tensor with shape batch_size * n_filters *
                out_h * out_w n_filters is actually the output dimension
                (analogous to Linear layer).
        """
        spatial_size = g.size(2) * g.size(3)
        g = g.transpose(1, 2).transpose(2, 3)
        g = g.reshape(-1, g.size(-1))
        g = g / spatial_size
        return get_cov(g)