예제 #1
0
    def test_SO3Vec_cat(self, batch1, batch2, batch3, channels1, channels2, channels3, maxl1, maxl2, maxl3):
        tau1 = [channels1] * (maxl1+1)
        tau2 = [channels2] * (maxl2+1)
        tau3 = [channels3] * (maxl2+1)

        tau12 = SO3Tau.cat([tau1, tau2])
        tau123 = SO3Tau.cat([tau1, tau2, tau3])

        vec1 = SO3Vec.randn(tau1, batch1)
        vec2 = SO3Vec.randn(tau2, batch2)
        vec3 = SO3Vec.randn(tau3, batch3)

        if batch1 == batch2:
            vec12 = so3_torch.cat([vec1, vec2])

            assert vec12.tau == tau12
        else:
            with pytest.raises(RuntimeError):
                vec12 = so3_torch.cat([vec1, vec2])

        if batch1 == batch2 == batch3:
            vec123 = so3_torch.cat([vec1, vec2, vec3])

            assert vec123.tau == tau123
        else:
            with pytest.raises(RuntimeError):
                vec12 = so3_torch.cat([vec1, vec2, vec3])
예제 #2
0
    def test_covariance(self, tau, num_channels, maxl, sample_batch):
        # setup the environment
        # env = build_environment(tau, maxl, num_channels)
        # datasets, data, num_species, charge_scale, sph_harms = env
        data, __, __ = sample_batch
        device, dtype = data['positions'].device, data['positions'].dtype
        sph_harms = SphericalHarmonicsRel(maxl - 1,
                                          conj=True,
                                          device=device,
                                          dtype=dtype,
                                          cg_dict=None)
        D, R, __ = rot.gen_rot(maxl, device=device, dtype=dtype)

        # Build Atom layer
        tlist = [tau] * maxl
        print(tlist)
        atom_lvl = CormorantAtomLevel(tlist,
                                      tlist,
                                      maxl,
                                      num_channels,
                                      1,
                                      'rand',
                                      device=device,
                                      dtype=dtype,
                                      cg_dict=None)

        # Setup Input
        atom_rep, atom_mask, edge_scalars, edge_mask, atom_positions = prep_input(
            data, tau, maxl)
        atom_positions_rot = rot.rotate_cart_vec(R, atom_positions)

        # Get nonrotated data
        spherical_harmonics, norms = sph_harms(atom_positions, atom_positions)
        edge_rep_list = [
            torch.cat([sph_l] * tau, axis=-3) for sph_l in spherical_harmonics
        ]
        edge_reps = SO3Vec(edge_rep_list)
        print(edge_reps.shapes)
        print(atom_rep.shapes)

        # Get Rotated output
        output = atom_lvl(atom_rep, edge_reps, atom_mask)
        output = output.apply_wigner(D)

        # Get rotated outputdata
        atom_rep_rot = atom_rep.apply_wigner(D)
        spherical_harmonics_rot, norms = sph_harms(atom_positions_rot,
                                                   atom_positions_rot)
        edge_rep_list_rot = [
            torch.cat([sph_l] * tau, axis=-3)
            for sph_l in spherical_harmonics_rot
        ]
        edge_reps_rot = SO3Vec(edge_rep_list_rot)
        output_from_rot = atom_lvl(atom_rep_rot, edge_reps_rot, atom_mask)

        for i in range(maxl):
            assert (torch.max(torch.abs(output_from_rot[i] - output[i])) <
                    1E-5)
예제 #3
0
    def test_SO3Vec_check_cplx_fail(self, batch, maxl, channels):
        tau = [channels] * (maxl+1)
        rand_vec = [torch.rand(batch + (t, 2*l+1, 1)) for l, t in enumerate(tau)]

        with pytest.raises(ValueError) as e:
            SO3Vec(rand_vec)

        rand_vec = [torch.rand(batch + (t, 2*l+1, 3)) for l, t in enumerate(tau)]

        with pytest.raises(ValueError) as e:
            SO3Vec(rand_vec)
예제 #4
0
    def test_SO3Vec_check_batch_fail(self, batch, channels):

        maxl = len(batch) - 1

        tau = torch.randint(1, channels+1, [maxl+1])

        rand_vec = [torch.rand(b + (t, 2*l+1, 2)) for l, (b, t) in enumerate(zip(batch, tau))]

        if len(set(batch)) == 1:
            SO3Vec(rand_vec)
        else:
            with pytest.raises(ValueError) as e:
                SO3Vec(rand_vec)
예제 #5
0
    def test_so3_vector_so3_vector_mul(self):
        maxl = 2
        middle_dims = (3, 3)
        vector1 = SO3Vec([torch.randn((2,) + middle_dims + (4, 2 * l+1, 2)) for l in range(maxl)])
        vector1_numpy = [numpy_from_complex(ti) for ti in vector1]
        vector2 = SO3Vec([torch.randn((2,) + middle_dims + (4, 2 * l+1, 2)) for l in range(maxl)])
        vector2_numpy = [numpy_from_complex(ti) for ti in vector2]

        true_complex_product = [part1 * part2 for (part1, part2) in zip(vector1_numpy, vector2_numpy)]
        with pytest.warns(RuntimeWarning):
            so3sv_product = vector1 * vector2
        so3sv_product_numpy = [numpy_from_complex(ti) for ti in so3sv_product]
        for exp_prod, true_prod in zip(so3sv_product_numpy, true_complex_product):
            assert(np.sum(np.abs(exp_prod - true_prod)) < 1E-6)
예제 #6
0
    def forward(self, atom_features, atom_mask, ignore, edge_mask, norms):
        """
        Forward pass for :class:`InputLinear` layer.

        Parameters
        ----------
        atom_features : :class:`torch.Tensor`
            Input atom features, i.e., a one-hot embedding of the atom type,
            atom charge, and any other related inputs.
        atom_mask : :class:`torch.Tensor`
            Mask used to account for padded atoms for unequal batch sizes.
        edge_features : :class:`torch.Tensor`
            Unused. Included only for pedagogical purposes.
        edge_mask : :class:`torch.Tensor`
            Unused. Included only for pedagogical purposes.
        norms : :class:`torch.Tensor`
            Unused. Included only for pedagogical purposes.

        Returns
        -------
        :class:`SO3Vec`
            Processed atom features to be used as input to Clebsch-Gordan layers
            as part of Cormorant.
        """
        atom_mask = atom_mask.unsqueeze(-1)

        out = torch.where(atom_mask, self.lin(atom_features), self.zero)
        out = out.view(atom_features.shape[0:2] + (self.channels_out, 1, 2))

        return SO3Vec([out])
예제 #7
0
def normalize_alms(a_lms: SO3Vec) -> SO3Vec:
    # Normalize a_lms such that:
    # \sum_\ell \sum_m | a_lm |^2 = 1
    k = get_normalization_constant(a_lms)  # [batches]
    clamped_k = k.clamp(min=1e-10)
    sqrt_k = torch.sqrt(clamped_k).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)  # [batches, 1, 1, 1]
    return SO3Vec([part / sqrt_k for part in a_lms])
예제 #8
0
    def test_SO3Vec_init_arb_tau(self, batch, maxl, channels):

        tau_list = torch.randint(1, channels+1, [maxl+1])

        test_vec = SO3Vec.rand(batch, tau_list)

        assert test_vec.tau == tau_list
예제 #9
0
    def test_SO3Vec_init_channels(self, batch, maxl, channels):

        tau_list = [channels]*(maxl+1)

        test_vec = SO3Vec.rand(batch, tau_list)

        assert test_vec.tau == tau_list
예제 #10
0
def concat_so3vecs(so3vecs: List[SO3Vec]) -> SO3Vec:
    # Concat SO3Vecs along batch dimension
    # Dimensions of SO3Vec's for each ell: (batches, taus, ms, 2)

    # Ensure that all SO3 vectors are of the same kind
    assert all(so3vec.ells == so3vecs[0].ells for so3vec in so3vecs)

    return SO3Vec(list(map(lambda tensors: torch.cat(tensors, dim=0), zip(*so3vecs))))
예제 #11
0
def select_taus(vec: SO3Vec, indices: torch.Tensor) -> SO3Vec:
    vectors = []
    # vec: (..., taus, ms, 2)
    for ell in vec.ells:
        gather_indices = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, (2 * ell + 1), 2)
        vectors.append(torch.gather(vec[ell], dim=1, index=gather_indices))

    return SO3Vec(vectors)  # (..., sliced_taus, ms, 2)
예제 #12
0
def select_atomic_covariats(vec: SO3Vec, focus: torch.Tensor) -> SO3Vec:
    # vec (per ell): [batches, atoms, taus, ms, 2]
    # focus: [batches, atoms]
    vectors = []
    for ell in vec.ells:
        vectors.append(torch.einsum('ba,batmx->btmx', focus, vec[ell]))  # type: ignore

    return SO3Vec(vectors)  # (batches, taus, ms, 2)
예제 #13
0
def estimate_alms(y_lms_conj: SO3Vec) -> SO3Vec:
    # Dimensions of SO3Vec's for each ell: (batches, taus, ms, 2)

    # Compute mean over samples
    means = []
    for ell in y_lms_conj.ells:
        # select all batch dimensions
        dim = list(range(len(y_lms_conj[ell].shape) - 3))
        means.append(torch.mean(y_lms_conj[ell], dim=dim, keepdim=True))
    return SO3Vec(means)
예제 #14
0
def select_element(vec: SO3Vec, element_oh: torch.Tensor) -> SO3Vec:
    # vec (per ell): [batches, taus, ms, 2]
    # element_oh: [batches, taus]
    tensors = []
    for ell in vec.ells:
        t = torch.einsum('bt,btmx->bmx', element_oh, vec[ell])  # type: ignore # [batches, ms, 2]
        t = t.unsqueeze(dim=-3)  # [batches, 1, ms, 2]
        tensors.append(t)

    return SO3Vec(tensors)  # [batches, 1, ms, 2]
예제 #15
0
    def test_SO3Vec_mul_scalar(self, batch, maxl, channels):
        tau = [channels] * (maxl+1)

        vec0 = SO3Vec([torch.rand(batch + (t, 2*l+1, 2)) for l, t in enumerate(tau)])

        vec1 = 2 * vec0
        assert all(torch.allclose(2*part0, part1) for part0, part1 in zip(vec0, vec1))

        vec1 = vec0 * 2.0
        assert all(torch.allclose(2*part0, part1) for part0, part1 in zip(vec0, vec1))
예제 #16
0
def test_mix_SO3Vec(batch, maxl, channels1, channels2):

    tau_in = [channels1] * (maxl + 1)
    tau_out = [channels2] * (maxl + 1)

    test_vec = SO3Vec.rand(batch, tau_in)
    test_weight = SO3Weight.rand(tau_in, tau_out)

    print(test_vec.shapes, test_weight.shapes)
    mix(test_weight, test_vec)
예제 #17
0
    def test_SO3Vec_add_list(self, batch, maxl, channels):
        tau = [channels] * (maxl+1)

        vec0 = SO3Vec([torch.rand(batch + (t, 2*l+1, 2)) for l, t in enumerate(tau)])

        scalar = [torch.rand(1).item() for _ in vec0]

        vec1 = scalar + vec0
        assert all(torch.allclose(s + part0, part1) for part0, s, part1 in zip(vec0, scalar, vec1))

        vec1 = vec0 + scalar
        assert all(torch.allclose(part0 + s, part1) for part0, s, part1 in zip(vec0, scalar, vec1))
예제 #18
0
def prep_input(data, taus, maxl):
    atom_positions = data['positions']
    atom_scalar_list = [
        torch.randn(atom_positions.shape[:2] + (taus, 2 * l + 1, 2))
        for l in range(maxl)
    ]
    # atom_scalar_list = [torch.randn(atom_positions.shape[:2] + (num_channels, 1, 2))]
    # atom_scalar_list += [torch.zeros(atom_positions.shape[:2] + (num_channels, 2*l+1, 2)) for l in range(1, maxl)]
    atom_scalars = SO3Vec(atom_scalar_list)
    atom_mask = data['atom_mask']
    edge_mask = data['edge_mask']
    edge_scalars = torch.tensor([])
    return atom_scalars, atom_mask, edge_scalars, edge_mask, atom_positions
예제 #19
0
    def test_cg_product_dict_maxl(self, maxl_dict, maxl_prod, maxl1, maxl2,
                                  chan, batch):
        cg_dict = CGDict(maxl=maxl_dict, dtype=torch.double)

        tau1, tau2 = [chan] * (maxl1 + 1), [chan] * (maxl2 + 1)

        rep1 = SO3Vec.rand(batch, tau1, dtype=torch.double)
        rep2 = SO3Vec.rand(batch, tau2, dtype=torch.double)

        if all(maxl_dict >= maxl for maxl in [maxl_prod, maxl1, maxl2]):
            cg_prod = cg_product(cg_dict, rep1, rep2, maxl=maxl_prod)

        else:
            with pytest.raises(ValueError) as e_info:
                cg_prod = cg_product(cg_dict, rep1, rep2, maxl=maxl_prod)

                tau_out = cg_prod.tau
                tau_pred = cg_product_tau(tau1, tau2)

                # Test to make sure the output type matches the expected output type
                assert list(tau_out) == list(tau_pred)

            assert str(e_info.value).startswith('CG Dictionary maxl')
예제 #20
0
    def test_CGProduct(self, batch, maxl1, maxl2, maxl, channels):
        maxl_all = max(maxl1, maxl2, maxl)
        D, R, _ = rot.gen_rot(maxl_all)

        cg_dict = CGDict(maxl=maxl_all, dtype=torch.double)
        cg_prod = CGProduct(maxl=maxl, dtype=torch.double, cg_dict=cg_dict)

        tau1 = SO3Tau([channels] * (maxl1 + 1))
        tau2 = SO3Tau([channels] * (maxl2 + 1))

        vec1 = SO3Vec.randn(tau1, batch, dtype=torch.double)
        vec2 = SO3Vec.randn(tau2, batch, dtype=torch.double)

        vec1i = vec1.apply_wigner(D, dir='left')
        vec2i = vec2.apply_wigner(D, dir='left')

        vec_prod = cg_prod(vec1, vec2)
        veci_prod = cg_prod(vec1i, vec2i)

        vecf_prod = vec_prod.apply_wigner(D, dir='left')

        # diff = (sph_harmsr - sph_harmsd).abs()
        diff = [(p1 - p2).abs().max() for p1, p2 in zip(veci_prod, vecf_prod)]
        assert all([d < 1e-6 for d in diff])
예제 #21
0
    def test_so3_scalar_so3_vector_mul(self, maxl, num_middle):
        middle_dims = (3,) * num_middle
        scalar = SO3Scalar([torch.randn((2,) + middle_dims + (4, 2)) for i in range(maxl)])
        scalar_numpy = [numpy_from_complex(ti) for ti in scalar]
        vector = SO3Vec([torch.randn((2,) + middle_dims + (4, 2 * i+1, 2)) for i in range(maxl)])
        vector_numpy = [numpy_from_complex(ti) for ti in vector]

        true_complex_product = [np.expand_dims(part1, -1) * part2 for (part1, part2) in zip(scalar_numpy, vector_numpy)]
        so3sv_product = vector * scalar
        so3sv_product_numpy = [numpy_from_complex(ti) for ti in so3sv_product]
        for exp_prod, true_prod in zip(so3sv_product_numpy, true_complex_product):
            assert(np.sum(np.abs(exp_prod - true_prod)) < 1E-6)

        so3sv_product = scalar * vector
        so3sv_product_numpy = [numpy_from_complex(ti) for ti in so3sv_product]
        for exp_prod, true_prod in zip(so3sv_product_numpy, true_complex_product):
            assert(np.sum(np.abs(exp_prod - true_prod)) < 1E-6)
예제 #22
0
 def test_apply_euler(self, batch, channels, maxl):
     tau = SO3Tau([channels] * (maxl + 1))
     vec = SO3Vec.rand(batch, tau, dtype=torch.double)
     wigner = SO3WignerD.euler(maxl, dtype=torch.double)
     so3_torch.apply_wigner(wigner, vec)
예제 #23
0
def unsqueeze_so3vec(vec: SO3Vec, dim: int) -> SO3Vec:
    return SO3Vec([t.unsqueeze(dim) for t in vec])
예제 #24
0
import torch
import pytest

from cormorant.so3_lib import SO3Tau, SO3Vec, SO3Scalar

rand_vec = lambda batch, tau: SO3Vec([torch.rand(batch + (t, 2*l+1, 2)) for l, t in enumerate(tau)])

class TestSO3Vec():

    @pytest.mark.parametrize('batch', [(1,), (2,), (7,), (1,1), (2, 2), (7, 7)])
    @pytest.mark.parametrize('maxl', range(3))
    @pytest.mark.parametrize('channels', range(1, 3))
    def test_SO3Vec_init_channels(self, batch, maxl, channels):

        tau_list = [channels]*(maxl+1)

        test_vec = SO3Vec.rand(batch, tau_list)

        assert test_vec.tau == tau_list

    @pytest.mark.parametrize('batch', [(1,), (2,), (7,), (1,1), (2, 2), (7, 7)])
    @pytest.mark.parametrize('maxl', range(4))
    @pytest.mark.parametrize('channels', range(1, 4))
    def test_SO3Vec_init_arb_tau(self, batch, maxl, channels):

        tau_list = torch.randint(1, channels+1, [maxl+1])

        test_vec = SO3Vec.rand(batch, tau_list)

        assert test_vec.tau == tau_list
예제 #25
0
파일: agent.py 프로젝트: gncs/molgym
    def step(self, observations: List[ObservationType], actions: Optional[np.ndarray] = None) -> dict:
        data = self.parse_observations(observations)

        # Cast action to tensor
        if actions is not None:
            actions = torch.as_tensor(actions, dtype=torch.float, device=self.device)

        # SO3Vec (batches, atoms, taus, ms, 2)
        covariats = self.cg_model(data)

        # Compute invariants
        invariats = self.atomic_scalars(covariats)  # (batches, atoms, inv_feats)

        # Focus
        focus_logits = self.phi_focus(invariats)  # (batches, atoms, 1)
        focus_logits = focus_logits.squeeze(-1)  # (batches, atoms)
        focus_probs = masked_softmax(focus_logits, mask=data['focus_mask'])  # (batches, atoms)
        focus_dist = torch.distributions.Categorical(probs=focus_probs)

        # focus: (batches, 1)
        if actions is not None:
            focus = torch.round(actions[:, :1]).long()
        elif self.training:
            focus = focus_dist.sample().unsqueeze(-1)
        else:
            focus = torch.argmax(focus_probs, dim=-1).unsqueeze(-1)

        focus_oh = to_one_hot(focus, num_classes=self.observation_space.canvas_space.size,
                              device=self.device)  # (batches, atoms)

        focused_cov = so3_tools.select_atomic_covariats(covariats, focus_oh)  # (batches, taus, ms, 2)
        focused_inv = so3_tools.select_atomic_invariats(invariats, focus_oh)  # (batches, feats)

        # Element
        element_logits = self.phi_element(focused_inv)  # (batches, zs)
        element_probs = masked_softmax(element_logits, mask=data['element_mask'])  # (batches, zs)
        element_dist = torch.distributions.Categorical(probs=element_probs)

        # element: (batches, 1)
        if actions is not None:
            element = torch.round(actions[:, 1:2]).long()
        elif self.training:
            element = element_dist.sample().unsqueeze(-1)
        else:
            element = torch.argmax(element_probs, dim=-1).unsqueeze(-1)

        # Crop element
        offsets = self.channel_offsets.expand(len(observations), -1)  # (batches, channels_per_element)
        indices = offsets + element * self.num_channels_per_element
        element_cov = so3_tools.select_taus(focused_cov, indices=indices)
        element_inv = self.atomic_scalars(element_cov)  # (batches, inv_feats)

        # Distance: Gaussian mixture model
        # gmm_log_probs, d_mean_trans: (batches, gaussians)
        gmm_log_probs, d_mean_trans = self.phi_d(element_inv).split(self.num_gaussians, dim=-1)
        distance_mean = torch.tanh(d_mean_trans) * self.distance_half_width + self.distance_center
        distance_dist = GaussianMixtureModel(log_probs=gmm_log_probs,
                                             means=distance_mean,
                                             stds=torch.exp(self.distance_log_stds).clamp(1e-6))

        # distance: (batches, 1)
        if actions is not None:
            distance = actions[:, 2:3]
        elif self.training:
            # Ensure that the sampled distance is > 0
            distance = distance_dist.sample().clamp(0.001).unsqueeze(-1)
        else:
            distance = distance_dist.argmax().unsqueeze(-1)

        # Condition on distance
        transformed_d = distance.unsqueeze(1).unsqueeze(1).expand(-1, self.num_channels_per_element, 1, -1)
        transformed_d = self.pad_zeros(transformed_d)
        distance_so3 = SO3Vec([transformed_d])
        cond_cov = self.cg_mix(element_cov, distance_so3)

        so3_dist = self.get_so3_distribution(a_lms=cond_cov, empty=data['empty'])

        # so3: (batches, 3)
        if actions is not None:
            orientation = actions[..., 3:6]
        elif self.training:
            orientation = so3_dist.sample()
        else:
            orientation = so3_dist.argmax()

        # Log prob
        log_prob_list = [
            focus_dist.log_prob(focus.squeeze(-1)),
            element_dist.log_prob(element.squeeze(-1)),
            distance_dist.log_prob(distance.squeeze(-1)),
            so3_dist.log_prob(orientation),
        ]
        log_prob = torch.stack(log_prob_list, dim=-1).sum(dim=-1)  # (batches, )

        # Entropy
        entropy_list = [
            focus_dist.entropy(),
            element_dist.entropy(),
        ]
        entropy = torch.stack(entropy_list, dim=-1).sum(dim=-1)  # (batches, )

        # Value function
        # atom_mask: (batches, atoms)
        # invariants: (batches, atoms, feats)
        trans_invariats = self.phi_trans(invariats)
        value_feats = torch.einsum(  # type: ignore
            'ba,baf->bf', data['value_mask'].to(self.dtype), trans_invariats)  # (batches, inv_feats)
        value = self.phi_v(value_feats).squeeze(-1)  # (batches, )

        # Action
        response: Dict[str, Any] = {}
        if actions is None:
            actions = torch.cat([focus.float(), element.float(), distance, orientation], dim=-1)

            # Build correspond action in action space
            response['actions'] = [self.to_action_space(a, o) for a, o in zip(actions, observations)]

        response.update({
            'a': actions,  # (batches, subactions)
            'logp': log_prob,  # (batches, )
            'ent': entropy,  # (batches, )
            'v': value,  # (batches, )
            'dists': [focus_dist, element_dist, distance_dist, so3_dist],
        })

        return response
예제 #26
0
def spherical_harmonics(cg_dict,
                        pos,
                        maxsh,
                        normalize=True,
                        conj=False,
                        sh_norm='unit'):
    r"""
    Functional form of the Spherical Harmonics. See documentation of
    :class:`SphericalHarmonics` for details.
    """
    s = pos.shape[:-1]

    pos = pos.view(-1, 3)

    if normalize:
        norm = pos.norm(dim=-1, keepdim=True)
        mask = (norm > 0)
        # pos /= norm
        # pos[pos == inf] = 0
        pos = torch.where(mask, pos / norm, torch.zeros_like(pos))

    psi0 = torch.full(s + (1, ),
                      sqrt(1 / (4 * pi)),
                      dtype=pos.dtype,
                      device=pos.device)
    psi0 = torch.stack([psi0, torch.zeros_like(psi0)], -1)
    psi0 = psi0.view(-1, 1, 1, 2)

    sph_harms = [psi0]
    if maxsh >= 1:
        psi1 = pos_to_rep(pos, conj=conj)
        psi1 *= sqrt(3 / (4 * pi))
        sph_harms.append(psi1)

    if maxsh >= 2:
        new_psi = psi1
        for l in range(2, maxsh + 1):
            new_psi = cg_product(cg_dict, [new_psi], [psi1],
                                 minl=0,
                                 maxl=l,
                                 ignore_check=True)[-1]
            # Use equation Y^{m1}_{l1} \otimes Y^{m2}_{l2} = \sqrt((2*l1+1)(2*l2+1)/4*\pi*(2*l3+1)) <l1 0 l2 0|l3 0> <l1 m1 l2 m2|l3 m3> Y^{m3}_{l3}
            # cg_coeff = CGcoeffs[1*(CGmaxL+1) + l-1][5*(l-1)+1, 3*(l-1)+1] # 5*l-4 = (l)^2 -(l-2)^2 + (l-1) + 1, notice indexing starts at l=2
            cg_coeff = cg_dict[(
                1, l - 1
            )][5 * (l - 1) + 1, 3 * (l - 1) +
               1]  # 5*l-4 = (l)^2 -(l-2)^2 + (l-1) + 1, notice indexing starts at l=2
            new_psi *= sqrt(
                (4 * pi * (2 * l + 1)) / (3 * (2 * l - 1))) / cg_coeff
            sph_harms.append(new_psi)
    sph_harms = [part.view(s + part.shape[1:]) for part in sph_harms]

    if sh_norm == 'qm':
        pass
    elif sh_norm == 'unit':
        sph_harms = [
            part * sqrt((4 * pi) / (2 * ell + 1))
            for ell, part in enumerate(sph_harms)
        ]
    else:
        raise ValueError(
            'Incorrect choice of spherial harmonic normalization!')

    return SO3Vec(sph_harms)
예제 #27
0
    def forward(self, features, atom_mask, edge_features, edge_mask, norms):
        """
        Forward pass for :class:`InputMPNN` layer.

        Parameters
        ----------
        features : :class:`torch.Tensor`
            Input atom features, i.e., a one-hot embedding of the atom type,
            atom charge, and any other related inputs.
        atom_mask : :class:`torch.Tensor`
            Mask used to account for padded atoms for unequal batch sizes.
        edge_features : :class:`torch.Tensor`
            Unused. Included only for pedagogical purposes.
        edge_mask : :class:`torch.Tensor`
            Mask used to account for padded edges for unequal batch sizes.
        norms : :class:`torch.Tensor`
            Matrix of relative distances between pairs of atoms.

        Returns
        -------
        :class:`SO3Vec`
            Processed atom features to be used as input to Clebsch-Gordan layers
            as part of Cormorant.
        """
        # Unsqueeze the atom mask to match the appropriate dimensions later
        atom_mask = atom_mask.unsqueeze(-1)

        # Get the shape of the input to reshape at the end
        s = features.shape

        # Loop over MPNN levels. There is no "edge network" here.
        # Instead, there is just masked radial functions, that take
        # the role of the adjacency matrix.
        for mlp, rad_filt, mask in zip(self.mlps, self.rad_filts, self.masks):
            # Construct the learnable radial functions
            rad = rad_filt(norms, edge_mask)

            # TODO: Real-valued SO3Scalar so we don't need any hacks
            # Convert to a form that MaskLevel expects
            # Hack to account for the lack of real-valued SO3Scalar and
            # structure of RadialFilters.
            rad = rad[0][..., 0].unsqueeze(-1)

            # OLD:
            # Convert to a form that MaskLevel expects
            # rad[0] = rad[0].unsqueeze(-1)

            # Mask the position function if desired
            edge = mask(rad, edge_mask, norms)
            # Convert to a form that MatMul expects
            edge = edge.squeeze(-1)

            # Now pass messages using matrix multiplication with the edge features
            # Einsum b: batch, a: atom, c: channel, x: to be summed over
            features_mp = torch.einsum('baxc,bxc->bac', edge, features)

            # Concatenate the passed messages with the original features
            features_mp = torch.cat([features_mp, features], dim=-1)

            # Now apply a masked MLP
            features = mlp(features_mp, mask=atom_mask)

        # The output are the MLP features reshaped into a set of complex numbers.
        out = features.view(s[0:2] + (self.channels_out, 1, 2))

        return SO3Vec([out])