Esempio n. 1
0
 def __init__(
     self,
     mol,
     basis,
     n_orbitals,
     cusp_correction=True,
     rc_scaling=1.0,
     eps=1e-6,
 ):
     super().__init__()
     self.n_atoms = len(mol)
     self.n_orbitals = n_orbitals
     self.basis = basis
     self.mo_coeff = nn.Linear(len(basis), n_orbitals, bias=False)
     if cusp_correction:
         rc = rc_scaling / mol.charges.float()
         dists = pairwise_distance(mol.coords, mol.coords)
         eye = torch.eye(len(mol), out=torch.empty_like(dists))
         factors = (eye + dists / (rc + rc[:, None])).min(dim=-1).values
         if (factors < 0.99).any():
             log.warning('Reducing cusp-correction cutoffs due to overlaps')
         rc = rc * factors
         self.cusp_corr = CuspCorrection(mol.charges, n_orbitals, rc, eps=eps)
         self.register_buffer('basis_cusp_info', basis.get_cusp_info(rc).t())
     else:
         self.cusp_corr = None
Esempio n. 2
0
 def forward(self, rs):  # noqa: C901
     batch_dim, n_elec = rs.shape[:2]
     assert n_elec == self.confs.shape[1]
     n_atoms = len(self.mol)
     coords = self.mol.coords
     diffs_nuc = pairwise_diffs(torch.cat([coords,
                                           rs.flatten(end_dim=1)]), coords)
     dists_elec = pairwise_distance(rs, rs)
     if self.omni:
         dists_nuc = (diffs_nuc[n_atoms:, :,
                                3].sqrt().view(batch_dim, n_elec, n_atoms))
     xs = self.mo(diffs_nuc)
     # get orbitals as [bs, 1, i, mu]
     xs = xs.view(batch_dim, 1, n_elec, -1)
     # get jastrow J and backflow fs (as [bs, q, i, mu/nu])
     J, fs = self.omni(dists_nuc, dists_elec) if self.omni else (None, None)
     if fs is not None and self.backflow_type == 'orbital':
         xs = self._backflow_op(xs, fs)
     # form dets as [bs, q, p, i, nu]
     conf_up, conf_down = self.confs[:, :self.n_up], self.confs[:,
                                                                self.n_up:]
     det_up = xs[:, :, :self.n_up, conf_up].transpose(-3, -2)
     det_down = xs[:, :, self.n_up:, conf_down].transpose(-3, -2)
     if fs is not None and self.backflow_type == 'det':
         n_conf = len(self.confs)
         fs = fs.unflatten(1,
                           ((None, fs.shape[1] // n_conf), (None, n_conf)))
         det_up = self._backflow_op(det_up, fs[..., :self.n_up, :self.n_up])
         det_down = self._backflow_op(det_down,
                                      fs[..., self.n_up:, :self.n_down])
         # with open-shell systems, part of the backflow output is not used
     if self.use_sloglindet == 'always' or (
             self.use_sloglindet == 'training' and not self.sampling):
         bf_dim = det_up.shape[-4]
         if isinstance(self.conf_coeff, nn.Linear):
             conf_coeff = self.conf_coeff.weight[0]
             conf_coeff = conf_coeff.expand(bf_dim,
                                            -1).flatten() / np.sqrt(bf_dim)
         else:
             conf_coeff = det_up.new_ones(1)
         det_up = det_up.flatten(start_dim=-4, end_dim=-3).contiguous()
         det_down = det_down.flatten(start_dim=-4, end_dim=-3).contiguous()
         sign, psi = sloglindet(conf_coeff, det_up, det_down)
         sign = sign.detach()
     else:
         if self.return_log:
             sign_up, det_up = eval_log_slater(det_up)
             sign_down, det_down = eval_log_slater(det_down)
             xs = det_up + det_down
             xs_shift = xs.flatten(start_dim=1).max(dim=-1).values
             # the exp-normalize trick, to avoid over/underflow of the exponential
             xs = sign_up * sign_down * torch.exp(xs -
                                                  xs_shift[:, None, None])
         else:
             det_up = eval_slater(det_up)
             det_down = eval_slater(det_down)
             xs = det_up * det_down
         psi = self.conf_coeff(xs).squeeze(dim=-1).mean(dim=-1)
         if self.return_log:
             psi, sign = psi.abs().log() + xs_shift, psi.sign().detach()
     if self.cusp_same:
         cusp_same = self.cusp_same(
             torch.cat(
                 [
                     triu_flat(dists_elec[:, idxs, idxs])
                     for idxs in self.spin_slices
                 ],
                 dim=1,
             ))
         cusp_anti = self.cusp_anti(
             dists_elec[:, :self.n_up, self.n_up:].flatten(start_dim=1))
         psi = (psi + cusp_same + cusp_anti if self.return_log else psi *
                torch.exp(cusp_same + cusp_anti))
     if J is not None:
         psi = psi + J if self.return_log else psi * torch.exp(J)
     return (psi, sign) if self.return_log else psi
Esempio n. 3
0
 def forward(self, rs, coords):
     dists_elec = pairwise_distance(rs, rs)
     dists_nuc = pairwise_distance(rs, coords)
     xs = self.schnet(dists_elec, dists_nuc)
     return self.orbital(xs).squeeze(dim=-1).sum(dim=-1), None, None