def __init__(self, Rs_in, selection_rule=o3.selection_rule): super().__init__() self.Rs_in = rs.simplify(Rs_in) self.Rs_out, mixing_matrix = rs.tensor_square(Rs_in, selection_rule, sorted=True) self.register_buffer('mixing_matrix', mixing_matrix)
def test_tensor_square_norm(self): for Rs_in in [[(1, 0), (2, 1), (4, 3)]]: with o3.torch_default_dtype(torch.float64): Rs_out, Q = rs.tensor_square(Rs_in, o3.selection_rule, normalization='component', sorted=True) abc = o3.rand_angles() D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) Q1 = torch.einsum("ijk,il->ljk", (Q, D_out)) Q2 = torch.einsum("li,mj,kij->klm", (D_in, D_in, Q)) d = (Q1 - Q2).pow(2).mean().sqrt() / Q1.pow(2).mean().sqrt() self.assertLess(d, 1e-10) n = Q.size(0) M = Q.reshape(n, -1) I = torch.eye(n) d = ((M @ M.t()) - I).pow(2).mean().sqrt() self.assertLess(d, 1e-10)
def __init__(self, Rs_in, Rs_out, linear=True, allow_change_output=False, allow_zero_outputs=False): super().__init__() self.Rs_in = rs.simplify(Rs_in) self.Rs_out = rs.simplify(Rs_out) ls = [l for _, l, _ in self.Rs_out] selection_rule = partial(o3.selection_rule, lfilter=lambda l: l in ls) if linear: Rs_in = [(1, 0, 1)] + self.Rs_in else: Rs_in = self.Rs_in self.linear = linear Rs_ts, T = rs.tensor_square(Rs_in, selection_rule) register_sparse_buffer(self, 'T', T) # [out, in1 * in2] ls = [l for _, l, _ in Rs_ts] if allow_change_output: self.Rs_out = [(mul, l, p) for mul, l, p in self.Rs_out if l in ls] elif not allow_zero_outputs: assert all(l in ls for _, l, _ in self.Rs_out) self.kernel = KernelLinear(Rs_ts, self.Rs_out) # [out, in, w]
def test_tensor_square_norm(): for Rs_in in [[(1, 0), (1, 1)]]: with o3.torch_default_dtype(torch.float64): Rs_out, Q = rs.tensor_square(Rs_in, o3.selection_rule, normalization='component', sorted=True) I1 = (Q @ Q.t()).to_dense() I2 = torch.eye(rs.dim(Rs_out)) d = (I1 - I2).pow(2).mean().sqrt() assert d < 1e-10