예제 #1
0
    def test_cg_prod_tau_check(self, maxl1, maxl2, chan1, chan2, set_tau1,
                               set_tau2):
        rand_rep = lambda tau, nbatch: [
            torch.rand(nbatch + (t, 2 * l + 1, 2)).double()
            for l, t in enumerate(tau)
        ]

        tau1 = [chan1] * (maxl1 + 1)
        tau2 = [chan2] * (maxl2 + 1)

        rep1 = rand_rep(tau1, (2, ))
        rep2 = rand_rep(tau2, (2, ))

        tau1_in = tau1 if set_tau1 else None
        tau2_in = tau2 if set_tau2 else None

        if (set_tau1 and set_tau2) and chan1 != chan2:
            with pytest.raises(ValueError) as e:
                cg_prod = CGProduct(tau1_in, tau2_in, maxl=2)
            return
        else:
            cg_prod = CGProduct(tau1_in, tau2_in, maxl=2)

        if set_tau1 and set_tau2:
            tau_out = cg_prod.tau_out
        else:
            with pytest.raises(ValueError) as e:
                tau_out = cg_prod.tau_out
예제 #2
0
        def test_cg_prod_double(self, maxl, dtype):

            cg_prod = CGProduct(maxl=maxl, dtype=dtype)
            cg_prod.double()
            assert cg_prod.dtype == torch.double
            assert cg_prod.cg_dict.dtype == torch.double
            assert all(
                [t.device == torch.double for t in cg_prod.cg_dict.values()])
예제 #3
0
        def test_cg_prod_float(self, maxl, dtype):

            cg_prod = CGProduct(maxl=maxl, dtype=dtype)
            cg_prod.float()
            assert cg_prod.dtype == torch.float
            assert cg_prod.cg_dict.dtype == torch.float
            assert all(
                [t.device == torch.float for t in cg_prod.cg_dict.values()])
예제 #4
0
        def test_cg_prod_half(self, maxl, dtype):

            cg_prod = CGProduct(maxl=maxl, dtype=dtype)
            cg_prod.half()
            assert cg_prod.dtype == torch.half
            assert cg_prod.cg_dict.dtype == torch.half
            assert all(
                [t.device == torch.half for t in cg_prod.cg_dict.values()])
예제 #5
0
    def test_cg_prod_to_device(self, device1, device2):

        cg_prod = CGProduct(maxl=1, device=device1)

        cg_prod.to(device=device2)
        assert cg_prod.device == device2
        assert cg_prod.cg_dict.device == device2
        assert all([t.device == device2 for t in cg_prod.cg_dict.values()])
예제 #6
0
        def test_cg_prod_cpu(self, maxl, device):

            cg_prod = CGProduct(maxl=maxl, device=device)
            cg_prod.cpu()
            assert cg_prod.device == torch.device('cpu')
            assert cg_prod.cg_dict.device == torch.device('cpu')
            assert all([
                t.device == torch.device('cpu')
                for t in cg_prod.cg_dict.values()
            ])
예제 #7
0
        def test_cg_prod_cuda(self, maxl, device):

            if not torch.cuda.is_available():
                return

            cg_prod = CGProduct(maxl=maxl, device=device)
            cg_prod.cuda()
            assert cg_prod.device == torch.device('cuda')
            assert cg_prod.cg_dict.device == torch.device('cuda')
            assert all([
                t.device == torch.device('cuda')
                for t in cg_prod.cg_dict.values()
            ])
예제 #8
0
    def test_cg_prod_set_from_cg_dict(self, maxl, dtype):

        cg_dict = CGDict(maxl=1, dtype=torch.float)

        if dtype in [torch.half, torch.double]:
            # If data type in CGProduct does not match CGDict, throw an errror
            with pytest.raises(ValueError):
                cg_prod = CGProduct(maxl=maxl, dtype=dtype, cg_dict=cg_dict)
        else:
            cg_prod = CGProduct(maxl=maxl, dtype=dtype, cg_dict=cg_dict)

            assert cg_prod.dtype == torch.float if dtype is None else dtype
            assert cg_prod.device == torch.device('cpu')
            assert cg_prod.maxl == maxl if maxl is not None else 1
            assert cg_prod.cg_dict
            assert cg_prod.cg_dict.maxl == max(1,
                                               maxl) if maxl is not None else 1
예제 #9
0
    def test_cg_prod_cg_dict_dtype(self, maxl, dtype):

        cg_prod = CGProduct(maxl=maxl, dtype=dtype)
        assert cg_prod.dtype == torch.float if dtype is None else dtype
        assert cg_prod.device == torch.device('cpu')
        assert cg_prod.maxl == maxl
        assert cg_prod.cg_dict
        assert cg_prod.cg_dict.maxl == maxl
예제 #10
0
    def __init__(self,
                 tau_in,
                 tau_pos,
                 maxl,
                 num_channels,
                 level_gain,
                 weight_init,
                 device=None,
                 dtype=None,
                 cg_dict=None):
        super().__init__(maxl=maxl,
                         device=device,
                         dtype=dtype,
                         cg_dict=cg_dict)
        device, dtype, cg_dict = self.device, self.dtype, self.cg_dict

        self.tau_in = tau_in
        self.tau_pos = tau_pos

        # Operations linear in input reps
        self.cg_aggregate = CGProduct(tau_pos,
                                      tau_in,
                                      maxl=self.maxl,
                                      aggregate=True,
                                      device=self.device,
                                      dtype=self.dtype,
                                      cg_dict=self.cg_dict)
        tau_ag = list(self.cg_aggregate.tau)

        self.cg_power = CGProduct(tau_in,
                                  tau_in,
                                  maxl=self.maxl,
                                  device=self.device,
                                  dtype=self.dtype,
                                  cg_dict=self.cg_dict)
        tau_sq = list(self.cg_power.tau)

        self.cat_mix = CatMixReps([tau_ag, tau_in, tau_sq],
                                  num_channels,
                                  maxl=self.maxl,
                                  weight_init=weight_init,
                                  gain=level_gain,
                                  device=self.device,
                                  dtype=self.dtype)
        self.tau = self.cat_mix.tau
예제 #11
0
    def test_CGProduct(self, batch, maxl1, maxl2, maxl, channels):
        maxl_all = max(maxl1, maxl2, maxl)
        D, R, _ = rot.gen_rot(maxl_all)

        cg_dict = CGDict(maxl=maxl_all, dtype=torch.double)
        cg_prod = CGProduct(maxl=maxl, dtype=torch.double, cg_dict=cg_dict)

        tau1 = SO3Tau([channels] * (maxl1 + 1))
        tau2 = SO3Tau([channels] * (maxl2 + 1))

        vec1 = SO3Vec.randn(tau1, batch, dtype=torch.double)
        vec2 = SO3Vec.randn(tau2, batch, dtype=torch.double)

        vec1i = vec1.apply_wigner(D, dir='left')
        vec2i = vec2.apply_wigner(D, dir='left')

        vec_prod = cg_prod(vec1, vec2)
        veci_prod = cg_prod(vec1i, vec2i)

        vecf_prod = vec_prod.apply_wigner(D, dir='left')

        # diff = (sph_harmsr - sph_harmsd).abs()
        diff = [(p1 - p2).abs().max() for p1, p2 in zip(veci_prod, vecf_prod)]
        assert all([d < 1e-6 for d in diff])
예제 #12
0
    def test_no_maxl_w_cg_dict(self, maxl):
        cg_dict = CGDict(maxl=maxl)
        cg_prod = CGProduct(cg_dict=cg_dict)

        assert cg_prod.cg_dict is not None
        assert cg_prod.maxl is not None
예제 #13
0
 def test_no_maxl(self):
     with pytest.raises(ValueError) as e_info:
         cg_prod = CGProduct()
예제 #14
0
    def test_cg_prod_to(self, dtype1, dtype2, device1, device2):

        cg_prod = CGProduct(maxl=1, dtype=dtype1, device=device1)

        cg_prod.to(device2, dtype2)
        assert cg_prod.dtype == dtype2
        assert cg_prod.cg_dict.dtype == dtype2
        assert all([t.dtype == dtype2 for t in cg_prod.cg_dict.values()])
        assert cg_prod.device == device2
        assert cg_prod.cg_dict.device == device2
        assert all([t.device == device2 for t in cg_prod.cg_dict.values()])

        # Check that .half() work as expected
        @pytest.mark.parametrize('dtype',
                                 [None, torch.half, torch.float, torch.double])
        def test_cg_prod_half(self, maxl, dtype):

            cg_prod = CGProduct(maxl=maxl, dtype=dtype)
            cg_prod.half()
            assert cg_prod.dtype == torch.half
            assert cg_prod.cg_dict.dtype == torch.half
            assert all(
                [t.device == torch.half for t in cg_prod.cg_dict.values()])

        # Check that .float() work as expected
        @pytest.mark.parametrize('dtype',
                                 [None, torch.half, torch.float, torch.double])
        def test_cg_prod_float(self, maxl, dtype):

            cg_prod = CGProduct(maxl=maxl, dtype=dtype)
            cg_prod.float()
            assert cg_prod.dtype == torch.float
            assert cg_prod.cg_dict.dtype == torch.float
            assert all(
                [t.device == torch.float for t in cg_prod.cg_dict.values()])

        # Check that .double() work as expected
        @pytest.mark.parametrize('dtype',
                                 [None, torch.half, torch.float, torch.double])
        def test_cg_prod_double(self, maxl, dtype):

            cg_prod = CGProduct(maxl=maxl, dtype=dtype)
            cg_prod.double()
            assert cg_prod.dtype == torch.double
            assert cg_prod.cg_dict.dtype == torch.double
            assert all(
                [t.device == torch.double for t in cg_prod.cg_dict.values()])

        # Check that .cpu() work as expected
        @pytest.mark.parametrize('device', devices)
        def test_cg_prod_cpu(self, maxl, device):

            cg_prod = CGProduct(maxl=maxl, device=device)
            cg_prod.cpu()
            assert cg_prod.device == torch.device('cpu')
            assert cg_prod.cg_dict.device == torch.device('cpu')
            assert all([
                t.device == torch.device('cpu')
                for t in cg_prod.cg_dict.values()
            ])

        # Check that .cuda() work as expected
        @pytest.mark.parametrize('device', devices)
        def test_cg_prod_cuda(self, maxl, device):

            if not torch.cuda.is_available():
                return

            cg_prod = CGProduct(maxl=maxl, device=device)
            cg_prod.cuda()
            assert cg_prod.device == torch.device('cuda')
            assert cg_prod.cg_dict.device == torch.device('cuda')
            assert all([
                t.device == torch.device('cuda')
                for t in cg_prod.cg_dict.values()
            ])