Example #1
0
def compute_efv(jfile):
    fp = open(jfile, 'r')
    jdata = json.load(fp)
    run_opt = RunOptions(None)
    systems = j_must_have(jdata, 'systems')
    set_pfx = j_must_have(jdata, 'set_prefix')
    batch_size = j_must_have(jdata, 'batch_size')
    test_size = j_must_have(jdata, 'numb_test')
    batch_size = 1
    test_size = 1
    rcut = j_must_have(jdata, 'rcut')

    data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt)

    tot_numb_batches = sum(data.get_nbatches())
    lr = LearingRate(jdata, tot_numb_batches)

    model = NNPModel(jdata, run_opt=run_opt)
    model.build(data, lr)

    test_data = data.get_test()

    feed_dict_test = {
        model.t_prop_c:
        test_data["prop_c"],
        model.t_energy:
        test_data["energy"][:model.numb_test],
        model.t_force:
        np.reshape(test_data["force"][:model.numb_test, :], [-1]),
        model.t_virial:
        np.reshape(test_data["virial"][:model.numb_test, :], [-1]),
        model.t_atom_ener:
        np.reshape(test_data["atom_ener"][:model.numb_test, :], [-1]),
        model.t_atom_pref:
        np.reshape(test_data["atom_pref"][:model.numb_test, :], [-1]),
        model.t_coord:
        np.reshape(test_data["coord"][:model.numb_test, :], [-1]),
        model.t_box:
        test_data["box"][:model.numb_test, :],
        model.t_type:
        np.reshape(test_data["type"][:model.numb_test, :], [-1]),
        model.t_natoms:
        test_data["natoms_vec"],
        model.t_mesh:
        test_data["default_mesh"],
        model.t_fparam:
        np.reshape(test_data["fparam"][:model.numb_test, :], [-1]),
        model.is_training:
        False
    }

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    [e, f, v] = sess.run([model.energy, model.force, model.virial],
                         feed_dict=feed_dict_test)
    return e, f, v
Example #2
0
def _main () :
    default_num_inter_threads = 0
    parser = argparse.ArgumentParser(
        description="*** Train a model. ***")
    parser.add_argument('INPUT', 
                        help='the input json database ')
    parser.add_argument('-t','--inter-threads', type = int, default = default_num_inter_threads,
                        help=
                        'With default value %d. ' % default_num_inter_threads + 
                        'Setting the "inter_op_parallelism_threads" key for the tensorflow, '  +
                        'the "intra_op_parallelism_threads" will be set by the env variable OMP_NUM_THREADS')
    parser.add_argument('--init-model', type = str, 
                        help=
                        'Initialize the model by the provided checkpoint.')
    parser.add_argument('--restart', type = str, 
                        help=
                        'Restart the training from the provided checkpoint.')
    args = parser.parse_args()

    # load json database
    fp = open (args.INPUT, 'r')
    jdata = json.load (fp)

    # init params and run options
    systems = j_must_have(jdata, 'systems')
    set_pfx = j_must_have(jdata, 'set_prefix')
    numb_sys = len(systems)
    seed = None
    if 'seed' in jdata.keys() : seed = jdata['seed']
    batch_size = j_must_have(jdata, 'batch_size')
    test_size = j_must_have(jdata, 'numb_test')
    stop_batch = j_must_have(jdata, 'stop_batch')
    rcut = j_must_have (jdata, 'rcut')
    print ("#")
    print ("# find %d system(s): " % numb_sys)    
    data = DataSystem(systems, set_pfx, batch_size, test_size, rcut)
    print ("#")
    tot_numb_batches = sum(data.get_nbatches())
    lr = LearingRate (jdata, tot_numb_batches)
    final_lr = lr.value (stop_batch)
    run_opt = RunOptions(args)
    print("# run with intra_op_parallelism_threads = %d, inter_op_parallelism_threads = %d " % 
          (run_opt.num_intra_threads, run_opt.num_inter_threads))

    # start tf
    tf.reset_default_graph()
    with tf.Session(
            config=tf.ConfigProto(intra_op_parallelism_threads=run_opt.num_intra_threads, 
                                  inter_op_parallelism_threads=run_opt.num_inter_threads
            )) as sess:
        # init the model
        model = NNPModel (sess, jdata, run_opt = run_opt)
        # build the model with stats from the first system
        model.build (data, lr)
        # train the model with the provided systems in a cyclic way
        start_time = time.time()
        cur_batch = model.get_global_step()
        print ("# start training, start lr is %e, final lr will be %e" % (lr.value(cur_batch), final_lr) )
        model.print_head()
        model.train (data, stop_batch)
        print ("# finished training")
        end_time = time.time()
        print ("# running time: %.3f s" % (end_time-start_time))
Example #3
0
def _main():
    default_num_inter_threads = 0
    parser = argparse.ArgumentParser(description="*** Train a model. ***")
    parser.add_argument('INPUT', help='the input json database ')
    parser.add_argument(
        '-t',
        '--inter-threads',
        type=int,
        default=default_num_inter_threads,
        help='With default value %d. ' % default_num_inter_threads +
        'Setting the "inter_op_parallelism_threads" key for the tensorflow, ' +
        'the "intra_op_parallelism_threads" will be set by the env variable OMP_NUM_THREADS'
    )
    parser.add_argument(
        '--init-model',
        type=str,
        help='Initialize the model by the provided checkpoint.')
    parser.add_argument(
        '--restart',
        type=str,
        help='Restart the training from the provided checkpoint.')
    args = parser.parse_args()

    # load json database
    fp = open(args.INPUT, 'r')
    jdata = json.load(fp)

    # Setup cluster for distributed training
    ps_num = j_must_have(jdata, 'ps_num')
    cluster, my_job_name, my_task_index = tf_config_from_slurm(
        ps_number=ps_num)
    cluster_spec = tf.train.ClusterSpec(cluster)
    server = tf.train.Server(server_or_cluster_def=cluster_spec,
                             job_name=my_job_name,
                             task_index=my_task_index)
    if my_job_name == "ps":
        queue = create_done_queue(cluster_spec, my_task_index)
        print("create queue")
        wait_done_queue(cluster_spec, server, queue, my_task_index)
        #server.join()
    elif my_job_name == "worker":
        is_chief = (my_task_index == 0)
        done_ops = connect_done_queue(cluster_spec, my_task_index)

        # init params and run options
        systems = j_must_have(jdata, 'systems')
        set_pfx = j_must_have(jdata, 'set_prefix')
        numb_sys = len(systems)
        seed = None
        if 'seed' in jdata.keys(): seed = jdata['seed']
        batch_size = j_must_have(jdata, 'batch_size')
        test_size = j_must_have(jdata, 'numb_test')
        stop_batch = j_must_have(jdata, 'stop_batch')
        rcut = j_must_have(jdata, 'rcut')
        data = DataSystem(systems, set_pfx, batch_size, test_size, rcut)
        tot_numb_batches = sum(data.get_nbatches())
        lr = LearingRate(jdata, tot_numb_batches)
        final_lr = lr.value(stop_batch)
        run_opt = RunOptions(args)
        if is_chief:
            print("#")
            print("# find %d system(s): " % numb_sys)
            print("#")
            print(
                "# run with intra_op_parallelism_threads = %d, inter_op_parallelism_threads = %d "
                % (run_opt.num_intra_threads, run_opt.num_inter_threads))
        run_opt.cluster = cluster_spec
        run_opt.server = server
        run_opt.is_chief = is_chief
        run_opt.my_job_name = my_job_name
        run_opt.my_task_index = my_task_index

        # init the model
        model = NNPModel(jdata, run_opt=run_opt)
        # build the model with stats from the first system
        model.build(data, lr)
        start_time = time.time()
        cur_batch = 0
        if is_chief:
            print("# start training, start lr is %e, final lr will be %e" %
                  (lr.value(cur_batch), final_lr))
            sys.stdout.flush()
            #model.print_head()
        # train the model with the provided systems in a cyclic way
        model.train(data, stop_batch)
        end_time = time.time()
        if is_chief:
            print("# finished training")
            print("# running time: %.3f s" % (end_time - start_time))
        fill_done_queue(cluster_spec, server, done_ops, my_task_index)