def build(self, coord_, atype_, natoms, box, mesh, input_dict, frz_model=None, suffix='', reuse=None): with tf.variable_scope('model_attr' + suffix, reuse=reuse): t_tmap = tf.constant(' '.join(self.type_map), name='tmap', dtype=tf.string) t_mt = tf.constant(self.model_type, name='model_type', dtype=tf.string) t_ver = tf.constant(MODEL_VERSION, name='model_version', dtype=tf.string) if self.srtab is not None: tab_info, tab_data = self.srtab.get() self.tab_info = tf.get_variable( 't_tab_info', tab_info.shape, dtype=tf.float64, trainable=False, initializer=tf.constant_initializer(tab_info, dtype=tf.float64)) self.tab_data = tf.get_variable( 't_tab_data', tab_data.shape, dtype=tf.float64, trainable=False, initializer=tf.constant_initializer(tab_data, dtype=tf.float64)) coord = tf.reshape(coord_, [-1, natoms[1] * 3]) atype = tf.reshape(atype_, [-1, natoms[1]]) # type embedding if any if self.typeebd is not None: type_embedding = self.typeebd.build( self.ntypes, reuse=reuse, suffix=suffix, ) input_dict['type_embedding'] = type_embedding if frz_model == None: dout \ = self.descrpt.build(coord_, atype_, natoms, box, mesh, input_dict, suffix = suffix, reuse = reuse) dout = tf.identity(dout, name='o_descriptor') else: tf.constant(self.rcut, name='descrpt_attr/rcut', dtype=GLOBAL_TF_FLOAT_PRECISION) tf.constant(self.ntypes, name='descrpt_attr/ntypes', dtype=tf.int32) feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box, mesh) return_elements = [ *self.descrpt.get_tensor_names(), 'o_descriptor:0' ] imported_tensors \ = self._import_graph_def_from_frz_model(frz_model, feed_dict, return_elements) dout = imported_tensors[-1] self.descrpt.pass_tensors_from_frz_model(*imported_tensors[:-1]) if self.srtab is not None: nlist, rij, sel_a, sel_r = self.descrpt.get_nlist() nnei_a = np.cumsum(sel_a)[-1] nnei_r = np.cumsum(sel_r)[-1] atom_ener = self.fitting.build(dout, natoms, input_dict, reuse=reuse, suffix=suffix) if self.srtab is not None: sw_lambda, sw_deriv \ = op_module.soft_min_switch(atype, rij, nlist, natoms, sel_a = sel_a, sel_r = sel_r, alpha = self.smin_alpha, rmin = self.sw_rmin, rmax = self.sw_rmax) inv_sw_lambda = 1.0 - sw_lambda # NOTICE: # atom energy is not scaled, # force and virial are scaled tab_atom_ener, tab_force, tab_atom_virial \ = op_module.pair_tab(self.tab_info, self.tab_data, atype, rij, nlist, natoms, sw_lambda, sel_a = sel_a, sel_r = sel_r) energy_diff = tab_atom_ener - tf.reshape(atom_ener, [-1, natoms[0]]) tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape( tab_atom_ener, [-1]) atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener energy_raw = tab_atom_ener + atom_ener else: energy_raw = atom_ener energy_raw = tf.reshape(energy_raw, [-1, natoms[0]], name='o_atom_energy' + suffix) energy = tf.reduce_sum(global_cvt_2_ener_float(energy_raw), axis=1, name='o_energy' + suffix) force, virial, atom_virial \ = self.descrpt.prod_force_virial (atom_ener, natoms) if self.srtab is not None: sw_force \ = op_module.soft_min_force(energy_diff, sw_deriv, nlist, natoms, n_a_sel = nnei_a, n_r_sel = nnei_r) force = force + sw_force + tab_force force = tf.reshape(force, [-1, 3 * natoms[1]], name="o_force" + suffix) if self.srtab is not None: sw_virial, sw_atom_virial \ = op_module.soft_min_virial (energy_diff, sw_deriv, rij, nlist, natoms, n_a_sel = nnei_a, n_r_sel = nnei_r) atom_virial = atom_virial + sw_atom_virial + tab_atom_virial virial = virial + sw_virial \ + tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, natoms[1], 9]), axis = 1) virial = tf.reshape(virial, [-1, 9], name="o_virial" + suffix) atom_virial = tf.reshape(atom_virial, [-1, 9 * natoms[1]], name="o_atom_virial" + suffix) model_dict = {} model_dict['energy'] = energy model_dict['force'] = force model_dict['virial'] = virial model_dict['atom_ener'] = energy_raw model_dict['atom_virial'] = atom_virial model_dict['coord'] = coord model_dict['atype'] = atype return model_dict
def comp_ef(self, dcoord, dbox, dtype, tnatoms, name, reuse=None): descrpt, descrpt_deriv, rij, nlist \ = op_module.prod_env_mat_a (dcoord, dtype, tnatoms, dbox, tf.constant(self.default_mesh), self.t_avg, self.t_std, rcut_a = self.rcut_a, rcut_r = self.rcut_r, rcut_r_smth = self.rcut_r_smth, sel_a = self.sel_a, sel_r = self.sel_r) inputs_reshape = tf.reshape(descrpt, [-1, self.ndescrpt]) atom_ener = self._net(inputs_reshape, name, reuse=reuse) sw_lambda, sw_deriv \ = op_module.soft_min_switch(dtype, rij, nlist, tnatoms, sel_a = self.sel_a, sel_r = self.sel_r, alpha = self.smin_alpha, rmin = self.sw_rmin, rmax = self.sw_rmax) inv_sw_lambda = 1.0 - sw_lambda tab_atom_ener, tab_force, tab_atom_virial \ = op_module.pair_tab( self.tab_info, self.tab_data, dtype, rij, nlist, tnatoms, sw_lambda, sel_a = self.sel_a, sel_r = self.sel_r) energy_diff = tab_atom_ener - tf.reshape(atom_ener, [-1, self.natoms[0]]) tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape( tab_atom_ener, [-1]) atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener energy_raw = tab_atom_ener + atom_ener energy_raw = tf.reshape(energy_raw, [-1, self.natoms[0]]) energy = tf.reduce_sum(energy_raw, axis=1) net_deriv_ = tf.gradients(atom_ener, inputs_reshape) net_deriv = net_deriv_[0] net_deriv_reshape = tf.reshape(net_deriv, [-1, self.natoms[0] * self.ndescrpt]) force = op_module.prod_force_se_a(net_deriv_reshape, descrpt_deriv, nlist, tnatoms, n_a_sel=self.nnei_a, n_r_sel=self.nnei_r) sw_force \ = op_module.soft_min_force(energy_diff, sw_deriv, nlist, tnatoms, n_a_sel = self.nnei_a, n_r_sel = self.nnei_r) force = force + sw_force + tab_force virial, atom_vir = op_module.prod_virial_se_a(net_deriv_reshape, descrpt_deriv, rij, nlist, tnatoms, n_a_sel=self.nnei_a, n_r_sel=self.nnei_r) sw_virial, sw_atom_virial \ = op_module.soft_min_virial (energy_diff, sw_deriv, rij, nlist, tnatoms, n_a_sel = self.nnei_a, n_r_sel = self.nnei_r) # atom_virial = atom_virial + sw_atom_virial + tab_atom_virial virial = virial + sw_virial \ + tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, self.natoms[1], 9]), axis = 1) return energy, force, virial