示例#1
0
class DescrptSeR(DescrptSe):
    """DeepPot-SE constructed from radial information of atomic configurations.
    
    The embedding takes the distance between atoms as input.

    Parameters
    ----------
    rcut
            The cut-off radius
    rcut_smth
            From where the environment matrix should be smoothed
    sel : list[str]
            sel[i] specifies the maxmum number of type i atoms in the cut-off radius
    neuron : list[int]
            Number of neurons in each hidden layers of the embedding net
    resnet_dt
            Time-step `dt` in the resnet construction:
            y = x + dt * \phi (Wx + b)
    trainable
            If the weights of embedding net are trainable.
    seed
            Random seed for initializing the network parameters.
    type_one_side
            Try to build N_types embedding nets. Otherwise, building N_types^2 embedding nets
    exclude_types : List[List[int]]
            The excluded pairs of types which have no interaction with each other.
            For example, `[[0, 1]]` means no interaction between type 0 and type 1.
    activation_function
            The activation function in the embedding net. Supported options are {0}
    precision
            The precision of the embedding net parameters. Supported options are {1}
    uniform_seed
            Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
    """
    @docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()),
                         list_to_doc(PRECISION_DICT.keys()))
    def __init__(self,
                 rcut: float,
                 rcut_smth: float,
                 sel: List[str],
                 neuron: List[int] = [24, 48, 96],
                 resnet_dt: bool = False,
                 trainable: bool = True,
                 seed: int = None,
                 type_one_side: bool = True,
                 exclude_types: List[List[int]] = [],
                 set_davg_zero: bool = False,
                 activation_function: str = 'tanh',
                 precision: str = 'default',
                 uniform_seed: bool = False) -> None:
        """
        Constructor
        """
        # args = ClassArg()\
        #        .add('sel',      list,   must = True) \
        #        .add('rcut',     float,  default = 6.0) \
        #        .add('rcut_smth',float,  default = 0.5) \
        #        .add('neuron',   list,   default = [10, 20, 40]) \
        #        .add('resnet_dt',bool,   default = False) \
        #        .add('trainable',bool,   default = True) \
        #        .add('seed',     int) \
        #        .add('type_one_side', bool, default = False) \
        #        .add('exclude_types', list, default = []) \
        #        .add('set_davg_zero', bool, default = False) \
        #        .add("activation_function", str, default = "tanh") \
        #        .add("precision",           str, default = "default")
        # class_data = args.parse(jdata)
        self.sel_r = sel
        self.rcut = rcut
        self.rcut_smth = rcut_smth
        self.filter_neuron = neuron
        self.filter_resnet_dt = resnet_dt
        self.seed = seed
        self.uniform_seed = uniform_seed
        self.seed_shift = embedding_net_rand_seed_shift(self.filter_neuron)
        self.trainable = trainable
        self.filter_activation_fn = get_activation_func(activation_function)
        self.filter_precision = get_precision(precision)
        exclude_types = exclude_types
        self.exclude_types = set()
        for tt in exclude_types:
            assert (len(tt) == 2)
            self.exclude_types.add((tt[0], tt[1]))
            self.exclude_types.add((tt[1], tt[0]))
        self.set_davg_zero = set_davg_zero
        self.type_one_side = type_one_side

        # descrpt config
        self.sel_a = [0 for ii in range(len(self.sel_r))]
        self.ntypes = len(self.sel_r)
        # numb of neighbors and numb of descrptors
        self.nnei_a = np.cumsum(self.sel_a)[-1]
        self.nnei_r = np.cumsum(self.sel_r)[-1]
        self.nnei = self.nnei_a + self.nnei_r
        self.ndescrpt_a = self.nnei_a * 4
        self.ndescrpt_r = self.nnei_r * 1
        self.ndescrpt = self.nnei_r
        self.useBN = False
        self.davg = None
        self.dstd = None
        self.embedding_net_variables = None

        self.place_holders = {}
        avg_zero = np.zeros([self.ntypes,
                             self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
        std_ones = np.ones([self.ntypes,
                            self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
        sub_graph = tf.Graph()
        with sub_graph.as_default():
            name_pfx = 'd_ser_'
            for ii in ['coord', 'box']:
                self.place_holders[ii] = tf.placeholder(
                    GLOBAL_NP_FLOAT_PRECISION, [None, None],
                    name=name_pfx + 't_' + ii)
            self.place_holders['type'] = tf.placeholder(tf.int32, [None, None],
                                                        name=name_pfx +
                                                        't_type')
            self.place_holders['natoms_vec'] = tf.placeholder(
                tf.int32, [self.ntypes + 2], name=name_pfx + 't_natoms')
            self.place_holders['default_mesh'] = tf.placeholder(
                tf.int32, [None], name=name_pfx + 't_mesh')
            self.stat_descrpt, descrpt_deriv, rij, nlist \
                = op_module.prod_env_mat_r(self.place_holders['coord'],
                                         self.place_holders['type'],
                                         self.place_holders['natoms_vec'],
                                         self.place_holders['box'],
                                         self.place_holders['default_mesh'],
                                         tf.constant(avg_zero),
                                         tf.constant(std_ones),
                                         rcut = self.rcut,
                                         rcut_smth = self.rcut_smth,
                                         sel = self.sel_r)
            self.sub_sess = tf.Session(graph=sub_graph,
                                       config=default_tf_session_config)

    def get_rcut(self):
        """
        Returns the cut-off radisu
        """
        return self.rcut

    def get_ntypes(self):
        """
        Returns the number of atom types
        """
        return self.ntypes

    def get_dim_out(self):
        """
        Returns the output dimension of this descriptor
        """
        return self.filter_neuron[-1]

    def get_nlist(self):
        """
        Returns
        -------
        nlist
                Neighbor list
        rij
                The relative distance between the neighbor and the center atom.
        sel_a
                The number of neighbors with full information
        sel_r
                The number of neighbors with only radial information
        """
        return self.nlist, self.rij, self.sel_a, self.sel_r

    def compute_input_stats(self, data_coord, data_box, data_atype, natoms_vec,
                            mesh, input_dict):
        """
        Compute the statisitcs (avg and std) of the training data. The input will be normalized by the statistics.
        
        Parameters
        ----------
        data_coord
                The coordinates. Can be generated by deepmd.model.make_stat_input
        data_box
                The box. Can be generated by deepmd.model.make_stat_input
        data_atype
                The atom types. Can be generated by deepmd.model.make_stat_input
        natoms_vec
                The vector for the number of atoms of the system and different types of atoms. Can be generated by deepmd.model.make_stat_input
        mesh
                The mesh for neighbor searching. Can be generated by deepmd.model.make_stat_input
        input_dict
                Dictionary for additional input
        """
        all_davg = []
        all_dstd = []
        sumr = []
        sumn = []
        sumr2 = []
        for cc, bb, tt, nn, mm in zip(data_coord, data_box, data_atype,
                                      natoms_vec, mesh):
            sysr,sysr2,sysn \
                = self._compute_dstats_sys_se_r(cc,bb,tt,nn,mm)
            sumr.append(sysr)
            sumn.append(sysn)
            sumr2.append(sysr2)
        sumr = np.sum(sumr, axis=0)
        sumn = np.sum(sumn, axis=0)
        sumr2 = np.sum(sumr2, axis=0)
        for type_i in range(self.ntypes):
            davgunit = [sumr[type_i] / sumn[type_i]]
            dstdunit = [
                self._compute_std(sumr2[type_i], sumr[type_i], sumn[type_i])
            ]
            davg = np.tile(davgunit, self.ndescrpt // 1)
            dstd = np.tile(dstdunit, self.ndescrpt // 1)
            all_davg.append(davg)
            all_dstd.append(dstd)

        if not self.set_davg_zero:
            self.davg = np.array(all_davg)
        self.dstd = np.array(all_dstd)

    def build(self,
              coord_: tf.Tensor,
              atype_: tf.Tensor,
              natoms: tf.Tensor,
              box_: tf.Tensor,
              mesh: tf.Tensor,
              input_dict: dict,
              reuse: bool = None,
              suffix: str = '') -> tf.Tensor:
        """
        Build the computational graph for the descriptor

        Parameters
        ----------
        coord_
                The coordinate of atoms
        atype_
                The type of atoms
        natoms
                The number of atoms. This tensor has the length of Ntypes + 2
                natoms[0]: number of local atoms
                natoms[1]: total number of atoms held by this processor
                natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
        mesh
                For historical reasons, only the length of the Tensor matters.
                if size of mesh == 6, pbc is assumed. 
                if size of mesh == 0, no-pbc is assumed. 
        input_dict
                Dictionary for additional inputs
        reuse
                The weights in the networks should be reused when get the variable.
        suffix
                Name suffix to identify this descriptor

        Returns
        -------
        descriptor
                The output descriptor
        """
        davg = self.davg
        dstd = self.dstd
        with tf.variable_scope('descrpt_attr' + suffix, reuse=reuse):
            if davg is None:
                davg = np.zeros([self.ntypes, self.ndescrpt])
            if dstd is None:
                dstd = np.ones([self.ntypes, self.ndescrpt])
            t_rcut = tf.constant(self.rcut,
                                 name='rcut',
                                 dtype=GLOBAL_TF_FLOAT_PRECISION)
            t_ntypes = tf.constant(self.ntypes, name='ntypes', dtype=tf.int32)
            t_ndescrpt = tf.constant(self.ndescrpt,
                                     name='ndescrpt',
                                     dtype=tf.int32)
            t_sel = tf.constant(self.sel_a, name='sel', dtype=tf.int32)
            self.t_avg = tf.get_variable(
                't_avg',
                davg.shape,
                dtype=GLOBAL_TF_FLOAT_PRECISION,
                trainable=False,
                initializer=tf.constant_initializer(davg))
            self.t_std = tf.get_variable(
                't_std',
                dstd.shape,
                dtype=GLOBAL_TF_FLOAT_PRECISION,
                trainable=False,
                initializer=tf.constant_initializer(dstd))

        coord = tf.reshape(coord_, [-1, natoms[1] * 3])
        box = tf.reshape(box_, [-1, 9])
        atype = tf.reshape(atype_, [-1, natoms[1]])

        self.descrpt, self.descrpt_deriv, self.rij, self.nlist \
            = op_module.prod_env_mat_r(coord,
                                      atype,
                                      natoms,
                                      box,
                                      mesh,
                                      self.t_avg,
                                      self.t_std,
                                      rcut = self.rcut,
                                      rcut_smth = self.rcut_smth,
                                      sel = self.sel_r)

        self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt])
        self._identity_tensors(suffix=suffix)

        # only used when tensorboard was set as true
        tf.summary.histogram('descrpt', self.descrpt)
        tf.summary.histogram('rij', self.rij)
        tf.summary.histogram('nlist', self.nlist)

        self.dout = self._pass_filter(self.descrpt_reshape,
                                      natoms,
                                      suffix=suffix,
                                      reuse=reuse,
                                      trainable=self.trainable)
        tf.summary.histogram('embedding_net_output', self.dout)

        return self.dout

    def prod_force_virial(
            self, atom_ener: tf.Tensor,
            natoms: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
        """
        Compute force and virial

        Parameters
        ----------
        atom_ener
                The atomic energy
        natoms
                The number of atoms. This tensor has the length of Ntypes + 2
                natoms[0]: number of local atoms
                natoms[1]: total number of atoms held by this processor
                natoms[i]: 2 <= i < Ntypes+2, number of type i atoms

        Returns
        -------
        force
                The force on atoms
        virial
                The total virial
        atom_virial
                The atomic virial
        """
        [net_deriv] = tf.gradients(atom_ener, self.descrpt_reshape)
        tf.summary.histogram('net_derivative', net_deriv)
        net_deriv_reshape = tf.reshape(net_deriv,
                                       [-1, natoms[0] * self.ndescrpt])
        force \
            = op_module.prod_force_se_r (net_deriv_reshape,
                                         self.descrpt_deriv,
                                         self.nlist,
                                         natoms)
        virial, atom_virial \
            = op_module.prod_virial_se_r (net_deriv_reshape,
                                          self.descrpt_deriv,
                                          self.rij,
                                          self.nlist,
                                          natoms)
        tf.summary.histogram('force', force)
        tf.summary.histogram('virial', virial)
        tf.summary.histogram('atom_virial', atom_virial)

        return force, virial, atom_virial

    def _pass_filter(self,
                     inputs,
                     natoms,
                     reuse=None,
                     suffix='',
                     trainable=True):
        start_index = 0
        inputs = tf.reshape(inputs, [-1, self.ndescrpt * natoms[0]])
        output = []
        if not self.type_one_side:
            for type_i in range(self.ntypes):
                inputs_i = tf.slice(inputs, [0, start_index * self.ndescrpt],
                                    [-1, natoms[2 + type_i] * self.ndescrpt])
                inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
                layer = self._filter_r(
                    tf.cast(inputs_i, self.filter_precision),
                    type_i,
                    name='filter_type_' + str(type_i) + suffix,
                    natoms=natoms,
                    reuse=reuse,
                    trainable=trainable,
                    activation_fn=self.filter_activation_fn)
                layer = tf.reshape(layer, [
                    tf.shape(inputs)[0],
                    natoms[2 + type_i] * self.get_dim_out()
                ])
                output.append(layer)
                start_index += natoms[2 + type_i]
        else:
            inputs_i = inputs
            inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
            type_i = -1
            layer = self._filter_r(tf.cast(inputs_i, self.filter_precision),
                                   type_i,
                                   name='filter_type_all' + suffix,
                                   natoms=natoms,
                                   reuse=reuse,
                                   trainable=trainable,
                                   activation_fn=self.filter_activation_fn)
            layer = tf.reshape(
                layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()])
            output.append(layer)
        output = tf.concat(output, axis=1)
        return output

    def _compute_dstats_sys_se_r(self, data_coord, data_box, data_atype,
                                 natoms_vec, mesh):
        dd_all \
            = run_sess(self.sub_sess, self.stat_descrpt,
                                feed_dict = {
                                    self.place_holders['coord']: data_coord,
                                    self.place_holders['type']: data_atype,
                                    self.place_holders['natoms_vec']: natoms_vec,
                                    self.place_holders['box']: data_box,
                                    self.place_holders['default_mesh']: mesh,
                                })
        natoms = natoms_vec
        dd_all = np.reshape(dd_all, [-1, self.ndescrpt * natoms[0]])
        start_index = 0
        sysr = []
        sysn = []
        sysr2 = []
        for type_i in range(self.ntypes):
            end_index = start_index + self.ndescrpt * natoms[2 + type_i]
            dd = dd_all[:, start_index:end_index]
            dd = np.reshape(dd, [-1, self.ndescrpt])
            start_index = end_index
            # compute
            dd = np.reshape(dd, [-1, 1])
            ddr = dd[:, :1]
            sumr = np.sum(ddr)
            sumn = dd.shape[0]
            sumr2 = np.sum(np.multiply(ddr, ddr))
            sysr.append(sumr)
            sysn.append(sumn)
            sysr2.append(sumr2)
        return sysr, sysr2, sysn

    def _compute_std(self, sumv2, sumv, sumn):
        val = np.sqrt(sumv2 / sumn - np.multiply(sumv / sumn, sumv / sumn))
        if np.abs(val) < 1e-2:
            val = 1e-2
        return val

    def _filter_r(self,
                  inputs,
                  type_input,
                  natoms,
                  activation_fn=tf.nn.tanh,
                  stddev=1.0,
                  bavg=0.0,
                  name='linear',
                  reuse=None,
                  trainable=True):
        # natom x nei
        outputs_size = [1] + self.filter_neuron
        with tf.variable_scope(name, reuse=reuse):
            start_index = 0
            xyz_scatter_total = []
            for type_i in range(self.ntypes):
                # cut-out inputs
                # with natom x nei_type_i
                inputs_i = tf.slice(inputs, [0, start_index],
                                    [-1, self.sel_r[type_i]])
                start_index += self.sel_r[type_i]
                shape_i = inputs_i.get_shape().as_list()
                # with (natom x nei_type_i) x 1
                xyz_scatter = tf.reshape(inputs_i, [-1, 1])
                if (type_input, type_i) not in self.exclude_types:
                    xyz_scatter = embedding_net(
                        xyz_scatter,
                        self.filter_neuron,
                        self.filter_precision,
                        activation_fn=activation_fn,
                        resnet_dt=self.filter_resnet_dt,
                        name_suffix="_" + str(type_i),
                        stddev=stddev,
                        bavg=bavg,
                        seed=self.seed,
                        trainable=trainable,
                        uniform_seed=self.uniform_seed,
                        initial_variables=self.embedding_net_variables,
                    )
                    if (not self.uniform_seed) and (self.seed is not None):
                        self.seed += self.seed_shift
                    # natom x nei_type_i x out_size
                    xyz_scatter = tf.reshape(
                        xyz_scatter, (-1, shape_i[1], outputs_size[-1]))
                else:
                    natom = tf.shape(inputs)[0]
                    xyz_scatter = tf.cast(
                        tf.fill((natom, shape_i[1], outputs_size[-1]), 0.),
                        GLOBAL_TF_FLOAT_PRECISION)
                xyz_scatter_total.append(xyz_scatter)

            # natom x nei x outputs_size
            xyz_scatter = tf.concat(xyz_scatter_total, axis=1)
            # natom x outputs_size
            #
            res_rescale = 1. / 5.
            result = tf.reduce_mean(xyz_scatter, axis=1) * res_rescale

        return result
示例#2
0
文件: ener.py 项目: njzjz/deepmd-kit
class EnerFitting(Fitting):
    r"""Fitting the energy of the system. The force and the virial can also be trained.

    The potential energy :math:`E` is a fitting network function of the descriptor :math:`\mathcal{D}`:

    .. math::
        E(\mathcal{D}) = \mathcal{L}^{(n)} \circ \mathcal{L}^{(n-1)}
        \circ \cdots \circ \mathcal{L}^{(1)} \circ \mathcal{L}^{(0)}

    The first :math:`n` hidden layers :math:`\mathcal{L}^{(0)}, \cdots, \mathcal{L}^{(n-1)}` are given by

    .. math::
        \mathbf{y}=\mathcal{L}(\mathbf{x};\mathbf{w},\mathbf{b})=
            \boldsymbol{\phi}(\mathbf{x}^T\mathbf{w}+\mathbf{b})

    where :math:`\mathbf{x} \in \mathbb{R}^{N_1}`$` is the input vector and :math:`\mathbf{y} \in \mathbb{R}^{N_2}`
    is the output vector. :math:`\mathbf{w} \in \mathbb{R}^{N_1 \times N_2}` and
    :math:`\mathbf{b} \in \mathbb{R}^{N_2}`$` are weights and biases, respectively,
    both of which are trainable if `trainable[i]` is `True`. :math:`\boldsymbol{\phi}`
    is the activation function.

    The output layer :math:`\mathcal{L}^{(n)}` is given by

    .. math::
        \mathbf{y}=\mathcal{L}^{(n)}(\mathbf{x};\mathbf{w},\mathbf{b})=
            \mathbf{x}^T\mathbf{w}+\mathbf{b}

    where :math:`\mathbf{x} \in \mathbb{R}^{N_{n-1}}`$` is the input vector and :math:`\mathbf{y} \in \mathbb{R}`
    is the output scalar. :math:`\mathbf{w} \in \mathbb{R}^{N_{n-1}}` and
    :math:`\mathbf{b} \in \mathbb{R}`$` are weights and bias, respectively,
    both of which are trainable if `trainable[n]` is `True`.

    Parameters
    ----------
    descrpt
            The descrptor :math:`\mathcal{D}`
    neuron
            Number of neurons :math:`N` in each hidden layer of the fitting net
    resnet_dt
            Time-step `dt` in the resnet construction:
            :math:`y = x + dt * \phi (Wx + b)`
    numb_fparam
            Number of frame parameter
    numb_aparam
            Number of atomic parameter
    rcond
            The condition number for the regression of atomic energy.
    tot_ener_zero
            Force the total energy to zero. Useful for the charge fitting.
    trainable
            If the weights of fitting net are trainable. 
            Suppose that we have :math:`N_l` hidden layers in the fitting net, 
            this list is of length :math:`N_l + 1`, specifying if the hidden layers and the output layer are trainable.
    seed
            Random seed for initializing the network parameters.
    atom_ener
            Specifying atomic energy contribution in vacuum. The `set_davg_zero` key in the descrptor should be set.
    activation_function
            The activation function :math:`\boldsymbol{\phi}` in the embedding net. Supported options are {0}
    precision
            The precision of the embedding net parameters. Supported options are {1}                
    uniform_seed
            Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
    """
    @docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()),
                         list_to_doc(PRECISION_DICT.keys()))
    def __init__(self,
                 descrpt: tf.Tensor,
                 neuron: List[int] = [120, 120, 120],
                 resnet_dt: bool = True,
                 numb_fparam: int = 0,
                 numb_aparam: int = 0,
                 rcond: float = 1e-3,
                 tot_ener_zero: bool = False,
                 trainable: List[bool] = None,
                 seed: int = None,
                 atom_ener: List[float] = [],
                 activation_function: str = 'tanh',
                 precision: str = 'default',
                 uniform_seed: bool = False) -> None:
        """
        Constructor
        """
        # model param
        self.ntypes = descrpt.get_ntypes()
        self.dim_descrpt = descrpt.get_dim_out()
        # args = ()\
        #        .add('numb_fparam',      int,    default = 0)\
        #        .add('numb_aparam',      int,    default = 0)\
        #        .add('neuron',           list,   default = [120,120,120], alias = 'n_neuron')\
        #        .add('resnet_dt',        bool,   default = True)\
        #        .add('rcond',            float,  default = 1e-3) \
        #        .add('tot_ener_zero',    bool,   default = False) \
        #        .add('seed',             int)               \
        #        .add('atom_ener',        list,   default = [])\
        #        .add("activation_function", str,    default = "tanh")\
        #        .add("precision",           str, default = "default")\
        #        .add("trainable",        [list, bool], default = True)
        self.numb_fparam = numb_fparam
        self.numb_aparam = numb_aparam
        self.n_neuron = neuron
        self.resnet_dt = resnet_dt
        self.rcond = rcond
        self.seed = seed
        self.uniform_seed = uniform_seed
        self.seed_shift = one_layer_rand_seed_shift()
        self.tot_ener_zero = tot_ener_zero
        self.fitting_activation_fn = get_activation_func(activation_function)
        self.fitting_precision = get_precision(precision)
        self.trainable = trainable
        if self.trainable is None:
            self.trainable = [True for ii in range(len(self.n_neuron) + 1)]
        if type(self.trainable) is bool:
            self.trainable = [self.trainable] * (len(self.n_neuron) + 1)
        assert (len(self.trainable) == len(self.n_neuron) +
                1), 'length of trainable should be that of n_neuron + 1'
        self.atom_ener = []
        self.atom_ener_v = atom_ener
        for at, ae in enumerate(atom_ener):
            if ae is not None:
                self.atom_ener.append(
                    tf.constant(ae,
                                self.fitting_precision,
                                name="atom_%d_ener" % at))
            else:
                self.atom_ener.append(None)
        self.useBN = False
        self.bias_atom_e = np.zeros(self.ntypes, dtype=np.float64)
        # data requirement
        if self.numb_fparam > 0:
            add_data_requirement('fparam',
                                 self.numb_fparam,
                                 atomic=False,
                                 must=True,
                                 high_prec=False)
            self.fparam_avg = None
            self.fparam_std = None
            self.fparam_inv_std = None
        if self.numb_aparam > 0:
            add_data_requirement('aparam',
                                 self.numb_aparam,
                                 atomic=True,
                                 must=True,
                                 high_prec=False)
            self.aparam_avg = None
            self.aparam_std = None
            self.aparam_inv_std = None

        self.fitting_net_variables = None
        self.mixed_prec = None

    def get_numb_fparam(self) -> int:
        """
        Get the number of frame parameters
        """
        return self.numb_fparam

    def get_numb_aparam(self) -> int:
        """
        Get the number of atomic parameters
        """
        return self.numb_fparam

    def compute_output_stats(self, all_stat: dict) -> None:
        """
        Compute the ouput statistics

        Parameters
        ----------
        all_stat
                must have the following components:
                all_stat['energy'] of shape n_sys x n_batch x n_frame
                can be prepared by model.make_stat_input
        """
        self.bias_atom_e = self._compute_output_stats(all_stat,
                                                      rcond=self.rcond)

    def _compute_output_stats(self, all_stat, rcond=1e-3):
        data = all_stat['energy']
        # data[sys_idx][batch_idx][frame_idx]
        sys_ener = np.array([])
        for ss in range(len(data)):
            sys_data = []
            for ii in range(len(data[ss])):
                for jj in range(len(data[ss][ii])):
                    sys_data.append(data[ss][ii][jj])
            sys_data = np.concatenate(sys_data)
            sys_ener = np.append(sys_ener, np.average(sys_data))
        data = all_stat['natoms_vec']
        sys_tynatom = np.array([])
        nsys = len(data)
        for ss in range(len(data)):
            sys_tynatom = np.append(sys_tynatom,
                                    data[ss][0].astype(np.float64))
        sys_tynatom = np.reshape(sys_tynatom, [nsys, -1])
        sys_tynatom = sys_tynatom[:, 2:]
        if len(self.atom_ener) > 0:
            # Atomic energies stats are incorrect if atomic energies are assigned.
            # In this situation, we directly use these assigned energies instead of computing stats.
            # This will make the loss decrease quickly
            assigned_atom_ener = np.array(
                list((ee for ee in self.atom_ener_v if ee is not None)))
            assigned_ener_idx = list((ii
                                      for ii, ee in enumerate(self.atom_ener_v)
                                      if ee is not None))
            # np.dot out size: nframe
            sys_ener -= np.dot(sys_tynatom[:, assigned_ener_idx],
                               assigned_atom_ener)
            sys_tynatom[:, assigned_ener_idx] = 0.
        energy_shift,resd,rank,s_value \
            = np.linalg.lstsq(sys_tynatom, sys_ener, rcond = rcond)
        if len(self.atom_ener) > 0:
            for ii in assigned_ener_idx:
                energy_shift[ii] = self.atom_ener_v[ii]
        return energy_shift

    def compute_input_stats(self,
                            all_stat: dict,
                            protection: float = 1e-2) -> None:
        """
        Compute the input statistics

        Parameters
        ----------
        all_stat
                if numb_fparam > 0 must have all_stat['fparam']
                if numb_aparam > 0 must have all_stat['aparam']
                can be prepared by model.make_stat_input
        protection
                Divided-by-zero protection
        """
        # stat fparam
        if self.numb_fparam > 0:
            cat_data = np.concatenate(all_stat['fparam'], axis=0)
            cat_data = np.reshape(cat_data, [-1, self.numb_fparam])
            self.fparam_avg = np.average(cat_data, axis=0)
            self.fparam_std = np.std(cat_data, axis=0)
            for ii in range(self.fparam_std.size):
                if self.fparam_std[ii] < protection:
                    self.fparam_std[ii] = protection
            self.fparam_inv_std = 1. / self.fparam_std
        # stat aparam
        if self.numb_aparam > 0:
            sys_sumv = []
            sys_sumv2 = []
            sys_sumn = []
            for ss_ in all_stat['aparam']:
                ss = np.reshape(ss_, [-1, self.numb_aparam])
                sys_sumv.append(np.sum(ss, axis=0))
                sys_sumv2.append(np.sum(np.multiply(ss, ss), axis=0))
                sys_sumn.append(ss.shape[0])
            sumv = np.sum(sys_sumv, axis=0)
            sumv2 = np.sum(sys_sumv2, axis=0)
            sumn = np.sum(sys_sumn)
            self.aparam_avg = (sumv) / sumn
            self.aparam_std = self._compute_std(sumv2, sumv, sumn)
            for ii in range(self.aparam_std.size):
                if self.aparam_std[ii] < protection:
                    self.aparam_std[ii] = protection
            self.aparam_inv_std = 1. / self.aparam_std

    def _compute_std(self, sumv2, sumv, sumn):
        return np.sqrt(sumv2 / sumn - np.multiply(sumv / sumn, sumv / sumn))

    def _build_lower(self,
                     start_index,
                     natoms,
                     inputs,
                     fparam=None,
                     aparam=None,
                     bias_atom_e=0.0,
                     suffix='',
                     reuse=None):
        # cut-out inputs
        inputs_i = tf.slice(inputs, [0, start_index * self.dim_descrpt],
                            [-1, natoms * self.dim_descrpt])
        inputs_i = tf.reshape(inputs_i, [-1, self.dim_descrpt])
        layer = inputs_i
        if fparam is not None:
            ext_fparam = tf.tile(fparam, [1, natoms])
            ext_fparam = tf.reshape(ext_fparam, [-1, self.numb_fparam])
            ext_fparam = tf.cast(ext_fparam, self.fitting_precision)
            layer = tf.concat([layer, ext_fparam], axis=1)
        if aparam is not None:
            ext_aparam = tf.slice(aparam, [0, start_index * self.numb_aparam],
                                  [-1, natoms * self.numb_aparam])
            ext_aparam = tf.reshape(ext_aparam, [-1, self.numb_aparam])
            ext_aparam = tf.cast(ext_aparam, self.fitting_precision)
            layer = tf.concat([layer, ext_aparam], axis=1)

        for ii in range(0, len(self.n_neuron)):
            if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii - 1]:
                layer += one_layer(
                    layer,
                    self.n_neuron[ii],
                    name='layer_' + str(ii) + suffix,
                    reuse=reuse,
                    seed=self.seed,
                    use_timestep=self.resnet_dt,
                    activation_fn=self.fitting_activation_fn,
                    precision=self.fitting_precision,
                    trainable=self.trainable[ii],
                    uniform_seed=self.uniform_seed,
                    initial_variables=self.fitting_net_variables,
                    mixed_prec=self.mixed_prec)
            else:
                layer = one_layer(layer,
                                  self.n_neuron[ii],
                                  name='layer_' + str(ii) + suffix,
                                  reuse=reuse,
                                  seed=self.seed,
                                  activation_fn=self.fitting_activation_fn,
                                  precision=self.fitting_precision,
                                  trainable=self.trainable[ii],
                                  uniform_seed=self.uniform_seed,
                                  initial_variables=self.fitting_net_variables,
                                  mixed_prec=self.mixed_prec)
            if (not self.uniform_seed) and (self.seed is not None):
                self.seed += self.seed_shift
        final_layer = one_layer(layer,
                                1,
                                activation_fn=None,
                                bavg=bias_atom_e,
                                name='final_layer' + suffix,
                                reuse=reuse,
                                seed=self.seed,
                                precision=self.fitting_precision,
                                trainable=self.trainable[-1],
                                uniform_seed=self.uniform_seed,
                                initial_variables=self.fitting_net_variables,
                                mixed_prec=self.mixed_prec,
                                final_layer=True)
        if (not self.uniform_seed) and (self.seed is not None):
            self.seed += self.seed_shift

        return final_layer

    @cast_precision
    def build(
        self,
        inputs: tf.Tensor,
        natoms: tf.Tensor,
        input_dict: dict = None,
        reuse: bool = None,
        suffix: str = '',
    ) -> tf.Tensor:
        """
        Build the computational graph for fitting net

        Parameters
        ----------
        inputs
                The input descriptor
        input_dict
                Additional dict for inputs. 
                if numb_fparam > 0, should have input_dict['fparam']
                if numb_aparam > 0, should have input_dict['aparam']
        natoms
                The number of atoms. This tensor has the length of Ntypes + 2
                natoms[0]: number of local atoms
                natoms[1]: total number of atoms held by this processor
                natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
        reuse
                The weights in the networks should be reused when get the variable.
        suffix
                Name suffix to identify this descriptor

        Returns
        -------
        ener
                The system energy
        """
        if input_dict is None:
            input_dict = {}
        bias_atom_e = self.bias_atom_e
        if self.numb_fparam > 0 and (self.fparam_avg is None
                                     or self.fparam_inv_std is None):
            raise RuntimeError(
                'No data stat result. one should do data statisitic, before build'
            )
        if self.numb_aparam > 0 and (self.aparam_avg is None
                                     or self.aparam_inv_std is None):
            raise RuntimeError(
                'No data stat result. one should do data statisitic, before build'
            )

        with tf.variable_scope('fitting_attr' + suffix, reuse=reuse):
            t_dfparam = tf.constant(self.numb_fparam,
                                    name='dfparam',
                                    dtype=tf.int32)
            t_daparam = tf.constant(self.numb_aparam,
                                    name='daparam',
                                    dtype=tf.int32)
            if self.numb_fparam > 0:
                t_fparam_avg = tf.get_variable(
                    't_fparam_avg',
                    self.numb_fparam,
                    dtype=GLOBAL_TF_FLOAT_PRECISION,
                    trainable=False,
                    initializer=tf.constant_initializer(self.fparam_avg))
                t_fparam_istd = tf.get_variable(
                    't_fparam_istd',
                    self.numb_fparam,
                    dtype=GLOBAL_TF_FLOAT_PRECISION,
                    trainable=False,
                    initializer=tf.constant_initializer(self.fparam_inv_std))
            if self.numb_aparam > 0:
                t_aparam_avg = tf.get_variable(
                    't_aparam_avg',
                    self.numb_aparam,
                    dtype=GLOBAL_TF_FLOAT_PRECISION,
                    trainable=False,
                    initializer=tf.constant_initializer(self.aparam_avg))
                t_aparam_istd = tf.get_variable(
                    't_aparam_istd',
                    self.numb_aparam,
                    dtype=GLOBAL_TF_FLOAT_PRECISION,
                    trainable=False,
                    initializer=tf.constant_initializer(self.aparam_inv_std))

        inputs = tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]])
        if len(self.atom_ener):
            # only for atom_ener
            nframes = input_dict.get('nframes')
            if nframes is not None:
                # like inputs, but we don't want to add a dependency on inputs
                inputs_zero = tf.zeros((nframes, self.dim_descrpt * natoms[0]),
                                       dtype=self.fitting_precision)
            else:
                inputs_zero = tf.zeros_like(inputs,
                                            dtype=self.fitting_precision)

        if bias_atom_e is not None:
            assert (len(bias_atom_e) == self.ntypes)

        fparam = None
        aparam = None
        if self.numb_fparam > 0:
            fparam = input_dict['fparam']
            fparam = tf.reshape(fparam, [-1, self.numb_fparam])
            fparam = (fparam - t_fparam_avg) * t_fparam_istd
        if self.numb_aparam > 0:
            aparam = input_dict['aparam']
            aparam = tf.reshape(aparam, [-1, self.numb_aparam])
            aparam = (aparam - t_aparam_avg) * t_aparam_istd
            aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]])

        type_embedding = input_dict.get('type_embedding', None)
        if type_embedding is not None:
            atype_embed = embed_atom_type(self.ntypes, natoms, type_embedding)
            atype_embed = tf.tile(atype_embed, [tf.shape(inputs)[0], 1])
        else:
            atype_embed = None

        if atype_embed is None:
            start_index = 0
            outs_list = []
            for type_i in range(self.ntypes):
                if bias_atom_e is None:
                    type_bias_ae = 0.0
                else:
                    type_bias_ae = bias_atom_e[type_i]
                final_layer = self._build_lower(start_index,
                                                natoms[2 + type_i],
                                                inputs,
                                                fparam,
                                                aparam,
                                                bias_atom_e=type_bias_ae,
                                                suffix='_type_' + str(type_i) +
                                                suffix,
                                                reuse=reuse)
                # concat the results
                if type_i < len(
                        self.atom_ener) and self.atom_ener[type_i] is not None:
                    zero_layer = self._build_lower(start_index,
                                                   natoms[2 + type_i],
                                                   inputs_zero,
                                                   fparam,
                                                   aparam,
                                                   bias_atom_e=type_bias_ae,
                                                   suffix='_type_' +
                                                   str(type_i) + suffix,
                                                   reuse=True)
                    final_layer += self.atom_ener[type_i] - zero_layer
                final_layer = tf.reshape(
                    final_layer, [tf.shape(inputs)[0], natoms[2 + type_i]])
                outs_list.append(final_layer)
                start_index += natoms[2 + type_i]
            # concat the results
            # concat once may be faster than multiple concat
            outs = tf.concat(outs_list, axis=1)
        # with type embedding
        else:
            if len(self.atom_ener) > 0:
                raise RuntimeError(
                    "setting atom_ener is not supported by type embedding")
            atype_embed = tf.cast(atype_embed, self.fitting_precision)
            type_shape = atype_embed.get_shape().as_list()
            inputs = tf.concat(
                [tf.reshape(inputs, [-1, self.dim_descrpt]), atype_embed],
                axis=1)
            self.dim_descrpt = self.dim_descrpt + type_shape[1]
            inputs = tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]])
            final_layer = self._build_lower(0,
                                            natoms[0],
                                            inputs,
                                            fparam,
                                            aparam,
                                            bias_atom_e=0.0,
                                            suffix=suffix,
                                            reuse=reuse)
            outs = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[0]])
            # add atom energy bias; TF will broadcast to all batches
            # tf.repeat is avaiable in TF>=2.1 or TF 1.15
            _TF_VERSION = Version(TF_VERSION)
            if (Version('1.15') <= _TF_VERSION < Version('2') or _TF_VERSION >=
                    Version('2.1')) and self.bias_atom_e is not None:
                outs += tf.repeat(
                    tf.Variable(self.bias_atom_e,
                                dtype=self.fitting_precision,
                                trainable=False,
                                name="bias_atom_ei"), natoms[2:])

        if self.tot_ener_zero:
            force_tot_ener = 0.0
            outs = tf.reshape(outs, [-1, natoms[0]])
            outs_mean = tf.reshape(tf.reduce_mean(outs, axis=1), [-1, 1])
            outs_mean = outs_mean - tf.ones_like(
                outs_mean, dtype=GLOBAL_TF_FLOAT_PRECISION) * (
                    force_tot_ener / global_cvt_2_tf_float(natoms[0]))
            outs = outs - outs_mean
            outs = tf.reshape(outs, [-1])

        tf.summary.histogram('fitting_net_output', outs)
        return tf.reshape(outs, [-1])

    def init_variables(
        self,
        graph: tf.Graph,
        graph_def: tf.GraphDef,
        suffix: str = "",
    ) -> None:
        """
        Init the fitting net variables with the given dict

        Parameters
        ----------
        graph : tf.Graph
            The input frozen model graph
        graph_def : tf.GraphDef
            The input frozen model graph_def
        suffix : str
            suffix to name scope
        """
        self.fitting_net_variables = get_fitting_net_variables_from_graph_def(
            graph_def)

    def enable_compression(self, model_file: str, suffix: str = "") -> None:
        """
        Set the fitting net attributes from the frozen model_file when fparam or aparam is not zero

        Parameters
        ----------
        model_file : str
            The input frozen model file
        suffix : str, optional
                The suffix of the scope
        """
        if self.numb_fparam > 0 or self.numb_aparam > 0:
            graph, _ = load_graph_def(model_file)
        if self.numb_fparam > 0:
            self.fparam_avg = get_tensor_by_name_from_graph(
                graph, 'fitting_attr%s/t_fparam_avg' % suffix)
            self.fparam_inv_std = get_tensor_by_name_from_graph(
                graph, 'fitting_attr%s/t_fparam_istd' % suffix)
        if self.numb_aparam > 0:
            self.aparam_avg = get_tensor_by_name_from_graph(
                graph, 'fitting_attr%s/t_aparam_avg' % suffix)
            self.aparam_inv_std = get_tensor_by_name_from_graph(
                graph, 'fitting_attr%s/t_aparam_istd' % suffix)

    def enable_mixed_precision(self, mixed_prec: dict = None) -> None:
        """
        Reveive the mixed precision setting.

        Parameters
        ----------
        mixed_prec
                The mixed precision setting used in the embedding net
        """
        self.mixed_prec = mixed_prec
        self.fitting_precision = get_precision(mixed_prec['output_prec'])
示例#3
0
文件: se_a.py 项目: njzjz/deepmd-kit
class DescrptSeA(DescrptSe):
    r"""DeepPot-SE constructed from all information (both angular and radial) of
    atomic configurations. The embedding takes the distance between atoms as input.

    The descriptor :math:`\mathcal{D}^i \in \mathcal{R}^{M_1 \times M_2}` is given by [1]_

    .. math::
        \mathcal{D}^i = (\mathcal{G}^i)^T \mathcal{R}^i (\mathcal{R}^i)^T \mathcal{G}^i_<

    where :math:`\mathcal{R}^i \in \mathbb{R}^{N \times 4}` is the coordinate
    matrix, and each row of :math:`\mathcal{R}^i` can be constructed as follows

    .. math::
        (\mathcal{R}^i)_j = [
        \begin{array}{c}
            s(r_{ji}) & \frac{s(r_{ji})x_{ji}}{r_{ji}} & \frac{s(r_{ji})y_{ji}}{r_{ji}} & \frac{s(r_{ji})z_{ji}}{r_{ji}}
        \end{array}
        ]

    where :math:`\mathbf{R}_{ji}=\mathbf{R}_j-\mathbf{R}_i = (x_{ji}, y_{ji}, z_{ji})` is 
    the relative coordinate and :math:`r_{ji}=\lVert \mathbf{R}_{ji} \lVert` is its norm.
    The switching function :math:`s(r)` is defined as:

    .. math::
        s(r)=
        \begin{cases}
        \frac{1}{r}, & r<r_s \\
        \frac{1}{r} \{ {(\frac{r - r_s}{ r_c - r_s})}^3 (-6 {(\frac{r - r_s}{ r_c - r_s})}^2 +15 \frac{r - r_s}{ r_c - r_s} -10) +1 \}, & r_s \leq r<r_c \\
        0, & r \geq r_c
        \end{cases}

    Each row of the embedding matrix  :math:`\mathcal{G}^i \in \mathbb{R}^{N \times M_1}` consists of outputs
    of a embedding network :math:`\mathcal{N}` of :math:`s(r_{ji})`:

    .. math::
        (\mathcal{G}^i)_j = \mathcal{N}(s(r_{ji}))

    :math:`\mathcal{G}^i_< \in \mathbb{R}^{N \times M_2}` takes first :math:`M_2`$` columns of
    :math:`\mathcal{G}^i`$`. The equation of embedding network :math:`\mathcal{N}` can be found at
    :meth:`deepmd.utils.network.embedding_net`.

    Parameters
    ----------
    rcut
            The cut-off radius :math:`r_c`
    rcut_smth
            From where the environment matrix should be smoothed :math:`r_s`
    sel : list[str]
            sel[i] specifies the maxmum number of type i atoms in the cut-off radius
    neuron : list[int]
            Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}`
    axis_neuron
            Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix)
    resnet_dt
            Time-step `dt` in the resnet construction:
            y = x + dt * \phi (Wx + b)
    trainable
            If the weights of embedding net are trainable.
    seed
            Random seed for initializing the network parameters.
    type_one_side
            Try to build N_types embedding nets. Otherwise, building N_types^2 embedding nets
    exclude_types : List[List[int]]
            The excluded pairs of types which have no interaction with each other.
            For example, `[[0, 1]]` means no interaction between type 0 and type 1.
    set_davg_zero
            Set the shift of embedding net input to zero.
    activation_function
            The activation function in the embedding net. Supported options are {0}
    precision
            The precision of the embedding net parameters. Supported options are {1}
    uniform_seed
            Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
    
    References
    ----------
    .. [1] Linfeng Zhang, Jiequn Han, Han Wang, Wissam A. Saidi, Roberto Car, and E. Weinan. 2018.
       End-to-end symmetry preserving inter-atomic potential energy model for finite and extended
       systems. In Proceedings of the 32nd International Conference on Neural Information Processing
       Systems (NIPS'18). Curran Associates Inc., Red Hook, NY, USA, 4441–4451.
    """
    @docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()),
                         list_to_doc(PRECISION_DICT.keys()))
    def __init__(self,
                 rcut: float,
                 rcut_smth: float,
                 sel: List[str],
                 neuron: List[int] = [24, 48, 96],
                 axis_neuron: int = 8,
                 resnet_dt: bool = False,
                 trainable: bool = True,
                 seed: int = None,
                 type_one_side: bool = True,
                 exclude_types: List[List[int]] = [],
                 set_davg_zero: bool = False,
                 activation_function: str = 'tanh',
                 precision: str = 'default',
                 uniform_seed: bool = False) -> None:
        """
        Constructor
        """
        if rcut < rcut_smth:
            raise RuntimeError(
                "rcut_smth (%f) should be no more than rcut (%f)!" %
                (rcut_smth, rcut))
        self.sel_a = sel
        self.rcut_r = rcut
        self.rcut_r_smth = rcut_smth
        self.filter_neuron = neuron
        self.n_axis_neuron = axis_neuron
        self.filter_resnet_dt = resnet_dt
        self.seed = seed
        self.uniform_seed = uniform_seed
        self.seed_shift = embedding_net_rand_seed_shift(self.filter_neuron)
        self.trainable = trainable
        self.compress_activation_fn = get_activation_func(activation_function)
        self.filter_activation_fn = get_activation_func(activation_function)
        self.filter_precision = get_precision(precision)
        self.exclude_types = set()
        for tt in exclude_types:
            assert (len(tt) == 2)
            self.exclude_types.add((tt[0], tt[1]))
            self.exclude_types.add((tt[1], tt[0]))
        self.set_davg_zero = set_davg_zero
        self.type_one_side = type_one_side

        # descrpt config
        self.sel_r = [0 for ii in range(len(self.sel_a))]
        self.ntypes = len(self.sel_a)
        assert (self.ntypes == len(self.sel_r))
        self.rcut_a = -1
        # numb of neighbors and numb of descrptors
        self.nnei_a = np.cumsum(self.sel_a)[-1]
        self.nnei_r = np.cumsum(self.sel_r)[-1]
        self.nnei = self.nnei_a + self.nnei_r
        self.ndescrpt_a = self.nnei_a * 4
        self.ndescrpt_r = self.nnei_r * 1
        self.ndescrpt = self.ndescrpt_a + self.ndescrpt_r
        self.useBN = False
        self.dstd = None
        self.davg = None
        self.compress = False
        self.embedding_net_variables = None
        self.mixed_prec = None
        self.place_holders = {}
        nei_type = np.array([])
        for ii in range(self.ntypes):
            nei_type = np.append(nei_type,
                                 ii * np.ones(self.sel_a[ii]))  # like a mask
        self.nei_type = tf.constant(nei_type, dtype=tf.int32)

        avg_zero = np.zeros([self.ntypes,
                             self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
        std_ones = np.ones([self.ntypes,
                            self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
        sub_graph = tf.Graph()
        with sub_graph.as_default():
            name_pfx = 'd_sea_'
            for ii in ['coord', 'box']:
                self.place_holders[ii] = tf.placeholder(
                    GLOBAL_NP_FLOAT_PRECISION, [None, None],
                    name=name_pfx + 't_' + ii)
            self.place_holders['type'] = tf.placeholder(tf.int32, [None, None],
                                                        name=name_pfx +
                                                        't_type')
            self.place_holders['natoms_vec'] = tf.placeholder(
                tf.int32, [self.ntypes + 2], name=name_pfx + 't_natoms')
            self.place_holders['default_mesh'] = tf.placeholder(
                tf.int32, [None], name=name_pfx + 't_mesh')
            self.stat_descrpt, descrpt_deriv, rij, nlist \
                = op_module.prod_env_mat_a(self.place_holders['coord'],
                                         self.place_holders['type'],
                                         self.place_holders['natoms_vec'],
                                         self.place_holders['box'],
                                         self.place_holders['default_mesh'],
                                         tf.constant(avg_zero),
                                         tf.constant(std_ones),
                                         rcut_a = self.rcut_a,
                                         rcut_r = self.rcut_r,
                                         rcut_r_smth = self.rcut_r_smth,
                                         sel_a = self.sel_a,
                                         sel_r = self.sel_r)
        self.sub_sess = tf.Session(graph=sub_graph,
                                   config=default_tf_session_config)
        self.original_sel = None

    def get_rcut(self) -> float:
        """
        Returns the cut-off radius
        """
        return self.rcut_r

    def get_ntypes(self) -> int:
        """
        Returns the number of atom types
        """
        return self.ntypes

    def get_dim_out(self) -> int:
        """
        Returns the output dimension of this descriptor
        """
        return self.filter_neuron[-1] * self.n_axis_neuron

    def get_dim_rot_mat_1(self) -> int:
        """
        Returns the first dimension of the rotation matrix. The rotation is of shape dim_1 x 3
        """
        return self.filter_neuron[-1]

    def get_nlist(self) -> Tuple[tf.Tensor, tf.Tensor, List[int], List[int]]:
        """
        Returns
        -------
        nlist
                Neighbor list
        rij
                The relative distance between the neighbor and the center atom.
        sel_a
                The number of neighbors with full information
        sel_r
                The number of neighbors with only radial information
        """
        return self.nlist, self.rij, self.sel_a, self.sel_r

    def compute_input_stats(self, data_coord: list, data_box: list,
                            data_atype: list, natoms_vec: list, mesh: list,
                            input_dict: dict) -> None:
        """
        Compute the statisitcs (avg and std) of the training data. The input will be normalized by the statistics.
        
        Parameters
        ----------
        data_coord
                The coordinates. Can be generated by deepmd.model.make_stat_input
        data_box
                The box. Can be generated by deepmd.model.make_stat_input
        data_atype
                The atom types. Can be generated by deepmd.model.make_stat_input
        natoms_vec
                The vector for the number of atoms of the system and different types of atoms. Can be generated by deepmd.model.make_stat_input
        mesh
                The mesh for neighbor searching. Can be generated by deepmd.model.make_stat_input
        input_dict
                Dictionary for additional input
        """
        all_davg = []
        all_dstd = []
        if True:
            sumr = []
            suma = []
            sumn = []
            sumr2 = []
            suma2 = []
            for cc, bb, tt, nn, mm in zip(data_coord, data_box, data_atype,
                                          natoms_vec, mesh):
                sysr,sysr2,sysa,sysa2,sysn \
                    = self._compute_dstats_sys_smth(cc,bb,tt,nn,mm)
                sumr.append(sysr)
                suma.append(sysa)
                sumn.append(sysn)
                sumr2.append(sysr2)
                suma2.append(sysa2)
            sumr = np.sum(sumr, axis=0)
            suma = np.sum(suma, axis=0)
            sumn = np.sum(sumn, axis=0)
            sumr2 = np.sum(sumr2, axis=0)
            suma2 = np.sum(suma2, axis=0)
            for type_i in range(self.ntypes):
                davgunit = [sumr[type_i] / (sumn[type_i] + 1e-15), 0, 0, 0]
                dstdunit = [
                    self._compute_std(sumr2[type_i], sumr[type_i],
                                      sumn[type_i]),
                    self._compute_std(suma2[type_i], suma[type_i],
                                      sumn[type_i]),
                    self._compute_std(suma2[type_i], suma[type_i],
                                      sumn[type_i]),
                    self._compute_std(suma2[type_i], suma[type_i],
                                      sumn[type_i])
                ]
                davg = np.tile(davgunit, self.ndescrpt // 4)
                dstd = np.tile(dstdunit, self.ndescrpt // 4)
                all_davg.append(davg)
                all_dstd.append(dstd)

        if not self.set_davg_zero:
            self.davg = np.array(all_davg)
        self.dstd = np.array(all_dstd)

    def enable_compression(
        self,
        min_nbor_dist: float,
        model_file: str = 'frozon_model.pb',
        table_extrapolate: float = 5,
        table_stride_1: float = 0.01,
        table_stride_2: float = 0.1,
        check_frequency: int = -1,
        suffix: str = "",
    ) -> None:
        """
        Reveive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.
        
        Parameters
        ----------
        min_nbor_dist
                The nearest distance between atoms
        model_file
                The original frozen model, which will be compressed by the program
        table_extrapolate
                The scale of model extrapolation
        table_stride_1
                The uniform stride of the first table
        table_stride_2
                The uniform stride of the second table
        check_frequency
                The overflow check frequency
        suffix : str, optional
                The suffix of the scope
        """
        # do some checks before the mocel compression process
        assert (
            not self.filter_resnet_dt
        ), "Model compression error: descriptor resnet_dt must be false!"
        for tt in self.exclude_types:
            if (tt[0] not in range(self.ntypes)) or (tt[1] not in range(
                    self.ntypes)):
                raise RuntimeError("exclude types" + str(tt) +
                                   " must within the number of atomic types " +
                                   str(self.ntypes) + "!")
        if (self.ntypes * self.ntypes - len(self.exclude_types) == 0):
            raise RuntimeError(
                "empty embedding-net are not supported in model compression!")

        for ii in range(len(self.filter_neuron) - 1):
            if self.filter_neuron[ii] * 2 != self.filter_neuron[ii + 1]:
                raise NotImplementedError(
                    "Model Compression error: descriptor neuron [%s] is not supported by model compression! "
                    "The size of the next layer of the neural network must be twice the size of the previous layer."
                    % ','.join([str(item) for item in self.filter_neuron]))

        self.compress = True
        self.table = DPTabulate(self,
                                self.filter_neuron,
                                model_file,
                                self.type_one_side,
                                self.exclude_types,
                                self.compress_activation_fn,
                                suffix=suffix)
        self.table_config = [
            table_extrapolate, table_stride_1, table_stride_2, check_frequency
        ]
        self.lower, self.upper \
            = self.table.build(min_nbor_dist,
                               table_extrapolate,
                               table_stride_1,
                               table_stride_2)

        graph, _ = load_graph_def(model_file)
        self.davg = get_tensor_by_name_from_graph(
            graph, 'descrpt_attr%s/t_avg' % suffix)
        self.dstd = get_tensor_by_name_from_graph(
            graph, 'descrpt_attr%s/t_std' % suffix)

    def enable_mixed_precision(self, mixed_prec: dict = None) -> None:
        """
        Reveive the mixed precision setting.

        Parameters
        ----------
        mixed_prec
                The mixed precision setting used in the embedding net
        """
        self.mixed_prec = mixed_prec
        self.filter_precision = get_precision(mixed_prec['output_prec'])

    def build(self,
              coord_: tf.Tensor,
              atype_: tf.Tensor,
              natoms: tf.Tensor,
              box_: tf.Tensor,
              mesh: tf.Tensor,
              input_dict: dict,
              reuse: bool = None,
              suffix: str = '') -> tf.Tensor:
        """
        Build the computational graph for the descriptor

        Parameters
        ----------
        coord_
                The coordinate of atoms
        atype_
                The type of atoms
        natoms
                The number of atoms. This tensor has the length of Ntypes + 2
                natoms[0]: number of local atoms
                natoms[1]: total number of atoms held by this processor
                natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
        mesh
                For historical reasons, only the length of the Tensor matters.
                if size of mesh == 6, pbc is assumed. 
                if size of mesh == 0, no-pbc is assumed. 
        input_dict
                Dictionary for additional inputs
        reuse
                The weights in the networks should be reused when get the variable.
        suffix
                Name suffix to identify this descriptor

        Returns
        -------
        descriptor
                The output descriptor
        """
        davg = self.davg
        dstd = self.dstd
        with tf.variable_scope('descrpt_attr' + suffix, reuse=reuse):
            if davg is None:
                davg = np.zeros([self.ntypes, self.ndescrpt])
            if dstd is None:
                dstd = np.ones([self.ntypes, self.ndescrpt])
            t_rcut = tf.constant(np.max([self.rcut_r, self.rcut_a]),
                                 name='rcut',
                                 dtype=GLOBAL_TF_FLOAT_PRECISION)
            t_ntypes = tf.constant(self.ntypes, name='ntypes', dtype=tf.int32)
            t_ndescrpt = tf.constant(self.ndescrpt,
                                     name='ndescrpt',
                                     dtype=tf.int32)
            t_sel = tf.constant(self.sel_a, name='sel', dtype=tf.int32)
            t_original_sel = tf.constant(self.original_sel if self.original_sel
                                         is not None else self.sel_a,
                                         name='original_sel',
                                         dtype=tf.int32)
            self.t_avg = tf.get_variable(
                't_avg',
                davg.shape,
                dtype=GLOBAL_TF_FLOAT_PRECISION,
                trainable=False,
                initializer=tf.constant_initializer(davg))
            self.t_std = tf.get_variable(
                't_std',
                dstd.shape,
                dtype=GLOBAL_TF_FLOAT_PRECISION,
                trainable=False,
                initializer=tf.constant_initializer(dstd))

        with tf.control_dependencies([t_sel, t_original_sel]):
            coord = tf.reshape(coord_, [-1, natoms[1] * 3])
        box = tf.reshape(box_, [-1, 9])
        atype = tf.reshape(atype_, [-1, natoms[1]])

        self.descrpt, self.descrpt_deriv, self.rij, self.nlist \
            = op_module.prod_env_mat_a (coord,
                                       atype,
                                       natoms,
                                       box,
                                       mesh,
                                       self.t_avg,
                                       self.t_std,
                                       rcut_a = self.rcut_a,
                                       rcut_r = self.rcut_r,
                                       rcut_r_smth = self.rcut_r_smth,
                                       sel_a = self.sel_a,
                                       sel_r = self.sel_r)
        # only used when tensorboard was set as true
        tf.summary.histogram('descrpt', self.descrpt)
        tf.summary.histogram('rij', self.rij)
        tf.summary.histogram('nlist', self.nlist)

        self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt])
        self._identity_tensors(suffix=suffix)

        self.dout, self.qmat = self._pass_filter(self.descrpt_reshape,
                                                 atype,
                                                 natoms,
                                                 input_dict,
                                                 suffix=suffix,
                                                 reuse=reuse,
                                                 trainable=self.trainable)

        # only used when tensorboard was set as true
        tf.summary.histogram('embedding_net_output', self.dout)
        return self.dout

    def get_rot_mat(self) -> tf.Tensor:
        """
        Get rotational matrix
        """
        return self.qmat

    def prod_force_virial(
            self, atom_ener: tf.Tensor,
            natoms: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
        """
        Compute force and virial

        Parameters
        ----------
        atom_ener
                The atomic energy
        natoms
                The number of atoms. This tensor has the length of Ntypes + 2
                natoms[0]: number of local atoms
                natoms[1]: total number of atoms held by this processor
                natoms[i]: 2 <= i < Ntypes+2, number of type i atoms

        Returns
        -------
        force
                The force on atoms
        virial
                The total virial
        atom_virial
                The atomic virial
        """
        [net_deriv] = tf.gradients(atom_ener, self.descrpt_reshape)
        tf.summary.histogram('net_derivative', net_deriv)
        net_deriv_reshape = tf.reshape(net_deriv,
                                       [-1, natoms[0] * self.ndescrpt])
        force \
            = op_module.prod_force_se_a (net_deriv_reshape,
                                          self.descrpt_deriv,
                                          self.nlist,
                                          natoms,
                                          n_a_sel = self.nnei_a,
                                          n_r_sel = self.nnei_r)
        virial, atom_virial \
            = op_module.prod_virial_se_a (net_deriv_reshape,
                                           self.descrpt_deriv,
                                           self.rij,
                                           self.nlist,
                                           natoms,
                                           n_a_sel = self.nnei_a,
                                           n_r_sel = self.nnei_r)
        tf.summary.histogram('force', force)
        tf.summary.histogram('virial', virial)
        tf.summary.histogram('atom_virial', atom_virial)

        return force, virial, atom_virial

    def _pass_filter(self,
                     inputs,
                     atype,
                     natoms,
                     input_dict,
                     reuse=None,
                     suffix='',
                     trainable=True):
        if input_dict is not None:
            type_embedding = input_dict.get('type_embedding', None)
        else:
            type_embedding = None
        start_index = 0
        inputs = tf.reshape(inputs, [-1, self.ndescrpt * natoms[0]])
        output = []
        output_qmat = []
        if not (self.type_one_side
                and len(self.exclude_types) == 0) and type_embedding is None:
            for type_i in range(self.ntypes):
                inputs_i = tf.slice(inputs, [0, start_index * self.ndescrpt],
                                    [-1, natoms[2 + type_i] * self.ndescrpt])
                inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
                if self.type_one_side:
                    # reuse NN parameters for all types to support type_one_side along with exclude_types
                    reuse = tf.AUTO_REUSE
                    filter_name = 'filter_type_all' + suffix
                else:
                    filter_name = 'filter_type_' + str(type_i) + suffix
                layer, qmat = self._filter(
                    inputs_i,
                    type_i,
                    name=filter_name,
                    natoms=natoms,
                    reuse=reuse,
                    trainable=trainable,
                    activation_fn=self.filter_activation_fn)
                layer = tf.reshape(layer, [
                    tf.shape(inputs)[0],
                    natoms[2 + type_i] * self.get_dim_out()
                ])
                qmat = tf.reshape(qmat, [
                    tf.shape(inputs)[0],
                    natoms[2 + type_i] * self.get_dim_rot_mat_1() * 3
                ])
                output.append(layer)
                output_qmat.append(qmat)
                start_index += natoms[2 + type_i]
        else:
            inputs_i = inputs
            inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
            type_i = -1
            layer, qmat = self._filter(inputs_i,
                                       type_i,
                                       name='filter_type_all' + suffix,
                                       natoms=natoms,
                                       reuse=reuse,
                                       trainable=trainable,
                                       activation_fn=self.filter_activation_fn,
                                       type_embedding=type_embedding)
            layer = tf.reshape(
                layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()])
            qmat = tf.reshape(qmat, [
                tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3
            ])
            output.append(layer)
            output_qmat.append(qmat)
        output = tf.concat(output, axis=1)
        output_qmat = tf.concat(output_qmat, axis=1)
        return output, output_qmat

    def _compute_dstats_sys_smth(self, data_coord, data_box, data_atype,
                                 natoms_vec, mesh):
        dd_all \
            = run_sess(self.sub_sess, self.stat_descrpt,
                                feed_dict = {
                                    self.place_holders['coord']: data_coord,
                                    self.place_holders['type']: data_atype,
                                    self.place_holders['natoms_vec']: natoms_vec,
                                    self.place_holders['box']: data_box,
                                    self.place_holders['default_mesh']: mesh,
                                })
        natoms = natoms_vec
        dd_all = np.reshape(dd_all, [-1, self.ndescrpt * natoms[0]])
        start_index = 0
        sysr = []
        sysa = []
        sysn = []
        sysr2 = []
        sysa2 = []
        for type_i in range(self.ntypes):
            end_index = start_index + self.ndescrpt * natoms[2 + type_i]
            dd = dd_all[:, start_index:end_index]
            dd = np.reshape(dd, [-1, self.ndescrpt])
            start_index = end_index
            # compute
            dd = np.reshape(dd, [-1, 4])
            ddr = dd[:, :1]
            dda = dd[:, 1:]
            sumr = np.sum(ddr)
            suma = np.sum(dda) / 3.
            sumn = dd.shape[0]
            sumr2 = np.sum(np.multiply(ddr, ddr))
            suma2 = np.sum(np.multiply(dda, dda)) / 3.
            sysr.append(sumr)
            sysa.append(suma)
            sysn.append(sumn)
            sysr2.append(sumr2)
            sysa2.append(suma2)
        return sysr, sysr2, sysa, sysa2, sysn

    def _compute_std(self, sumv2, sumv, sumn):
        if sumn == 0:
            return 1e-2
        val = np.sqrt(sumv2 / sumn - np.multiply(sumv / sumn, sumv / sumn))
        if np.abs(val) < 1e-2:
            val = 1e-2
        return val

    def _concat_type_embedding(
        self,
        xyz_scatter,
        nframes,
        natoms,
        type_embedding,
    ):
        '''Concatenate `type_embedding` of neighbors and `xyz_scatter`.
        If not self.type_one_side, concatenate `type_embedding` of center atoms as well.

        Parameters
        ----------
        xyz_scatter:
                shape is [nframes*natoms[0]*self.nnei, 1]
        nframes:
                shape is []
        natoms:
                shape is [1+1+self.ntypes]
        type_embedding:
                shape is [self.ntypes, Y] where Y=jdata['type_embedding']['neuron'][-1]

        Returns
        -------
            embedding:
                environment of each atom represented by embedding.
        '''
        te_out_dim = type_embedding.get_shape().as_list()[-1]
        nei_embed = tf.nn.embedding_lookup(
            type_embedding,
            tf.cast(self.nei_type,
                    dtype=tf.int32))  # shape is [self.nnei, 1+te_out_dim]
        nei_embed = tf.tile(
            nei_embed,
            (nframes * natoms[0],
             1))  # shape is [nframes*natoms[0]*self.nnei, te_out_dim]
        nei_embed = tf.reshape(nei_embed, [-1, te_out_dim])
        embedding_input = tf.concat(
            [xyz_scatter, nei_embed],
            1)  # shape is [nframes*natoms[0]*self.nnei, 1+te_out_dim]
        if not self.type_one_side:
            atm_embed = embed_atom_type(
                self.ntypes, natoms,
                type_embedding)  # shape is [natoms[0], te_out_dim]
            atm_embed = tf.tile(
                atm_embed,
                (nframes, self.nnei
                 ))  # shape is [nframes*natoms[0], self.nnei*te_out_dim]
            atm_embed = tf.reshape(
                atm_embed,
                [-1, te_out_dim
                 ])  # shape is [nframes*natoms[0]*self.nnei, te_out_dim]
            embedding_input = tf.concat(
                [embedding_input, atm_embed], 1
            )  # shape is [nframes*natoms[0]*self.nnei, 1+te_out_dim+te_out_dim]
        return embedding_input

    def _filter_lower(
        self,
        type_i,
        type_input,
        start_index,
        incrs_index,
        inputs,
        nframes,
        natoms,
        type_embedding=None,
        is_exclude=False,
        activation_fn=None,
        bavg=0.0,
        stddev=1.0,
        trainable=True,
        suffix='',
    ):
        """
        input env matrix, returns R.G
        """
        outputs_size = [1] + self.filter_neuron
        # cut-out inputs
        # with natom x (nei_type_i x 4)
        inputs_i = tf.slice(inputs, [0, start_index * 4],
                            [-1, incrs_index * 4])
        shape_i = inputs_i.get_shape().as_list()
        natom = tf.shape(inputs_i)[0]
        # with (natom x nei_type_i) x 4
        inputs_reshape = tf.reshape(inputs_i, [-1, 4])
        # with (natom x nei_type_i) x 1
        xyz_scatter = tf.reshape(tf.slice(inputs_reshape, [0, 0], [-1, 1]),
                                 [-1, 1])
        if type_embedding is not None:
            xyz_scatter = self._concat_type_embedding(xyz_scatter, nframes,
                                                      natoms, type_embedding)
            if self.compress:
                raise RuntimeError(
                    'compression of type embedded descriptor is not supported at the moment'
                )
        # natom x 4 x outputs_size
        if self.compress and (not is_exclude):
            info = [
                self.lower, self.upper, self.upper * self.table_config[0],
                self.table_config[1], self.table_config[2],
                self.table_config[3]
            ]
            if self.type_one_side:
                net = 'filter_-1_net_' + str(type_i)
            else:
                net = 'filter_' + str(type_input) + '_net_' + str(type_i)
            return op_module.tabulate_fusion_se_a(
                tf.cast(self.table.data[net], self.filter_precision),
                info,
                xyz_scatter,
                tf.reshape(inputs_i, [natom, shape_i[1] // 4, 4]),
                last_layer_size=outputs_size[-1])
        else:
            if (not is_exclude):
                # with (natom x nei_type_i) x out_size
                xyz_scatter = embedding_net(
                    xyz_scatter,
                    self.filter_neuron,
                    self.filter_precision,
                    activation_fn=activation_fn,
                    resnet_dt=self.filter_resnet_dt,
                    name_suffix=suffix,
                    stddev=stddev,
                    bavg=bavg,
                    seed=self.seed,
                    trainable=trainable,
                    uniform_seed=self.uniform_seed,
                    initial_variables=self.embedding_net_variables,
                    mixed_prec=self.mixed_prec)
                if (not self.uniform_seed) and (self.seed is not None):
                    self.seed += self.seed_shift
            else:
                # we can safely return the final xyz_scatter filled with zero directly
                return tf.cast(tf.fill((natom, 4, outputs_size[-1]), 0.),
                               self.filter_precision)
            # natom x nei_type_i x out_size
            xyz_scatter = tf.reshape(xyz_scatter,
                                     (-1, shape_i[1] // 4, outputs_size[-1]))
            # When using tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]) below
            # [588 24] -> [588 6 4] correct
            # but if sel is zero
            # [588 0] -> [147 0 4] incorrect; the correct one is [588 0 4]
            # So we need to explicitly assign the shape to tf.shape(inputs_i)[0] instead of -1
            # natom x 4 x outputs_size
            return tf.matmul(tf.reshape(inputs_i, [natom, shape_i[1] // 4, 4]),
                             xyz_scatter,
                             transpose_a=True)

    @cast_precision
    def _filter(self,
                inputs,
                type_input,
                natoms,
                type_embedding=None,
                activation_fn=tf.nn.tanh,
                stddev=1.0,
                bavg=0.0,
                name='linear',
                reuse=None,
                trainable=True):
        nframes = tf.shape(tf.reshape(inputs,
                                      [-1, natoms[0], self.ndescrpt]))[0]
        # natom x (nei x 4)
        shape = inputs.get_shape().as_list()
        outputs_size = [1] + self.filter_neuron
        outputs_size_2 = self.n_axis_neuron
        all_excluded = all([(type_input, type_i) in self.exclude_types
                            for type_i in range(self.ntypes)])
        if all_excluded:
            # all types are excluded so result and qmat should be zeros
            # we can safaly return a zero matrix...
            # See also https://stackoverflow.com/a/34725458/9567349
            # result: natom x outputs_size x outputs_size_2
            # qmat: natom x outputs_size x 3
            natom = tf.shape(inputs)[0]
            result = tf.cast(
                tf.fill((natom, outputs_size_2, outputs_size[-1]), 0.),
                GLOBAL_TF_FLOAT_PRECISION)
            qmat = tf.cast(tf.fill((natom, outputs_size[-1], 3), 0.),
                           GLOBAL_TF_FLOAT_PRECISION)
            return result, qmat

        with tf.variable_scope(name, reuse=reuse):
            start_index = 0
            type_i = 0
            # natom x 4 x outputs_size
            if type_embedding is None:
                rets = []
                for type_i in range(self.ntypes):
                    ret = self._filter_lower(type_i,
                                             type_input,
                                             start_index,
                                             self.sel_a[type_i],
                                             inputs,
                                             nframes,
                                             natoms,
                                             type_embedding=type_embedding,
                                             is_exclude=(type_input, type_i)
                                             in self.exclude_types,
                                             activation_fn=activation_fn,
                                             stddev=stddev,
                                             bavg=bavg,
                                             trainable=trainable,
                                             suffix="_" + str(type_i))
                    if (type_input, type_i) not in self.exclude_types:
                        # add zero is meaningless; skip
                        rets.append(ret)
                    start_index += self.sel_a[type_i]
                # faster to use accumulate_n than multiple add
                xyz_scatter_1 = tf.accumulate_n(rets)
            else:
                xyz_scatter_1 = self._filter_lower(
                    type_i,
                    type_input,
                    start_index,
                    np.cumsum(self.sel_a)[-1],
                    inputs,
                    nframes,
                    natoms,
                    type_embedding=type_embedding,
                    is_exclude=False,
                    activation_fn=activation_fn,
                    stddev=stddev,
                    bavg=bavg,
                    trainable=trainable)
            # natom x nei x outputs_size
            # xyz_scatter = tf.concat(xyz_scatter_total, axis=1)
            # natom x nei x 4
            # inputs_reshape = tf.reshape(inputs, [-1, shape[1]//4, 4])
            # natom x 4 x outputs_size
            # xyz_scatter_1 = tf.matmul(inputs_reshape, xyz_scatter, transpose_a = True)
            if self.original_sel is None:
                # shape[1] = nnei * 4
                nnei = shape[1] / 4
            else:
                nnei = tf.cast(
                    tf.Variable(np.sum(self.original_sel),
                                dtype=tf.int32,
                                trainable=False,
                                name="nnei"), self.filter_precision)
            xyz_scatter_1 = xyz_scatter_1 / nnei
            # natom x 4 x outputs_size_2
            xyz_scatter_2 = tf.slice(xyz_scatter_1, [0, 0, 0],
                                     [-1, -1, outputs_size_2])
            # # natom x 3 x outputs_size_2
            # qmat = tf.slice(xyz_scatter_2, [0,1,0], [-1, 3, -1])
            # natom x 3 x outputs_size_1
            qmat = tf.slice(xyz_scatter_1, [0, 1, 0], [-1, 3, -1])
            # natom x outputs_size_1 x 3
            qmat = tf.transpose(qmat, perm=[0, 2, 1])
            # natom x outputs_size x outputs_size_2
            result = tf.matmul(xyz_scatter_1, xyz_scatter_2, transpose_a=True)
            # natom x (outputs_size x outputs_size_2)
            result = tf.reshape(result,
                                [-1, outputs_size_2 * outputs_size[-1]])

        return result, qmat

    def init_variables(
        self,
        graph: tf.Graph,
        graph_def: tf.GraphDef,
        suffix: str = "",
    ) -> None:
        """
        Init the embedding net variables with the given dict

        Parameters
        ----------
        graph : tf.Graph
            The input frozen model graph
        graph_def : tf.GraphDef
            The input frozen model graph_def
        suffix : str, optional
            The suffix of the scope
        """
        super().init_variables(graph=graph, graph_def=graph_def, suffix=suffix)
        try:
            self.original_sel = get_tensor_by_name_from_graph(
                graph, 'descrpt_attr%s/original_sel' % suffix)
        except GraphWithoutTensorError:
            # original_sel is not restored in old graphs, assume sel never changed before
            pass
        # check sel == original sel?
        try:
            sel = get_tensor_by_name_from_graph(graph,
                                                'descrpt_attr%s/sel' % suffix)
        except GraphWithoutTensorError:
            # sel is not restored in old graphs
            pass
        else:
            if not np.array_equal(np.array(self.sel_a), sel):
                if not self.set_davg_zero:
                    raise RuntimeError(
                        "Adjusting sel is only supported when `set_davg_zero` is true!"
                    )
                # as set_davg_zero, self.davg is safely zero
                self.davg = np.zeros([self.ntypes, self.ndescrpt
                                      ]).astype(GLOBAL_NP_FLOAT_PRECISION)
                new_dstd = np.ones([self.ntypes, self.ndescrpt
                                    ]).astype(GLOBAL_NP_FLOAT_PRECISION)
                # shape of davg and dstd is (ntypes, ndescrpt), ndescrpt = 4*sel
                n_descpt = np.array(self.sel_a) * 4
                n_descpt_old = np.array(sel) * 4
                end_index = np.cumsum(n_descpt)
                end_index_old = np.cumsum(n_descpt_old)
                start_index = np.roll(end_index, 1)
                start_index[0] = 0
                start_index_old = np.roll(end_index_old, 1)
                start_index_old[0] = 0

                for nn, oo, ii, jj in zip(n_descpt, n_descpt_old, start_index,
                                          start_index_old):
                    if nn < oo:
                        # new size is smaller, copy part of std
                        new_dstd[:, ii:ii + nn] = self.dstd[:, jj:jj + nn]
                    else:
                        # new size is larger, copy all, the rest remains 1
                        new_dstd[:, ii:ii + oo] = self.dstd[:, jj:jj + oo]
                self.dstd = new_dstd
                if self.original_sel is None:
                    self.original_sel = sel
示例#4
0
class TypeEmbedNet():
    """

    Parameters
    ----------
    neuron : list[int]
            Number of neurons in each hidden layers of the embedding net
    resnet_dt
            Time-step `dt` in the resnet construction:
            y = x + dt * \phi (Wx + b)
    activation_function
            The activation function in the embedding net. Supported options are {0}
    precision
            The precision of the embedding net parameters. Supported options are {1}        
    trainable
            If the weights of embedding net are trainable.
    seed
            Random seed for initializing the network parameters.
    uniform_seed
            Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
    """
    @docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()),
                         list_to_doc(PRECISION_DICT.keys()))
    def __init__(
        self,
        neuron: List[int] = [],
        resnet_dt: bool = False,
        activation_function: str = 'tanh',
        precision: str = 'default',
        trainable: bool = True,
        seed: int = None,
        uniform_seed: bool = False,
    ) -> None:
        """
        Constructor
        """
        self.neuron = neuron
        self.seed = seed
        self.filter_resnet_dt = resnet_dt
        self.filter_precision = get_precision(precision)
        self.filter_activation_fn = get_activation_func(activation_function)
        self.trainable = trainable
        self.uniform_seed = uniform_seed

    def build(
        self,
        ntypes: int,
        reuse=None,
        suffix='',
    ):
        """
        Build the computational graph for the descriptor

        Parameters
        ----------
        ntypes
                Number of atom types.
        reuse
                The weights in the networks should be reused when get the variable.
        suffix
                Name suffix to identify this descriptor

        Returns
        -------
        embedded_types
                The computational graph for embedded types        
        """
        types = tf.convert_to_tensor([ii for ii in range(ntypes)],
                                     dtype=tf.int32)
        ebd_type = tf.cast(
            tf.one_hot(tf.cast(types, dtype=tf.int32), int(ntypes)),
            self.filter_precision)
        ebd_type = tf.reshape(ebd_type, [-1, ntypes])
        name = 'type_embed_net' + suffix
        with tf.variable_scope(name, reuse=reuse):
            ebd_type = embedding_net(ebd_type,
                                     self.neuron,
                                     activation_fn=self.filter_activation_fn,
                                     precision=self.filter_precision,
                                     resnet_dt=self.filter_resnet_dt,
                                     seed=self.seed,
                                     trainable=self.trainable,
                                     uniform_seed=self.uniform_seed)
        ebd_type = tf.reshape(ebd_type,
                              [-1, self.neuron[-1]])  # nnei * neuron[-1]
        self.ebd_type = tf.identity(ebd_type, name='t_typeebd')
        return self.ebd_type
示例#5
0
class DescrptSeT(DescrptSe):
    """DeepPot-SE constructed from all information (both angular and radial) of atomic
    configurations.
    
    The embedding takes angles between two neighboring atoms as input.

    Parameters
    ----------
    rcut
            The cut-off radius
    rcut_smth
            From where the environment matrix should be smoothed
    sel : list[str]
            sel[i] specifies the maxmum number of type i atoms in the cut-off radius
    neuron : list[int]
            Number of neurons in each hidden layers of the embedding net
    resnet_dt
            Time-step `dt` in the resnet construction:
            y = x + dt * \phi (Wx + b)
    trainable
            If the weights of embedding net are trainable.
    seed
            Random seed for initializing the network parameters.
    set_davg_zero
            Set the shift of embedding net input to zero.
    activation_function
            The activation function in the embedding net. Supported options are {0}
    precision
            The precision of the embedding net parameters. Supported options are {1}
    uniform_seed
            Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
    """
    @docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()),
                         list_to_doc(PRECISION_DICT.keys()))
    def __init__(self,
                 rcut: float,
                 rcut_smth: float,
                 sel: List[str],
                 neuron: List[int] = [24, 48, 96],
                 resnet_dt: bool = False,
                 trainable: bool = True,
                 seed: int = None,
                 set_davg_zero: bool = False,
                 activation_function: str = 'tanh',
                 precision: str = 'default',
                 uniform_seed: bool = False) -> None:
        """
        Constructor
        """
        if rcut < rcut_smth:
            raise RuntimeError(
                "rcut_smth (%f) should be no more than rcut (%f)!" %
                (rcut_smth, rcut))
        self.sel_a = sel
        self.rcut_r = rcut
        self.rcut_r_smth = rcut_smth
        self.filter_neuron = neuron
        self.filter_resnet_dt = resnet_dt
        self.seed = seed
        self.uniform_seed = uniform_seed
        self.seed_shift = embedding_net_rand_seed_shift(self.filter_neuron)
        self.trainable = trainable
        self.filter_activation_fn = get_activation_func(activation_function)
        self.filter_precision = get_precision(precision)
        # self.exclude_types = set()
        # for tt in exclude_types:
        #     assert(len(tt) == 2)
        #     self.exclude_types.add((tt[0], tt[1]))
        #     self.exclude_types.add((tt[1], tt[0]))
        self.set_davg_zero = set_davg_zero

        # descrpt config
        self.sel_r = [0 for ii in range(len(self.sel_a))]
        self.ntypes = len(self.sel_a)
        assert (self.ntypes == len(self.sel_r))
        self.rcut_a = -1
        # numb of neighbors and numb of descrptors
        self.nnei_a = np.cumsum(self.sel_a)[-1]
        self.nnei_r = np.cumsum(self.sel_r)[-1]
        self.nnei = self.nnei_a + self.nnei_r
        self.ndescrpt_a = self.nnei_a * 4
        self.ndescrpt_r = self.nnei_r * 1
        self.ndescrpt = self.ndescrpt_a + self.ndescrpt_r
        self.useBN = False
        self.dstd = None
        self.davg = None
        self.compress = False
        self.embedding_net_variables = None

        self.place_holders = {}
        avg_zero = np.zeros([self.ntypes,
                             self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
        std_ones = np.ones([self.ntypes,
                            self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
        sub_graph = tf.Graph()
        with sub_graph.as_default():
            name_pfx = 'd_sea_'
            for ii in ['coord', 'box']:
                self.place_holders[ii] = tf.placeholder(
                    GLOBAL_NP_FLOAT_PRECISION, [None, None],
                    name=name_pfx + 't_' + ii)
            self.place_holders['type'] = tf.placeholder(tf.int32, [None, None],
                                                        name=name_pfx +
                                                        't_type')
            self.place_holders['natoms_vec'] = tf.placeholder(
                tf.int32, [self.ntypes + 2], name=name_pfx + 't_natoms')
            self.place_holders['default_mesh'] = tf.placeholder(
                tf.int32, [None], name=name_pfx + 't_mesh')
            self.stat_descrpt, descrpt_deriv, rij, nlist \
                = op_module.prod_env_mat_a(self.place_holders['coord'],
                                         self.place_holders['type'],
                                         self.place_holders['natoms_vec'],
                                         self.place_holders['box'],
                                         self.place_holders['default_mesh'],
                                         tf.constant(avg_zero),
                                         tf.constant(std_ones),
                                         rcut_a = self.rcut_a,
                                         rcut_r = self.rcut_r,
                                         rcut_r_smth = self.rcut_r_smth,
                                         sel_a = self.sel_a,
                                         sel_r = self.sel_r)
        self.sub_sess = tf.Session(graph=sub_graph,
                                   config=default_tf_session_config)

    def get_rcut(self) -> float:
        """
        Returns the cut-off radisu
        """
        return self.rcut_r

    def get_ntypes(self) -> int:
        """
        Returns the number of atom types
        """
        return self.ntypes

    def get_dim_out(self) -> int:
        """
        Returns the output dimension of this descriptor
        """
        return self.filter_neuron[-1]

    def get_nlist(self) -> Tuple[tf.Tensor, tf.Tensor, List[int], List[int]]:
        """
        Returns
        -------
        nlist
                Neighbor list
        rij
                The relative distance between the neighbor and the center atom.
        sel_a
                The number of neighbors with full information
        sel_r
                The number of neighbors with only radial information
        """
        return self.nlist, self.rij, self.sel_a, self.sel_r

    def compute_input_stats(self, data_coord: list, data_box: list,
                            data_atype: list, natoms_vec: list, mesh: list,
                            input_dict: dict) -> None:
        """
        Compute the statisitcs (avg and std) of the training data. The input will be normalized by the statistics.
        
        Parameters
        ----------
        data_coord
                The coordinates. Can be generated by deepmd.model.make_stat_input
        data_box
                The box. Can be generated by deepmd.model.make_stat_input
        data_atype
                The atom types. Can be generated by deepmd.model.make_stat_input
        natoms_vec
                The vector for the number of atoms of the system and different types of atoms. Can be generated by deepmd.model.make_stat_input
        mesh
                The mesh for neighbor searching. Can be generated by deepmd.model.make_stat_input
        input_dict
                Dictionary for additional input
        """
        all_davg = []
        all_dstd = []
        if True:
            sumr = []
            suma = []
            sumn = []
            sumr2 = []
            suma2 = []
            for cc, bb, tt, nn, mm in zip(data_coord, data_box, data_atype,
                                          natoms_vec, mesh):
                sysr,sysr2,sysa,sysa2,sysn \
                    = self._compute_dstats_sys_smth(cc,bb,tt,nn,mm)
                sumr.append(sysr)
                suma.append(sysa)
                sumn.append(sysn)
                sumr2.append(sysr2)
                suma2.append(sysa2)
            sumr = np.sum(sumr, axis=0)
            suma = np.sum(suma, axis=0)
            sumn = np.sum(sumn, axis=0)
            sumr2 = np.sum(sumr2, axis=0)
            suma2 = np.sum(suma2, axis=0)
            for type_i in range(self.ntypes):
                davgunit = [sumr[type_i] / sumn[type_i], 0, 0, 0]
                dstdunit = [
                    self._compute_std(sumr2[type_i], sumr[type_i],
                                      sumn[type_i]),
                    self._compute_std(suma2[type_i], suma[type_i],
                                      sumn[type_i]),
                    self._compute_std(suma2[type_i], suma[type_i],
                                      sumn[type_i]),
                    self._compute_std(suma2[type_i], suma[type_i],
                                      sumn[type_i])
                ]
                davg = np.tile(davgunit, self.ndescrpt // 4)
                dstd = np.tile(dstdunit, self.ndescrpt // 4)
                all_davg.append(davg)
                all_dstd.append(dstd)

        if not self.set_davg_zero:
            self.davg = np.array(all_davg)
        self.dstd = np.array(all_dstd)

    def enable_compression(
        self,
        min_nbor_dist: float,
        model_file: str = 'frozon_model.pb',
        table_extrapolate: float = 5,
        table_stride_1: float = 0.01,
        table_stride_2: float = 0.1,
        check_frequency: int = -1,
        suffix: str = "",
    ) -> None:
        """
        Reveive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.

        Parameters
        ----------
        min_nbor_dist
                The nearest distance between atoms
        model_file
                The original frozen model, which will be compressed by the program
        table_extrapolate
                The scale of model extrapolation
        table_stride_1
                The uniform stride of the first table
        table_stride_2
                The uniform stride of the second table
        check_frequency
                The overflow check frequency
        suffix : str, optional
                The suffix of the scope
        """
        assert (
            not self.filter_resnet_dt
        ), "Model compression error: descriptor resnet_dt must be false!"

        for ii in range(len(self.filter_neuron) - 1):
            if self.filter_neuron[ii] * 2 != self.filter_neuron[ii + 1]:
                raise NotImplementedError(
                    "Model Compression error: descriptor neuron [%s] is not supported by model compression! "
                    "The size of the next layer of the neural network must be twice the size of the previous layer."
                    % ','.join([str(item) for item in self.filter_neuron]))

        self.compress = True
        self.table = DPTabulate(self,
                                self.filter_neuron,
                                model_file,
                                activation_fn=self.filter_activation_fn,
                                suffix=suffix)
        self.table_config = [
            table_extrapolate, table_stride_1 * 10, table_stride_2 * 10,
            check_frequency
        ]
        self.lower, self.upper \
            = self.table.build(min_nbor_dist,
                               table_extrapolate,
                               table_stride_1 * 10,
                               table_stride_2 * 10)

        graph, _ = load_graph_def(model_file)
        self.davg = get_tensor_by_name_from_graph(
            graph, 'descrpt_attr%s/t_avg' % suffix)
        self.dstd = get_tensor_by_name_from_graph(
            graph, 'descrpt_attr%s/t_std' % suffix)

    def build(self,
              coord_: tf.Tensor,
              atype_: tf.Tensor,
              natoms: tf.Tensor,
              box_: tf.Tensor,
              mesh: tf.Tensor,
              input_dict: dict,
              reuse: bool = None,
              suffix: str = '') -> tf.Tensor:
        """
        Build the computational graph for the descriptor

        Parameters
        ----------
        coord_
                The coordinate of atoms
        atype_
                The type of atoms
        natoms
                The number of atoms. This tensor has the length of Ntypes + 2
                natoms[0]: number of local atoms
                natoms[1]: total number of atoms held by this processor
                natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
        mesh
                For historical reasons, only the length of the Tensor matters.
                if size of mesh == 6, pbc is assumed. 
                if size of mesh == 0, no-pbc is assumed. 
        input_dict
                Dictionary for additional inputs
        reuse
                The weights in the networks should be reused when get the variable.
        suffix
                Name suffix to identify this descriptor

        Returns
        -------
        descriptor
                The output descriptor
        """
        davg = self.davg
        dstd = self.dstd
        with tf.variable_scope('descrpt_attr' + suffix, reuse=reuse):
            if davg is None:
                davg = np.zeros([self.ntypes, self.ndescrpt])
            if dstd is None:
                dstd = np.ones([self.ntypes, self.ndescrpt])
            t_rcut = tf.constant(np.max([self.rcut_r, self.rcut_a]),
                                 name='rcut',
                                 dtype=GLOBAL_TF_FLOAT_PRECISION)
            t_ntypes = tf.constant(self.ntypes, name='ntypes', dtype=tf.int32)
            t_ndescrpt = tf.constant(self.ndescrpt,
                                     name='ndescrpt',
                                     dtype=tf.int32)
            t_sel = tf.constant(self.sel_a, name='sel', dtype=tf.int32)
            self.t_avg = tf.get_variable(
                't_avg',
                davg.shape,
                dtype=GLOBAL_TF_FLOAT_PRECISION,
                trainable=False,
                initializer=tf.constant_initializer(davg))
            self.t_std = tf.get_variable(
                't_std',
                dstd.shape,
                dtype=GLOBAL_TF_FLOAT_PRECISION,
                trainable=False,
                initializer=tf.constant_initializer(dstd))

        coord = tf.reshape(coord_, [-1, natoms[1] * 3])
        box = tf.reshape(box_, [-1, 9])
        atype = tf.reshape(atype_, [-1, natoms[1]])

        self.descrpt, self.descrpt_deriv, self.rij, self.nlist \
            = op_module.prod_env_mat_a (coord,
                                       atype,
                                       natoms,
                                       box,
                                       mesh,
                                       self.t_avg,
                                       self.t_std,
                                       rcut_a = self.rcut_a,
                                       rcut_r = self.rcut_r,
                                       rcut_r_smth = self.rcut_r_smth,
                                       sel_a = self.sel_a,
                                       sel_r = self.sel_r)

        self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt])
        self._identity_tensors(suffix=suffix)

        self.dout, self.qmat = self._pass_filter(self.descrpt_reshape,
                                                 atype,
                                                 natoms,
                                                 input_dict,
                                                 suffix=suffix,
                                                 reuse=reuse,
                                                 trainable=self.trainable)

        return self.dout

    def prod_force_virial(
            self, atom_ener: tf.Tensor,
            natoms: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
        """
        Compute force and virial

        Parameters
        ----------
        atom_ener
                The atomic energy
        natoms
                The number of atoms. This tensor has the length of Ntypes + 2
                natoms[0]: number of local atoms
                natoms[1]: total number of atoms held by this processor
                natoms[i]: 2 <= i < Ntypes+2, number of type i atoms

        Returns
        -------
        force
                The force on atoms
        virial
                The total virial
        atom_virial
                The atomic virial
        """
        [net_deriv] = tf.gradients(atom_ener, self.descrpt_reshape)
        net_deriv_reshape = tf.reshape(net_deriv,
                                       [-1, natoms[0] * self.ndescrpt])
        force \
            = op_module.prod_force_se_a (net_deriv_reshape,
                                          self.descrpt_deriv,
                                          self.nlist,
                                          natoms,
                                          n_a_sel = self.nnei_a,
                                          n_r_sel = self.nnei_r)
        virial, atom_virial \
            = op_module.prod_virial_se_a (net_deriv_reshape,
                                           self.descrpt_deriv,
                                           self.rij,
                                           self.nlist,
                                           natoms,
                                           n_a_sel = self.nnei_a,
                                           n_r_sel = self.nnei_r)
        return force, virial, atom_virial

    def _pass_filter(self,
                     inputs,
                     atype,
                     natoms,
                     input_dict,
                     reuse=None,
                     suffix='',
                     trainable=True):
        start_index = 0
        inputs = tf.reshape(inputs, [-1, self.ndescrpt * natoms[0]])
        output = []
        output_qmat = []
        inputs_i = inputs
        inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
        type_i = -1
        layer, qmat = self._filter(inputs_i,
                                   type_i,
                                   name='filter_type_all' + suffix,
                                   natoms=natoms,
                                   reuse=reuse,
                                   trainable=trainable,
                                   activation_fn=self.filter_activation_fn)
        layer = tf.reshape(
            layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()])
        # qmat  = tf.reshape(qmat,  [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3])
        output.append(layer)
        # output_qmat.append(qmat)
        output = tf.concat(output, axis=1)
        # output_qmat = tf.concat(output_qmat, axis = 1)
        return output, None

    def _compute_dstats_sys_smth(self, data_coord, data_box, data_atype,
                                 natoms_vec, mesh):
        dd_all \
            = run_sess(self.sub_sess, self.stat_descrpt,
                                feed_dict = {
                                    self.place_holders['coord']: data_coord,
                                    self.place_holders['type']: data_atype,
                                    self.place_holders['natoms_vec']: natoms_vec,
                                    self.place_holders['box']: data_box,
                                    self.place_holders['default_mesh']: mesh,
                                })
        natoms = natoms_vec
        dd_all = np.reshape(dd_all, [-1, self.ndescrpt * natoms[0]])
        start_index = 0
        sysr = []
        sysa = []
        sysn = []
        sysr2 = []
        sysa2 = []
        for type_i in range(self.ntypes):
            end_index = start_index + self.ndescrpt * natoms[2 + type_i]
            dd = dd_all[:, start_index:end_index]
            dd = np.reshape(dd, [-1, self.ndescrpt])
            start_index = end_index
            # compute
            dd = np.reshape(dd, [-1, 4])
            ddr = dd[:, :1]
            dda = dd[:, 1:]
            sumr = np.sum(ddr)
            suma = np.sum(dda) / 3.
            sumn = dd.shape[0]
            sumr2 = np.sum(np.multiply(ddr, ddr))
            suma2 = np.sum(np.multiply(dda, dda)) / 3.
            sysr.append(sumr)
            sysa.append(suma)
            sysn.append(sumn)
            sysr2.append(sumr2)
            sysa2.append(suma2)
        return sysr, sysr2, sysa, sysa2, sysn

    def _compute_std(self, sumv2, sumv, sumn):
        val = np.sqrt(sumv2 / sumn - np.multiply(sumv / sumn, sumv / sumn))
        if np.abs(val) < 1e-2:
            val = 1e-2
        return val

    @cast_precision
    def _filter(self,
                inputs,
                type_input,
                natoms,
                activation_fn=tf.nn.tanh,
                stddev=1.0,
                bavg=0.0,
                name='linear',
                reuse=None,
                trainable=True):
        # natom x (nei x 4)
        shape = inputs.get_shape().as_list()
        outputs_size = [1] + self.filter_neuron
        with tf.variable_scope(name, reuse=reuse):
            start_index_i = 0
            result = None
            for type_i in range(self.ntypes):
                # cut-out inputs
                # with natom x (nei_type_i x 4)
                inputs_i = tf.slice(inputs, [0, start_index_i * 4],
                                    [-1, self.sel_a[type_i] * 4])
                start_index_j = start_index_i
                start_index_i += self.sel_a[type_i]
                nei_type_i = self.sel_a[type_i]
                shape_i = inputs_i.get_shape().as_list()
                assert (shape_i[1] == nei_type_i * 4)
                # with natom x nei_type_i x 4
                env_i = tf.reshape(inputs_i, [-1, nei_type_i, 4])
                # with natom x nei_type_i x 3
                env_i = tf.slice(env_i, [0, 0, 1], [-1, -1, -1])
                for type_j in range(type_i, self.ntypes):
                    # with natom x (nei_type_j x 4)
                    inputs_j = tf.slice(inputs, [0, start_index_j * 4],
                                        [-1, self.sel_a[type_j] * 4])
                    start_index_j += self.sel_a[type_j]
                    nei_type_j = self.sel_a[type_j]
                    shape_j = inputs_j.get_shape().as_list()
                    assert (shape_j[1] == nei_type_j * 4)
                    # with natom x nei_type_j x 4
                    env_j = tf.reshape(inputs_j, [-1, nei_type_j, 4])
                    # with natom x nei_type_i x 3
                    env_j = tf.slice(env_j, [0, 0, 1], [-1, -1, -1])
                    # with natom x nei_type_i x nei_type_j
                    env_ij = tf.einsum('ijm,ikm->ijk', env_i, env_j)
                    # with (natom x nei_type_i x nei_type_j)
                    ebd_env_ij = tf.reshape(env_ij, [-1, 1])
                    if self.compress:
                        info = [
                            self.lower, self.upper,
                            self.upper * self.table_config[0],
                            self.table_config[1], self.table_config[2],
                            self.table_config[3]
                        ]
                        net = 'filter_' + str(type_i) + '_net_' + str(type_j)
                        res_ij = op_module.tabulate_fusion_se_t(
                            tf.cast(self.table.data[net],
                                    self.filter_precision),
                            info,
                            ebd_env_ij,
                            env_ij,
                            last_layer_size=outputs_size[-1])
                    else:
                        # with (natom x nei_type_i x nei_type_j) x out_size
                        ebd_env_ij = embedding_net(
                            ebd_env_ij,
                            self.filter_neuron,
                            self.filter_precision,
                            activation_fn=activation_fn,
                            resnet_dt=self.filter_resnet_dt,
                            name_suffix=f"_{type_i}_{type_j}",
                            stddev=stddev,
                            bavg=bavg,
                            seed=self.seed,
                            trainable=trainable,
                            uniform_seed=self.uniform_seed,
                            initial_variables=self.embedding_net_variables,
                        )
                        if (not self.uniform_seed) and (self.seed is not None):
                            self.seed += self.seed_shift
                        # with natom x nei_type_i x nei_type_j x out_size
                        ebd_env_ij = tf.reshape(
                            ebd_env_ij,
                            [-1, nei_type_i, nei_type_j, outputs_size[-1]])
                        # with natom x out_size
                        res_ij = tf.einsum('ijk,ijkm->im', env_ij, ebd_env_ij)
                    res_ij = res_ij * (1.0 / float(nei_type_i) /
                                       float(nei_type_j))
                    if result is None:
                        result = res_ij
                    else:
                        result += res_ij
        return result, None
示例#6
0
class DescrptSeAEf(Descriptor):
    """

    Parameters
    ----------
    rcut
            The cut-off radius
    rcut_smth
            From where the environment matrix should be smoothed
    sel : list[str]
            sel[i] specifies the maxmum number of type i atoms in the cut-off radius
    neuron : list[int]
            Number of neurons in each hidden layers of the embedding net
    axis_neuron
            Number of the axis neuron (number of columns of the sub-matrix of the embedding matrix)
    resnet_dt
            Time-step `dt` in the resnet construction:
            y = x + dt * \phi (Wx + b)
    trainable
            If the weights of embedding net are trainable.
    seed
            Random seed for initializing the network parameters.
    type_one_side
            Try to build N_types embedding nets. Otherwise, building N_types^2 embedding nets
    exclude_types : List[List[int]]
            The excluded pairs of types which have no interaction with each other.
            For example, `[[0, 1]]` means no interaction between type 0 and type 1.
    set_davg_zero
            Set the shift of embedding net input to zero.
    activation_function
            The activation function in the embedding net. Supported options are {0}
    precision
            The precision of the embedding net parameters. Supported options are {1}
    uniform_seed
            Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
    """
    @docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()),
                         list_to_doc(PRECISION_DICT.keys()))
    def __init__(self,
                 rcut: float,
                 rcut_smth: float,
                 sel: List[str],
                 neuron: List[int] = [24, 48, 96],
                 axis_neuron: int = 8,
                 resnet_dt: bool = False,
                 trainable: bool = True,
                 seed: int = None,
                 type_one_side: bool = True,
                 exclude_types: List[List[int]] = [],
                 set_davg_zero: bool = False,
                 activation_function: str = 'tanh',
                 precision: str = 'default',
                 uniform_seed=False) -> None:
        """
        Constructor
        """
        self.descrpt_para = DescrptSeAEfLower(
            op_module.descrpt_se_a_ef_para,
            rcut,
            rcut_smth,
            sel,
            neuron,
            axis_neuron,
            resnet_dt,
            trainable,
            seed,
            type_one_side,
            exclude_types,
            set_davg_zero,
            activation_function,
            precision,
            uniform_seed,
        )
        self.descrpt_vert = DescrptSeAEfLower(
            op_module.descrpt_se_a_ef_vert,
            rcut,
            rcut_smth,
            sel,
            neuron,
            axis_neuron,
            resnet_dt,
            trainable,
            seed,
            type_one_side,
            exclude_types,
            set_davg_zero,
            activation_function,
            precision,
            uniform_seed,
        )

    def get_rcut(self) -> float:
        """
        Returns the cut-off radisu
        """
        return self.descrpt_vert.rcut_r

    def get_ntypes(self) -> int:
        """
        Returns the number of atom types
        """
        return self.descrpt_vert.ntypes

    def get_dim_out(self) -> int:
        """
        Returns the output dimension of this descriptor
        """
        return self.descrpt_vert.get_dim_out() + self.descrpt_para.get_dim_out(
        )

    def get_dim_rot_mat_1(self) -> int:
        """
        Returns the first dimension of the rotation matrix. The rotation is of shape dim_1 x 3
        """
        return self.descrpt_vert.filter_neuron[-1]

    def get_rot_mat(self) -> tf.Tensor:
        """
        Get rotational matrix
        """
        return self.qmat

    def get_nlist(self) -> Tuple[tf.Tensor, tf.Tensor, List[int], List[int]]:
        """
        Returns
        -------
        nlist
                Neighbor list
        rij
                The relative distance between the neighbor and the center atom.
        sel_a
                The number of neighbors with full information
        sel_r
                The number of neighbors with only radial information
        """
        return \
            self.descrpt_vert.nlist, \
            self.descrpt_vert.rij, \
            self.descrpt_vert.sel_a, \
            self.descrpt_vert.sel_r

    def compute_input_stats(self, data_coord: list, data_box: list,
                            data_atype: list, natoms_vec: list, mesh: list,
                            input_dict: dict) -> None:
        """
        Compute the statisitcs (avg and std) of the training data. The input will be normalized by the statistics.
        
        Parameters
        ----------
        data_coord
                The coordinates. Can be generated by deepmd.model.make_stat_input
        data_box
                The box. Can be generated by deepmd.model.make_stat_input
        data_atype
                The atom types. Can be generated by deepmd.model.make_stat_input
        natoms_vec
                The vector for the number of atoms of the system and different types of atoms. Can be generated by deepmd.model.make_stat_input
        mesh
                The mesh for neighbor searching. Can be generated by deepmd.model.make_stat_input
        input_dict
                Dictionary for additional input
        """
        self.descrpt_vert.compute_input_stats(data_coord, data_box, data_atype,
                                              natoms_vec, mesh, input_dict)
        self.descrpt_para.compute_input_stats(data_coord, data_box, data_atype,
                                              natoms_vec, mesh, input_dict)

    def build(self,
              coord_: tf.Tensor,
              atype_: tf.Tensor,
              natoms: tf.Tensor,
              box_: tf.Tensor,
              mesh: tf.Tensor,
              input_dict: dict,
              reuse: bool = None,
              suffix: str = '') -> tf.Tensor:
        """
        Build the computational graph for the descriptor

        Parameters
        ----------
        coord_
                The coordinate of atoms
        atype_
                The type of atoms
        natoms
                The number of atoms. This tensor has the length of Ntypes + 2
                natoms[0]: number of local atoms
                natoms[1]: total number of atoms held by this processor
                natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
        mesh
                For historical reasons, only the length of the Tensor matters.
                if size of mesh == 6, pbc is assumed. 
                if size of mesh == 0, no-pbc is assumed. 
        input_dict
                Dictionary for additional inputs. Should have 'efield'.
        reuse
                The weights in the networks should be reused when get the variable.
        suffix
                Name suffix to identify this descriptor

        Returns
        -------
        descriptor
                The output descriptor
        """
        self.dout_vert = self.descrpt_vert.build(coord_, atype_, natoms, box_,
                                                 mesh, input_dict)
        self.dout_para = self.descrpt_para.build(coord_,
                                                 atype_,
                                                 natoms,
                                                 box_,
                                                 mesh,
                                                 input_dict,
                                                 reuse=True)
        coord = tf.reshape(coord_, [-1, natoms[1] * 3])
        nframes = tf.shape(coord)[0]
        self.dout_vert = tf.reshape(
            self.dout_vert,
            [nframes * natoms[0],
             self.descrpt_vert.get_dim_out()])
        self.dout_para = tf.reshape(
            self.dout_para,
            [nframes * natoms[0],
             self.descrpt_para.get_dim_out()])
        self.dout = tf.concat([self.dout_vert, self.dout_para], axis=1)
        self.dout = tf.reshape(self.dout,
                               [nframes, natoms[0] * self.get_dim_out()])
        self.qmat = self.descrpt_vert.qmat + self.descrpt_para.qmat

        tf.summary.histogram('embedding_net_output', self.dout)

        return self.dout

    def prod_force_virial(
            self, atom_ener: tf.Tensor,
            natoms: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
        """
        Compute force and virial

        Parameters
        ----------
        atom_ener
                The atomic energy
        natoms
                The number of atoms. This tensor has the length of Ntypes + 2
                natoms[0]: number of local atoms
                natoms[1]: total number of atoms held by this processor
                natoms[i]: 2 <= i < Ntypes+2, number of type i atoms

        Returns
        -------
        force
                The force on atoms
        virial
                The total virial
        atom_virial
                The atomic virial
        """
        f_vert, v_vert, av_vert \
            = self.descrpt_vert.prod_force_virial(atom_ener, natoms)
        f_para, v_para, av_para \
            = self.descrpt_para.prod_force_virial(atom_ener, natoms)
        force = f_vert + f_para
        virial = v_vert + v_para
        atom_vir = av_vert + av_para
        return force, virial, atom_vir
示例#7
0
class DipoleFittingSeA():
    """
    Fit the atomic dipole with descriptor se_a
    
    Parameters
    ----------
    descrpt : tf.Tensor
            The descrptor
    neuron : List[int]
            Number of neurons in each hidden layer of the fitting net
    resnet_dt : bool
            Time-step `dt` in the resnet construction:
            y = x + dt * \phi (Wx + b)
    sel_type : List[int]
            The atom types selected to have an atomic dipole prediction. If is None, all atoms are selected.
    seed : int
            Random seed for initializing the network parameters.
    activation_function : str
            The activation function in the embedding net. Supported options are {0}
    precision : str
            The precision of the embedding net parameters. Supported options are {1}        
    uniform_seed
            Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
    """
    @docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()),
                         list_to_doc(PRECISION_DICT.keys()))
    def __init__(self,
                 descrpt: tf.Tensor,
                 neuron: List[int] = [120, 120, 120],
                 resnet_dt: bool = True,
                 sel_type: List[int] = None,
                 seed: int = None,
                 activation_function: str = 'tanh',
                 precision: str = 'default',
                 uniform_seed: bool = False) -> None:
        """
        Constructor
        """
        if not isinstance(descrpt, DescrptSeA):
            raise RuntimeError('DipoleFittingSeA only supports DescrptSeA')
        self.ntypes = descrpt.get_ntypes()
        self.dim_descrpt = descrpt.get_dim_out()
        # args = ClassArg()\
        #        .add('neuron',           list,   default = [120,120,120], alias = 'n_neuron')\
        #        .add('resnet_dt',        bool,   default = True)\
        #        .add('sel_type',         [list,int],   default = [ii for ii in range(self.ntypes)], alias = 'dipole_type')\
        #        .add('seed',             int)\
        #        .add("activation_function", str, default = "tanh")\
        #        .add('precision',           str,    default = "default")
        # class_data = args.parse(jdata)
        self.n_neuron = neuron
        self.resnet_dt = resnet_dt
        self.sel_type = sel_type
        if self.sel_type is None:
            self.sel_type = [ii for ii in range(self.ntypes)]
        self.sel_type = sel_type
        self.seed = seed
        self.uniform_seed = uniform_seed
        self.seed_shift = one_layer_rand_seed_shift()
        self.fitting_activation_fn = get_activation_func(activation_function)
        self.fitting_precision = get_precision(precision)
        self.dim_rot_mat_1 = descrpt.get_dim_rot_mat_1()
        self.dim_rot_mat = self.dim_rot_mat_1 * 3
        self.useBN = False

    def get_sel_type(self) -> int:
        """
        Get selected type
        """
        return self.sel_type

    def get_out_size(self) -> int:
        """
        Get the output size. Should be 3
        """
        return 3

    def build(self,
              input_d: tf.Tensor,
              rot_mat: tf.Tensor,
              natoms: tf.Tensor,
              reuse: bool = None,
              suffix: str = '') -> tf.Tensor:
        """
        Build the computational graph for fitting net
        
        Parameters
        ----------
        input_d
                The input descriptor
        rot_mat
                The rotation matrix from the descriptor.
        natoms
                The number of atoms. This tensor has the length of Ntypes + 2
                natoms[0]: number of local atoms
                natoms[1]: total number of atoms held by this processor
                natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
        reuse
                The weights in the networks should be reused when get the variable.
        suffix
                Name suffix to identify this descriptor

        Returns
        -------
        dipole
                The atomic dipole.
        """
        start_index = 0
        inputs = tf.cast(
            tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]]),
            self.fitting_precision)
        rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]])

        count = 0
        for type_i in range(self.ntypes):
            # cut-out inputs
            inputs_i = tf.slice(inputs, [0, start_index * self.dim_descrpt],
                                [-1, natoms[2 + type_i] * self.dim_descrpt])
            inputs_i = tf.reshape(inputs_i, [-1, self.dim_descrpt])
            rot_mat_i = tf.slice(rot_mat, [0, start_index * self.dim_rot_mat],
                                 [-1, natoms[2 + type_i] * self.dim_rot_mat])
            rot_mat_i = tf.reshape(rot_mat_i, [-1, self.dim_rot_mat_1, 3])
            start_index += natoms[2 + type_i]
            if not type_i in self.sel_type:
                continue
            layer = inputs_i
            for ii in range(0, len(self.n_neuron)):
                if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii - 1]:
                    layer += one_layer(
                        layer,
                        self.n_neuron[ii],
                        name='layer_' + str(ii) + '_type_' + str(type_i) +
                        suffix,
                        reuse=reuse,
                        seed=self.seed,
                        use_timestep=self.resnet_dt,
                        activation_fn=self.fitting_activation_fn,
                        precision=self.fitting_precision,
                        uniform_seed=self.uniform_seed)
                else:
                    layer = one_layer(layer,
                                      self.n_neuron[ii],
                                      name='layer_' + str(ii) + '_type_' +
                                      str(type_i) + suffix,
                                      reuse=reuse,
                                      seed=self.seed,
                                      activation_fn=self.fitting_activation_fn,
                                      precision=self.fitting_precision,
                                      uniform_seed=self.uniform_seed)
                if (not self.uniform_seed) and (self.seed is not None):
                    self.seed += self.seed_shift
            # (nframes x natoms) x naxis
            final_layer = one_layer(layer,
                                    self.dim_rot_mat_1,
                                    activation_fn=None,
                                    name='final_layer_type_' + str(type_i) +
                                    suffix,
                                    reuse=reuse,
                                    seed=self.seed,
                                    precision=self.fitting_precision,
                                    uniform_seed=self.uniform_seed)
            if (not self.uniform_seed) and (self.seed is not None):
                self.seed += self.seed_shift
            # (nframes x natoms) x 1 * naxis
            final_layer = tf.reshape(final_layer, [
                tf.shape(inputs)[0] * natoms[2 + type_i], 1, self.dim_rot_mat_1
            ])
            # (nframes x natoms) x 1 x 3(coord)
            final_layer = tf.matmul(final_layer, rot_mat_i)
            # nframes x natoms x 3
            final_layer = tf.reshape(
                final_layer, [tf.shape(inputs)[0], natoms[2 + type_i], 3])

            # concat the results
            if count == 0:
                outs = final_layer
            else:
                outs = tf.concat([outs, final_layer], axis=1)
            count += 1

        tf.summary.histogram('fitting_net_output', outs)
        return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION)
示例#8
0
文件: polar.py 项目: njzjz/deepmd-kit
class GlobalPolarFittingSeA () :
    """
    Fit the system polarizability with descriptor se_a

    Parameters
    ----------
    descrpt : tf.Tensor
            The descrptor
    neuron : List[int]
            Number of neurons in each hidden layer of the fitting net
    resnet_dt : bool
            Time-step `dt` in the resnet construction:
            y = x + dt * \phi (Wx + b)
    sel_type : List[int]
            The atom types selected to have an atomic polarizability prediction
    fit_diag : bool
            Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to normal polarizability matrix by contracting with the rotation matrix.
    scale : List[float]
            The output of the fitting net (polarizability matrix) for type i atom will be scaled by scale[i]
    diag_shift : List[float]
            The diagonal part of the polarizability matrix of type i will be shifted by diag_shift[i]. The shift operation is carried out after scale.        
    seed : int
            Random seed for initializing the network parameters.
    activation_function : str
            The activation function in the embedding net. Supported options are {0}
    precision : str
            The precision of the embedding net parameters. Supported options are {1}    
    """
    @docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()), list_to_doc(PRECISION_DICT.keys()))
    def __init__ (self, 
                  descrpt : tf.Tensor,
                  neuron : List[int] = [120,120,120],
                  resnet_dt : bool = True,
                  sel_type : List[int] = None,
                  fit_diag : bool = True,
                  scale : List[float] = None,
                  diag_shift : List[float] = None,
                  seed : int = None,
                  activation_function : str = 'tanh',
                  precision : str = 'default'
    ) -> None:
        """
        Constructor            
        """
        if not isinstance(descrpt, DescrptSeA) :
            raise RuntimeError('GlobalPolarFittingSeA only supports DescrptSeA')
        self.ntypes = descrpt.get_ntypes()
        self.dim_descrpt = descrpt.get_dim_out()
        self.polar_fitting = PolarFittingSeA(descrpt,
                                             neuron,
                                             resnet_dt,
                                             sel_type,
                                             fit_diag,
                                             scale,
                                             diag_shift,
                                             seed,
                                             activation_function,
                                             precision)

    def get_sel_type(self) -> int:
        """
        Get selected atom types
        """
        return self.polar_fitting.get_sel_type()

    def get_out_size(self) -> int:
        """
        Get the output size. Should be 9
        """
        return self.polar_fitting.get_out_size()

    def build (self,
               input_d,
               rot_mat,
               natoms,
               reuse = None,
               suffix = '') -> tf.Tensor:
        """
        Build the computational graph for fitting net
        
        Parameters
        ----------
        input_d
                The input descriptor
        rot_mat
                The rotation matrix from the descriptor.
        natoms
                The number of atoms. This tensor has the length of Ntypes + 2
                natoms[0]: number of local atoms
                natoms[1]: total number of atoms held by this processor
                natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
        reuse
                The weights in the networks should be reused when get the variable.
        suffix
                Name suffix to identify this descriptor

        Returns
        -------
        polar
                The system polarizability        
        """
        inputs = tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]])
        outs = self.polar_fitting.build(input_d, rot_mat, natoms, reuse, suffix)
        # nframes x natoms x 9
        outs = tf.reshape(outs, [tf.shape(inputs)[0], -1, 9])
        outs = tf.reduce_sum(outs, axis = 1)
        tf.summary.histogram('fitting_net_output', outs)
        return tf.reshape(outs, [-1])
    
    def init_variables(self,
                       graph: tf.Graph,
                       graph_def: tf.GraphDef,
                       suffix : str = "",
    ) -> None:
        """
        Init the fitting net variables with the given dict

        Parameters
        ----------
        graph : tf.Graph
            The input frozen model graph
        graph_def : tf.GraphDef
            The input frozen model graph_def
        suffix : str
            suffix to name scope
        """
        self.polar_fitting.init_variables(graph=graph, graph_def=graph_def, suffix=suffix)


    def enable_mixed_precision(self, mixed_prec : dict = None) -> None:
        """
        Reveive the mixed precision setting.

        Parameters
        ----------
        mixed_prec
                The mixed precision setting used in the embedding net
        """
        self.polar_fitting.enable_mixed_precision(mixed_prec)
示例#9
0
文件: polar.py 项目: njzjz/deepmd-kit
class PolarFittingSeA (Fitting) :
    """
    Fit the atomic polarizability with descriptor se_a
    """
    @docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()), list_to_doc(PRECISION_DICT.keys()))
    def __init__ (self, 
                  descrpt : tf.Tensor,
                  neuron : List[int] = [120,120,120],
                  resnet_dt : bool = True,
                  sel_type : List[int] = None,
                  fit_diag : bool = True,
                  scale : List[float] = None,
                  shift_diag : bool = True,     # YWolfeee: will support the user to decide whether to use this function
                  #diag_shift : List[float] = None, YWolfeee: will not support the user to assign a shift
                  seed : int = None,
                  activation_function : str = 'tanh',
                  precision : str = 'default',
                  uniform_seed: bool = False                  
    ) -> None:
        """
        Constructor

        Parameters
        ----------
        descrpt : tf.Tensor
                The descrptor
        neuron : List[int]
                Number of neurons in each hidden layer of the fitting net
        resnet_dt : bool
                Time-step `dt` in the resnet construction:
                y = x + dt * \phi (Wx + b)
        sel_type : List[int]
                The atom types selected to have an atomic polarizability prediction. If is None, all atoms are selected.
        fit_diag : bool
                Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to normal polarizability matrix by contracting with the rotation matrix.
        scale : List[float]
                The output of the fitting net (polarizability matrix) for type i atom will be scaled by scale[i]
        diag_shift : List[float]
                The diagonal part of the polarizability matrix of type i will be shifted by diag_shift[i]. The shift operation is carried out after scale.        
        seed : int
                Random seed for initializing the network parameters.
        activation_function : str
                The activation function in the embedding net. Supported options are {0}
        precision : str
                The precision of the embedding net parameters. Supported options are {1}                
        uniform_seed
                Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
        """
        if not isinstance(descrpt, DescrptSeA) :
            raise RuntimeError('PolarFittingSeA only supports DescrptSeA')
        self.ntypes = descrpt.get_ntypes()
        self.dim_descrpt = descrpt.get_dim_out()
        # args = ClassArg()\
        #        .add('neuron',           list,   default = [120,120,120], alias = 'n_neuron')\
        #        .add('resnet_dt',        bool,   default = True)\
        #        .add('fit_diag',         bool,   default = True)\
        #        .add('diag_shift',       [list,float], default = [0.0 for ii in range(self.ntypes)])\
        #        .add('scale',            [list,float], default = [1.0 for ii in range(self.ntypes)])\
        #        .add('sel_type',         [list,int],   default = [ii for ii in range(self.ntypes)], alias = 'pol_type')\
        #        .add('seed',             int)\
        #        .add("activation_function", str ,   default = "tanh")\
        #        .add('precision',           str,    default = "default")
        # class_data = args.parse(jdata)
        self.n_neuron = neuron
        self.resnet_dt = resnet_dt
        self.sel_type = sel_type
        self.fit_diag = fit_diag
        self.seed = seed
        self.uniform_seed = uniform_seed
        self.seed_shift = one_layer_rand_seed_shift()
        #self.diag_shift = diag_shift
        self.shift_diag = shift_diag
        self.scale = scale
        self.fitting_activation_fn = get_activation_func(activation_function)
        self.fitting_precision = get_precision(precision)
        if self.sel_type is None:
            self.sel_type = [ii for ii in range(self.ntypes)]
        if self.scale is None:
            self.scale = [1.0 for ii in range(self.ntypes)]
        #if self.diag_shift is None:
        #    self.diag_shift = [0.0 for ii in range(self.ntypes)]
        if type(self.sel_type) is not list:
            self.sel_type = [self.sel_type]
        self.constant_matrix = np.zeros(len(self.sel_type)) # len(sel_type) x 1, store the average diagonal value
        #if type(self.diag_shift) is not list:
        #    self.diag_shift = [self.diag_shift]
        if type(self.scale) is not list:
            self.scale = [self.scale]
        self.dim_rot_mat_1 = descrpt.get_dim_rot_mat_1()
        self.dim_rot_mat = self.dim_rot_mat_1 * 3
        self.useBN = False
        self.fitting_net_variables = None
        self.mixed_prec = None

    def get_sel_type(self) -> List[int]:
        """
        Get selected atom types
        """
        return self.sel_type

    def get_out_size(self) -> int:
        """
        Get the output size. Should be 9
        """
        return 9

    def compute_input_stats(self, 
                            all_stat, 
                            protection = 1e-2):
        """
        Compute the input statistics

        Parameters
        ----------
        all_stat
                Dictionary of inputs. 
                can be prepared by model.make_stat_input
        protection
                Divided-by-zero protection
        """
        if not ('polarizability' in all_stat.keys()):
            self.avgeig = np.zeros([9])
            warnings.warn('no polarizability data, cannot do data stat. use zeros as guess')
            return
        data = all_stat['polarizability']
        all_tmp = []
        for ss in range(len(data)):
            tmp = np.concatenate(data[ss], axis = 0)
            tmp = np.reshape(tmp, [-1, 3, 3])
            tmp,_ = np.linalg.eig(tmp)
            tmp = np.absolute(tmp)
            tmp = np.sort(tmp, axis = 1)
            all_tmp.append(tmp)
        all_tmp = np.concatenate(all_tmp, axis = 1)
        self.avgeig = np.average(all_tmp, axis = 0)

        # YWolfeee: support polar normalization, initialize to a more appropriate point 
        if self.shift_diag:
            mean_polar = np.zeros([len(self.sel_type), 9])
            sys_matrix, polar_bias = [], []
            for ss in range(len(all_stat['type'])):
                atom_has_polar = [w for w in all_stat['type'][ss][0] if (w in self.sel_type)]   # select atom with polar
                if all_stat['find_atomic_polarizability'][ss] > 0.0:
                    for itype in range(len(self.sel_type)): # Atomic polar mode, should specify the atoms
                        index_lis = [index for index, w in enumerate(atom_has_polar) \
                                        if atom_has_polar[index] == self.sel_type[itype]]   # select index in this type

                        sys_matrix.append(np.zeros((1,len(self.sel_type))))
                        sys_matrix[-1][0,itype] = len(index_lis)

                        polar_bias.append(np.sum(
                            all_stat['atomic_polarizability'][ss].reshape((-1,9))[index_lis],axis=0).reshape((1,9)))
                else:   # No atomic polar in this system, so it should have global polar
                    if not all_stat['find_polarizability'][ss] > 0.0: # This system is jsut a joke?
                        continue
                    # Till here, we have global polar
                    sys_matrix.append(np.zeros((1,len(self.sel_type)))) # add a line in the equations
                    for itype in range(len(self.sel_type)): # Atomic polar mode, should specify the atoms
                        index_lis = [index for index, w in enumerate(atom_has_polar) \
                                        if atom_has_polar[index] == self.sel_type[itype]]   # select index in this type

                        sys_matrix[-1][0,itype] = len(index_lis)
                    
                    # add polar_bias
                    polar_bias.append(all_stat['polarizability'][ss].reshape((1,9)))

            matrix, bias = np.concatenate(sys_matrix,axis=0), np.concatenate(polar_bias,axis=0)
            atom_polar,_,_,_ \
                = np.linalg.lstsq(matrix, bias, rcond = 1e-3)
            for itype in range(len(self.sel_type)):
                self.constant_matrix[itype] = np.mean(np.diagonal(atom_polar[itype].reshape((3,3))))

    @cast_precision
    def build (self, 
               input_d : tf.Tensor,
               rot_mat : tf.Tensor,
               natoms : tf.Tensor,
               reuse : bool = None,
               suffix : str = '') :
        """
        Build the computational graph for fitting net
        
        Parameters
        ----------
        input_d
                The input descriptor
        rot_mat
                The rotation matrix from the descriptor.
        natoms
                The number of atoms. This tensor has the length of Ntypes + 2
                natoms[0]: number of local atoms
                natoms[1]: total number of atoms held by this processor
                natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
        reuse
                The weights in the networks should be reused when get the variable.
        suffix
                Name suffix to identify this descriptor

        Returns
        -------
        atomic_polar
                The atomic polarizability        
        """
        start_index = 0
        inputs = tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]])
        rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]])

        count = 0
        outs_list = []
        for type_i in range(self.ntypes):
            # cut-out inputs
            inputs_i = tf.slice (inputs,
                                 [ 0, start_index*      self.dim_descrpt],
                                 [-1, natoms[2+type_i]* self.dim_descrpt] )
            inputs_i = tf.reshape(inputs_i, [-1, self.dim_descrpt])
            rot_mat_i = tf.slice (rot_mat,
                                  [ 0, start_index*      self.dim_rot_mat],
                                  [-1, natoms[2+type_i]* self.dim_rot_mat] )
            rot_mat_i = tf.reshape(rot_mat_i, [-1, self.dim_rot_mat_1, 3])
            start_index += natoms[2+type_i]
            if not type_i in self.sel_type :
                continue
            layer = inputs_i
            for ii in range(0,len(self.n_neuron)) :
                if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] :
                    layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables, mixed_prec = self.mixed_prec)
                else :
                    layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables, mixed_prec = self.mixed_prec)
                if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift
            if self.fit_diag :
                bavg = np.zeros(self.dim_rot_mat_1)
                # bavg[0] = self.avgeig[0]
                # bavg[1] = self.avgeig[1]
                # bavg[2] = self.avgeig[2]
                # (nframes x natoms) x naxis
                final_layer = one_layer(layer, self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, bavg = bavg, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables, mixed_prec = self.mixed_prec, final_layer = True)
                if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift
                # (nframes x natoms) x naxis
                final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0] * natoms[2+type_i], self.dim_rot_mat_1])
                # (nframes x natoms) x naxis x naxis
                final_layer = tf.matrix_diag(final_layer)                
            else :
                bavg = np.zeros(self.dim_rot_mat_1*self.dim_rot_mat_1)
                # bavg[0*self.dim_rot_mat_1+0] = self.avgeig[0]
                # bavg[1*self.dim_rot_mat_1+1] = self.avgeig[1]
                # bavg[2*self.dim_rot_mat_1+2] = self.avgeig[2]
                # (nframes x natoms) x (naxis x naxis)
                final_layer = one_layer(layer, self.dim_rot_mat_1*self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, bavg = bavg, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables, mixed_prec = self.mixed_prec, final_layer = True)
                if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift
                # (nframes x natoms) x naxis x naxis
                final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0] * natoms[2+type_i], self.dim_rot_mat_1, self.dim_rot_mat_1])
                # (nframes x natoms) x naxis x naxis
                final_layer = final_layer + tf.transpose(final_layer, perm = [0,2,1])
            # (nframes x natoms) x naxis x 3(coord)
            final_layer = tf.matmul(final_layer, rot_mat_i)
            # (nframes x natoms) x 3(coord) x 3(coord)
            final_layer = tf.matmul(rot_mat_i, final_layer, transpose_a = True)
            # nframes x natoms x 3 x 3
            final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i], 3, 3])
            # shift and scale
            sel_type_idx = self.sel_type.index(type_i)
            final_layer = final_layer * self.scale[sel_type_idx]
            final_layer = final_layer + self.constant_matrix[sel_type_idx] * tf.eye(3, batch_shape=[tf.shape(inputs)[0], natoms[2+type_i]], dtype = GLOBAL_TF_FLOAT_PRECISION)

            # concat the results
            outs_list.append(final_layer)
            count += 1
        outs = tf.concat(outs_list, axis = 1)
        
        tf.summary.histogram('fitting_net_output', outs)
        return tf.reshape(outs, [-1])

    def init_variables(self,
                       graph: tf.Graph,
                       graph_def: tf.GraphDef,
                       suffix : str = "",
    ) -> None:
        """
        Init the fitting net variables with the given dict

        Parameters
        ----------
        graph : tf.Graph
            The input frozen model graph
        graph_def : tf.GraphDef
            The input frozen model graph_def
        suffix : str
            suffix to name scope
        """
        self.fitting_net_variables = get_fitting_net_variables_from_graph_def(graph_def)


    def enable_mixed_precision(self, mixed_prec : dict = None) -> None:
        """
        Reveive the mixed precision setting.

        Parameters
        ----------
        mixed_prec
                The mixed precision setting used in the embedding net
        """
        self.mixed_prec = mixed_prec
        self.fitting_precision = get_precision(mixed_prec['output_prec'])