Exemplo n.º 1
0
 def forward(self, rs):  # noqa: C901
     batch_dim, n_elec = rs.shape[:2]
     assert n_elec == self.confs.shape[1]
     n_atoms = len(self.mol)
     coords = self.mol.coords
     diffs_nuc = pairwise_diffs(torch.cat([coords,
                                           rs.flatten(end_dim=1)]), coords)
     dists_elec = pairwise_distance(rs, rs)
     if self.omni:
         dists_nuc = (diffs_nuc[n_atoms:, :,
                                3].sqrt().view(batch_dim, n_elec, n_atoms))
     xs = self.mo(diffs_nuc)
     # get orbitals as [bs, 1, i, mu]
     xs = xs.view(batch_dim, 1, n_elec, -1)
     # get jastrow J and backflow fs (as [bs, q, i, mu/nu])
     J, fs = self.omni(dists_nuc, dists_elec) if self.omni else (None, None)
     if fs is not None and self.backflow_type == 'orbital':
         xs = self._backflow_op(xs, fs)
     # form dets as [bs, q, p, i, nu]
     conf_up, conf_down = self.confs[:, :self.n_up], self.confs[:,
                                                                self.n_up:]
     det_up = xs[:, :, :self.n_up, conf_up].transpose(-3, -2)
     det_down = xs[:, :, self.n_up:, conf_down].transpose(-3, -2)
     if fs is not None and self.backflow_type == 'det':
         n_conf = len(self.confs)
         fs = fs.unflatten(1,
                           ((None, fs.shape[1] // n_conf), (None, n_conf)))
         det_up = self._backflow_op(det_up, fs[..., :self.n_up, :self.n_up])
         det_down = self._backflow_op(det_down,
                                      fs[..., self.n_up:, :self.n_down])
         # with open-shell systems, part of the backflow output is not used
     if self.use_sloglindet == 'always' or (
             self.use_sloglindet == 'training' and not self.sampling):
         bf_dim = det_up.shape[-4]
         if isinstance(self.conf_coeff, nn.Linear):
             conf_coeff = self.conf_coeff.weight[0]
             conf_coeff = conf_coeff.expand(bf_dim,
                                            -1).flatten() / np.sqrt(bf_dim)
         else:
             conf_coeff = det_up.new_ones(1)
         det_up = det_up.flatten(start_dim=-4, end_dim=-3).contiguous()
         det_down = det_down.flatten(start_dim=-4, end_dim=-3).contiguous()
         sign, psi = sloglindet(conf_coeff, det_up, det_down)
         sign = sign.detach()
     else:
         if self.return_log:
             sign_up, det_up = eval_log_slater(det_up)
             sign_down, det_down = eval_log_slater(det_down)
             xs = det_up + det_down
             xs_shift = xs.flatten(start_dim=1).max(dim=-1).values
             # the exp-normalize trick, to avoid over/underflow of the exponential
             xs = sign_up * sign_down * torch.exp(xs -
                                                  xs_shift[:, None, None])
         else:
             det_up = eval_slater(det_up)
             det_down = eval_slater(det_down)
             xs = det_up * det_down
         psi = self.conf_coeff(xs).squeeze(dim=-1).mean(dim=-1)
         if self.return_log:
             psi, sign = psi.abs().log() + xs_shift, psi.sign().detach()
     if self.cusp_same:
         cusp_same = self.cusp_same(
             torch.cat(
                 [
                     triu_flat(dists_elec[:, idxs, idxs])
                     for idxs in self.spin_slices
                 ],
                 dim=1,
             ))
         cusp_anti = self.cusp_anti(
             dists_elec[:, :self.n_up, self.n_up:].flatten(start_dim=1))
         psi = (psi + cusp_same + cusp_anti if self.return_log else psi *
                torch.exp(cusp_same + cusp_anti))
     if J is not None:
         psi = psi + J if self.return_log else psi * torch.exp(J)
     return (psi, sign) if self.return_log else psi
Exemplo n.º 2
0
def test_torch_gto_aos(gtowf, grids):
    coords, weights = map(torch.tensor, (grids.coords, grids.weights))
    ovlps = (gtowf.mo.basis(pairwise_diffs(coords, gtowf.mol.coords))**2 *
             weights[:, None]).sum(dim=0)
    assert_allclose(ovlps, 1)
Exemplo n.º 3
0
 def forward_from_rs(self, rs, coords):
     diffs_nuc = pairwise_diffs(torch.cat([coords, rs]), coords)
     return self(diffs_nuc)
Exemplo n.º 4
0
 def forward(self, rs):  # noqa: C901
     batch_dim, n_elec = rs.shape[:2]
     assert n_elec == self.confs.shape[1]
     dists_elec = pairwise_self_distance(rs, full=True)
     # get jastrow J, backflow fs (as [bs, q, i, mu/nu]), and real-space
     # backflow ps (as [bs, i, 3])
     coords = self.mol.coords
     J, fs, ps = (self.omni(
         rs, torch.cat([self.mol.coords, self.dummy_coords], dim=0))
                  if self.omni else (None, None, None))
     if ps is not None:
         rs = rs + ps
     diffs_nuc = pairwise_diffs(torch.cat([coords,
                                           rs.flatten(end_dim=1)]), coords)
     if self.omni:
         dists_nuc = (diffs_nuc[len(coords):, :,
                                -1].sqrt().view(batch_dim, n_elec, -1))
     xs = self.mo(diffs_nuc)
     # get orbitals as [bs, 1, i, mu]
     xs = xs.view(batch_dim, 1, n_elec, -1)
     if fs is not None and self.backflow_type == 'orbital':
         xs = self._backflow_op(xs, fs, dists_nuc)
     # form dets as [bs, q, p, i, nu]
     n_up = self.n_up
     conf_up, conf_down = self.confs[:, :n_up], self.confs[:, n_up:]
     det_up = xs[:, :, :n_up, conf_up].transpose(-3, -2)
     det_down = xs[:, :, n_up:, conf_down].transpose(-3, -2)
     if fs is not None and self.backflow_type == 'det':
         n_conf = len(self.confs)
         if self.full_determinant:
             fs = fs.unflatten(1, (fs.shape[1] // n_conf, n_conf))
             det_full = fs.new_zeros((*det_up.shape[:3], n_elec, n_elec))
             det_full[..., :n_up, :n_up] = det_up
             det_full[..., n_up:, n_up:] = det_down
             det_up = det_full = self._backflow_op(det_full, fs, dists_nuc)
             det_down = fs.new_empty((*det_down.shape[:3], 0, 0))
         else:
             fs = (
                 fs[0].unflatten(1, (fs[0].shape[1] // n_conf, n_conf)),
                 fs[1].unflatten(1, (fs[1].shape[1] // n_conf, n_conf)),
             )
             det_up = self._backflow_op(det_up, fs[0], dists_nuc[:, :n_up])
             det_down = self._backflow_op(det_down, fs[1], dists_nuc[:,
                                                                     n_up:])
     if self.use_sloglindet == 'always' or (
             self.use_sloglindet == 'training' and not self.sampling):
         bf_dim = det_up.shape[-4]
         if isinstance(self.conf_coeff, nn.Linear):
             conf_coeff = self.conf_coeff.weight[0]
             conf_coeff = conf_coeff.expand(bf_dim,
                                            -1).flatten() / np.sqrt(bf_dim)
         else:
             conf_coeff = det_up.new_ones(1)
         det_up = det_up.flatten(start_dim=-4, end_dim=-3).contiguous()
         det_down = det_down.flatten(start_dim=-4, end_dim=-3).contiguous()
         sign, psi = sloglindet(conf_coeff, det_up, det_down)
         sign = sign.detach()
     else:
         if self.return_log:
             sign_up, det_up = eval_log_slater(det_up)
             sign_down, det_down = eval_log_slater(det_down)
             xs = det_up + det_down
             xs_shift = xs.flatten(start_dim=1).max(dim=-1).values
             # the exp-normalize trick, to avoid over/underflow of the exponential
             xs_shift = xs_shift.where(~torch.isinf(xs_shift),
                                       xs_shift.new_tensor(0))
             # replace -inf shifts, to avoid running into nans (see sloglindet)
             xs = sign_up * sign_down * torch.exp(xs -
                                                  xs_shift[:, None, None])
         else:
             det_up = eval_slater(det_up)
             det_down = eval_slater(det_down)
             xs = det_up * det_down
         psi = self.conf_coeff(xs).squeeze(dim=-1).mean(dim=-1)
         if self.return_log:
             psi, sign = psi.abs().log() + xs_shift, psi.sign().detach()
     if self.cusp_same:
         cusp_same = self.cusp_same(
             torch.cat(
                 [
                     triu_flat(dists_elec[:, idxs, idxs])
                     for idxs in self.spin_slices
                 ],
                 dim=1,
             ))
         cusp_anti = self.cusp_anti(dists_elec[:, :n_up,
                                               n_up:].flatten(start_dim=1))
         psi = (psi + cusp_same + cusp_anti if self.return_log else psi *
                torch.exp(cusp_same + cusp_anti))
     if J is not None:
         psi = psi + J if self.return_log else psi * torch.exp(J)
     return (psi, sign) if self.return_log else psi