def test_SO3Vec_init_channels(self, batch, maxl, channels): tau_list = [channels]*(maxl+1) test_vec = SO3Vec.rand(batch, tau_list) assert test_vec.tau == tau_list
def test_SO3Vec_init_arb_tau(self, batch, maxl, channels): tau_list = torch.randint(1, channels+1, [maxl+1]) test_vec = SO3Vec.rand(batch, tau_list) assert test_vec.tau == tau_list
def test_mix_SO3Vec(batch, maxl, channels1, channels2): tau_in = [channels1] * (maxl + 1) tau_out = [channels2] * (maxl + 1) test_vec = SO3Vec.rand(batch, tau_in) test_weight = SO3Weight.rand(tau_in, tau_out) print(test_vec.shapes, test_weight.shapes) mix(test_weight, test_vec)
def test_cg_product_dict_maxl(self, maxl_dict, maxl_prod, maxl1, maxl2, chan, batch): cg_dict = CGDict(maxl=maxl_dict, dtype=torch.double) tau1, tau2 = [chan] * (maxl1 + 1), [chan] * (maxl2 + 1) rep1 = SO3Vec.rand(batch, tau1, dtype=torch.double) rep2 = SO3Vec.rand(batch, tau2, dtype=torch.double) if all(maxl_dict >= maxl for maxl in [maxl_prod, maxl1, maxl2]): cg_prod = cg_product(cg_dict, rep1, rep2, maxl=maxl_prod) else: with pytest.raises(ValueError) as e_info: cg_prod = cg_product(cg_dict, rep1, rep2, maxl=maxl_prod) tau_out = cg_prod.tau tau_pred = cg_product_tau(tau1, tau2) # Test to make sure the output type matches the expected output type assert list(tau_out) == list(tau_pred) assert str(e_info.value).startswith('CG Dictionary maxl')
def test_apply_euler(self, batch, channels, maxl): tau = SO3Tau([channels] * (maxl + 1)) vec = SO3Vec.rand(batch, tau, dtype=torch.double) wigner = SO3WignerD.euler(maxl, dtype=torch.double) so3_torch.apply_wigner(wigner, vec)