예제 #1
0
파일: rs_test.py 프로젝트: soupwaylee/e3nn
def test_tensor_product_symmetry():
    with o3.torch_default_dtype(torch.float64):
        Rs_in = [(3, 0), (2, 1), (5, 2)]
        Rs_out = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)]

        mul1 = rs.TensorProduct(Rs_in, o3.selection_rule, Rs_out)
        mul2 = rs.TensorProduct(o3.selection_rule, Rs_in, Rs_out)

        assert mul1.Rs_in2 == mul2.Rs_in1

        x = torch.randn(rs.dim(Rs_in), rs.dim(mul1.Rs_in2))
        y1 = mul1(x)
        y2 = mul2(x.T)

        assert (y1 - y2).abs().max() < 1e-10
예제 #2
0
    def __init__(self,
                 Rs_in,
                 Rs_out,
                 RadialModel,
                 r,
                 r_eps=0,
                 selection_rule=o3.selection_rule_in_out_sh,
                 normalization='component'):
        """
        :param Rs_in: list of triplet (multiplicity, representation order, parity)
        :param Rs_out: list of triplet (multiplicity, representation order, parity)
        :param RadialModel: Class(d), trainable model: R -> R^d
        :param tensor r: [..., 3]
        :param float r_eps: distance considered as zero
        :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter]
        :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...]
        :param normalization: either 'norm' or 'component'
        representation order = nonnegative integer
        parity = 0 (no parity), 1 (even), -1 (odd)
        """
        super().__init__()

        self.Rs_in = rs.convention(Rs_in)
        self.Rs_out = rs.convention(Rs_out)
        self.check_input_output(selection_rule)

        *self.size, xyz = r.size()
        assert xyz == 3
        r = r.reshape(-1, 3)  # [batch, space]
        self.register_buffer('radii', r.norm(2, dim=1))  # [batch]
        self.r_eps = r_eps

        self.tp = rs.TensorProduct(self.Rs_in,
                                   selection_rule,
                                   self.Rs_out,
                                   normalization,
                                   sorted=True)
        self.Rs_f = self.tp.Rs_in2

        Y = rsh.spherical_harmonics_xyz(
            [(1, l, p) for _, l, p in self.Rs_f],
            r[self.radii > self.r_eps])  # [batch, l_filter * m_filter]

        # Normalize the spherical harmonics
        if normalization == 'component':
            Y.mul_(math.sqrt(4 * math.pi))
        if normalization == 'norm':
            diag = math.sqrt(4 * math.pi) * torch.cat([
                torch.ones(2 * l + 1) / math.sqrt(2 * l + 1)
                for _, l, _ in self.Rs_f
            ])
            Y.mul_(diag)

        self.register_buffer('Y', Y)
        self.R = RadialModel(rs.mul_dim(self.Rs_f))

        if (self.radii <= self.r_eps).any():
            self.linear = KernelLinear(self.Rs_in, self.Rs_out)
        else:
            self.linear = None
예제 #3
0
파일: kernel_mod.py 프로젝트: wudangt/e3nn
    def __init__(self, Rs_in, Rs_out, RadialModel,
                 selection_rule=o3.selection_rule_in_out_sh,
                 normalization='component',
                 allow_unused_inputs=False,
                 allow_zero_outputs=False):
        """
        :param Rs_in: list of triplet (multiplicity, representation order, parity)
        :param Rs_out: list of triplet (multiplicity, representation order, parity)
        :param RadialModel: Class(d), trainable model: R -> R^d
        :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter]
        :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...]
        :param normalization: either 'norm' or 'component'
        representation order = nonnegative integer
        parity = 0 (no parity), 1 (even), -1 (odd)
        """
        super().__init__()

        self.Rs_in = rs.convention(Rs_in)
        self.Rs_out = rs.convention(Rs_out)
        if not allow_unused_inputs:
            self.check_input(selection_rule)
        if not allow_zero_outputs:
            self.check_output(selection_rule)

        self.normalization = normalization

        self.tp = rs.TensorProduct(self.Rs_in, selection_rule, Rs_out, normalization, sorted=True)
        self.Rs_f = self.tp.Rs_in2

        self.Ls = [l for _, l, _ in self.Rs_f]
        self.R = RadialModel(rs.mul_dim(self.Rs_f))

        self.linear = KernelLinear(self.Rs_in, self.Rs_out)
예제 #4
0
파일: rs_test.py 프로젝트: soupwaylee/e3nn
def test_tensor_product_to_dense():
    with o3.torch_default_dtype(torch.float64):
        Rs_1 = [(3, 0), (2, 1), (5, 2)]
        Rs_2 = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)]

        mul = rs.TensorProduct(Rs_1, Rs_2, o3.selection_rule)
        assert mul.to_dense().shape == (rs.dim(mul.Rs_out), rs.dim(Rs_1),
                                        rs.dim(Rs_2))
예제 #5
0
    def __init__(self, Rs_in, Rs_out):
        """
        :param Rs_in: list of triplet (multiplicity, representation order, parity)
        :param Rs_out: list of triplet (multiplicity, representation order, parity)
        representation order = nonnegative integer
        parity = 0 (no parity), 1 (even), -1 (odd)
        """
        super().__init__()

        selection_rule = partial(o3.selection_rule_in_out_sh, lmax=0)
        self.tp = rs.TensorProduct(Rs_in, selection_rule, Rs_out, sorted=False)
        self.weight = torch.nn.Parameter(torch.randn(rs.dim(self.tp.Rs_in2)))
예제 #6
0
파일: rs_test.py 프로젝트: soupwaylee/e3nn
def test_tensor_product_in_in_normalization_norm(Rs_in1, Rs_in2):
    with o3.torch_default_dtype(torch.float64):
        tp = rs.TensorProduct(Rs_in1,
                              Rs_in2,
                              o3.selection_rule,
                              normalization='norm')

        x1 = rs.randn(10, Rs_in1, normalization='norm')
        x2 = rs.randn(10, Rs_in2, normalization='norm')

        n = Norm(tp.Rs_out, normalization='norm')
        x = n(tp(x1, x2)).mean(0)
        assert (x.log10().abs() < 1).all()
예제 #7
0
def initialize_edges(x,
                     Rs_in,
                     pos,
                     edge_index_dict,
                     lmax,
                     self_edge=1.,
                     symmetric_edges=False):
    """Initialize edge features of DataEdgeNeighbors using node features and SphericalTensor.

    Args:
        x (torch.tensor shape [N, rs.dim(Rs_in)]): Node features.
        Rs_in (rs.TY_RS_STRICT): Representation list of input.
        pos (torch.tensor shape [N, 3]): Cartesian coordinates of nodes.
        edge_index (torch.LongTensor shape [2, num_edges]): Edges described by index of node target then node source.
        lmax (int > 0): Maximum L to use for SphericalTensor projection of radial distance vectors
        self_edge (float, optional): L=0 feature for self edges. Defaults to 1.
        symmetric_edges (bool, optional): Constrain edge features to be symmetric in node index. Defaults to False

    Returns:
        edge_x: Edge features.
        Rs_edge (rs.TY_RS_STRICT): Representation list of edge features.
    """
    from e3nn.tensor import SphericalTensor
    edge_x = []
    if symmetric_edges:
        Rs, Q = rs.reduce_tensor('ij=ji', i=Rs_in)
    else:
        Rs, Q = rs.reduce_tensor('ij', i=Rs_in, j=Rs_in)
    Q = Q.reshape(-1, rs.dim(Rs_in), rs.dim(Rs_in))
    Rs_sph = [(1, l, (-1)**l) for l in range(lmax + 1)]
    tp_kernel = rs.TensorProduct(Rs, Rs_sph, o3.selection_rule)
    keys, values = list(zip(*edge_index_dict.items()))
    sorted_edges = sorted(zip(keys, values), key=lambda x: x[1])
    for (target, source), _ in sorted_edges:
        Ia = x[target]
        Ib = x[source]
        vector = (pos[source] - pos[target]).reshape(-1, 3)
        if torch.allclose(vector, torch.zeros(vector.shape)):
            signal = torch.zeros(rs.dim(Rs_sph))
            signal[0] = self_edge
        else:
            signal = SphericalTensor.from_geometry(vector, lmax=lmax).signal
            if symmetric_edges:
                signal += SphericalTensor.from_geometry(-vector,
                                                        lmax=lmax).signal
                signal *= 0.5
        output = torch.einsum('kij,i,j->k', Q, Ia, Ib)
        output = tp_kernel(output, signal)
        edge_x.append(output)
    edge_x = torch.stack(edge_x, dim=0)
    return edge_x, tp_kernel.Rs_out
예제 #8
0
파일: rs_test.py 프로젝트: soupwaylee/e3nn
def test_tensor_product_equal_TensorProduct():
    with o3.torch_default_dtype(torch.float64):
        Rs_1 = [(3, 0), (2, 1), (5, 2)]
        Rs_2 = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)]

        Rs_out, m = rs.tensor_product(Rs_1,
                                      Rs_2,
                                      o3.selection_rule,
                                      sorted=True)
        mul = rs.TensorProduct(Rs_1, Rs_2, o3.selection_rule)

        x1 = rs.randn(1, Rs_1)
        x2 = rs.randn(1, Rs_2)

        y1 = mul(x1, x2)
        y2 = torch.einsum('zi,zj->ijz', x1, x2)
        y2 = (m @ y2.reshape(rs.dim(Rs_1) * rs.dim(Rs_2), -1)).T

        assert rs.dim(Rs_out) == y1.shape[1]
        assert (y1 - y2).abs().max() < 1e-10 * y1.abs().max()
예제 #9
0
파일: rs_test.py 프로젝트: soupwaylee/e3nn
def test_tensor_product_left_right():
    with o3.torch_default_dtype(torch.float64):
        Rs_1 = [(3, 0), (2, 1), (5, 2)]
        Rs_2 = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)]

        mul = rs.TensorProduct(Rs_1, Rs_2, o3.selection_rule)

        x1 = rs.randn(2, Rs_1)
        x2 = rs.randn(2, Rs_2)

        y0 = mul(x1, x2)

        y1 = mul(torch.einsum('zi,zj->zij', x1, x2))
        assert (y0 - y1).abs().max() < 1e-10 * y0.abs().max()

        mul._complete = 'in1'
        y1 = mul(x1, x2)
        assert (y0 - y1).abs().max() < 1e-10 * y0.abs().max()

        mul._complete = 'in2'
        y1 = mul(x1, x2)
        assert (y0 - y1).abs().max() < 1e-10 * y0.abs().max()
예제 #10
0
 def __matmul__(self, other):
     # Tensor product
     # Better handle mismatch of features indices
     tp = rs.TensorProduct(self.Rs, other.Rs, o3.selection_rule)
     return IrrepTensor(tp(self.signal, other.signal), tp.Rs_out)