def __init__(self, Rs_in, mul, lmax, Rs_out, size=5, layers=3): super().__init__() Rs = rs.simplify(Rs_in) Rs_out = rs.simplify(Rs_out) Rs_act = list(range(lmax + 1)) self.mul = mul self.layers = [] for _ in range(layers): conv = ImageConvolution(Rs, mul * Rs_act, size, lmax=lmax, fuzzy_pixels=True, padding=size // 2) # s2 nonlinearity act = S2Activation(Rs_act, swish, res=60) Rs = mul * act.Rs_out pool = LowPassFilter(scale=2.0, stride=2) self.layers += [torch.nn.ModuleList([conv, act, pool])] self.layers = torch.nn.ModuleList(self.layers) self.tail = LearnableTensorSquare(Rs, Rs_out)
def __init__(self, Rs_in, mul, lmax, Rs_out, layers=3): super().__init__() Rs = self.Rs_in = rs.simplify(Rs_in) self.Rs_out = rs.simplify(Rs_out) self.act = S2Activation(list(range(lmax + 1)), swish, res=20 * (lmax + 1)) self.layers = [] for _ in range(layers): lin = LearnableTensorSquare(Rs, mul * self.act.Rs_in, linear=True, allow_zero_outputs=True) # s2 nonlinearity Rs = mul * self.act.Rs_out self.layers += [lin] self.layers = torch.nn.ModuleList(self.layers) self.tail = LearnableTensorSquare(Rs, self.Rs_out)
def __init__(self, Rs_in, mul, lmax, Rs_out, layers=3): super().__init__() Rs = self.Rs_in = rs.simplify(Rs_in) self.Rs_out = rs.simplify(Rs_out) def make_act(p_val, p_arg, act): Rs = [(1, l, p_val * p_arg**l) for l in range(lmax + 1)] return S2Activation(Rs, act, res=20 * (lmax + 1)) self.act1, self.act2 = make_act(1, -1, swish), make_act(-1, -1, tanh) self.mul = mul self.layers = [] for _ in range(layers): Rs_out = mul * (self.act1.Rs_in + self.act2.Rs_in) lin = LearnableTensorSquare(Rs, Rs_out, linear=True, allow_zero_outputs=True) # s2 nonlinearity Rs = mul * (self.act1.Rs_out + self.act2.Rs_out) self.layers += [lin] self.layers = torch.nn.ModuleList(self.layers) self.tail = LearnableTensorSquare(Rs, self.Rs_out)
def __init__(self, Rs_in, Rs_out, linear=True, allow_change_output=False, allow_zero_outputs=False): super().__init__() self.Rs_in = rs.simplify(Rs_in) self.Rs_out = rs.simplify(Rs_out) ls = [l for _, l, _ in self.Rs_out] selection_rule = partial(o3.selection_rule, lfilter=lambda l: l in ls) if linear: Rs_in = [(1, 0, 1)] + self.Rs_in else: Rs_in = self.Rs_in self.linear = linear Rs_ts, T = rs.tensor_square(Rs_in, selection_rule) register_sparse_buffer(self, 'T', T) # [out, in1 * in2] ls = [l for _, l, _ in Rs_ts] if allow_change_output: self.Rs_out = [(mul, l, p) for mul, l, p in self.Rs_out if l in ls] elif not allow_zero_outputs: assert all(l in ls for _, l, _ in self.Rs_out) self.kernel = KernelLinear(Rs_ts, self.Rs_out) # [out, in, w]
def __init__(self, Rs_1, Rs_2, selection_rule=o3.selection_rule): super().__init__() self.Rs_1 = rs.simplify(Rs_1) self.Rs_2 = rs.simplify(Rs_2) Rs_out, mixing_matrix = rs.tensor_product(Rs_1, Rs_2, selection_rule) self.Rs_out = rs.simplify(Rs_out) self.register_buffer('mixing_matrix', mixing_matrix)
def __init__(self, Rs_1, Rs_2, selection_rule=o3.selection_rule): super().__init__() Rs_1 = rs.simplify(Rs_1) Rs_2 = rs.simplify(Rs_2) assert sum(mul for mul, _, _ in Rs_1) == sum(mul for mul, _, _ in Rs_2) Rs_out, mixing_matrix = rs.elementwise_tensor_product(Rs_1, Rs_2, selection_rule) self.register_buffer("mixing_matrix", mixing_matrix) self.Rs_out = rs.simplify(Rs_out)
def __init__(self, Rs_in, Rs_out, Rs_sh, RadialModel, groups=math.inf, normalization='component'): super().__init__(aggr='add', flow='target_to_source') self.Rs_in = rs.simplify(Rs_in) self.Rs_out = rs.simplify(Rs_out) self.lin1 = Linear(Rs_in, Rs_out, allow_unused_inputs=True, allow_zero_outputs=True) self.tp = GroupedWeightedTensorProduct(Rs_in, Rs_sh, Rs_out, groups=groups, normalization=normalization, own_weight=False) self.rm = RadialModel(self.tp.nweight) self.lin2 = Linear(Rs_out, Rs_out) self.Rs_sh = Rs_sh self.normalization = normalization
def WeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, normalization='component', own_weight=True): Rs_in1 = rs.simplify(Rs_in1) Rs_in2 = rs.simplify(Rs_in2) Rs_out = rs.simplify(Rs_out) instr = [ (i_1, i_2, i_out, 'uvw') for i_1, (_, l_1, p_1) in enumerate(Rs_in1) for i_2, (_, l_2, p_2) in enumerate(Rs_in2) for i_out, (_, l_out, p_out) in enumerate(Rs_out) if abs(l_1 - l_2) <= l_out <= l_1 + l_2 and p_1 * p_2 == p_out ] return CustomWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, instr, normalization, own_weight)
def __init__(self, Rs_in, Rs_out, Rs_sh, RadialModel, normalization='component'): """ :param Rs_in: input representation :param lmax: spherical harmonic representation :param Rs_out: output representation :param RadialModel: model constructor """ super().__init__(aggr='add', flow='target_to_source') self.Rs_in = rs.simplify(Rs_in) self.Rs_out = rs.simplify(Rs_out) self.tp = WeightedTensorProduct(Rs_in, Rs_sh, Rs_out, normalization, own_weight=False) self.rm = RadialModel(self.tp.nweight) self.Rs_sh = Rs_sh self.normalization = normalization
def spherical_harmonics_xyz(Rs, xyz): """ spherical harmonics :param Rs: list of L's :param xyz: tensor of shape [..., 3] :return: tensor of shape [..., m] """ Rs = rs.simplify(Rs) if xyz.device.type == 'cuda' and not xyz.requires_grad and rs.lmax(Rs) <= 10: # pragma: no cover try: return spherical_harmonics_xyz_cuda(Rs, xyz) except ImportError: pass *size, _ = xyz.shape xyz = xyz.reshape(-1, 3) xyz = xyz / torch.norm(xyz, 2, dim=1, keepdim=True) # if z > x, rotate x-axis with z-axis s = xyz[:, 2].abs() > xyz[:, 0].abs() xyz[s] = xyz[s] @ xyz.new_tensor([[0, 0, 1], [0, 1, 0], [-1, 0, 0]]) alpha = torch.atan2(xyz[:, 1], xyz[:, 0]) z = xyz[:, 2] y = (xyz[:, 0].pow(2) + xyz[:, 1].pow(2)).sqrt() sh = spherical_harmonics_alpha_z_y(Rs, alpha, z, y) # rotate back sh[s] = sh[s] @ _rep_zx(tuple(Rs), xyz.dtype, xyz.device) return sh.reshape(*size, sh.shape[1])
def __init__(self, Rs, act, res, normalization='component', lmax_out=None, random_rot=False): ''' map to the sphere, apply the non linearity point wise and project back the signal on the sphere is a quasiregular representation of O3 and we can apply a pointwise operation on these representations :param Rs: input representation of the form [(1, l, p0 * u^l) for l in [0, ..., lmax]] :param act: activation function :param res: resolution of the grid on the sphere (the higher the more accurate) :param normalization: either 'norm' or 'component' :param lmax_out: maximum l of the output :param random_rot: rotate randomly the grid ''' super().__init__() Rs = rs.simplify(Rs) _, _, p0 = Rs[0] _, lmax, _ = Rs[-1] assert all(mul == 1 for mul, _, _ in Rs) assert [l for _, l, _ in Rs] == [l for l in range(lmax + 1)] if all(p == p0 for _, l, p in Rs): u = 1 elif all(p == p0 * (-1)**l for _, l, p in Rs): u = -1 else: assert False, "the parity of the input is not well defined" self.Rs_in = Rs # the input transforms as : A_l ---> p0 * u^l * A_l # the sphere signal transforms as : f(r) ---> p0 * f(u * r) if lmax_out is None: lmax_out = lmax if p0 == +1 or p0 == 0: self.Rs_out = [(1, l, p0 * u**l) for l in range(lmax_out + 1)] if p0 == -1: x = torch.linspace(0, 10, 256) a1, a2 = act(x), act(-x) if (a1 - a2).abs().max() < a1.abs().max() * 1e-10: # p_act = 1 self.Rs_out = [(1, l, u**l) for l in range(lmax_out + 1)] elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10: # p_act = -1 self.Rs_out = [(1, l, -u**l) for l in range(lmax_out + 1)] else: # p_act = 0 raise ValueError("warning! the parity is violated") self.to_s2 = s2grid.ToS2Grid(lmax, res, normalization=normalization) self.from_s2 = s2grid.FromS2Grid(res, lmax_out, normalization=normalization, lmax_in=lmax) self.act = act self.random_rot = random_rot
def spherical_harmonics_xyz_cuda(Rs, xyz): # pragma: no cover """ cuda version of spherical_harmonics_xyz """ from e3nn import cuda_rsh # pylint: disable=no-name-in-module, import-outside-toplevel Rs = rs.simplify(Rs) *size, _ = xyz.size() xyz = xyz.reshape(-1, 3) xyz = xyz / torch.norm(xyz, 2, -1, keepdim=True) lmax = rs.lmax(Rs) out = xyz.new_empty(((lmax + 1)**2, xyz.size(0))) # [ filters, batch_size] cuda_rsh.real_spherical_harmonics(out, xyz) # (-1)^L same as (pi-theta) -> (-1)^(L+m) and 'quantum' norm (-1)^m combined # h - halved norm_coef = [elem for lh in range((lmax + 1) // 2) for elem in [1.] * (4 * lh + 1) + [-1.] * (4 * lh + 3)] if lmax % 2 == 0: norm_coef.extend([1.] * (2 * lmax + 1)) norm_coef = out.new_tensor(norm_coef).unsqueeze(1) out.mul_(norm_coef) if not rs.are_equal(Rs, list(range(lmax + 1))): out = torch.cat([out[l**2: (l + 1)**2] for mul, l, _ in Rs for _ in range(mul)]) return out.T.reshape(*size, out.shape[0])
def kernel_geometric(Rs_in, Rs_out, selection_rule=o3.selection_rule_in_out_sh, normalization='component'): # Compute Clebsh-Gordan coefficients Rs_f, Q = rs.tensor_product(Rs_in, selection_rule, Rs_out, normalization) # [out, in, Y] # Sort filters representation Rs_f, perm = rs.sort(Rs_f) Rs_f = rs.simplify(Rs_f) Q = torch.einsum('ijk,lk->ijl', Q, perm) del perm # Normalize the spherical harmonics if normalization == 'component': diag = torch.ones(rs.irrep_dim(Rs_f)) if normalization == 'norm': diag = torch.cat( [torch.ones(2 * l + 1) / math.sqrt(2 * l + 1) for _, l, _ in Rs_f]) norm_Y = math.sqrt(4 * math.pi) * torch.diag(diag) # [Y, Y] # Matrix to dispatch the spherical harmonics mat_Y = rs.map_irrep_to_Rs(Rs_f) # [Rs_f, Y] mat_Y = mat_Y @ norm_Y # Create the radial model: R+ -> R^n_path mat_R = rs.map_mul_to_Rs(Rs_f) # [Rs_f, R] mixing_matrix = torch.einsum('ijk,ky,kw->ijyw', Q, mat_Y, mat_R) # [out, in, Y, R] return Rs_f, mixing_matrix
def __init__(self, Rs, act, n): ''' map to a signal on SO3, apply the non linearity point wise and project back the signal on SO3 is the regular representation of SO3 and we can apply a pointwise operation on these representations :param Rs: input representation :param act: activation function :param n: number of point on the sphere (the higher the more accurate) ''' super().__init__() Rs = rs.simplify(Rs) mul0, _, _ = Rs[0] assert all(mul0 * (2 * l + 1) == mul for mul, l, _ in Rs) assert [l for _, l, _ in Rs] == list(range(len(Rs))) assert all(p == 0 for _, l, p in Rs) self.Rs_out = Rs x = [o3.rand_rot() for _ in range(n)] Z = torch.stack([ torch.cat([ o3.irr_repr(l, *o3.rot_to_abc(R)).flatten() * (2 * l + 1)**0.5 for l in range(len(Rs)) ]) for R in x ]) # [z, lmn] Z.div_(Z.shape[1]**0.5) self.register_buffer('Z', Z) self.act = act
def __init__(self, Rs_in, selection_rule=o3.selection_rule): super().__init__() self.Rs_in = rs.simplify(Rs_in) self.Rs_out, mixing_matrix = rs.tensor_square(Rs_in, selection_rule, sorted=True) self.register_buffer('mixing_matrix', mixing_matrix)
def __init__(self, Rs, normalization='component'): super().__init__() Rs = rs.simplify(Rs) n = sum(mul for mul, _, _ in Rs) self.Rs_in = Rs self.Rs_out = [(n, 0, +1)] self.normalization = normalization
def spherical_harmonics_alpha_z_y(Rs, alpha, z, y): """ cpu version of spherical_harmonics_alpha_beta """ Rs = rs.simplify(Rs) sha = spherical_harmonics_alpha(rs.lmax(Rs), alpha.flatten()) # [z, m] shz = spherical_harmonics_z(Rs, z.flatten(), y.flatten()) # [z, l * m] out = mul_m_lm(Rs, sha, shz) return out.reshape(alpha.shape + (shz.shape[1],))
def __init__(self, Rs, acts): ''' Can be used only with scalar fields :param acts: list of tuple (multiplicity, activation) ''' super().__init__() Rs = rs.simplify(Rs) acts = copy.deepcopy(acts) n1 = sum(mul for mul, _, _ in Rs) n2 = sum(mul for mul, _ in acts if mul > 0) for i, (mul, act) in enumerate(acts): if mul == -1: acts[i] = (n1 - n2, act) assert n1 - n2 >= 0 assert n1 == sum(mul for mul, _ in acts) i = 0 while i < len(Rs): mul_r, l, p_r = Rs[i] mul_a, act = acts[i] if mul_r < mul_a: acts[i] = (mul_r, act) acts.insert(i + 1, (mul_a - mul_r, act)) if mul_a < mul_r: Rs[i] = (mul_a, l, p_r) Rs.insert(i + 1, (mul_r - mul_a, l, p_r)) i += 1 x = torch.linspace(0, 10, 256) Rs_out = [] for (mul, l, p_in), (mul_a, act) in zip(Rs, acts): assert mul == mul_a a1, a2 = act(x), act(-x) if (a1 - a2).abs().max() < a1.abs().max() * 1e-10: p_act = 1 elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10: p_act = -1 else: p_act = 0 p = p_act if p_in == -1 else p_in Rs_out.append((mul, 0, p)) if p_in != 0 and p == 0: raise ValueError("warning! the parity is violated") self.Rs_out = Rs_out self.acts = acts
def spherical_harmonics_xyz(Rs, xyz, normalization='none'): """ spherical harmonics :param Rs: list of L's :param xyz: tensor of shape [..., 3] :return: tensor of shape [..., m] """ Rs = rs.simplify(Rs) *size, _ = xyz.shape xyz = xyz.reshape(-1, 3) d = torch.norm(xyz, 2, dim=1) xyz = xyz[d > 0] xyz = xyz / d[d > 0, None] sh = None if xyz.device.type == 'cuda' and not xyz.requires_grad and rs.lmax( Rs) <= 10: # pragma: no cover try: sh = _spherical_harmonics_xyz_cuda(Rs, xyz) except ImportError: pass if sh is None: # if z > x, rotate x-axis with z-axis s = xyz[:, 2].abs() > xyz[:, 0].abs() xyz[s] = xyz[s] @ xyz.new_tensor([[0, 0, 1], [0, 1, 0], [-1, 0, 0]]) alpha = torch.atan2(xyz[:, 1], xyz[:, 0]) z = xyz[:, 2] y = xyz[:, :2].norm(dim=1) sh = spherical_harmonics_alpha_z_y(Rs, alpha, z, y) # rotate back sh[s] = sh[s] @ _rep_zx(tuple(Rs), xyz.dtype, xyz.device) if len(d) > len(sh): out = sh.new_zeros(len(d), sh.shape[1]) out[d == 0] = math.sqrt(1 / (4 * math.pi)) * torch.cat([ sh.new_ones(1) if l == 0 else sh.new_zeros(2 * l + 1) for mul, l, p in Rs for _ in range(mul) ]) out[d > 0] = sh sh = out if normalization == 'component': sh.mul_(math.sqrt(4 * math.pi)) if normalization == 'norm': sh.mul_( torch.cat([ math.sqrt(4 * math.pi / (2 * l + 1)) * sh.new_ones(2 * l + 1) for mul, l, p in Rs for _ in range(mul) ])) return sh.reshape(*size, sh.shape[1])
def __init__(self, Rs_in1, Rs_in2, Rs_out, allow_change_output=False): super().__init__() self.Rs_in1 = rs.simplify(Rs_in1) self.Rs_in2 = rs.simplify(Rs_in2) self.Rs_out = rs.simplify(Rs_out) ls = [l for _, l, _ in self.Rs_out] selection_rule = partial(o3.selection_rule, lfilter=lambda l: l in ls) Rs_ts, T = rs.tensor_product(self.Rs_in1, self.Rs_in2, selection_rule) register_sparse_buffer(self, 'T', T) # [out, in1 * in2] ls = [l for _, l, _ in Rs_ts] if allow_change_output: self.Rs_out = [(mul, l, p) for mul, l, p in self.Rs_out if l in ls] else: assert all(l in ls for _, l, _ in self.Rs_out) self.kernel = KernelLinear(Rs_ts, self.Rs_out) # [out, in, w]
def test_weighted_tensor_product(): torch.set_default_dtype(torch.float64) Rs_in1 = rs.simplify([1] * 20 + [2] * 4) Rs_in2 = rs.simplify([0] * 10 + [1] * 10 + [2] * 5) Rs_out = rs.simplify([0] * 3 + [1] * 4) tp = WeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, groups=2) x1 = rs.randn(20, Rs_in1) x2 = rs.randn(20, Rs_in2) angles = o3.rand_angles() z1 = tp(x1, x2) @ rs.rep(Rs_out, *angles).T z2 = tp(x1 @ rs.rep(Rs_in1, *angles).T, x2 @ rs.rep(Rs_in2, *angles).T) z1.sum().backward() assert torch.allclose(z1, z2)
def __init__(self, Rs_in, Rs_out): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.simplify(Rs_in) self.Rs_out = rs.simplify(Rs_out) n_path = 0 for mul_out, l_out, p_out in self.Rs_out: for mul_in, l_in, p_in in self.Rs_in: if (l_out, p_out) == (l_in, p_in): # compute the number of degrees of freedom n_path += mul_out * mul_in self.weight = torch.nn.Parameter(torch.randn(n_path))
def spherical_harmonics_z(Rs, z, y=None): """ the z component of the spherical harmonics (useful to perform fourier transform) :param z: tensor of shape [...] :return: tensor of shape [..., l * m] """ Rs = rs.simplify(Rs) assert all(p in [0, (-1)**l] for _, l, p in Rs) ls = [l for mul, l, _ in Rs for _ in range(mul)] return legendre(ls, z, y) # [..., l * m]
def __init__(self, Rs_in, Rs_out, lmax=3): super().__init__(aggr='add', flow='target_to_source') RadialModel = partial( GaussianRadialModel, max_radius=1.2, min_radius=0.0, number_of_basis=3, h=100, L=2, act=swish ) Rs_sh = [(1, l, (-1)**l) for l in range(0, lmax + 1)] self.Rs_in = rs.simplify(Rs_in) self.Rs_out = rs.simplify(Rs_out) self.lin1 = Linear(Rs_in, Rs_out, allow_unused_inputs=True, allow_zero_outputs=True) self.tp = GroupedWeightedTensorProduct(Rs_in, Rs_sh, Rs_out, own_weight=False) self.rm = RadialModel(self.tp.nweight) self.lin2 = Linear(Rs_out, Rs_out) self.Rs_sh = Rs_sh
def kernel_linear(Rs_in, Rs_out): # Compute Clebsh-Gordan coefficients def selection_rule(l_in, p_in, l_out, p_out): if l_in == l_out and p_out in [0, p_in]: return [0] return [] Rs_f, Q = rs.tensor_product(Rs_in, selection_rule, Rs_out) # [out, in, w] Rs_f = rs.simplify(Rs_f) [(_n_path, l, p)] = Rs_f assert l == 0 and p in [0, 1] return Q
def __init__(self, Rs_in, mul, lmax, Rs_out, layers=3): super().__init__() Rs = rs.simplify(Rs_in) Rs_out = rs.simplify(Rs_out) self.layers = [] for _ in range(layers): # tensor product: nonlinear and mixes the l's tp = TensorSquare(Rs, selection_rule=partial(o3.selection_rule, lmax=lmax)) # direct sum Rs = Rs + tp.Rs_out # linear: learned but don't mix l's Rs_act = [(1, l) for l in range(lmax + 1)] lin = Linear(Rs, mul * Rs_act, allow_unused_inputs=True) # s2 nonlinearity act = S2Activation(Rs_act, swish, res=20 * (lmax + 1)) Rs = mul * act.Rs_out self.layers += [torch.nn.ModuleList([tp, lin, act])] self.layers = torch.nn.ModuleList(self.layers) def lfilter(l): return l in [j for _, j, _ in Rs_out] tp = TensorSquare(Rs, selection_rule=partial(o3.selection_rule, lfilter=lfilter)) Rs = Rs + tp.Rs_out lin = Linear(Rs, Rs_out, allow_unused_inputs=True) self.tail = torch.nn.ModuleList([tp, lin])
def __init__(self, Rs_out, scalar_activation, gate_activation): """ :param Rs_out: list of triplet (multiplicity, representation order, parity) :param scalar_activation: nonlinear function applied on l=0 channels :param gate_activation: nonlinear function applied on the gates """ super().__init__() Rs_out = rs.simplify(Rs_out) self.scalar_act = scalar_activation self.gate_act = gate_activation Rs = [] Rs_gates = [] for mul, l, p in Rs_out: if p != 0: raise ValueError("use GatedBlockParity instead") Rs.append((mul, l)) if l != 0: Rs_gates.append((mul, 0)) self.Rs = Rs self.Rs_in = rs.simplify(Rs + Rs_gates)
def from_irrep_tensor(cls, irrep_tensor): Rs_remove_p = [(mul, L) for mul, L, p in irrep_tensor.Rs] Rs, perm = rs.sort(Rs_remove_p) Rs = rs.simplify(Rs) mul, Ls, _ = zip(*Rs) if max(mul) > 1: raise ValueError( "Cannot have multiplicity greater than 1 for any L. This tensor has a simplified Rs of {}".format(Rs) ) Lmax = max(Ls) sorted_tensor = torch.einsum('ij,...j->...i', perm.to_dense(), irrep_tensor.tensor) signal = torch.zeros((Lmax + 1)**2) Rs_idx = 0 for L in range(Lmax + 1): if Rs[Rs_idx][1] == L: ten_slice = slice(rs.dim(Rs[:Rs_idx]), rs.dim(Rs[:Rs_idx + 1])) signal[L ** 2: (L + 1) ** 2] = sorted_tensor[ten_slice] Rs_idx += 1 return cls(signal)
def __init__(self, Rs, act, n): ''' map to the sphere, apply the non linearity point wise and project back the signal on the sphere is a quasiregular representation of O3 and we can apply a pointwise operation on these representations :param Rs: input representation :param act: activation function :param n: number of point on the sphere (the higher the more accurate) ''' super().__init__() Rs = rs.simplify(Rs) mul0, _, p0 = Rs[0] assert all(mul0 == mul for mul, _, _ in Rs) assert [l for _, l, _ in Rs] == list(range(len(Rs))) assert all(p == p0 for _, l, p in Rs) or all(p == p0 * (-1)**l for _, l, p in Rs) if p0 == +1 or p0 == 0: self.Rs_out = Rs if p0 == -1: x = torch.linspace(0, 10, 256) a1, a2 = act(x), act(-x) if (a1 - a2).abs().max() < a1.abs().max() * 1e-10: # p_act = 1 self.Rs_out = [(mul, l, -p) for mul, l, p in Rs] elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10: # p_act = -1 self.Rs_out = Rs else: # p_act = 0 raise ValueError("warning! the parity is violated") x = torch.randn(n, 3) x = torch.cat([x, -x]) Y = o3.spherical_harmonics_xyz(list(range(len(Rs))), x) # [lm, z] self.register_buffer('Y', Y) self.act = act
def test_simplify(): Rs = [(1, 0), 0, (1, 0)] assert rs.simplify(Rs) == [(3, 0, 0)]