예제 #1
0
    def test_SO3Vec_cat(self, batch1, batch2, batch3, channels1, channels2, channels3, maxl1, maxl2, maxl3):
        tau1 = [channels1] * (maxl1+1)
        tau2 = [channels2] * (maxl2+1)
        tau3 = [channels3] * (maxl2+1)

        tau12 = SO3Tau.cat([tau1, tau2])
        tau123 = SO3Tau.cat([tau1, tau2, tau3])

        vec1 = SO3Vec.randn(tau1, batch1)
        vec2 = SO3Vec.randn(tau2, batch2)
        vec3 = SO3Vec.randn(tau3, batch3)

        if batch1 == batch2:
            vec12 = so3_torch.cat([vec1, vec2])

            assert vec12.tau == tau12
        else:
            with pytest.raises(RuntimeError):
                vec12 = so3_torch.cat([vec1, vec2])

        if batch1 == batch2 == batch3:
            vec123 = so3_torch.cat([vec1, vec2, vec3])

            assert vec123.tau == tau123
        else:
            with pytest.raises(RuntimeError):
                vec12 = so3_torch.cat([vec1, vec2, vec3])
예제 #2
0
    def test_cat(self):
        tau1 = SO3Tau([1, 2, 3])
        tau2 = SO3Tau([1, 1])
        tau3 = SO3Tau([0, 0, 2])

        tau = SO3Tau.cat([tau1, tau2])
        assert list(tau) == [2, 3, 3]

        assert type(tau) == SO3Tau

        print(tau)

        tau = (tau1 & tau2)
        assert list(tau) == [2, 3, 3]

        tau1 &= tau2
        assert list(tau1) == [2, 3, 3]

        tau123 = (tau1 & tau2) & tau3

        assert SO3Tau.cat([tau1, tau2, tau3]) == tau123