예제 #1
0
 def forward(self, rs):
     rs, rs_2 = rs[..., :3], rs[..., 3]
     angulars = pow_int(rs[:, None, :], self.ls).prod(dim=-1)
     exps = torch.exp(-self.zetas * rs_2[:, None])
     radials = (self.coeffs * exps).sum(dim=-1)
     phis = self.anorms * angulars * radials[:, None]
     return phis
예제 #2
0
def test_pow_int():
    xs = torch.randn(4, 3)
    exps = torch.tensor([(1, 2, 3), (0, 1, 2)])
    assert_allclose(pow_int(xs[:, None, :], exps), xs[:, None, :] ** exps.float())