def test_weird_irreps():
    # string input
    o3.spherical_harmonics("0e + 1o", torch.randn(1, 3), False)

    # Weird multipliciteis
    irreps = o3.Irreps("1x0e + 4x1o + 3x2e")
    out = o3.spherical_harmonics(irreps, torch.randn(7, 3), True)
    assert out.shape[-1] == irreps.dim

    # Bad parity
    with pytest.raises(ValueError):
        # L = 1 shouldn't be even for a vector input
        o3.SphericalHarmonics(
            irreps_out="1x0e + 4x1e + 3x2e",
            normalize=True,
            normalization='integral',
            irreps_in="1o",
        )

    # Good parity but psuedovector input
    _ = o3.SphericalHarmonics(irreps_in="1e",
                              irreps_out="1x0e + 4x1e + 3x2e",
                              normalize=True)

    # Invalid input
    with pytest.raises(ValueError):
        _ = o3.SphericalHarmonics(
            irreps_in="1e + 3o",  # invalid
            irreps_out="1x0e + 4x1e + 3x2e",
            normalize=True)
def test_parity(float_tolerance, l):
    r"""
    (-1)^l Y(x) = Y(-x)
    """
    x = torch.randn(3)
    Y1 = (-1)**l * o3.spherical_harmonics(l, x, False)
    Y2 = o3.spherical_harmonics(l, -x, False)
    assert (Y1 - Y2).abs().max() < float_tolerance
def test_equivariance(float_tolerance):
    lmax = 5
    irreps = o3.Irreps.spherical_harmonics(lmax)
    x = torch.randn(2, 3)
    abc = o3.rand_angles()
    y1 = o3.spherical_harmonics(irreps, x @ o3.angles_to_matrix(*abc).T, False)
    y2 = o3.spherical_harmonics(irreps, x,
                                False) @ irreps.D_from_angles(*abc).T

    assert (y1 - y2).abs().max() < 10 * float_tolerance
Ejemplo n.º 4
0
    def forward(self, data) -> torch.Tensor:
        num_neighbors = 2  # typical number of neighbors
        num_nodes = 4  # typical number of nodes

        edge_src, edge_dst = radius_graph(
            x=data.pos, r=1.1,
            batch=data.batch)  # tensors of indices representing the graph
        edge_vec = data.pos[edge_src] - data.pos[edge_dst]
        edge_sh = o3.spherical_harmonics(
            l=self.irreps_sh,
            x=edge_vec,
            normalize=
            False,  # here we don't normalize otherwise it would not be a polynomial
            normalization='component')

        # For each node, the initial features are the sum of the spherical harmonics of the neighbors
        node_features = scatter(edge_sh, edge_dst,
                                dim=0).div(num_neighbors**0.5)

        # For each edge, tensor product the features on the source node with the spherical harmonics
        edge_features = self.tp1(node_features[edge_src], edge_sh)
        node_features = scatter(edge_features, edge_dst,
                                dim=0).div(num_neighbors**0.5)

        edge_features = self.tp2(node_features[edge_src], edge_sh)
        node_features = scatter(edge_features, edge_dst,
                                dim=0).div(num_neighbors**0.5)

        # For each graph, all the node's features are summed
        return scatter(node_features, data.batch, dim=0).div(num_nodes**0.5)
Ejemplo n.º 5
0
    def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:
        batch, node_inputs, edge_src, edge_dst, edge_vec = self.preprocess(data)
        del data

        edge_attr = o3.spherical_harmonics(range(self.lmax + 1), edge_vec, True, normalization='component')

        # Edge length embedding
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedding = soft_one_hot_linspace(
            edge_length,
            0.0,
            self.max_radius,
            self.number_of_basis,
            basis='smooth_finite',  # the smooth_finite basis with cutoff = True goes to zero at max_radius
            cutoff=True,  # no need for an additional smooth cutoff
        ).mul(self.number_of_basis**0.5)

        # Node attributes are not used here
        node_attr = node_inputs.new_ones(node_inputs.shape[0], 1)

        node_outputs = self.mp(node_inputs, node_attr, edge_src, edge_dst, edge_attr, edge_length_embedding)

        if self.pool_nodes:
            return scatter(node_outputs, batch, int(batch.max()) + 1).div(self.num_nodes**0.5)
        else:
            return node_outputs
Ejemplo n.º 6
0
    def forward(self, data) -> torch.Tensor:
        num_neighbors = 3  # typical number of neighbors
        num_nodes = 4  # typical number of nodes
        num_z = self.num_z  # number of atom types

        # graph
        edge_src, edge_dst = radius_graph(data.pos, 10.0, data.batch)

        # spherical harmonics
        edge_vec = data.pos[edge_src] - data.pos[edge_dst]
        edge_sh = o3.spherical_harmonics(self.irreps_sh, edge_vec, normalize=False, normalization='component')

        # edge types
        edge_zz = num_z * data.z[edge_src] + data.z[edge_dst]  # from 0 to num_z^2 - 1
        edge_zz = torch.nn.functional.one_hot(edge_zz, num_z**2).mul(num_z)
        edge_zz = edge_zz.to(edge_sh.dtype)

        # edge attributes
        edge_attr = self.mul(edge_zz, edge_sh)

        # For each node, the initial features are the sum of the spherical harmonics of the neighbors
        node_features = scatter(edge_sh, edge_dst, dim=0).div(num_neighbors**0.5)

        # For each edge, tensor product the features on the source node with the spherical harmonics
        edge_features = self.tp1(node_features[edge_src], edge_attr)
        node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5)

        edge_features = self.tp2(node_features[edge_src], edge_attr)
        node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5)

        # For each graph, all the node's features are summed
        return scatter(node_features, data.batch, dim=0).div(num_nodes**0.5)
Ejemplo n.º 7
0
    def signal_xyz(self, signal, r):
        r"""Evaluate the signal on given points on the sphere

        .. math::

            f(\vec x / \|\vec x\|)

        Parameters
        ----------
        signal : `torch.Tensor`
            tensor of shape ``(*A, self.dim)``

        r : `torch.Tensor`
            tensor of shape ``(*B, 3)``

        Returns
        -------
        `torch.Tensor`
            tensor of shape ``(*A, *B)``

        Examples
        --------
        >>> s = SphericalTensor(3, 1, -1)
        >>> s.signal_xyz(s.randn(2, 1, 3, -1), torch.randn(2, 4, 3)).shape
        torch.Size([2, 1, 3, 2, 4])
        """
        sh = o3.spherical_harmonics(self, r, normalize=True)
        dim = (self.lmax + 1)**2
        output = torch.einsum('bi,ai->ab', sh.reshape(-1, dim),
                              signal.reshape(-1, dim))
        return output.reshape(signal.shape[:-1] + r.shape[:-1])
def test_zeros():
    assert torch.allclose(
        o3.spherical_harmonics([0, 1],
                               torch.zeros(1, 3),
                               False,
                               normalization='norm'),
        torch.tensor([[1, 0, 0, 0.0]]))
Ejemplo n.º 9
0
def test():
    from torch_cluster import radius
    from e3nn.math import soft_one_hot_linspace

    conv = Convolution(
        irreps_node_input='0e + 1e',
        irreps_node_output='0e + 1e',
        irreps_node_attr_input='2x0e',
        irreps_node_attr_output='3x0e',
        irreps_edge_attr='0e + 1e',
        num_edge_scalar_attr=4,
        radial_layers=1,
        radial_neurons=50,
        num_neighbors=3.0,
    )

    pos_in = torch.randn(5, 3)
    pos_out = torch.randn(2, 3)

    node_input = torch.randn(5, 4)
    node_attr_input = torch.randn(5, 2)
    node_attr_output = torch.randn(2, 3)

    edge_src, edge_dst = radius(pos_out, pos_in, r=2.0)
    edge_vec = pos_in[edge_src] - pos_out[edge_dst]
    edge_attr = o3.spherical_harmonics([0, 1], edge_vec, True)
    edge_scalar_attr = soft_one_hot_linspace(x=edge_vec.norm(dim=1),
                                             start=0.0,
                                             end=2.0,
                                             number=4,
                                             basis='smooth_finite',
                                             cutoff=True)

    conv(node_input, node_attr_input, node_attr_output, edge_src, edge_dst,
         edge_attr, edge_scalar_attr)
Ejemplo n.º 10
0
def test_sh_same(float_tolerance):
    for l in range(4 + 1):
        x = torch.randn(10, 3)
        a, b = o3.xyz_to_angles(x)

        y1 = o3.spherical_harmonics(l, x, True)
        y2 = o3.spherical_harmonics_alpha_beta(l, a, b)
        assert (y1 - y2).abs().max() < float_tolerance
def test_module(normalization, normalize):
    l = o3.Irreps("0e + 1o + 3o")
    sp = o3.SphericalHarmonics(l, normalize, normalization)
    sp_jit = assert_auto_jitable(sp)
    xyz = torch.randn(11, 3)
    assert torch.allclose(
        sp_jit(xyz), o3.spherical_harmonics(l, xyz, normalize, normalization))
    assert_equivariant(sp)
Ejemplo n.º 12
0
    def forward(self, data: Union[Data, Dict[str,
                                             torch.Tensor]]) -> torch.Tensor:
        """evaluate the network

        Parameters
        ----------
        data : `torch_geometric.data.Data` or dict
            data object containing
            - ``pos`` the position of the nodes (atoms)
            - ``x`` the input features of the nodes, optional
            - ``z`` the attributes of the nodes, for instance the atom type, optional
            - ``batch`` the graph to which the node belong, optional
        """
        if 'batch' in data:
            batch = data['batch']
        else:
            batch = data['pos'].new_zeros(data['pos'].shape[0],
                                          dtype=torch.long)

        edge_index = radius_graph(data['pos'], self.max_radius, batch)
        edge_src = edge_index[0]
        edge_dst = edge_index[1]
        edge_vec = data['pos'][edge_src] - data['pos'][edge_dst]
        edge_sh = o3.spherical_harmonics(self.irreps_edge_attr,
                                         edge_vec,
                                         True,
                                         normalization='component')
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedded = soft_one_hot_linspace(
            x=edge_length,
            start=0.0,
            end=self.max_radius,
            number=self.number_of_basis,
            basis='gaussian',
            cutoff=False).mul(self.number_of_basis**0.5)
        edge_attr = smooth_cutoff(
            edge_length / self.max_radius)[:, None] * edge_sh

        if self.input_has_node_in and 'x' in data:
            assert self.irreps_in is not None
            x = data['x']
        else:
            assert self.irreps_in is None
            x = data['pos'].new_ones((data['pos'].shape[0], 1))

        if self.input_has_node_attr and 'z' in data:
            z = data['z']
        else:
            assert self.irreps_node_attr == o3.Irreps("0e")
            z = data['pos'].new_ones((data['pos'].shape[0], 1))

        for lay in self.layers:
            x = lay(x, z, edge_src, edge_dst, edge_attr, edge_length_embedded)

        if self.reduce_output:
            return scatter(x, batch, dim=0).div(self.num_nodes**0.5)
        else:
            return x
def test_normalization(float_tolerance, l):

    n = o3.spherical_harmonics(l,
                               torch.randn(3),
                               normalize=True,
                               normalization='integral').pow(2).mean()
    assert abs(n - 1 / (4 * math.pi)) < float_tolerance

    n = o3.spherical_harmonics(l,
                               torch.randn(3),
                               normalize=True,
                               normalization='norm').norm()
    assert abs(n - 1) < float_tolerance

    n = o3.spherical_harmonics(l,
                               torch.randn(3),
                               normalize=True,
                               normalization='component').pow(2).mean()
    assert abs(n - 1) < float_tolerance
Ejemplo n.º 14
0
    def sum_of_diracs(self, positions: torch.Tensor,
                      values: torch.Tensor) -> torch.Tensor:
        r"""Sum (almost-) dirac deltas

        .. math::

            f(x) = \sum_i v_i \delta^L(\vec r_i)

        where :math:`\delta^L` is the apporximation of a dirac delta.

        Parameters
        ----------
        positions : `torch.Tensor`
            :math:`\vec r_i` tensor of shape ``(..., N, 3)``

        values : `torch.Tensor`
            :math:`v_i` tensor of shape ``(..., N)``

        Returns
        -------
        `torch.Tensor`
            tensor of shape ``(..., self.dim)``

        Examples
        --------
        >>> s = SphericalTensor(7, 1, -1)
        >>> pos = torch.tensor([
        ...     [1.0, 0.0, 0.0],
        ...     [0.0, 1.0, 0.0],
        ... ])
        >>> val = torch.tensor([
        ...     -1.0,
        ...     1.0,
        ... ])
        >>> x = s.sum_of_diracs(pos, val)
        >>> s.signal_xyz(x, torch.eye(3)).mul(10.0).round()
        tensor([-10.,  10.,  -0.])

        >>> s.sum_of_diracs(torch.empty(1, 0, 2, 3), torch.empty(2, 0, 1)).shape
        torch.Size([2, 0, 64])

        >>> s.sum_of_diracs(torch.randn(1, 3, 2, 3), torch.randn(2, 1, 1)).shape
        torch.Size([2, 3, 64])
        """
        positions, values = torch.broadcast_tensors(positions, values[...,
                                                                      None])
        values = values[..., 0]

        if positions.numel() == 0:
            return torch.zeros(values.shape[:-1] + (self.dim, ))

        y = o3.spherical_harmonics(self, positions, True)  # [..., N, dim]
        v = values[..., None]

        return 4 * pi / (self.lmax + 1)**2 * (y * v).sum(-2)
Ejemplo n.º 15
0
    def forward(self, data) -> torch.Tensor:
        node_atom = data['z']
        node_pos = data['pos']
        batch = data['batch']

        # The graph
        edge_src, edge_dst = radius_graph(node_pos,
                                          r=self.max_radius,
                                          batch=batch,
                                          max_num_neighbors=1000)

        # Edge attributes
        edge_vec = node_pos[edge_src] - node_pos[edge_dst]
        edge_sh = o3.spherical_harmonics(l=range(self.sh_lmax + 1),
                                         x=edge_vec,
                                         normalize=True,
                                         normalization='component')

        # Edge length embedding
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedding = soft_one_hot_linspace(
            edge_length,
            0.0,
            self.max_radius,
            self.num_basis,
            basis='smooth_finite',
            cutoff=True,
        ).mul(self.num_basis**0.5)

        node_input = node_pos.new_ones(node_pos.shape[0], 1)

        node_attr = node_atom.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3,
                                          4])[node_atom]
        node_attr = torch.nn.functional.one_hot(node_attr, 5).mul(5**0.5)

        node_outputs = self.mp(node_features=node_input,
                               node_attr=node_attr,
                               edge_src=edge_src,
                               edge_dst=edge_dst,
                               edge_attr=edge_sh,
                               edge_scalars=edge_length_embedding)

        node_outputs = node_outputs[:, 0] + node_outputs[:, 1].pow(2).mul(0.5)
        node_outputs = node_outputs.view(-1, 1)

        node_outputs = node_outputs.div(self.num_nodes**0.5)

        if self.atomref is not None:
            node_outputs = node_outputs + self.atomref[node_atom]
        # for target=7, MAE of 75eV

        outputs = scatter(node_outputs, batch, dim=0)

        return outputs
Ejemplo n.º 16
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--l_in", type=int, required=True)
    parser.add_argument("--l_out", type=int, required=True)
    parser.add_argument("--n", type=int, default=30, help="size of the SOFT grid")
    parser.add_argument("--dpi", type=float, default=100)
    parser.add_argument("--sep", type=float, default=0.5, help="space between matrices")

    args = parser.parse_args()

    torch.set_default_dtype(torch.float64)
    x, y, z, alpha, beta = spherical_surface(args.n)

    out = []
    for l in range(abs(args.l_out - args.l_in), args.l_out + args.l_in + 1):
        C = o3.clebsch_gordan(args.l_out, args.l_in, l)
        Y = o3.spherical_harmonics(l, alpha, beta)
        out.append(torch.einsum("ijk,k...->ij...", (C, Y)))
    f = torch.stack(out)

    nf, dim_out, dim_in, *_ = f.size()

    f = 0.5 + 0.5 * f / f.abs().max()

    fig = plt.figure(figsize=(nf * dim_in + (nf - 1) * args.sep, dim_out), dpi=args.dpi)

    for index in range(nf):
        for i in range(dim_out):
            for j in range(dim_in):
                width = 1 / (nf * dim_in + (nf - 1) * args.sep)
                height = 1 / dim_out
                rect = [
                    (index * (dim_in + args.sep) + j) * width,
                    (dim_out - i - 1) * height,
                    width,
                    height
                ]
                ax = fig.add_axes(rect, projection='3d')

                fc = plt.get_cmap("bwr")(f[index, i, j].detach().cpu().numpy())

                ax.plot_surface(x.numpy(), y.numpy(), z.numpy(), rstride=1, cstride=1, facecolors=fc)
                ax.set_axis_off()

                a = 0.6
                ax.set_xlim3d(-a, a)
                ax.set_ylim3d(-a, a)
                ax.set_zlim3d(-a, a)

                ax.view_init(90, 0)

    plt.savefig("kernels{}to{}.png".format(args.l_in, args.l_out), transparent=True)
Ejemplo n.º 17
0
    def __init__(self, irreps_in, irreps_out, irreps_sh, diameter, num_radial_basis, steps=(1.0, 1.0, 1.0), **kwargs):
        super().__init__()

        self.irreps_in = o3.Irreps(irreps_in)
        self.irreps_out = o3.Irreps(irreps_out)
        self.irreps_sh = o3.Irreps(irreps_sh)

        self.num_radial_basis = num_radial_basis

        # self-connection
        self.sc = Linear(self.irreps_in, self.irreps_out)

        # connection with neighbors
        r = diameter / 2

        s = math.floor(r / steps[0])
        x = torch.arange(-s, s + 1.0) * steps[0]

        s = math.floor(r / steps[1])
        y = torch.arange(-s, s + 1.0) * steps[1]

        s = math.floor(r / steps[2])
        z = torch.arange(-s, s + 1.0) * steps[2]

        lattice = torch.stack(torch.meshgrid(x, y, z), dim=-1)  # [x, y, z, R^3]
        self.register_buffer('lattice', lattice)

        if 'padding' not in kwargs:
            kwargs['padding'] = tuple(s // 2 for s in lattice.shape[:3])
        self.kwargs = kwargs

        emb = soft_one_hot_linspace(
            x=lattice.norm(dim=-1),
            start=0.0,
            end=r,
            number=self.num_radial_basis,
            basis='smooth_finite',
            cutoff=True,
        )
        self.register_buffer('emb', emb)

        sh = o3.spherical_harmonics(
            l=self.irreps_sh,
            x=lattice,
            normalize=True,
            normalization='component'
        )  # [x, y, z, irreps_sh.dim]
        self.register_buffer('sh', sh)

        self.tp = FullyConnectedTensorProduct(self.irreps_in, self.irreps_sh, self.irreps_out, shared_weights=False)

        self.weight = torch.nn.Parameter(torch.randn(self.num_radial_basis, self.tp.weight_numel))
Ejemplo n.º 18
0
    def test_spherical_harmonics(self):
        """
        This test tests that
        - irr_repr
        - compose
        - spherical_harmonics
        are compatible

        Y(Z(alpha) Y(beta) Z(gamma) x) = D(alpha, beta, gamma) Y(x)
        with x = Z(a) Y(b) eta
        """
        for order in range(7):
            with o3.torch_default_dtype(torch.float64):
                a, b = torch.rand(2)
                alpha, beta, gamma = torch.rand(3)

                ra, rb, _ = o3.compose(alpha, beta, gamma, a, b, 0)
                Yrx = o3.spherical_harmonics(order, ra, rb)

                Y = o3.spherical_harmonics(order, a, b)
                DrY = o3.irr_repr(order, alpha, beta, gamma) @ Y

                self.assertLess((Yrx - DrY).abs().max(), 1e-10 * Y.abs().max())
Ejemplo n.º 19
0
    def message(self, x_i, x_j, pos_i, pos_j, cell_offsets):
        """ Message according to eqs 3-4 in the paper """
        rel_pos = (pos_i - pos_j) + cell_offsets
        dist = rel_pos.pow(2).sum(-1, keepdims=True)
        rel_pos = spherical_harmonics(self.irreps_rel_pos, rel_pos, normalize=True, normalization='component')
        # message = self.message_net(torch.cat((x_i, x_j, dist, rel_pos), dim=-1))
        message = self.message_layer_1(torch.cat((x_i, x_j, dist, rel_pos), dim=-1))
        message = self.message_layer_2(torch.cat((message, rel_pos), dim=-1))
        message = torch.cat((message, rel_pos), dim=-1) # <---- pass the relative position along
        if self.update_pos:  # TODO: currently no updated...
            pos_message = (pos_i - pos_j) * self.pos_net(message)
            # torch geometric does not support tuple outputs.
            message = torch.cat((pos_message, message), dim=-1)

        return message
def test_recurrence_relation(float_tolerance, l):
    if torch.get_default_dtype() != torch.float64 and l > 6:
        pytest.xfail('we expect this to fail for high l and single precision')

    x = torch.randn(3, requires_grad=True)

    a = o3.spherical_harmonics(l + 1, x, False)

    b = torch.einsum('ijk,j,k->i', o3.wigner_3j(l + 1, l, 1),
                     o3.spherical_harmonics(l, x, False), x)

    alpha = b.norm() / a.norm()

    assert (a / a.norm() - b / b.norm()).abs().max() < 10 * float_tolerance

    def f(x):
        return o3.spherical_harmonics(l + 1, x, False)

    a = torch.autograd.functional.jacobian(f, x)

    b = (l + 1) / alpha * torch.einsum('ijk,j->ik', o3.wigner_3j(l + 1, l, 1),
                                       o3.spherical_harmonics(l, x, False))

    assert (a - b).abs().max() < 100 * float_tolerance
def test_closure():
    r"""
    integral of Ylm * Yjn = delta_lj delta_mn
    integral of 1 over the unit sphere = 4 pi
    """
    x = torch.randn(1_000_000, 3)
    Ys = [o3.spherical_harmonics(l, x, True) for l in range(0, 3 + 1)]
    for l1, Y1 in enumerate(Ys):
        for l2, Y2 in enumerate(Ys):
            m = Y1[:, :, None] * Y2[:, None, :]
            m = m.mean(0) * 4 * math.pi
            if l1 == l2:
                i = torch.eye(2 * l1 + 1)
                assert (m - i).abs().max() < 0.01
            else:
                assert m.abs().max() < 0.01
Ejemplo n.º 22
0
    def forward(self, data: Union[Data, Dict[str,
                                             torch.Tensor]]) -> torch.Tensor:
        if 'batch' in data:
            batch = data['batch']
        else:
            batch = data['pos'].new_zeros(data['pos'].shape[0],
                                          dtype=torch.long)

        # The graph
        edge_src = data['edge_index'][0]
        edge_dst = data['edge_index'][1]

        # Edge attributes
        edge_vec = data['pos'][edge_src] - data['pos'][edge_dst]
        edge_sh = o3.spherical_harmonics(range(self.lmax + 1),
                                         edge_vec,
                                         True,
                                         normalization='component')
        edge_attr = torch.cat([data['edge_attr'], edge_sh], dim=1)

        # Edge length embedding
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedding = soft_one_hot_linspace(
            edge_length,
            0.0,
            self.max_radius,
            self.number_of_basis,
            basis=
            'cosine',  # the cosine basis with cutoff = True goes to zero at max_radius
            cutoff=True,  # no need for an additional smooth cutoff
        ).mul(self.number_of_basis**0.5)

        node_outputs = self.mp(data['node_input'], data['node_attr'], edge_src,
                               edge_dst, edge_attr, edge_length_embedding)

        if self.pool_nodes:
            return scatter(node_outputs, batch, dim=0).div(self.num_nodes**0.5)
        else:
            return node_outputs
Ejemplo n.º 23
0
    def forward(self, data) -> torch.Tensor:
        num_nodes = 4  # typical number of nodes

        edge_src, edge_dst = radius_graph(x=data.pos, r=2.5, batch=data.batch)
        edge_vec = data.pos[edge_src] - data.pos[edge_dst]
        edge_attr = o3.spherical_harmonics(l=self.irreps_sh,
                                           x=edge_vec,
                                           normalize=True,
                                           normalization='component')
        edge_length_embedded = soft_one_hot_linspace(x=edge_vec.norm(dim=1),
                                                     start=0.5,
                                                     end=2.5,
                                                     number=3,
                                                     basis='smooth_finite',
                                                     cutoff=True) * 3**0.5

        x = scatter(edge_attr, edge_dst, dim=0).div(self.num_neighbors**0.5)

        x = self.conv(x, edge_src, edge_dst, edge_attr, edge_length_embedded)
        x = self.gate(x)
        x = self.final(x, edge_src, edge_dst, edge_attr, edge_length_embedded)

        return scatter(x, data.batch, dim=0).div(num_nodes**0.5)
Ejemplo n.º 24
0
    def forward(self, node_atom, node_pos, batch) -> torch.Tensor:
        # The graph
        edge_src, edge_dst = radius_graph(
            node_pos,
            r=self.max_radius,
            batch=batch,
            max_num_neighbors=1000
        )

        # Edge attributes
        edge_vec = node_pos[edge_src] - node_pos[edge_dst]
        edge_sh = o3.spherical_harmonics(
            l=range(self.lmax + 1),
            x=edge_vec,
            normalize=True,
            normalization='component'
        )

        # Edge length embedding
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedding = soft_one_hot_linspace(
            edge_length,
            0.0,
            self.max_radius,
            self.number_of_basis,
            basis='cosine',  # the cosine basis with cutoff = True goes to zero at max_radius
            cutoff=True,  # no need for an additional smooth cutoff
        ).mul(self.number_of_basis**0.5)

        node_input = node_pos.new_ones(node_pos.shape[0], 1)

        node_attr = node_atom.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3, 4])[node_atom]
        node_attr = torch.nn.functional.one_hot(node_attr, 5).mul(5**0.5)

        node_outputs = self.mp(
            node_features=node_input,
            node_attr=node_attr,
            edge_src=edge_src,
            edge_dst=edge_dst,
            edge_attr=edge_sh,
            edge_scalars=edge_length_embedding
        )

        node_outputs = node_outputs[:, 0] + node_outputs[:, 1].pow(2).mul(0.5)
        node_outputs = node_outputs.view(-1, 1)

        node_outputs = node_outputs.div(self.num_nodes**0.5)

        if self.mean is not None and self.std is not None:
            node_outputs = node_outputs * self.std + self.mean

        if self.atomref is not None:
            node_outputs = node_outputs + self.atomref[node_atom]
        # for target=7, MAE of 75eV

        outputs = scatter(node_outputs, batch, dim=0)

        if self.scale is not None:
            outputs = self.scale * outputs

        return outputs
def test_weird_call():
    o3.spherical_harmonics([4, 1, 2, 3, 3, 1, 0], torch.randn(2, 1, 2, 3),
                           False)
 def func(pos):
     return o3.spherical_harmonics(ls, pos, False)
Ejemplo n.º 27
0
    def forward(self, z, pos, batch=None):
        assert z.dim() == 1 and z.dtype == torch.long
        assert pos.dim() == 2 and pos.shape[1] == 3
        batch = torch.zeros_like(z) if batch is None else batch

        edge_src, edge_dst = radius_graph(pos,
                                          r=self.cutoff,
                                          batch=batch,
                                          max_num_neighbors=1000)
        edge_vec = pos[edge_src] - pos[edge_dst]
        edge_sh = o3.spherical_harmonics(self.irreps_sh, edge_vec, True,
                                         'component')
        edge_len = edge_vec.norm(dim=1)
        edge_len_emb = self.radial(
            soft_one_hot_linspace(edge_len, 0.0, self.cutoff,
                                  self.rad_gaussians))
        edge_c = (pi * edge_len / self.cutoff).cos().add(1).div(2)
        edge_sh = edge_c[:, None] * edge_sh / self.num_neighbors**0.5

        # z : [1, 6, 7, 8, 9] -> [0, 1, 2, 3, 4]
        node_z = z.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3, 4])[z]
        # edge_zz = 5 * node_z[edge_src] + node_z[edge_dst]

        node_z = torch.nn.functional.one_hot(node_z, 5).mul(5**0.5)
        # edge_zz = torch.nn.functional.one_hot(edge_zz, 25).mul(5.0)

        # edge_attr = self.mul(edge_zz, edge_sh)
        edge_attr = edge_sh

        h = scatter(edge_sh, edge_src, dim=0, dim_size=len(pos))
        h[:, 0] = 1
        h = self.mul_node(node_z, h)

        print_std('h', h)

        for conv, act in self.layers[:-1]:
            with torch.autograd.profiler.record_function("Layer"):
                h = conv(h, node_z, edge_src, edge_dst, edge_len_emb,
                         edge_attr)  # convolution
                print_std('post conv', h)
                h = act(h)  # gate non linearity
                print_std('post gate', h)

        with torch.autograd.profiler.record_function("Layer"):
            h = self.layers[-1](h, node_z, edge_src, edge_dst, edge_len_emb,
                                edge_attr)

        print_std('h out', h)

        s = 0
        for i, (mul, (l, p)) in enumerate(self.irreps_out):
            assert mul == 1 and l == 0
            if p == 1:
                s += h[:, i]
            if p == -1:
                s += h[:, i].pow(2).mul(0.5)  # odd^2 = even
        h = s.view(-1, 1)

        print_std('h out+', h)

        # for the scatter we normalize
        h = h / self.num_atoms**0.5

        if self.mean is not None and self.std is not None:
            h = h * self.std + self.mean

        if self.atomref is not None:
            h = h + self.atomref[z]
        # for target=7, MAE of 75eV

        out = scatter(h, batch, dim=0)

        if self.scale is not None:
            out = self.scale * out

        return out
Ejemplo n.º 28
0
    def _forward(self, data):

        pos = data.pos
        batch = data.batch

        if self.otf_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                data, self.cutoff, 50, data.pos.device
            )
            data.edge_index = edge_index
            data.cell_offsets = cell_offsets
            data.neighbors = neighbors

        if self.use_pbc:
            out = get_pbc_distances(
                pos,
                data.edge_index,
                data.cell,
                data.cell_offsets,
                data.neighbors,
                return_offsets=True,
            )

            edge_index = out["edge_index"]
            cell_offsets = out["offsets"]
        else:
            edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
            raise NotImplementedError

        # construct the node and edge attributes
        rel_pos = (pos[edge_index[0]] - pos[edge_index[1]]) + cell_offsets
        edge_dist = rel_pos.pow(2).sum(-1, keepdims=True)

        edge_dist_radii_1 = edge_dist - self.atom_radii[data.atomic_numbers.long()[edge_index[0]]][:, None]
        edge_dist_radii_2 = edge_dist - self.atom_radii[data.atomic_numbers.long()[edge_index[1]]][:, None]
        edge_dist_radii_12 = edge_dist - self.atom_radii[data.atomic_numbers.long()[edge_index[0]]][:, None] - self.atom_radii[
            data.atomic_numbers.long()[edge_index[1]]][:, None]

        edge_attr = spherical_harmonics(self.attr_irreps, rel_pos, normalize=True, normalization='component')
        node_attr = self.node_attribute_net(edge_index, edge_attr)
        if (data.contains_isolated_nodes() and edge_index.max().item() + 1 != data.num_nodes):
            nr_add_attr = data.num_nodes - (edge_index.max().item() + 1)
            add_attr = node_attr.new_tensor(np.tile(np.eye(node_attr.shape[-1])[0,:], (nr_add_attr,1)))
            #add_attr = node_attr.new_tensor(np.zeros((nr_add_attr, node_attr.shape[-1])))
            node_attr = torch.cat((node_attr, add_attr), -2)

        # node_attr, edge_attr = self.attribute_net(pos, edge_index)
        x = self.atom_map[data.atomic_numbers.long()]
        x = self.embedding_layer_1(x, node_attr)
        x = self.embedding_layer_2(x, node_attr)
        x = self.embedding_layer_3(x, node_attr)

        # The main layers
        for layer in self.layers:
            x, pos = layer(x, pos, edge_index, edge_dist, edge_dist_radii_1, edge_dist_radii_2, edge_dist_radii_12, edge_attr, node_attr)

        # Output head
        x = self.head_pre_pool_layer_1(x, node_attr)
        x = self.head_pre_pool_layer_2(x)
        x = global_mean_pool(x, batch)
        x = self.head_post_pool_layer_1(x)
        x = self.head_post_pool_layer_2(x)

        # Return the result
        return x
 def f(x):
     return o3.spherical_harmonics(l + 1, x, False)
Ejemplo n.º 30
0
 def __init__(self, alpha, beta, lmax):
     super().__init__()
     sh = torch.cat(
         [o3.spherical_harmonics(l, alpha, beta) for l in range(lmax + 1)])
     self.register_buffer("sh", sh)