def comp_ef(self, dcoord, dbox, dtype, tnatoms, name, reuse=None): descrpt, descrpt_deriv, rij, nlist \ = op_module.prod_env_mat_r(dcoord, dtype, tnatoms, dbox, tf.constant(self.default_mesh), self.t_avg, self.t_std, rcut = self.rcut, rcut_smth = self.rcut_smth, sel = self.sel) inputs_reshape = tf.reshape(descrpt, [-1, self.ndescrpt]) atom_ener = self._net(inputs_reshape, name, reuse=reuse) atom_ener_reshape = tf.reshape(atom_ener, [-1, self.natoms[0]]) energy = tf.reduce_sum(atom_ener_reshape, axis=1) net_deriv_ = tf.gradients(atom_ener, inputs_reshape) net_deriv = net_deriv_[0] net_deriv_reshape = tf.reshape(net_deriv, [-1, self.natoms[0] * self.ndescrpt]) force = op_module.prod_force_se_r(net_deriv_reshape, descrpt_deriv, nlist, tnatoms) virial, atom_vir = op_module.prod_virial_se_r(net_deriv_reshape, descrpt_deriv, rij, nlist, tnatoms) return energy, force, virial
def prod_force_virial(self, atom_ener, natoms) : [net_deriv] = tf.gradients (atom_ener, self.descrpt_reshape) net_deriv_reshape = tf.reshape (net_deriv, [-1, natoms[0] * self.ndescrpt]) force \ = op_module.prod_force_se_r (net_deriv_reshape, self.descrpt_deriv, self.nlist, natoms) virial, atom_virial \ = op_module.prod_virial_se_r (net_deriv_reshape, self.descrpt_deriv, self.rij, self.nlist, natoms) return force, virial, atom_virial
def prod_force_virial( self, atom_ener: tf.Tensor, natoms: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: """ Compute force and virial Parameters ---------- atom_ener The atomic energy natoms The number of atoms. This tensor has the length of Ntypes + 2 natoms[0]: number of local atoms natoms[1]: total number of atoms held by this processor natoms[i]: 2 <= i < Ntypes+2, number of type i atoms Returns ------- force The force on atoms virial The total virial atom_virial The atomic virial """ [net_deriv] = tf.gradients(atom_ener, self.descrpt_reshape) tf.summary.histogram('net_derivative', net_deriv) net_deriv_reshape = tf.reshape(net_deriv, [-1, natoms[0] * self.ndescrpt]) force \ = op_module.prod_force_se_r (net_deriv_reshape, self.descrpt_deriv, self.nlist, natoms) virial, atom_virial \ = op_module.prod_virial_se_r (net_deriv_reshape, self.descrpt_deriv, self.rij, self.nlist, natoms) tf.summary.histogram('force', force) tf.summary.histogram('virial', virial) tf.summary.histogram('atom_virial', atom_virial) return force, virial, atom_virial