Exemplo n.º 1
0
def make_train(iter_index, json_file, base_dir="./"):
    json_file = os.path.abspath(json_file)
    fp = open(json_file, 'r')
    jdata = json.load(fp)
    fp.close()
    # template_dir = jdata["template_dir"]
    numb_model = jdata["numb_model"]
    res_iter = jdata["res_iter"]

    # abs path
    base_dir = os.path.abspath(base_dir) + "/"
    iter_name = make_iter_name(iter_index)
    train_path = base_dir + iter_name + "/" + train_name + "/"
    data_path = train_path + "data/"
    cwd = os.getcwd() + "/"
    create_path(train_path)
    os.makedirs(data_path)
    collect_data(iter_index, base_dir)

    # create train dirs
    log_task("create train dirs")
    for ii in range(numb_model):
        work_path = train_path + ("%03d/" % ii)
        old_model_path = work_path + "old_model/"
        create_path(work_path)
        os.chdir(work_path)
        os.symlink("../data", "./data")
        os.chdir(cwd)
        if iter_index >= 1:
            prev_iter_index = iter_index - 1
            prev_iter_name = make_iter_name(prev_iter_index)
            prev_train_path = base_dir + prev_iter_name + "/" + train_name + "/"
            prev_work_path = prev_train_path + ("%03d/" % ii)
            prev_model_files = glob.glob(prev_work_path + "model.ckpt.*")
            prev_model_files = prev_model_files + \
                [prev_work_path + "checkpoint"]
            create_path(old_model_path)
            os.chdir(old_model_path)
            # why to copy twice.
            for ii in prev_model_files:
                os.symlink(os.path.relpath(ii), os.path.basename(ii))
            os.chdir(cwd)
            for jj in prev_model_files:
                shutil.copy(jj, work_path)
    print("Training files have prepared.")
Exemplo n.º 2
0
def main(out_dir, mol_dir, rid_json, machine_json, cv_file, init_model, record_name="record.txt"):
    out_dir = os.path.abspath(out_dir)
    mol_dir = os.path.abspath(mol_dir)
    rid_json = os.path.abspath(rid_json)
    cv_file = os.path.abspath(cv_file)
    fp = open(rid_json, 'r')
    jdata = json.load(fp)
    fp.close()
    record_file = os.path.join(out_dir, record_name)
    checkpoint = get_checkpoint(record_file)
    max_tasks = 10
    number_tasks = 8
    iter_numb = int(jdata['numb_iter'])
    prev_model = init_model

    if sum(checkpoint) < 0:
        print("prepare gen_rid")
        gen_rid (out_dir, mol_dir, rid_json)
    
    for iter_idx in range(iter_numb):
        if iter_idx > 0 :
            prev_model = glob.glob (out_dir + "/" + make_iter_name(iter_idx-1) + "/02.train/*pb")
        for tag in range(number_tasks):
            if iter_idx * max_tasks + tag <= checkpoint[0] * max_tasks + checkpoint[1]:
                continue
            elif tag == 0:
                print("prepare gen_enhc")
                enhcMD.make_enhc(iter_idx, rid_json, prev_model, mol_dir, cv_file ,base_dir=out_dir)
                record_iter(record_file, iter_idx, tag)
            elif tag == 1:
                print("run enhanced MD")
                enhcMD.run_enhc(iter_idx, rid_json, machine_json, base_dir=out_dir)
                record_iter(record_file, iter_idx, tag)
            elif tag == 2:
                print("prepare post enhc")
                enhcMD.post_enhc(iter_idx, rid_json, machine_json, base_dir=out_dir)
                record_iter(record_file, iter_idx, tag)
            elif tag == 3:
                print("prepare gen_res")
                resMD.make_res(iter_index=iter_idx, json_file=rid_json, cv_file=cv_file, mol_path=mol_dir, base_dir=out_dir)
                record_iter(record_file, iter_idx, tag)
            elif tag == 4:
                print("prepare run res")
                resMD.run_res (iter_index=iter_idx, json_file=rid_json, machine_json=machine_json, base_dir=out_dir)
                record_iter(record_file, iter_idx, tag)
            elif tag == 5:
                print("prepare post res")
                resMD.post_res(iter_index=iter_idx, json_file=rid_json, machine_json=machine_json, cv_file=cv_file, base_dir=out_dir)
                record_iter(record_file, iter_idx, tag)
            elif tag == 6:
                print("prepare gen train")
                train.make_train(iter_index=iter_idx, json_file=rid_json, base_dir=out_dir)
                record_iter(record_file, iter_idx, tag)
            elif tag == 7:
                print("prepare run train")
                train.run_train(iter_index=iter_idx, json_file=rid_json, machine_json=machine_json, cv_file=cv_file, base_dir=out_dir)
                record_iter(record_file, iter_idx, tag)
Exemplo n.º 3
0
def run_enhc(iter_index,
             json_file,
             machine_json,
             base_dir='./'):
    json_file = os.path.abspath(json_file)
    base_dir = os.path.abspath(base_dir) + "/"
    iter_name = make_iter_name(iter_index)
    work_path = base_dir + iter_name + "/" + enhc_name + "/"

    fp = open(json_file, 'r')
    jdata = json.load(fp)
    fp.close()
    gmx_prep = jdata["gmx_prep"]
    gmx_run = jdata["gmx_run"]
    enhc_thread = jdata["enhc_thread"]
    gmx_run = gmx_run + (" -nt %d" % enhc_thread)
    gmx_prep_log = "gmx_grompp.log"
    gmx_run_log = "gmx_mdrun.log"
    # assuming at least one walker
    graph_files = glob.glob(work_path + (make_walker_name(0)) + "/*.pb")
    if len(graph_files) != 0:
        gmx_run = gmx_run + " -plumed " + enhc_plm
    else:
        gmx_run = gmx_run + " -plumed " + enhc_bf_plm
    gmx_prep_cmd = cmd_append_log(gmx_prep, gmx_prep_log)
    gmx_run_cmd = cmd_append_log(gmx_run, gmx_run_log)
    numb_walkers = jdata["numb_walkers"]

    all_task = list(filter(lambda x: os.path.isdir(
        x),  glob.glob(work_path + "/[0-9]*[0-9]")))
    all_task.sort()

    all_task_basedir = [os.path.relpath(ii, work_path) for ii in all_task]
    print('run_enhc:work_path', work_path)
    print('run_enhc:gmx_prep_cmd:', gmx_prep_cmd)
    print('run_enhc:gmx_run_cmd:', gmx_run_cmd)
    print('run_enhc:all_task:', all_task)
    print('run_enhc:all_task_basedir:', all_task_basedir)

    machine = set_machine(machine_json, target="enhcMD")
    resources = set_resource(machine_json, target="enhcMD")

    gmx_prep_task = [Task(command=gmx_prep_cmd, task_work_path=ii,
                          outlog='gmx_grompp.log', errlog='gmx_grompp.log') for ii in all_task_basedir]
    gmx_prep_submission = Submission(
        work_base=work_path, machine=machine, resources=resources, task_list=gmx_prep_task)

    gmx_prep_submission.run_submission()

    gmx_run_task = [Task(command=gmx_run_cmd, task_work_path=ii,
                         outlog='gmx_mdrun.log', errlog='gmx_mdrun.log') for ii in all_task_basedir]
    gmx_run_submission = Submission(
        work_base=work_path, machine=machine, resources=resources, task_list=gmx_run_task)
    gmx_run_submission.run_submission()
Exemplo n.º 4
0
def post_enhc(iter_index, json_file, machine_json, base_dir="./"):
    base_dir = os.path.abspath(base_dir) + "/"
    iter_name = make_iter_name(iter_index)
    work_path = base_dir + iter_name + "/" + enhc_name + "/"
    json_file = os.path.abspath(json_file)
    json_file = os.path.abspath(json_file)
    fp = open(json_file, 'r')
    jdata = json.load(fp)
    fp.close()
    gmx_split = jdata["gmx_split_traj"]
    gmx_split_log = "gmx_split.log"
    gmx_split_cmd = cmd_append_log(gmx_split, gmx_split_log)

    all_task = list(
        filter(lambda x: os.path.isdir(x),
               glob.glob(work_path + "/[0-9]*[0-9]")))
    all_task.sort()

    cwd = os.getcwd()
    numb_walkers = jdata["numb_walkers"]
    for ii in range(numb_walkers):
        walker_path = work_path + make_walker_name(ii) + "/"
        os.chdir(walker_path)
        if os.path.isdir("confs"):
            shutil.rmtree("confs")
        os.makedirs("confs")
        os.chdir(cwd)

    print('rid.py:post_enhc:gmx_split_cmd', gmx_split_cmd)
    print('rid.py:post_enhc:work path', work_path)

    machine = set_machine(machine_json, target="post_enhc")
    resources = set_resource(machine_json, target="post_enhc")
    all_task_relpath = [os.path.relpath(ii, work_path) for ii in all_task]
    gmx_split_task = [
        Task(command=gmx_split_cmd,
             task_work_path=ii,
             outlog='gmx_split.log',
             errlog='gmx_split.log') for ii in all_task_relpath
    ]
    gmx_split_submission = Submission(work_base=work_path,
                                      resources=resources,
                                      machine=machine,
                                      task_list=gmx_split_task)
    gmx_split_submission.run_submission()

    for ii in range(numb_walkers):
        walker_path = work_path + make_walker_name(ii) + "/"
        angles = np.loadtxt(walker_path + enhc_out_plm)
        np.savetxt(walker_path + enhc_out_angle, angles[:, 1:], fmt="%.6f")
    print("Post process of enhanced sampling finished.")
Exemplo n.º 5
0
def check_new_data(iter_index, train_path, base_path):
    # check if new data is empty
    new_data_file = os.path.join(train_path, 'data/data.new.raw')
    if os.stat(new_data_file).st_size == 0:
        prev_iter_index = iter_index - 1
        prev_train_path = base_path + \
            make_iter_name(prev_iter_index) + "/" + train_name + "/"
        prev_models = glob.glob(prev_train_path + "*.pb")
        for ii in prev_models:
            model_name = os.path.basename(ii)
            os.symlink(ii, os.path.join(train_path, model_name))
        return True
    else:
        return False
Exemplo n.º 6
0
def collect_data(iter_index, base_dir):
    iter_name = make_iter_name(iter_index)
    train_path = base_dir + iter_name + "/" + train_name + "/"
    data_path = train_path + "data/"
    data_file = train_path + "data/data.raw"
    data_old_file = train_path + "data/data.old.raw"
    data_new_file = train_path + "data/data.new.raw"
    cwd = os.getcwd() + "/"
    # collect data
    log_task("collect data upto %d" % (iter_index))
    if iter_index == 0:
        ii = 0
        this_raw = base_dir + make_iter_name(ii) + "/" + res_name + "/data.raw"
        os.chdir(data_path)
        os.symlink(os.path.relpath(this_raw), os.path.basename(data_new_file))
        os.symlink(os.path.basename(data_new_file),
                   os.path.basename(data_file))
        os.chdir(cwd)
        open(data_old_file, "w").close()
    else:
        prev_iter_index = iter_index - 1
        prev_data_file = base_dir + \
            make_iter_name(prev_iter_index) + "/" + \
            train_name + "/data/data.raw"
        this_raw = base_dir + \
            make_iter_name(iter_index) + "/" + res_name + "/data.raw"
        os.chdir(data_path)
        os.symlink(os.path.relpath(prev_data_file),
                   os.path.basename(data_old_file))
        os.symlink(os.path.relpath(this_raw), os.path.basename(data_new_file))
        os.chdir(cwd)
        with open(data_file, "wb") as fo:
            with open(data_old_file, "rb") as f0, open(data_new_file,
                                                       "rb") as f1:
                shutil.copyfileobj(f0, fo)
                shutil.copyfileobj(f1, fo)
Exemplo n.º 7
0
def post_res(iter_index, json_file, machine_json, cv_file, base_dir="./"):
    json_file = os.path.abspath(json_file)
    machine_json = os.path.abspath(machine_json)
    cv_file = os.path.abspath(cv_file)
    base_dir = os.path.abspath(base_dir) + "/"
    iter_name = make_iter_name(iter_index)
    res_path = base_dir + iter_name + "/" + res_name + "/"
    cwd = os.getcwd()

    fp = open(json_file, 'r')
    jdata = json.load(fp)
    fp.close()

    os.chdir(res_path)
    all_task = glob.glob("/[0-9]*[0-9]")
    all_task = list(
        filter(lambda x: os.path.isdir(x), glob.glob("[0-9]*[0-9]")))
    if len(all_task) == 0:
        np.savetxt(res_path + 'data.raw', [], fmt="%.6e")
        os.chdir(cwd)
        return
    all_task.sort()
    all_task_reldir = [os.path.relpath(ii, res_path) for ii in all_task]

    centers = []
    force = []
    ndim = 0

    _conf_file = os.path.abspath(all_task[0] + "/conf.gro")
    cv_dim_list = cal_cv_dim(_conf_file, cv_file)
    cv_dih_dim = cv_dim_list[0]

    cmpf_cmd = "python3 {}/cmpf.py".format(LIB_PATH)
    cmpf_cmd += " -c %d" % cv_dih_dim
    cmpf_log = "cmpf.log"

    print("rid.post_res.post_res:cmpf_cmd:", cmpf_cmd)

    cmpf_resources = set_resource(machine_json, target="cmpf")
    machine = set_machine(machine_json, target="cmpf")

    cmpf_task = [
        Task(command=cmpf_cmd,
             task_work_path="{}".format(ii),
             outlog=cmpf_log,
             errlog=cmpf_log) for ii in all_task_reldir
    ]
    cmpf_submission = Submission(work_base=res_path,
                                 machine=machine,
                                 resources=cmpf_resources,
                                 task_list=cmpf_task)
    cmpf_submission.run_submission()
    print('cmpf done')

    abs_res_path = os.getcwd()
    for work_path in all_task:
        os.chdir(work_path)
        this_centers = np.loadtxt('centers.out')
        centers = np.append(centers, this_centers)
        this_force = np.loadtxt('force.out')
        force = np.append(force, this_force)
        ndim = this_force.size
        assert (ndim == this_centers.size
                ), "center size is diff to force size in " + work_path
        os.chdir(abs_res_path)

    os.chdir(cwd)
    centers = np.reshape(centers, [-1, ndim])
    force = np.reshape(force, [-1, ndim])
    data = np.concatenate((centers, force), axis=1)
    np.savetxt(res_path + 'data.raw', data, fmt="%.6e")

    norm_force = np.linalg.norm(force, axis=1)
    log_task("min|f| = %e  max|f| = %e  avg|f| = %e" %
             (np.min(norm_force), np.max(norm_force), np.average(norm_force)))
    print("min|f| = %e  max|f| = %e  avg|f| = %e" %
          (np.min(norm_force), np.max(norm_force), np.average(norm_force)))
    print('Saving cmpf finished!')
    print("Post process of restrained MD finished.")
    print(os.getcwd())
Exemplo n.º 8
0
def run_res(iter_index, json_file, machine_json, base_dir="./"):
    json_file = os.path.abspath(json_file)
    fp = open(json_file, 'r')
    jdata = json.load(fp)
    fp.close()
    gmx_prep = jdata["gmx_prep"]
    gmx_run = jdata["gmx_run"]
    res_thread = jdata["res_thread"]
    gmx_run = gmx_run + (" -nt %d" % res_thread)
    gmx_run = gmx_run + " -plumed " + res_plm
    # gmx_cont_run = gmx_run + " -cpi state.cpt"
    gmx_prep_log = "gmx_grompp.log"
    gmx_run_log = "gmx_mdrun.log"
    gmx_prep_cmd = cmd_append_log(gmx_prep, gmx_prep_log)
    gmx_run_cmd = cmd_append_log(gmx_run, gmx_run_log)
    # gmx_cont_run_cmd = cmd_append_log (gmx_cont_run, gmx_run_log)

    base_dir = os.path.abspath(base_dir) + "/"
    iter_name = make_iter_name(iter_index)
    res_path = base_dir + iter_name + "/" + res_name + "/"

    if not os.path.isdir(res_path):
        raise RuntimeError("do not see any restrained simulation (%s)." %
                           res_path)

    all_task = list(
        filter(lambda x: os.path.isdir(x),
               glob.glob(res_path + "/[0-9]*[0-9]")))
    print('run_res:all_task_propose:', all_task)
    print('run_res:gmx_prep_cmd:', gmx_prep_cmd)
    print('run_res:gmx_run_cmd:', gmx_run_cmd)
    # print('run_res:gmx_cont_run_cmd:', gmx_cont_run_cmd)

    if len(all_task) == 0:
        return None
    all_task.sort()
    all_task_basedir = [os.path.relpath(ii, res_path) for ii in all_task]

    res_resources = set_resource(machine_json, target="resMD")
    machine = set_machine(machine_json, target="resMD")

    gmx_prep_task = [
        Task(command=gmx_prep_cmd,
             task_work_path=ii,
             outlog='gmx_grompp.log',
             errlog='gmx_grompp.log') for ii in all_task_basedir
    ]
    gmx_prep_submission = Submission(work_base=res_path,
                                     machine=machine,
                                     resources=res_resources,
                                     task_list=gmx_prep_task)
    gmx_prep_submission.run_submission()

    gmx_run_task = [
        Task(command=gmx_run_cmd,
             task_work_path=ii,
             outlog='gmx_mdrun.log',
             errlog='gmx_mdrun.log') for ii in all_task_basedir
    ]
    gmx_run_submission = Submission(work_base=res_path,
                                    machine=machine,
                                    resources=res_resources,
                                    task_list=gmx_run_task)
    gmx_run_submission.run_submission()
Exemplo n.º 9
0
def make_res(iter_index,
             json_file,
             cv_file,
             mol_path,
             base_dir="./",
             custom_mdp=None):

    json_file = os.path.abspath(json_file)
    fp = open(json_file, 'r')
    jdata = json.load(fp)
    fp.close()
    cv_file = os.path.abspath(cv_file)

    numb_walkers = jdata["numb_walkers"]
    bias_nsteps = jdata["bias_nsteps"]
    bias_frame_freq = jdata["bias_frame_freq"]
    nsteps = jdata["res_nsteps"]
    frame_freq = jdata["res_frame_freq"]
    sel_threshold = jdata["sel_threshold"]
    max_sel = jdata["max_sel"]
    cluster_threshold = jdata["cluster_threshold"]
    init_numb_cluster_upper = int(jdata["init_numb_cluster_upper"])
    init_numb_cluster_lower = int(jdata["init_numb_cluster_lower"])
    init_numb_cluster = [init_numb_cluster_lower, init_numb_cluster_upper]

    base_dir = os.path.abspath(base_dir) + "/"
    iter_name = make_iter_name(iter_index)
    enhc_path = base_dir + iter_name + "/" + enhc_name + "/"
    res_path = base_dir + iter_name + "/" + res_name + "/"
    create_path(res_path)

    cwd = os.getcwd()
    _conf_file = enhc_path + make_walker_name(0) + "/" + "conf.gro"
    cv_dim_list = cal_cv_dim(_conf_file, cv_file)
    cv_dim = sum(cv_dim_list)
    cv_dih_dim = cv_dim_list[0]
    ret_list = [True for ii in range(numb_walkers)]

    weight = jdata["cv_weight_for_cluster"]
    if type(weight) == list:
        assert len(
            weight
        ) == cv_dim, "Number of values in the weight list is not equal to the number of CVs."
    elif type(weight) == float or type(weight) == int:
        assert weight != 0
    else:
        raise TypeError(
            "Invalid type of weight of CVs for clustering. Please use int or list instead."
        )

    # check if we have graph in enhc
    for walker_idx in range(numb_walkers):
        cls_sel = None
        walker_path = enhc_path + walker_format % walker_idx + "/"
        graph_files = glob.glob(walker_path + "*.pb")
        if len(graph_files) != 0:
            cluster_threshold = np.loadtxt(base_dir + "cluster_threshold.dat")
            os.chdir(walker_path)
            models = glob.glob("*.pb")
            std_message = make_std(cv_dim,
                                   dataset=enhc_out_angle,
                                   models=models,
                                   threshold=sel_threshold,
                                   output="sel.out",
                                   output_angle="sel.angle.out")
            os.system('echo "{}" > sel.log'.format(std_message))
            log_task("select with threshold %f" % sel_threshold)
            os.chdir(cwd)

            sel_idx = []
            sel_angles = np.array([])
            with open(walker_path + "sel.out") as fp:
                for line in fp:
                    sel_idx += [int(x) for x in line.split()]
            if len(sel_idx) != 0:
                sel_angles = np.reshape(
                    np.loadtxt(walker_path + 'sel.angle.out'), [-1, cv_dim])
            elif len(sel_idx) == 0:
                np.savetxt(walker_path + 'num_of_cluster.dat', [0], fmt='%d')
                np.savetxt(walker_path + 'cls.sel.out', [], fmt='%d')
                continue
        else:
            cluster_threshold = jdata["cluster_threshold"]
            sel_idx = range(
                len(glob.glob(walker_path + enhc_out_conf + "conf*gro")))
            sel_angles = np.loadtxt(walker_path + enhc_out_angle)
            sel_angles = np.reshape(sel_angles, [-1, cv_dim])
            np.savetxt(walker_path + 'sel.out', sel_idx, fmt='%d')
            np.savetxt(walker_path + 'sel.angle.out', sel_angles, fmt='%.6f')
            cls_sel, cluster_threshold = make_threshold(
                walker_idx, walker_path, base_dir, sel_angles,
                cluster_threshold, init_numb_cluster, cv_dih_dim, weight)
        if cls_sel is None:
            print(sel_angles, cluster_threshold, cv_dih_dim)
            cls_sel = sel_from_cluster(sel_angles, cluster_threshold,
                                       cv_dih_dim, weight)

        conf_start = 0
        conf_every = 1

        sel_idx = np.array(sel_idx, dtype=np.int)
        assert (
            len(sel_idx) == sel_angles.shape[0]
        ), "{} selected indexes don't match {} selected angles.".format(
            len(sel_idx), sel_angles.shape[0])
        sel_idx = config_cls(sel_idx, cls_sel, max_sel, walker_path,
                             cluster_threshold, sel_angles)

        res_angles = np.loadtxt(walker_path + enhc_out_angle)
        res_angles = np.reshape(res_angles, [-1, cv_dim])
        res_angles = res_angles[sel_idx]
        np.savetxt(walker_path + 'cls.sel.out', sel_idx, fmt='%d')
        np.savetxt(walker_path + 'cls.sel.angle.out', res_angles, fmt='%.6f')
        res_confs = []
        for ii in sel_idx:
            res_confs.append(walker_path + enhc_out_conf + ("conf%d.gro" % ii))

        assert (len(res_confs) == res_angles.shape[0]
                ), "number of enhc out conf does not match out angle"
        assert (len(sel_idx) == res_angles.shape[0]
                ), "number of enhc out conf does not match number sel"
        nconf = len(res_confs)
        if nconf == 0:
            ret_list[walker_idx] = False
            continue

        sel_list = make_sel_list(nconf, sel_idx)
        log_task("selected %d confs, indexes: %s" % (nconf, sel_list))
        make_conf(nconf,
                  res_path,
                  walker_idx,
                  walker_path,
                  sel_idx,
                  jdata,
                  mol_path,
                  conf_start=0,
                  conf_every=1,
                  custom_mdp=custom_mdp)
        make_res_plumed(nconf,
                        jdata,
                        res_path,
                        walker_idx,
                        sel_idx,
                        res_angles,
                        _conf_file,
                        cv_file,
                        conf_start=0,
                        conf_every=1)
    print("Restrained MD has been prepared.")
Exemplo n.º 10
0
def make_enhc(iter_index,
              json_file,
              graph_files,
              mol_dir,
              cv_file,
              base_dir='./',
              custom_mdp=None):
    base_dir = os.path.abspath(base_dir) + "/"
    json_file = os.path.abspath(json_file)
    cv_file = os.path.abspath(cv_file)
    graph_files.sort()
    fp = open(json_file, 'r')
    jdata = json.load(fp)
    fp.close()
    numb_walkers = jdata["numb_walkers"]
    enhc_trust_lvl_1 = jdata["bias_trust_lvl_1"]
    enhc_trust_lvl_2 = jdata["bias_trust_lvl_2"]
    nsteps = jdata["bias_nsteps"]
    frame_freq = jdata["bias_frame_freq"]
    num_of_cluster_threshold = jdata["num_of_cluster_threshold"]
    dt = jdata["bias_dt"]
    temperature = jdata["bias_temperature"]

    iter_name = make_iter_name(iter_index)
    work_path = base_dir + iter_name + "/" + enhc_name + "/"
    mol_path = os.path.abspath(mol_dir) + "/"

    conf_list = glob.glob(mol_path + "*gro")
    conf_list.sort()
    assert (len(conf_list) >=
            numb_walkers), "not enough conf files in mol dir %s" % mol_path

    create_path(work_path)

    mol_files = ["topol.top"]
    for walker_idx in range(numb_walkers):
        walker_path = work_path + make_walker_name(walker_idx) + "/"
        create_path(walker_path)

        make_grompp(walker_path + "grompp.mdp", "bias", nsteps,
                    frame_freq, temperature=temperature, dt=dt, define='', custom_mdp=custom_mdp)
        # make_grompp(walker_path + "grompp_restraint.mdp", "res", nsteps, frame_freq, temperature=temperature, dt=dt, define='-DPOSRE')

        for ii in mol_files:
            checkfile(walker_path + ii)
            shutil.copy(mol_path + ii, walker_path)

        # copy conf file
        conf_file = conf_list[walker_idx]
        checkfile(walker_path + "conf.gro")
        shutil.copy(conf_file, walker_path + "conf.gro")
        checkfile(walker_path + "conf_init.gro")
        shutil.copy(conf_file, walker_path + "conf_init.gro")

        # if have prev confout.gro, use as init conf
        if iter_index > 0:
            prev_enhc_path = base_dir + \
                make_iter_name(iter_index-1) + "/" + enhc_name + \
                "/" + make_walker_name(walker_idx) + "/"
            prev_enhc_path = os.path.abspath(prev_enhc_path) + "/"
            if os.path.isfile(prev_enhc_path + "confout.gro"):
                os.remove(walker_path + "conf.gro")
                rel_prev_enhc_path = os.path.relpath(
                    prev_enhc_path + "confout.gro", walker_path)
                os.symlink(rel_prev_enhc_path, walker_path + "conf.gro")
            else:
                raise RuntimeError(
                    "cannot find prev output conf file  " + prev_enhc_path + 'confout.gro')
            log_task("use conf of iter " + make_iter_name(iter_index -
                                                          1) + " walker " + make_walker_name(walker_idx))

            enhc_trust_lvl_1, enhc_trust_lvl_2 = adjust_lvl(
                prev_enhc_path, num_of_cluster_threshold, jdata)

        np.savetxt(walker_path+'trust_lvl1.dat',
                   [enhc_trust_lvl_1], fmt='%.6f')

        make_plumed(walker_path, "dpbias", conf_file, cv_file)
        make_plumed(walker_path, "bf", conf_file, cv_file)

        prep_graph(graph_files, walker_path)
        # config plumed
        graph_list = get_graph_list(graph_files)
        conf_enhc_plumed(walker_path + enhc_plm, "enhc", graph_list, enhc_trust_lvl_1=enhc_trust_lvl_1,
                         enhc_trust_lvl_2=enhc_trust_lvl_2, frame_freq=frame_freq, enhc_out_plm=enhc_out_plm)
        conf_enhc_plumed(walker_path + enhc_bf_plm, "bf", graph_list,
                         frame_freq=frame_freq, enhc_out_plm=enhc_out_plm)

        if len(graph_list) == 0:
            log_task("brute force MD without NN acc")
        else:
            log_task("use NN model(s): " + graph_list)
            log_task("set trust l1 and l2: %f %f" %
                     (enhc_trust_lvl_1, enhc_trust_lvl_2))
    print("Enhanced sampling has prepared.")
Exemplo n.º 11
0
 def test_make_name(self):
     self.assertTrue((utils.make_iter_name(5) == "iter.000005"))
     self.assertTrue((utils.make_walker_name(5) == "005"))
     pass
Exemplo n.º 12
0
def run_train(iter_index, json_file, machine_json, cv_file, base_dir="./"):
    json_file = os.path.abspath(json_file)
    cv_file = os.path.abspath(cv_file)
    fp = open(json_file, 'r')
    jdata = json.load(fp)
    fp.close()
    cv_file = os.path.abspath(cv_file)
    numb_model = jdata["numb_model"]
    train_thread = jdata["train_thread"]
    res_iter = jdata["res_iter"]
    base_dir = os.path.abspath(base_dir) + "/"
    iter_name = make_iter_name(iter_index)
    train_path = base_dir + iter_name + "/" + train_name + "/"
    if check_new_data(iter_index, train_path, base_dir):
        return

    enhc_path = base_dir + iter_name + "/" + enhc_name + "/"
    _conf_file = enhc_path + "000/conf.gro"
    cv_dim_list = cal_cv_dim(_conf_file, cv_file)

    cwd = os.getcwd()
    neurons = jdata["neurons"]
    batch_size = jdata["batch_size"]
    if iter_index < res_iter:
        numb_epoches = jdata["numb_epoches"]
        starter_lr = jdata["starter_lr"]
        decay_steps = jdata["decay_steps"]
        decay_rate = jdata["decay_rate"]
        cmdl_args = ""
    else:
        numb_epoches = jdata["res_numb_epoches"]
        starter_lr = jdata["res_starter_lr"]
        decay_steps = jdata["res_decay_steps"]
        decay_rate = jdata["res_decay_rate"]
        old_ratio = jdata["res_olddata_ratio"]
        cmdl_args = " --restart --use-mix --old-ratio %f " % old_ratio

    if jdata["resnet"]:
        cmdl_args += " --resnet "
    cmdl_args += " -n "
    for nn in neurons:
        cmdl_args += "%d " % nn
    cmdl_args += " -c "
    for cv_dim in cv_dim_list:
        cmdl_args += "%d " % cv_dim
    cmdl_args += " -b " + str(batch_size)
    cmdl_args += " -e " + str(numb_epoches)
    cmdl_args += " -l " + str(starter_lr)
    cmdl_args += " --decay-steps " + str(decay_steps)
    cmdl_args += " --decay-rate " + str(decay_rate)

    train_cmd = "python3 {}/train.py -t {:d}".format(NN_PATH, train_thread)
    train_cmd += cmdl_args
    train_cmd = cmd_append_log(train_cmd, "train.log")
    freez_cmd = "python3 {}/freeze.py -o graph.pb".format(NN_PATH)
    freez_cmd = cmd_append_log(freez_cmd, "freeze.log")
    task_dirs = [("%03d" % ii) for ii in range(numb_model)]

    print('lib.modeling.run_train:train_cmd:', train_cmd)
    print('lib.modeling.run_train:freez_cmd:', freez_cmd)
    print('lib.modeling.run_train:train_path:', train_path)
    print('lib.modeling.run_train:task_dirs:', task_dirs)

    resources = set_resource(machine_json, target="train")
    machine = set_machine(machine_json, target="train")

    train_task = [
        Task(command=train_cmd,
             task_work_path=ii,
             outlog='train.log',
             errlog='train.log') for ii in task_dirs
    ]
    train_submission = Submission(work_base=train_path,
                                  machine=machine,
                                  resources=resources,
                                  task_list=train_task)
    train_submission.run_submission()

    freez_task = [
        Task(command=freez_cmd,
             task_work_path=ii,
             outlog='freeze.log',
             errlog='freeze.log') for ii in task_dirs
    ]
    freez_submission = Submission(work_base=train_path,
                                  machine=machine,
                                  resources=resources,
                                  task_list=freez_task)
    freez_submission.run_submission()

    os.chdir(train_path)
    for ii in range(numb_model):
        os.symlink("%03d/graph.pb" % ii, "graph.%03d.pb" % ii)
    os.chdir(cwd)

    print("Training finished!")