def forward(self, batch): features, _, mask, diff_geo, radii = constants(batch) embedding = self.layers[0] features = embedding(features) set_of_l_filters = self.layers[1][0].set_of_l_filters y = spherical_harmonics_xyz(set_of_l_filters, diff_geo) kc, act = self.layers[1] features = kc(features.div(self.avg_n_atoms**0.5), diff_geo, mask, y=y, radii=radii, custom_backward=CUSTOM_BACKWARD) features = act(features) for kc, act in self.layers[2:]: if kc.set_of_l_filters != set_of_l_filters: set_of_l_filters = kc.set_of_l_filters y = spherical_harmonics_xyz(set_of_l_filters, diff_geo) new_features = kc(features.div(self.avg_n_atoms**0.5), diff_geo, mask, y=y, radii=radii, custom_backward=CUSTOM_BACKWARD) new_features = act(new_features) new_features = new_features * mask.unsqueeze(-1) features = features + new_features return features
def test_sh_parity(self): """ (-1)^l Y(x) = Y(-x) """ with o3.torch_default_dtype(torch.float64): for l in range(7 + 1): x = torch.randn(3) Y1 = (-1)**l * o3.spherical_harmonics_xyz(l, x) Y2 = o3.spherical_harmonics_xyz(l, -x) self.assertLess((Y1 - Y2).abs().max(), 1e-10 * Y1.abs().max())
def test_sh_cuda_ordered_full(self): if torch.cuda.is_available(): with o3.torch_default_dtype(torch.float64): l = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] x = torch.randn(10, 3) x_cuda = x.cuda() Y1 = o3.spherical_harmonics_xyz(l, x) Y2 = o3.spherical_harmonics_xyz(l, x_cuda).cpu() self.assertLess((Y1 - Y2).abs().max(), 1e-7) else: print("Cuda is not available! test_sh_cuda_ordered_full skipped!")
def test_sh_cuda_single(self): if torch.cuda.is_available(): with o3.torch_default_dtype(torch.float64): for l in range(10 + 1): x = torch.randn(10, 3) x_cuda = x.cuda() Y1 = o3.spherical_harmonics_xyz(l, x) Y2 = o3.spherical_harmonics_xyz(l, x_cuda).cpu() self.assertLess((Y1 - Y2).abs().max(), 1e-7) else: print("Cuda is not available! test_sh_cuda_single skipped!")
def test_sh_closure(self): """ integral of Ylm * Yjn = delta_lj delta_mn integral of 1 over the unit sphere = 4 pi """ with o3.torch_default_dtype(torch.float64): for l1 in range(0, 3 + 1): for l2 in range(l1, 3 + 1): x = torch.randn(200000, 3) Y1 = o3.spherical_harmonics_xyz(l1, x) Y2 = o3.spherical_harmonics_xyz(l2, x) x = (Y1.view(2 * l1 + 1, 1, -1) * Y2.view(1, 2 * l2 + 1, -1)).mean(-1) * (4 * math.pi) if l1 == l2: i = torch.eye(2 * l1 + 1) self.assertLess((x - i).pow(2).max(), 1e-4) else: self.assertLess(x.pow(2).max(), 1e-4)
def test_sh_norm(self): with o3.torch_default_dtype(torch.float64): l_filter = list(range(15)) Ys = [ o3.spherical_harmonics_xyz(l, torch.randn(10, 3)) for l in l_filter ] s = torch.stack([Y.pow(2).mean(0) for Y in Ys]) d = s - 1 / (4 * math.pi) self.assertLess(d.pow(2).mean().sqrt(), 1e-10)
def test_clebsch_gordan_sh_norm(self): with o3.torch_default_dtype(torch.float64): for l_out in range(6): for l_in in range(6): for l_f in range(abs(l_out - l_in), l_out + l_in + 1): Q = o3.clebsch_gordan(l_out, l_in, l_f) Y = o3.spherical_harmonics_xyz(l_f, torch.randn( 1, 3)).view(2 * l_f + 1) QY = math.sqrt(4 * math.pi) * Q @ Y self.assertLess(abs(QY.norm() - 1), 1e-10)
def test1(self): """Test irr_repr and clebsch_gordan equivariance.""" with torch_default_dtype(torch.float64): l_in = 3 l_out = 2 for l_f in range(abs(l_in - l_out), l_in + l_out + 1): r = torch.randn(100, 3) Q = o3.clebsch_gordan(l_out, l_in, l_f) abc = torch.randn(3) D_in = o3.irr_repr(l_in, *abc) D_out = o3.irr_repr(l_out, *abc) Y = o3.spherical_harmonics_xyz(l_f, r @ o3.rot(*abc).t()) W = torch.einsum("ijk,kz->zij", (Q, Y)) W1 = torch.einsum("zij,jk->zik", (W, D_in)) Y = o3.spherical_harmonics_xyz(l_f, r) W = torch.einsum("ijk,kz->zij", (Q, Y)) W2 = torch.einsum("ij,zjk->zik", (D_out, W)) self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)
def __init__(self, Rs, act, n): ''' map to the sphere, apply the non linearity point wise and project back the signal on the sphere is a quasiregular representation of O3 and we can apply a pointwise operation on these representations :param Rs: input representation :param act: activation function :param n: number of point on the sphere (the higher the more accurate) ''' super().__init__() Rs = rs.simplify(Rs) mul0, _, p0 = Rs[0] assert all(mul0 == mul for mul, _, _ in Rs) assert [l for _, l, _ in Rs] == list(range(len(Rs))) assert all(p == p0 for _, l, p in Rs) or all(p == p0 * (-1)**l for _, l, p in Rs) if p0 == +1 or p0 == 0: self.Rs_out = Rs if p0 == -1: x = torch.linspace(0, 10, 256) a1, a2 = act(x), act(-x) if (a1 - a2).abs().max() < a1.abs().max() * 1e-10: # p_act = 1 self.Rs_out = [(mul, l, -p) for mul, l, p in Rs] elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10: # p_act = -1 self.Rs_out = Rs else: # p_act = 0 raise ValueError("warning! the parity is violated") x = torch.randn(n, 3) x = torch.cat([x, -x]) Y = o3.spherical_harmonics_xyz(list(range(len(Rs))), x) # [lm, z] self.register_buffer('Y', Y) self.act = act