コード例 #1
0
ファイル: test.py プロジェクト: zhouwei25/deepmd-kit
def test(args):
    de = DeepEval(args.model)
    all_sys = expand_sys_str(args.system)
    if len(all_sys) == 0:
        print('Did not find valid system')
    err_coll = []
    siz_coll = []
    if de.model_type == 'ener':
        dp = DeepPot(args.model)
    elif de.model_type == 'dipole':
        dp = DeepDipole(args.model)
    elif de.model_type == 'polar':
        dp = DeepPolar(args.model)
    elif de.model_type == 'global_polar':
        dp = DeepGlobalPolar(args.model)
    elif de.model_type == 'wfc':
        dp = DeepWFC(args.model)
    else:
        raise RuntimeError('unknow model type ' + de.model_type)
    for cc, ii in enumerate(all_sys):
        args.system = ii
        print("# ---------------output of dp test--------------- ")
        print("# testing system : " + ii)
        if de.model_type == 'ener':
            err, siz = test_ener(dp, args, append_detail=(cc != 0))
        elif de.model_type == 'dipole':
            err, siz = test_dipole(dp, args)
        elif de.model_type == 'polar':
            err, siz = test_polar(dp, args, global_polar=False)
        elif de.model_type == 'global_polar':
            err, siz = test_polar(dp, args, global_polar=True)
        elif de.model_type == 'wfc':
            err, siz = test_wfc(dp, args)
        else:
            raise RuntimeError('unknow model type ' + de.model_type)
        print("# ----------------------------------------------- ")
        err_coll.append(err)
        siz_coll.append(siz)
    avg_err = weighted_average(err_coll, siz_coll)
    if len(all_sys) != len(err_coll):
        print('Not all systems are tested! Check if the systems are valid')
    if len(all_sys) > 1:
        print("# ----------weighted average of errors----------- ")
        print("# number of systems : %d" % len(all_sys))
        if de.model_type == 'ener':
            print_ener_sys_avg(avg_err)
        elif de.model_type == 'dipole':
            print_dipole_sys_avg(avg_err)
        elif de.model_type == 'polar':
            print_polar_sys_avg(avg_err)
        elif de.model_type == 'global_polar':
            print_polar_sys_avg(avg_err)
        elif de.model_type == 'wfc':
            print_wfc_sys_avg(avg_err)
        else:
            raise RuntimeError('unknow model type ' + de.model_type)
        print("# ----------------------------------------------- ")
コード例 #2
0
def test(args):
    de = DeepEval(args.model)
    if de.model_type == 'ener':
        test_ener(args)
    elif de.model_type == 'polar':
        test_polar(args)
    elif de.model_type == 'wfc':
        test_wfc(args)
    else:
        raise RuntimeError('unknow model type ' + de.model_type)
コード例 #3
0
    def __init__(self,
                 model_file=None,
                 coord_file=None,
                 energy_file=None,
                 force_file=None,
                 grad_file=None,
                 box_file=None,
                 atom_types=None,
                 length_unit="A",
                 energy_unit="Eh",
                 force_unit="Eh/Bohr",
                 is_pbc=False,
                 verbose=True):

        self.is_pbc = is_pbc
        self.verbose = verbose

        if isinstance(atom_types, str):
            self._atm_type = numpy.loadtxt(atom_types, dtype=int)
        else:
            self._atm_type = numpy.asarray(atom_types, dtype=int)

        self.model_file = model_file
        model_type = DeepEval(model_file).model_type
        assert model_type == 'ener'
        self.dp = DeepPot(model_file)

        length_unit, length_unit_converter = get_length_unit_converter(
            length_unit)
        energy_unit, energy_unit_converter = get_energy_unit_converter(
            energy_unit)
        force_unit, force_unit_converter = get_force_unit_converter(force_unit)

        self._coord_data = numpy.load(coord_file) * length_unit_converter
        self._energy_data = numpy.load(energy_file) * energy_unit_converter
        if force_file is not None and grad_file is None:
            self._force_data = numpy.load(force_file) * force_unit_converter
        elif force_file is None and grad_file is not None:
            self._force_data = -numpy.load(grad_file) * force_unit_converter

        self.nframe = self._coord_data.shape[0]
        self.natom = self._coord_data.shape[1]
        self._atm_type.reshape(self.natom)

        if not is_pbc and (box_file is None):
            self._box_data = numpy.zeros([self.nframe, 9])
        else:
            self._box_data = numpy.load(box_file) * length_unit_converter

        assert self._box_data.shape == (self.nframe, 9)
        assert self._coord_data.shape == (self.nframe, self.natom, 3)
        assert self._force_data.shape == (self.nframe, self.natom, 3)
        assert self._energy_data.shape == (self.nframe, )
        assert self._atm_type.shape == (self.natom, )

        self.dump_info(model_file=model_file,
                       atom_types=atom_types,
                       is_pbc=is_pbc,
                       coord_file=coord_file,
                       energy_file=energy_file,
                       force_file=force_file,
                       grad_file=grad_file,
                       box_file=box_file,
                       length_unit=length_unit,
                       energy_unit=energy_unit,
                       force_unit=force_unit)