コード例 #1
0
 def setUp(self, data):
     self.sess = tf.Session()
     self.data = data
     self.natoms = self.data.get_natoms()
     self.ntypes = self.data.get_ntypes()
     param_a = {
         'sel': [12, 24],
         'rcut': 4,
         'rcut_smth': 3.5,
         'neuron': [5, 10, 20],
         'seed': 1,
     }
     param_r = {
         'sel': [20, 40],
         'rcut': 6,
         'rcut_smth': 6.5,
         'neuron': [10, 20, 40],
         'seed': 1,
     }
     param = {'a': param_a, 'r': param_r}
     self.descrpt = DescrptSeAR(param)
     self.ndescrpt = self.descrpt.get_dim_out()
     # davg = np.zeros ([self.ntypes, self.ndescrpt])
     # dstd = np.ones  ([self.ntypes, self.ndescrpt])
     # self.t_avg = tf.constant(davg.astype(np.float64))
     # self.t_std = tf.constant(dstd.astype(np.float64))
     avg_a = np.zeros([self.ntypes, self.descrpt.descrpt_a.ndescrpt])
     std_a = np.ones([self.ntypes, self.descrpt.descrpt_a.ndescrpt])
     avg_r = np.zeros([self.ntypes, self.descrpt.descrpt_r.ndescrpt])
     std_r = np.ones([self.ntypes, self.descrpt.descrpt_r.ndescrpt])
     self.avg = [avg_a, avg_r]
     self.std = [std_a, std_r]
     self.default_mesh = np.zeros(6, dtype=np.int32)
     self.default_mesh[3] = 2
     self.default_mesh[4] = 2
     self.default_mesh[5] = 2
     # make place holder
     self.coord = tf.placeholder(global_tf_float_precision,
                                 [None, self.natoms[0] * 3],
                                 name='t_coord')
     self.box = tf.placeholder(global_tf_float_precision, [None, 9],
                               name='t_box')
     self.type = tf.placeholder(tf.int32, [None, self.natoms[0]],
                                name="t_type")
     self.tnatoms = tf.placeholder(tf.int32, [None], name="t_natoms")
コード例 #2
0
    def _init_param(self, jdata):
        # model config
        model_param = j_must_have(jdata, 'model')
        descrpt_param = j_must_have(model_param, 'descriptor')
        fitting_param = j_must_have(model_param, 'fitting_net')

        # descriptor
        descrpt_type = j_must_have(descrpt_param, 'type')
        if descrpt_type == 'loc_frame':
            self.descrpt = DescrptLocFrame(descrpt_param)
        elif descrpt_type == 'se_a':
            self.descrpt = DescrptSeA(descrpt_param)
        elif descrpt_type == 'se_r':
            self.descrpt = DescrptSeR(descrpt_param)
        elif descrpt_type == 'se_ar':
            self.descrpt = DescrptSeAR(descrpt_param)
        else:
            raise RuntimeError('unknow model type ' + descrpt_type)

        # fitting net
        try:
            fitting_type = fitting_param['type']
        except:
            fitting_type = 'ener'
        if fitting_type == 'ener':
            self.fitting = EnerFitting(fitting_param, self.descrpt)
        elif fitting_type == 'wfc':
            self.fitting = WFCFitting(fitting_param, self.descrpt)
        elif fitting_type == 'dipole':
            if descrpt_type == 'se_a':
                self.fitting = DipoleFittingSeA(fitting_param, self.descrpt)
            else:
                raise RuntimeError(
                    'fitting dipole only supports descrptors: se_a')
        elif fitting_type == 'polar':
            if descrpt_type == 'loc_frame':
                self.fitting = PolarFittingLocFrame(fitting_param,
                                                    self.descrpt)
            elif descrpt_type == 'se_a':
                self.fitting = PolarFittingSeA(fitting_param, self.descrpt)
            else:
                raise RuntimeError(
                    'fitting polar only supports descrptors: loc_frame and se_a'
                )
        elif fitting_type == 'global_polar':
            if descrpt_type == 'se_a':
                self.fitting = GlobalPolarFittingSeA(fitting_param,
                                                     self.descrpt)
            else:
                raise RuntimeError(
                    'fitting global_polar only supports descrptors: loc_frame and se_a'
                )
        else:
            raise RuntimeError('unknow fitting type ' + fitting_type)

        # init model
        # infer model type by fitting_type
        if fitting_type == Model.model_type:
            self.model = Model(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'wfc':
            self.model = WFCModel(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'dipole':
            self.model = DipoleModel(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'polar':
            self.model = PolarModel(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'global_polar':
            self.model = GlobalPolarModel(model_param, self.descrpt,
                                          self.fitting)
        else:
            raise RuntimeError('get unknown fitting type when building model')

        # learning rate
        lr_param = j_must_have(jdata, 'learning_rate')
        try:
            lr_type = lr_param['type']
        except:
            lr_type = 'exp'
        if lr_type == 'exp':
            self.lr = LearningRateExp(lr_param)
        else:
            raise RuntimeError('unknown learning_rate type ' + lr_type)

        # loss
        # infer loss type by fitting_type
        try:
            loss_param = jdata['loss']
            loss_type = loss_param.get('type', 'std')
        except:
            loss_param = None
            loss_type = 'std'

        if fitting_type == 'ener':
            if loss_type == 'std':
                self.loss = EnerStdLoss(
                    loss_param, starter_learning_rate=self.lr.start_lr())
            elif loss_type == 'ener_dipole':
                self.loss = EnerDipoleLoss(
                    loss_param, starter_learning_rate=self.lr.start_lr())
            else:
                raise RuntimeError('unknow loss type')
        elif fitting_type == 'wfc':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='wfc',
                                   tensor_size=self.model.get_out_size(),
                                   label_name='wfc')
        elif fitting_type == 'dipole':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='dipole',
                                   tensor_size=3,
                                   label_name='dipole')
        elif fitting_type == 'polar':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='polar',
                                   tensor_size=9,
                                   label_name='polarizability')
        elif fitting_type == 'global_polar':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='global_polar',
                                   tensor_size=9,
                                   atomic=False,
                                   label_name='polarizability')
        else:
            raise RuntimeError(
                'get unknown fitting type when building loss function')

        # training
        training_param = j_must_have(jdata, 'training')

        tr_args = ClassArg()\
                  .add('numb_test',     int,    default = 1)\
                  .add('disp_file',     str,    default = 'lcurve.out')\
                  .add('disp_freq',     int,    default = 100)\
                  .add('save_freq',     int,    default = 1000)\
                  .add('save_ckpt',     str,    default = 'model.ckpt')\
                  .add('display_in_training', bool, default = True)\
                  .add('timing_in_training',  bool, default = True)\
                  .add('profiling',     bool,   default = False)\
                  .add('profiling_file',str,    default = 'timeline.json')\
                  .add('sys_probs',   list    )\
                  .add('auto_prob_style', str, default = "prob_sys_size")
        tr_data = tr_args.parse(training_param)
        self.numb_test = tr_data['numb_test']
        self.disp_file = tr_data['disp_file']
        self.disp_freq = tr_data['disp_freq']
        self.save_freq = tr_data['save_freq']
        self.save_ckpt = tr_data['save_ckpt']
        self.display_in_training = tr_data['display_in_training']
        self.timing_in_training = tr_data['timing_in_training']
        self.profiling = tr_data['profiling']
        self.profiling_file = tr_data['profiling_file']
        self.sys_probs = tr_data['sys_probs']
        self.auto_prob_style = tr_data['auto_prob_style']
        self.useBN = False
        if fitting_type == 'ener' and self.fitting.get_numb_fparam() > 0:
            self.numb_fparam = self.fitting.get_numb_fparam()
        else:
            self.numb_fparam = 0
コード例 #3
0
class Inter():
    def setUp(self, data):
        self.sess = tf.Session()
        self.data = data
        self.natoms = self.data.get_natoms()
        self.ntypes = self.data.get_ntypes()
        param_a = {
            'sel': [12, 24],
            'rcut': 4,
            'rcut_smth': 3.5,
            'neuron': [5, 10, 20],
            'seed': 1,
        }
        param_r = {
            'sel': [20, 40],
            'rcut': 6,
            'rcut_smth': 6.5,
            'neuron': [10, 20, 40],
            'seed': 1,
        }
        param = {'a': param_a, 'r': param_r}
        self.descrpt = DescrptSeAR(param)
        self.ndescrpt = self.descrpt.get_dim_out()
        # davg = np.zeros ([self.ntypes, self.ndescrpt])
        # dstd = np.ones  ([self.ntypes, self.ndescrpt])
        # self.t_avg = tf.constant(davg.astype(np.float64))
        # self.t_std = tf.constant(dstd.astype(np.float64))
        avg_a = np.zeros([self.ntypes, self.descrpt.descrpt_a.ndescrpt])
        std_a = np.ones([self.ntypes, self.descrpt.descrpt_a.ndescrpt])
        avg_r = np.zeros([self.ntypes, self.descrpt.descrpt_r.ndescrpt])
        std_r = np.ones([self.ntypes, self.descrpt.descrpt_r.ndescrpt])
        self.avg = [avg_a, avg_r]
        self.std = [std_a, std_r]
        self.default_mesh = np.zeros(6, dtype=np.int32)
        self.default_mesh[3] = 2
        self.default_mesh[4] = 2
        self.default_mesh[5] = 2
        # make place holder
        self.coord = tf.placeholder(global_tf_float_precision,
                                    [None, self.natoms[0] * 3],
                                    name='t_coord')
        self.box = tf.placeholder(global_tf_float_precision, [None, 9],
                                  name='t_box')
        self.type = tf.placeholder(tf.int32, [None, self.natoms[0]],
                                   name="t_type")
        self.tnatoms = tf.placeholder(tf.int32, [None], name="t_natoms")

    def _net(self, inputs, name, reuse=False):
        with tf.variable_scope(name, reuse=reuse):
            net_w = tf.get_variable('net_w', [self.descrpt.get_dim_out()],
                                    global_tf_float_precision,
                                    tf.constant_initializer(self.net_w_i))
        dot_v = tf.matmul(tf.reshape(inputs,
                                     [-1, self.descrpt.get_dim_out()]),
                          tf.reshape(net_w, [self.descrpt.get_dim_out(), 1]))
        return tf.reshape(dot_v, [-1])

    def comp_ef(self, dcoord, dbox, dtype, tnatoms, name, reuse=None):
        dout = self.descrpt.build(dcoord,
                                  dtype,
                                  tnatoms,
                                  dbox,
                                  self.default_mesh,
                                  suffix=name,
                                  reuse=reuse)
        inputs_reshape = tf.reshape(dout, [-1, self.descrpt.get_dim_out()])
        atom_ener = self._net(inputs_reshape, name, reuse=reuse)
        atom_ener_reshape = tf.reshape(atom_ener, [-1, self.natoms[0]])
        energy = tf.reduce_sum(atom_ener_reshape, axis=1)
        force, virial, av = self.descrpt.prod_force_virial(
            atom_ener_reshape, tnatoms)
        return energy, force, virial