def __call__(self, x): for size in self.sizes[:-1]: x = Dense(size)(x) x = jnp.tanh(x) # x = flax.linen.sigmoid(x) return Dense(self.sizes[-1])(x)
def __call__(self, x): def f_bond_length(x): # reshape (n_atoms,3) x = jnp.reshape(x, (self.n_atoms, 3)) # compute all difference z = x[:, None] - x[None, :] # select upper diagonal (LEXIC ORDER) i0 = jnp.triu_indices(self.n_atoms, 1) diff = z[i0] # compute the bond length r = jnp.linalg.norm(diff, axis=1) return r x = vmap(f_bond_length)(x) # NN for size in self.sizes[:]: x = Dense(size)(x) x = jnp.tanh(x) # x = LayerNorm()(x) # Adiabatic energies def f_adiab(x): w00 = x[0] w11 = x[1] w01 = x[2] W = jnp.diag(jnp.array([w00, w11])) W = W.at[0, 1].set(w01) W = W.at[1, 0].set(w01) w, _ = jnp.linalg.eigh(W) return jnp.stack((w, x[-1])) x = vmap(f_adiab, (0))(x) return x
def __call__(self, x): def f_Coulomb_Matrix(x): z_atoms = jnp.array([7., 1., 1., 1.]) z_diag = 0.5 * z_atoms**2.4 M = jnp.multiply(z_atoms[:, None], z_atoms[None, :]) M = M.at[jnp.diag_indices(self.n_atoms)].set(z_diag) x = jnp.reshape(x, (self.n_atoms, 3)) r = x[:, None] - x[None, :] r = jnp.asarray(r) i0 = jnp.diag_indices(self.n_atoms, 2) r = r.at[i0].set(1.) r = jnp.linalg.norm(r, axis=2) r = 1. / r Z = jnp.multiply(M, r) # i0 = jnp.triu_indices(self.n_atoms,0) # Z[i0] return Z.ravel() # Adiabatic energies def f_adiab(x): w00 = x[0] w11 = x[1] w01 = x[2] W = jnp.diag(jnp.array([w00, w11])) W = W.at[0, 1].set(w01) W = W.at[1, 0].set(w01) w, _ = jnp.linalg.eigh(W) return w # ------------ x = vmap(f_Coulomb_Matrix)(x) # NN for size in self.sizes[:-1]: x = Dense(size, dtype=jnp.float64)(x) x = linen.relu(x) # x = silu(x) # x = jnp.tanh(x) # x = linen.relu(x) x = Dense(self.sizes[-1], dtype=jnp.float64)(x) #new x = vmap(f_adiab)(x) return x
def __call__(self, x): def f_bond_length(x): # reshape (n_atoms,3) x = jnp.reshape(x, (self.n_atoms, 3)) # compute all difference z = x[:, None] - x[None, :] # select upper diagonal (LEXIC ORDER) i0 = jnp.triu_indices(self.n_atoms, 1) diff = z[i0] # compute the bond length r = jnp.linalg.norm(diff, axis=1) return r x = vmap(f_bond_length)(x) # NN for size in self.sizes[:]: x = Dense(size)(x) x = jnp.tanh(x) # x = LayerNorm()(x) return x
def __call__(self, x): for width in self.widths[:-1]: x = nn.relu(Dense(width)(x)) return Dense(self.widths[-1])(x)
def __call__(self, x): l = self.param( 'lambd', self.lambd_init, # Initialization function (1, )) l = jnp.asarray(3. * l, self.dtype) # l = 3. def f_bond_length(x): # reshape (n_atoms,3) x = jnp.reshape(x, (self.n_atoms, 3)) # compute all difference z = x[:, None] - x[None, :] # select upper diagonal (LEXIC ORDER) i0 = jnp.triu_indices(self.n_atoms, 1) diff = z[i0] # compute the bond length r = jnp.linalg.norm(diff, axis=1) return r def dot_cross_product(x): # (R_N - R_H1)dot-prod [(R_N - R_H2)cross-prod(R_N - R_H3)]/ (r_NH1 * r_NH2 * r_NH3) x = jnp.reshape(x, (self.n_atoms, 3)) R_nh1 = x[0, :] - x[1, :] R_nh1 = R_nh1 / jnp.linalg.norm(R_nh1) R_nh2 = x[0, :] - x[2, :] R_nh2 = R_nh2 / jnp.linalg.norm(R_nh2) R_nh3 = x[0, :] - x[3, :] R_nh3 = R_nh3 / jnp.linalg.norm(R_nh3) b = jnp.cross(R_nh2, R_nh3) c = jnp.dot(R_nh1, b) return c def f_morse(x): x = f_bond_length(x) # internuclear-distances x = jnp.exp(-x / l) # morse variables # x = 1./x # inv. distance x = f_poly(x) # PIP return x q_NHHH = vmap(dot_cross_product)(x) x = vmap(f_morse)(x) # NN for size in self.sizes[:-1]: x = Dense(size, dtype=jnp.float64)(x) x = linen.relu(x) # x = silu(x) # x = jnp.tanh(x) # x = linen.relu(x) u = Dense(self.sizes[-1], dtype=jnp.float64)(x) #new # Adiabatic energies def f_adiab(x, q_NHHH): w00 = x[0] w11 = x[1] w01 = q_NHHH * x[2] + (q_NHHH**3) * x[3] W = jnp.diag(jnp.array([w00, w11])) W = W.at[0, 1].set(w01) W = W.at[1, 0].set(w01) w, _ = jnp.linalg.eigh(W) return w v = vmap(f_adiab, (0, 0))(u, q_NHHH) return v