コード例 #1
0
    def comp_ef(self, dcoord, dbox, dtype, tnatoms, name, reuse=None):
        descrpt, descrpt_deriv, rij, nlist \
            = op_module.prod_env_mat_r(dcoord,
                                      dtype,
                                      tnatoms,
                                      dbox,
                                      tf.constant(self.default_mesh),
                                      self.t_avg,
                                      self.t_std,
                                      rcut = self.rcut,
                                      rcut_smth = self.rcut_smth,
                                      sel = self.sel)
        inputs_reshape = tf.reshape(descrpt, [-1, self.ndescrpt])
        atom_ener = self._net(inputs_reshape, name, reuse=reuse)
        atom_ener_reshape = tf.reshape(atom_ener, [-1, self.natoms[0]])
        energy = tf.reduce_sum(atom_ener_reshape, axis=1)
        net_deriv_ = tf.gradients(atom_ener, inputs_reshape)
        net_deriv = net_deriv_[0]
        net_deriv_reshape = tf.reshape(net_deriv,
                                       [-1, self.natoms[0] * self.ndescrpt])

        force = op_module.prod_force_se_r(net_deriv_reshape, descrpt_deriv,
                                          nlist, tnatoms)
        virial, atom_vir = op_module.prod_virial_se_r(net_deriv_reshape,
                                                      descrpt_deriv, rij,
                                                      nlist, tnatoms)
        return energy, force, virial
コード例 #2
0
    def __init__(self,
                 rcut: float,
                 rcut_smth: float,
                 sel: List[str],
                 neuron: List[int] = [24, 48, 96],
                 resnet_dt: bool = False,
                 trainable: bool = True,
                 seed: int = None,
                 type_one_side: bool = True,
                 exclude_types: List[List[int]] = [],
                 set_davg_zero: bool = False,
                 activation_function: str = 'tanh',
                 precision: str = 'default',
                 uniform_seed: bool = False) -> None:
        """
        Constructor
        """
        # args = ClassArg()\
        #        .add('sel',      list,   must = True) \
        #        .add('rcut',     float,  default = 6.0) \
        #        .add('rcut_smth',float,  default = 0.5) \
        #        .add('neuron',   list,   default = [10, 20, 40]) \
        #        .add('resnet_dt',bool,   default = False) \
        #        .add('trainable',bool,   default = True) \
        #        .add('seed',     int) \
        #        .add('type_one_side', bool, default = False) \
        #        .add('exclude_types', list, default = []) \
        #        .add('set_davg_zero', bool, default = False) \
        #        .add("activation_function", str, default = "tanh") \
        #        .add("precision",           str, default = "default")
        # class_data = args.parse(jdata)
        self.sel_r = sel
        self.rcut = rcut
        self.rcut_smth = rcut_smth
        self.filter_neuron = neuron
        self.filter_resnet_dt = resnet_dt
        self.seed = seed
        self.uniform_seed = uniform_seed
        self.seed_shift = embedding_net_rand_seed_shift(self.filter_neuron)
        self.trainable = trainable
        self.filter_activation_fn = get_activation_func(activation_function)
        self.filter_precision = get_precision(precision)
        exclude_types = exclude_types
        self.exclude_types = set()
        for tt in exclude_types:
            assert (len(tt) == 2)
            self.exclude_types.add((tt[0], tt[1]))
            self.exclude_types.add((tt[1], tt[0]))
        self.set_davg_zero = set_davg_zero
        self.type_one_side = type_one_side

        # descrpt config
        self.sel_a = [0 for ii in range(len(self.sel_r))]
        self.ntypes = len(self.sel_r)
        # numb of neighbors and numb of descrptors
        self.nnei_a = np.cumsum(self.sel_a)[-1]
        self.nnei_r = np.cumsum(self.sel_r)[-1]
        self.nnei = self.nnei_a + self.nnei_r
        self.ndescrpt_a = self.nnei_a * 4
        self.ndescrpt_r = self.nnei_r * 1
        self.ndescrpt = self.nnei_r
        self.useBN = False
        self.davg = None
        self.dstd = None
        self.embedding_net_variables = None

        self.place_holders = {}
        avg_zero = np.zeros([self.ntypes,
                             self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
        std_ones = np.ones([self.ntypes,
                            self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
        sub_graph = tf.Graph()
        with sub_graph.as_default():
            name_pfx = 'd_ser_'
            for ii in ['coord', 'box']:
                self.place_holders[ii] = tf.placeholder(
                    GLOBAL_NP_FLOAT_PRECISION, [None, None],
                    name=name_pfx + 't_' + ii)
            self.place_holders['type'] = tf.placeholder(tf.int32, [None, None],
                                                        name=name_pfx +
                                                        't_type')
            self.place_holders['natoms_vec'] = tf.placeholder(
                tf.int32, [self.ntypes + 2], name=name_pfx + 't_natoms')
            self.place_holders['default_mesh'] = tf.placeholder(
                tf.int32, [None], name=name_pfx + 't_mesh')
            self.stat_descrpt, descrpt_deriv, rij, nlist \
                = op_module.prod_env_mat_r(self.place_holders['coord'],
                                         self.place_holders['type'],
                                         self.place_holders['natoms_vec'],
                                         self.place_holders['box'],
                                         self.place_holders['default_mesh'],
                                         tf.constant(avg_zero),
                                         tf.constant(std_ones),
                                         rcut = self.rcut,
                                         rcut_smth = self.rcut_smth,
                                         sel = self.sel_r)
            self.sub_sess = tf.Session(graph=sub_graph,
                                       config=default_tf_session_config)
コード例 #3
0
    def build(self,
              coord_: tf.Tensor,
              atype_: tf.Tensor,
              natoms: tf.Tensor,
              box_: tf.Tensor,
              mesh: tf.Tensor,
              input_dict: dict,
              reuse: bool = None,
              suffix: str = '') -> tf.Tensor:
        """
        Build the computational graph for the descriptor

        Parameters
        ----------
        coord_
                The coordinate of atoms
        atype_
                The type of atoms
        natoms
                The number of atoms. This tensor has the length of Ntypes + 2
                natoms[0]: number of local atoms
                natoms[1]: total number of atoms held by this processor
                natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
        mesh
                For historical reasons, only the length of the Tensor matters.
                if size of mesh == 6, pbc is assumed. 
                if size of mesh == 0, no-pbc is assumed. 
        input_dict
                Dictionary for additional inputs
        reuse
                The weights in the networks should be reused when get the variable.
        suffix
                Name suffix to identify this descriptor

        Returns
        -------
        descriptor
                The output descriptor
        """
        davg = self.davg
        dstd = self.dstd
        with tf.variable_scope('descrpt_attr' + suffix, reuse=reuse):
            if davg is None:
                davg = np.zeros([self.ntypes, self.ndescrpt])
            if dstd is None:
                dstd = np.ones([self.ntypes, self.ndescrpt])
            t_rcut = tf.constant(self.rcut,
                                 name='rcut',
                                 dtype=GLOBAL_TF_FLOAT_PRECISION)
            t_ntypes = tf.constant(self.ntypes, name='ntypes', dtype=tf.int32)
            t_ndescrpt = tf.constant(self.ndescrpt,
                                     name='ndescrpt',
                                     dtype=tf.int32)
            t_sel = tf.constant(self.sel_a, name='sel', dtype=tf.int32)
            self.t_avg = tf.get_variable(
                't_avg',
                davg.shape,
                dtype=GLOBAL_TF_FLOAT_PRECISION,
                trainable=False,
                initializer=tf.constant_initializer(davg))
            self.t_std = tf.get_variable(
                't_std',
                dstd.shape,
                dtype=GLOBAL_TF_FLOAT_PRECISION,
                trainable=False,
                initializer=tf.constant_initializer(dstd))

        coord = tf.reshape(coord_, [-1, natoms[1] * 3])
        box = tf.reshape(box_, [-1, 9])
        atype = tf.reshape(atype_, [-1, natoms[1]])

        self.descrpt, self.descrpt_deriv, self.rij, self.nlist \
            = op_module.prod_env_mat_r(coord,
                                      atype,
                                      natoms,
                                      box,
                                      mesh,
                                      self.t_avg,
                                      self.t_std,
                                      rcut = self.rcut,
                                      rcut_smth = self.rcut_smth,
                                      sel = self.sel_r)

        self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt])
        self._identity_tensors(suffix=suffix)

        # only used when tensorboard was set as true
        tf.summary.histogram('descrpt', self.descrpt)
        tf.summary.histogram('rij', self.rij)
        tf.summary.histogram('nlist', self.nlist)

        self.dout = self._pass_filter(self.descrpt_reshape,
                                      natoms,
                                      suffix=suffix,
                                      reuse=reuse,
                                      trainable=self.trainable)
        tf.summary.histogram('embedding_net_output', self.dout)

        return self.dout