コード例 #1
0
ファイル: train.py プロジェクト: thunderdruid/deepmd-kit
def _main():
    parser = argparse.ArgumentParser(description="*** Train a model. ***")
    parser.add_argument('INPUT', help='the input json database ')
    args = parser.parse_args()

    # load json database
    fp = open(args.INPUT, 'r')
    jdata = json.load(fp)

    # init params
    systems = j_must_have(jdata, 'systems')
    set_pfx = j_must_have(jdata, 'set_prefix')
    numb_sys = len(systems)
    seed = None
    if 'seed' in jdata.keys(): seed = jdata['seed']
    num_threads = j_must_have(jdata, 'num_threads')
    batch_size = j_must_have(jdata, 'batch_size')
    stop_batch = j_must_have(jdata, 'stop_batch')
    tot_numb_batches = 0
    print("#")
    print("# using %d system(s): " % numb_sys)
    for _sys in systems:
        s_data = DataScan(_sys, set_pfx)
        numb_batches = s_data.get_sys_numb_batch(batch_size)
        tot_numb_batches += numb_batches
        print("# %s has %d batches, and was copied by %s " %
              (_sys, numb_batches, str(s_data.get_ncopies())))
    print("#")
    lr = LearingRate(jdata, tot_numb_batches)
    final_lr = lr.value(stop_batch)

    # start tf
    tf.reset_default_graph()
    with tf.Session(config=tf.ConfigProto(
            intra_op_parallelism_threads=num_threads)) as sess:
        # init the model
        model = NNPModel(jdata, sess)
        # build the model with stats from the first system
        data = DataSets(systems[0], set_pfx, seed=seed, do_norm=False)
        model.build(data, lr)
        # train the model with the provided systems in a cyclic way
        start_time = time.time()
        count = 0
        cur_batch = model.get_global_step()
        cur_stop_batch = cur_batch
        print("# start training, start lr is %e, final lr will be %e" %
              (lr.value(cur_stop_batch), final_lr))
        model.print_head()
        while True:
            cur_sys = systems[count % numb_sys]
            data = DataSets(cur_sys, set_pfx, seed=seed, do_norm=False)
            cur_batch = cur_stop_batch
            cur_stop_batch += data.get_sys_numb_batch(batch_size)
            if cur_stop_batch > stop_batch: cur_stop_batch = stop_batch
            print("# train with %s that has %d batches" %
                  (cur_sys, cur_stop_batch - cur_batch))
            model.train(data, cur_stop_batch)
            if cur_stop_batch == stop_batch: break
            count += 1
        print("# finished training")
        end_time = time.time()
        print("# running time: %.3f s" % (end_time - start_time))
コード例 #2
0
ファイル: dp_test.py プロジェクト: neojie/mldp
def train_ener(inputs):
    """
    deepmd-kit has function test_ener which deal with test_data only
    `train_ener` are for train data only
    """

    if inputs['rand_seed'] is not None:
        np.random.seed(inputs['rand_seed'] % (2**32))

    data = DataSets(inputs['system'],
                    inputs['set_prefix'],
                    shuffle_test=inputs['shuffle_test'])

    train_data = get_train_data(data)

    numb_test = data.get_sys_numb_batch(
        1)  ## use 1 batch, # of batches are the numb of train
    natoms = len(train_data["type"][0])
    nframes = train_data["box"].shape[0]
    #print("xxxxx",nframes, numb_test)
    numb_test = nframes  #, to be investigated, original dp use min, but here should be nframes directly, I think, Jan 18, 21, min(nfames, numb_test)
    dp = DeepPot(inputs['model'])
    coord = train_data["coord"].reshape([numb_test, -1])
    box = train_data["box"]
    atype = train_data["type"][0]
    if dp.get_dim_fparam() > 0:
        fparam = train_data["fparam"]
    else:
        fparam = None
    if dp.get_dim_aparam() > 0:
        aparam = train_data["aparam"]
    else:
        aparam = None
    detail_file = inputs['detail_file']
    if detail_file is not None:
        atomic = True
    else:
        atomic = False

    ret = dp.eval(coord,
                  box,
                  atype,
                  fparam=fparam,
                  aparam=aparam,
                  atomic=atomic)
    energy = ret[0]
    force = ret[1]
    virial = ret[2]
    energy = energy.reshape([numb_test, 1])
    force = force.reshape([numb_test, -1])
    virial = virial.reshape([numb_test, 9])
    if atomic:
        ae = ret[3]
        av = ret[4]
        ae = ae.reshape([numb_test, -1])
        av = av.reshape([numb_test, -1])

    l2e = (l2err(energy - train_data["energy"].reshape([-1, 1])))
    l2f = (l2err(force - train_data["force"]))
    l2v = (l2err(virial - train_data["virial"]))
    l2ea = l2e / natoms
    l2va = l2v / natoms

    # print ("# energies: %s" % energy)
    print("# number of train data : %d " % numb_test)
    print("Energy L2err        : %e eV" % l2e)
    print("Energy L2err/Natoms : %e eV" % l2ea)
    print("Force  L2err        : %e eV/A" % l2f)
    print("Virial L2err        : %e eV" % l2v)
    print("Virial L2err/Natoms : %e eV" % l2va)

    if detail_file is not None:
        pe = np.concatenate((np.reshape(train_data["energy"],
                                        [-1, 1]), np.reshape(energy, [-1, 1])),
                            axis=1)
        np.savetxt(os.path.join(inputs['system'], detail_file + ".e.tr.out"),
                   pe,
                   header='data_e pred_e')
        pf = np.concatenate((np.reshape(train_data["force"],
                                        [-1, 3]), np.reshape(force, [-1, 3])),
                            axis=1)
        np.savetxt(os.path.join(inputs['system'], detail_file + ".f.tr.out"),
                   pf,
                   header='data_fx data_fy data_fz pred_fx pred_fy pred_fz')
        pv = np.concatenate((np.reshape(train_data["virial"],
                                        [-1, 9]), np.reshape(virial, [-1, 9])),
                            axis=1)
        np.savetxt(
            os.path.join(inputs['system'], detail_file + ".v.tr.out"),
            pv,
            header=
            'data_vxx data_vxy data_vxz data_vyx data_vyy data_vyz data_vzx data_vzy data_vzz pred_vxx pred_vxy pred_vxz pred_vyx pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz'
        )
    return numb_test, fparam[0][0], natoms, l2e, l2ea, l2f, l2v