Ejemplo n.º 1
0
    def __init__(self, 
                 rcut: float,
                 sel_a : List[int],
                 sel_r : List[int],
                 axis_rule : List[int]
    ) -> None:
        """
        Constructor    
        """
        # args = ClassArg()\
        #        .add('sel_a',    list,   must = True) \
        #        .add('sel_r',    list,   must = True) \
        #        .add('rcut',     float,  default = 6.0) \
        #        .add('axis_rule',list,   must = True)
        # class_data = args.parse(jdata)
        self.sel_a = sel_a
        self.sel_r = sel_r
        self.axis_rule = axis_rule
        self.rcut_r = rcut
        # ntypes and rcut_a === -1
        self.ntypes = len(self.sel_a)
        assert(self.ntypes == len(self.sel_r))
        self.rcut_a = -1
        # numb of neighbors and numb of descrptors
        self.nnei_a = np.cumsum(self.sel_a)[-1]
        self.nnei_r = np.cumsum(self.sel_r)[-1]
        self.nnei = self.nnei_a + self.nnei_r
        self.ndescrpt_a = self.nnei_a * 4
        self.ndescrpt_r = self.nnei_r * 1
        self.ndescrpt = self.ndescrpt_a + self.ndescrpt_r
        self.davg = None
        self.dstd = None

        self.place_holders = {}
        avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
        std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
        sub_graph = tf.Graph()
        with sub_graph.as_default():
            name_pfx = 'd_lf_'
            for ii in ['coord', 'box']:
                self.place_holders[ii] = tf.placeholder(GLOBAL_NP_FLOAT_PRECISION, [None, None], name = name_pfx+'t_'+ii)
            self.place_holders['type'] = tf.placeholder(tf.int32, [None, None], name=name_pfx+'t_type')
            self.place_holders['natoms_vec'] = tf.placeholder(tf.int32, [self.ntypes+2], name=name_pfx+'t_natoms')
            self.place_holders['default_mesh'] = tf.placeholder(tf.int32, [None], name=name_pfx+'t_mesh')
            self.stat_descrpt, descrpt_deriv, rij, nlist, axis, rot_mat \
                = op_module.descrpt (self.place_holders['coord'],
                                     self.place_holders['type'],
                                     self.place_holders['natoms_vec'],
                                     self.place_holders['box'],
                                     self.place_holders['default_mesh'],
                                     tf.constant(avg_zero),
                                     tf.constant(std_ones),
                                     rcut_a = self.rcut_a,
                                     rcut_r = self.rcut_r,
                                     sel_a = self.sel_a,
                                     sel_r = self.sel_r,
                                     axis_rule = self.axis_rule)
        self.sub_sess = tf.Session(graph = sub_graph, config=default_tf_session_config)
Ejemplo n.º 2
0
    def build (self, 
               coord_, 
               atype_,
               natoms,
               box_, 
               mesh,
               suffix = '', 
               reuse = None):
        davg = self.davg
        dstd = self.dstd
        with tf.variable_scope('descrpt_attr' + suffix, reuse = reuse) :
            if davg is None:
                davg = np.zeros([self.ntypes, self.ndescrpt]) 
            if dstd is None:
                dstd = np.ones ([self.ntypes, self.ndescrpt])
            t_rcut = tf.constant(np.max([self.rcut_r, self.rcut_a]), 
                                 name = 'rcut', 
                                 dtype = global_tf_float_precision)
            t_ntypes = tf.constant(self.ntypes, 
                                   name = 'ntypes', 
                                   dtype = tf.int32)
            self.t_avg = tf.get_variable('t_avg', 
                                         davg.shape, 
                                         dtype = global_tf_float_precision,
                                         trainable = False,
                                         initializer = tf.constant_initializer(davg))
            self.t_std = tf.get_variable('t_std', 
                                         dstd.shape, 
                                         dtype = global_tf_float_precision,
                                         trainable = False,
                                         initializer = tf.constant_initializer(dstd))

        coord = tf.reshape (coord_, [-1, natoms[1] * 3])
        box   = tf.reshape (box_, [-1, 9])
        atype = tf.reshape (atype_, [-1, natoms[1]])

        self.descrpt, self.descrpt_deriv, self.rij, self.nlist, self.axis, self.rot_mat \
            = op_module.descrpt (coord,
                                 atype,
                                 natoms,
                                 box,                                    
                                 mesh,
                                 self.t_avg,
                                 self.t_std,
                                 rcut_a = self.rcut_a,
                                 rcut_r = self.rcut_r,
                                 sel_a = self.sel_a,
                                 sel_r = self.sel_r,
                                 axis_rule = self.axis_rule)
        self.descrpt = tf.reshape(self.descrpt, [-1, self.ndescrpt])
        return self.descrpt
Ejemplo n.º 3
0
    def comp_ef(self, dcoord, dbox, dtype, tnatoms, name, reuse=None):
        t_default_mesh = tf.constant(self.default_mesh)
        descrpt, descrpt_deriv, rij, nlist, axis, rot_mat \
            = op_module.descrpt (dcoord,
                                 dtype,
                                 tnatoms,
                                 dbox,
                                 t_default_mesh,
                                 self.t_avg,
                                 self.t_std,
                                 rcut_a = self.rcut_a,
                                 rcut_r = self.rcut_r,
                                 sel_a = self.sel_a,
                                 sel_r = self.sel_r,
                                 axis_rule = self.axis_rule)
        self.axis = axis
        self.nlist = nlist
        self.descrpt = descrpt
        inputs_reshape = tf.reshape(descrpt, [-1, self.ndescrpt])
        atom_ener = self._net(inputs_reshape, name, reuse=reuse)
        atom_ener_reshape = tf.reshape(atom_ener, [-1, self.natoms[0]])
        energy = tf.reduce_sum(atom_ener_reshape, axis=1)
        net_deriv_ = tf.gradients(atom_ener, inputs_reshape)
        net_deriv = net_deriv_[0]
        net_deriv_reshape = tf.reshape(net_deriv,
                                       [-1, self.natoms[0] * self.ndescrpt])

        force = op_module.prod_force(net_deriv_reshape,
                                     descrpt_deriv,
                                     nlist,
                                     axis,
                                     tnatoms,
                                     n_a_sel=self.nnei_a,
                                     n_r_sel=self.nnei_r)
        virial, atom_vir = op_module.prod_virial(net_deriv_reshape,
                                                 descrpt_deriv,
                                                 rij,
                                                 nlist,
                                                 axis,
                                                 tnatoms,
                                                 n_a_sel=self.nnei_a,
                                                 n_r_sel=self.nnei_r)
        return energy, force, virial
Ejemplo n.º 4
0
    def build(self,
              coord_: tf.Tensor,
              atype_: tf.Tensor,
              natoms: tf.Tensor,
              box_: tf.Tensor,
              mesh: tf.Tensor,
              input_dict: dict,
              reuse: bool = None,
              suffix: str = '') -> tf.Tensor:
        """
        Build the computational graph for the descriptor

        Parameters
        ----------
        coord_
                The coordinate of atoms
        atype_
                The type of atoms
        natoms
                The number of atoms. This tensor has the length of Ntypes + 2
                natoms[0]: number of local atoms
                natoms[1]: total number of atoms held by this processor
                natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
        mesh
                For historical reasons, only the length of the Tensor matters.
                if size of mesh == 6, pbc is assumed. 
                if size of mesh == 0, no-pbc is assumed. 
        input_dict
                Dictionary for additional inputs
        reuse
                The weights in the networks should be reused when get the variable.
        suffix
                Name suffix to identify this descriptor

        Returns
        -------
        descriptor
                The output descriptor
        """
        davg = self.davg
        dstd = self.dstd
        with tf.variable_scope('descrpt_attr' + suffix, reuse=reuse):
            if davg is None:
                davg = np.zeros([self.ntypes, self.ndescrpt])
            if dstd is None:
                dstd = np.ones([self.ntypes, self.ndescrpt])
            t_rcut = tf.constant(np.max([self.rcut_r, self.rcut_a]),
                                 name='rcut',
                                 dtype=GLOBAL_TF_FLOAT_PRECISION)
            t_ntypes = tf.constant(self.ntypes, name='ntypes', dtype=tf.int32)
            self.t_avg = tf.get_variable(
                't_avg',
                davg.shape,
                dtype=GLOBAL_TF_FLOAT_PRECISION,
                trainable=False,
                initializer=tf.constant_initializer(davg))
            self.t_std = tf.get_variable(
                't_std',
                dstd.shape,
                dtype=GLOBAL_TF_FLOAT_PRECISION,
                trainable=False,
                initializer=tf.constant_initializer(dstd))

        coord = tf.reshape(coord_, [-1, natoms[1] * 3])
        box = tf.reshape(box_, [-1, 9])
        atype = tf.reshape(atype_, [-1, natoms[1]])

        self.descrpt, self.descrpt_deriv, self.rij, self.nlist, self.axis, self.rot_mat \
            = op_module.descrpt (coord,
                                 atype,
                                 natoms,
                                 box,
                                 mesh,
                                 self.t_avg,
                                 self.t_std,
                                 rcut_a = self.rcut_a,
                                 rcut_r = self.rcut_r,
                                 sel_a = self.sel_a,
                                 sel_r = self.sel_r,
                                 axis_rule = self.axis_rule)
        self.descrpt = tf.reshape(self.descrpt, [-1, self.ndescrpt])
        tf.summary.histogram('descrpt', self.descrpt)
        tf.summary.histogram('rij', self.rij)
        tf.summary.histogram('nlist', self.nlist)

        return self.descrpt
Ejemplo n.º 5
0
    def comp_interpl_ef(self, dcoord, dbox, dtype, tnatoms, name, reuse=None):
        descrpt, descrpt_deriv, rij, nlist, axis \
            = op_module.descrpt (dcoord,
                                 dtype,
                                 tnatoms,
                                 dbox,
                                 tf.constant(self.default_mesh),
                                 self.t_avg,
                                 self.t_std,
                                 rcut_a = self.rcut_a,
                                 rcut_r = self.rcut_r,
                                 sel_a = self.sel_a,
                                 sel_r = self.sel_r,
                                 axis_rule = self.axis_rule)
        inputs_reshape = tf.reshape(descrpt, [-1, self.ndescrpt])
        atom_ener = self._net(inputs_reshape, name, reuse=reuse)

        sw_lambda, sw_deriv \
            = op_module.soft_min_switch(dtype,
                                        rij,
                                        nlist,
                                        tnatoms,
                                        sel_a = self.sel_a,
                                        sel_r = self.sel_r,
                                        alpha = self.smin_alpha,
                                        rmin = self.sw_rmin,
                                        rmax = self.sw_rmax)
        inv_sw_lambda = 1.0 - sw_lambda
        tab_atom_ener, tab_force, tab_atom_virial \
            = op_module.pair_tab(
                self.tab_info,
                self.tab_data,
                dtype,
                rij,
                nlist,
                tnatoms,
                sw_lambda,
                sel_a = self.sel_a,
                sel_r = self.sel_r)
        energy_diff = tab_atom_ener - tf.reshape(atom_ener,
                                                 [-1, self.natoms[0]])
        tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape(
            tab_atom_ener, [-1])
        atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener
        energy_raw = tab_atom_ener + atom_ener

        energy_raw = tf.reshape(energy_raw, [-1, self.natoms[0]])
        energy = tf.reduce_sum(energy_raw, axis=1)

        net_deriv_ = tf.gradients(atom_ener, inputs_reshape)
        net_deriv = net_deriv_[0]
        net_deriv_reshape = tf.reshape(net_deriv,
                                       [-1, self.natoms[0] * self.ndescrpt])

        force = op_module.prod_force(net_deriv_reshape,
                                     descrpt_deriv,
                                     nlist,
                                     axis,
                                     tnatoms,
                                     n_a_sel=self.nnei_a,
                                     n_r_sel=self.nnei_r)
        sw_force \
            = op_module.soft_min_force(energy_diff,
                                       sw_deriv,
                                       nlist,
                                       tnatoms,
                                       n_a_sel = self.nnei_a,
                                       n_r_sel = self.nnei_r)
        force = force + sw_force + tab_force
        virial, atom_vir = op_module.prod_virial(net_deriv_reshape,
                                                 descrpt_deriv,
                                                 rij,
                                                 nlist,
                                                 axis,
                                                 tnatoms,
                                                 n_a_sel=self.nnei_a,
                                                 n_r_sel=self.nnei_r)
        sw_virial, sw_atom_virial \
            = op_module.soft_min_virial (energy_diff,
                                         sw_deriv,
                                         rij,
                                         nlist,
                                         tnatoms,
                                         n_a_sel = self.nnei_a,
                                         n_r_sel = self.nnei_r)
        # atom_virial = atom_virial + sw_atom_virial + tab_atom_virial
        virial = virial + sw_virial \
                 + tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, self.natoms[1], 9]), axis = 1)

        return energy, force, virial