def test_SO3Vec_cat(self, batch1, batch2, batch3, channels1, channels2, channels3, maxl1, maxl2, maxl3): tau1 = [channels1] * (maxl1+1) tau2 = [channels2] * (maxl2+1) tau3 = [channels3] * (maxl2+1) tau12 = SO3Tau.cat([tau1, tau2]) tau123 = SO3Tau.cat([tau1, tau2, tau3]) vec1 = SO3Vec.randn(tau1, batch1) vec2 = SO3Vec.randn(tau2, batch2) vec3 = SO3Vec.randn(tau3, batch3) if batch1 == batch2: vec12 = so3_torch.cat([vec1, vec2]) assert vec12.tau == tau12 else: with pytest.raises(RuntimeError): vec12 = so3_torch.cat([vec1, vec2]) if batch1 == batch2 == batch3: vec123 = so3_torch.cat([vec1, vec2, vec3]) assert vec123.tau == tau123 else: with pytest.raises(RuntimeError): vec12 = so3_torch.cat([vec1, vec2, vec3])
def test_CGProduct(self, batch, maxl1, maxl2, maxl, channels): maxl_all = max(maxl1, maxl2, maxl) D, R, _ = rot.gen_rot(maxl_all) cg_dict = CGDict(maxl=maxl_all, dtype=torch.double) cg_prod = CGProduct(maxl=maxl, dtype=torch.double, cg_dict=cg_dict) tau1 = SO3Tau([channels] * (maxl1 + 1)) tau2 = SO3Tau([channels] * (maxl2 + 1)) vec1 = SO3Vec.randn(tau1, batch, dtype=torch.double) vec2 = SO3Vec.randn(tau2, batch, dtype=torch.double) vec1i = vec1.apply_wigner(D, dir='left') vec2i = vec2.apply_wigner(D, dir='left') vec_prod = cg_prod(vec1, vec2) veci_prod = cg_prod(vec1i, vec2i) vecf_prod = vec_prod.apply_wigner(D, dir='left') # diff = (sph_harmsr - sph_harmsd).abs() diff = [(p1 - p2).abs().max() for p1, p2 in zip(veci_prod, vecf_prod)] assert all([d < 1e-6 for d in diff])