def __init__(self, jdata):
     args = ClassArg()\
            .add('a',      dict,   must = True) \
            .add('r',      dict,   must = True)
     class_data = args.parse(jdata)
     self.param_a = class_data['a']
     self.param_r = class_data['r']
     self.descrpt_a = DescrptSeA(self.param_a)
     self.descrpt_r = DescrptSeR(self.param_r)
     assert (self.descrpt_a.get_ntypes() == self.descrpt_r.get_ntypes())
Exemple #2
0
 def test (self) :
     jdata = j_loader(input_json)
     jdata = jdata['model']
     descrpt = DescrptSeA(jdata['descriptor'])
     fitting = EnerFitting(jdata['fitting_net'], descrpt)
     avgs = [0, 10]
     stds = [2, 0.4]
     sys_natoms = [10, 100]
     sys_nframes = [5, 2]
     all_data = _make_fake_data(sys_natoms, sys_nframes, avgs, stds)
     frefa, frefs = _brute_fparam(all_data, len(avgs))
     arefa, arefs = _brute_aparam(all_data, len(avgs))
     fitting.compute_input_stats(all_data, protection = 1e-2)
     # print(frefa, frefs)
     for ii in range(len(avgs)):
         self.assertAlmostEqual(frefa[ii], fitting.fparam_avg[ii])
         self.assertAlmostEqual(frefs[ii], fitting.fparam_std[ii])
         self.assertAlmostEqual(arefa[ii], fitting.aparam_avg[ii])
         self.assertAlmostEqual(arefs[ii], fitting.aparam_std[ii])
Exemple #3
0
    def test_model(self):
        jfile = 'water_se_a.json'
        with open(jfile) as fp:
            jdata = json.load(fp)
        run_opt = RunOptions(None)
        systems = j_must_have(jdata, 'systems')
        set_pfx = j_must_have(jdata, 'set_prefix')
        batch_size = j_must_have(jdata, 'batch_size')
        test_size = j_must_have(jdata, 'numb_test')
        batch_size = 1
        test_size = 1
        stop_batch = j_must_have(jdata, 'stop_batch')
        rcut = j_must_have(jdata['model']['descriptor'], 'rcut')

        data = DataSystem(systems,
                          set_pfx,
                          batch_size,
                          test_size,
                          rcut,
                          run_opt=None)

        test_data = data.get_test()
        numb_test = 1

        descrpt = DescrptSeA(jdata['model']['descriptor'])
        fitting = EnerFitting(jdata['model']['fitting_net'], descrpt)
        model = Model(jdata['model'], descrpt, fitting)

        # model._compute_dstats([test_data['coord']], [test_data['box']], [test_data['type']], [test_data['natoms_vec']], [test_data['default_mesh']])
        input_data = {
            'coord': [test_data['coord']],
            'box': [test_data['box']],
            'type': [test_data['type']],
            'natoms_vec': [test_data['natoms_vec']],
            'default_mesh': [test_data['default_mesh']]
        }
        model._compute_dstats(input_data)
        model.bias_atom_e = data.compute_energy_shift()

        t_prop_c = tf.placeholder(tf.float32, [5], name='t_prop_c')
        t_energy = tf.placeholder(global_ener_float_precision, [None],
                                  name='t_energy')
        t_force = tf.placeholder(global_tf_float_precision, [None],
                                 name='t_force')
        t_virial = tf.placeholder(global_tf_float_precision, [None],
                                  name='t_virial')
        t_atom_ener = tf.placeholder(global_tf_float_precision, [None],
                                     name='t_atom_ener')
        t_coord = tf.placeholder(global_tf_float_precision, [None],
                                 name='i_coord')
        t_type = tf.placeholder(tf.int32, [None], name='i_type')
        t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2],
                                  name='i_natoms')
        t_box = tf.placeholder(global_tf_float_precision, [None, 9],
                               name='i_box')
        t_mesh = tf.placeholder(tf.int32, [None], name='i_mesh')
        is_training = tf.placeholder(tf.bool)
        t_fparam = None

        model_pred \
            = model.build (t_coord,
                           t_type,
                           t_natoms,
                           t_box,
                           t_mesh,
                           t_fparam,
                           suffix = "se_a",
                           reuse = False)
        energy = model_pred['energy']
        force = model_pred['force']
        virial = model_pred['virial']
        atom_ener = model_pred['atom_ener']

        feed_dict_test = {
            t_prop_c: test_data['prop_c'],
            t_energy: test_data['energy'][:numb_test],
            t_force: np.reshape(test_data['force'][:numb_test, :], [-1]),
            t_virial: np.reshape(test_data['virial'][:numb_test, :], [-1]),
            t_atom_ener: np.reshape(test_data['atom_ener'][:numb_test, :],
                                    [-1]),
            t_coord: np.reshape(test_data['coord'][:numb_test, :], [-1]),
            t_box: test_data['box'][:numb_test, :],
            t_type: np.reshape(test_data['type'][:numb_test, :], [-1]),
            t_natoms: test_data['natoms_vec'],
            t_mesh: test_data['default_mesh'],
            is_training: False
        }

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)

        e = e.reshape([-1])
        f = f.reshape([-1])
        v = v.reshape([-1])
        refe = [6.135449167779321300e+01]
        reff = [
            7.799691562262310585e-02, 9.423098804815030483e-02,
            3.790560997388224204e-03, 1.432522403799846578e-01,
            1.148392791403983204e-01, -1.321871172563671148e-02,
            -7.318966526325138000e-02, 6.516069212737778116e-02,
            5.406418483320515412e-04, 5.870713761026503247e-02,
            -1.605402669549013672e-01, -5.089516979826595386e-03,
            -2.554593467731766654e-01, 3.092063507347833987e-02,
            1.510355029451411479e-02, 4.869271842355533952e-02,
            -1.446113274345035005e-01, -1.126524434771078789e-03
        ]
        refv = [
            -6.076776685178300053e-01, 1.103174323630009418e-01,
            1.984250991380156690e-02, 1.103174323630009557e-01,
            -3.319759402259439551e-01, -6.007404107650986258e-03,
            1.984250991380157036e-02, -6.007404107650981921e-03,
            -1.200076017439753642e-03
        ]
        refe = np.reshape(refe, [-1])
        reff = np.reshape(reff, [-1])
        refv = np.reshape(refv, [-1])

        places = 10
        for ii in range(e.size):
            self.assertAlmostEqual(e[ii], refe[ii], places=places)
        for ii in range(f.size):
            self.assertAlmostEqual(f[ii], reff[ii], places=places)
        for ii in range(v.size):
            self.assertAlmostEqual(v[ii], refv[ii], places=places)
Exemple #4
0
class DescrptSeAR():
    def __init__(self, jdata):
        args = ClassArg()\
               .add('a',      dict,   must = True) \
               .add('r',      dict,   must = True)
        class_data = args.parse(jdata)
        self.param_a = class_data['a']
        self.param_r = class_data['r']
        self.descrpt_a = DescrptSeA(self.param_a)
        self.descrpt_r = DescrptSeR(self.param_r)
        assert (self.descrpt_a.get_ntypes() == self.descrpt_r.get_ntypes())
        self.davg = None
        self.dstd = None

    def get_rcut(self):
        return np.max([self.descrpt_a.get_rcut(), self.descrpt_r.get_rcut()])

    def get_ntypes(self):
        return self.descrpt_r.get_ntypes()

    def get_dim_out(self):
        return (self.descrpt_a.get_dim_out() + self.descrpt_r.get_dim_out())

    def get_nlist_a(self):
        return self.descrpt_a.nlist, self.descrpt_a.rij, self.descrpt_a.sel_a, self.descrpt_a.sel_r

    def get_nlist_r(self):
        return self.descrpt_r.nlist, self.descrpt_r.rij, self.descrpt_r.sel_a, self.descrpt_r.sel_r

    def compute_input_stats(self, data_coord, data_box, data_atype, natoms_vec,
                            mesh):
        self.descrpt_a.compute_input_stats(data_coord, data_box, data_atype,
                                           natoms_vec, mesh)
        self.descrpt_r.compute_input_stats(data_coord, data_box, data_atype,
                                           natoms_vec, mesh)
        self.davg = [self.descrpt_a.davg, self.descrpt_r.davg]
        self.dstd = [self.descrpt_a.dstd, self.descrpt_r.dstd]

    def build(self, coord_, atype_, natoms, box, mesh, suffix='', reuse=None):
        davg = self.davg
        dstd = self.dstd
        if davg is None:
            davg = [
                np.zeros([self.descrpt_a.ntypes, self.descrpt_a.ndescrpt]),
                np.zeros([self.descrpt_r.ntypes, self.descrpt_r.ndescrpt])
            ]
        if dstd is None:
            dstd = [
                np.ones([self.descrpt_a.ntypes, self.descrpt_a.ndescrpt]),
                np.ones([self.descrpt_r.ntypes, self.descrpt_r.ndescrpt])
            ]
        # dout
        self.dout_a = self.descrpt_a.build(coord_,
                                           atype_,
                                           natoms,
                                           box,
                                           mesh,
                                           suffix=suffix + '_a',
                                           reuse=reuse)
        self.dout_r = self.descrpt_r.build(coord_,
                                           atype_,
                                           natoms,
                                           box,
                                           mesh,
                                           suffix=suffix,
                                           reuse=reuse)
        self.dout_a = tf.reshape(self.dout_a,
                                 [-1, self.descrpt_a.get_dim_out()])
        self.dout_r = tf.reshape(self.dout_r,
                                 [-1, self.descrpt_r.get_dim_out()])
        self.dout = tf.concat([self.dout_a, self.dout_r], axis=1)
        self.dout = tf.reshape(self.dout, [-1, natoms[0] * self.get_dim_out()])
        return self.dout

    def prod_force_virial(self, atom_ener, natoms):
        f_a, v_a, av_a = self.descrpt_a.prod_force_virial(atom_ener, natoms)
        f_r, v_r, av_r = self.descrpt_r.prod_force_virial(atom_ener, natoms)
        force = f_a + f_r
        virial = v_a + v_r
        atom_virial = av_a + av_r
        return force, virial, atom_virial
Exemple #5
0
    def test_model(self):
        jfile = 'polar_se_a.json'
        with open(jfile) as fp:
            jdata = json.load(fp)
        run_opt = RunOptions(None)
        systems = j_must_have(jdata, 'systems')
        set_pfx = j_must_have(jdata, 'set_prefix')
        batch_size = j_must_have(jdata, 'batch_size')
        test_size = j_must_have(jdata, 'numb_test')
        batch_size = 1
        test_size = 1
        stop_batch = j_must_have(jdata, 'stop_batch')
        rcut = j_must_have(jdata['model']['descriptor'], 'rcut')

        data = DataSystem(systems,
                          set_pfx,
                          batch_size,
                          test_size,
                          rcut,
                          run_opt=None)

        test_data = data.get_test()
        numb_test = 1

        descrpt = DescrptSeA(jdata['model']['descriptor'])
        fitting = PolarFittingSeA(jdata['model']['fitting_net'], descrpt)
        model = PolarModel(jdata['model'], descrpt, fitting)

        model._compute_dstats([test_data['coord']], [test_data['box']],
                              [test_data['type']], [test_data['natoms_vec']],
                              [test_data['default_mesh']])

        t_prop_c = tf.placeholder(tf.float32, [5], name='t_prop_c')
        t_energy = tf.placeholder(global_ener_float_precision, [None],
                                  name='t_energy')
        t_force = tf.placeholder(global_tf_float_precision, [None],
                                 name='t_force')
        t_virial = tf.placeholder(global_tf_float_precision, [None],
                                  name='t_virial')
        t_atom_ener = tf.placeholder(global_tf_float_precision, [None],
                                     name='t_atom_ener')
        t_coord = tf.placeholder(global_tf_float_precision, [None],
                                 name='i_coord')
        t_type = tf.placeholder(tf.int32, [None], name='i_type')
        t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2],
                                  name='i_natoms')
        t_box = tf.placeholder(global_tf_float_precision, [None, 9],
                               name='i_box')
        t_mesh = tf.placeholder(tf.int32, [None], name='i_mesh')
        is_training = tf.placeholder(tf.bool)
        t_fparam = None

        model_pred \
            = model.build (t_coord,
                           t_type,
                           t_natoms,
                           t_box,
                           t_mesh,
                           t_fparam,
                           suffix = "polar_se_a",
                           reuse = False)
        polar = model_pred['polar']

        feed_dict_test = {
            t_prop_c: test_data['prop_c'],
            t_coord: np.reshape(test_data['coord'][:numb_test, :], [-1]),
            t_box: test_data['box'][:numb_test, :],
            t_type: np.reshape(test_data['type'][:numb_test, :], [-1]),
            t_natoms: test_data['natoms_vec'],
            t_mesh: test_data['default_mesh'],
            is_training: False
        }

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        [p] = sess.run([polar], feed_dict=feed_dict_test)

        p = p.reshape([-1])
        refp = [
            3.39695248e+01, 2.16564043e+01, 8.18501479e-01, 2.16564043e+01,
            1.38211789e+01, 5.22775159e-01, 8.18501479e-01, 5.22775159e-01,
            1.97847218e-02, 8.08467431e-01, 3.42081126e+00, -2.01072261e-01,
            3.42081126e+00, 1.54924596e+01, -9.06153697e-01, -2.01072261e-01,
            -9.06153697e-01, 5.30193262e-02
        ]

        places = 6
        for ii in range(p.size):
            self.assertAlmostEqual(p[ii], refp[ii], places=places)
Exemple #6
0
    def _init_param(self, jdata):
        # model config
        model_param = j_must_have(jdata, 'model')
        descrpt_param = j_must_have(model_param, 'descriptor')
        fitting_param = j_must_have(model_param, 'fitting_net')

        # descriptor
        descrpt_type = j_must_have(descrpt_param, 'type')
        if descrpt_type == 'loc_frame':
            self.descrpt = DescrptLocFrame(descrpt_param)
        elif descrpt_type == 'se_a':
            self.descrpt = DescrptSeA(descrpt_param)
        elif descrpt_type == 'se_r':
            self.descrpt = DescrptSeR(descrpt_param)
        elif descrpt_type == 'se_ar':
            self.descrpt = DescrptSeAR(descrpt_param)
        else:
            raise RuntimeError('unknow model type ' + descrpt_type)

        # fitting net
        try:
            fitting_type = fitting_param['type']
        except:
            fitting_type = 'ener'
        if fitting_type == 'ener':
            self.fitting = EnerFitting(fitting_param, self.descrpt)
        elif fitting_type == 'wfc':
            self.fitting = WFCFitting(fitting_param, self.descrpt)
        elif fitting_type == 'dipole':
            if descrpt_type == 'se_a':
                self.fitting = DipoleFittingSeA(fitting_param, self.descrpt)
            else:
                raise RuntimeError(
                    'fitting dipole only supports descrptors: se_a')
        elif fitting_type == 'polar':
            if descrpt_type == 'loc_frame':
                self.fitting = PolarFittingLocFrame(fitting_param,
                                                    self.descrpt)
            elif descrpt_type == 'se_a':
                self.fitting = PolarFittingSeA(fitting_param, self.descrpt)
            else:
                raise RuntimeError(
                    'fitting polar only supports descrptors: loc_frame and se_a'
                )
        elif fitting_type == 'global_polar':
            if descrpt_type == 'se_a':
                self.fitting = GlobalPolarFittingSeA(fitting_param,
                                                     self.descrpt)
            else:
                raise RuntimeError(
                    'fitting global_polar only supports descrptors: loc_frame and se_a'
                )
        else:
            raise RuntimeError('unknow fitting type ' + fitting_type)

        # init model
        # infer model type by fitting_type
        if fitting_type == Model.model_type:
            self.model = Model(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'wfc':
            self.model = WFCModel(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'dipole':
            self.model = DipoleModel(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'polar':
            self.model = PolarModel(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'global_polar':
            self.model = GlobalPolarModel(model_param, self.descrpt,
                                          self.fitting)
        else:
            raise RuntimeError('get unknown fitting type when building model')

        # learning rate
        lr_param = j_must_have(jdata, 'learning_rate')
        try:
            lr_type = lr_param['type']
        except:
            lr_type = 'exp'
        if lr_type == 'exp':
            self.lr = LearningRateExp(lr_param)
        else:
            raise RuntimeError('unknown learning_rate type ' + lr_type)

        # loss
        # infer loss type by fitting_type
        try:
            loss_param = jdata['loss']
            loss_type = loss_param.get('type', 'std')
        except:
            loss_param = None
            loss_type = 'std'

        if fitting_type == 'ener':
            if loss_type == 'std':
                self.loss = EnerStdLoss(
                    loss_param, starter_learning_rate=self.lr.start_lr())
            elif loss_type == 'ener_dipole':
                self.loss = EnerDipoleLoss(
                    loss_param, starter_learning_rate=self.lr.start_lr())
            else:
                raise RuntimeError('unknow loss type')
        elif fitting_type == 'wfc':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='wfc',
                                   tensor_size=self.model.get_out_size(),
                                   label_name='wfc')
        elif fitting_type == 'dipole':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='dipole',
                                   tensor_size=3,
                                   label_name='dipole')
        elif fitting_type == 'polar':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='polar',
                                   tensor_size=9,
                                   label_name='polarizability')
        elif fitting_type == 'global_polar':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='global_polar',
                                   tensor_size=9,
                                   atomic=False,
                                   label_name='polarizability')
        else:
            raise RuntimeError(
                'get unknown fitting type when building loss function')

        # training
        training_param = j_must_have(jdata, 'training')

        tr_args = ClassArg()\
                  .add('numb_test',     int,    default = 1)\
                  .add('disp_file',     str,    default = 'lcurve.out')\
                  .add('disp_freq',     int,    default = 100)\
                  .add('save_freq',     int,    default = 1000)\
                  .add('save_ckpt',     str,    default = 'model.ckpt')\
                  .add('display_in_training', bool, default = True)\
                  .add('timing_in_training',  bool, default = True)\
                  .add('profiling',     bool,   default = False)\
                  .add('profiling_file',str,    default = 'timeline.json')\
                  .add('sys_probs',   list    )\
                  .add('auto_prob_style', str, default = "prob_sys_size")
        tr_data = tr_args.parse(training_param)
        self.numb_test = tr_data['numb_test']
        self.disp_file = tr_data['disp_file']
        self.disp_freq = tr_data['disp_freq']
        self.save_freq = tr_data['save_freq']
        self.save_ckpt = tr_data['save_ckpt']
        self.display_in_training = tr_data['display_in_training']
        self.timing_in_training = tr_data['timing_in_training']
        self.profiling = tr_data['profiling']
        self.profiling_file = tr_data['profiling_file']
        self.sys_probs = tr_data['sys_probs']
        self.auto_prob_style = tr_data['auto_prob_style']
        self.useBN = False
        if fitting_type == 'ener' and self.fitting.get_numb_fparam() > 0:
            self.numb_fparam = self.fitting.get_numb_fparam()
        else:
            self.numb_fparam = 0
    def test_model(self):
        jfile = 'water_se_a_aparam.json'
        with open(jfile) as fp:
            jdata = json.load(fp)
        run_opt = RunOptions(None)
        systems = j_must_have(jdata, 'systems')
        set_pfx = j_must_have(jdata, 'set_prefix')
        batch_size = j_must_have(jdata, 'batch_size')
        test_size = j_must_have(jdata, 'numb_test')
        batch_size = 1
        test_size = 1
        stop_batch = j_must_have(jdata, 'stop_batch')
        rcut = j_must_have(jdata['model']['descriptor'], 'rcut')

        data = DataSystem(systems,
                          set_pfx,
                          batch_size,
                          test_size,
                          rcut,
                          run_opt=None)

        test_data = data.get_test()
        # manually set aparam
        test_data['aparam'] = np.load('system/set.000/aparam.npy')
        numb_test = 1

        descrpt = DescrptSeA(jdata['model']['descriptor'])
        fitting = EnerFitting(jdata['model']['fitting_net'], descrpt)
        model = Model(jdata['model'], descrpt, fitting)

        # model._compute_dstats([test_data['coord']], [test_data['box']], [test_data['type']], [test_data['natoms_vec']], [test_data['default_mesh']])
        input_data = {
            'coord': [test_data['coord']],
            'box': [test_data['box']],
            'type': [test_data['type']],
            'natoms_vec': [test_data['natoms_vec']],
            'default_mesh': [test_data['default_mesh']],
            'aparam': [test_data['aparam']],
        }
        model._compute_dstats(input_data)
        model.bias_atom_e = data.compute_energy_shift()

        t_prop_c = tf.placeholder(tf.float32, [5], name='t_prop_c')
        t_energy = tf.placeholder(global_ener_float_precision, [None],
                                  name='t_energy')
        t_force = tf.placeholder(global_tf_float_precision, [None],
                                 name='t_force')
        t_virial = tf.placeholder(global_tf_float_precision, [None],
                                  name='t_virial')
        t_atom_ener = tf.placeholder(global_tf_float_precision, [None],
                                     name='t_atom_ener')
        t_coord = tf.placeholder(global_tf_float_precision, [None],
                                 name='i_coord')
        t_type = tf.placeholder(tf.int32, [None], name='i_type')
        t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2],
                                  name='i_natoms')
        t_box = tf.placeholder(global_tf_float_precision, [None, 9],
                               name='i_box')
        t_mesh = tf.placeholder(tf.int32, [None], name='i_mesh')
        t_aparam = tf.placeholder(global_tf_float_precision, [None],
                                  name='i_aparam')
        is_training = tf.placeholder(tf.bool)
        input_dict = {}
        input_dict['aparam'] = t_aparam

        model_pred\
            = model.build (t_coord,
                           t_type,
                           t_natoms,
                           t_box,
                           t_mesh,
                           input_dict,
                           suffix = "se_a_aparam",
                           reuse = False)
        energy = model_pred['energy']
        force = model_pred['force']
        virial = model_pred['virial']
        atom_ener = model_pred['atom_ener']

        feed_dict_test = {
            t_prop_c: test_data['prop_c'],
            t_energy: test_data['energy'][:numb_test],
            t_force: np.reshape(test_data['force'][:numb_test, :], [-1]),
            t_virial: np.reshape(test_data['virial'][:numb_test, :], [-1]),
            t_atom_ener: np.reshape(test_data['atom_ener'][:numb_test, :],
                                    [-1]),
            t_coord: np.reshape(test_data['coord'][:numb_test, :], [-1]),
            t_box: test_data['box'][:numb_test, :],
            t_type: np.reshape(test_data['type'][:numb_test, :], [-1]),
            t_natoms: test_data['natoms_vec'],
            t_mesh: test_data['default_mesh'],
            t_aparam: np.reshape(test_data['aparam'][:numb_test, :], [-1]),
            is_training: False
        }

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)

        e = e.reshape([-1])
        f = f.reshape([-1])
        v = v.reshape([-1])
        refe = [61.35473702079649]
        reff = [
            7.789591210641927388e-02, 9.411176646369459609e-02,
            3.785806413688173194e-03, 1.430830954178063386e-01,
            1.146964190520970150e-01, -1.320340288927138173e-02,
            -7.308720494747594776e-02, 6.508269338140809657e-02,
            5.398739145542804643e-04, 5.863268336973800898e-02,
            -1.603409523950408699e-01, -5.083084610994957619e-03,
            -2.551569799443983988e-01, 3.087934885732580501e-02,
            1.508590526622844222e-02, 4.863249399791078065e-02,
            -1.444292753594846324e-01, -1.125098094204559241e-03
        ]
        refv = [
            -6.069498397488943819e-01, 1.101778888191114192e-01,
            1.981907430646132409e-02, 1.101778888191114608e-01,
            -3.315612988100872793e-01, -5.999739184898976799e-03,
            1.981907430646132756e-02, -5.999739184898974197e-03,
            -1.198656608172396325e-03
        ]
        refe = np.reshape(refe, [-1])
        reff = np.reshape(reff, [-1])
        refv = np.reshape(refv, [-1])

        places = 10
        for ii in range(e.size):
            self.assertAlmostEqual(e[ii], refe[ii], places=places)
        for ii in range(f.size):
            self.assertAlmostEqual(f[ii], reff[ii], places=places)
        for ii in range(v.size):
            self.assertAlmostEqual(v[ii], refv[ii], places=places)
    def test_model(self):
        jfile = 'water_se_a_fparam.json'
        with open(jfile) as fp:
            jdata = json.load(fp)
        run_opt = RunOptions(None)
        systems = j_must_have(jdata, 'systems')
        set_pfx = j_must_have(jdata, 'set_prefix')
        batch_size = j_must_have(jdata, 'batch_size')
        test_size = j_must_have(jdata, 'numb_test')
        batch_size = 1
        test_size = 1
        stop_batch = j_must_have(jdata, 'stop_batch')
        rcut = j_must_have(jdata['model']['descriptor'], 'rcut')

        data = DataSystem(systems,
                          set_pfx,
                          batch_size,
                          test_size,
                          rcut,
                          run_opt=None)

        test_data = data.get_test()
        numb_test = 1

        descrpt = DescrptSeA(jdata['model']['descriptor'])
        fitting = EnerFitting(jdata['model']['fitting_net'], descrpt)
        model = Model(jdata['model'], descrpt, fitting)

        model._compute_dstats([test_data['coord']], [test_data['box']],
                              [test_data['type']], [test_data['natoms_vec']],
                              [test_data['default_mesh']])
        model.bias_atom_e = data.compute_energy_shift()

        t_prop_c = tf.placeholder(tf.float32, [5], name='t_prop_c')
        t_energy = tf.placeholder(global_ener_float_precision, [None],
                                  name='t_energy')
        t_force = tf.placeholder(global_tf_float_precision, [None],
                                 name='t_force')
        t_virial = tf.placeholder(global_tf_float_precision, [None],
                                  name='t_virial')
        t_atom_ener = tf.placeholder(global_tf_float_precision, [None],
                                     name='t_atom_ener')
        t_coord = tf.placeholder(global_tf_float_precision, [None],
                                 name='i_coord')
        t_type = tf.placeholder(tf.int32, [None], name='i_type')
        t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2],
                                  name='i_natoms')
        t_box = tf.placeholder(global_tf_float_precision, [None, 9],
                               name='i_box')
        t_mesh = tf.placeholder(tf.int32, [None], name='i_mesh')
        t_fparam = tf.placeholder(global_tf_float_precision, [None],
                                  name='i_fparam')
        is_training = tf.placeholder(tf.bool)
        input_dict = {}
        input_dict['fparam'] = t_fparam

        model_pred\
            = model.build (t_coord,
                           t_type,
                           t_natoms,
                           t_box,
                           t_mesh,
                           input_dict,
                           suffix = "se_a_fparam",
                           reuse = False)
        energy = model_pred['energy']
        force = model_pred['force']
        virial = model_pred['virial']
        atom_ener = model_pred['atom_ener']

        feed_dict_test = {
            t_prop_c: test_data['prop_c'],
            t_energy: test_data['energy'][:numb_test],
            t_force: np.reshape(test_data['force'][:numb_test, :], [-1]),
            t_virial: np.reshape(test_data['virial'][:numb_test, :], [-1]),
            t_atom_ener: np.reshape(test_data['atom_ener'][:numb_test, :],
                                    [-1]),
            t_coord: np.reshape(test_data['coord'][:numb_test, :], [-1]),
            t_box: test_data['box'][:numb_test, :],
            t_type: np.reshape(test_data['type'][:numb_test, :], [-1]),
            t_natoms: test_data['natoms_vec'],
            t_mesh: test_data['default_mesh'],
            t_fparam: np.reshape(test_data['fparam'][:numb_test, :], [-1]),
            is_training: False
        }

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)

        e = e.reshape([-1])
        f = f.reshape([-1])
        v = v.reshape([-1])
        refe = [6.135136929183754972e+01]
        reff = [
            7.761477777656561328e-02, 9.383013575207051205e-02,
            3.776776376267230399e-03, 1.428268971463224069e-01,
            1.143858253900619654e-01, -1.318441687719179231e-02,
            -7.271897092708884403e-02, 6.494907553857684479e-02,
            5.355599592111062821e-04, 5.840910251709752199e-02,
            -1.599042555763417750e-01, -5.067165555590445389e-03,
            -2.546246315216804113e-01, 3.073296814647456451e-02,
            1.505994759166155023e-02, 4.849282500878367153e-02,
            -1.439937492508420736e-01, -1.120701494357654411e-03
        ]
        refv = [
            -6.054303146013112480e-01, 1.097859194719944115e-01,
            1.977605183964963390e-02, 1.097859194719943976e-01,
            -3.306167096812382966e-01, -5.978855662865613894e-03,
            1.977605183964964083e-02, -5.978855662865616497e-03,
            -1.196331922996723236e-03
        ]
        refe = np.reshape(refe, [-1])
        reff = np.reshape(reff, [-1])
        refv = np.reshape(refv, [-1])

        places = 10
        for ii in range(e.size):
            self.assertAlmostEqual(e[ii], refe[ii], places=places)
        for ii in range(f.size):
            self.assertAlmostEqual(f[ii], reff[ii], places=places)
        for ii in range(v.size):
            self.assertAlmostEqual(v[ii], refv[ii], places=places)