class EnerModel(): """Energy model. Parameters ---------- descrpt Descriptor fitting Fitting net type_map Mapping atom type to the name (str) of the type. For example `type_map[1]` gives the name of the type 1. data_stat_nbatch Number of frames used for data statistic data_stat_protect Protect parameter for atomic energy regression use_srtab The table for the short-range pairwise interaction added on top of DP. The table is a text data file with (N_t + 1) * N_t / 2 + 1 columes. The first colume is the distance between atoms. The second to the last columes are energies for pairs of certain types. For example we have two atom types, 0 and 1. The columes from 2nd to 4th are for 0-0, 0-1 and 1-1 correspondingly. smin_alpha The short-range tabulated interaction will be swithed according to the distance of the nearest neighbor. This distance is calculated by softmin. This parameter is the decaying parameter in the softmin. It is only required when `use_srtab` is provided. sw_rmin The lower boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided. sw_rmin The upper boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided. """ model_type = 'ener' def __init__(self, descrpt, fitting, typeebd=None, type_map: List[str] = None, data_stat_nbatch: int = 10, data_stat_protect: float = 1e-2, use_srtab: str = None, smin_alpha: float = None, sw_rmin: float = None, sw_rmax: float = None) -> None: """ Constructor """ # descriptor self.descrpt = descrpt self.rcut = self.descrpt.get_rcut() self.ntypes = self.descrpt.get_ntypes() # fitting self.fitting = fitting self.numb_fparam = self.fitting.get_numb_fparam() # type embedding self.typeebd = typeebd # other inputs if type_map is None: self.type_map = [] else: self.type_map = type_map self.data_stat_nbatch = data_stat_nbatch self.data_stat_protect = data_stat_protect self.srtab_name = use_srtab if self.srtab_name is not None: self.srtab = PairTab(self.srtab_name) self.smin_alpha = smin_alpha self.sw_rmin = sw_rmin self.sw_rmax = sw_rmax else: self.srtab = None def get_rcut(self): return self.rcut def get_ntypes(self): return self.ntypes def get_type_map(self): return self.type_map def data_stat(self, data): all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False) m_all_stat = merge_sys_stat(all_stat) self._compute_input_stat(m_all_stat, protection=self.data_stat_protect) self._compute_output_stat(all_stat) # self.bias_atom_e = data.compute_energy_shift(self.rcond) def _compute_input_stat(self, all_stat, protection=1e-2): self.descrpt.compute_input_stats(all_stat['coord'], all_stat['box'], all_stat['type'], all_stat['natoms_vec'], all_stat['default_mesh'], all_stat) self.fitting.compute_input_stats(all_stat, protection=protection) def _compute_output_stat(self, all_stat): self.fitting.compute_output_stats(all_stat) def build(self, coord_, atype_, natoms, box, mesh, input_dict, frz_model=None, suffix='', reuse=None): with tf.variable_scope('model_attr' + suffix, reuse=reuse): t_tmap = tf.constant(' '.join(self.type_map), name='tmap', dtype=tf.string) t_mt = tf.constant(self.model_type, name='model_type', dtype=tf.string) t_ver = tf.constant(MODEL_VERSION, name='model_version', dtype=tf.string) if self.srtab is not None: tab_info, tab_data = self.srtab.get() self.tab_info = tf.get_variable( 't_tab_info', tab_info.shape, dtype=tf.float64, trainable=False, initializer=tf.constant_initializer(tab_info, dtype=tf.float64)) self.tab_data = tf.get_variable( 't_tab_data', tab_data.shape, dtype=tf.float64, trainable=False, initializer=tf.constant_initializer(tab_data, dtype=tf.float64)) coord = tf.reshape(coord_, [-1, natoms[1] * 3]) atype = tf.reshape(atype_, [-1, natoms[1]]) # type embedding if any if self.typeebd is not None: type_embedding = self.typeebd.build( self.ntypes, reuse=reuse, suffix=suffix, ) input_dict['type_embedding'] = type_embedding if frz_model == None: dout \ = self.descrpt.build(coord_, atype_, natoms, box, mesh, input_dict, suffix = suffix, reuse = reuse) dout = tf.identity(dout, name='o_descriptor') else: tf.constant(self.rcut, name='descrpt_attr/rcut', dtype=GLOBAL_TF_FLOAT_PRECISION) tf.constant(self.ntypes, name='descrpt_attr/ntypes', dtype=tf.int32) feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box, mesh) return_elements = [ *self.descrpt.get_tensor_names(), 'o_descriptor:0' ] imported_tensors \ = self._import_graph_def_from_frz_model(frz_model, feed_dict, return_elements) dout = imported_tensors[-1] self.descrpt.pass_tensors_from_frz_model(*imported_tensors[:-1]) if self.srtab is not None: nlist, rij, sel_a, sel_r = self.descrpt.get_nlist() nnei_a = np.cumsum(sel_a)[-1] nnei_r = np.cumsum(sel_r)[-1] atom_ener = self.fitting.build(dout, natoms, input_dict, reuse=reuse, suffix=suffix) if self.srtab is not None: sw_lambda, sw_deriv \ = op_module.soft_min_switch(atype, rij, nlist, natoms, sel_a = sel_a, sel_r = sel_r, alpha = self.smin_alpha, rmin = self.sw_rmin, rmax = self.sw_rmax) inv_sw_lambda = 1.0 - sw_lambda # NOTICE: # atom energy is not scaled, # force and virial are scaled tab_atom_ener, tab_force, tab_atom_virial \ = op_module.pair_tab(self.tab_info, self.tab_data, atype, rij, nlist, natoms, sw_lambda, sel_a = sel_a, sel_r = sel_r) energy_diff = tab_atom_ener - tf.reshape(atom_ener, [-1, natoms[0]]) tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape( tab_atom_ener, [-1]) atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener energy_raw = tab_atom_ener + atom_ener else: energy_raw = atom_ener energy_raw = tf.reshape(energy_raw, [-1, natoms[0]], name='o_atom_energy' + suffix) energy = tf.reduce_sum(global_cvt_2_ener_float(energy_raw), axis=1, name='o_energy' + suffix) force, virial, atom_virial \ = self.descrpt.prod_force_virial (atom_ener, natoms) if self.srtab is not None: sw_force \ = op_module.soft_min_force(energy_diff, sw_deriv, nlist, natoms, n_a_sel = nnei_a, n_r_sel = nnei_r) force = force + sw_force + tab_force force = tf.reshape(force, [-1, 3 * natoms[1]], name="o_force" + suffix) if self.srtab is not None: sw_virial, sw_atom_virial \ = op_module.soft_min_virial (energy_diff, sw_deriv, rij, nlist, natoms, n_a_sel = nnei_a, n_r_sel = nnei_r) atom_virial = atom_virial + sw_atom_virial + tab_atom_virial virial = virial + sw_virial \ + tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, natoms[1], 9]), axis = 1) virial = tf.reshape(virial, [-1, 9], name="o_virial" + suffix) atom_virial = tf.reshape(atom_virial, [-1, 9 * natoms[1]], name="o_atom_virial" + suffix) model_dict = {} model_dict['energy'] = energy model_dict['force'] = force model_dict['virial'] = virial model_dict['atom_ener'] = energy_raw model_dict['atom_virial'] = atom_virial model_dict['coord'] = coord model_dict['atype'] = atype return model_dict def _import_graph_def_from_frz_model(self, frz_model, feed_dict, return_elements): graph, graph_def = load_graph_def(frz_model) return tf.import_graph_def(graph_def, input_map=feed_dict, return_elements=return_elements)
class IntplInter(Inter): def setUp(self, data, sess=None): # tabulated Inter.setUp(self, data, sess=sess) _make_tab(data.get_ntypes()) self.srtab = PairTab('tab.xvg') self.smin_alpha = 0.3 self.sw_rmin = 1 self.sw_rmax = 3.45 tab_info, tab_data = self.srtab.get() with tf.variable_scope('tab', reuse=tf.AUTO_REUSE): self.tab_info = tf.get_variable( 't_tab_info', tab_info.shape, dtype=tf.float64, trainable=False, initializer=tf.constant_initializer(tab_info)) self.tab_data = tf.get_variable( 't_tab_data', tab_data.shape, dtype=tf.float64, trainable=False, initializer=tf.constant_initializer(tab_data)) def tearDown(self): os.remove('tab.xvg') def comp_ef(self, dcoord, dbox, dtype, tnatoms, name, reuse=None): descrpt, descrpt_deriv, rij, nlist \ = op_module.prod_env_mat_a (dcoord, dtype, tnatoms, dbox, tf.constant(self.default_mesh), self.t_avg, self.t_std, rcut_a = self.rcut_a, rcut_r = self.rcut_r, rcut_r_smth = self.rcut_r_smth, sel_a = self.sel_a, sel_r = self.sel_r) inputs_reshape = tf.reshape(descrpt, [-1, self.ndescrpt]) atom_ener = self._net(inputs_reshape, name, reuse=reuse) sw_lambda, sw_deriv \ = op_module.soft_min_switch(dtype, rij, nlist, tnatoms, sel_a = self.sel_a, sel_r = self.sel_r, alpha = self.smin_alpha, rmin = self.sw_rmin, rmax = self.sw_rmax) inv_sw_lambda = 1.0 - sw_lambda tab_atom_ener, tab_force, tab_atom_virial \ = op_module.pair_tab( self.tab_info, self.tab_data, dtype, rij, nlist, tnatoms, sw_lambda, sel_a = self.sel_a, sel_r = self.sel_r) energy_diff = tab_atom_ener - tf.reshape(atom_ener, [-1, self.natoms[0]]) tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape( tab_atom_ener, [-1]) atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener energy_raw = tab_atom_ener + atom_ener energy_raw = tf.reshape(energy_raw, [-1, self.natoms[0]]) energy = tf.reduce_sum(energy_raw, axis=1) net_deriv_ = tf.gradients(atom_ener, inputs_reshape) net_deriv = net_deriv_[0] net_deriv_reshape = tf.reshape(net_deriv, [-1, self.natoms[0] * self.ndescrpt]) force = op_module.prod_force_se_a(net_deriv_reshape, descrpt_deriv, nlist, tnatoms, n_a_sel=self.nnei_a, n_r_sel=self.nnei_r) sw_force \ = op_module.soft_min_force(energy_diff, sw_deriv, nlist, tnatoms, n_a_sel = self.nnei_a, n_r_sel = self.nnei_r) force = force + sw_force + tab_force virial, atom_vir = op_module.prod_virial_se_a(net_deriv_reshape, descrpt_deriv, rij, nlist, tnatoms, n_a_sel=self.nnei_a, n_r_sel=self.nnei_r) sw_virial, sw_atom_virial \ = op_module.soft_min_virial (energy_diff, sw_deriv, rij, nlist, tnatoms, n_a_sel = self.nnei_a, n_r_sel = self.nnei_r) # atom_virial = atom_virial + sw_atom_virial + tab_atom_virial virial = virial + sw_virial \ + tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, self.natoms[1], 9]), axis = 1) return energy, force, virial