示例#1
0
    def test_kron(self):
        with o3.torch_default_dtype(torch.float64):
            m1 = torch.randn(4, 4)
            m2 = torch.randn(3, 5)
            m3 = torch.randn(6, 6)

            x1 = o3.kron(m1, m2, m3)
            x2 = o3.kron(m1, o3.kron(m2, m3))
            assert torch.allclose(x1, x2)
示例#2
0
        def representation(alpha, beta, gamma, parity=None):
            def re(r):
                if callable(r):
                    if has_parity:
                        return r(alpha, beta, gamma, parity)
                    return r(alpha, beta, gamma)
                return rep(r, alpha, beta, gamma, parity)

            m = o3.kron(*(re(kw_Rs[i]) for i in f0))
            return Q @ m @ Q.T
示例#3
0
 def to_irrep_transformation(self):
     dim = self.tensor.dim()
     change = o3.kron(*[o3.xyz_to_irreducible_basis()] * dim)
     Rs = [(1, 1)]  # vectors
     old_indices = self.formula.split("=")[0]
     Rs_out, Q = rs.reduce_tensor(self.formula,
                                  **{i: Rs
                                     for i in old_indices})
     return Rs_out, torch.einsum('ab,bc->ac', Q,
                                 change.reshape(3**dim, 3**dim))
示例#4
0
 def representation(alpha, beta, gamma):
     m = o3.kron(*(rep(kw_Rs[i], alpha, beta, gamma) for i in f0))
     return Q @ m @ Q.T