Example #1
0
    def test_cg_mod_double(self, maxl, dtype):

        cg_mod = CGModule(maxl=maxl, dtype=dtype)
        cg_mod.double()
        assert cg_mod.dtype == torch.double
        assert cg_mod.cg_dict.dtype == torch.double
        assert all([t.dtype == torch.double for t in cg_mod.cg_dict.values()])
Example #2
0
    def test_cg_mod_float(self, maxl, dtype):

        cg_mod = CGModule(maxl=maxl, dtype=dtype)
        cg_mod.float()
        assert cg_mod.dtype == torch.float
        assert cg_mod.cg_dict.dtype == torch.float
        assert all([t.dtype == torch.float for t in cg_mod.cg_dict.values()])
Example #3
0
    def test_cg_mod_to_device(self, device1, device2):

        cg_mod = CGModule(maxl=1, device=device1)

        cg_mod.to(device=device2)
        assert cg_mod.device == device2
        assert cg_mod.cg_dict.device == device2
        assert all([t.device == device2 for t in cg_mod.cg_dict.values()])
Example #4
0
    def test_cg_mod_cpu(self, maxl, device):

        cg_mod = CGModule(maxl=maxl, device=device)
        cg_mod.cpu()
        assert cg_mod.device == torch.device('cpu')
        assert cg_mod.cg_dict.device == torch.device('cpu')
        assert all(
            [t.device == torch.device('cpu') for t in cg_mod.cg_dict.values()])
Example #5
0
    def test_cg_mod_half(self, maxl, dtype):

        cg_mod = CGModule(maxl=maxl, dtype=dtype)
        print(cg_mod.dtype, dtype)
        cg_mod.half()
        print(cg_mod.dtype, dtype)
        assert cg_mod.dtype == torch.half
        assert cg_mod.cg_dict.dtype == torch.half
        assert all([t.dtype == torch.half for t in cg_mod.cg_dict.values()])
Example #6
0
    def test_cg_mod_device(self, dtype):

        if dtype == torch.long:
            with pytest.raises(ValueError):
                cg_mod = CGModule(dtype=dtype)
        else:
            cg_mod = CGModule(dtype=dtype)
            assert cg_mod.dtype == dtype
            assert cg_mod.device == torch.device('cpu')
            assert cg_mod.maxl is None
            assert cg_mod.cg_dict is None
Example #7
0
    def test_cg_mod_cuda(self, maxl, device):

        if not torch.cuda.is_available():
            return

        cg_mod = CGModule(maxl=maxl, device=device)
        cg_mod.cuda()
        assert cg_mod.device == torch.device('cuda')
        assert cg_mod.cg_dict.device == torch.device('cuda')
        assert all([
            t.device == torch.device('cuda') for t in cg_mod.cg_dict.values()
        ])
Example #8
0
    def test_cg_mod_set_from_cg_dict(self, maxl, dtype):

        cg_dict = CGDict(maxl=1, dtype=torch.float)

        if dtype in [torch.half, torch.double]:
            # If data type in CGModule does not match CGDict, throw an errror
            with pytest.raises(ValueError):
                cg_mod = CGModule(maxl=maxl, dtype=dtype, cg_dict=cg_dict)
        else:
            cg_mod = CGModule(maxl=maxl, dtype=dtype, cg_dict=cg_dict)

            assert cg_mod.dtype == torch.float if dtype is None else dtype
            assert cg_mod.device == torch.device('cpu')
            assert cg_mod.maxl == maxl if maxl is not None else 1
            assert cg_mod.cg_dict
            assert cg_mod.cg_dict.maxl == max(1,
                                              maxl) if maxl is not None else 1
Example #9
0
    def test_cg_mod_cg_dict_dtype(self, maxl, dtype):

        cg_mod = CGModule(maxl=maxl, dtype=dtype)
        assert cg_mod.dtype == torch.float if dtype is None else dtype
        assert cg_mod.device == torch.device('cpu')
        assert cg_mod.maxl == maxl
        assert cg_mod.cg_dict
        assert cg_mod.cg_dict.maxl == maxl
Example #10
0
 def test_cg_mod_nodict(self):
     cg_mod = CGModule()
     assert cg_mod.maxl is None
     assert not cg_mod.cg_dict
     assert cg_mod.device == torch.device('cpu')
     assert cg_mod.dtype == torch.float