def prod_force_virial( self, atom_ener: tf.Tensor, natoms: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: """ Compute force and virial Parameters ---------- atom_ener The atomic energy natoms The number of atoms. This tensor has the length of Ntypes + 2 natoms[0]: number of local atoms natoms[1]: total number of atoms held by this processor natoms[i]: 2 <= i < Ntypes+2, number of type i atoms Returns ------- force The force on atoms virial The total virial atom_virial The atomic virial """ [net_deriv] = tf.gradients(atom_ener, self.descrpt) tf.summary.histogram('net_derivative', net_deriv) net_deriv_reshape = tf.reshape(net_deriv, [-1, natoms[0] * self.ndescrpt]) force = op_module.prod_force(net_deriv_reshape, self.descrpt_deriv, self.nlist, self.axis, natoms, n_a_sel=self.nnei_a, n_r_sel=self.nnei_r) virial, atom_virial \ = op_module.prod_virial (net_deriv_reshape, self.descrpt_deriv, self.rij, self.nlist, self.axis, natoms, n_a_sel = self.nnei_a, n_r_sel = self.nnei_r) tf.summary.histogram('force', force) tf.summary.histogram('virial', virial) tf.summary.histogram('atom_virial', atom_virial) return force, virial, atom_virial
def comp_ef(self, dcoord, dbox, dtype, tnatoms, name, reuse=None): t_default_mesh = tf.constant(self.default_mesh) descrpt, descrpt_deriv, rij, nlist, axis, rot_mat \ = op_module.descrpt (dcoord, dtype, tnatoms, dbox, t_default_mesh, self.t_avg, self.t_std, rcut_a = self.rcut_a, rcut_r = self.rcut_r, sel_a = self.sel_a, sel_r = self.sel_r, axis_rule = self.axis_rule) self.axis = axis self.nlist = nlist self.descrpt = descrpt inputs_reshape = tf.reshape(descrpt, [-1, self.ndescrpt]) atom_ener = self._net(inputs_reshape, name, reuse=reuse) atom_ener_reshape = tf.reshape(atom_ener, [-1, self.natoms[0]]) energy = tf.reduce_sum(atom_ener_reshape, axis=1) net_deriv_ = tf.gradients(atom_ener, inputs_reshape) net_deriv = net_deriv_[0] net_deriv_reshape = tf.reshape(net_deriv, [-1, self.natoms[0] * self.ndescrpt]) force = op_module.prod_force(net_deriv_reshape, descrpt_deriv, nlist, axis, tnatoms, n_a_sel=self.nnei_a, n_r_sel=self.nnei_r) virial, atom_vir = op_module.prod_virial(net_deriv_reshape, descrpt_deriv, rij, nlist, axis, tnatoms, n_a_sel=self.nnei_a, n_r_sel=self.nnei_r) return energy, force, virial
def prod_force_virial(self, atom_ener, natoms) : [net_deriv] = tf.gradients (atom_ener, self.descrpt) net_deriv_reshape = tf.reshape (net_deriv, [-1, natoms[0] * self.ndescrpt]) force = op_module.prod_force (net_deriv_reshape, self.descrpt_deriv, self.nlist, self.axis, natoms, n_a_sel = self.nnei_a, n_r_sel = self.nnei_r) virial, atom_virial \ = op_module.prod_virial (net_deriv_reshape, self.descrpt_deriv, self.rij, self.nlist, self.axis, natoms, n_a_sel = self.nnei_a, n_r_sel = self.nnei_r) return force, virial, atom_virial
def comp_interpl_ef(self, dcoord, dbox, dtype, tnatoms, name, reuse=None): descrpt, descrpt_deriv, rij, nlist, axis \ = op_module.descrpt (dcoord, dtype, tnatoms, dbox, tf.constant(self.default_mesh), self.t_avg, self.t_std, rcut_a = self.rcut_a, rcut_r = self.rcut_r, sel_a = self.sel_a, sel_r = self.sel_r, axis_rule = self.axis_rule) inputs_reshape = tf.reshape(descrpt, [-1, self.ndescrpt]) atom_ener = self._net(inputs_reshape, name, reuse=reuse) sw_lambda, sw_deriv \ = op_module.soft_min_switch(dtype, rij, nlist, tnatoms, sel_a = self.sel_a, sel_r = self.sel_r, alpha = self.smin_alpha, rmin = self.sw_rmin, rmax = self.sw_rmax) inv_sw_lambda = 1.0 - sw_lambda tab_atom_ener, tab_force, tab_atom_virial \ = op_module.pair_tab( self.tab_info, self.tab_data, dtype, rij, nlist, tnatoms, sw_lambda, sel_a = self.sel_a, sel_r = self.sel_r) energy_diff = tab_atom_ener - tf.reshape(atom_ener, [-1, self.natoms[0]]) tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape( tab_atom_ener, [-1]) atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener energy_raw = tab_atom_ener + atom_ener energy_raw = tf.reshape(energy_raw, [-1, self.natoms[0]]) energy = tf.reduce_sum(energy_raw, axis=1) net_deriv_ = tf.gradients(atom_ener, inputs_reshape) net_deriv = net_deriv_[0] net_deriv_reshape = tf.reshape(net_deriv, [-1, self.natoms[0] * self.ndescrpt]) force = op_module.prod_force(net_deriv_reshape, descrpt_deriv, nlist, axis, tnatoms, n_a_sel=self.nnei_a, n_r_sel=self.nnei_r) sw_force \ = op_module.soft_min_force(energy_diff, sw_deriv, nlist, tnatoms, n_a_sel = self.nnei_a, n_r_sel = self.nnei_r) force = force + sw_force + tab_force virial, atom_vir = op_module.prod_virial(net_deriv_reshape, descrpt_deriv, rij, nlist, axis, tnatoms, n_a_sel=self.nnei_a, n_r_sel=self.nnei_r) sw_virial, sw_atom_virial \ = op_module.soft_min_virial (energy_diff, sw_deriv, rij, nlist, tnatoms, n_a_sel = self.nnei_a, n_r_sel = self.nnei_r) # atom_virial = atom_virial + sw_atom_virial + tab_atom_virial virial = virial + sw_virial \ + tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, self.natoms[1], 9]), axis = 1) return energy, force, virial