Beispiel #1
0
    def forward(self, x):  # pylint: disable=W
        '''
        :x:      [batch, feature_in,  beta, alpha, gamma]
        :return: [batch, feature_out, beta, alpha, gamma]
        '''
        assert x.size(1) == self.nfeature_in
        assert x.size(2) == 2 * self.b_in
        assert x.size(3) == 2 * self.b_in
        assert x.size(4) == 2 * self.b_in

        x = SO3_fft_real.apply(
            x, self.b_out)  # [l * m * n, batch, feature_in, complex]
        y = so3_rft(self.kernel * self.scaling, self.b_out,
                    self.grid)  # [l * m * n, feature_in, feature_out, complex]
        assert x.size(0) == y.size(0)
        assert x.size(2) == y.size(1)
        z = so3_mm(x, y)  # [l * m * n, batch, feature_out, complex]
        assert z.size(0) == x.size(0)
        assert z.size(1) == x.size(1)
        assert z.size(2) == y.size(2)
        z = SO3_ifft_real.apply(z)  # [batch, feature_out, beta, alpha, gamma]

        z = z + self.bias

        return z
Beispiel #2
0
def test_so3_rfft(b_in, b_out, device):
    x = torch.randn(2 * b_in, 2 * b_in, 2 * b_in, dtype=torch.float, device=device)  # [beta, alpha, gamma]

    from s2cnn.soft.so3_fft import so3_rfft
    y1 = so3_rfft(x, b_out=b_out)

    from s2cnn import so3_rft, so3_soft_grid
    import lie_learn.spaces.S3 as S3

    # so3_ft computes a non weighted Fourier transform
    weights = torch.tensor(S3.quadrature_weights(b_in), dtype=torch.float, device=device)
    x2 = torch.einsum("bac,b->bac", (x, weights))

    y2 = so3_rft(x2.view(-1), b_out, so3_soft_grid(b_in))
    assert (y1 - y2).abs().max().item() < 1e-4 * y1.abs().mean().item() 
Beispiel #3
0
    def call(self, x):
        assert K.int_shape(x)[1] == self.nfeature_in
        assert K.int_shape(x)[2] == 2 * self.b_in
        assert K.int_shape(x)[3] == 2 * self.b_in
        assert K.int_shape(x)[4] == 2 * self.b_in

        x = so3_rfft(x, self.b_out)
        y = so3_rft(self.kernel * self.scaling, self.b_out, self.grid)

        assert K.int_shape(x)[0] == K.int_shape(y)[0]
        assert K.int_shape(x)[2] == K.int_shape(y)[1]
        z = K.dot(x, y)  # [l * m * n, batch, feature_out, complex]
        assert K.int_shape(z)[0] == K.int_shape(x)[0]
        assert K.int_shape(z)[1] == K.int_shape(x)[1]
        assert K.int_shape(z)[2] == K.int_shape(y)[2]
        z = so3_rifft(z)  # [batch, feature_out, beta, alpha, gamma]

        if self.use_bias:
            z = K.eval(z)
            z = z + self.bias
            z = K.constant(z)

        return z
Beispiel #4
0
Compare so3_ft with so3_fft
'''
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

b_in, b_out = 6, 6  # bandwidth
# random input data to be Fourier Transform
x = torch.randn(2 * b_in, 2 * b_in, 2 * b_in, dtype=torch.float,
                device=device)  # [beta, alpha, gamma]

# Fast version
from s2cnn.soft.so3_fft import so3_rfft

y1 = so3_rfft(x, b_out=b_out)

# Equivalent version but using the naive version
from s2cnn import so3_rft, so3_soft_grid
import lie_learn.spaces.S3 as S3

# so3_ft computes a non weighted Fourier transform
weights = torch.tensor(S3.quadrature_weights(b_in),
                       dtype=torch.float,
                       device=device)
x = torch.einsum("bac,b->bac", (x, weights))

y2 = so3_rft(x.view(-1), b_out, so3_soft_grid(b_in))

# Compare values
assert (y1 - y2).abs().max().item() < 1e-4 * y1.abs().mean().item()
Beispiel #5
0
b = 6  # bandwidth
# random input data to be Fourier Transform
x = torch.randn(2 * b, 2 * b, 2 * b, dtype=torch.float, device="cuda")  # [beta, alpha, gamma]


# Fast version
from s2cnn.soft.gpu.so3_fft import so3_rfft
t = time.perf_counter()

y1 = so3_rfft(x)

print("so3_rfft: {}s".format(time.perf_counter() - t))


# Equivalent version but using the naive version
from s2cnn import so3_rft, so3_soft_grid
import lie_learn.spaces.S3 as S3
t = time.perf_counter()

# so3_ft computes a non weighted Fourier transform
weights = torch.tensor(S3.quadrature_weights(b), dtype=torch.float, device="cuda")
x = torch.einsum("bac,b->bac", (x, weights))

y2 = so3_rft(x.view(-1), b, so3_soft_grid(b))

print("so3_rft: {}s".format(time.perf_counter() - t))


# Compare values
assert (y1 - y2).abs().max().item() < 1e-4 * y1.abs().mean().item()