Пример #1
0
    def test_SO3Vec_cat(self, batch1, batch2, batch3, channels1, channels2, channels3, maxl1, maxl2, maxl3):
        tau1 = [channels1] * (maxl1+1)
        tau2 = [channels2] * (maxl2+1)
        tau3 = [channels3] * (maxl2+1)

        tau12 = SO3Tau.cat([tau1, tau2])
        tau123 = SO3Tau.cat([tau1, tau2, tau3])

        vec1 = SO3Vec.randn(tau1, batch1)
        vec2 = SO3Vec.randn(tau2, batch2)
        vec3 = SO3Vec.randn(tau3, batch3)

        if batch1 == batch2:
            vec12 = so3_torch.cat([vec1, vec2])

            assert vec12.tau == tau12
        else:
            with pytest.raises(RuntimeError):
                vec12 = so3_torch.cat([vec1, vec2])

        if batch1 == batch2 == batch3:
            vec123 = so3_torch.cat([vec1, vec2, vec3])

            assert vec123.tau == tau123
        else:
            with pytest.raises(RuntimeError):
                vec12 = so3_torch.cat([vec1, vec2, vec3])
Пример #2
0
    def test_CGProduct(self, batch, maxl1, maxl2, maxl, channels):
        maxl_all = max(maxl1, maxl2, maxl)
        D, R, _ = rot.gen_rot(maxl_all)

        cg_dict = CGDict(maxl=maxl_all, dtype=torch.double)
        cg_prod = CGProduct(maxl=maxl, dtype=torch.double, cg_dict=cg_dict)

        tau1 = SO3Tau([channels] * (maxl1 + 1))
        tau2 = SO3Tau([channels] * (maxl2 + 1))

        vec1 = SO3Vec.randn(tau1, batch, dtype=torch.double)
        vec2 = SO3Vec.randn(tau2, batch, dtype=torch.double)

        vec1i = vec1.apply_wigner(D, dir='left')
        vec2i = vec2.apply_wigner(D, dir='left')

        vec_prod = cg_prod(vec1, vec2)
        veci_prod = cg_prod(vec1i, vec2i)

        vecf_prod = vec_prod.apply_wigner(D, dir='left')

        # diff = (sph_harmsr - sph_harmsd).abs()
        diff = [(p1 - p2).abs().max() for p1, p2 in zip(veci_prod, vecf_prod)]
        assert all([d < 1e-6 for d in diff])