def forward(self, det_up: torch.Tensor, det_down: torch.Tensor) -> Tuple: sign, log_psi = sloglindet(self.w, det_up, det_down) sign = sign.detach() return sign, log_psi
def forward(self, rs): # noqa: C901 batch_dim, n_elec = rs.shape[:2] assert n_elec == self.confs.shape[1] n_atoms = len(self.mol) coords = self.mol.coords diffs_nuc = pairwise_diffs(torch.cat([coords, rs.flatten(end_dim=1)]), coords) dists_elec = pairwise_distance(rs, rs) if self.omni: dists_nuc = (diffs_nuc[n_atoms:, :, 3].sqrt().view(batch_dim, n_elec, n_atoms)) xs = self.mo(diffs_nuc) # get orbitals as [bs, 1, i, mu] xs = xs.view(batch_dim, 1, n_elec, -1) # get jastrow J and backflow fs (as [bs, q, i, mu/nu]) J, fs = self.omni(dists_nuc, dists_elec) if self.omni else (None, None) if fs is not None and self.backflow_type == 'orbital': xs = self._backflow_op(xs, fs) # form dets as [bs, q, p, i, nu] conf_up, conf_down = self.confs[:, :self.n_up], self.confs[:, self.n_up:] det_up = xs[:, :, :self.n_up, conf_up].transpose(-3, -2) det_down = xs[:, :, self.n_up:, conf_down].transpose(-3, -2) if fs is not None and self.backflow_type == 'det': n_conf = len(self.confs) fs = fs.unflatten(1, ((None, fs.shape[1] // n_conf), (None, n_conf))) det_up = self._backflow_op(det_up, fs[..., :self.n_up, :self.n_up]) det_down = self._backflow_op(det_down, fs[..., self.n_up:, :self.n_down]) # with open-shell systems, part of the backflow output is not used if self.use_sloglindet == 'always' or ( self.use_sloglindet == 'training' and not self.sampling): bf_dim = det_up.shape[-4] if isinstance(self.conf_coeff, nn.Linear): conf_coeff = self.conf_coeff.weight[0] conf_coeff = conf_coeff.expand(bf_dim, -1).flatten() / np.sqrt(bf_dim) else: conf_coeff = det_up.new_ones(1) det_up = det_up.flatten(start_dim=-4, end_dim=-3).contiguous() det_down = det_down.flatten(start_dim=-4, end_dim=-3).contiguous() sign, psi = sloglindet(conf_coeff, det_up, det_down) sign = sign.detach() else: if self.return_log: sign_up, det_up = eval_log_slater(det_up) sign_down, det_down = eval_log_slater(det_down) xs = det_up + det_down xs_shift = xs.flatten(start_dim=1).max(dim=-1).values # the exp-normalize trick, to avoid over/underflow of the exponential xs = sign_up * sign_down * torch.exp(xs - xs_shift[:, None, None]) else: det_up = eval_slater(det_up) det_down = eval_slater(det_down) xs = det_up * det_down psi = self.conf_coeff(xs).squeeze(dim=-1).mean(dim=-1) if self.return_log: psi, sign = psi.abs().log() + xs_shift, psi.sign().detach() if self.cusp_same: cusp_same = self.cusp_same( torch.cat( [ triu_flat(dists_elec[:, idxs, idxs]) for idxs in self.spin_slices ], dim=1, )) cusp_anti = self.cusp_anti( dists_elec[:, :self.n_up, self.n_up:].flatten(start_dim=1)) psi = (psi + cusp_same + cusp_anti if self.return_log else psi * torch.exp(cusp_same + cusp_anti)) if J is not None: psi = psi + J if self.return_log else psi * torch.exp(J) return (psi, sign) if self.return_log else psi
def forward(self, rs): # noqa: C901 batch_dim, n_elec = rs.shape[:2] assert n_elec == self.confs.shape[1] dists_elec = pairwise_self_distance(rs, full=True) # get jastrow J, backflow fs (as [bs, q, i, mu/nu]), and real-space # backflow ps (as [bs, i, 3]) coords = self.mol.coords J, fs, ps = (self.omni( rs, torch.cat([self.mol.coords, self.dummy_coords], dim=0)) if self.omni else (None, None, None)) if ps is not None: rs = rs + ps diffs_nuc = pairwise_diffs(torch.cat([coords, rs.flatten(end_dim=1)]), coords) if self.omni: dists_nuc = (diffs_nuc[len(coords):, :, -1].sqrt().view(batch_dim, n_elec, -1)) xs = self.mo(diffs_nuc) # get orbitals as [bs, 1, i, mu] xs = xs.view(batch_dim, 1, n_elec, -1) if fs is not None and self.backflow_type == 'orbital': xs = self._backflow_op(xs, fs, dists_nuc) # form dets as [bs, q, p, i, nu] n_up = self.n_up conf_up, conf_down = self.confs[:, :n_up], self.confs[:, n_up:] det_up = xs[:, :, :n_up, conf_up].transpose(-3, -2) det_down = xs[:, :, n_up:, conf_down].transpose(-3, -2) if fs is not None and self.backflow_type == 'det': n_conf = len(self.confs) if self.full_determinant: fs = fs.unflatten(1, (fs.shape[1] // n_conf, n_conf)) det_full = fs.new_zeros((*det_up.shape[:3], n_elec, n_elec)) det_full[..., :n_up, :n_up] = det_up det_full[..., n_up:, n_up:] = det_down det_up = det_full = self._backflow_op(det_full, fs, dists_nuc) det_down = fs.new_empty((*det_down.shape[:3], 0, 0)) else: fs = ( fs[0].unflatten(1, (fs[0].shape[1] // n_conf, n_conf)), fs[1].unflatten(1, (fs[1].shape[1] // n_conf, n_conf)), ) det_up = self._backflow_op(det_up, fs[0], dists_nuc[:, :n_up]) det_down = self._backflow_op(det_down, fs[1], dists_nuc[:, n_up:]) if self.use_sloglindet == 'always' or ( self.use_sloglindet == 'training' and not self.sampling): bf_dim = det_up.shape[-4] if isinstance(self.conf_coeff, nn.Linear): conf_coeff = self.conf_coeff.weight[0] conf_coeff = conf_coeff.expand(bf_dim, -1).flatten() / np.sqrt(bf_dim) else: conf_coeff = det_up.new_ones(1) det_up = det_up.flatten(start_dim=-4, end_dim=-3).contiguous() det_down = det_down.flatten(start_dim=-4, end_dim=-3).contiguous() sign, psi = sloglindet(conf_coeff, det_up, det_down) sign = sign.detach() else: if self.return_log: sign_up, det_up = eval_log_slater(det_up) sign_down, det_down = eval_log_slater(det_down) xs = det_up + det_down xs_shift = xs.flatten(start_dim=1).max(dim=-1).values # the exp-normalize trick, to avoid over/underflow of the exponential xs_shift = xs_shift.where(~torch.isinf(xs_shift), xs_shift.new_tensor(0)) # replace -inf shifts, to avoid running into nans (see sloglindet) xs = sign_up * sign_down * torch.exp(xs - xs_shift[:, None, None]) else: det_up = eval_slater(det_up) det_down = eval_slater(det_down) xs = det_up * det_down psi = self.conf_coeff(xs).squeeze(dim=-1).mean(dim=-1) if self.return_log: psi, sign = psi.abs().log() + xs_shift, psi.sign().detach() if self.cusp_same: cusp_same = self.cusp_same( torch.cat( [ triu_flat(dists_elec[:, idxs, idxs]) for idxs in self.spin_slices ], dim=1, )) cusp_anti = self.cusp_anti(dists_elec[:, :n_up, n_up:].flatten(start_dim=1)) psi = (psi + cusp_same + cusp_anti if self.return_log else psi * torch.exp(cusp_same + cusp_anti)) if J is not None: psi = psi + J if self.return_log else psi * torch.exp(J) return (psi, sign) if self.return_log else psi