예제 #1
0
 def forward(self, batch):
     features, _, mask, diff_geo, radii = constants(batch)
     embedding = self.layers[0]
     features = embedding(features)
     set_of_l_filters = self.layers[1][0].set_of_l_filters
     y = spherical_harmonics_xyz(set_of_l_filters, diff_geo)
     kc, act = self.layers[1]
     features = kc(features.div(self.avg_n_atoms**0.5),
                   diff_geo,
                   mask,
                   y=y,
                   radii=radii,
                   custom_backward=CUSTOM_BACKWARD)
     features = act(features)
     for kc, act in self.layers[2:]:
         if kc.set_of_l_filters != set_of_l_filters:
             set_of_l_filters = kc.set_of_l_filters
             y = spherical_harmonics_xyz(set_of_l_filters, diff_geo)
         new_features = kc(features.div(self.avg_n_atoms**0.5),
                           diff_geo,
                           mask,
                           y=y,
                           radii=radii,
                           custom_backward=CUSTOM_BACKWARD)
         new_features = act(new_features)
         new_features = new_features * mask.unsqueeze(-1)
         features = features + new_features
     return features
예제 #2
0
 def test_sh_parity(self):
     """
     (-1)^l Y(x) = Y(-x)
     """
     with o3.torch_default_dtype(torch.float64):
         for l in range(7 + 1):
             x = torch.randn(3)
             Y1 = (-1)**l * o3.spherical_harmonics_xyz(l, x)
             Y2 = o3.spherical_harmonics_xyz(l, -x)
             self.assertLess((Y1 - Y2).abs().max(), 1e-10 * Y1.abs().max())
예제 #3
0
 def test_sh_cuda_ordered_full(self):
     if torch.cuda.is_available():
         with o3.torch_default_dtype(torch.float64):
             l = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
             x = torch.randn(10, 3)
             x_cuda = x.cuda()
             Y1 = o3.spherical_harmonics_xyz(l, x)
             Y2 = o3.spherical_harmonics_xyz(l, x_cuda).cpu()
             self.assertLess((Y1 - Y2).abs().max(), 1e-7)
     else:
         print("Cuda is not available! test_sh_cuda_ordered_full skipped!")
예제 #4
0
 def test_sh_cuda_single(self):
     if torch.cuda.is_available():
         with o3.torch_default_dtype(torch.float64):
             for l in range(10 + 1):
                 x = torch.randn(10, 3)
                 x_cuda = x.cuda()
                 Y1 = o3.spherical_harmonics_xyz(l, x)
                 Y2 = o3.spherical_harmonics_xyz(l, x_cuda).cpu()
                 self.assertLess((Y1 - Y2).abs().max(), 1e-7)
     else:
         print("Cuda is not available! test_sh_cuda_single skipped!")
예제 #5
0
 def test_sh_closure(self):
     """
     integral of Ylm * Yjn = delta_lj delta_mn
     integral of 1 over the unit sphere = 4 pi
     """
     with o3.torch_default_dtype(torch.float64):
         for l1 in range(0, 3 + 1):
             for l2 in range(l1, 3 + 1):
                 x = torch.randn(200000, 3)
                 Y1 = o3.spherical_harmonics_xyz(l1, x)
                 Y2 = o3.spherical_harmonics_xyz(l2, x)
                 x = (Y1.view(2 * l1 + 1, 1, -1) *
                      Y2.view(1, 2 * l2 + 1, -1)).mean(-1) * (4 * math.pi)
                 if l1 == l2:
                     i = torch.eye(2 * l1 + 1)
                     self.assertLess((x - i).pow(2).max(), 1e-4)
                 else:
                     self.assertLess(x.pow(2).max(), 1e-4)
예제 #6
0
 def test_sh_norm(self):
     with o3.torch_default_dtype(torch.float64):
         l_filter = list(range(15))
         Ys = [
             o3.spherical_harmonics_xyz(l, torch.randn(10, 3))
             for l in l_filter
         ]
         s = torch.stack([Y.pow(2).mean(0) for Y in Ys])
         d = s - 1 / (4 * math.pi)
         self.assertLess(d.pow(2).mean().sqrt(), 1e-10)
예제 #7
0
 def test_clebsch_gordan_sh_norm(self):
     with o3.torch_default_dtype(torch.float64):
         for l_out in range(6):
             for l_in in range(6):
                 for l_f in range(abs(l_out - l_in), l_out + l_in + 1):
                     Q = o3.clebsch_gordan(l_out, l_in, l_f)
                     Y = o3.spherical_harmonics_xyz(l_f, torch.randn(
                         1, 3)).view(2 * l_f + 1)
                     QY = math.sqrt(4 * math.pi) * Q @ Y
                     self.assertLess(abs(QY.norm() - 1), 1e-10)
예제 #8
0
    def test1(self):
        """Test irr_repr and clebsch_gordan equivariance."""
        with torch_default_dtype(torch.float64):
            l_in = 3
            l_out = 2

            for l_f in range(abs(l_in - l_out), l_in + l_out + 1):
                r = torch.randn(100, 3)
                Q = o3.clebsch_gordan(l_out, l_in, l_f)

                abc = torch.randn(3)
                D_in = o3.irr_repr(l_in, *abc)
                D_out = o3.irr_repr(l_out, *abc)

                Y = o3.spherical_harmonics_xyz(l_f, r @ o3.rot(*abc).t())
                W = torch.einsum("ijk,kz->zij", (Q, Y))
                W1 = torch.einsum("zij,jk->zik", (W, D_in))

                Y = o3.spherical_harmonics_xyz(l_f, r)
                W = torch.einsum("ijk,kz->zij", (Q, Y))
                W2 = torch.einsum("ij,zjk->zik", (D_out, W))

                self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)
예제 #9
0
    def __init__(self, Rs, act, n):
        '''
        map to the sphere, apply the non linearity point wise and project back
        the signal on the sphere is a quasiregular representation of O3
        and we can apply a pointwise operation on these representations

        :param Rs: input representation
        :param act: activation function
        :param n: number of point on the sphere (the higher the more accurate)
        '''
        super().__init__()

        Rs = rs.simplify(Rs)
        mul0, _, p0 = Rs[0]
        assert all(mul0 == mul for mul, _, _ in Rs)
        assert [l for _, l, _ in Rs] == list(range(len(Rs)))
        assert all(p == p0 for _, l, p in Rs) or all(p == p0 * (-1)**l
                                                     for _, l, p in Rs)

        if p0 == +1 or p0 == 0:
            self.Rs_out = Rs
        if p0 == -1:
            x = torch.linspace(0, 10, 256)
            a1, a2 = act(x), act(-x)
            if (a1 - a2).abs().max() < a1.abs().max() * 1e-10:
                # p_act = 1
                self.Rs_out = [(mul, l, -p) for mul, l, p in Rs]
            elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10:
                # p_act = -1
                self.Rs_out = Rs
            else:
                # p_act = 0
                raise ValueError("warning! the parity is violated")

        x = torch.randn(n, 3)
        x = torch.cat([x, -x])
        Y = o3.spherical_harmonics_xyz(list(range(len(Rs))), x)  # [lm, z]
        self.register_buffer('Y', Y)
        self.act = act