def _init_sess_serial(self) : self.sess = tf.Session( config=tf.ConfigProto(intra_op_parallelism_threads=self.run_opt.num_intra_threads, inter_op_parallelism_threads=self.run_opt.num_inter_threads )) self.saver = tf.train.Saver() saver = self.saver if self.run_opt.init_mode == 'init_from_scratch' : self._message("initialize model from scratch") init_op = tf.global_variables_initializer() self.sess.run(init_op) fp = open(self.disp_file, "w") fp.close () elif self.run_opt.init_mode == 'init_from_model' : self._message("initialize from model %s" % self.run_opt.init_model) init_op = tf.global_variables_initializer() self.sess.run(init_op) saver.restore (self.sess, self.run_opt.init_model) self.sess.run(self.global_step.assign(0)) fp = open(self.disp_file, "w") fp.close () elif self.run_opt.init_mode == 'restart' : self._message("restart from model %s" % self.run_opt.restart) init_op = tf.global_variables_initializer() self.sess.run(init_op) saver.restore (self.sess, self.run_opt.restart) else : raise RuntimeError ("unkown init mode")
def _init_sess_distrib(self): ckpt_dir = os.path.join(os.getcwd(), self.save_ckpt) assert (_is_subdir(ckpt_dir, os.getcwd()) ), "the checkpoint dir must be a subdir of the current dir" if self.run_opt.init_mode == 'init_from_scratch': self._message("initialize model from scratch") if self.run_opt.is_chief: if os.path.exists(ckpt_dir): shutil.rmtree(ckpt_dir) if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) fp = open(self.disp_file, "w") fp.close() elif self.run_opt.init_mode == 'init_from_model': raise RuntimeError("distributed training does not support %s" % self.run_opt.init_mode) elif self.run_opt.init_mode == 'restart': self._message("restart from model %s" % ckpt_dir) if self.run_opt.is_chief: assert (os.path.isdir(ckpt_dir) ), "the checkpoint dir %s should exists" % ckpt_dir else: raise RuntimeError("unkown init mode") saver = tf.train.Saver(max_to_keep=1) self.saver = None # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5) # config = tf.ConfigProto(allow_soft_placement=True, # gpu_options = gpu_options, # intra_op_parallelism_threads=self.run_opt.num_intra_threads, # inter_op_parallelism_threads=self.run_opt.num_inter_threads) config = tf.ConfigProto( intra_op_parallelism_threads=self.run_opt.num_intra_threads, inter_op_parallelism_threads=self.run_opt.num_inter_threads) # The stop_hook handles stopping after running given steps # stop_hook = tf.train.StopAtStepHook(last_step = stop_batch) # hooks = [self.sync_replicas_hook, stop_hook] hooks = [self.sync_replicas_hook] scaffold = tf.train.Scaffold(saver=saver) # Use monitor session for distributed computation self.sess = tf.train.MonitoredTrainingSession( master=self.run_opt.server.target, is_chief=self.run_opt.is_chief, config=config, hooks=hooks, scaffold=scaffold, checkpoint_dir=ckpt_dir)
def setUp(self): config = tf.ConfigProto() if int(os.environ.get("DP_AUTO_PARALLELIZATION", 0)): config.graph_options.rewrite_options.custom_optimizers.add( ).name = "dpparallel" self.sess = self.test_session(config=config).__enter__() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, 12.09, 2.87, 2.74, 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, 3.51, 2.51, 2.60, 4.27, 3.22, 1.56 ] self.dtype = [0, 1, 1, 0, 1, 1] self.dbox = [13., 0., 0., 0., 13., 0., 0., 0., 13.] self.dnlist = [ 33, -1, -1, -1, -1, 1, 32, 34, 35, -1, 0, 33, -1, -1, -1, 32, 34, 35, -1, -1, 6, 3, -1, -1, -1, 7, 4, 5, -1, -1, 6, -1, -1, -1, -1, 4, 5, 2, 7, -1, 3, 6, -1, -1, -1, 5, 2, 7, -1, -1, 3, 6, -1, -1, -1, 4, 2, 7, -1, -1 ] self.dem_deriv = [ 0.13227682739491875, 0.01648776318803519, -0.013864709953575083, 0.12967498112414713, 0.0204174282700489, -0.017169201045268437, 0.0204174282700489, -0.031583528930688706, -0.0021400703852459233, -0.01716920104526844, -0.0021400703852459233, -0.03232887285478848, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.7946522798827726, 0.33289487400494444, 0.6013584820734476, 0.15412158847174678, -0.502001299580599, -0.9068410573068878, -0.502001299580599, -0.833906252681877, 0.3798928753582899, -0.9068410573068878, 0.3798928753582899, -0.3579459969766471, 0.4206262499369199, 0.761133214171572, -0.5007455356391932, -0.6442543005863454, 0.635525177045359, -0.4181086691087898, 0.6355251770453592, 0.15453235677768898, -0.75657759172067, -0.4181086691087898, -0.75657759172067, -0.49771716703202185, 0.12240657396947655, -0.0016631327984983461, 0.013970315507385892, 0.12123416269111335, -0.0020346719145638054, 0.017091244082335703, -0.002034671914563806, -0.028490045221941415, -0.00023221799024912971, 0.017091244082335703, -0.00023221799024912971, -0.026567059102687942, 0.057945707686107975, 0.008613551142529565, -0.008091517739952026, 0.056503423854730866, 0.009417127630974357, -0.008846392623036528, 0.009417127630974357, -0.005448318729873151, -0.0013150043088297543, -0.008846392623036528, -0.0013150043088297541, -0.005612854948377751, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7946522798827726, -0.33289487400494444, -0.6013584820734476, 0.15412158847174678, -0.502001299580599, -0.9068410573068878, -0.502001299580599, -0.833906252681877, 0.3798928753582899, -0.9068410573068878, 0.3798928753582899, -0.3579459969766471, 0.06884320605436924, 0.002095928989945659, -0.01499395354345747, 0.0668001797461137, 0.0023216922720068383, -0.016609029330510533, 0.0023216922720068383, -0.009387797963986713, -0.0005056613145120282, -0.016609029330510533, -0.0005056613145120282, -0.005841058553679004, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3025931001933299, 0.11738525438534331, -0.2765074881076981, 0.034913562192579815, 0.15409432322878, -0.3629777391611269, 0.15409432322878003, -0.30252938969021487, -0.14081032984698866, -0.3629777391611269, -0.14081032984698866, -0.030620805157591004, 0.06555082496658332, -0.005338981218997747, -0.002076270474054677, 0.06523884623439505, -0.00599162877720186, -0.0023300778578007205, -0.00599162877720186, -0.007837034455273667, 0.00018978009701544363, -0.0023300778578007205, 0.00018978009701544363, -0.008251237047966105, 0.014091999096200191, 0.0009521621010946066, -0.00321014651226182, 0.013676554858123476, 0.0009667394698497006, -0.0032592930697789946, 0.0009667394698497006, -0.0005658690612028018, -0.00022022250471479668, -0.0032592930697789937, -0.00022022250471479666, 0.00011127514881492382, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.4206262499369199, -0.761133214171572, 0.5007455356391932, -0.6442543005863454, 0.635525177045359, -0.4181086691087898, 0.6355251770453592, 0.15453235677768898, -0.75657759172067, -0.4181086691087898, -0.75657759172067, -0.49771716703202185, 0.17265177804411166, -0.01776481317495682, 0.007216955352326217, 0.1708538944675734, -0.023853120077098278, 0.009690330031321191, -0.02385312007709828, -0.05851427595224925, -0.0009970757588497682, 0.00969033003132119, -0.0009970757588497682, -0.06056355425469288, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.3025931001933299, -0.11738525438534331, 0.2765074881076981, 0.034913562192579815, 0.15409432322878, -0.3629777391611269, 0.15409432322878003, -0.30252938969021487, -0.14081032984698866, -0.3629777391611269, -0.14081032984698866, -0.030620805157591004, 0.13298898711407747, -0.03304327593938735, 0.03753063440029181, 0.11967949867634801, -0.0393666881596552, 0.044712781613435545, -0.0393666881596552, -0.02897797727002851, -0.01110961751744871, 0.044712781613435545, -0.011109617517448708, -0.026140939946396612, 0.09709214772325653, -0.00241522755530488, -0.0028982730663658636, 0.09699249715361474, -0.0028489422636695603, -0.0034187307164034813, -0.00284894226366956, -0.017464112635362926, 8.504305264685245e-05, -0.003418730716403481, 8.504305264685245e-05, -0.017432930182725747, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.1322768273949186, -0.016487763188035173, 0.013864709953575069, 0.12967498112414702, 0.020417428270048884, -0.017169201045268423, 0.02041742827004888, -0.03158352893068868, -0.002140070385245921, -0.017169201045268423, -0.002140070385245921, -0.03232887285478844, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1802999914938216, -0.5889799722131493, 0.9495799552007915, -1.070225697321266, -0.18728687322613707, 0.30195230581356786, -0.18728687322613707, -0.5157546277429348, -0.9863775323243197, 0.30195230581356786, -0.9863775323243197, 0.4627237303364723, 1.0053013143052718, 0.24303987818369216, -0.2761816797541954, 0.8183357773897718, 0.45521877564245394, -0.517294063230061, 0.45521877564245394, -0.9545617219529918, -0.1250601031984763, -0.517294063230061, -0.1250601031984763, -0.922500859133019, -0.17265177804411166, 0.01776481317495682, -0.007216955352326217, 0.1708538944675734, -0.023853120077098278, 0.009690330031321191, -0.02385312007709828, -0.05851427595224925, -0.0009970757588497682, 0.00969033003132119, -0.0009970757588497682, -0.06056355425469288, -0.06884320605436924, -0.002095928989945659, 0.01499395354345747, 0.0668001797461137, 0.0023216922720068383, -0.016609029330510533, 0.0023216922720068383, -0.009387797963986713, -0.0005056613145120282, -0.016609029330510533, -0.0005056613145120282, -0.005841058553679004, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.1802999914938216, 0.5889799722131493, -0.9495799552007915, -1.070225697321266, -0.18728687322613707, 0.30195230581356786, -0.18728687322613707, -0.5157546277429348, -0.9863775323243197, 0.30195230581356786, -0.9863775323243197, 0.4627237303364723, -0.12240657396947667, 0.0016631327984983487, -0.013970315507385913, 0.12123416269111348, -0.002034671914563809, 0.01709124408233573, -0.002034671914563809, -0.028490045221941467, -0.00023221799024913015, 0.01709124408233573, -0.00023221799024913015, -0.026567059102687987, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2602591506940697, 0.24313683814840728, -0.3561441009497795, -0.19841405298242495, 0.23891499072173572, -0.3499599864093028, 0.23891499072173572, -0.23095714382387694, -0.32693630309290145, -0.34995998640930287, -0.32693630309290145, 0.02473856993038946, -0.13298898711407747, 0.03304327593938735, -0.03753063440029181, 0.11967949867634801, -0.0393666881596552, 0.044712781613435545, -0.0393666881596552, -0.02897797727002851, -0.01110961751744871, 0.044712781613435545, -0.011109617517448708, -0.026140939946396612, -0.0655508249665835, 0.005338981218997763, 0.002076270474054683, 0.0652388462343952, -0.005991628777201879, -0.0023300778578007283, -0.005991628777201879, -0.007837034455273709, 0.0001897800970154443, -0.002330077857800728, 0.0001897800970154443, -0.008251237047966148, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0053013143052718, -0.24303987818369216, 0.2761816797541954, 0.8183357773897718, 0.45521877564245394, -0.517294063230061, 0.45521877564245394, -0.9545617219529918, -0.1250601031984763, -0.517294063230061, -0.1250601031984763, -0.922500859133019, -0.057945707686107864, -0.008613551142529548, 0.00809151773995201, 0.05650342385473076, 0.009417127630974336, -0.00884639262303651, 0.009417127630974336, -0.005448318729873148, -0.0013150043088297515, -0.00884639262303651, -0.0013150043088297513, -0.005612854948377747, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2602591506940697, -0.24313683814840728, 0.3561441009497795, -0.19841405298242495, 0.23891499072173572, -0.3499599864093028, 0.23891499072173572, -0.23095714382387694, -0.32693630309290145, -0.34995998640930287, -0.32693630309290145, 0.02473856993038946, -0.09709214772325653, 0.00241522755530488, 0.0028982730663658636, 0.09699249715361474, -0.0028489422636695603, -0.0034187307164034813, -0.00284894226366956, -0.017464112635362926, 8.504305264685245e-05, -0.003418730716403481, 8.504305264685245e-05, -0.017432930182725747, -0.014091999096200191, -0.0009521621010946064, 0.0032101465122618194, 0.013676554858123474, 0.0009667394698497003, -0.0032592930697789933, 0.0009667394698497003, -0.0005658690612028016, -0.0002202225047147966, -0.0032592930697789933, -0.0002202225047147966, 0.00011127514881492362, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ] self.dcoord = np.reshape(self.dcoord, [1, -1]) self.dtype = np.reshape(self.dtype, [1, -1]) self.dbox = np.reshape(self.dbox, [1, -1]) self.dnlist = np.reshape(self.dnlist, [1, -1]) self.dem_deriv = np.reshape(self.dem_deriv, [1, -1]) self.dcoord = np.tile(self.dcoord, [self.nframes, 1]) self.dtype = np.tile(self.dtype, [self.nframes, 1]) self.dbox = np.tile(self.dbox, [self.nframes, 1]) self.dnlist = np.tile(self.dnlist, [self.nframes, 1]) self.dem_deriv = np.tile(self.dem_deriv, [self.nframes, 1]) self.expected_force = [ 9.44498, -13.86254, 10.52884, -19.42688, 8.09273, 19.64478, 4.81771, 11.39255, 12.38830, -16.65832, 6.65153, -10.15585, 1.16660, -14.43259, 22.97076, 22.86479, 7.42726, -11.41943, -7.67893, -7.23287, -11.33442, -4.51184, -3.80588, -2.44935, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 1.16217, 6.16192, -28.79094, 3.81076, -0.01986, -1.01629, 3.65869, -0.49195, -0.07437, 1.35028, 0.11969, -0.29201, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000 ] self.sel = [5, 5] self.sec = np.array([0, 0, 0], dtype=int) self.sec[1:3] = np.cumsum(self.sel) self.rcut = 6. self.rcut_smth = 0.8 self.dnatoms = [6, 48, 2, 4] self.nloc = self.dnatoms[0] self.nall = self.dnatoms[1] self.nnei = self.sec[-1] self.ndescrpt = 4 * self.nnei self.ntypes = np.max(self.dtype) + 1 self.dnet_deriv = [] for ii in range(self.nloc * self.ndescrpt): self.dnet_deriv.append(10 - ii * 0.01) self.dnet_deriv = np.reshape(self.dnet_deriv, [1, -1]) self.dnet_deriv = np.tile(self.dnet_deriv, [self.nframes, 1]) self.tnet_deriv = tf.placeholder( GLOBAL_TF_FLOAT_PRECISION, [None, self.dnatoms[0] * self.ndescrpt], name='t_net_deriv') self.tem_deriv = tf.placeholder( GLOBAL_TF_FLOAT_PRECISION, [None, self.dnatoms[0] * self.ndescrpt * 3], name='t_em_deriv') self.tnlist = tf.placeholder(tf.int32, [None, self.dnatoms[0] * self.nnei], name="t_nlist") self.tnatoms = tf.placeholder(tf.int32, [None], name="t_natoms")