def __init__( self, mol, dist_feat_dim, n_up, n_down, n_orbitals, n_channels, *, embedding_dim=128, with_jastrow=True, n_jastrow_layers=3, with_backflow=True, n_backflow_layers=3, with_r_backflow=False, schnet_kwargs=None, subnet_kwargs=None, ): super().__init__() self.schnet = ElectronicSchNet( n_up, n_down, len(mol), dist_feat_dim=dist_feat_dim, embedding_dim=embedding_dim, subnet_metafactory=partial(SubnetFactory, **(subnet_kwargs or {})), **(schnet_kwargs or {}), ) if with_jastrow: self.jastrow = get_log_dnn(embedding_dim, 1, SSP, n_layers=n_jastrow_layers) else: self.forward_jastrow = None if with_backflow: backflow = [ get_log_dnn( embedding_dim, n_orbitals, SSP, n_layers=n_backflow_layers, last_bias=True, ) for _ in range(n_channels) ] self.backflow = nn.ModuleList(backflow) else: self.forward_backflow = None if with_r_backflow: self.r_backflow = Backflow(mol, embedding_dim) else: self.forward_r_backflow = None self._cache = {}
def g_subnet(self): r"""Create the :math:`\mathbf g` network.""" return get_log_dnn( self.kernel_dim, self.embedding_dim, SSP, n_layers=self.n_layers_g, )
def __init__(self, embedding_dim, activation_factory=SSP, *, n_layers=3, sum_first=True): super().__init__() self.net = get_log_dnn(embedding_dim, 1, activation_factory, n_layers=n_layers) self.sum_first = sum_first
def __init__( self, embedding_dim, n_orbitals, n_backflows, activation_factory=SSP, *, n_layers=3, ): super().__init__() nets = [ get_log_dnn( embedding_dim, n_orbitals, activation_factory, n_layers=n_layers, last_bias=True, ) for _ in range(n_backflows) ] self.nets = nn.ModuleList(nets)
def w_subnet(self): r"""Create the :math:`\mathbf w` network.""" return get_log_dnn(self.dist_feat_dim, self.kernel_dim, SSP, n_layers=self.n_layers_w)
def subnets_factory(embedding_dim): return ( get_log_dnn(embedding_dim, 1, SSP, n_layers=3), get_log_dnn(embedding_dim, len(mol), SSP, n_layers=3), )