示例#1
0
    def build (self, 
               learning_rate,
               natoms,
               model_dict,
               label_dict,
               suffix):        
        coord = model_dict['coord']
        energy = model_dict['energy']
        atom_ener = model_dict['atom_ener']
        nframes = tf.shape(atom_ener)[0]
        natoms = tf.shape(atom_ener)[1]
        # build energy dipole
        atom_ener0 = atom_ener - tf.reshape(tf.tile(tf.reshape(energy/global_cvt_2_ener_float(natoms), [-1, 1]), [1, natoms]), [nframes, natoms])
        coord = tf.reshape(coord, [nframes, natoms, 3])
        atom_ener0 = tf.reshape(atom_ener0, [nframes, 1, natoms])
        ener_dipole = tf.matmul(atom_ener0, coord)
        ener_dipole = tf.reshape(ener_dipole, [nframes, 3])
        
        energy_hat = label_dict['energy']
        ener_dipole_hat = label_dict['energy_dipole']
        find_energy = label_dict['find_energy']
        find_ener_dipole = label_dict['find_energy_dipole']                

        l2_ener_loss = tf.reduce_mean( tf.square(energy - energy_hat), name='l2_'+suffix)

        ener_dipole_reshape = tf.reshape(ener_dipole, [-1])
        ener_dipole_hat_reshape = tf.reshape(ener_dipole_hat, [-1])
        l2_ener_dipole_loss = tf.reduce_mean( tf.square(ener_dipole_reshape - ener_dipole_hat_reshape), name='l2_'+suffix)

        # atom_norm_ener  = 1./ global_cvt_2_ener_float(natoms[0]) 
        atom_norm_ener  = 1./ global_cvt_2_ener_float(natoms) 
        pref_e  = global_cvt_2_ener_float(find_energy * (self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * learning_rate / self.starter_learning_rate) )
        pref_ed = global_cvt_2_tf_float(find_ener_dipole * (self.limit_pref_ed + (self.start_pref_ed - self.limit_pref_ed) * learning_rate / self.starter_learning_rate) )

        l2_loss = 0
        more_loss = {}
        l2_loss += atom_norm_ener * (pref_e * l2_ener_loss)
        l2_loss += global_cvt_2_ener_float(pref_ed * l2_ener_dipole_loss)
        more_loss['l2_ener_loss'] = l2_ener_loss
        more_loss['l2_ener_dipole_loss'] = l2_ener_dipole_loss

        self.l2_loss_summary = tf.summary.scalar('l2_loss', tf.sqrt(l2_loss))
        self.l2_loss_ener_summary = tf.summary.scalar('l2_ener_loss', tf.sqrt(l2_ener_loss) / global_cvt_2_tf_float(natoms[0]))
        self.l2_ener_dipole_loss_summary = tf.summary.scalar('l2_ener_dipole_loss', tf.sqrt(l2_ener_dipole_loss))

        self.l2_l = l2_loss
        self.l2_more = more_loss
        return l2_loss, more_loss
示例#2
0
    def build(self,
              coord_,
              atype_,
              natoms,
              box,
              mesh,
              input_dict,
              frz_model=None,
              suffix='',
              reuse=None):

        with tf.variable_scope('model_attr' + suffix, reuse=reuse):
            t_tmap = tf.constant(' '.join(self.type_map),
                                 name='tmap',
                                 dtype=tf.string)
            t_mt = tf.constant(self.model_type,
                               name='model_type',
                               dtype=tf.string)
            t_ver = tf.constant(MODEL_VERSION,
                                name='model_version',
                                dtype=tf.string)

            if self.srtab is not None:
                tab_info, tab_data = self.srtab.get()
                self.tab_info = tf.get_variable(
                    't_tab_info',
                    tab_info.shape,
                    dtype=tf.float64,
                    trainable=False,
                    initializer=tf.constant_initializer(tab_info,
                                                        dtype=tf.float64))
                self.tab_data = tf.get_variable(
                    't_tab_data',
                    tab_data.shape,
                    dtype=tf.float64,
                    trainable=False,
                    initializer=tf.constant_initializer(tab_data,
                                                        dtype=tf.float64))

        coord = tf.reshape(coord_, [-1, natoms[1] * 3])
        atype = tf.reshape(atype_, [-1, natoms[1]])

        # type embedding if any
        if self.typeebd is not None:
            type_embedding = self.typeebd.build(
                self.ntypes,
                reuse=reuse,
                suffix=suffix,
            )
            input_dict['type_embedding'] = type_embedding

        if frz_model == None:
            dout \
                = self.descrpt.build(coord_,
                                     atype_,
                                     natoms,
                                     box,
                                     mesh,
                                     input_dict,
                                     suffix = suffix,
                                     reuse = reuse)
            dout = tf.identity(dout, name='o_descriptor')
        else:
            tf.constant(self.rcut,
                        name='descrpt_attr/rcut',
                        dtype=GLOBAL_TF_FLOAT_PRECISION)
            tf.constant(self.ntypes,
                        name='descrpt_attr/ntypes',
                        dtype=tf.int32)
            feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box,
                                                   mesh)
            return_elements = [
                *self.descrpt.get_tensor_names(), 'o_descriptor:0'
            ]
            imported_tensors \
                = self._import_graph_def_from_frz_model(frz_model, feed_dict, return_elements)
            dout = imported_tensors[-1]
            self.descrpt.pass_tensors_from_frz_model(*imported_tensors[:-1])

        if self.srtab is not None:
            nlist, rij, sel_a, sel_r = self.descrpt.get_nlist()
            nnei_a = np.cumsum(sel_a)[-1]
            nnei_r = np.cumsum(sel_r)[-1]

        atom_ener = self.fitting.build(dout,
                                       natoms,
                                       input_dict,
                                       reuse=reuse,
                                       suffix=suffix)

        if self.srtab is not None:
            sw_lambda, sw_deriv \
                = op_module.soft_min_switch(atype,
                                            rij,
                                            nlist,
                                            natoms,
                                            sel_a = sel_a,
                                            sel_r = sel_r,
                                            alpha = self.smin_alpha,
                                            rmin = self.sw_rmin,
                                            rmax = self.sw_rmax)
            inv_sw_lambda = 1.0 - sw_lambda
            # NOTICE:
            # atom energy is not scaled,
            # force and virial are scaled
            tab_atom_ener, tab_force, tab_atom_virial \
                = op_module.pair_tab(self.tab_info,
                                      self.tab_data,
                                      atype,
                                      rij,
                                      nlist,
                                      natoms,
                                      sw_lambda,
                                      sel_a = sel_a,
                                      sel_r = sel_r)
            energy_diff = tab_atom_ener - tf.reshape(atom_ener,
                                                     [-1, natoms[0]])
            tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape(
                tab_atom_ener, [-1])
            atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener
            energy_raw = tab_atom_ener + atom_ener
        else:
            energy_raw = atom_ener

        energy_raw = tf.reshape(energy_raw, [-1, natoms[0]],
                                name='o_atom_energy' + suffix)
        energy = tf.reduce_sum(global_cvt_2_ener_float(energy_raw),
                               axis=1,
                               name='o_energy' + suffix)

        force, virial, atom_virial \
            = self.descrpt.prod_force_virial (atom_ener, natoms)

        if self.srtab is not None:
            sw_force \
                = op_module.soft_min_force(energy_diff,
                                           sw_deriv,
                                           nlist,
                                           natoms,
                                           n_a_sel = nnei_a,
                                           n_r_sel = nnei_r)
            force = force + sw_force + tab_force

        force = tf.reshape(force, [-1, 3 * natoms[1]], name="o_force" + suffix)

        if self.srtab is not None:
            sw_virial, sw_atom_virial \
                = op_module.soft_min_virial (energy_diff,
                                             sw_deriv,
                                             rij,
                                             nlist,
                                             natoms,
                                             n_a_sel = nnei_a,
                                             n_r_sel = nnei_r)
            atom_virial = atom_virial + sw_atom_virial + tab_atom_virial
            virial = virial + sw_virial \
                     + tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, natoms[1], 9]), axis = 1)

        virial = tf.reshape(virial, [-1, 9], name="o_virial" + suffix)
        atom_virial = tf.reshape(atom_virial, [-1, 9 * natoms[1]],
                                 name="o_atom_virial" + suffix)

        model_dict = {}
        model_dict['energy'] = energy
        model_dict['force'] = force
        model_dict['virial'] = virial
        model_dict['atom_ener'] = energy_raw
        model_dict['atom_virial'] = atom_virial
        model_dict['coord'] = coord
        model_dict['atype'] = atype

        return model_dict
示例#3
0
    def build (self, 
               learning_rate,
               natoms,
               model_dict,
               label_dict,
               suffix):        
        energy = model_dict['energy']
        force = model_dict['force']
        virial = model_dict['virial']
        atom_ener = model_dict['atom_ener']
        energy_hat = label_dict['energy']
        force_hat = label_dict['force']
        virial_hat = label_dict['virial']
        atom_ener_hat = label_dict['atom_ener']
        atom_pref = label_dict['atom_pref']
        find_energy = label_dict['find_energy']
        find_force = label_dict['find_force']
        find_virial = label_dict['find_virial']
        find_atom_ener = label_dict['find_atom_ener']                
        find_atom_pref = label_dict['find_atom_pref']                

        l2_ener_loss = tf.reduce_mean( tf.square(energy - energy_hat), name='l2_'+suffix)

        force_reshape = tf.reshape (force, [-1])
        force_hat_reshape = tf.reshape (force_hat, [-1])
        atom_pref_reshape = tf.reshape (atom_pref, [-1])
        diff_f = force_hat_reshape - force_reshape
        if self.relative_f is not None:            
            force_hat_3 = tf.reshape(force_hat, [-1, 3])
            norm_f = tf.reshape(tf.norm(force_hat_3, axis = 1), [-1, 1]) + self.relative_f
            diff_f_3 = tf.reshape(diff_f, [-1, 3])
            diff_f_3 = diff_f_3 / norm_f
            diff_f = tf.reshape(diff_f_3, [-1])
        l2_force_loss = tf.reduce_mean(tf.square(diff_f), name = "l2_force_" + suffix)
        l2_pref_force_loss = tf.reduce_mean(tf.multiply(tf.square(diff_f), atom_pref_reshape), name = "l2_pref_force_" + suffix)

        virial_reshape = tf.reshape (virial, [-1])
        virial_hat_reshape = tf.reshape (virial_hat, [-1])
        l2_virial_loss = tf.reduce_mean (tf.square(virial_hat_reshape - virial_reshape), name = "l2_virial_" + suffix)

        atom_ener_reshape = tf.reshape (atom_ener, [-1])
        atom_ener_hat_reshape = tf.reshape (atom_ener_hat, [-1])
        l2_atom_ener_loss = tf.reduce_mean (tf.square(atom_ener_hat_reshape - atom_ener_reshape), name = "l2_atom_ener_" + suffix)

        atom_norm  = 1./ global_cvt_2_tf_float(natoms[0]) 
        atom_norm_ener  = 1./ global_cvt_2_ener_float(natoms[0]) 
        pref_e = global_cvt_2_ener_float(find_energy * (self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * learning_rate / self.starter_learning_rate) )
        pref_f = global_cvt_2_tf_float(find_force * (self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * learning_rate / self.starter_learning_rate) )
        pref_v = global_cvt_2_tf_float(find_virial * (self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * learning_rate / self.starter_learning_rate) )
        pref_ae= global_cvt_2_tf_float(find_atom_ener * (self.limit_pref_ae+ (self.start_pref_ae-self.limit_pref_ae) * learning_rate / self.starter_learning_rate) )
        pref_pf= global_cvt_2_tf_float(find_atom_pref * (self.limit_pref_pf+ (self.start_pref_pf-self.limit_pref_pf) * learning_rate / self.starter_learning_rate) )

        l2_loss = 0
        more_loss = {}
        if self.has_e :
            l2_loss += atom_norm_ener * (pref_e * l2_ener_loss)
        more_loss['l2_ener_loss'] = l2_ener_loss
        if self.has_f :
            l2_loss += global_cvt_2_ener_float(pref_f * l2_force_loss)
        more_loss['l2_force_loss'] = l2_force_loss
        if self.has_v :
            l2_loss += global_cvt_2_ener_float(atom_norm * (pref_v * l2_virial_loss))
        more_loss['l2_virial_loss'] = l2_virial_loss
        if self.has_ae :
            l2_loss += global_cvt_2_ener_float(pref_ae * l2_atom_ener_loss)
        more_loss['l2_atom_ener_loss'] = l2_atom_ener_loss
        if self.has_pf :
            l2_loss += global_cvt_2_ener_float(pref_pf * l2_pref_force_loss)
        more_loss['l2_pref_force_loss'] = l2_pref_force_loss

        # only used when tensorboard was set as true
        self.l2_loss_summary = tf.summary.scalar('l2_loss', tf.sqrt(l2_loss))
        self.l2_loss_ener_summary = tf.summary.scalar('l2_ener_loss', global_cvt_2_tf_float(tf.sqrt(l2_ener_loss)) / global_cvt_2_tf_float(natoms[0]))
        self.l2_loss_force_summary = tf.summary.scalar('l2_force_loss', tf.sqrt(l2_force_loss))
        self.l2_loss_virial_summary = tf.summary.scalar('l2_virial_loss', tf.sqrt(l2_virial_loss) / global_cvt_2_tf_float(natoms[0]))

        self.l2_l = l2_loss
        self.l2_more = more_loss
        return l2_loss, more_loss