def test_SO3Scalar_check_cplx_fail(self, batch, maxl, channels): tau = [channels] * (maxl+1) rand_scalar = [torch.rand(batch + (t, 1, 1)) for l, t in enumerate(tau)] with pytest.raises(ValueError) as e: SO3Scalar(rand_scalar) rand_scalar = [torch.rand(batch + (t, 1, 3)) for l, t in enumerate(tau)] with pytest.raises(ValueError) as e: SO3Scalar(rand_scalar)
def test_so3_scalar_so3_scalar_mul(self, maxl, num_middle): middle_dims = (3,) * num_middle scalar_size = (2,) + middle_dims + (4, 2) scalar1 = SO3Scalar([torch.randn(scalar_size) for i in range(maxl)]) scalar1_numpy = [numpy_from_complex(ti) for ti in scalar1] scalar2 = SO3Scalar([torch.randn(scalar_size) for i in range(maxl)]) scalar2_numpy = [numpy_from_complex(ti) for ti in scalar2] true_complex_product = [part1 * part2 for (part1, part2) in zip(scalar1_numpy, scalar2_numpy)] so3scalar_product = scalar1 * scalar2 so3scalar_product_numpy = [numpy_from_complex(ti) for ti in so3scalar_product] for exp_prod, true_prod in zip(so3scalar_product_numpy, true_complex_product): assert(np.sum(np.abs(exp_prod - true_prod)) < 1E-6)
def test_SO3Scalar_check_batch_fail(self, batch, channels): maxl = len(batch) - 1 tau = torch.randint(1, channels+1, [maxl+1]) rand_scalar = [torch.rand(b + (t, 2)) for l, (b, t) in enumerate(zip(batch, tau))] if len(set(batch)) == 1: SO3Scalar(rand_scalar) else: with pytest.raises(ValueError) as e: SO3Scalar(rand_scalar)
def test_SO3Scalar_init_arb_tau(self, batch, maxl, channels): tau_list = torch.randint(1, channels+1, [maxl+1]) test_vec = SO3Scalar.rand(batch, tau_list) assert test_vec.tau == tau_list
def test_SO3Scalar_init_channels(self, batch, maxl, channels): tau_list = [channels]*(maxl+1) test_vec = SO3Scalar.rand(batch, tau_list) assert test_vec.tau == tau_list
def test_mix_SO3Scalar(batch, maxl, channels1, channels2): tau_in = [channels1] * (maxl + 1) tau_out = [channels2] * (maxl + 1) test_scalar = SO3Scalar.rand(batch, tau_in) test_weight = SO3Weight.rand(tau_in, tau_out) print(test_scalar.shapes, test_weight.shapes) mix(test_weight, test_scalar)
def test_SO3Scalar_mul_scalar(self, batch, maxl, channels): tau = [channels] * (maxl+1) vec0 = SO3Scalar([torch.rand(batch + (t, 2)) for l, t in enumerate(tau)]) vec1 = 2 * vec0 assert all(torch.allclose(2*part0, part1) for part0, part1 in zip(vec0, vec1)) vec1 = vec0 * 2.0 assert all(torch.allclose(2*part0, part1) for part0, part1 in zip(vec0, vec1))
def test_SO3Scalar_mul_list(self, batch, maxl, channels): tau = [channels] * (maxl+1) vec0 = SO3Scalar([torch.rand(batch + (t, 2)) for l, t in enumerate(tau)]) scalar = [torch.rand(1).item() for _ in vec0] vec1 = scalar * vec0 assert all(torch.allclose(s*part0, part1) for part0, s, part1 in zip(vec0, scalar, vec1)) vec1 = vec0 * scalar assert all(torch.allclose(part0*s, part1) for part0, s, part1 in zip(vec0, scalar, vec1))
def forward(self, norms, edge_mask): # Shape to resize at end s = norms.shape # Mask and reshape edge_mask = (edge_mask * (norms > 0)).unsqueeze(-1) norms = norms.unsqueeze(-1) # Get inverse powers rad_powers = torch.stack([ torch.where(edge_mask, norms.pow(-pow), self.zero) for pow in range(self.rpow + 1) ], dim=-1) # Calculate trig functions rad_trig = torch.where( edge_mask, torch.sin((2 * pi * self.scales) * norms + self.phases), self.zero).unsqueeze(-1) # Take the product of the radial powers and the trig components and reshape rad_prod = (rad_powers * rad_trig).view(s + ( 1, 2 * self.num_rad, )) # Apply linear mixing function, if desired if self.mix == 'cplx': radial_functions = [ linear(rad_prod).view(s + (self.num_channels, 2)) for linear in self.linear ] elif self.mix == 'real': radial_functions = [ linear(rad_prod).view(s + (self.num_channels, )) for linear in self.linear ] # Hack because real-valued SO3Scalar class has not been implemented yet. # TODO: Implement real-valued SO3Scalar and fix this... radial_functions = [ torch.stack([rad, torch.zeros_like(rad)], dim=-1) for rad in radial_functions ] else: radial_functions = [rad_prod.view(s + (self.num_rad, 2)) ] * (self.max_sh + 1) return SO3Scalar(radial_functions)
def test_so3_scalar_so3_vector_mul(self, maxl, num_middle): middle_dims = (3,) * num_middle scalar = SO3Scalar([torch.randn((2,) + middle_dims + (4, 2)) for i in range(maxl)]) scalar_numpy = [numpy_from_complex(ti) for ti in scalar] vector = SO3Vec([torch.randn((2,) + middle_dims + (4, 2 * i+1, 2)) for i in range(maxl)]) vector_numpy = [numpy_from_complex(ti) for ti in vector] true_complex_product = [np.expand_dims(part1, -1) * part2 for (part1, part2) in zip(scalar_numpy, vector_numpy)] so3sv_product = vector * scalar so3sv_product_numpy = [numpy_from_complex(ti) for ti in so3sv_product] for exp_prod, true_prod in zip(so3sv_product_numpy, true_complex_product): assert(np.sum(np.abs(exp_prod - true_prod)) < 1E-6) so3sv_product = scalar * vector so3sv_product_numpy = [numpy_from_complex(ti) for ti in so3sv_product] for exp_prod, true_prod in zip(so3sv_product_numpy, true_complex_product): assert(np.sum(np.abs(exp_prod - true_prod)) < 1E-6)
def forward(self, reps): """ Performs the forward pass. Parameters ---------- reps : :class:`SO3Vec <cormorant.so3_lib.SO3Vec>` Input SO3 Vector. Returns ------- dot_products : :class:`SO3Scalar <cormorant.so3_lib.SO3Scalar>` SO3 scalars representing a Matrix of form :math:`(\psi_i \cdot \psi_j)_c`, where c is a channel index with :math:`|C| = \sum_l \tau_l`. """ if self.tau_in is not None and self.tau_in != reps.tau: raise ValueError( 'Initialized tau not consistent with tau from forward! {} {}'. format(self.tau_in, reps.tau)) signs = self.signs conj = self.conj reps1 = [part.unsqueeze(-4) for part in reps] reps2 = [part.unsqueeze(-5) for part in reps] reps2 = [part.flip(-2) * sign for part, sign in zip(reps2, signs)] dot_product_r = [(part1 * part2 * conj).sum(dim=(-2, -1)) for part1, part2 in zip(reps1, reps2)] dot_product_i = [(part1 * part2.flip(-1)).sum(dim=(-2, -1)) for part1, part2 in zip(reps1, reps2)] dot_products = [ torch.stack([prod_r, prod_i], dim=-1) for prod_r, prod_i in zip(dot_product_r, dot_product_i) ] if self.cat: dot_products = torch.cat(dot_products, dim=-2) dot_products = [dot_products] * len(reps) return SO3Scalar(dot_products)
def test_covariance(self, tau, num_channels, maxl, basis, edge_net_type, sample_batch): # env = build_environment(tau, maxl, num_channels) # datasets, data, num_species, charge_scale, sph_harms = env data, __, __ = sample_batch device, dtype = data['positions'].device, data['positions'].dtype sph_harms = SphericalHarmonicsRel(maxl - 1, conj=True, device=device, dtype=dtype, cg_dict=None) batch_size, natoms = data['positions'].shape[:2] D, R, __ = rot.gen_rot(maxl, device=device, dtype=dtype) # Setup Input atom_reps, atom_mask, edge_scalars, edge_mask, atom_positions = prep_input( data, tau, maxl) atom_positions_rot = rot.rotate_cart_vec(R, atom_positions) atom_reps_rot = atom_reps.apply_wigner(D) # Calculate spherical harmonics and radial functions __, norms = sph_harms(atom_positions, atom_positions) __, norms_rot = sph_harms(atom_positions_rot, atom_positions_rot) rad_funcs = RadialFilters([maxl - 1], [basis, basis], [num_channels], 1, device=device, dtype=dtype) rad_func_levels = rad_funcs(norms, edge_mask * (norms > 0)) tau_pos = rad_funcs.tau[0] # Build the initial edge network if edge_net_type is None: edge_reps = None elif edge_net_type == 'rand': reps = [ torch.randn((batch_size, natoms, natoms, tau, 2)) for i in range(maxl) ] edge_reps = SO3Scalar(reps) else: raise ValueError # Build Edge layer tlist = [tau] * maxl tau_atoms = tlist tau_edge = tlist if edge_net_type is None: tau_edge = [] edge_lvl = CormorantEdgeLevel(tau_atoms, tau_edge, tau_pos, num_channels, maxl, cutoff_type='soft', device=device, dtype=dtype, hard_cut_rad=1.73, soft_cut_rad=1.73, soft_cut_width=0.2) output_edge_reps = edge_lvl(edge_reps, atom_reps, rad_func_levels[0], edge_mask, norms) output_edge_reps_rot = edge_lvl(edge_reps, atom_reps_rot, rad_func_levels[0], edge_mask, norms) for i in range(maxl): assert (torch.max( torch.abs(output_edge_reps[i] - output_edge_reps_rot[i])) < 1E-5)
import torch import pytest from cormorant.so3_lib import SO3Tau, SO3Scalar, SO3Scalar rand_scalar = lambda batch, tau: SO3Scalar([torch.rand(batch + (t, 2)) for l, t in enumerate(tau)]) class TestSO3Scalar(): @pytest.mark.parametrize('batch', [(1,), (2,), (7,), (1,1), (2, 2), (7, 7)]) @pytest.mark.parametrize('maxl', range(3)) @pytest.mark.parametrize('channels', range(1, 3)) def test_SO3Scalar_init_channels(self, batch, maxl, channels): tau_list = [channels]*(maxl+1) test_vec = SO3Scalar.rand(batch, tau_list) assert test_vec.tau == tau_list @pytest.mark.parametrize('batch', [(1,), (2,), (7,), (1,1), (2, 2), (7, 7)]) @pytest.mark.parametrize('maxl', range(4)) @pytest.mark.parametrize('channels', range(1, 4)) def test_SO3Scalar_init_arb_tau(self, batch, maxl, channels): tau_list = torch.randint(1, channels+1, [maxl+1]) test_vec = SO3Scalar.rand(batch, tau_list) assert test_vec.tau == tau_list