Exemple #1
0
def run_temp(iter_index, json_file):
    iter_name = make_iter_name(iter_index)
    work_path = iter_name + "/" + temp_name + "/"

    fp = open(json_file, 'r')
    jdata = json.load(fp)
    gmx_prep = jdata["gmx_prep"]
    gmx_run = jdata["gmx_run"]
    temp_thread = jdata["temp_thread"]
    gmx_run = gmx_run + (" -nt %d " % temp_thread)
    gmx_prep_log = "gmx_grompp.log"
    gmx_run_log = "gmx_mdrun.log"
    gmx_run = gmx_run + " -plumed " + temp_plm
    gmx_prep_cmd = cmd_append_log(gmx_prep, gmx_prep_log)
    gmx_run_cmd = cmd_append_log(gmx_run, gmx_run_log)
    numb_walkers = jdata["numb_walkers"]

    all_task = glob.glob(work_path + "/[0-9]*[0-9]")
    all_task.sort()

    global exec_machine
    exec_hosts(MachineLocal, gmx_prep_cmd, 1, all_task, None)
    if len(all_task) == 1:
        exec_hosts(MachineLocal, gmx_run_cmd, temp_thread, all_task, None)
    else:
        exec_hosts_batch(exec_machine, gmx_run_cmd, temp_thread, all_task,
                         None)
Exemple #2
0
def run_res (iter_index,
             json_file, 
             exec_machine = MachineLocal) :
    fp = open (json_file, 'r')
    jdata = json.load (fp)
    gmx_prep = jdata["gmx_prep"]
    gmx_run = jdata["gmx_run"]
    res_thread = jdata["res_thread"]
    gmx_run = gmx_run + (" -nt %d " % res_thread)
    gmx_run = gmx_run + " -plumed " + res_plm
    gmx_cont_run = gmx_run + " -cpi state.cpt "
    gmx_prep_log = "gmx_grompp.log"
    gmx_run_log = "gmx_mdrun.log"
    gmx_prep_cmd = cmd_append_log (gmx_prep, gmx_prep_log)
    gmx_run_cmd = cmd_append_log (gmx_run, gmx_run_log)
    gmx_cont_run_cmd = cmd_append_log (gmx_cont_run, gmx_run_log)
    res_group_size = jdata['res_group_size']
    batch_jobs = jdata['batch_jobs']
    batch_time_limit = jdata['batch_time_limit']
    batch_modules = jdata['batch_modules']
    batch_sources = jdata['batch_sources']
    
    iter_name = make_iter_name (iter_index)
    res_path = iter_name + "/" + res_name + "/"  
    base_path = os.getcwd() + "/"

    if not os.path.isdir (res_path) : 
        raise RuntimeError ("do not see any restrained simulation (%s)." % res_path)

    all_task_propose = glob.glob(res_path + "/[0-9]*[0-9]")
    if len(all_task_propose) == 0 :
        return
    all_task_propose.sort()
    if batch_jobs :
        all_task = all_task_propose
    else :
        all_task = []
        all_cont_task = []
        for ii in all_task_propose :
            if not os.path.isfile(os.path.join(ii, "confout.gro")) :
                if os.path.isfile(os.path.join(ii, "state.cpt")) :
                    all_cont_task.append(ii)
                else :
                    all_task.append(ii)

    if batch_jobs:
        exec_hosts (MachineLocal, gmx_prep_cmd, 1, all_task, None)
        exec_batch_group(gmx_run_cmd, res_thread, 1, all_task, task_args = None, group_size = res_group_size, time_limit = batch_time_limit, modules = batch_modules, sources = batch_sources)
    else :
        if len(all_task) == 1 :
            exec_hosts (MachineLocal, gmx_prep_cmd, 1, all_task, None)
            exec_hosts (MachineLocal, gmx_run_cmd, res_thread, all_task, None)
        elif len(all_task) > 1 :
            exec_hosts (MachineLocal, gmx_prep_cmd, 1, all_task, None)
            exec_hosts_batch (exec_machine, gmx_run_cmd, res_thread, all_task, None)
        if len(all_cont_task) == 1 :
            exec_hosts (MachineLocal, gmx_cont_run_cmd, res_thread, all_cont_task, None)
        elif len(all_cont_task) > 1 :
            exec_hosts_batch (exec_machine, gmx_cont_run_cmd, res_thread, all_cont_task, None)
Exemple #3
0
def run_enhc(iter_index, json_file):
    iter_name = make_iter_name(iter_index)
    work_path = iter_name + "/" + enhc_name + "/"

    fp = open(json_file, 'r')
    jdata = json.load(fp)
    gmx_prep = jdata["gmx_prep"]
    gmx_run = jdata["gmx_run"]
    enhc_thread = jdata["bias_thread"]
    gmx_run = gmx_run + (" -nt %d " % enhc_thread)
    gmx_prep_log = "gmx_grompp.log"
    gmx_run_log = "gmx_mdrun.log"
    # assuming at least one walker
    graph_files = glob.glob(work_path + (make_walker_name(0)) + "/*.pb")
    if len(graph_files) != 0:
        gmx_run = gmx_run + " -plumed " + enhc_plm
    else:
        gmx_run = gmx_run + " -plumed " + enhc_bf_plm
    gmx_prep_cmd = cmd_append_log(gmx_prep, gmx_prep_log)
    gmx_run_cmd = cmd_append_log(gmx_run, gmx_run_log)
    numb_walkers = jdata["numb_walkers"]
    batch_jobs = jdata['batch_jobs']
    batch_time_limit = jdata['batch_time_limit']
    batch_modules = jdata['batch_modules']
    batch_sources = jdata['batch_sources']

    all_task = glob.glob(work_path + "/[0-9]*[0-9]")
    all_task.sort()

    global exec_machine
    exec_hosts(MachineLocal, gmx_prep_cmd, 1, all_task, None)
    if batch_jobs:
        exec_batch(gmx_run_cmd,
                   enhc_thread,
                   1,
                   all_task,
                   task_args=None,
                   time_limit=batch_time_limit,
                   modules=batch_modules,
                   sources=batch_sources)
    else:
        if len(all_task) == 1:
            exec_hosts(MachineLocal, gmx_run_cmd, enhc_thread, all_task, None)
        else:
            exec_hosts_batch(exec_machine, gmx_run_cmd, enhc_thread, all_task,
                             None)
Exemple #4
0
def run_enhc(iter_index, json_file):
    iter_name = make_iter_name(iter_index)
    work_path = iter_name + "/" + enhc_name + "/"

    fp = open(json_file, 'r')
    jdata = json.load(fp)
    bPosre = jdata.get("gmx_posre", False)
    gmx_prep = jdata["gmx_prep"]
    if bPosre:
        gmx_prep += " -f grompp_restraint.mdp -r conf_init.gro"
    gmx_run = jdata["gmx_run"]
    enhc_thread = jdata["bias_thread"]
    gmx_run = gmx_run + (" -nt %d " % enhc_thread)
    gmx_prep_log = "gmx_grompp.log"
    gmx_run_log = "gmx_mdrun.log"
    # assuming at least one walker
    graph_files = glob.glob(work_path + (make_walker_name(0)) + "/*.pb")
    if len(graph_files) != 0:
        gmx_run = gmx_run + " -plumed " + enhc_plm
    else:
        gmx_run = gmx_run + " -plumed " + enhc_bf_plm
    gmx_prep_cmd = cmd_append_log(gmx_prep, gmx_prep_log)
    gmx_run_cmd = cmd_append_log(gmx_run, gmx_run_log)
    numb_walkers = jdata["numb_walkers"]
    batch_jobs = jdata['batch_jobs']
    batch_time_limit = jdata['batch_time_limit']
    batch_modules = jdata['batch_modules']
    batch_sources = jdata['batch_sources']

    # print('debug', glob.glob(work_path + "/[0-9]*[0-9]"))
    # all_task = glob.glob(work_path + "/[0-9]*[0-9]")
    all_task = list(
        filter(lambda x: os.path.isdir(x),
               glob.glob(work_path + "/[0-9]*[0-9]")))
    all_task.sort()

    # all_task_basedir = [os.path.relpath(ii, work_path) for ii in all_task]
    # print('run_enhc:work_path', work_path)
    # print('run_enhc:gmx_prep_cmd:', gmx_prep_cmd)
    # print('run_enhc:gmx_run_cmd:', gmx_run_cmd)
    # print('run_enhc:all_task:', all_task)
    # print('run_enhc:all_task_basedir:', all_task_basedir)
    # print('run_enhc:batch_jobs:', batch_jobs)
    #
    # lazy_local_context = LazyLocalContext(local_root='./', work_profile=None)
    # # pbs = PBS(context=lazy_local_context)
    # slurm = Slurm(context=lazy_local_context)
    # gmx_prep_task = [Task(command=gmx_prep_cmd, task_work_path=ii, outlog='gmx_grompp.log', errlog='gmx_grompp.err') for
    #                  ii in all_task_basedir]
    # gmx_prep_submission = Submission(work_base=work_path, resources=resources, batch=slurm, task_list=gmx_prep_task)
    #
    # gmx_prep_submission.run_submission()
    #
    # gmx_run_task = [Task(command=gmx_run_cmd, task_work_path=ii, outlog='gmx_mdrun.log', errlog='gmx_mdrun.log') for ii
    #                 in all_task_basedir]
    # gmx_run_submission = Submission(work_base=work_path, resources=resources, batch=slurm, task_list=gmx_run_task)
    # gmx_run_submission.run_submission()

    global exec_machine
    exec_hosts(MachineLocal, gmx_prep_cmd, 1, all_task, None)
    if batch_jobs:
        exec_batch(gmx_run_cmd,
                   enhc_thread,
                   1,
                   all_task,
                   task_args=None,
                   time_limit=batch_time_limit,
                   modules=batch_modules,
                   sources=batch_sources)
    else:
        if len(all_task) == 1:
            exec_hosts(MachineLocal, gmx_run_cmd, enhc_thread, all_task, None)
        else:
            exec_hosts_batch(exec_machine, gmx_run_cmd, enhc_thread, all_task,
                             None)
Exemple #5
0
def run_train(iter_index, json_file, exec_machine=MachineLocal):
    fp = open(json_file, 'r')
    jdata = json.load(fp)
    numb_model = jdata["numb_model"]
    train_thread = jdata["train_thread"]
    res_iter = jdata["res_iter"]

    iter_name = make_iter_name(iter_index)
    train_path = iter_name + "/" + train_name + "/"
    base_path = os.getcwd() + "/"

    # check if new data is empty
    new_data_file = os.path.join(train_path, 'data/data.new.raw')
    if os.stat(new_data_file).st_size == 0:
        prev_iter_index = iter_index - 1
        prev_train_path = base_path + make_iter_name(
            prev_iter_index) + "/" + train_name + "/"
        prev_models = glob.glob(prev_train_path + "*.pb")
        for ii in prev_models:
            model_name = os.path.basename(ii)
            os.symlink(ii, os.path.join(train_path, model_name))
        return

    neurons = jdata["neurons"]
    batch_size = jdata["batch_size"]
    if iter_index < res_iter:
        numb_epoches = jdata["numb_epoches"]
        starter_lr = jdata["starter_lr"]
        decay_steps = jdata["decay_steps"]
        decay_rate = jdata["decay_rate"]
        cmdl_args = ""
    else:
        numb_epoches = jdata["res_numb_epoches"]
        starter_lr = jdata["res_starter_lr"]
        decay_steps = jdata["res_decay_steps"]
        decay_rate = jdata["res_decay_rate"]
        old_ratio = jdata["res_olddata_ratio"]
        cmdl_args = " --restart --use-mix --old-ratio %f " % old_ratio

    if jdata["resnet"]:
        cmdl_args += " --resnet "
    cmdl_args += " -n "
    for nn in neurons:
        cmdl_args += "%d " % nn
    cmdl_args += " -b " + str(batch_size)
    cmdl_args += " -e " + str(numb_epoches)
    cmdl_args += " -l " + str(starter_lr)
    cmdl_args += " --decay-steps " + str(decay_steps)
    cmdl_args += " --decay-rate " + str(decay_rate)

    train_cmd = "../main.py -t %d" % train_thread
    train_cmd += cmdl_args
    train_cmd = cmd_append_log(train_cmd, "train.log")
    freez_cmd = "../freeze.py -o graph.pb"
    freez_cmd = cmd_append_log(freez_cmd, "freeze.log")
    task_dirs = [("%03d" % ii) for ii in range(numb_model)]

    batch_jobs = jdata['batch_jobs']
    batch_time_limit = jdata['batch_time_limit']
    batch_modules = jdata['batch_modules']
    batch_sources = jdata['batch_sources']

    os.chdir(train_path)
    if batch_jobs:
        exec_batch(train_cmd,
                   train_thread,
                   1,
                   task_dirs,
                   task_args=None,
                   time_limit=batch_time_limit,
                   modules=batch_modules,
                   sources=batch_sources)
    else:
        if len(task_dirs) == 1:
            exec_hosts(MachineLocal, train_cmd, train_thread, task_dirs, None)
        else:
            exec_hosts_batch(exec_machine, train_cmd, train_thread, task_dirs,
                             None)

    exec_hosts(MachineLocal, freez_cmd, 1, task_dirs, None)
    for ii in range(numb_model):
        os.symlink("%03d/graph.pb" % ii, "graph.%03d.pb" % ii)
    os.chdir(base_path)
Exemple #6
0
def run_train(iter_index,
              json_file,
              exec_machine=MachineLocal,
              data_dir="data",
              data_name="data",
              sits_iter=False):
    fp = open(json_file, 'r')
    jdata = json.load(fp)
    cmd_env = jdata.get("cmd_sources", [])
    sits_param = jdata.get("sits_settings", None)

    numb_model = jdata["numb_model"]
    train_thread = jdata["train_thread"]
    res_iter = jdata["res_iter"]

    iter_name = make_iter_name(iter_index)
    if sits_param is not None:
        if sits_iter:
            iter_name = join("sits", make_iter_name(iter_index))
    train_path = join(iter_name, train_name)
    base_path = os.getcwd() + "/"

    # check if new data is empty
    new_data_file = os.path.join(train_path, data_dir, data_name + '.new.raw')
    filesize = os.stat(new_data_file).st_size if os.path.exists(
        new_data_file) else 0
    if (filesize == 0) & (not sits_iter):
        prev_iter_index = iter_index - 1
        prev_train_path = join(base_path, make_iter_name(prev_iter_index),
                               train_name) + "/"
        prev_models = glob.glob(join(prev_train_path, "*.pb"))
        for ii in prev_models:
            model_name = os.path.basename(ii)
            os.symlink(ii, join(train_path, model_name))
        return

    neurons = jdata["neurons"]
    batch_size = jdata["batch_size"]
    if iter_index < res_iter:
        numb_epoches = jdata["numb_epoches"]
        starter_lr = jdata["starter_lr"]
        decay_steps = jdata["decay_steps"]
        decay_rate = jdata["decay_rate"]
        cmdl_args = ""
    else:
        numb_epoches = jdata["res_numb_epoches"]
        starter_lr = jdata["res_starter_lr"]
        decay_steps = jdata["res_decay_steps"]
        decay_rate = jdata["res_decay_rate"]
        old_ratio = jdata["res_olddata_ratio"]
        cmdl_args = " --restart --use-mix --old-ratio %f " % old_ratio

    if jdata["resnet"]:
        cmdl_args += " --resnet "
    cmdl_args += " -n "
    for nn in neurons:
        cmdl_args += "%d " % nn
    cmdl_args += " -b " + str(batch_size)
    cmdl_args += " -e " + str(numb_epoches)
    cmdl_args += " -l " + str(starter_lr)
    cmdl_args += " --decay-steps " + str(decay_steps)
    cmdl_args += " --decay-rate " + str(decay_rate)

    train_cmd = "../main.py -t %d" % train_thread
    train_cmd += cmdl_args
    train_cmd = cmd_append_log(train_cmd, "train.log", env=cmd_env)
    freez_cmd = "../freeze.py -o graph.pb"
    freez_cmd = cmd_append_log(freez_cmd, "freeze.log", env=cmd_env)
    task_dirs = [("%03d" % ii) for ii in range(numb_model)]

    batch_jobs = jdata['batch_jobs']
    batch_time_limit = jdata['batch_time_limit']
    batch_modules = jdata['batch_modules']
    batch_sources = jdata['batch_sources']

    # print('lib.modeling.run_train:train_cmd:', train_cmd)
    # print('lib.modeling.run_train:freez_cmd:', freez_cmd)
    # print('lib.modeling.run_train:train_path:', train_path)
    # print('lib.modeling.run_train:task_dirs:', task_dirs)

    # lazy_local_context = LazyLocalContext(local_root='./', work_profile=None)
    # # pbs = PBS(context=lazy_local_context)
    # slurm = Slurm(context=lazy_local_context)

    # train_task = [Task(command=train_cmd, task_work_path=ii, outlog='train.log', errlog='train.log') for ii in
    #               task_dirs]
    # train_submission = Submission(work_base=train_path, resources=resources, batch=slurm, task_list=train_task)
    # train_submission.run_submission()

    # freez_task = [Task(command=freez_cmd, task_work_path=ii, outlog='freeze.log', errlog='freeze.log') for ii in
    #               task_dirs]
    # freez_submission = Submission(work_base=train_path, resources=resources, batch=slurm, task_list=freez_task)
    # freez_submission.run_submission()

    os.chdir(train_path)
    if batch_jobs:
        exec_batch(train_cmd,
                   train_thread,
                   1,
                   task_dirs,
                   task_args=None,
                   time_limit=batch_time_limit,
                   modules=batch_modules,
                   sources=batch_sources)
    else:
        if len(task_dirs) == 1:
            exec_hosts(MachineLocal, train_cmd, train_thread, task_dirs, None)
        else:
            exec_hosts_batch(exec_machine, train_cmd, train_thread, task_dirs,
                             None)

    # exec_hosts(MachineLocal, freez_cmd, 1, task_dirs, None)
    for task_dir in task_dirs:
        exec_hosts(MachineLocal, freez_cmd, 1, [task_dir], None)
    for ii in range(numb_model):
        os.symlink("%03d/graph.pb" % ii, "graph.%03d.pb" % ii)
    os.chdir(base_path)
Exemple #7
0
def run_res(iter_index, json_file, exec_machine=MachineLocal):
    fp = open(json_file, 'r')
    jdata = json.load(fp)
    cmd_env = jdata.get("cmd_sources", [])
    sits_param = jdata.get("sits_settings", None)
    gmx_prep = jdata["gmx_prep"]
    bPosre = jdata.get("gmx_posre", False)
    if sits_param is not None:
        if not bPosre:
            gmx_prep += " -f grompp_sits.mdp"
        else:
            gmx_prep += " -f grompp_sits_restraint.mdp -r conf_init.gro"
        if sits_param.get("sits_energrp", None) not in ["Protein", "MOL"]:
            gmx_prep += " -n index.ndx"
    gmx_run = jdata["gmx_run"]
    res_thread = jdata["res_thread"]
    gmx_run = gmx_run + (" -nt %d " % res_thread)
    gmx_run = gmx_run + " -plumed " + res_plm
    gmx_cont_run = gmx_run + " -cpi state.cpt "
    gmx_prep_log = "gmx_grompp.log"
    gmx_run_log = "gmx_mdrun.log"
    gmx_prep_cmd = cmd_append_log(gmx_prep, gmx_prep_log, env=cmd_env)
    gmx_run_cmd = cmd_append_log(gmx_run, gmx_run_log, env=cmd_env)
    gmx_cont_run_cmd = cmd_append_log(gmx_cont_run, gmx_run_log, env=cmd_env)
    res_group_size = jdata['res_group_size']
    batch_jobs = jdata['batch_jobs']
    batch_time_limit = jdata['batch_time_limit']
    batch_modules = jdata['batch_modules']
    batch_sources = jdata['batch_sources']

    iter_name = make_iter_name(iter_index)
    res_path = iter_name + "/" + res_name + "/"
    base_path = os.getcwd() + "/"

    if not os.path.isdir(res_path):
        raise RuntimeError("do not see any restrained simulation (%s)." %
                           res_path)

    # all_task_propose = glob.glob(res_path + "/[0-9]*[0-9]")
    # assume that
    # TODO
    all_task_propose = list(
        filter(lambda x: os.path.isdir(x),
               glob.glob(res_path + "/[0-9]*[0-9]")))
    # print('lib.modeling.run_res:all_task_propose:', all_task)
    # print('lib.modeling.run_res:gmx_prep_cmd:', gmx_prep_cmd)
    # print('lib.modeling.run_res:gmx_run_cmd:', gmx_run_cmd)
    # print('lib.modeling.run_res:gmx_cont_run_cmd:', gmx_cont_run_cmd)
    # raise RuntimeError('lib.modeling.run_res:debug')

    if len(all_task_propose) == 0:
        return
    all_task_propose.sort()
    if batch_jobs:
        all_task = all_task_propose
    else:
        all_task = []
        all_cont_task = []
        for ii in all_task_propose:
            if not os.path.isfile(os.path.join(ii, "confout.gro")):
                if os.path.isfile(os.path.join(ii, "state.cpt")):
                    all_cont_task.append(ii)
                else:
                    all_task.append(ii)
    # if len(all_task) == 0:
    #     return None
    # all_task.sort()

    # all_task_basedir = [os.path.relpath(ii, res_path) for ii in all_task]
    # lazy_local_context = LazyLocalContext(local_root='./', work_profile=None)
    # slurm = Slurm(context=lazy_local_context)
    # # pbs = PBS(context=lazy_local_context)

    # gmx_prep_task = [Task(command=gmx_prep_cmd, task_work_path=ii, outlog='gmx_grompp.log', errlog='gmx_grompp.err') for
    #                  ii in all_task_basedir]
    # gmx_prep_submission = Submission(work_base=res_path, resources=res_resources, batch=slurm, task_list=gmx_prep_task)
    # gmx_prep_submission.run_submission()

    # gmx_run_task = [Task(command=gmx_run_cmd, task_work_path=ii, outlog='gmx_mdrun.log', errlog='gmx_mdrun.log') for ii
    #                 in all_task_basedir]
    # gmx_run_submission = Submission(work_base=res_path, resources=res_resources, batch=slurm, task_list=gmx_run_task)
    # gmx_run_submission.run_submission()

    if batch_jobs:
        exec_hosts(MachineLocal, gmx_prep_cmd, 1, all_task, None)
        exec_batch_group(gmx_run_cmd,
                         res_thread,
                         1,
                         all_task,
                         task_args=None,
                         group_size=res_group_size,
                         time_limit=batch_time_limit,
                         modules=batch_modules,
                         sources=batch_sources)
    else:
        if len(all_task) == 1:
            exec_hosts(MachineLocal, gmx_prep_cmd, 1, all_task, None)
            exec_hosts(MachineLocal, gmx_run_cmd, res_thread, all_task, None)
        elif len(all_task) > 1:
            exec_hosts(MachineLocal, gmx_prep_cmd, 1, all_task, None)
            exec_hosts_batch(exec_machine, gmx_run_cmd, res_thread, all_task,
                             None)
        if len(all_cont_task) == 1:
            exec_hosts(MachineLocal, gmx_cont_run_cmd, res_thread,
                       all_cont_task, None)
        elif len(all_cont_task) > 1:
            exec_hosts_batch(exec_machine, gmx_cont_run_cmd, res_thread,
                             all_cont_task, None)
Exemple #8
0
def train_ori(iter_index,
              json_file,
              exec_machine=MachineLocal,
              data_dir="data",
              data_name="data000"):
    fp = open(json_file, 'r')
    jdata = json.load(fp)
    cmd_env = jdata.get("cmd_sources", [])
    sits_param = jdata.get("sits_settings", None)
    res_cmpf_error = jdata["res_cmpf_error"]

    train_ori_name = "03.train_ori"
    iter_name = make_iter_name(iter_index)
    res_path = iter_name + "/" + res_name + "/"
    base_path = os.getcwd() + "/"

    all_task = glob.glob(res_path + "/[0-9]*[0-9]")
    if len(all_task) == 0:
        np.savetxt(res_path + data_name + '.raw', [], fmt="%.6e")
    else:
        all_task.sort()
        centers = []
        force = []
        ndim = 0

        for work_path in all_task:
            os.chdir(work_path)
            this_centers = np.loadtxt('centers.out')
            centers = np.append(centers, this_centers)
            this_force = np.loadtxt('force_000.out')
            force = np.append(force, this_force)
            ndim = this_force.size
            assert (ndim == this_centers.size
                    ), "center size is diff to force size in " + work_path
            os.chdir(base_path)

        centers = np.reshape(centers, [-1, ndim])
        force = np.reshape(force, [-1, ndim])
        data = np.concatenate((centers, force), axis=1)
        np.savetxt(res_path + data_name + '.raw', data, fmt="%.6e")

        norm_force = np.linalg.norm(force, axis=1)
        log_task(
            "min|f| = %e  max|f| = %e  avg|f| = %e" %
            (np.min(norm_force), np.max(norm_force), np.average(norm_force)))

    template_dir = jdata["template_dir"]
    numb_model = jdata["numb_model"]
    res_iter = jdata["res_iter"]

    iter_name = make_iter_name(iter_index)

    train_path = join(iter_name, train_ori_name)
    data_path = join(train_path, data_dir)

    data_file = join(data_path, data_name + ".raw")
    data_old_file = join(data_path, data_name + ".old.raw")
    data_new_file = join(data_path, data_name + ".new.raw")
    templ_train_path = join(template_dir, train_name)

    create_path(train_path)
    os.makedirs(data_path)
    copy_file_list(train_files, templ_train_path, train_path)
    replace(join(train_path, "model.py"), "\./data", "./" + data_dir)
    replace(join(train_path, "model.py"), "data\.", data_name + ".")
    replace(join(train_path, "main.py"), "\./data", "./" + data_dir)
    replace(join(train_path, "main.py"), "data\.raw", data_name + ".raw")

    # collect data
    log_task("collect data upto %d" % (iter_index))
    if iter_index == 0:
        ii = 0
        this_raw = join(base_path, make_iter_name(ii), res_name,
                        data_name + ".raw")
        os.chdir(data_path)
        os.symlink(os.path.relpath(this_raw), os.path.basename(data_new_file))
        os.symlink(os.path.basename(data_new_file),
                   os.path.basename(data_file))
        os.chdir(base_path)
        open(data_old_file, "w").close()
    else:
        prev_iter_index = iter_index - 1
        prev_data_file = join(base_path, make_iter_name(prev_iter_index),
                              train_ori_name, data_dir, data_name + ".raw")
        this_raw = join(base_path, make_iter_name(iter_index), res_name,
                        data_name + ".raw")
        os.chdir(data_path)
        os.symlink(os.path.relpath(prev_data_file),
                   os.path.basename(data_old_file))
        os.symlink(os.path.relpath(this_raw), os.path.basename(data_new_file))
        os.chdir(base_path)
        with open(data_file, "wb") as fo:
            with open(data_old_file, "rb") as f0, open(data_new_file,
                                                       "rb") as f1:
                shutil.copyfileobj(f0, fo)
                shutil.copyfileobj(f1, fo)

    # create train dirs
    log_task("create train dirs")
    for ii in range(numb_model):
        work_path = join(train_path, ("%03d" % ii))
        old_model_path = join(work_path, "old_model")

        create_path(work_path)
        os.chdir(work_path)
        os.symlink(join("..", data_dir), data_dir)
        os.chdir(base_path)
        if iter_index >= 1:
            prev_iter_index = iter_index - 1
            prev_iter_name = make_iter_name(prev_iter_index)
            prev_train_path = prev_iter_name + "/" + train_ori_name + "/"
            prev_train_path = os.path.abspath(prev_train_path) + "/"
            prev_work_path = prev_train_path + ("%03d/" % ii)
            prev_model_files = glob.glob(
                join(prev_work_path,
                     "model.ckpt.*")) + [join(prev_work_path, "checkpoint")]
            # prev_model_files += [join(prev_work_path, "checkpoint")]
            create_path(old_model_path)
            os.chdir(old_model_path)
            for ii in prev_model_files:
                os.symlink(os.path.relpath(ii), os.path.basename(ii))
                # shutil.copy (ii, old_model_path)
            os.chdir(base_path)
            for ii in prev_model_files:
                shutil.copy(ii, work_path)

    numb_model = jdata["numb_model"]
    train_thread = jdata["train_thread"]
    res_iter = jdata["res_iter"]

    iter_name = make_iter_name(iter_index)
    # if sits_param is not None:
    #     if sits_iter:
    #         iter_name = join("sits", make_iter_name(iter_index))
    base_path = os.getcwd() + "/"

    # check if new data is empty
    new_data_file = os.path.join(train_path, data_dir, data_name + '.new.raw')
    filesize = os.stat(new_data_file).st_size if os.path.exists(
        new_data_file) else 0
    if filesize == 0:
        prev_iter_index = iter_index - 1
        prev_train_path = join(base_path, make_iter_name(prev_iter_index),
                               train_ori_name) + "/"
        prev_models = glob.glob(join(prev_train_path, "*.pb"))
        for ii in prev_models:
            model_name = os.path.basename(ii)
            os.symlink(ii, join(train_path, model_name))
    else:
        neurons = jdata["neurons"]
        batch_size = jdata["batch_size"]
        if iter_index < res_iter:
            numb_epoches = jdata["numb_epoches"]
            starter_lr = jdata["starter_lr"]
            decay_steps = jdata["decay_steps"]
            decay_rate = jdata["decay_rate"]
            cmdl_args = ""
        else:
            numb_epoches = jdata["res_numb_epoches"]
            starter_lr = jdata["res_starter_lr"]
            decay_steps = jdata["res_decay_steps"]
            decay_rate = jdata["res_decay_rate"]
            old_ratio = jdata["res_olddata_ratio"]
            cmdl_args = " --restart --use-mix --old-ratio %f " % old_ratio

        if jdata["resnet"]:
            cmdl_args += " --resnet "
        cmdl_args += " -n "
        for nn in neurons:
            cmdl_args += "%d " % nn
        cmdl_args += " -b " + str(batch_size)
        cmdl_args += " -e " + str(numb_epoches)
        cmdl_args += " -l " + str(starter_lr)
        cmdl_args += " --decay-steps " + str(decay_steps)
        cmdl_args += " --decay-rate " + str(decay_rate)

        train_cmd = "../main.py -t %d" % train_thread
        train_cmd += cmdl_args
        train_cmd = cmd_append_log(train_cmd, "train.log", env=cmd_env)
        freez_cmd = "../freeze.py -o graph.pb"
        freez_cmd = cmd_append_log(freez_cmd, "freeze.log", env=cmd_env)
        task_dirs = [("%03d" % ii) for ii in range(numb_model)]

        batch_jobs = jdata['batch_jobs']
        batch_time_limit = jdata['batch_time_limit']
        batch_modules = jdata['batch_modules']
        batch_sources = jdata['batch_sources']

        os.chdir(train_path)
        if batch_jobs:
            exec_batch(train_cmd,
                       train_thread,
                       1,
                       task_dirs,
                       task_args=None,
                       time_limit=batch_time_limit,
                       modules=batch_modules,
                       sources=batch_sources)
        else:
            if len(task_dirs) == 1:
                exec_hosts(MachineLocal, train_cmd, train_thread, task_dirs,
                           None)
            else:
                exec_hosts_batch(exec_machine, train_cmd, train_thread,
                                 task_dirs, None)

        # exec_hosts(MachineLocal, freez_cmd, 1, task_dirs, None)
        for task_dir in task_dirs:
            exec_hosts(MachineLocal, freez_cmd, 1, [task_dir], None)
        for ii in range(numb_model):
            os.symlink("%03d/graph.pb" % ii, "graph.%03d.pb" % ii)
        os.chdir(base_path)