def test_SphericalTensor(self): torch.set_default_dtype(torch.float64) lmax = 6 mul = 1 sphten.SphericalTensor(torch.randn(mul * (lmax + 1) ** 2), mul, lmax) mul = 3 sphten.SphericalTensor(torch.randn(mul * (lmax + 1) ** 2), mul, lmax)
def test_add(self): lmax = 4 mul = 1 signal1 = torch.zeros((lmax + 1) ** 2) signal2 = signal1.clone() signal1[0] = 1. signal2[3] = 1. sph1 = sphten.SphericalTensor(signal1, mul, lmax) sph2 = sphten.SphericalTensor(signal2, mul, lmax) new_sph = sph1 + sph2 assert new_sph.mul == mul assert new_sph.lmax == max(sph1.lmax, sph2.lmax)
def test_mul_and_dot(self): lmax = 4 mul = 1 signal1 = torch.zeros((lmax + 1) ** 2) signal2 = signal1.clone() signal1[0] = 1. signal2[3] = 1. sph1 = sphten.SphericalTensor(signal1, mul, lmax) sph2 = sphten.SphericalTensor(signal2, mul, lmax) new_sph = sph1 * sph2 assert rs.are_equal(new_sph.Rs, [(rs.mul_dim(sph1.Rs), 0, 0)]) sph1.dot(sph2)
def test_sh_dirac(self): with o3.torch_default_dtype(torch.float64): for l in range(5): angles = torch.tensor(1.2), torch.tensor(2.1) a = sphten.spherical_harmonics_dirac(torch.stack(o3.angles_to_xyz(*angles), dim=-1), l) v = sphten.SphericalTensor(a, 1, l).value(*angles) self.assertAlmostEqual(v.item(), 1)
def test_change_lmax(self): lmax = 0 mul = 1 signal = torch.zeros(rs.dim([(mul, lmax)])) sph = sphten.SphericalTensor(signal, mul, lmax) lmax_new = 5 sph_new = sph.change_lmax(lmax_new) assert sph_new.signal.shape[0] == rs.dim(sph_new.Rs)