def test_SO3Vec_cat(self, batch1, batch2, batch3, channels1, channels2, channels3, maxl1, maxl2, maxl3): tau1 = [channels1] * (maxl1+1) tau2 = [channels2] * (maxl2+1) tau3 = [channels3] * (maxl2+1) tau12 = SO3Tau.cat([tau1, tau2]) tau123 = SO3Tau.cat([tau1, tau2, tau3]) vec1 = SO3Vec.randn(tau1, batch1) vec2 = SO3Vec.randn(tau2, batch2) vec3 = SO3Vec.randn(tau3, batch3) if batch1 == batch2: vec12 = so3_torch.cat([vec1, vec2]) assert vec12.tau == tau12 else: with pytest.raises(RuntimeError): vec12 = so3_torch.cat([vec1, vec2]) if batch1 == batch2 == batch3: vec123 = so3_torch.cat([vec1, vec2, vec3]) assert vec123.tau == tau123 else: with pytest.raises(RuntimeError): vec12 = so3_torch.cat([vec1, vec2, vec3])
def test_covariance(self, tau, num_channels, maxl, sample_batch): # setup the environment # env = build_environment(tau, maxl, num_channels) # datasets, data, num_species, charge_scale, sph_harms = env data, __, __ = sample_batch device, dtype = data['positions'].device, data['positions'].dtype sph_harms = SphericalHarmonicsRel(maxl - 1, conj=True, device=device, dtype=dtype, cg_dict=None) D, R, __ = rot.gen_rot(maxl, device=device, dtype=dtype) # Build Atom layer tlist = [tau] * maxl print(tlist) atom_lvl = CormorantAtomLevel(tlist, tlist, maxl, num_channels, 1, 'rand', device=device, dtype=dtype, cg_dict=None) # Setup Input atom_rep, atom_mask, edge_scalars, edge_mask, atom_positions = prep_input( data, tau, maxl) atom_positions_rot = rot.rotate_cart_vec(R, atom_positions) # Get nonrotated data spherical_harmonics, norms = sph_harms(atom_positions, atom_positions) edge_rep_list = [ torch.cat([sph_l] * tau, axis=-3) for sph_l in spherical_harmonics ] edge_reps = SO3Vec(edge_rep_list) print(edge_reps.shapes) print(atom_rep.shapes) # Get Rotated output output = atom_lvl(atom_rep, edge_reps, atom_mask) output = output.apply_wigner(D) # Get rotated outputdata atom_rep_rot = atom_rep.apply_wigner(D) spherical_harmonics_rot, norms = sph_harms(atom_positions_rot, atom_positions_rot) edge_rep_list_rot = [ torch.cat([sph_l] * tau, axis=-3) for sph_l in spherical_harmonics_rot ] edge_reps_rot = SO3Vec(edge_rep_list_rot) output_from_rot = atom_lvl(atom_rep_rot, edge_reps_rot, atom_mask) for i in range(maxl): assert (torch.max(torch.abs(output_from_rot[i] - output[i])) < 1E-5)
def test_SO3Vec_check_cplx_fail(self, batch, maxl, channels): tau = [channels] * (maxl+1) rand_vec = [torch.rand(batch + (t, 2*l+1, 1)) for l, t in enumerate(tau)] with pytest.raises(ValueError) as e: SO3Vec(rand_vec) rand_vec = [torch.rand(batch + (t, 2*l+1, 3)) for l, t in enumerate(tau)] with pytest.raises(ValueError) as e: SO3Vec(rand_vec)
def test_SO3Vec_check_batch_fail(self, batch, channels): maxl = len(batch) - 1 tau = torch.randint(1, channels+1, [maxl+1]) rand_vec = [torch.rand(b + (t, 2*l+1, 2)) for l, (b, t) in enumerate(zip(batch, tau))] if len(set(batch)) == 1: SO3Vec(rand_vec) else: with pytest.raises(ValueError) as e: SO3Vec(rand_vec)
def test_so3_vector_so3_vector_mul(self): maxl = 2 middle_dims = (3, 3) vector1 = SO3Vec([torch.randn((2,) + middle_dims + (4, 2 * l+1, 2)) for l in range(maxl)]) vector1_numpy = [numpy_from_complex(ti) for ti in vector1] vector2 = SO3Vec([torch.randn((2,) + middle_dims + (4, 2 * l+1, 2)) for l in range(maxl)]) vector2_numpy = [numpy_from_complex(ti) for ti in vector2] true_complex_product = [part1 * part2 for (part1, part2) in zip(vector1_numpy, vector2_numpy)] with pytest.warns(RuntimeWarning): so3sv_product = vector1 * vector2 so3sv_product_numpy = [numpy_from_complex(ti) for ti in so3sv_product] for exp_prod, true_prod in zip(so3sv_product_numpy, true_complex_product): assert(np.sum(np.abs(exp_prod - true_prod)) < 1E-6)
def forward(self, atom_features, atom_mask, ignore, edge_mask, norms): """ Forward pass for :class:`InputLinear` layer. Parameters ---------- atom_features : :class:`torch.Tensor` Input atom features, i.e., a one-hot embedding of the atom type, atom charge, and any other related inputs. atom_mask : :class:`torch.Tensor` Mask used to account for padded atoms for unequal batch sizes. edge_features : :class:`torch.Tensor` Unused. Included only for pedagogical purposes. edge_mask : :class:`torch.Tensor` Unused. Included only for pedagogical purposes. norms : :class:`torch.Tensor` Unused. Included only for pedagogical purposes. Returns ------- :class:`SO3Vec` Processed atom features to be used as input to Clebsch-Gordan layers as part of Cormorant. """ atom_mask = atom_mask.unsqueeze(-1) out = torch.where(atom_mask, self.lin(atom_features), self.zero) out = out.view(atom_features.shape[0:2] + (self.channels_out, 1, 2)) return SO3Vec([out])
def normalize_alms(a_lms: SO3Vec) -> SO3Vec: # Normalize a_lms such that: # \sum_\ell \sum_m | a_lm |^2 = 1 k = get_normalization_constant(a_lms) # [batches] clamped_k = k.clamp(min=1e-10) sqrt_k = torch.sqrt(clamped_k).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # [batches, 1, 1, 1] return SO3Vec([part / sqrt_k for part in a_lms])
def test_SO3Vec_init_arb_tau(self, batch, maxl, channels): tau_list = torch.randint(1, channels+1, [maxl+1]) test_vec = SO3Vec.rand(batch, tau_list) assert test_vec.tau == tau_list
def test_SO3Vec_init_channels(self, batch, maxl, channels): tau_list = [channels]*(maxl+1) test_vec = SO3Vec.rand(batch, tau_list) assert test_vec.tau == tau_list
def concat_so3vecs(so3vecs: List[SO3Vec]) -> SO3Vec: # Concat SO3Vecs along batch dimension # Dimensions of SO3Vec's for each ell: (batches, taus, ms, 2) # Ensure that all SO3 vectors are of the same kind assert all(so3vec.ells == so3vecs[0].ells for so3vec in so3vecs) return SO3Vec(list(map(lambda tensors: torch.cat(tensors, dim=0), zip(*so3vecs))))
def select_taus(vec: SO3Vec, indices: torch.Tensor) -> SO3Vec: vectors = [] # vec: (..., taus, ms, 2) for ell in vec.ells: gather_indices = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, (2 * ell + 1), 2) vectors.append(torch.gather(vec[ell], dim=1, index=gather_indices)) return SO3Vec(vectors) # (..., sliced_taus, ms, 2)
def select_atomic_covariats(vec: SO3Vec, focus: torch.Tensor) -> SO3Vec: # vec (per ell): [batches, atoms, taus, ms, 2] # focus: [batches, atoms] vectors = [] for ell in vec.ells: vectors.append(torch.einsum('ba,batmx->btmx', focus, vec[ell])) # type: ignore return SO3Vec(vectors) # (batches, taus, ms, 2)
def estimate_alms(y_lms_conj: SO3Vec) -> SO3Vec: # Dimensions of SO3Vec's for each ell: (batches, taus, ms, 2) # Compute mean over samples means = [] for ell in y_lms_conj.ells: # select all batch dimensions dim = list(range(len(y_lms_conj[ell].shape) - 3)) means.append(torch.mean(y_lms_conj[ell], dim=dim, keepdim=True)) return SO3Vec(means)
def select_element(vec: SO3Vec, element_oh: torch.Tensor) -> SO3Vec: # vec (per ell): [batches, taus, ms, 2] # element_oh: [batches, taus] tensors = [] for ell in vec.ells: t = torch.einsum('bt,btmx->bmx', element_oh, vec[ell]) # type: ignore # [batches, ms, 2] t = t.unsqueeze(dim=-3) # [batches, 1, ms, 2] tensors.append(t) return SO3Vec(tensors) # [batches, 1, ms, 2]
def test_SO3Vec_mul_scalar(self, batch, maxl, channels): tau = [channels] * (maxl+1) vec0 = SO3Vec([torch.rand(batch + (t, 2*l+1, 2)) for l, t in enumerate(tau)]) vec1 = 2 * vec0 assert all(torch.allclose(2*part0, part1) for part0, part1 in zip(vec0, vec1)) vec1 = vec0 * 2.0 assert all(torch.allclose(2*part0, part1) for part0, part1 in zip(vec0, vec1))
def test_mix_SO3Vec(batch, maxl, channels1, channels2): tau_in = [channels1] * (maxl + 1) tau_out = [channels2] * (maxl + 1) test_vec = SO3Vec.rand(batch, tau_in) test_weight = SO3Weight.rand(tau_in, tau_out) print(test_vec.shapes, test_weight.shapes) mix(test_weight, test_vec)
def test_SO3Vec_add_list(self, batch, maxl, channels): tau = [channels] * (maxl+1) vec0 = SO3Vec([torch.rand(batch + (t, 2*l+1, 2)) for l, t in enumerate(tau)]) scalar = [torch.rand(1).item() for _ in vec0] vec1 = scalar + vec0 assert all(torch.allclose(s + part0, part1) for part0, s, part1 in zip(vec0, scalar, vec1)) vec1 = vec0 + scalar assert all(torch.allclose(part0 + s, part1) for part0, s, part1 in zip(vec0, scalar, vec1))
def prep_input(data, taus, maxl): atom_positions = data['positions'] atom_scalar_list = [ torch.randn(atom_positions.shape[:2] + (taus, 2 * l + 1, 2)) for l in range(maxl) ] # atom_scalar_list = [torch.randn(atom_positions.shape[:2] + (num_channels, 1, 2))] # atom_scalar_list += [torch.zeros(atom_positions.shape[:2] + (num_channels, 2*l+1, 2)) for l in range(1, maxl)] atom_scalars = SO3Vec(atom_scalar_list) atom_mask = data['atom_mask'] edge_mask = data['edge_mask'] edge_scalars = torch.tensor([]) return atom_scalars, atom_mask, edge_scalars, edge_mask, atom_positions
def test_cg_product_dict_maxl(self, maxl_dict, maxl_prod, maxl1, maxl2, chan, batch): cg_dict = CGDict(maxl=maxl_dict, dtype=torch.double) tau1, tau2 = [chan] * (maxl1 + 1), [chan] * (maxl2 + 1) rep1 = SO3Vec.rand(batch, tau1, dtype=torch.double) rep2 = SO3Vec.rand(batch, tau2, dtype=torch.double) if all(maxl_dict >= maxl for maxl in [maxl_prod, maxl1, maxl2]): cg_prod = cg_product(cg_dict, rep1, rep2, maxl=maxl_prod) else: with pytest.raises(ValueError) as e_info: cg_prod = cg_product(cg_dict, rep1, rep2, maxl=maxl_prod) tau_out = cg_prod.tau tau_pred = cg_product_tau(tau1, tau2) # Test to make sure the output type matches the expected output type assert list(tau_out) == list(tau_pred) assert str(e_info.value).startswith('CG Dictionary maxl')
def test_CGProduct(self, batch, maxl1, maxl2, maxl, channels): maxl_all = max(maxl1, maxl2, maxl) D, R, _ = rot.gen_rot(maxl_all) cg_dict = CGDict(maxl=maxl_all, dtype=torch.double) cg_prod = CGProduct(maxl=maxl, dtype=torch.double, cg_dict=cg_dict) tau1 = SO3Tau([channels] * (maxl1 + 1)) tau2 = SO3Tau([channels] * (maxl2 + 1)) vec1 = SO3Vec.randn(tau1, batch, dtype=torch.double) vec2 = SO3Vec.randn(tau2, batch, dtype=torch.double) vec1i = vec1.apply_wigner(D, dir='left') vec2i = vec2.apply_wigner(D, dir='left') vec_prod = cg_prod(vec1, vec2) veci_prod = cg_prod(vec1i, vec2i) vecf_prod = vec_prod.apply_wigner(D, dir='left') # diff = (sph_harmsr - sph_harmsd).abs() diff = [(p1 - p2).abs().max() for p1, p2 in zip(veci_prod, vecf_prod)] assert all([d < 1e-6 for d in diff])
def test_so3_scalar_so3_vector_mul(self, maxl, num_middle): middle_dims = (3,) * num_middle scalar = SO3Scalar([torch.randn((2,) + middle_dims + (4, 2)) for i in range(maxl)]) scalar_numpy = [numpy_from_complex(ti) for ti in scalar] vector = SO3Vec([torch.randn((2,) + middle_dims + (4, 2 * i+1, 2)) for i in range(maxl)]) vector_numpy = [numpy_from_complex(ti) for ti in vector] true_complex_product = [np.expand_dims(part1, -1) * part2 for (part1, part2) in zip(scalar_numpy, vector_numpy)] so3sv_product = vector * scalar so3sv_product_numpy = [numpy_from_complex(ti) for ti in so3sv_product] for exp_prod, true_prod in zip(so3sv_product_numpy, true_complex_product): assert(np.sum(np.abs(exp_prod - true_prod)) < 1E-6) so3sv_product = scalar * vector so3sv_product_numpy = [numpy_from_complex(ti) for ti in so3sv_product] for exp_prod, true_prod in zip(so3sv_product_numpy, true_complex_product): assert(np.sum(np.abs(exp_prod - true_prod)) < 1E-6)
def test_apply_euler(self, batch, channels, maxl): tau = SO3Tau([channels] * (maxl + 1)) vec = SO3Vec.rand(batch, tau, dtype=torch.double) wigner = SO3WignerD.euler(maxl, dtype=torch.double) so3_torch.apply_wigner(wigner, vec)
def unsqueeze_so3vec(vec: SO3Vec, dim: int) -> SO3Vec: return SO3Vec([t.unsqueeze(dim) for t in vec])
import torch import pytest from cormorant.so3_lib import SO3Tau, SO3Vec, SO3Scalar rand_vec = lambda batch, tau: SO3Vec([torch.rand(batch + (t, 2*l+1, 2)) for l, t in enumerate(tau)]) class TestSO3Vec(): @pytest.mark.parametrize('batch', [(1,), (2,), (7,), (1,1), (2, 2), (7, 7)]) @pytest.mark.parametrize('maxl', range(3)) @pytest.mark.parametrize('channels', range(1, 3)) def test_SO3Vec_init_channels(self, batch, maxl, channels): tau_list = [channels]*(maxl+1) test_vec = SO3Vec.rand(batch, tau_list) assert test_vec.tau == tau_list @pytest.mark.parametrize('batch', [(1,), (2,), (7,), (1,1), (2, 2), (7, 7)]) @pytest.mark.parametrize('maxl', range(4)) @pytest.mark.parametrize('channels', range(1, 4)) def test_SO3Vec_init_arb_tau(self, batch, maxl, channels): tau_list = torch.randint(1, channels+1, [maxl+1]) test_vec = SO3Vec.rand(batch, tau_list) assert test_vec.tau == tau_list
def step(self, observations: List[ObservationType], actions: Optional[np.ndarray] = None) -> dict: data = self.parse_observations(observations) # Cast action to tensor if actions is not None: actions = torch.as_tensor(actions, dtype=torch.float, device=self.device) # SO3Vec (batches, atoms, taus, ms, 2) covariats = self.cg_model(data) # Compute invariants invariats = self.atomic_scalars(covariats) # (batches, atoms, inv_feats) # Focus focus_logits = self.phi_focus(invariats) # (batches, atoms, 1) focus_logits = focus_logits.squeeze(-1) # (batches, atoms) focus_probs = masked_softmax(focus_logits, mask=data['focus_mask']) # (batches, atoms) focus_dist = torch.distributions.Categorical(probs=focus_probs) # focus: (batches, 1) if actions is not None: focus = torch.round(actions[:, :1]).long() elif self.training: focus = focus_dist.sample().unsqueeze(-1) else: focus = torch.argmax(focus_probs, dim=-1).unsqueeze(-1) focus_oh = to_one_hot(focus, num_classes=self.observation_space.canvas_space.size, device=self.device) # (batches, atoms) focused_cov = so3_tools.select_atomic_covariats(covariats, focus_oh) # (batches, taus, ms, 2) focused_inv = so3_tools.select_atomic_invariats(invariats, focus_oh) # (batches, feats) # Element element_logits = self.phi_element(focused_inv) # (batches, zs) element_probs = masked_softmax(element_logits, mask=data['element_mask']) # (batches, zs) element_dist = torch.distributions.Categorical(probs=element_probs) # element: (batches, 1) if actions is not None: element = torch.round(actions[:, 1:2]).long() elif self.training: element = element_dist.sample().unsqueeze(-1) else: element = torch.argmax(element_probs, dim=-1).unsqueeze(-1) # Crop element offsets = self.channel_offsets.expand(len(observations), -1) # (batches, channels_per_element) indices = offsets + element * self.num_channels_per_element element_cov = so3_tools.select_taus(focused_cov, indices=indices) element_inv = self.atomic_scalars(element_cov) # (batches, inv_feats) # Distance: Gaussian mixture model # gmm_log_probs, d_mean_trans: (batches, gaussians) gmm_log_probs, d_mean_trans = self.phi_d(element_inv).split(self.num_gaussians, dim=-1) distance_mean = torch.tanh(d_mean_trans) * self.distance_half_width + self.distance_center distance_dist = GaussianMixtureModel(log_probs=gmm_log_probs, means=distance_mean, stds=torch.exp(self.distance_log_stds).clamp(1e-6)) # distance: (batches, 1) if actions is not None: distance = actions[:, 2:3] elif self.training: # Ensure that the sampled distance is > 0 distance = distance_dist.sample().clamp(0.001).unsqueeze(-1) else: distance = distance_dist.argmax().unsqueeze(-1) # Condition on distance transformed_d = distance.unsqueeze(1).unsqueeze(1).expand(-1, self.num_channels_per_element, 1, -1) transformed_d = self.pad_zeros(transformed_d) distance_so3 = SO3Vec([transformed_d]) cond_cov = self.cg_mix(element_cov, distance_so3) so3_dist = self.get_so3_distribution(a_lms=cond_cov, empty=data['empty']) # so3: (batches, 3) if actions is not None: orientation = actions[..., 3:6] elif self.training: orientation = so3_dist.sample() else: orientation = so3_dist.argmax() # Log prob log_prob_list = [ focus_dist.log_prob(focus.squeeze(-1)), element_dist.log_prob(element.squeeze(-1)), distance_dist.log_prob(distance.squeeze(-1)), so3_dist.log_prob(orientation), ] log_prob = torch.stack(log_prob_list, dim=-1).sum(dim=-1) # (batches, ) # Entropy entropy_list = [ focus_dist.entropy(), element_dist.entropy(), ] entropy = torch.stack(entropy_list, dim=-1).sum(dim=-1) # (batches, ) # Value function # atom_mask: (batches, atoms) # invariants: (batches, atoms, feats) trans_invariats = self.phi_trans(invariats) value_feats = torch.einsum( # type: ignore 'ba,baf->bf', data['value_mask'].to(self.dtype), trans_invariats) # (batches, inv_feats) value = self.phi_v(value_feats).squeeze(-1) # (batches, ) # Action response: Dict[str, Any] = {} if actions is None: actions = torch.cat([focus.float(), element.float(), distance, orientation], dim=-1) # Build correspond action in action space response['actions'] = [self.to_action_space(a, o) for a, o in zip(actions, observations)] response.update({ 'a': actions, # (batches, subactions) 'logp': log_prob, # (batches, ) 'ent': entropy, # (batches, ) 'v': value, # (batches, ) 'dists': [focus_dist, element_dist, distance_dist, so3_dist], }) return response
def spherical_harmonics(cg_dict, pos, maxsh, normalize=True, conj=False, sh_norm='unit'): r""" Functional form of the Spherical Harmonics. See documentation of :class:`SphericalHarmonics` for details. """ s = pos.shape[:-1] pos = pos.view(-1, 3) if normalize: norm = pos.norm(dim=-1, keepdim=True) mask = (norm > 0) # pos /= norm # pos[pos == inf] = 0 pos = torch.where(mask, pos / norm, torch.zeros_like(pos)) psi0 = torch.full(s + (1, ), sqrt(1 / (4 * pi)), dtype=pos.dtype, device=pos.device) psi0 = torch.stack([psi0, torch.zeros_like(psi0)], -1) psi0 = psi0.view(-1, 1, 1, 2) sph_harms = [psi0] if maxsh >= 1: psi1 = pos_to_rep(pos, conj=conj) psi1 *= sqrt(3 / (4 * pi)) sph_harms.append(psi1) if maxsh >= 2: new_psi = psi1 for l in range(2, maxsh + 1): new_psi = cg_product(cg_dict, [new_psi], [psi1], minl=0, maxl=l, ignore_check=True)[-1] # Use equation Y^{m1}_{l1} \otimes Y^{m2}_{l2} = \sqrt((2*l1+1)(2*l2+1)/4*\pi*(2*l3+1)) <l1 0 l2 0|l3 0> <l1 m1 l2 m2|l3 m3> Y^{m3}_{l3} # cg_coeff = CGcoeffs[1*(CGmaxL+1) + l-1][5*(l-1)+1, 3*(l-1)+1] # 5*l-4 = (l)^2 -(l-2)^2 + (l-1) + 1, notice indexing starts at l=2 cg_coeff = cg_dict[( 1, l - 1 )][5 * (l - 1) + 1, 3 * (l - 1) + 1] # 5*l-4 = (l)^2 -(l-2)^2 + (l-1) + 1, notice indexing starts at l=2 new_psi *= sqrt( (4 * pi * (2 * l + 1)) / (3 * (2 * l - 1))) / cg_coeff sph_harms.append(new_psi) sph_harms = [part.view(s + part.shape[1:]) for part in sph_harms] if sh_norm == 'qm': pass elif sh_norm == 'unit': sph_harms = [ part * sqrt((4 * pi) / (2 * ell + 1)) for ell, part in enumerate(sph_harms) ] else: raise ValueError( 'Incorrect choice of spherial harmonic normalization!') return SO3Vec(sph_harms)
def forward(self, features, atom_mask, edge_features, edge_mask, norms): """ Forward pass for :class:`InputMPNN` layer. Parameters ---------- features : :class:`torch.Tensor` Input atom features, i.e., a one-hot embedding of the atom type, atom charge, and any other related inputs. atom_mask : :class:`torch.Tensor` Mask used to account for padded atoms for unequal batch sizes. edge_features : :class:`torch.Tensor` Unused. Included only for pedagogical purposes. edge_mask : :class:`torch.Tensor` Mask used to account for padded edges for unequal batch sizes. norms : :class:`torch.Tensor` Matrix of relative distances between pairs of atoms. Returns ------- :class:`SO3Vec` Processed atom features to be used as input to Clebsch-Gordan layers as part of Cormorant. """ # Unsqueeze the atom mask to match the appropriate dimensions later atom_mask = atom_mask.unsqueeze(-1) # Get the shape of the input to reshape at the end s = features.shape # Loop over MPNN levels. There is no "edge network" here. # Instead, there is just masked radial functions, that take # the role of the adjacency matrix. for mlp, rad_filt, mask in zip(self.mlps, self.rad_filts, self.masks): # Construct the learnable radial functions rad = rad_filt(norms, edge_mask) # TODO: Real-valued SO3Scalar so we don't need any hacks # Convert to a form that MaskLevel expects # Hack to account for the lack of real-valued SO3Scalar and # structure of RadialFilters. rad = rad[0][..., 0].unsqueeze(-1) # OLD: # Convert to a form that MaskLevel expects # rad[0] = rad[0].unsqueeze(-1) # Mask the position function if desired edge = mask(rad, edge_mask, norms) # Convert to a form that MatMul expects edge = edge.squeeze(-1) # Now pass messages using matrix multiplication with the edge features # Einsum b: batch, a: atom, c: channel, x: to be summed over features_mp = torch.einsum('baxc,bxc->bac', edge, features) # Concatenate the passed messages with the original features features_mp = torch.cat([features_mp, features], dim=-1) # Now apply a masked MLP features = mlp(features_mp, mask=atom_mask) # The output are the MLP features reshaped into a set of complex numbers. out = features.view(s[0:2] + (self.channels_out, 1, 2)) return SO3Vec([out])