def test_cg_prod_to_device(self, device1, device2): cg_prod = CGProduct(maxl=1, device=device1) cg_prod.to(device=device2) assert cg_prod.device == device2 assert cg_prod.cg_dict.device == device2 assert all([t.device == device2 for t in cg_prod.cg_dict.values()])
def test_cg_prod_to(self, dtype1, dtype2, device1, device2): cg_prod = CGProduct(maxl=1, dtype=dtype1, device=device1) cg_prod.to(device2, dtype2) assert cg_prod.dtype == dtype2 assert cg_prod.cg_dict.dtype == dtype2 assert all([t.dtype == dtype2 for t in cg_prod.cg_dict.values()]) assert cg_prod.device == device2 assert cg_prod.cg_dict.device == device2 assert all([t.device == device2 for t in cg_prod.cg_dict.values()]) # Check that .half() work as expected @pytest.mark.parametrize('dtype', [None, torch.half, torch.float, torch.double]) def test_cg_prod_half(self, maxl, dtype): cg_prod = CGProduct(maxl=maxl, dtype=dtype) cg_prod.half() assert cg_prod.dtype == torch.half assert cg_prod.cg_dict.dtype == torch.half assert all( [t.device == torch.half for t in cg_prod.cg_dict.values()]) # Check that .float() work as expected @pytest.mark.parametrize('dtype', [None, torch.half, torch.float, torch.double]) def test_cg_prod_float(self, maxl, dtype): cg_prod = CGProduct(maxl=maxl, dtype=dtype) cg_prod.float() assert cg_prod.dtype == torch.float assert cg_prod.cg_dict.dtype == torch.float assert all( [t.device == torch.float for t in cg_prod.cg_dict.values()]) # Check that .double() work as expected @pytest.mark.parametrize('dtype', [None, torch.half, torch.float, torch.double]) def test_cg_prod_double(self, maxl, dtype): cg_prod = CGProduct(maxl=maxl, dtype=dtype) cg_prod.double() assert cg_prod.dtype == torch.double assert cg_prod.cg_dict.dtype == torch.double assert all( [t.device == torch.double for t in cg_prod.cg_dict.values()]) # Check that .cpu() work as expected @pytest.mark.parametrize('device', devices) def test_cg_prod_cpu(self, maxl, device): cg_prod = CGProduct(maxl=maxl, device=device) cg_prod.cpu() assert cg_prod.device == torch.device('cpu') assert cg_prod.cg_dict.device == torch.device('cpu') assert all([ t.device == torch.device('cpu') for t in cg_prod.cg_dict.values() ]) # Check that .cuda() work as expected @pytest.mark.parametrize('device', devices) def test_cg_prod_cuda(self, maxl, device): if not torch.cuda.is_available(): return cg_prod = CGProduct(maxl=maxl, device=device) cg_prod.cuda() assert cg_prod.device == torch.device('cuda') assert cg_prod.cg_dict.device == torch.device('cuda') assert all([ t.device == torch.device('cuda') for t in cg_prod.cg_dict.values() ])