Example #1
0
    def _init_param(self, jdata):
        # model config
        model_param = j_must_have(jdata, 'model')
        descrpt_param = j_must_have(model_param, 'descriptor')
        fitting_param = j_must_have(model_param, 'fitting_net')

        # descriptor
        descrpt_type = j_must_have(descrpt_param, 'type')
        if descrpt_type == 'loc_frame':
            self.descrpt = DescrptLocFrame(descrpt_param)
        elif descrpt_type == 'se_a':
            self.descrpt = DescrptSeA(descrpt_param)
        elif descrpt_type == 'se_r':
            self.descrpt = DescrptSeR(descrpt_param)
        elif descrpt_type == 'se_ar':
            self.descrpt = DescrptSeAR(descrpt_param)
        else:
            raise RuntimeError('unknow model type ' + descrpt_type)

        # fitting net
        try:
            fitting_type = fitting_param['type']
        except:
            fitting_type = 'ener'
        if fitting_type == 'ener':
            self.fitting = EnerFitting(fitting_param, self.descrpt)
        elif fitting_type == 'wfc':
            self.fitting = WFCFitting(fitting_param, self.descrpt)
        elif fitting_type == 'dipole':
            if descrpt_type == 'se_a':
                self.fitting = DipoleFittingSeA(fitting_param, self.descrpt)
            else:
                raise RuntimeError(
                    'fitting dipole only supports descrptors: se_a')
        elif fitting_type == 'polar':
            if descrpt_type == 'loc_frame':
                self.fitting = PolarFittingLocFrame(fitting_param,
                                                    self.descrpt)
            elif descrpt_type == 'se_a':
                self.fitting = PolarFittingSeA(fitting_param, self.descrpt)
            else:
                raise RuntimeError(
                    'fitting polar only supports descrptors: loc_frame and se_a'
                )
        elif fitting_type == 'global_polar':
            if descrpt_type == 'se_a':
                self.fitting = GlobalPolarFittingSeA(fitting_param,
                                                     self.descrpt)
            else:
                raise RuntimeError(
                    'fitting global_polar only supports descrptors: loc_frame and se_a'
                )
        else:
            raise RuntimeError('unknow fitting type ' + fitting_type)

        # init model
        # infer model type by fitting_type
        if fitting_type == Model.model_type:
            self.model = Model(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'wfc':
            self.model = WFCModel(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'dipole':
            self.model = DipoleModel(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'polar':
            self.model = PolarModel(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'global_polar':
            self.model = GlobalPolarModel(model_param, self.descrpt,
                                          self.fitting)
        else:
            raise RuntimeError('get unknown fitting type when building model')

        # learning rate
        lr_param = j_must_have(jdata, 'learning_rate')
        try:
            lr_type = lr_param['type']
        except:
            lr_type = 'exp'
        if lr_type == 'exp':
            self.lr = LearningRateExp(lr_param)
        else:
            raise RuntimeError('unknown learning_rate type ' + lr_type)

        # loss
        # infer loss type by fitting_type
        try:
            loss_param = jdata['loss']
            loss_type = loss_param.get('type', 'std')
        except:
            loss_param = None
            loss_type = 'std'

        if fitting_type == 'ener':
            if loss_type == 'std':
                self.loss = EnerStdLoss(
                    loss_param, starter_learning_rate=self.lr.start_lr())
            elif loss_type == 'ener_dipole':
                self.loss = EnerDipoleLoss(
                    loss_param, starter_learning_rate=self.lr.start_lr())
            else:
                raise RuntimeError('unknow loss type')
        elif fitting_type == 'wfc':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='wfc',
                                   tensor_size=self.model.get_out_size(),
                                   label_name='wfc')
        elif fitting_type == 'dipole':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='dipole',
                                   tensor_size=3,
                                   label_name='dipole')
        elif fitting_type == 'polar':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='polar',
                                   tensor_size=9,
                                   label_name='polarizability')
        elif fitting_type == 'global_polar':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='global_polar',
                                   tensor_size=9,
                                   atomic=False,
                                   label_name='polarizability')
        else:
            raise RuntimeError(
                'get unknown fitting type when building loss function')

        # training
        training_param = j_must_have(jdata, 'training')

        tr_args = ClassArg()\
                  .add('numb_test',     int,    default = 1)\
                  .add('disp_file',     str,    default = 'lcurve.out')\
                  .add('disp_freq',     int,    default = 100)\
                  .add('save_freq',     int,    default = 1000)\
                  .add('save_ckpt',     str,    default = 'model.ckpt')\
                  .add('display_in_training', bool, default = True)\
                  .add('timing_in_training',  bool, default = True)\
                  .add('profiling',     bool,   default = False)\
                  .add('profiling_file',str,    default = 'timeline.json')\
                  .add('sys_probs',   list    )\
                  .add('auto_prob_style', str, default = "prob_sys_size")
        tr_data = tr_args.parse(training_param)
        self.numb_test = tr_data['numb_test']
        self.disp_file = tr_data['disp_file']
        self.disp_freq = tr_data['disp_freq']
        self.save_freq = tr_data['save_freq']
        self.save_ckpt = tr_data['save_ckpt']
        self.display_in_training = tr_data['display_in_training']
        self.timing_in_training = tr_data['timing_in_training']
        self.profiling = tr_data['profiling']
        self.profiling_file = tr_data['profiling_file']
        self.sys_probs = tr_data['sys_probs']
        self.auto_prob_style = tr_data['auto_prob_style']
        self.useBN = False
        if fitting_type == 'ener' and self.fitting.get_numb_fparam() > 0:
            self.numb_fparam = self.fitting.get_numb_fparam()
        else:
            self.numb_fparam = 0
Example #2
0
    def test_model(self):
        jfile = 'polar_se_a.json'
        with open(jfile) as fp:
            jdata = json.load(fp)
        run_opt = RunOptions(None)
        systems = j_must_have(jdata, 'systems')
        set_pfx = j_must_have(jdata, 'set_prefix')
        batch_size = j_must_have(jdata, 'batch_size')
        test_size = j_must_have(jdata, 'numb_test')
        batch_size = 1
        test_size = 1
        stop_batch = j_must_have(jdata, 'stop_batch')
        rcut = j_must_have(jdata['model']['descriptor'], 'rcut')

        data = DataSystem(systems,
                          set_pfx,
                          batch_size,
                          test_size,
                          rcut,
                          run_opt=None)

        test_data = data.get_test()
        numb_test = 1

        descrpt = DescrptSeA(jdata['model']['descriptor'])
        fitting = PolarFittingSeA(jdata['model']['fitting_net'], descrpt)
        model = PolarModel(jdata['model'], descrpt, fitting)

        model._compute_dstats([test_data['coord']], [test_data['box']],
                              [test_data['type']], [test_data['natoms_vec']],
                              [test_data['default_mesh']])

        t_prop_c = tf.placeholder(tf.float32, [5], name='t_prop_c')
        t_energy = tf.placeholder(global_ener_float_precision, [None],
                                  name='t_energy')
        t_force = tf.placeholder(global_tf_float_precision, [None],
                                 name='t_force')
        t_virial = tf.placeholder(global_tf_float_precision, [None],
                                  name='t_virial')
        t_atom_ener = tf.placeholder(global_tf_float_precision, [None],
                                     name='t_atom_ener')
        t_coord = tf.placeholder(global_tf_float_precision, [None],
                                 name='i_coord')
        t_type = tf.placeholder(tf.int32, [None], name='i_type')
        t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2],
                                  name='i_natoms')
        t_box = tf.placeholder(global_tf_float_precision, [None, 9],
                               name='i_box')
        t_mesh = tf.placeholder(tf.int32, [None], name='i_mesh')
        is_training = tf.placeholder(tf.bool)
        t_fparam = None

        model_pred \
            = model.build (t_coord,
                           t_type,
                           t_natoms,
                           t_box,
                           t_mesh,
                           t_fparam,
                           suffix = "polar_se_a",
                           reuse = False)
        polar = model_pred['polar']

        feed_dict_test = {
            t_prop_c: test_data['prop_c'],
            t_coord: np.reshape(test_data['coord'][:numb_test, :], [-1]),
            t_box: test_data['box'][:numb_test, :],
            t_type: np.reshape(test_data['type'][:numb_test, :], [-1]),
            t_natoms: test_data['natoms_vec'],
            t_mesh: test_data['default_mesh'],
            is_training: False
        }

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        [p] = sess.run([polar], feed_dict=feed_dict_test)

        p = p.reshape([-1])
        refp = [
            3.39695248e+01, 2.16564043e+01, 8.18501479e-01, 2.16564043e+01,
            1.38211789e+01, 5.22775159e-01, 8.18501479e-01, 5.22775159e-01,
            1.97847218e-02, 8.08467431e-01, 3.42081126e+00, -2.01072261e-01,
            3.42081126e+00, 1.54924596e+01, -9.06153697e-01, -2.01072261e-01,
            -9.06153697e-01, 5.30193262e-02
        ]

        places = 6
        for ii in range(p.size):
            self.assertAlmostEqual(p[ii], refp[ii], places=places)
Example #3
0
class NNPTrainer (object):
    def __init__(self, 
                 jdata, 
                 run_opt):
        self.run_opt = run_opt
        self._init_param(jdata)

    def _init_param(self, jdata):
        # model config        
        model_param = j_must_have(jdata, 'model')
        descrpt_param = j_must_have(model_param, 'descriptor')
        fitting_param = j_must_have(model_param, 'fitting_net')

        # descriptor
        descrpt_type = j_must_have(descrpt_param, 'type')
        if descrpt_type == 'loc_frame':
            self.descrpt = DescrptLocFrame(descrpt_param)
        elif descrpt_type == 'se_a' :
            self.descrpt = DescrptSeA(descrpt_param)
        elif descrpt_type == 'se_r' :
            self.descrpt = DescrptSeR(descrpt_param)
        elif descrpt_type == 'se_ar' :
            self.descrpt = DescrptSeAR(descrpt_param)
        else :
            raise RuntimeError('unknow model type ' + descrpt_type)

        # fitting net
        try: 
            fitting_type = fitting_param['type']
        except:
            fitting_type = 'ener'
        if fitting_type == 'ener':
            self.fitting = EnerFitting(fitting_param, self.descrpt)
        elif fitting_type == 'wfc':            
            self.fitting = WFCFitting(fitting_param, self.descrpt)
        elif fitting_type == 'polar':
            if descrpt_type == 'loc_frame':
                self.fitting = PolarFittingLocFrame(fitting_param, self.descrpt)
            elif descrpt_type == 'se_a':
                self.fitting = PolarFittingSeA(fitting_param, self.descrpt)
            else :
                raise RuntimeError('fitting polar only supports descrptors: loc_frame and se_a')
        else :
            raise RuntimeError('unknow fitting type ' + fitting_type)

        # init model
        # infer model type by fitting_type
        if fitting_type == Model.model_type:
            self.model = Model(model_param, self.descrpt, self.fitting)
        elif fitting_type == WFCModel.model_type:
            self.model = WFCModel(model_param, self.descrpt, self.fitting)
        elif fitting_type == PolarModel.model_type:
            self.model = PolarModel(model_param, self.descrpt, self.fitting)
        else :
            raise RuntimeError('get unknown fitting type when building model')

        # learning rate
        lr_param = j_must_have(jdata, 'learning_rate')
        try: 
            lr_type = lr_param['type']
        except:
            lr_type = 'exp'
        if lr_type == 'exp':
            self.lr = LearningRateExp(lr_param)
        else :
            raise RuntimeError('unknown learning_rate type ' + lr_type)        

        # loss
        # infer loss type by fitting_type
        try :
            loss_param = jdata['loss']
        except:
            loss_param = None
        if fitting_type == 'ener':
            self.loss = EnerStdLoss(loss_param, starter_learning_rate = self.lr.start_lr())
        elif fitting_type == 'wfc':
            self.loss = WFCLoss(loss_param, model = self.model)
        elif fitting_type == 'polar':
            self.loss = TensorLoss(loss_param, 
                                   model = self.model, 
                                   tensor_name = 'polar',
                                   tensor_size = 9,
                                   label_name = 'polarizability')
        else :
            raise RuntimeError('get unknown fitting type when building loss function')

        # training
        training_param = j_must_have(jdata, 'training')

        tr_args = ClassArg()\
                  .add('numb_test',     int,    default = 1)\
                  .add('disp_file',     str,    default = 'lcurve.out')\
                  .add('disp_freq',     int,    default = 100)\
                  .add('save_freq',     int,    default = 1000)\
                  .add('save_ckpt',     str,    default = 'model.ckpt')\
                  .add('display_in_training', bool, default = True)\
                  .add('timing_in_training',  bool, default = True)\
                  .add('profiling',     bool,   default = False)\
                  .add('profiling_file',str,    default = 'timeline.json')\
                  .add('sys_weights',   list    )
        tr_data = tr_args.parse(training_param)
        self.numb_test = tr_data['numb_test']
        self.disp_file = tr_data['disp_file']
        self.disp_freq = tr_data['disp_freq']
        self.save_freq = tr_data['save_freq']
        self.save_ckpt = tr_data['save_ckpt']
        self.display_in_training = tr_data['display_in_training']
        self.timing_in_training  = tr_data['timing_in_training']
        self.profiling = tr_data['profiling']
        self.profiling_file = tr_data['profiling_file']
        self.sys_weights = tr_data['sys_weights']        
        self.useBN = False
        if fitting_type == 'ener' and  self.fitting.get_numb_fparam() > 0 :
            self.numb_fparam = self.fitting.get_numb_fparam()
        else :
            self.numb_fparam = 0


    def _message (self, msg) :
        self.run_opt.message(msg)

    def build (self, 
               data) :
        self.ntypes = self.model.get_ntypes()
        assert (self.ntypes == data.get_ntypes()), "ntypes should match that found in data"

        self.batch_size = data.get_batch_size()

        if self.numb_fparam > 0 :
            self._message("training with %d frame parameter(s)" % self.numb_fparam)
        else:
            self._message("training without frame parameter")

        self.type_map = data.get_type_map()

        self.model.data_stat(data)

        worker_device = "/job:%s/task:%d/%s" % (self.run_opt.my_job_name,
                                                self.run_opt.my_task_index,
                                                self.run_opt.my_device)

        with tf.device(tf.train.replica_device_setter(worker_device = worker_device,
                                                      cluster = self.run_opt.cluster_spec)):
            self._build_lr()
            self._build_network(data)
            self._build_training()


    def _build_lr(self):
        self._extra_train_ops   = []
        self.global_step = tf.train.get_or_create_global_step()
        self.learning_rate = self.lr.build(self.global_step)
        self._message("built lr")

    def _build_network(self, data):        
        self.place_holders = {}
        data_dict = data.get_data_dict()
        for kk in data_dict.keys():
            if kk == 'type':
                continue
            prec = global_tf_float_precision
            if data_dict[kk]['high_prec'] :
                prec = global_ener_float_precision
            self.place_holders[kk] = tf.placeholder(prec, [None], name = 't_' + kk)
            self.place_holders['find_'+kk] = tf.placeholder(tf.float32, name = 't_find_' + kk)

        self.place_holders['type']      = tf.placeholder(tf.int32,   [None], name='t_type')
        self.place_holders['natoms_vec']        = tf.placeholder(tf.int32,   [self.ntypes+2], name='t_natoms')
        self.place_holders['default_mesh']      = tf.placeholder(tf.int32,   [None], name='t_mesh')
        self.place_holders['is_training']       = tf.placeholder(tf.bool)

        self.model_pred\
            = self.model.build (self.place_holders['coord'], 
                                self.place_holders['type'], 
                                self.place_holders['natoms_vec'], 
                                self.place_holders['box'], 
                                self.place_holders['default_mesh'],
                                self.place_holders,
                                suffix = "", 
                                reuse = False)

        self.l2_l, self.l2_more\
            = self.loss.build (self.learning_rate,
                               self.place_holders['natoms_vec'], 
                               self.model_pred,
                               self.place_holders,
                               suffix = "test")

        self._message("built network")

    def _build_training(self):
        trainable_variables = tf.trainable_variables()
        optimizer = tf.train.AdamOptimizer(learning_rate = self.learning_rate)
        if self.run_opt.is_distrib :
            optimizer = tf.train.SyncReplicasOptimizer(
                optimizer,
                replicas_to_aggregate = self.run_opt.cluster_spec.num_tasks("worker"),
                total_num_replicas = self.run_opt.cluster_spec.num_tasks("worker"),
                name = "sync_replicas")
            self.sync_replicas_hook = optimizer.make_session_run_hook(self.run_opt.is_chief)            
        grads = tf.gradients(self.l2_l, trainable_variables)
        apply_op = optimizer.apply_gradients (zip (grads, trainable_variables),
                                              global_step=self.global_step,
                                              name='train_step')
        train_ops = [apply_op] + self._extra_train_ops
        self.train_op = tf.group(*train_ops)
        self._message("built training")

    def _init_sess_serial(self) :
        self.sess = tf.Session(
            config=tf.ConfigProto(intra_op_parallelism_threads=self.run_opt.num_intra_threads, 
                                  inter_op_parallelism_threads=self.run_opt.num_inter_threads
            ))
        self.saver = tf.train.Saver()
        saver = self.saver
        if self.run_opt.init_mode == 'init_from_scratch' :
            self._message("initialize model from scratch")
            init_op = tf.global_variables_initializer()
            self.sess.run(init_op)
            fp = open(self.disp_file, "w")
            fp.close ()
        elif self.run_opt.init_mode == 'init_from_model' :
            self._message("initialize from model %s" % self.run_opt.init_model)
            init_op = tf.global_variables_initializer()
            self.sess.run(init_op)
            saver.restore (self.sess, self.run_opt.init_model)            
            self.sess.run(self.global_step.assign(0))
            fp = open(self.disp_file, "w")
            fp.close ()
        elif self.run_opt.init_mode == 'restart' :
            self._message("restart from model %s" % self.run_opt.restart)
            init_op = tf.global_variables_initializer()
            self.sess.run(init_op)
            saver.restore (self.sess, self.run_opt.restart)
        else :
            raise RuntimeError ("unkown init mode")

    def _init_sess_distrib(self):
        ckpt_dir = os.path.join(os.getcwd(), self.save_ckpt)
        assert(_is_subdir(ckpt_dir, os.getcwd())), "the checkpoint dir must be a subdir of the current dir"
        if self.run_opt.init_mode == 'init_from_scratch' :
            self._message("initialize model from scratch")
            if self.run_opt.is_chief :
                if os.path.exists(ckpt_dir):
                    shutil.rmtree(ckpt_dir)
                if not os.path.exists(ckpt_dir) :
                    os.makedirs(ckpt_dir)
                fp = open(self.disp_file, "w")
                fp.close ()
        elif self.run_opt.init_mode == 'init_from_model' :
            raise RuntimeError("distributed training does not support %s" % self.run_opt.init_mode)
        elif self.run_opt.init_mode == 'restart' :
            self._message("restart from model %s" % ckpt_dir)
            if self.run_opt.is_chief :
                assert(os.path.isdir(ckpt_dir)), "the checkpoint dir %s should exists" % ckpt_dir
        else :
            raise RuntimeError ("unkown init mode")

        saver = tf.train.Saver(max_to_keep = 1)
        self.saver = None
        # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
        # config = tf.ConfigProto(allow_soft_placement=True,
        #                         gpu_options = gpu_options,
        #                         intra_op_parallelism_threads=self.run_opt.num_intra_threads,
        #                         inter_op_parallelism_threads=self.run_opt.num_inter_threads)
        config = tf.ConfigProto(intra_op_parallelism_threads=self.run_opt.num_intra_threads,
                                inter_op_parallelism_threads=self.run_opt.num_inter_threads)
        # The stop_hook handles stopping after running given steps
        # stop_hook = tf.train.StopAtStepHook(last_step = stop_batch)
        # hooks = [self.sync_replicas_hook, stop_hook]
        hooks = [self.sync_replicas_hook]
        scaffold = tf.train.Scaffold(saver=saver)
        # Use monitor session for distributed computation
        self.sess = tf.train.MonitoredTrainingSession(master = self.run_opt.server.target,
                                                      is_chief = self.run_opt.is_chief,
                                                      config = config,
                                                      hooks = hooks,
                                                      scaffold = scaffold,
                                                      checkpoint_dir = ckpt_dir)
        # ,
        # save_checkpoint_steps = self.save_freq)

    def train (self, 
               data, 
               stop_batch) :
        if self.run_opt.is_distrib :
            self._init_sess_distrib()
        else :
            self._init_sess_serial()

        self.print_head()
        fp = None
        if self.run_opt.is_chief :
            fp = open(self.disp_file, "a")

        cur_batch = self.sess.run(self.global_step)
        self.cur_batch = cur_batch
        self.run_opt.message("start training at lr %.2e (== %.2e), final lr will be %.2e" % 
                             (self.sess.run(self.learning_rate),
                              self.lr.value(cur_batch), 
                              self.lr.value(stop_batch)) 
        )

        prf_options = None
        prf_run_metadata = None
        if self.profiling :
            prf_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            prf_run_metadata = tf.RunMetadata()

        train_time = 0
        while cur_batch < stop_batch :
            batch_data = data.get_batch (sys_weights = self.sys_weights)
            cur_batch_size = batch_data["coord"].shape[0]
            feed_dict_batch = {}
            for kk in batch_data.keys():
                if kk == 'find_type' or kk == 'type' :
                    continue
                if 'find_' in kk :
                    feed_dict_batch[self.place_holders[kk]] = batch_data[kk]
                else:
                    feed_dict_batch[self.place_holders[kk]] = np.reshape(batch_data[kk], [-1])
            for ii in ['type'] :
                feed_dict_batch[self.place_holders[ii]] = np.reshape(batch_data[ii], [-1])
            for ii in ['natoms_vec', 'default_mesh'] :
                feed_dict_batch[self.place_holders[ii]] = batch_data[ii]
            feed_dict_batch[self.place_holders['is_training']] = True

            if self.display_in_training and cur_batch == 0 :
                self.test_on_the_fly(fp, data, feed_dict_batch)
            if self.timing_in_training : tic = time.time()
            self.sess.run([self.train_op], feed_dict = feed_dict_batch, options=prf_options, run_metadata=prf_run_metadata)
            if self.timing_in_training : toc = time.time()
            if self.timing_in_training : train_time += toc - tic
            cur_batch = self.sess.run(self.global_step)
            self.cur_batch = cur_batch

            if self.display_in_training and (cur_batch % self.disp_freq == 0) :
                tic = time.time()
                self.test_on_the_fly(fp, data, feed_dict_batch)
                toc = time.time()
                test_time = toc - tic
                if self.timing_in_training :
                    self._message("batch %7d training time %.2f s, testing time %.2f s"
                                  % (cur_batch, train_time, test_time))
                    train_time = 0
                if self.save_freq > 0 and cur_batch % self.save_freq == 0 and self.run_opt.is_chief :
                    if self.saver is not None :
                        self.saver.save (self.sess, os.getcwd() + "/" + self.save_ckpt)
                        self._message("saved checkpoint %s" % self.save_ckpt)
        if self.run_opt.is_chief: 
            fp.close ()
        if self.profiling and self.run_opt.is_chief :
            fetched_timeline = timeline.Timeline(prf_run_metadata.step_stats)
            chrome_trace = fetched_timeline.generate_chrome_trace_format()
            with open(self.profiling_file, 'w') as f:
                f.write(chrome_trace)

    def get_global_step (self) :
        return self.sess.run(self.global_step)

    def print_head (self) :
        if self.run_opt.is_chief:
            fp = open(self.disp_file, "a")
            print_str = "# %5s" % 'batch'
            print_str += self.loss.print_header()
            print_str += '   %8s\n' % 'lr'
            fp.write(print_str)
            fp.close ()

    def test_on_the_fly (self,
                         fp,
                         data,
                         feed_dict_batch) :
        test_data = data.get_test ()
        feed_dict_test = {}
        for kk in test_data.keys():
            if kk == 'find_type' or kk == 'type' :
                continue
            if 'find_' in kk:
                feed_dict_test[self.place_holders[kk]] = test_data[kk]
            else:
                feed_dict_test[self.place_holders[kk]] = np.reshape(test_data[kk][:self.numb_test], [-1])
        for ii in ['type'] :
            feed_dict_test[self.place_holders[ii]] = np.reshape(test_data[ii][:self.numb_test], [-1])            
        for ii in ['natoms_vec', 'default_mesh'] :
            feed_dict_test[self.place_holders[ii]] = test_data[ii]
        feed_dict_test[self.place_holders['is_training']] = False

        cur_batch = self.cur_batch
        current_lr = self.sess.run(self.learning_rate)
        if self.run_opt.is_chief:
            print_str = "%7d" % cur_batch
            print_str += self.loss.print_on_training(self.sess,
                                                     test_data['natoms_vec'],
                                                     feed_dict_test,
                                                     feed_dict_batch)
            print_str += "   %8.1e\n" % current_lr
            fp.write(print_str)
            fp.flush ()