def test_SO3Vec_cat(self, batch1, batch2, batch3, channels1, channels2, channels3, maxl1, maxl2, maxl3): tau1 = [channels1] * (maxl1+1) tau2 = [channels2] * (maxl2+1) tau3 = [channels3] * (maxl2+1) tau12 = SO3Tau.cat([tau1, tau2]) tau123 = SO3Tau.cat([tau1, tau2, tau3]) vec1 = SO3Vec.randn(tau1, batch1) vec2 = SO3Vec.randn(tau2, batch2) vec3 = SO3Vec.randn(tau3, batch3) if batch1 == batch2: vec12 = so3_torch.cat([vec1, vec2]) assert vec12.tau == tau12 else: with pytest.raises(RuntimeError): vec12 = so3_torch.cat([vec1, vec2]) if batch1 == batch2 == batch3: vec123 = so3_torch.cat([vec1, vec2, vec3]) assert vec123.tau == tau123 else: with pytest.raises(RuntimeError): vec12 = so3_torch.cat([vec1, vec2, vec3])
def test_channels(self): tau = SO3Tau([3, 2, 1]) assert tau.channels == None tau = SO3Tau([3] * 2) assert tau.channels == 3
def __init__(self, max_sh, basis_set, num_channels, mix=False, device=torch.device('cpu'), dtype=torch.float): super(RadPolyTrig, self).__init__() trig_basis, rpow = basis_set self.rpow = rpow self.max_sh = max_sh assert (trig_basis >= 0 and rpow >= 0) self.num_rad = (trig_basis + 1) * (rpow + 1) self.num_channels = num_channels # This instantiates a set of functions sin(2*pi*n*x/a), cos(2*pi*n*x/a) with a=1. self.scales = torch.cat( [torch.arange(trig_basis + 1), torch.arange(trig_basis + 1)]).view(1, 1, 1, -1).to(device=device, dtype=dtype) self.phases = torch.cat([ torch.zeros(trig_basis + 1), pi / 2 * torch.ones(trig_basis + 1) ]).view(1, 1, 1, -1).to(device=device, dtype=dtype) # This avoids the sin(0*r + 0) = 0 part from wasting computations. self.phases[0, 0, 0, 0] = pi / 2 # Now, make the above learnable self.scales = nn.Parameter(self.scales) self.phases = nn.Parameter(self.phases) # If desired, mix the radial components to a desired shape self.mix = mix if (mix == 'cplx') or (mix is True): self.linear = nn.ModuleList([ nn.Linear(2 * self.num_rad, 2 * self.num_channels).to(device=device, dtype=dtype) for _ in range(max_sh + 1) ]) self.tau = SO3Tau((num_channels, ) * (max_sh + 1)) elif mix == 'real': self.linear = nn.ModuleList([ nn.Linear(2 * self.num_rad, self.num_channels).to(device=device, dtype=dtype) for _ in range(max_sh + 1) ]) self.tau = SO3Tau((num_channels, ) * (max_sh + 1)) elif (mix == 'none') or (mix is False): self.linear = None self.tau = SO3Tau((self.num_rad, ) * (max_sh + 1)) else: raise ValueError( 'Can only specify mix = real, cplx, or none! {}'.format(mix)) self.zero = torch.tensor(0, device=device, dtype=dtype)
def test_init(self, maxl, num_channels): tau = SO3Tau([num_channels] * (maxl + 1)) assert type(tau) == SO3Tau assert list(tau) == [num_channels] * (maxl + 1) tau = SO3Tau(tau) assert type(tau) == SO3Tau assert list(tau) == [num_channels] * (maxl + 1)
def test_SO3Weight_init(self, maxl, channels1, channels2): tau1_list = SO3Tau([channels1] * (maxl + 1)) tau2_list = SO3Tau([channels2] * (maxl + 1)) weight_list = rand_weight_list(tau1_list, tau2_list) weight = SO3Weight(weight_list) assert isinstance(weight, SO3Weight) assert weight.tau_in == SO3Tau(tau1_list) assert weight.tau_out == SO3Tau(tau2_list)
def __init__(self, taus_in, maxl=None): super().__init__() self.taus_in = taus_in = [SO3Tau(tau) for tau in taus_in if tau] if maxl is None: maxl = max([tau.maxl for tau in taus_in]) self.maxl = maxl self.tau_out = reduce(lambda x, y: x & y, taus_in)[:self.maxl + 1]
def __init__(self, tau_in=None, cat=True, device=None, dtype=None): super().__init__(device=device, dtype=dtype) self.tau_in = tau_in self.cat = cat if self.tau_in is not None: if cat: self.tau = SO3Tau([sum(tau_in)] * len(tau_in)) else: self.tau = SO3Tau([t for t in tau_in]) self.signs = [ torch.tensor(-1.).pow(torch.arange(-ell, ell + 1).float()).to( device=self.device, dtype=self.dtype).unsqueeze(-1) for ell in range(len(tau_in) + 1) ] self.conj = torch.tensor([1., -1.]).to(device=self.device, dtype=self.dtype) else: self.tau = None self.signs = None
def forward(self, rep): """ Linearly mix a represention. Parameters ---------- rep : :obj:`list` of :obj:`torch.Tensor` Representation to mix. Returns ------- rep : :obj:`list` of :obj:`torch.Tensor` Mixed representation. """ if SO3Tau.from_rep(rep) != self.tau_in: raise ValueError('Tau of input rep does not match initialized tau!' ' rep: {} tau: {}'.format(SO3Tau.from_rep(rep), self.tau_in)) return so3_torch.mix(self.weights, rep)
def test_from_rep(self, batch, tau0): rand_rep = lambda tau, batch: [ torch.rand(batch + (t, 2 * l + 1, 2)).double() for l, t in enumerate(tau) ] rep = rand_rep(tau0, batch) tau = SO3Tau.from_rep(rep) assert type(tau) == SO3Tau assert list(tau) == list(tau0)
def test_cat(self): tau1 = SO3Tau([1, 2, 3]) tau2 = SO3Tau([1, 1]) tau3 = SO3Tau([0, 0, 2]) tau = SO3Tau.cat([tau1, tau2]) assert list(tau) == [2, 3, 3] assert type(tau) == SO3Tau print(tau) tau = (tau1 & tau2) assert list(tau) == [2, 3, 3] tau1 &= tau2 assert list(tau1) == [2, 3, 3] tau123 = (tau1 & tau2) & tau3 assert SO3Tau.cat([tau1, tau2, tau3]) == tau123
def __init__(self, tau_in, tau_out, real=False, weight_init='randn', gain=1, device=None, dtype=None): super().__init__(device=device, dtype=dtype) tau_in = SO3Tau(tau_in) tau_out = SO3Tau(tau_out) if type(tau_out) is not int else tau_out # Allow one to set the output tau to a pre-specified number of output channels. if type(tau_out) is int: tau_out = [tau_out] * len(tau_in) self.tau_in = SO3Tau(tau_in) self.tau_out = SO3Tau(tau_out) self.real = real if weight_init is 'randn': weights = SO3Weight.randn(self.tau_in, self.tau_out, device=device, dtype=dtype) elif weight_init is 'rand': weights = SO3Weight.rand(self.tau_in, self.tau_out, device=device, dtype=dtype) weights = 2 * weights - 1 else: raise NotImplementedError( 'weight_init can only be randn or rand for now') gain = [gain / max(shape) for shape in weights.shapes] weights = gain * weights self.weights = weights.as_parameter()
def cg_product_tau(tau1, tau2, maxl=inf): """ Calulate output multiplicity of the CG Product of two SO3 Vectors given the multiplicty of two input SO3 Vectors. Parameters ---------- tau1 : :class:`list` of :class:`int`, :class:`SO3Tau`. Multiplicity of first representation. tau2 : :class:`list` of :class:`int`, :class:`SO3Tau`. Multiplicity of second representation. maxl : :class:`int` Largest weight to include in CG Product. Return ------ tau : :class:`SO3Tau` Multiplicity of output representation. """ tau1 = SO3Tau(tau1) tau2 = SO3Tau(tau2) L1, L2 = tau1.maxl, tau2.maxl L = min(L1 + L2, maxl) tau = [0] * (L + 1) for l1 in range(L1 + 1): for l2 in range(L2 + 1): lmin, lmax = abs(l1 - l2), min(l1 + l2, maxl) for l in range(lmin, lmax + 1): tau[l] += tau1[l1] return SO3Tau(tau)
def test_add(self): tau1 = SO3Tau([1, 2, 3]) tau2 = SO3Tau([1, 1]) tau = tau1 + tau2 assert type(tau) == SO3Tau assert list(tau) == list(tau1) + list(tau2) tau = tuple(tau1) + tau2 assert type(tau) == SO3Tau assert list(tau) == list(tau1) + list(tau2) tau = list(tau1) + tau2 assert type(tau) == SO3Tau assert list(tau) == list(tau1) + list(tau2) tau1p = SO3Tau(tau1) tau1p += tau2 assert type(tau1p) == SO3Tau assert list(tau1p) == list(tau1) + list(tau2) tau = sum([SO3Tau([3, 2, 1]), [1], (2, 3)]) assert type(tau) == SO3Tau assert list(tau) == [3, 2, 1, 1, 2, 3]
def test_CGProduct(self, batch, maxl1, maxl2, maxl, channels): maxl_all = max(maxl1, maxl2, maxl) D, R, _ = rot.gen_rot(maxl_all) cg_dict = CGDict(maxl=maxl_all, dtype=torch.double) cg_prod = CGProduct(maxl=maxl, dtype=torch.double, cg_dict=cg_dict) tau1 = SO3Tau([channels] * (maxl1 + 1)) tau2 = SO3Tau([channels] * (maxl2 + 1)) vec1 = SO3Vec.randn(tau1, batch, dtype=torch.double) vec2 = SO3Vec.randn(tau2, batch, dtype=torch.double) vec1i = vec1.apply_wigner(D, dir='left') vec2i = vec2.apply_wigner(D, dir='left') vec_prod = cg_prod(vec1, vec2) veci_prod = cg_prod(vec1i, vec2i) vecf_prod = vec_prod.apply_wigner(D, dir='left') # diff = (sph_harmsr - sph_harmsd).abs() diff = [(p1 - p2).abs().max() for p1, p2 in zip(veci_prod, vecf_prod)] assert all([d < 1e-6 for d in diff])
def __init__(self, taus_in, tau_out, maxl=None, real=False, weight_init='randn', gain=1, device=None, dtype=None): super().__init__(device=device, dtype=dtype) self.cat_reps = CatReps(taus_in, maxl=maxl) self.mix_reps = MixReps(self.cat_reps.tau, tau_out, real=real, weight_init=weight_init, gain=gain, device=device, dtype=dtype) self.taus_in = taus_in self.tau_out = SO3Tau(self.mix_reps)
def tau(self): return SO3Tau([self.channels_out])
def test_apply_euler(self, batch, channels, maxl): tau = SO3Tau([channels] * (maxl + 1)) vec = SO3Vec.rand(batch, tau, dtype=torch.double) wigner = SO3WignerD.euler(maxl, dtype=torch.double) so3_torch.apply_wigner(wigner, vec)
def __init__( self, observation_space: ObservationSpace, action_space: ActionSpace, min_max_distance: Tuple[float, float], network_width: int, maxl: int, num_cg_levels: int, num_channels_hidden: int, num_channels_per_element: int, num_gaussians: int, bag_scale: int, beta: Optional[float] = None, device=None, ): super().__init__(observation_space, action_space) self.device = device self.dtype = torch.float self.zs = self.observation_space.zs self.zs_tensor = torch.tensor(self.zs, dtype=self.dtype, device=self.device) self.min_distance, self.max_distance = min_max_distance assert self.min_distance < self.max_distance self.beta = beta self.max_sh = maxl self.num_cg_levels = num_cg_levels self.num_channels_hidden = num_channels_hidden self.num_channels_per_element = num_channels_per_element self.num_gaussians = num_gaussians self.num_channels_out = len(self.zs) * self.num_channels_per_element self.channel_offsets = torch.arange(start=0, end=self.num_channels_per_element, dtype=torch.long, device=self.device).unsqueeze(0) self.cg_dict = CGDict(maxl=self.max_sh, device=self.device, dtype=self.dtype) self.cg_model = Cormorant( maxl=self.max_sh, # Cutoff in CG operations (default: [3]) max_sh=self.max_sh, # Number of spherical harmonic powers to use (default: [3]) num_cg_levels=self.num_cg_levels, # Number of CG levels (default: 4) num_channels=[self.num_channels_hidden] * self.num_cg_levels + [self.num_channels_out], num_species=len(self.zs), cutoff_type=['soft'], # Types of cutoffs to include hard_cut_rad=min(self.max_distance, 2.1), # Radius of hard cutoff (in AA) soft_cut_rad=min(self.max_distance, 2.1), # Radius of soft cutoff (in AA) soft_cut_width=0.2, # Width of SOFT cutoff in Angstroms (default: 0.2) weight_init='rand', # Weight initialization function to use (default: rand) level_gain=[10.0], # Gain at each level (default: [10.]) charge_power=2, # Maximum power to take in one-hot (default: 2) basis_set=[3, 3], # Use gaussian mask instead of sigmoid mask. charge_scale=max(self.zs), bag_scale=bag_scale, device=self.device, dtype=self.dtype, cg_dict=self.cg_dict, ) self.cg_mix = CormorantMixer( tau_in=SO3Tau([self.num_channels_per_element] * (self.max_sh + 1)), tau_other=SO3Tau([self.num_channels_per_element]), maxl=self.max_sh, num_channels=self.num_channels_per_element, level_gain=10.0, weight_init='rand', device=self.device, dtype=self.dtype, cg_dict=self.cg_dict, ) self.sph_harms = SphericalHarmonics(maxl=self.max_sh, conj=False, sh_norm='qm', device=self.device, dtype=self.dtype, cg_dict=self.cg_dict) self.atomic_scalars = AtomicScalars(maxl=self.max_sh, full_scalars=True, device=self.device, dtype=self.dtype) self.num_latent = self.atomic_scalars.get_output_dim(self.num_channels_out) self.num_latent_element = self.atomic_scalars.get_output_dim(self.num_channels_per_element) # Focus self.phi_focus = MLP( input_dim=self.num_latent, output_dims=(network_width, 1), ) # Element self.phi_element = MLP( input_dim=self.num_latent, output_dims=(network_width, len(self.zs)), ) # Distance: Gaussian Mixture Model self.phi_d = MLP( input_dim=self.num_latent_element, output_dims=(network_width, 2 * self.num_gaussians), ) self.pad_zeros = torch.nn.ConstantPad1d(padding=(0, 1), value=0.0) # Pad with one 0.0 to the right self.distance_half_width = torch.tensor((self.max_distance - self.min_distance) / 2, dtype=self.dtype, device=self.device) self.distance_center = torch.tensor((self.min_distance + self.max_distance) / 2, dtype=self.dtype, device=self.device) self.distance_log_stds = torch.nn.Parameter(torch.log( torch.tensor([0.1] * self.num_gaussians, dtype=self.dtype, device=self.device)), requires_grad=True) # (gaussians, ) # Value function self.phi_trans = MLP( input_dim=self.num_latent, output_dims=(network_width, network_width), ) self.phi_v = MLP( input_dim=network_width, output_dims=(network_width, 1), ) self.to(self.device)
def tau(self): return SO3Tau([])