Esempio n. 1
0
 def __init__(
     self,
     mol,
     dist_feat_dim,
     n_up,
     n_down,
     n_orbitals,
     n_channels,
     *,
     embedding_dim=128,
     with_jastrow=True,
     n_jastrow_layers=3,
     with_backflow=True,
     n_backflow_layers=3,
     with_r_backflow=False,
     schnet_kwargs=None,
     subnet_kwargs=None,
 ):
     super().__init__()
     self.schnet = ElectronicSchNet(
         n_up,
         n_down,
         len(mol),
         dist_feat_dim=dist_feat_dim,
         embedding_dim=embedding_dim,
         subnet_metafactory=partial(SubnetFactory, **(subnet_kwargs or {})),
         **(schnet_kwargs or {}),
     )
     if with_jastrow:
         self.jastrow = get_log_dnn(embedding_dim, 1, SSP, n_layers=n_jastrow_layers)
     else:
         self.forward_jastrow = None
     if with_backflow:
         backflow = [
             get_log_dnn(
                 embedding_dim,
                 n_orbitals,
                 SSP,
                 n_layers=n_backflow_layers,
                 last_bias=True,
             )
             for _ in range(n_channels)
         ]
         self.backflow = nn.ModuleList(backflow)
     else:
         self.forward_backflow = None
     if with_r_backflow:
         self.r_backflow = Backflow(mol, embedding_dim)
     else:
         self.forward_r_backflow = None
     self._cache = {}
Esempio n. 2
0
 def g_subnet(self):
     r"""Create the :math:`\mathbf g` network."""
     return get_log_dnn(
         self.kernel_dim,
         self.embedding_dim,
         SSP,
         n_layers=self.n_layers_g,
     )
Esempio n. 3
0
 def __init__(self,
              embedding_dim,
              activation_factory=SSP,
              *,
              n_layers=3,
              sum_first=True):
     super().__init__()
     self.net = get_log_dnn(embedding_dim,
                            1,
                            activation_factory,
                            n_layers=n_layers)
     self.sum_first = sum_first
Esempio n. 4
0
 def __init__(
     self,
     embedding_dim,
     n_orbitals,
     n_backflows,
     activation_factory=SSP,
     *,
     n_layers=3,
 ):
     super().__init__()
     nets = [
         get_log_dnn(
             embedding_dim,
             n_orbitals,
             activation_factory,
             n_layers=n_layers,
             last_bias=True,
         ) for _ in range(n_backflows)
     ]
     self.nets = nn.ModuleList(nets)
Esempio n. 5
0
 def w_subnet(self):
     r"""Create the :math:`\mathbf w` network."""
     return get_log_dnn(self.dist_feat_dim,
                        self.kernel_dim,
                        SSP,
                        n_layers=self.n_layers_w)
Esempio n. 6
0
 def subnets_factory(embedding_dim):
     return (
         get_log_dnn(embedding_dim, 1, SSP, n_layers=3),
         get_log_dnn(embedding_dim, len(mol), SSP, n_layers=3),
     )