def test_cg_dict_init(self, maxl): cg_dict = CGDict(maxl=maxl) assert len(cg_dict.keys()) == (maxl + 1)**2 for key, val in cg_dict.items(): assert val * val.t()
def test_cg_dict_device_dtype(self, maxl, device, dtype): if (device == torch.device('cuda')) and (not torch.cuda.is_available()): with pytest.raises(AssertionError) as e_info: cg_dict = CGDict(maxl=maxl, device=device, dtype=dtype) else: cg_dict = CGDict(maxl=maxl, device=device, dtype=dtype)
def test_cg_dict_to(self, maxl, dtype1, dtype2): cg_dict = CGDict(maxl=maxl, dtype=dtype1) assert cg_dict.dtype == cg_dict[(0, 0)].dtype assert cg_dict.dtype == dtype1 cg_dict.to(dtype=dtype2) assert cg_dict.dtype == cg_dict[(0, 0)].dtype assert cg_dict.dtype == dtype2
def test_cg_dict_update_maxl(self, maxl1, maxl2): maxl = max(maxl1, maxl2) cg_dict = CGDict(maxl=maxl1) cg_dict.update_maxl(maxl2) assert cg_dict.maxl == maxl assert set(cg_dict.keys()) == {(l1, l2) for l1 in range(maxl + 1) for l2 in range(maxl + 1)}
def _init_cg_dict(self, cg_dict, maxl): """ Initialize the Clebsch-Gordan dictionary. If cg_dict is set, check the following:: - The dtype of cg_dict matches with self. - The devices of cg_dict matches with self. - The desired :maxl: <= :cg_dict.maxl: so that the CGDict will contain all necessary coefficients If :cg_dict: is not set, but :maxl: is set, get the cg_dict from a dict of global CGDict() objects. """ # If cg_dict is defined, check it has the right properties if cg_dict is not None: if cg_dict.dtype != self.dtype: raise ValueError( 'CGDict dtype ({}) not match CGModule() dtype ({})'.format( cg_dict.dtype, self.dtype)) if cg_dict.device != self.device: raise ValueError( 'CGDict device ({}) not match CGModule() device ({})'. format(cg_dict.device, self.device)) if maxl is None: Warning( 'maxl is not defined, setting maxl based upon CGDict maxl ({}!' .format(cg_dict.maxl)) elif maxl > cg_dict.maxl: Warning( 'CGDict maxl ({}) is smaller than CGModule() maxl ({}). Updating!' .format(cg_dict.maxl, maxl)) cg_dict.update_maxl(maxl) self.cg_dict = cg_dict self._maxl = maxl # If cg_dict is not defined, but elif cg_dict is None and maxl is not None: self.cg_dict = CGDict(maxl=maxl, device=self.device, dtype=self.dtype) self._maxl = maxl else: self.cg_dict = None self._maxl = None
def test_cg_product_covariance(self, maxl_cg, maxl1, maxl2, channels, batch): maxl = max(maxl_cg, maxl1, maxl2) cg_dict = CGDict(maxl=maxl, dtype=torch.double) rand_rep = lambda tau, batch: [ torch.rand(batch + (t, 2 * l + 1, 2)).double() for l, t in enumerate(tau) ] angles = torch.rand(3) D, R = gen_rot(angles, maxl) tau1 = [channels] * (maxl1 + 1) tau2 = [channels] * (maxl2 + 1) rep1 = rand_rep(tau1, (batch, )) rep2 = rand_rep(tau2, (batch, )) cg_prod = cg_product(cg_dict, rep1, rep2, maxl=maxl_cg) cg_prod_rot_out = rot.rotate_rep(D, cg_prod) rep1_rot = rot.rotate_rep(D, rep1) rep2_rot = rot.rotate_rep(D, rep2) cg_prod_rot_in = cg_product(cg_dict, rep1_rot, rep2_rot, maxl=maxl_cg) for part1, part2 in zip(cg_prod_rot_out, cg_prod_rot_in): assert torch.allclose(part1, part2)
def test_cg_dict_uninit(self): cg_dict = CGDict() assert (not cg_dict) with pytest.raises(ValueError) as e_info: cg_dict[(0, 0)]
def test_spherical_harmonics_vs_scipy(self, maxl, batch, conj): cg_dict = CGDict(maxl=maxl, dtype=torch.double) pos = torch.rand(batch + (3, ), dtype=torch.double) sh = spherical_harmonics(cg_dict, pos, maxl, normalize=True, conj=conj, sh_norm='qm') sh_sp = sph_harms_from_scipy(pos, maxl, conj=conj) for part1, part2 in zip(sh, sh_sp): assert torch.allclose(part1, part2)
def test_cg_mod_set_from_cg_dict(self, maxl, dtype): cg_dict = CGDict(maxl=1, dtype=torch.float) if dtype in [torch.half, torch.double]: # If data type in CGModule does not match CGDict, throw an errror with pytest.raises(ValueError): cg_mod = CGModule(maxl=maxl, dtype=dtype, cg_dict=cg_dict) else: cg_mod = CGModule(maxl=maxl, dtype=dtype, cg_dict=cg_dict) assert cg_mod.dtype == torch.float if dtype is None else dtype assert cg_mod.device == torch.device('cpu') assert cg_mod.maxl == maxl if maxl is not None else 1 assert cg_mod.cg_dict assert cg_mod.cg_dict.maxl == max(1, maxl) if maxl is not None else 1
def test_spherical_rel_harmonics_vs_scipy(self, maxl, batch, natoms1, natoms2, conj): cg_dict = CGDict(maxl=maxl, dtype=torch.double) pos1 = torch.rand(batch + (natoms1, 3), dtype=torch.double) pos2 = torch.rand(batch + (natoms2, 3), dtype=torch.double) sh, norms = spherical_harmonics_rel(cg_dict, pos1, pos2, maxl, conj=conj, sh_norm='qm') sh_sp, norms_sp = sph_harms_rel_from_scipy(pos1, pos2, maxl, conj=conj) for l, (part1, part2) in enumerate(zip(sh, sh_sp)): if l == 0: continue assert torch.allclose(part1, part2) assert torch.allclose(norms, norms_sp)
def test_cg_product_dict_maxl(self, maxl_dict, maxl_prod, maxl1, maxl2, chan, batch): cg_dict = CGDict(maxl=maxl_dict, dtype=torch.double) tau1, tau2 = [chan] * (maxl1 + 1), [chan] * (maxl2 + 1) rep1 = SO3Vec.rand(batch, tau1, dtype=torch.double) rep2 = SO3Vec.rand(batch, tau2, dtype=torch.double) if all(maxl_dict >= maxl for maxl in [maxl_prod, maxl1, maxl2]): cg_prod = cg_product(cg_dict, rep1, rep2, maxl=maxl_prod) else: with pytest.raises(ValueError) as e_info: cg_prod = cg_product(cg_dict, rep1, rep2, maxl=maxl_prod) tau_out = cg_prod.tau tau_pred = cg_product_tau(tau1, tau2) # Test to make sure the output type matches the expected output type assert list(tau_out) == list(tau_pred) assert str(e_info.value).startswith('CG Dictionary maxl')
def test_spherical_harmonics(self, maxl, channels, conj): D, R, angles = rot.gen_rot(maxl, dtype=torch.double) D = SO3WignerD(D) if not conj: R = R.t() pos = torch.randn((channels, 3), dtype=torch.double) posr = rot.rotate_cart_vec(R, pos) cg_dict = CGDict(maxl, dtype=torch.double) sph_harms = spherical_harmonics(cg_dict, pos, maxl, conj=conj) sph_harmsr = spherical_harmonics(cg_dict, posr, maxl, conj=conj) dir = 'left' if conj else 'right' sph_harmsd = so3_torch.apply_wigner(D, sph_harms, dir=dir) # diff = (sph_harmsr - sph_harmsd).abs() diff = [(p1 - p2).abs().max() for p1, p2 in zip(sph_harmsr, sph_harmsd)] print(diff) assert all([d < 1e-6 for d in diff])
def test_CGProduct(self, batch, maxl1, maxl2, maxl, channels): maxl_all = max(maxl1, maxl2, maxl) D, R, _ = rot.gen_rot(maxl_all) cg_dict = CGDict(maxl=maxl_all, dtype=torch.double) cg_prod = CGProduct(maxl=maxl, dtype=torch.double, cg_dict=cg_dict) tau1 = SO3Tau([channels] * (maxl1 + 1)) tau2 = SO3Tau([channels] * (maxl2 + 1)) vec1 = SO3Vec.randn(tau1, batch, dtype=torch.double) vec2 = SO3Vec.randn(tau2, batch, dtype=torch.double) vec1i = vec1.apply_wigner(D, dir='left') vec2i = vec2.apply_wigner(D, dir='left') vec_prod = cg_prod(vec1, vec2) veci_prod = cg_prod(vec1i, vec2i) vecf_prod = vec_prod.apply_wigner(D, dir='left') # diff = (sph_harmsr - sph_harmsd).abs() diff = [(p1 - p2).abs().max() for p1, p2 in zip(veci_prod, vecf_prod)] assert all([d < 1e-6 for d in diff])
class CGModule(nn.Module): """ Clebsch-Gordan module. This functions identically to a normal PyTorch nn.Module, except for it adds the ability to specify a Clebsch-Gordan dictionary, and has additional tracking behavior set up to allow the CG Dictionary to be compatible with the DataParallel module. If `cg_dict` is specified upon instantiation, then the specified `cg_dict` is set as the Clebsch-Gordan dictionary for the CG module. If `cg_dict` is not specified, and `maxl` is specified, then CGModule will attempt to set the local `cg_dict` based upon the global `cormorant.cg_lib.global_cg_dicts`. If the dictionary has not been initialized with the appropriate `dtype`, `device`, and `maxl`, it will be initialized and stored in the `global_cg_dicts`, and then set to the local `cg_dict`. In this way, if there are many modules that need `CGDicts`, only a single `CGDict` will be initialized and automatically set up. Parameters ---------- cg_dict : :class:`CGDict`, optional Specify an input CGDict to use for Clebsch-Gordan operations. maxl : :class:`int`, optional Maximum weight to initialize the Clebsch-Gordan dictionary. device : :class:`torch.torch.device`, optional Device to initialize the module and Clebsch-Gordan dictionary to. dtype : :class:`torch.torch.dtype`, optional Data type to initialize the module and Clebsch-Gordan dictionary to. """ def __init__(self, cg_dict=None, maxl=None, device=None, dtype=None, *args, **kwargs): self._init_device_dtype(device, dtype) self._init_cg_dict(cg_dict, maxl) super().__init__(*args, **kwargs) def _init_device_dtype(self, device, dtype): """ Initialize the default device and data type. device : :class:`torch.torch.device`, optional Set device for CGDict and related. If unset defaults to torch.device('cpu'). dtype : :class:`torch.torch.dtype`, optional Set device for CGDict and related. If unset defaults to torch.float. """ if device is None: self._device = torch.device('cpu') else: self._device = device if dtype is None: self._dtype = torch.float else: if not (dtype == torch.half or dtype == torch.float or dtype == torch.double): raise ValueError( 'CG Module only takes internal data types of half/float/double. Got: {}' .format(dtype)) self._dtype = dtype def _init_cg_dict(self, cg_dict, maxl): """ Initialize the Clebsch-Gordan dictionary. If cg_dict is set, check the following:: - The dtype of cg_dict matches with self. - The devices of cg_dict matches with self. - The desired :maxl: <= :cg_dict.maxl: so that the CGDict will contain all necessary coefficients If :cg_dict: is not set, but :maxl: is set, get the cg_dict from a dict of global CGDict() objects. """ # If cg_dict is defined, check it has the right properties if cg_dict is not None: if cg_dict.dtype != self.dtype: raise ValueError( 'CGDict dtype ({}) not match CGModule() dtype ({})'.format( cg_dict.dtype, self.dtype)) if cg_dict.device != self.device: raise ValueError( 'CGDict device ({}) not match CGModule() device ({})'. format(cg_dict.device, self.device)) if maxl is None: Warning( 'maxl is not defined, setting maxl based upon CGDict maxl ({}!' .format(cg_dict.maxl)) elif maxl > cg_dict.maxl: Warning( 'CGDict maxl ({}) is smaller than CGModule() maxl ({}). Updating!' .format(cg_dict.maxl, maxl)) cg_dict.update_maxl(maxl) self.cg_dict = cg_dict self._maxl = maxl # If cg_dict is not defined, but elif cg_dict is None and maxl is not None: self.cg_dict = CGDict(maxl=maxl, device=self.device, dtype=self.dtype) self._maxl = maxl else: self.cg_dict = None self._maxl = None @property def device(self): return self._device @property def dtype(self): return self._dtype @property def maxl(self): return self._maxl def to(self, *args, **kwargs): super().to(*args, **kwargs) device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs) if self.cg_dict is not None: self.cg_dict.to(device=device, dtype=dtype) if device is not None: self._device = device if dtype is not None: self._dtype = dtype return self def cuda(self, device=None): if device is None: device = torch.device('cuda') elif device in range(torch.cuda.device_count()): device = torch.device('cuda:{}'.format(device)) else: ValueError('Incorrect choice of device!') super().cuda(device=device) if self.cg_dict is not None: self.cg_dict.to(device=device) self._device = device return self def cpu(self): super().cpu() if self.cg_dict is not None: self.cg_dict.to(device=torch.device('cpu')) self._device = torch.device('cpu') return self def half(self): super().half() if self.cg_dict is not None: self.cg_dict.to(dtype=torch.half) self._dtype = torch.half return self def float(self): super().float() if self.cg_dict is not None: self.cg_dict.to(dtype=torch.float) self._dtype = torch.float return self def double(self): super().double() if self.cg_dict is not None: self.cg_dict.to(dtype=torch.double) self._dtype = torch.double return self
def test_no_maxl_w_cg_dict(self, maxl): cg_dict = CGDict(maxl=maxl) cg_prod = CGProduct(cg_dict=cg_dict) assert cg_prod.cg_dict is not None assert cg_prod.maxl is not None
def test_cg_dict_init(self, maxl): cg_dict = CGDict(maxl=maxl) assert set(cg_dict.keys()) == {(l1, l2) for l1 in range(maxl + 1) for l2 in range(maxl + 1)}
def test_cg_aggregate(self, maxl_dict, maxl_prod, maxl1, maxl2, chan, batch, atom1, atom2): if any(maxl_dict < maxl for maxl in [maxl_prod, maxl1, maxl2]): return cg_dict = CGDict(maxl=maxl_dict, dtype=torch.double) rand_rep = lambda tau, batch: [ torch.rand(batch + (t, 2 * l + 1, 2)).double() for l, t in enumerate(tau) ] tau1, tau2 = [chan] * (maxl1 + 1), [chan] * (maxl2 + 1) batch1 = (batch, atom1, atom2) batch2 = (batch, atom2) rep1 = rand_rep(tau1, batch1) rep2 = rand_rep(tau2, batch2) # Calculate CG Aggregate and compare it to explicit calculation cg_agg = cg_product(cg_dict, rep1, rep2, maxl=maxl_prod, aggregate=True) cg_agg_explicit = [torch.zeros_like(p) for p in cg_agg] for bidx in range(batch): for aidx1 in range(atom1): for aidx2 in range(atom2): rep1_sub = [p[bidx, aidx1, aidx2] for p in rep1] rep2_sub = [p[bidx, aidx2] for p in rep2] cg_out = cg_product(cg_dict, rep1_sub, rep2_sub, maxl=maxl_prod, aggregate=False) out_ell = [(p.shape[-2] - 1) // 2 for p in cg_out] for ell in out_ell: cg_agg_explicit[ell][bidx, aidx1] += cg_out[ell] for part1, part2 in zip(cg_agg, cg_agg_explicit): assert torch.allclose(part1, part2) cg_agg = cg_product(cg_dict, rep2, rep1, maxl=maxl_prod, aggregate=True) cg_agg_explicit = [torch.zeros_like(p) for p in cg_agg] for bidx in range(batch): for aidx1 in range(atom1): for aidx2 in range(atom2): rep1_sub = [p[bidx, aidx1, aidx2] for p in rep1] rep2_sub = [p[bidx, aidx2] for p in rep2] cg_out = cg_product(cg_dict, rep2_sub, rep1_sub, maxl=maxl_prod, aggregate=False) out_ell = [(p.shape[-2] - 1) // 2 for p in cg_out] for ell in out_ell: cg_agg_explicit[ell][bidx, aidx1] += cg_out[ell] for part1, part2 in zip(cg_agg, cg_agg_explicit): assert torch.allclose(part1, part2)
def __init__( self, observation_space: ObservationSpace, action_space: ActionSpace, min_max_distance: Tuple[float, float], network_width: int, maxl: int, num_cg_levels: int, num_channels_hidden: int, num_channels_per_element: int, num_gaussians: int, bag_scale: int, beta: Optional[float] = None, device=None, ): super().__init__(observation_space, action_space) self.device = device self.dtype = torch.float self.zs = self.observation_space.zs self.zs_tensor = torch.tensor(self.zs, dtype=self.dtype, device=self.device) self.min_distance, self.max_distance = min_max_distance assert self.min_distance < self.max_distance self.beta = beta self.max_sh = maxl self.num_cg_levels = num_cg_levels self.num_channels_hidden = num_channels_hidden self.num_channels_per_element = num_channels_per_element self.num_gaussians = num_gaussians self.num_channels_out = len(self.zs) * self.num_channels_per_element self.channel_offsets = torch.arange(start=0, end=self.num_channels_per_element, dtype=torch.long, device=self.device).unsqueeze(0) self.cg_dict = CGDict(maxl=self.max_sh, device=self.device, dtype=self.dtype) self.cg_model = Cormorant( maxl=self.max_sh, # Cutoff in CG operations (default: [3]) max_sh=self.max_sh, # Number of spherical harmonic powers to use (default: [3]) num_cg_levels=self.num_cg_levels, # Number of CG levels (default: 4) num_channels=[self.num_channels_hidden] * self.num_cg_levels + [self.num_channels_out], num_species=len(self.zs), cutoff_type=['soft'], # Types of cutoffs to include hard_cut_rad=min(self.max_distance, 2.1), # Radius of hard cutoff (in AA) soft_cut_rad=min(self.max_distance, 2.1), # Radius of soft cutoff (in AA) soft_cut_width=0.2, # Width of SOFT cutoff in Angstroms (default: 0.2) weight_init='rand', # Weight initialization function to use (default: rand) level_gain=[10.0], # Gain at each level (default: [10.]) charge_power=2, # Maximum power to take in one-hot (default: 2) basis_set=[3, 3], # Use gaussian mask instead of sigmoid mask. charge_scale=max(self.zs), bag_scale=bag_scale, device=self.device, dtype=self.dtype, cg_dict=self.cg_dict, ) self.cg_mix = CormorantMixer( tau_in=SO3Tau([self.num_channels_per_element] * (self.max_sh + 1)), tau_other=SO3Tau([self.num_channels_per_element]), maxl=self.max_sh, num_channels=self.num_channels_per_element, level_gain=10.0, weight_init='rand', device=self.device, dtype=self.dtype, cg_dict=self.cg_dict, ) self.sph_harms = SphericalHarmonics(maxl=self.max_sh, conj=False, sh_norm='qm', device=self.device, dtype=self.dtype, cg_dict=self.cg_dict) self.atomic_scalars = AtomicScalars(maxl=self.max_sh, full_scalars=True, device=self.device, dtype=self.dtype) self.num_latent = self.atomic_scalars.get_output_dim(self.num_channels_out) self.num_latent_element = self.atomic_scalars.get_output_dim(self.num_channels_per_element) # Focus self.phi_focus = MLP( input_dim=self.num_latent, output_dims=(network_width, 1), ) # Element self.phi_element = MLP( input_dim=self.num_latent, output_dims=(network_width, len(self.zs)), ) # Distance: Gaussian Mixture Model self.phi_d = MLP( input_dim=self.num_latent_element, output_dims=(network_width, 2 * self.num_gaussians), ) self.pad_zeros = torch.nn.ConstantPad1d(padding=(0, 1), value=0.0) # Pad with one 0.0 to the right self.distance_half_width = torch.tensor((self.max_distance - self.min_distance) / 2, dtype=self.dtype, device=self.device) self.distance_center = torch.tensor((self.min_distance + self.max_distance) / 2, dtype=self.dtype, device=self.device) self.distance_log_stds = torch.nn.Parameter(torch.log( torch.tensor([0.1] * self.num_gaussians, dtype=self.dtype, device=self.device)), requires_grad=True) # (gaussians, ) # Value function self.phi_trans = MLP( input_dim=self.num_latent, output_dims=(network_width, network_width), ) self.phi_v = MLP( input_dim=network_width, output_dims=(network_width, 1), ) self.to(self.device)