def test_cg_prod_tau_check(self, maxl1, maxl2, chan1, chan2, set_tau1, set_tau2): rand_rep = lambda tau, nbatch: [ torch.rand(nbatch + (t, 2 * l + 1, 2)).double() for l, t in enumerate(tau) ] tau1 = [chan1] * (maxl1 + 1) tau2 = [chan2] * (maxl2 + 1) rep1 = rand_rep(tau1, (2, )) rep2 = rand_rep(tau2, (2, )) tau1_in = tau1 if set_tau1 else None tau2_in = tau2 if set_tau2 else None if (set_tau1 and set_tau2) and chan1 != chan2: with pytest.raises(ValueError) as e: cg_prod = CGProduct(tau1_in, tau2_in, maxl=2) return else: cg_prod = CGProduct(tau1_in, tau2_in, maxl=2) if set_tau1 and set_tau2: tau_out = cg_prod.tau_out else: with pytest.raises(ValueError) as e: tau_out = cg_prod.tau_out
def test_cg_prod_double(self, maxl, dtype): cg_prod = CGProduct(maxl=maxl, dtype=dtype) cg_prod.double() assert cg_prod.dtype == torch.double assert cg_prod.cg_dict.dtype == torch.double assert all( [t.device == torch.double for t in cg_prod.cg_dict.values()])
def test_cg_prod_float(self, maxl, dtype): cg_prod = CGProduct(maxl=maxl, dtype=dtype) cg_prod.float() assert cg_prod.dtype == torch.float assert cg_prod.cg_dict.dtype == torch.float assert all( [t.device == torch.float for t in cg_prod.cg_dict.values()])
def test_cg_prod_half(self, maxl, dtype): cg_prod = CGProduct(maxl=maxl, dtype=dtype) cg_prod.half() assert cg_prod.dtype == torch.half assert cg_prod.cg_dict.dtype == torch.half assert all( [t.device == torch.half for t in cg_prod.cg_dict.values()])
def test_cg_prod_to_device(self, device1, device2): cg_prod = CGProduct(maxl=1, device=device1) cg_prod.to(device=device2) assert cg_prod.device == device2 assert cg_prod.cg_dict.device == device2 assert all([t.device == device2 for t in cg_prod.cg_dict.values()])
def test_cg_prod_cpu(self, maxl, device): cg_prod = CGProduct(maxl=maxl, device=device) cg_prod.cpu() assert cg_prod.device == torch.device('cpu') assert cg_prod.cg_dict.device == torch.device('cpu') assert all([ t.device == torch.device('cpu') for t in cg_prod.cg_dict.values() ])
def test_cg_prod_cuda(self, maxl, device): if not torch.cuda.is_available(): return cg_prod = CGProduct(maxl=maxl, device=device) cg_prod.cuda() assert cg_prod.device == torch.device('cuda') assert cg_prod.cg_dict.device == torch.device('cuda') assert all([ t.device == torch.device('cuda') for t in cg_prod.cg_dict.values() ])
def test_cg_prod_set_from_cg_dict(self, maxl, dtype): cg_dict = CGDict(maxl=1, dtype=torch.float) if dtype in [torch.half, torch.double]: # If data type in CGProduct does not match CGDict, throw an errror with pytest.raises(ValueError): cg_prod = CGProduct(maxl=maxl, dtype=dtype, cg_dict=cg_dict) else: cg_prod = CGProduct(maxl=maxl, dtype=dtype, cg_dict=cg_dict) assert cg_prod.dtype == torch.float if dtype is None else dtype assert cg_prod.device == torch.device('cpu') assert cg_prod.maxl == maxl if maxl is not None else 1 assert cg_prod.cg_dict assert cg_prod.cg_dict.maxl == max(1, maxl) if maxl is not None else 1
def test_cg_prod_cg_dict_dtype(self, maxl, dtype): cg_prod = CGProduct(maxl=maxl, dtype=dtype) assert cg_prod.dtype == torch.float if dtype is None else dtype assert cg_prod.device == torch.device('cpu') assert cg_prod.maxl == maxl assert cg_prod.cg_dict assert cg_prod.cg_dict.maxl == maxl
def __init__(self, tau_in, tau_pos, maxl, num_channels, level_gain, weight_init, device=None, dtype=None, cg_dict=None): super().__init__(maxl=maxl, device=device, dtype=dtype, cg_dict=cg_dict) device, dtype, cg_dict = self.device, self.dtype, self.cg_dict self.tau_in = tau_in self.tau_pos = tau_pos # Operations linear in input reps self.cg_aggregate = CGProduct(tau_pos, tau_in, maxl=self.maxl, aggregate=True, device=self.device, dtype=self.dtype, cg_dict=self.cg_dict) tau_ag = list(self.cg_aggregate.tau) self.cg_power = CGProduct(tau_in, tau_in, maxl=self.maxl, device=self.device, dtype=self.dtype, cg_dict=self.cg_dict) tau_sq = list(self.cg_power.tau) self.cat_mix = CatMixReps([tau_ag, tau_in, tau_sq], num_channels, maxl=self.maxl, weight_init=weight_init, gain=level_gain, device=self.device, dtype=self.dtype) self.tau = self.cat_mix.tau
def test_CGProduct(self, batch, maxl1, maxl2, maxl, channels): maxl_all = max(maxl1, maxl2, maxl) D, R, _ = rot.gen_rot(maxl_all) cg_dict = CGDict(maxl=maxl_all, dtype=torch.double) cg_prod = CGProduct(maxl=maxl, dtype=torch.double, cg_dict=cg_dict) tau1 = SO3Tau([channels] * (maxl1 + 1)) tau2 = SO3Tau([channels] * (maxl2 + 1)) vec1 = SO3Vec.randn(tau1, batch, dtype=torch.double) vec2 = SO3Vec.randn(tau2, batch, dtype=torch.double) vec1i = vec1.apply_wigner(D, dir='left') vec2i = vec2.apply_wigner(D, dir='left') vec_prod = cg_prod(vec1, vec2) veci_prod = cg_prod(vec1i, vec2i) vecf_prod = vec_prod.apply_wigner(D, dir='left') # diff = (sph_harmsr - sph_harmsd).abs() diff = [(p1 - p2).abs().max() for p1, p2 in zip(veci_prod, vecf_prod)] assert all([d < 1e-6 for d in diff])
def test_no_maxl_w_cg_dict(self, maxl): cg_dict = CGDict(maxl=maxl) cg_prod = CGProduct(cg_dict=cg_dict) assert cg_prod.cg_dict is not None assert cg_prod.maxl is not None
def test_no_maxl(self): with pytest.raises(ValueError) as e_info: cg_prod = CGProduct()
def test_cg_prod_to(self, dtype1, dtype2, device1, device2): cg_prod = CGProduct(maxl=1, dtype=dtype1, device=device1) cg_prod.to(device2, dtype2) assert cg_prod.dtype == dtype2 assert cg_prod.cg_dict.dtype == dtype2 assert all([t.dtype == dtype2 for t in cg_prod.cg_dict.values()]) assert cg_prod.device == device2 assert cg_prod.cg_dict.device == device2 assert all([t.device == device2 for t in cg_prod.cg_dict.values()]) # Check that .half() work as expected @pytest.mark.parametrize('dtype', [None, torch.half, torch.float, torch.double]) def test_cg_prod_half(self, maxl, dtype): cg_prod = CGProduct(maxl=maxl, dtype=dtype) cg_prod.half() assert cg_prod.dtype == torch.half assert cg_prod.cg_dict.dtype == torch.half assert all( [t.device == torch.half for t in cg_prod.cg_dict.values()]) # Check that .float() work as expected @pytest.mark.parametrize('dtype', [None, torch.half, torch.float, torch.double]) def test_cg_prod_float(self, maxl, dtype): cg_prod = CGProduct(maxl=maxl, dtype=dtype) cg_prod.float() assert cg_prod.dtype == torch.float assert cg_prod.cg_dict.dtype == torch.float assert all( [t.device == torch.float for t in cg_prod.cg_dict.values()]) # Check that .double() work as expected @pytest.mark.parametrize('dtype', [None, torch.half, torch.float, torch.double]) def test_cg_prod_double(self, maxl, dtype): cg_prod = CGProduct(maxl=maxl, dtype=dtype) cg_prod.double() assert cg_prod.dtype == torch.double assert cg_prod.cg_dict.dtype == torch.double assert all( [t.device == torch.double for t in cg_prod.cg_dict.values()]) # Check that .cpu() work as expected @pytest.mark.parametrize('device', devices) def test_cg_prod_cpu(self, maxl, device): cg_prod = CGProduct(maxl=maxl, device=device) cg_prod.cpu() assert cg_prod.device == torch.device('cpu') assert cg_prod.cg_dict.device == torch.device('cpu') assert all([ t.device == torch.device('cpu') for t in cg_prod.cg_dict.values() ]) # Check that .cuda() work as expected @pytest.mark.parametrize('device', devices) def test_cg_prod_cuda(self, maxl, device): if not torch.cuda.is_available(): return cg_prod = CGProduct(maxl=maxl, device=device) cg_prod.cuda() assert cg_prod.device == torch.device('cuda') assert cg_prod.cg_dict.device == torch.device('cuda') assert all([ t.device == torch.device('cuda') for t in cg_prod.cg_dict.values() ])