Пример #1
0
 def forward(self, x, Y, edges_elec, edges_nuc):
     *batch_dims, n_elec = edges_nuc.shape[:-2]
     h = self.h(x)
     i, j = idx_perm(n_elec, 2, x.device)
     z_elec = (self.w(edges_elec[..., i, j, :]) * h[..., j, :]).sum(dim=-2)
     z_nuc = (self.w(edges_nuc) * Y[..., None, :, :]).sum(dim=-2)
     return self.g(z_elec + z_nuc)
Пример #2
0
def idx_pair_spin(n_up, n_down, device=torch.device('cpu')):  # noqa: B008
    # indexes for up-up, up-down, down-up, down-down
    ij = idx_perm(n_up + n_down, 2, device=device)
    mask = ij < n_up
    return [
        ('same', ij[:, mask[0] & mask[1]].view(2, n_up, -1)),
        ('anti', ij[:, mask[0] & ~mask[1]].view(2, n_up, -1)),
        ('anti', ij[:, ~mask[0] & mask[1]].view(2, n_down, -1)),
        ('same', ij[:, ~mask[0] & ~mask[1]].view(2, n_down, -1)),
    ]
Пример #3
0
 def forward(self, rs, xs):
     batch_dim, n_elec = rs.shape[:2]
     i, j = idx_perm(n_elec, 2, rs.device)
     diffs_elec = rs[..., i, :] - rs[..., j, :]
     bf_elec = (self.bf_elec(
         xs[..., i, :] * xs[..., j, :]).squeeze(dim=-1)[..., None] *
                diffs_elec).sum(dim=-2)
     diffs_nuc = rs[..., :, None, :] - self.mol.coords
     bf_nuc = (self.bf_nuc(xs)[..., None] * diffs_nuc).sum(dim=-2)
     cutoff = backflow_cutoff(diffs_nuc.norm(dim=-1)).prod(dim=-1)
     return rs + 1e-4 * cutoff[..., None] * (bf_elec + bf_nuc)