Ejemplo n.º 1
0
    def __init__(self, Rs_in, selection_rule=o3.selection_rule):
        super().__init__()

        self.Rs_in = rs.simplify(Rs_in)

        self.Rs_out, mixing_matrix = rs.tensor_square(Rs_in, selection_rule, sorted=True)
        self.register_buffer('mixing_matrix', mixing_matrix)
Ejemplo n.º 2
0
    def test_tensor_square_norm(self):
        for Rs_in in [[(1, 0), (2, 1), (4, 3)]]:
            with o3.torch_default_dtype(torch.float64):
                Rs_out, Q = rs.tensor_square(Rs_in,
                                             o3.selection_rule,
                                             normalization='component',
                                             sorted=True)

                abc = o3.rand_angles()

                D_in = rs.rep(Rs_in, *abc)
                D_out = rs.rep(Rs_out, *abc)

                Q1 = torch.einsum("ijk,il->ljk", (Q, D_out))
                Q2 = torch.einsum("li,mj,kij->klm", (D_in, D_in, Q))

                d = (Q1 - Q2).pow(2).mean().sqrt() / Q1.pow(2).mean().sqrt()
                self.assertLess(d, 1e-10)

                n = Q.size(0)
                M = Q.reshape(n, -1)
                I = torch.eye(n)

                d = ((M @ M.t()) - I).pow(2).mean().sqrt()
                self.assertLess(d, 1e-10)
Ejemplo n.º 3
0
    def __init__(self,
                 Rs_in,
                 Rs_out,
                 linear=True,
                 allow_change_output=False,
                 allow_zero_outputs=False):
        super().__init__()

        self.Rs_in = rs.simplify(Rs_in)
        self.Rs_out = rs.simplify(Rs_out)

        ls = [l for _, l, _ in self.Rs_out]
        selection_rule = partial(o3.selection_rule, lfilter=lambda l: l in ls)

        if linear:
            Rs_in = [(1, 0, 1)] + self.Rs_in
        else:
            Rs_in = self.Rs_in
        self.linear = linear

        Rs_ts, T = rs.tensor_square(Rs_in, selection_rule)
        register_sparse_buffer(self, 'T', T)  # [out, in1 * in2]

        ls = [l for _, l, _ in Rs_ts]
        if allow_change_output:
            self.Rs_out = [(mul, l, p) for mul, l, p in self.Rs_out if l in ls]
        elif not allow_zero_outputs:
            assert all(l in ls for _, l, _ in self.Rs_out)

        self.kernel = KernelLinear(Rs_ts, self.Rs_out)  # [out, in, w]
Ejemplo n.º 4
0
def test_tensor_square_norm():
    for Rs_in in [[(1, 0), (1, 1)]]:
        with o3.torch_default_dtype(torch.float64):
            Rs_out, Q = rs.tensor_square(Rs_in,
                                         o3.selection_rule,
                                         normalization='component',
                                         sorted=True)

            I1 = (Q @ Q.t()).to_dense()
            I2 = torch.eye(rs.dim(Rs_out))

            d = (I1 - I2).pow(2).mean().sqrt()
            assert d < 1e-10