def make_train(iter_index, json_file, base_dir="./"): json_file = os.path.abspath(json_file) fp = open(json_file, 'r') jdata = json.load(fp) fp.close() # template_dir = jdata["template_dir"] numb_model = jdata["numb_model"] res_iter = jdata["res_iter"] # abs path base_dir = os.path.abspath(base_dir) + "/" iter_name = make_iter_name(iter_index) train_path = base_dir + iter_name + "/" + train_name + "/" data_path = train_path + "data/" cwd = os.getcwd() + "/" create_path(train_path) os.makedirs(data_path) collect_data(iter_index, base_dir) # create train dirs log_task("create train dirs") for ii in range(numb_model): work_path = train_path + ("%03d/" % ii) old_model_path = work_path + "old_model/" create_path(work_path) os.chdir(work_path) os.symlink("../data", "./data") os.chdir(cwd) if iter_index >= 1: prev_iter_index = iter_index - 1 prev_iter_name = make_iter_name(prev_iter_index) prev_train_path = base_dir + prev_iter_name + "/" + train_name + "/" prev_work_path = prev_train_path + ("%03d/" % ii) prev_model_files = glob.glob(prev_work_path + "model.ckpt.*") prev_model_files = prev_model_files + \ [prev_work_path + "checkpoint"] create_path(old_model_path) os.chdir(old_model_path) # why to copy twice. for ii in prev_model_files: os.symlink(os.path.relpath(ii), os.path.basename(ii)) os.chdir(cwd) for jj in prev_model_files: shutil.copy(jj, work_path) print("Training files have prepared.")
def main(out_dir, mol_dir, rid_json, machine_json, cv_file, init_model, record_name="record.txt"): out_dir = os.path.abspath(out_dir) mol_dir = os.path.abspath(mol_dir) rid_json = os.path.abspath(rid_json) cv_file = os.path.abspath(cv_file) fp = open(rid_json, 'r') jdata = json.load(fp) fp.close() record_file = os.path.join(out_dir, record_name) checkpoint = get_checkpoint(record_file) max_tasks = 10 number_tasks = 8 iter_numb = int(jdata['numb_iter']) prev_model = init_model if sum(checkpoint) < 0: print("prepare gen_rid") gen_rid (out_dir, mol_dir, rid_json) for iter_idx in range(iter_numb): if iter_idx > 0 : prev_model = glob.glob (out_dir + "/" + make_iter_name(iter_idx-1) + "/02.train/*pb") for tag in range(number_tasks): if iter_idx * max_tasks + tag <= checkpoint[0] * max_tasks + checkpoint[1]: continue elif tag == 0: print("prepare gen_enhc") enhcMD.make_enhc(iter_idx, rid_json, prev_model, mol_dir, cv_file ,base_dir=out_dir) record_iter(record_file, iter_idx, tag) elif tag == 1: print("run enhanced MD") enhcMD.run_enhc(iter_idx, rid_json, machine_json, base_dir=out_dir) record_iter(record_file, iter_idx, tag) elif tag == 2: print("prepare post enhc") enhcMD.post_enhc(iter_idx, rid_json, machine_json, base_dir=out_dir) record_iter(record_file, iter_idx, tag) elif tag == 3: print("prepare gen_res") resMD.make_res(iter_index=iter_idx, json_file=rid_json, cv_file=cv_file, mol_path=mol_dir, base_dir=out_dir) record_iter(record_file, iter_idx, tag) elif tag == 4: print("prepare run res") resMD.run_res (iter_index=iter_idx, json_file=rid_json, machine_json=machine_json, base_dir=out_dir) record_iter(record_file, iter_idx, tag) elif tag == 5: print("prepare post res") resMD.post_res(iter_index=iter_idx, json_file=rid_json, machine_json=machine_json, cv_file=cv_file, base_dir=out_dir) record_iter(record_file, iter_idx, tag) elif tag == 6: print("prepare gen train") train.make_train(iter_index=iter_idx, json_file=rid_json, base_dir=out_dir) record_iter(record_file, iter_idx, tag) elif tag == 7: print("prepare run train") train.run_train(iter_index=iter_idx, json_file=rid_json, machine_json=machine_json, cv_file=cv_file, base_dir=out_dir) record_iter(record_file, iter_idx, tag)
def run_enhc(iter_index, json_file, machine_json, base_dir='./'): json_file = os.path.abspath(json_file) base_dir = os.path.abspath(base_dir) + "/" iter_name = make_iter_name(iter_index) work_path = base_dir + iter_name + "/" + enhc_name + "/" fp = open(json_file, 'r') jdata = json.load(fp) fp.close() gmx_prep = jdata["gmx_prep"] gmx_run = jdata["gmx_run"] enhc_thread = jdata["enhc_thread"] gmx_run = gmx_run + (" -nt %d" % enhc_thread) gmx_prep_log = "gmx_grompp.log" gmx_run_log = "gmx_mdrun.log" # assuming at least one walker graph_files = glob.glob(work_path + (make_walker_name(0)) + "/*.pb") if len(graph_files) != 0: gmx_run = gmx_run + " -plumed " + enhc_plm else: gmx_run = gmx_run + " -plumed " + enhc_bf_plm gmx_prep_cmd = cmd_append_log(gmx_prep, gmx_prep_log) gmx_run_cmd = cmd_append_log(gmx_run, gmx_run_log) numb_walkers = jdata["numb_walkers"] all_task = list(filter(lambda x: os.path.isdir( x), glob.glob(work_path + "/[0-9]*[0-9]"))) all_task.sort() all_task_basedir = [os.path.relpath(ii, work_path) for ii in all_task] print('run_enhc:work_path', work_path) print('run_enhc:gmx_prep_cmd:', gmx_prep_cmd) print('run_enhc:gmx_run_cmd:', gmx_run_cmd) print('run_enhc:all_task:', all_task) print('run_enhc:all_task_basedir:', all_task_basedir) machine = set_machine(machine_json, target="enhcMD") resources = set_resource(machine_json, target="enhcMD") gmx_prep_task = [Task(command=gmx_prep_cmd, task_work_path=ii, outlog='gmx_grompp.log', errlog='gmx_grompp.log') for ii in all_task_basedir] gmx_prep_submission = Submission( work_base=work_path, machine=machine, resources=resources, task_list=gmx_prep_task) gmx_prep_submission.run_submission() gmx_run_task = [Task(command=gmx_run_cmd, task_work_path=ii, outlog='gmx_mdrun.log', errlog='gmx_mdrun.log') for ii in all_task_basedir] gmx_run_submission = Submission( work_base=work_path, machine=machine, resources=resources, task_list=gmx_run_task) gmx_run_submission.run_submission()
def post_enhc(iter_index, json_file, machine_json, base_dir="./"): base_dir = os.path.abspath(base_dir) + "/" iter_name = make_iter_name(iter_index) work_path = base_dir + iter_name + "/" + enhc_name + "/" json_file = os.path.abspath(json_file) json_file = os.path.abspath(json_file) fp = open(json_file, 'r') jdata = json.load(fp) fp.close() gmx_split = jdata["gmx_split_traj"] gmx_split_log = "gmx_split.log" gmx_split_cmd = cmd_append_log(gmx_split, gmx_split_log) all_task = list( filter(lambda x: os.path.isdir(x), glob.glob(work_path + "/[0-9]*[0-9]"))) all_task.sort() cwd = os.getcwd() numb_walkers = jdata["numb_walkers"] for ii in range(numb_walkers): walker_path = work_path + make_walker_name(ii) + "/" os.chdir(walker_path) if os.path.isdir("confs"): shutil.rmtree("confs") os.makedirs("confs") os.chdir(cwd) print('rid.py:post_enhc:gmx_split_cmd', gmx_split_cmd) print('rid.py:post_enhc:work path', work_path) machine = set_machine(machine_json, target="post_enhc") resources = set_resource(machine_json, target="post_enhc") all_task_relpath = [os.path.relpath(ii, work_path) for ii in all_task] gmx_split_task = [ Task(command=gmx_split_cmd, task_work_path=ii, outlog='gmx_split.log', errlog='gmx_split.log') for ii in all_task_relpath ] gmx_split_submission = Submission(work_base=work_path, resources=resources, machine=machine, task_list=gmx_split_task) gmx_split_submission.run_submission() for ii in range(numb_walkers): walker_path = work_path + make_walker_name(ii) + "/" angles = np.loadtxt(walker_path + enhc_out_plm) np.savetxt(walker_path + enhc_out_angle, angles[:, 1:], fmt="%.6f") print("Post process of enhanced sampling finished.")
def check_new_data(iter_index, train_path, base_path): # check if new data is empty new_data_file = os.path.join(train_path, 'data/data.new.raw') if os.stat(new_data_file).st_size == 0: prev_iter_index = iter_index - 1 prev_train_path = base_path + \ make_iter_name(prev_iter_index) + "/" + train_name + "/" prev_models = glob.glob(prev_train_path + "*.pb") for ii in prev_models: model_name = os.path.basename(ii) os.symlink(ii, os.path.join(train_path, model_name)) return True else: return False
def collect_data(iter_index, base_dir): iter_name = make_iter_name(iter_index) train_path = base_dir + iter_name + "/" + train_name + "/" data_path = train_path + "data/" data_file = train_path + "data/data.raw" data_old_file = train_path + "data/data.old.raw" data_new_file = train_path + "data/data.new.raw" cwd = os.getcwd() + "/" # collect data log_task("collect data upto %d" % (iter_index)) if iter_index == 0: ii = 0 this_raw = base_dir + make_iter_name(ii) + "/" + res_name + "/data.raw" os.chdir(data_path) os.symlink(os.path.relpath(this_raw), os.path.basename(data_new_file)) os.symlink(os.path.basename(data_new_file), os.path.basename(data_file)) os.chdir(cwd) open(data_old_file, "w").close() else: prev_iter_index = iter_index - 1 prev_data_file = base_dir + \ make_iter_name(prev_iter_index) + "/" + \ train_name + "/data/data.raw" this_raw = base_dir + \ make_iter_name(iter_index) + "/" + res_name + "/data.raw" os.chdir(data_path) os.symlink(os.path.relpath(prev_data_file), os.path.basename(data_old_file)) os.symlink(os.path.relpath(this_raw), os.path.basename(data_new_file)) os.chdir(cwd) with open(data_file, "wb") as fo: with open(data_old_file, "rb") as f0, open(data_new_file, "rb") as f1: shutil.copyfileobj(f0, fo) shutil.copyfileobj(f1, fo)
def post_res(iter_index, json_file, machine_json, cv_file, base_dir="./"): json_file = os.path.abspath(json_file) machine_json = os.path.abspath(machine_json) cv_file = os.path.abspath(cv_file) base_dir = os.path.abspath(base_dir) + "/" iter_name = make_iter_name(iter_index) res_path = base_dir + iter_name + "/" + res_name + "/" cwd = os.getcwd() fp = open(json_file, 'r') jdata = json.load(fp) fp.close() os.chdir(res_path) all_task = glob.glob("/[0-9]*[0-9]") all_task = list( filter(lambda x: os.path.isdir(x), glob.glob("[0-9]*[0-9]"))) if len(all_task) == 0: np.savetxt(res_path + 'data.raw', [], fmt="%.6e") os.chdir(cwd) return all_task.sort() all_task_reldir = [os.path.relpath(ii, res_path) for ii in all_task] centers = [] force = [] ndim = 0 _conf_file = os.path.abspath(all_task[0] + "/conf.gro") cv_dim_list = cal_cv_dim(_conf_file, cv_file) cv_dih_dim = cv_dim_list[0] cmpf_cmd = "python3 {}/cmpf.py".format(LIB_PATH) cmpf_cmd += " -c %d" % cv_dih_dim cmpf_log = "cmpf.log" print("rid.post_res.post_res:cmpf_cmd:", cmpf_cmd) cmpf_resources = set_resource(machine_json, target="cmpf") machine = set_machine(machine_json, target="cmpf") cmpf_task = [ Task(command=cmpf_cmd, task_work_path="{}".format(ii), outlog=cmpf_log, errlog=cmpf_log) for ii in all_task_reldir ] cmpf_submission = Submission(work_base=res_path, machine=machine, resources=cmpf_resources, task_list=cmpf_task) cmpf_submission.run_submission() print('cmpf done') abs_res_path = os.getcwd() for work_path in all_task: os.chdir(work_path) this_centers = np.loadtxt('centers.out') centers = np.append(centers, this_centers) this_force = np.loadtxt('force.out') force = np.append(force, this_force) ndim = this_force.size assert (ndim == this_centers.size ), "center size is diff to force size in " + work_path os.chdir(abs_res_path) os.chdir(cwd) centers = np.reshape(centers, [-1, ndim]) force = np.reshape(force, [-1, ndim]) data = np.concatenate((centers, force), axis=1) np.savetxt(res_path + 'data.raw', data, fmt="%.6e") norm_force = np.linalg.norm(force, axis=1) log_task("min|f| = %e max|f| = %e avg|f| = %e" % (np.min(norm_force), np.max(norm_force), np.average(norm_force))) print("min|f| = %e max|f| = %e avg|f| = %e" % (np.min(norm_force), np.max(norm_force), np.average(norm_force))) print('Saving cmpf finished!') print("Post process of restrained MD finished.") print(os.getcwd())
def run_res(iter_index, json_file, machine_json, base_dir="./"): json_file = os.path.abspath(json_file) fp = open(json_file, 'r') jdata = json.load(fp) fp.close() gmx_prep = jdata["gmx_prep"] gmx_run = jdata["gmx_run"] res_thread = jdata["res_thread"] gmx_run = gmx_run + (" -nt %d" % res_thread) gmx_run = gmx_run + " -plumed " + res_plm # gmx_cont_run = gmx_run + " -cpi state.cpt" gmx_prep_log = "gmx_grompp.log" gmx_run_log = "gmx_mdrun.log" gmx_prep_cmd = cmd_append_log(gmx_prep, gmx_prep_log) gmx_run_cmd = cmd_append_log(gmx_run, gmx_run_log) # gmx_cont_run_cmd = cmd_append_log (gmx_cont_run, gmx_run_log) base_dir = os.path.abspath(base_dir) + "/" iter_name = make_iter_name(iter_index) res_path = base_dir + iter_name + "/" + res_name + "/" if not os.path.isdir(res_path): raise RuntimeError("do not see any restrained simulation (%s)." % res_path) all_task = list( filter(lambda x: os.path.isdir(x), glob.glob(res_path + "/[0-9]*[0-9]"))) print('run_res:all_task_propose:', all_task) print('run_res:gmx_prep_cmd:', gmx_prep_cmd) print('run_res:gmx_run_cmd:', gmx_run_cmd) # print('run_res:gmx_cont_run_cmd:', gmx_cont_run_cmd) if len(all_task) == 0: return None all_task.sort() all_task_basedir = [os.path.relpath(ii, res_path) for ii in all_task] res_resources = set_resource(machine_json, target="resMD") machine = set_machine(machine_json, target="resMD") gmx_prep_task = [ Task(command=gmx_prep_cmd, task_work_path=ii, outlog='gmx_grompp.log', errlog='gmx_grompp.log') for ii in all_task_basedir ] gmx_prep_submission = Submission(work_base=res_path, machine=machine, resources=res_resources, task_list=gmx_prep_task) gmx_prep_submission.run_submission() gmx_run_task = [ Task(command=gmx_run_cmd, task_work_path=ii, outlog='gmx_mdrun.log', errlog='gmx_mdrun.log') for ii in all_task_basedir ] gmx_run_submission = Submission(work_base=res_path, machine=machine, resources=res_resources, task_list=gmx_run_task) gmx_run_submission.run_submission()
def make_res(iter_index, json_file, cv_file, mol_path, base_dir="./", custom_mdp=None): json_file = os.path.abspath(json_file) fp = open(json_file, 'r') jdata = json.load(fp) fp.close() cv_file = os.path.abspath(cv_file) numb_walkers = jdata["numb_walkers"] bias_nsteps = jdata["bias_nsteps"] bias_frame_freq = jdata["bias_frame_freq"] nsteps = jdata["res_nsteps"] frame_freq = jdata["res_frame_freq"] sel_threshold = jdata["sel_threshold"] max_sel = jdata["max_sel"] cluster_threshold = jdata["cluster_threshold"] init_numb_cluster_upper = int(jdata["init_numb_cluster_upper"]) init_numb_cluster_lower = int(jdata["init_numb_cluster_lower"]) init_numb_cluster = [init_numb_cluster_lower, init_numb_cluster_upper] base_dir = os.path.abspath(base_dir) + "/" iter_name = make_iter_name(iter_index) enhc_path = base_dir + iter_name + "/" + enhc_name + "/" res_path = base_dir + iter_name + "/" + res_name + "/" create_path(res_path) cwd = os.getcwd() _conf_file = enhc_path + make_walker_name(0) + "/" + "conf.gro" cv_dim_list = cal_cv_dim(_conf_file, cv_file) cv_dim = sum(cv_dim_list) cv_dih_dim = cv_dim_list[0] ret_list = [True for ii in range(numb_walkers)] weight = jdata["cv_weight_for_cluster"] if type(weight) == list: assert len( weight ) == cv_dim, "Number of values in the weight list is not equal to the number of CVs." elif type(weight) == float or type(weight) == int: assert weight != 0 else: raise TypeError( "Invalid type of weight of CVs for clustering. Please use int or list instead." ) # check if we have graph in enhc for walker_idx in range(numb_walkers): cls_sel = None walker_path = enhc_path + walker_format % walker_idx + "/" graph_files = glob.glob(walker_path + "*.pb") if len(graph_files) != 0: cluster_threshold = np.loadtxt(base_dir + "cluster_threshold.dat") os.chdir(walker_path) models = glob.glob("*.pb") std_message = make_std(cv_dim, dataset=enhc_out_angle, models=models, threshold=sel_threshold, output="sel.out", output_angle="sel.angle.out") os.system('echo "{}" > sel.log'.format(std_message)) log_task("select with threshold %f" % sel_threshold) os.chdir(cwd) sel_idx = [] sel_angles = np.array([]) with open(walker_path + "sel.out") as fp: for line in fp: sel_idx += [int(x) for x in line.split()] if len(sel_idx) != 0: sel_angles = np.reshape( np.loadtxt(walker_path + 'sel.angle.out'), [-1, cv_dim]) elif len(sel_idx) == 0: np.savetxt(walker_path + 'num_of_cluster.dat', [0], fmt='%d') np.savetxt(walker_path + 'cls.sel.out', [], fmt='%d') continue else: cluster_threshold = jdata["cluster_threshold"] sel_idx = range( len(glob.glob(walker_path + enhc_out_conf + "conf*gro"))) sel_angles = np.loadtxt(walker_path + enhc_out_angle) sel_angles = np.reshape(sel_angles, [-1, cv_dim]) np.savetxt(walker_path + 'sel.out', sel_idx, fmt='%d') np.savetxt(walker_path + 'sel.angle.out', sel_angles, fmt='%.6f') cls_sel, cluster_threshold = make_threshold( walker_idx, walker_path, base_dir, sel_angles, cluster_threshold, init_numb_cluster, cv_dih_dim, weight) if cls_sel is None: print(sel_angles, cluster_threshold, cv_dih_dim) cls_sel = sel_from_cluster(sel_angles, cluster_threshold, cv_dih_dim, weight) conf_start = 0 conf_every = 1 sel_idx = np.array(sel_idx, dtype=np.int) assert ( len(sel_idx) == sel_angles.shape[0] ), "{} selected indexes don't match {} selected angles.".format( len(sel_idx), sel_angles.shape[0]) sel_idx = config_cls(sel_idx, cls_sel, max_sel, walker_path, cluster_threshold, sel_angles) res_angles = np.loadtxt(walker_path + enhc_out_angle) res_angles = np.reshape(res_angles, [-1, cv_dim]) res_angles = res_angles[sel_idx] np.savetxt(walker_path + 'cls.sel.out', sel_idx, fmt='%d') np.savetxt(walker_path + 'cls.sel.angle.out', res_angles, fmt='%.6f') res_confs = [] for ii in sel_idx: res_confs.append(walker_path + enhc_out_conf + ("conf%d.gro" % ii)) assert (len(res_confs) == res_angles.shape[0] ), "number of enhc out conf does not match out angle" assert (len(sel_idx) == res_angles.shape[0] ), "number of enhc out conf does not match number sel" nconf = len(res_confs) if nconf == 0: ret_list[walker_idx] = False continue sel_list = make_sel_list(nconf, sel_idx) log_task("selected %d confs, indexes: %s" % (nconf, sel_list)) make_conf(nconf, res_path, walker_idx, walker_path, sel_idx, jdata, mol_path, conf_start=0, conf_every=1, custom_mdp=custom_mdp) make_res_plumed(nconf, jdata, res_path, walker_idx, sel_idx, res_angles, _conf_file, cv_file, conf_start=0, conf_every=1) print("Restrained MD has been prepared.")
def make_enhc(iter_index, json_file, graph_files, mol_dir, cv_file, base_dir='./', custom_mdp=None): base_dir = os.path.abspath(base_dir) + "/" json_file = os.path.abspath(json_file) cv_file = os.path.abspath(cv_file) graph_files.sort() fp = open(json_file, 'r') jdata = json.load(fp) fp.close() numb_walkers = jdata["numb_walkers"] enhc_trust_lvl_1 = jdata["bias_trust_lvl_1"] enhc_trust_lvl_2 = jdata["bias_trust_lvl_2"] nsteps = jdata["bias_nsteps"] frame_freq = jdata["bias_frame_freq"] num_of_cluster_threshold = jdata["num_of_cluster_threshold"] dt = jdata["bias_dt"] temperature = jdata["bias_temperature"] iter_name = make_iter_name(iter_index) work_path = base_dir + iter_name + "/" + enhc_name + "/" mol_path = os.path.abspath(mol_dir) + "/" conf_list = glob.glob(mol_path + "*gro") conf_list.sort() assert (len(conf_list) >= numb_walkers), "not enough conf files in mol dir %s" % mol_path create_path(work_path) mol_files = ["topol.top"] for walker_idx in range(numb_walkers): walker_path = work_path + make_walker_name(walker_idx) + "/" create_path(walker_path) make_grompp(walker_path + "grompp.mdp", "bias", nsteps, frame_freq, temperature=temperature, dt=dt, define='', custom_mdp=custom_mdp) # make_grompp(walker_path + "grompp_restraint.mdp", "res", nsteps, frame_freq, temperature=temperature, dt=dt, define='-DPOSRE') for ii in mol_files: checkfile(walker_path + ii) shutil.copy(mol_path + ii, walker_path) # copy conf file conf_file = conf_list[walker_idx] checkfile(walker_path + "conf.gro") shutil.copy(conf_file, walker_path + "conf.gro") checkfile(walker_path + "conf_init.gro") shutil.copy(conf_file, walker_path + "conf_init.gro") # if have prev confout.gro, use as init conf if iter_index > 0: prev_enhc_path = base_dir + \ make_iter_name(iter_index-1) + "/" + enhc_name + \ "/" + make_walker_name(walker_idx) + "/" prev_enhc_path = os.path.abspath(prev_enhc_path) + "/" if os.path.isfile(prev_enhc_path + "confout.gro"): os.remove(walker_path + "conf.gro") rel_prev_enhc_path = os.path.relpath( prev_enhc_path + "confout.gro", walker_path) os.symlink(rel_prev_enhc_path, walker_path + "conf.gro") else: raise RuntimeError( "cannot find prev output conf file " + prev_enhc_path + 'confout.gro') log_task("use conf of iter " + make_iter_name(iter_index - 1) + " walker " + make_walker_name(walker_idx)) enhc_trust_lvl_1, enhc_trust_lvl_2 = adjust_lvl( prev_enhc_path, num_of_cluster_threshold, jdata) np.savetxt(walker_path+'trust_lvl1.dat', [enhc_trust_lvl_1], fmt='%.6f') make_plumed(walker_path, "dpbias", conf_file, cv_file) make_plumed(walker_path, "bf", conf_file, cv_file) prep_graph(graph_files, walker_path) # config plumed graph_list = get_graph_list(graph_files) conf_enhc_plumed(walker_path + enhc_plm, "enhc", graph_list, enhc_trust_lvl_1=enhc_trust_lvl_1, enhc_trust_lvl_2=enhc_trust_lvl_2, frame_freq=frame_freq, enhc_out_plm=enhc_out_plm) conf_enhc_plumed(walker_path + enhc_bf_plm, "bf", graph_list, frame_freq=frame_freq, enhc_out_plm=enhc_out_plm) if len(graph_list) == 0: log_task("brute force MD without NN acc") else: log_task("use NN model(s): " + graph_list) log_task("set trust l1 and l2: %f %f" % (enhc_trust_lvl_1, enhc_trust_lvl_2)) print("Enhanced sampling has prepared.")
def test_make_name(self): self.assertTrue((utils.make_iter_name(5) == "iter.000005")) self.assertTrue((utils.make_walker_name(5) == "005")) pass
def run_train(iter_index, json_file, machine_json, cv_file, base_dir="./"): json_file = os.path.abspath(json_file) cv_file = os.path.abspath(cv_file) fp = open(json_file, 'r') jdata = json.load(fp) fp.close() cv_file = os.path.abspath(cv_file) numb_model = jdata["numb_model"] train_thread = jdata["train_thread"] res_iter = jdata["res_iter"] base_dir = os.path.abspath(base_dir) + "/" iter_name = make_iter_name(iter_index) train_path = base_dir + iter_name + "/" + train_name + "/" if check_new_data(iter_index, train_path, base_dir): return enhc_path = base_dir + iter_name + "/" + enhc_name + "/" _conf_file = enhc_path + "000/conf.gro" cv_dim_list = cal_cv_dim(_conf_file, cv_file) cwd = os.getcwd() neurons = jdata["neurons"] batch_size = jdata["batch_size"] if iter_index < res_iter: numb_epoches = jdata["numb_epoches"] starter_lr = jdata["starter_lr"] decay_steps = jdata["decay_steps"] decay_rate = jdata["decay_rate"] cmdl_args = "" else: numb_epoches = jdata["res_numb_epoches"] starter_lr = jdata["res_starter_lr"] decay_steps = jdata["res_decay_steps"] decay_rate = jdata["res_decay_rate"] old_ratio = jdata["res_olddata_ratio"] cmdl_args = " --restart --use-mix --old-ratio %f " % old_ratio if jdata["resnet"]: cmdl_args += " --resnet " cmdl_args += " -n " for nn in neurons: cmdl_args += "%d " % nn cmdl_args += " -c " for cv_dim in cv_dim_list: cmdl_args += "%d " % cv_dim cmdl_args += " -b " + str(batch_size) cmdl_args += " -e " + str(numb_epoches) cmdl_args += " -l " + str(starter_lr) cmdl_args += " --decay-steps " + str(decay_steps) cmdl_args += " --decay-rate " + str(decay_rate) train_cmd = "python3 {}/train.py -t {:d}".format(NN_PATH, train_thread) train_cmd += cmdl_args train_cmd = cmd_append_log(train_cmd, "train.log") freez_cmd = "python3 {}/freeze.py -o graph.pb".format(NN_PATH) freez_cmd = cmd_append_log(freez_cmd, "freeze.log") task_dirs = [("%03d" % ii) for ii in range(numb_model)] print('lib.modeling.run_train:train_cmd:', train_cmd) print('lib.modeling.run_train:freez_cmd:', freez_cmd) print('lib.modeling.run_train:train_path:', train_path) print('lib.modeling.run_train:task_dirs:', task_dirs) resources = set_resource(machine_json, target="train") machine = set_machine(machine_json, target="train") train_task = [ Task(command=train_cmd, task_work_path=ii, outlog='train.log', errlog='train.log') for ii in task_dirs ] train_submission = Submission(work_base=train_path, machine=machine, resources=resources, task_list=train_task) train_submission.run_submission() freez_task = [ Task(command=freez_cmd, task_work_path=ii, outlog='freeze.log', errlog='freeze.log') for ii in task_dirs ] freez_submission = Submission(work_base=train_path, machine=machine, resources=resources, task_list=freez_task) freez_submission.run_submission() os.chdir(train_path) for ii in range(numb_model): os.symlink("%03d/graph.pb" % ii, "graph.%03d.pb" % ii) os.chdir(cwd) print("Training finished!")