def forward(self, x): # pylint: disable=W ''' :x: [batch, feature_in, beta, alpha, gamma] :return: [batch, feature_out, beta, alpha, gamma] ''' assert x.size(1) == self.nfeature_in assert x.size(2) == 2 * self.b_in assert x.size(3) == 2 * self.b_in assert x.size(4) == 2 * self.b_in x = SO3_fft_real.apply( x, self.b_out) # [l * m * n, batch, feature_in, complex] y = so3_rft(self.kernel * self.scaling, self.b_out, self.grid) # [l * m * n, feature_in, feature_out, complex] assert x.size(0) == y.size(0) assert x.size(2) == y.size(1) z = so3_mm(x, y) # [l * m * n, batch, feature_out, complex] assert z.size(0) == x.size(0) assert z.size(1) == x.size(1) assert z.size(2) == y.size(2) z = SO3_ifft_real.apply(z) # [batch, feature_out, beta, alpha, gamma] z = z + self.bias return z
def test_so3_rfft(b_in, b_out, device): x = torch.randn(2 * b_in, 2 * b_in, 2 * b_in, dtype=torch.float, device=device) # [beta, alpha, gamma] from s2cnn.soft.so3_fft import so3_rfft y1 = so3_rfft(x, b_out=b_out) from s2cnn import so3_rft, so3_soft_grid import lie_learn.spaces.S3 as S3 # so3_ft computes a non weighted Fourier transform weights = torch.tensor(S3.quadrature_weights(b_in), dtype=torch.float, device=device) x2 = torch.einsum("bac,b->bac", (x, weights)) y2 = so3_rft(x2.view(-1), b_out, so3_soft_grid(b_in)) assert (y1 - y2).abs().max().item() < 1e-4 * y1.abs().mean().item()
def call(self, x): assert K.int_shape(x)[1] == self.nfeature_in assert K.int_shape(x)[2] == 2 * self.b_in assert K.int_shape(x)[3] == 2 * self.b_in assert K.int_shape(x)[4] == 2 * self.b_in x = so3_rfft(x, self.b_out) y = so3_rft(self.kernel * self.scaling, self.b_out, self.grid) assert K.int_shape(x)[0] == K.int_shape(y)[0] assert K.int_shape(x)[2] == K.int_shape(y)[1] z = K.dot(x, y) # [l * m * n, batch, feature_out, complex] assert K.int_shape(z)[0] == K.int_shape(x)[0] assert K.int_shape(z)[1] == K.int_shape(x)[1] assert K.int_shape(z)[2] == K.int_shape(y)[2] z = so3_rifft(z) # [batch, feature_out, beta, alpha, gamma] if self.use_bias: z = K.eval(z) z = z + self.bias z = K.constant(z) return z
Compare so3_ft with so3_fft ''' import torch device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") b_in, b_out = 6, 6 # bandwidth # random input data to be Fourier Transform x = torch.randn(2 * b_in, 2 * b_in, 2 * b_in, dtype=torch.float, device=device) # [beta, alpha, gamma] # Fast version from s2cnn.soft.so3_fft import so3_rfft y1 = so3_rfft(x, b_out=b_out) # Equivalent version but using the naive version from s2cnn import so3_rft, so3_soft_grid import lie_learn.spaces.S3 as S3 # so3_ft computes a non weighted Fourier transform weights = torch.tensor(S3.quadrature_weights(b_in), dtype=torch.float, device=device) x = torch.einsum("bac,b->bac", (x, weights)) y2 = so3_rft(x.view(-1), b_out, so3_soft_grid(b_in)) # Compare values assert (y1 - y2).abs().max().item() < 1e-4 * y1.abs().mean().item()
b = 6 # bandwidth # random input data to be Fourier Transform x = torch.randn(2 * b, 2 * b, 2 * b, dtype=torch.float, device="cuda") # [beta, alpha, gamma] # Fast version from s2cnn.soft.gpu.so3_fft import so3_rfft t = time.perf_counter() y1 = so3_rfft(x) print("so3_rfft: {}s".format(time.perf_counter() - t)) # Equivalent version but using the naive version from s2cnn import so3_rft, so3_soft_grid import lie_learn.spaces.S3 as S3 t = time.perf_counter() # so3_ft computes a non weighted Fourier transform weights = torch.tensor(S3.quadrature_weights(b), dtype=torch.float, device="cuda") x = torch.einsum("bac,b->bac", (x, weights)) y2 = so3_rft(x.view(-1), b, so3_soft_grid(b)) print("so3_rft: {}s".format(time.perf_counter() - t)) # Compare values assert (y1 - y2).abs().max().item() < 1e-4 * y1.abs().mean().item()