Exemplo n.º 1
0
 def __init__ (self, 
               rand_pert = 0.1, 
               seed = 1, 
               box_scale = 20,
               nframes = 1):
     coord = [[0.0, 0.0, 0.1], [1.1, 0.0, 0.1], [0.0, 1.1, 0.1], 
              [4.0, 0.0, 0.0], [5.1, 0.0, 0.0], [4.0, 1.1, 0.0]]
     self.nframes = nframes
     self.coord = np.array(coord)
     self.coord = self._copy_nframes(self.coord)
     dp_random.seed(seed)
     self.coord += rand_pert * dp_random.random(self.coord.shape)
     self.fparam = np.array([[0.1, 0.2]])
     self.aparam = np.tile(self.fparam, [1, 6])
     self.fparam = self._copy_nframes(self.fparam)
     self.aparam = self._copy_nframes(self.aparam)
     self.atype = np.array([0, 1, 1, 0, 1, 1], dtype = int)
     self.cell = box_scale * np.eye(3)
     self.cell = self._copy_nframes(self.cell)
     self.coord = self.coord.reshape([self.nframes, -1])
     self.cell = self.cell.reshape([self.nframes, -1])
     self.natoms = len(self.atype)        
     self.idx_map = np.lexsort ((np.arange(self.natoms), self.atype))
     self.coord = self.coord.reshape([self.nframes, -1, 3])
     self.coord = self.coord[:,self.idx_map,:]
     self.coord = self.coord.reshape([self.nframes, -1])        
     self.efield = dp_random.random(self.coord.shape)
     self.atype = self.atype[self.idx_map]
     self.datype = self._copy_nframes(self.atype)
Exemplo n.º 2
0
 def test_ener_shift(self):
     dp_random.seed(0)
     data = DeepmdDataSystem(['system_0', 'system_1'], 5, 10, 1.0)
     data.add('energy', 1, must=True)
     ener_shift0 = data.compute_energy_shift(rcond=1)
     all_stat = make_stat_input(data, 4, merge_sys=False)
     ener_shift1 = EnerFitting._compute_output_stats(all_stat, rcond=1)
     np.testing.assert_almost_equal(ener_shift0, ener_shift1)
Exemplo n.º 3
0
 def test_ener_shift(self):
     dp_random.seed(0)
     data = DeepmdDataSystem(['system_0', 'system_1'], 5, 10, 1.0)
     data.add('energy', 1, must=True)
     ener_shift0 = data.compute_energy_shift(rcond=1)
     all_stat = make_stat_input(data, 4, merge_sys=False)
     descrpt = DescrptSeA(6.0,
                          5.8, [46, 92],
                          neuron=[25, 50, 100],
                          axis_neuron=16)
     fitting = EnerFitting(descrpt, neuron=[240, 240, 240], resnet_dt=True)
     ener_shift1 = fitting._compute_output_stats(all_stat, rcond=1)
     np.testing.assert_almost_equal(ener_shift0, ener_shift1)
Exemplo n.º 4
0
 def test_ener_shift_assigned(self):
     dp_random.seed(0)
     ae0 = dp_random.random()
     data = DeepmdDataSystem(['system_0'], 5, 10, 1.0)
     data.add('energy', 1, must=True)
     all_stat = make_stat_input(data, 4, merge_sys=False)
     descrpt = DescrptSeA(6.0,
                          5.8, [46, 92],
                          neuron=[25, 50, 100],
                          axis_neuron=16)
     fitting = EnerFitting(descrpt,
                           neuron=[240, 240, 240],
                           resnet_dt=True,
                           atom_ener=[ae0, None, None])
     ener_shift1 = fitting._compute_output_stats(all_stat, rcond=1)
     # check assigned energy
     np.testing.assert_almost_equal(ae0, ener_shift1[0])
     # check if total energy are the same
     natoms = data.natoms_vec[0][2:]
     tot0 = np.dot(data.compute_energy_shift(rcond=1), natoms)
     tot1 = np.dot(ener_shift1, natoms)
     np.testing.assert_almost_equal(tot0, tot1)
Exemplo n.º 5
0
    def test_merge_all_stat(self):
        dp_random.seed(0)
        data0 = DeepmdDataSystem(['system_0', 'system_1'], 5, 10, 1.0)
        data0.add('energy', 1, must=True)
        dp_random.seed(0)
        data1 = DeepmdDataSystem(['system_0', 'system_1'], 5, 10, 1.0)
        data1.add('energy', 1, must=True)
        dp_random.seed(0)
        data2 = DeepmdDataSystem(['system_0', 'system_1'], 5, 10, 1.0)
        data2.add('energy', 1, must=True)

        dp_random.seed(0)
        all_stat_0 = make_stat_input(data0, 10, merge_sys=False)
        dp_random.seed(0)
        all_stat_1 = make_stat_input(data1, 10, merge_sys=True)
        all_stat_2 = merge_sys_stat(all_stat_0)
        dp_random.seed(0)
        all_stat_3 = _make_all_stat_ref(data2, 10)

        ####################################
        # only check if the energy is concatenated correctly
        ####################################
        dd = 'energy'
        # if 'find_' in dd: continue
        # if 'natoms_vec' in dd: continue
        # if 'default_mesh' in dd: continue
        # print(all_stat_2[dd])
        # print(dd, all_stat_1[dd])
        d1 = np.array(all_stat_1[dd])
        d2 = np.array(all_stat_2[dd])
        d3 = np.array(all_stat_3[dd])
        # print(dd)
        # print(d1.shape)
        # print(d2.shape)
        # self.assertEqual(all_stat_2[dd], all_stat_1[dd])
        self._comp_data(d1, d2)
        self._comp_data(d1, d3)
Exemplo n.º 6
0
def test(
    *,
    model: str,
    system: str,
    set_prefix: str,
    numb_test: int,
    rand_seed: Optional[int],
    shuffle_test: bool,
    detail_file: str,
    atomic: bool,
    **kwargs,
):
    """Test model predictions.

    Parameters
    ----------
    model : str
        path where model is stored
    system : str
        system directory
    set_prefix : str
        string prefix of set
    numb_test : int
        munber of tests to do
    rand_seed : Optional[int]
        seed for random generator
    shuffle_test : bool
        whether to shuffle tests
    detail_file : Optional[str]
        file where test details will be output
    atomic : bool
        whether per atom quantities should be computed

    Raises
    ------
    RuntimeError
        if no valid system was found
    """
    all_sys = expand_sys_str(system)
    if len(all_sys) == 0:
        raise RuntimeError("Did not find valid system")
    err_coll = []
    siz_coll = []

    # init random seed
    if rand_seed is not None:
        dp_random.seed(rand_seed % (2 ** 32))

    # init model
    dp = DeepPotential(model)

    for cc, system in enumerate(all_sys):
        log.info("# ---------------output of dp test--------------- ")
        log.info(f"# testing system : {system}")

        # create data class
        tmap = dp.get_type_map() if dp.model_type == "ener" else None
        data = DeepmdData(system, set_prefix, shuffle_test=shuffle_test, type_map=tmap)

        if dp.model_type == "ener":
            err = test_ener(
                dp,
                data,
                system,
                numb_test,
                detail_file,
                atomic,
                append_detail=(cc != 0),
            )
        elif dp.model_type == "dipole":
            err = test_dipole(dp, data, numb_test, detail_file, atomic)
        elif dp.model_type == "polar":
            err = test_polar(dp, data, numb_test, detail_file, atomic=atomic)
        elif dp.model_type == "global_polar":   # should not appear in this new version
            log.warning("Global polar model is not currently supported. Please directly use the polar mode and change loss parameters.")
            err = test_polar(dp, data, numb_test, detail_file, atomic=False)    # YWolfeee: downward compatibility
        log.info("# ----------------------------------------------- ")
        err_coll.append(err)

    avg_err = weighted_average(err_coll)

    if len(all_sys) != len(err_coll):
        log.warning("Not all systems are tested! Check if the systems are valid")

    if len(all_sys) > 1:
        log.info("# ----------weighted average of errors----------- ")
        log.info(f"# number of systems : {len(all_sys)}")
        if dp.model_type == "ener":
            print_ener_sys_avg(avg_err)
        elif dp.model_type == "dipole":
            print_dipole_sys_avg(avg_err)
        elif dp.model_type == "polar":
            print_polar_sys_avg(avg_err)
        elif dp.model_type == "global_polar":
            print_polar_sys_avg(avg_err)
        elif dp.model_type == "wfc":
            print_wfc_sys_avg(avg_err)
        log.info("# ----------------------------------------------- ")
Exemplo n.º 7
0
def _do_work(jdata: Dict[str, Any], run_opt: RunOptions, is_compress: bool = False):
    """Run serial model training.

    Parameters
    ----------
    jdata : Dict[str, Any]
        arguments read form json/yaml control file
    run_opt : RunOptions
        object with run configuration
    is_compress : Bool
        indicates whether in model compress mode

    Raises
    ------
    RuntimeError
        If unsupported modifier type is selected for model
    """
    # make necessary checks
    assert "training" in jdata

    # init the model
    model = DPTrainer(jdata, run_opt=run_opt, is_compress = is_compress)
    rcut = model.model.get_rcut()
    type_map = model.model.get_type_map()
    if len(type_map) == 0:
        ipt_type_map = None
    else:
        ipt_type_map = type_map

    # init random seed of data systems
    seed = jdata["training"].get("seed", None)
    if seed is not None:
        # avoid the same batch sequence among workers
        seed += run_opt.my_rank
        seed = seed % (2 ** 32)
    dp_random.seed(seed)

    # setup data modifier
    modifier = get_modifier(jdata["model"].get("modifier", None))

    # decouple the training data from the model compress process
    train_data = None
    valid_data = None
    if not is_compress:
        # init data
        train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, modifier)
        train_data.print_summary("training")
        if jdata["training"].get("validation_data", None) is not None:
            valid_data = get_data(jdata["training"]["validation_data"], rcut, ipt_type_map, modifier)
            valid_data.print_summary("validation")

    # get training info
    stop_batch = j_must_have(jdata["training"], "numb_steps")
    model.build(train_data, stop_batch)

    if not is_compress:
        # train the model with the provided systems in a cyclic way
        start_time = time.time()
        model.train(train_data, valid_data)
        end_time = time.time()
        log.info("finished training")
        log.info(f"wall time: {(end_time - start_time):.3f} s")
    else:
        model.save_compressed()
        log.info("finished compressing")