def test_sh_parity():
    """
    (-1)^l Y(x) = Y(-x)
    """
    torch.set_default_dtype(torch.float64)
    for l in range(11 + 1):
        x = torch.randn(3)
        Y1 = (-1)**l * o3.spherical_harmonics([l], x)
        Y2 = o3.spherical_harmonics([l], -x)
        assert (Y1 - Y2).abs().max() < 1e-10
def test_sh_normalization():
    torch.set_default_dtype(torch.float64)
    for l in range(11 + 1):
        n = o3.spherical_harmonics([l], torch.randn(3),
                                   'integral').pow(2).mean()
        assert abs(n - 1 / (4 * math.pi)) < 1e-10

        n = o3.spherical_harmonics([l], torch.randn(3), 'norm').norm()
        assert abs(n - 1) < 1e-10

        n = o3.spherical_harmonics([l], torch.randn(3),
                                   'component').pow(2).mean()
        assert abs(n - 1) < 1e-10
Exemplo n.º 3
0
    def forward(self, data) -> torch.Tensor:
        num_neighbors = 2  # typical number of neighbors
        num_nodes = 4  # typical number of nodes

        edge_src, edge_dst = radius_graph(
            data.pos, 1.1,
            data.batch)  # tensors of indices representing the graph
        edge_vec = data.pos[edge_src] - data.pos[edge_dst]
        edge_sh = o3.spherical_harmonics(self.sh,
                                         edge_vec,
                                         normalization='component')

        # For each node, the initial features are the sum of the spherical harmonics of the neighbors
        node_features = scatter(edge_sh, edge_dst,
                                dim=0).div(num_neighbors**0.5)

        # For each edge, tensor product the features on the source node with the spherical harmonics
        edge_features = self.tp1(node_features[edge_src], edge_sh)
        node_features = scatter(edge_features, edge_dst,
                                dim=0).div(num_neighbors**0.5)

        edge_features = self.tp2(node_features[edge_src], edge_sh)
        node_features = scatter(edge_features, edge_dst,
                                dim=0).div(num_neighbors**0.5)

        # For each graph, all the node's features are summed
        return scatter(node_features, data.batch, dim=0).div(num_nodes**0.5)
def test_sh_equivariance2():
    """test
    - rot
    - rep
    - spherical_harmonics
    """
    torch.set_default_dtype(torch.float64)
    Rs = [0, 1, 2, 3, 4, 5, 6]

    abc = o3.rand_angles()
    R = o3.rot(*abc)
    D = o3.rep(Rs, *abc)

    x = torch.randn(10, 3)

    y1 = o3.spherical_harmonics(Rs, x @ R.T)
    y2 = o3.spherical_harmonics(Rs, x) @ D.T

    assert (y1 - y2).abs().max() < 1e-10
def test_sh_equivariance2():
    """test
    - rot
    - rep
    - spherical_harmonics
    """
    torch.set_default_dtype(torch.float64)
    irreps = o3.Irreps("0e + 1o + 2e + 3o + 4e")

    abc = o3.rand_angles()
    R = o3.rot(*abc)
    D = irreps.D(*abc)

    x = torch.randn(10, 3)

    y1 = o3.spherical_harmonics(irreps, x @ R.T)
    y2 = o3.spherical_harmonics(irreps, x) @ D.T

    assert (y1 - y2).abs().max() < 1e-10
Exemplo n.º 6
0
    def forward(self, z, pos, batch=None):
        assert z.dim() == 1 and z.dtype == torch.long
        batch = torch.zeros_like(z) if batch is None else batch

        h = self.embedding(z)

        edge_index = radius_graph(pos,
                                  r=self.cutoff,
                                  batch=batch,
                                  max_num_neighbors=1000)
        row, col = edge_index
        edge_vec = pos[row] - pos[col]
        edge_sh = o3.spherical_harmonics(self.Rs_sh, edge_vec,
                                         'component') / self.num_neighbors**0.5
        edge_len = edge_vec.norm(dim=1)
        edge_weight = self.radial(edge_len)
        edge_c = (pi * edge_len / self.cutoff).cos().add(1).div(2)

        for conv, act, shortcut in self.layers[:-1]:
            with torch.autograd.profiler.record_function("Layer"):
                if shortcut:
                    s = shortcut(h)

                h = conv(h, edge_index, edge_weight, edge_c,
                         edge_sh)  # convolution
                h = act(h)  # gate non linearity

                if shortcut:
                    m = shortcut.output_mask
                    h = 0.5**0.5 * s + (1 + (0.5**0.5 - 1) * m) * h

        with torch.autograd.profiler.record_function("Layer"):
            h = self.layers[-1](h, edge_index, edge_weight, edge_c, edge_sh)

        s = 0
        for i, (mul, l, p) in enumerate(self.Rs_out):
            assert mul == 1 and l == 0
            if p == 1:
                s += h[:, i]
            if p == -1:
                s += h[:, i].pow(2).mul(0.5)  # odd^2 = even
        h = s.view(-1, 1)

        if self.mean is not None and self.std is not None:
            h = h * self.std + self.mean

        if self.atomref is not None:
            h = h + self.atomref(z)

        out = scatter(h, batch, dim=0, reduce=self.readout)

        if self.scale is not None:
            out = self.scale * out

        return out
def test_sh_closure():
    """
    integral of Ylm * Yjn = delta_lj delta_mn
    integral of 1 over the unit sphere = 4 pi
    """
    torch.set_default_dtype(torch.float64)
    x = torch.randn(300_000, 3)
    Ys = [o3.spherical_harmonics([l], x) for l in range(0, 3 + 1)]
    for l1, Y1 in enumerate(Ys):
        for l2, Y2 in enumerate(Ys):
            m = Y1[:, :, None] * Y2[:, None, :]
            m = m.mean(0) * 4 * math.pi
            if l1 == l2:
                i = torch.eye(2 * l1 + 1)
                assert (m - i).abs().max() < 0.01
            else:
                assert m.abs().max() < 0.01
def test_all_sh():
    ls = list(range(11 + 1))
    pos = torch.randn(4, 3)
    o3.spherical_harmonics(ls, pos)
    for l in ls:
        o3.spherical_harmonics(l, pos)
def test_zeros():
    assert torch.allclose(
        o3.spherical_harmonics([0, 1], torch.zeros(1, 3),
                               normalization='norm'),
        torch.tensor([[1, 0, 0, 0.0]]))