def __call__(self, x):

        for size in self.sizes[:-1]:
            x = Dense(size)(x)
            x = jnp.tanh(x)


#         x = flax.linen.sigmoid(x)
        return Dense(self.sizes[-1])(x)
    def __call__(self, x):
        def f_bond_length(x):
            #     reshape (n_atoms,3)
            x = jnp.reshape(x, (self.n_atoms, 3))
            #     compute all difference
            z = x[:, None] - x[None, :]
            #     select upper diagonal (LEXIC ORDER)
            i0 = jnp.triu_indices(self.n_atoms, 1)
            diff = z[i0]
            #     compute the bond length
            r = jnp.linalg.norm(diff, axis=1)
            return r

        x = vmap(f_bond_length)(x)

        #     NN
        for size in self.sizes[:]:
            x = Dense(size)(x)
            x = jnp.tanh(x)
#         x = LayerNorm()(x)

#     Adiabatic energies

        def f_adiab(x):
            w00 = x[0]
            w11 = x[1]
            w01 = x[2]
            W = jnp.diag(jnp.array([w00, w11]))
            W = W.at[0, 1].set(w01)
            W = W.at[1, 0].set(w01)
            w, _ = jnp.linalg.eigh(W)
            return jnp.stack((w, x[-1]))

        x = vmap(f_adiab, (0))(x)
        return x
    def __call__(self, x):
        def f_Coulomb_Matrix(x):
            z_atoms = jnp.array([7., 1., 1., 1.])
            z_diag = 0.5 * z_atoms**2.4
            M = jnp.multiply(z_atoms[:, None], z_atoms[None, :])
            M = M.at[jnp.diag_indices(self.n_atoms)].set(z_diag)

            x = jnp.reshape(x, (self.n_atoms, 3))
            r = x[:, None] - x[None, :]
            r = jnp.asarray(r)
            i0 = jnp.diag_indices(self.n_atoms, 2)
            r = r.at[i0].set(1.)
            r = jnp.linalg.norm(r, axis=2)
            r = 1. / r

            Z = jnp.multiply(M, r)
            #         i0 = jnp.triu_indices(self.n_atoms,0)
            #         Z[i0]
            return Z.ravel()

#     Adiabatic energies

        def f_adiab(x):
            w00 = x[0]
            w11 = x[1]
            w01 = x[2]
            W = jnp.diag(jnp.array([w00, w11]))
            W = W.at[0, 1].set(w01)
            W = W.at[1, 0].set(w01)
            w, _ = jnp.linalg.eigh(W)
            return w
#    ------------

        x = vmap(f_Coulomb_Matrix)(x)
        #     NN
        for size in self.sizes[:-1]:
            x = Dense(size, dtype=jnp.float64)(x)
            x = linen.relu(x)


#         x = silu(x)
#         x = jnp.tanh(x)
#         x = linen.relu(x)
        x = Dense(self.sizes[-1], dtype=jnp.float64)(x)  #new

        x = vmap(f_adiab)(x)
        return x
    def __call__(self, x):
        def f_bond_length(x):
            #     reshape (n_atoms,3)
            x = jnp.reshape(x, (self.n_atoms, 3))
            #     compute all difference
            z = x[:, None] - x[None, :]
            #     select upper diagonal (LEXIC ORDER)
            i0 = jnp.triu_indices(self.n_atoms, 1)
            diff = z[i0]
            #     compute the bond length
            r = jnp.linalg.norm(diff, axis=1)
            return r

        x = vmap(f_bond_length)(x)

        #     NN
        for size in self.sizes[:]:
            x = Dense(size)(x)
            x = jnp.tanh(x)


#         x = LayerNorm()(x)

        return x
Пример #5
0
 def __call__(self, x):
     for width in self.widths[:-1]:
         x = nn.relu(Dense(width)(x))
     return Dense(self.widths[-1])(x)
    def __call__(self, x):

        l = self.param(
            'lambd',
            self.lambd_init,  # Initialization function
            (1, ))
        l = jnp.asarray(3. * l, self.dtype)

        #     l = 3.

        def f_bond_length(x):
            #     reshape (n_atoms,3)
            x = jnp.reshape(x, (self.n_atoms, 3))
            #     compute all difference
            z = x[:, None] - x[None, :]
            #     select upper diagonal (LEXIC ORDER)
            i0 = jnp.triu_indices(self.n_atoms, 1)
            diff = z[i0]
            #     compute the bond length
            r = jnp.linalg.norm(diff, axis=1)
            return r

        def dot_cross_product(x):
            #     (R_N - R_H1)dot-prod [(R_N - R_H2)cross-prod(R_N - R_H3)]/ (r_NH1 * r_NH2 * r_NH3)

            x = jnp.reshape(x, (self.n_atoms, 3))
            R_nh1 = x[0, :] - x[1, :]
            R_nh1 = R_nh1 / jnp.linalg.norm(R_nh1)
            R_nh2 = x[0, :] - x[2, :]
            R_nh2 = R_nh2 / jnp.linalg.norm(R_nh2)
            R_nh3 = x[0, :] - x[3, :]
            R_nh3 = R_nh3 / jnp.linalg.norm(R_nh3)

            b = jnp.cross(R_nh2, R_nh3)
            c = jnp.dot(R_nh1, b)

            return c

        def f_morse(x):
            x = f_bond_length(x)  # internuclear-distances
            x = jnp.exp(-x / l)  # morse variables
            #         x = 1./x # inv. distance
            x = f_poly(x)  # PIP
            return x

        q_NHHH = vmap(dot_cross_product)(x)
        x = vmap(f_morse)(x)

        #     NN
        for size in self.sizes[:-1]:
            x = Dense(size, dtype=jnp.float64)(x)
            x = linen.relu(x)


#         x = silu(x)
#         x = jnp.tanh(x)
#         x = linen.relu(x)
        u = Dense(self.sizes[-1], dtype=jnp.float64)(x)  #new

        #     Adiabatic energies
        def f_adiab(x, q_NHHH):
            w00 = x[0]
            w11 = x[1]
            w01 = q_NHHH * x[2] + (q_NHHH**3) * x[3]
            W = jnp.diag(jnp.array([w00, w11]))
            W = W.at[0, 1].set(w01)
            W = W.at[1, 0].set(w01)
            w, _ = jnp.linalg.eigh(W)
            return w

        v = vmap(f_adiab, (0, 0))(u, q_NHHH)
        return v