Esempio n. 1
0
    def build(self,
              coord_,
              atype_,
              natoms,
              box,
              mesh,
              input_dict,
              frz_model=None,
              suffix='',
              reuse=None):

        with tf.variable_scope('model_attr' + suffix, reuse=reuse):
            t_tmap = tf.constant(' '.join(self.type_map),
                                 name='tmap',
                                 dtype=tf.string)
            t_mt = tf.constant(self.model_type,
                               name='model_type',
                               dtype=tf.string)
            t_ver = tf.constant(MODEL_VERSION,
                                name='model_version',
                                dtype=tf.string)

            if self.srtab is not None:
                tab_info, tab_data = self.srtab.get()
                self.tab_info = tf.get_variable(
                    't_tab_info',
                    tab_info.shape,
                    dtype=tf.float64,
                    trainable=False,
                    initializer=tf.constant_initializer(tab_info,
                                                        dtype=tf.float64))
                self.tab_data = tf.get_variable(
                    't_tab_data',
                    tab_data.shape,
                    dtype=tf.float64,
                    trainable=False,
                    initializer=tf.constant_initializer(tab_data,
                                                        dtype=tf.float64))

        coord = tf.reshape(coord_, [-1, natoms[1] * 3])
        atype = tf.reshape(atype_, [-1, natoms[1]])

        # type embedding if any
        if self.typeebd is not None:
            type_embedding = self.typeebd.build(
                self.ntypes,
                reuse=reuse,
                suffix=suffix,
            )
            input_dict['type_embedding'] = type_embedding

        if frz_model == None:
            dout \
                = self.descrpt.build(coord_,
                                     atype_,
                                     natoms,
                                     box,
                                     mesh,
                                     input_dict,
                                     suffix = suffix,
                                     reuse = reuse)
            dout = tf.identity(dout, name='o_descriptor')
        else:
            tf.constant(self.rcut,
                        name='descrpt_attr/rcut',
                        dtype=GLOBAL_TF_FLOAT_PRECISION)
            tf.constant(self.ntypes,
                        name='descrpt_attr/ntypes',
                        dtype=tf.int32)
            feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box,
                                                   mesh)
            return_elements = [
                *self.descrpt.get_tensor_names(), 'o_descriptor:0'
            ]
            imported_tensors \
                = self._import_graph_def_from_frz_model(frz_model, feed_dict, return_elements)
            dout = imported_tensors[-1]
            self.descrpt.pass_tensors_from_frz_model(*imported_tensors[:-1])

        if self.srtab is not None:
            nlist, rij, sel_a, sel_r = self.descrpt.get_nlist()
            nnei_a = np.cumsum(sel_a)[-1]
            nnei_r = np.cumsum(sel_r)[-1]

        atom_ener = self.fitting.build(dout,
                                       natoms,
                                       input_dict,
                                       reuse=reuse,
                                       suffix=suffix)

        if self.srtab is not None:
            sw_lambda, sw_deriv \
                = op_module.soft_min_switch(atype,
                                            rij,
                                            nlist,
                                            natoms,
                                            sel_a = sel_a,
                                            sel_r = sel_r,
                                            alpha = self.smin_alpha,
                                            rmin = self.sw_rmin,
                                            rmax = self.sw_rmax)
            inv_sw_lambda = 1.0 - sw_lambda
            # NOTICE:
            # atom energy is not scaled,
            # force and virial are scaled
            tab_atom_ener, tab_force, tab_atom_virial \
                = op_module.pair_tab(self.tab_info,
                                      self.tab_data,
                                      atype,
                                      rij,
                                      nlist,
                                      natoms,
                                      sw_lambda,
                                      sel_a = sel_a,
                                      sel_r = sel_r)
            energy_diff = tab_atom_ener - tf.reshape(atom_ener,
                                                     [-1, natoms[0]])
            tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape(
                tab_atom_ener, [-1])
            atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener
            energy_raw = tab_atom_ener + atom_ener
        else:
            energy_raw = atom_ener

        energy_raw = tf.reshape(energy_raw, [-1, natoms[0]],
                                name='o_atom_energy' + suffix)
        energy = tf.reduce_sum(global_cvt_2_ener_float(energy_raw),
                               axis=1,
                               name='o_energy' + suffix)

        force, virial, atom_virial \
            = self.descrpt.prod_force_virial (atom_ener, natoms)

        if self.srtab is not None:
            sw_force \
                = op_module.soft_min_force(energy_diff,
                                           sw_deriv,
                                           nlist,
                                           natoms,
                                           n_a_sel = nnei_a,
                                           n_r_sel = nnei_r)
            force = force + sw_force + tab_force

        force = tf.reshape(force, [-1, 3 * natoms[1]], name="o_force" + suffix)

        if self.srtab is not None:
            sw_virial, sw_atom_virial \
                = op_module.soft_min_virial (energy_diff,
                                             sw_deriv,
                                             rij,
                                             nlist,
                                             natoms,
                                             n_a_sel = nnei_a,
                                             n_r_sel = nnei_r)
            atom_virial = atom_virial + sw_atom_virial + tab_atom_virial
            virial = virial + sw_virial \
                     + tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, natoms[1], 9]), axis = 1)

        virial = tf.reshape(virial, [-1, 9], name="o_virial" + suffix)
        atom_virial = tf.reshape(atom_virial, [-1, 9 * natoms[1]],
                                 name="o_atom_virial" + suffix)

        model_dict = {}
        model_dict['energy'] = energy
        model_dict['force'] = force
        model_dict['virial'] = virial
        model_dict['atom_ener'] = energy_raw
        model_dict['atom_virial'] = atom_virial
        model_dict['coord'] = coord
        model_dict['atype'] = atype

        return model_dict
Esempio n. 2
0
    def comp_ef(self, dcoord, dbox, dtype, tnatoms, name, reuse=None):
        descrpt, descrpt_deriv, rij, nlist \
            = op_module.prod_env_mat_a (dcoord,
                                       dtype,
                                       tnatoms,
                                       dbox,
                                       tf.constant(self.default_mesh),
                                       self.t_avg,
                                       self.t_std,
                                       rcut_a = self.rcut_a,
                                       rcut_r = self.rcut_r,
                                       rcut_r_smth = self.rcut_r_smth,
                                       sel_a = self.sel_a,
                                       sel_r = self.sel_r)
        inputs_reshape = tf.reshape(descrpt, [-1, self.ndescrpt])
        atom_ener = self._net(inputs_reshape, name, reuse=reuse)

        sw_lambda, sw_deriv \
            = op_module.soft_min_switch(dtype,
                                        rij,
                                        nlist,
                                        tnatoms,
                                        sel_a = self.sel_a,
                                        sel_r = self.sel_r,
                                        alpha = self.smin_alpha,
                                        rmin = self.sw_rmin,
                                        rmax = self.sw_rmax)
        inv_sw_lambda = 1.0 - sw_lambda
        tab_atom_ener, tab_force, tab_atom_virial \
            = op_module.pair_tab(
                self.tab_info,
                self.tab_data,
                dtype,
                rij,
                nlist,
                tnatoms,
                sw_lambda,
                sel_a = self.sel_a,
                sel_r = self.sel_r)
        energy_diff = tab_atom_ener - tf.reshape(atom_ener,
                                                 [-1, self.natoms[0]])
        tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape(
            tab_atom_ener, [-1])
        atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener
        energy_raw = tab_atom_ener + atom_ener

        energy_raw = tf.reshape(energy_raw, [-1, self.natoms[0]])
        energy = tf.reduce_sum(energy_raw, axis=1)

        net_deriv_ = tf.gradients(atom_ener, inputs_reshape)
        net_deriv = net_deriv_[0]
        net_deriv_reshape = tf.reshape(net_deriv,
                                       [-1, self.natoms[0] * self.ndescrpt])

        force = op_module.prod_force_se_a(net_deriv_reshape,
                                          descrpt_deriv,
                                          nlist,
                                          tnatoms,
                                          n_a_sel=self.nnei_a,
                                          n_r_sel=self.nnei_r)
        sw_force \
            = op_module.soft_min_force(energy_diff,
                                       sw_deriv,
                                       nlist,
                                       tnatoms,
                                       n_a_sel = self.nnei_a,
                                       n_r_sel = self.nnei_r)
        force = force + sw_force + tab_force
        virial, atom_vir = op_module.prod_virial_se_a(net_deriv_reshape,
                                                      descrpt_deriv,
                                                      rij,
                                                      nlist,
                                                      tnatoms,
                                                      n_a_sel=self.nnei_a,
                                                      n_r_sel=self.nnei_r)
        sw_virial, sw_atom_virial \
            = op_module.soft_min_virial (energy_diff,
                                         sw_deriv,
                                         rij,
                                         nlist,
                                         tnatoms,
                                         n_a_sel = self.nnei_a,
                                         n_r_sel = self.nnei_r)
        # atom_virial = atom_virial + sw_atom_virial + tab_atom_virial
        virial = virial + sw_virial \
                 + tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, self.natoms[1], 9]), axis = 1)

        return energy, force, virial